red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: e7ed2df404bfc36275381f523c0439f2e73debcf36f3b5edb063e985502d7a70
|
4
|
+
data.tar.gz: 357c983134aae985808568113d3f4f82bacebf6d25e5cf7c4f9197b1825455dc
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 40eb83d14d6efd140a4cb9748f04f50cfa325c9831d8020890a20fe88fc1485547f4dcab48cdcadfda317b46b3f4a6bc936eb8204ae39a876e053878caa7359f
|
7
|
+
data.tar.gz: af4133b975c5b4b5ca6e2ce9fb05eddd2b1de5a8a30df9c776531a5acdcf5bc4d8322dc7d6875c49800587a4d98031d0eb62054dbd87ced964093c501da32c95
|
data/.gitignore
CHANGED
data/.travis.yml
CHANGED
@@ -1,8 +1,13 @@
|
|
1
|
+
notifications:
|
2
|
+
webhooks:
|
3
|
+
- https://webhook.commit-email.info/
|
1
4
|
sudo: false
|
2
5
|
language: ruby
|
3
6
|
rvm:
|
4
|
-
- 2.3.6
|
5
7
|
- 2.4.3
|
6
8
|
- 2.5.0
|
7
|
-
|
8
|
-
|
9
|
+
- 2.6.0
|
10
|
+
before_install: gem install bundler
|
11
|
+
script:
|
12
|
+
- ruby test/run_test.rb
|
13
|
+
- yardoc --fail-on-warning
|
data/.yardopts
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
-p templates
|
data/Gemfile
CHANGED
@@ -1,4 +1,9 @@
|
|
1
1
|
source "https://rubygems.org"
|
2
2
|
|
3
|
-
# Specify your gem's dependencies in red-chainer.gemspec
|
4
3
|
gemspec
|
4
|
+
|
5
|
+
local_gemfile = File.join(File.dirname(__FILE__), "Gemfile.local")
|
6
|
+
if File.exist?(local_gemfile)
|
7
|
+
puts "Loading Gemfile.local ..." if $DEBUG # `ruby -d` or `bundle -v`
|
8
|
+
instance_eval File.read(local_gemfile)
|
9
|
+
end
|
data/README.md
CHANGED
@@ -8,7 +8,7 @@ It ported python's [Chainer](https://github.com/chainer/chainer) with Ruby.
|
|
8
8
|
|
9
9
|
## Requirements
|
10
10
|
|
11
|
-
* Ruby 2.
|
11
|
+
* Ruby 2.4 or later
|
12
12
|
|
13
13
|
## Installation
|
14
14
|
|
@@ -31,7 +31,10 @@ $ gem install red-chainer
|
|
31
31
|
```
|
32
32
|
|
33
33
|
## Usage
|
34
|
-
|
34
|
+
|
35
|
+
### Run MNIST example
|
36
|
+
|
37
|
+
MNIST sample program is [here](./examples/mnist/mnist.rb)
|
35
38
|
|
36
39
|
```bash
|
37
40
|
# when install Gemfile
|
@@ -40,6 +43,34 @@ $ bundle exec ruby examples/mnist/mnist.rb
|
|
40
43
|
$ ruby examples/mnist/mnist.rb
|
41
44
|
```
|
42
45
|
|
46
|
+
### Run MNIST example with GPU
|
47
|
+
|
48
|
+
On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
|
49
|
+
|
50
|
+
Run the example with `--gpu` option whose value indicates GPU device ID such as:
|
51
|
+
|
52
|
+
```
|
53
|
+
$ bundle exec ruby examples/mnist/mnist.rb --gpu 0
|
54
|
+
```
|
55
|
+
|
56
|
+
## Development
|
57
|
+
|
58
|
+
### Run tests
|
59
|
+
|
60
|
+
```
|
61
|
+
$ bundle exec ruby test/run_test.rb
|
62
|
+
```
|
63
|
+
|
64
|
+
### Run tests with Cumo
|
65
|
+
|
66
|
+
On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
|
67
|
+
|
68
|
+
Run tests with `RED_CHAINER_GPU` environment variable whose value indicates GPU device ID such as:
|
69
|
+
|
70
|
+
```
|
71
|
+
$ bundle exec env RED_CHAINER_GPU=0 ruby test/run_test.rb
|
72
|
+
```
|
73
|
+
|
43
74
|
## License
|
44
75
|
|
45
76
|
The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
@@ -54,4 +85,4 @@ The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
|
54
85
|
| [connection](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/connection) | 12 | 2 | Linear, Convolution2D |
|
55
86
|
| [pooling](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/pooling) | 14 | 3 | Pooling2D, MaxPooling2D, AveragePooling2D |
|
56
87
|
| [example](https://github.com/red-data-tools/red-chainer/tree/master/examples) | 31 | 3 | MNIST, Iris, CIFAR |
|
57
|
-
| GPU | use
|
88
|
+
| GPU | use CuPy | use [Cumo](https://github.com/sonots/cumo) ||
|
@@ -9,6 +9,7 @@ args = {
|
|
9
9
|
batchsize: 64,
|
10
10
|
learnrate: 0.05,
|
11
11
|
epoch: 300,
|
12
|
+
gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
|
12
13
|
out: 'result',
|
13
14
|
resume: nil,
|
14
15
|
model: 'vgg',
|
@@ -21,11 +22,21 @@ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default:
|
|
21
22
|
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
22
23
|
opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
|
23
24
|
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
25
|
+
opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
|
24
26
|
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
25
27
|
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
26
28
|
opt.on('-m', '--model VALUE', "Use model") { |v| args[:model] = v }
|
27
29
|
opt.parse!(ARGV)
|
28
30
|
|
31
|
+
puts "GPU: #{args[:gpu]}"
|
32
|
+
puts "# unit: #{args[:unit]}"
|
33
|
+
puts "# Minibatch-size: #{args[:batchsize]}"
|
34
|
+
puts "# epoch: #{args[:epoch]}"
|
35
|
+
puts
|
36
|
+
|
37
|
+
device = Chainer::Device.create(args[:gpu])
|
38
|
+
Chainer::Device.change_default(device)
|
39
|
+
|
29
40
|
# Set up a neural network to train.
|
30
41
|
# Classifier reports softmax cross entropy loss and accuracy at every
|
31
42
|
# iteration, which will be used by the PrintReport extension below.
|
@@ -57,10 +68,10 @@ optimizer.setup(model)
|
|
57
68
|
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
58
69
|
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
59
70
|
|
60
|
-
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device:
|
71
|
+
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
|
61
72
|
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
62
73
|
|
63
|
-
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device:
|
74
|
+
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: device))
|
64
75
|
|
65
76
|
trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
|
66
77
|
|
data/examples/iris/iris.rb
CHANGED
@@ -25,6 +25,10 @@ class IrisChain < Chainer::Chain
|
|
25
25
|
end
|
26
26
|
end
|
27
27
|
|
28
|
+
device = Chainer::Device.create(Integer(ENV['RED_CHAINER_GPU'] || -1))
|
29
|
+
Chainer::Device.change_default(device)
|
30
|
+
xm = device.xm
|
31
|
+
|
28
32
|
model = IrisChain.new(6,3)
|
29
33
|
|
30
34
|
optimizer = Chainer::Optimizers::Adam.new
|
@@ -35,7 +39,7 @@ iris_table = iris.to_table
|
|
35
39
|
x = iris_table.fetch_values(:sepal_length, :sepal_width, :petal_length, :petal_width).transpose
|
36
40
|
|
37
41
|
# target
|
38
|
-
y_class = iris_table[:
|
42
|
+
y_class = iris_table[:label]
|
39
43
|
|
40
44
|
# class index array
|
41
45
|
# ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
|
@@ -46,7 +50,7 @@ y = y_class.map{|s|
|
|
46
50
|
}
|
47
51
|
|
48
52
|
# y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
|
49
|
-
y_onehot =
|
53
|
+
y_onehot = xm::SFloat.eye(class_name.size)[y, false]
|
50
54
|
|
51
55
|
puts "Iris Datasets"
|
52
56
|
puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
|
@@ -57,9 +61,9 @@ x.each_with_index{|r, i|
|
|
57
61
|
# [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
|
58
62
|
# [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
|
59
63
|
|
60
|
-
x =
|
61
|
-
y =
|
62
|
-
y_onehot =
|
64
|
+
x = xm::SFloat.cast(x)
|
65
|
+
y = xm::SFloat.cast(y)
|
66
|
+
y_onehot = xm::SFloat.cast(y_onehot)
|
63
67
|
|
64
68
|
x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
65
69
|
y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
data/examples/mnist/mnist.rb
CHANGED
@@ -27,6 +27,7 @@ args = {
|
|
27
27
|
batchsize: 100,
|
28
28
|
frequency: -1,
|
29
29
|
epoch: 20,
|
30
|
+
gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
|
30
31
|
resume: nil,
|
31
32
|
unit: 1000,
|
32
33
|
out: 'result'
|
@@ -35,25 +36,36 @@ args = {
|
|
35
36
|
opt = OptionParser.new
|
36
37
|
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
|
37
38
|
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
39
|
+
opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
|
38
40
|
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
39
41
|
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
40
42
|
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
41
43
|
opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
|
42
44
|
opt.parse!(ARGV)
|
43
45
|
|
44
|
-
|
46
|
+
puts "GPU: #{args[:gpu]}"
|
47
|
+
puts "# unit: #{args[:unit]}"
|
48
|
+
puts "# Minibatch-size: #{args[:batchsize]}"
|
49
|
+
puts "# epoch: #{args[:epoch]}"
|
50
|
+
puts
|
51
|
+
|
52
|
+
device = Chainer::Device.create(args[:gpu])
|
53
|
+
Chainer::Device.change_default(device)
|
54
|
+
|
55
|
+
lossfun = -> (x, t) { Chainer::Functions::Loss::SoftmaxCrossEntropy.new(ignore_label: nil).(x, t) }
|
56
|
+
model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10), lossfun)
|
45
57
|
|
46
58
|
optimizer = Chainer::Optimizers::Adam.new
|
47
59
|
optimizer.setup(model)
|
48
|
-
train, test = Chainer::Datasets::
|
60
|
+
train, test = Chainer::Datasets::MNIST.get_mnist
|
49
61
|
|
50
62
|
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
51
63
|
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
52
64
|
|
53
|
-
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device:
|
65
|
+
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
|
54
66
|
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
55
67
|
|
56
|
-
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device:
|
68
|
+
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: args[:gpu]))
|
57
69
|
|
58
70
|
# Take a snapshot for each specified epoch
|
59
71
|
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
|
data/lib/chainer.rb
CHANGED
@@ -3,8 +3,11 @@ require "weakref"
|
|
3
3
|
require "chainer/version"
|
4
4
|
|
5
5
|
require 'chainer/cuda'
|
6
|
+
require 'chainer/backend'
|
6
7
|
require 'chainer/configuration'
|
8
|
+
require 'chainer/device'
|
7
9
|
require 'chainer/function'
|
10
|
+
require 'chainer/function_node'
|
8
11
|
require 'chainer/optimizer'
|
9
12
|
require 'chainer/gradient_method'
|
10
13
|
require 'chainer/gradient_check'
|
@@ -15,6 +18,7 @@ require 'chainer/initializer'
|
|
15
18
|
require 'chainer/initializers/init'
|
16
19
|
require 'chainer/initializers/constant'
|
17
20
|
require 'chainer/initializers/normal'
|
21
|
+
require 'chainer/initializers/uniform'
|
18
22
|
require 'chainer/iterators/serial_iterator'
|
19
23
|
require 'chainer/link'
|
20
24
|
require 'chainer/links/connection/convolution_2d'
|
@@ -30,15 +34,28 @@ require 'chainer/utils/variable'
|
|
30
34
|
require 'chainer/utils/array'
|
31
35
|
require 'chainer/functions/activation/leaky_relu'
|
32
36
|
require 'chainer/functions/activation/relu'
|
37
|
+
require 'chainer/functions/activation/relu_grad2'
|
33
38
|
require 'chainer/functions/activation/sigmoid'
|
39
|
+
require 'chainer/functions/activation/sigmoid_grad'
|
34
40
|
require 'chainer/functions/activation/tanh'
|
35
41
|
require 'chainer/functions/activation/log_softmax'
|
42
|
+
require 'chainer/functions/array/broadcast_to'
|
43
|
+
require 'chainer/functions/array/cast'
|
44
|
+
require 'chainer/functions/array/reshape'
|
45
|
+
require 'chainer/functions/array/rollaxis'
|
46
|
+
require 'chainer/functions/array/select_item'
|
47
|
+
require 'chainer/functions/array/squeeze'
|
48
|
+
require 'chainer/functions/array/transpose'
|
36
49
|
require 'chainer/functions/evaluation/accuracy'
|
37
50
|
require 'chainer/functions/math/basic_math'
|
38
51
|
require 'chainer/functions/math/identity'
|
52
|
+
require 'chainer/functions/math/sum'
|
53
|
+
require 'chainer/functions/math/exp'
|
39
54
|
require 'chainer/functions/loss/mean_squared_error'
|
40
55
|
require 'chainer/functions/loss/softmax_cross_entropy'
|
41
56
|
require 'chainer/functions/connection/convolution_2d'
|
57
|
+
require 'chainer/functions/connection/deconvolution_2d'
|
58
|
+
require 'chainer/functions/connection/convolution_2d_grad_w'
|
42
59
|
require 'chainer/functions/connection/linear'
|
43
60
|
require 'chainer/functions/noise/dropout'
|
44
61
|
require 'chainer/functions/normalization/batch_normalization'
|
@@ -61,7 +78,6 @@ require 'chainer/training/triggers/interval'
|
|
61
78
|
require 'chainer/parameter'
|
62
79
|
require 'chainer/optimizers/adam'
|
63
80
|
require 'chainer/optimizers/momentum_sgd'
|
64
|
-
require 'chainer/dataset/download'
|
65
81
|
require 'chainer/datasets/mnist'
|
66
82
|
require 'chainer/datasets/cifar'
|
67
83
|
require 'chainer/datasets/tuple_dataset'
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Chainer
|
2
|
+
# Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+ from given arrays.
|
3
|
+
#
|
4
|
+
# @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
|
5
|
+
# @return [Class] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of the arguments.
|
6
|
+
def get_array_module(*args)
|
7
|
+
arrays = args.map {|v| v.kind_of?(Chainer::Variable) ? v.data : v }
|
8
|
+
if CUDA.available?
|
9
|
+
return Cumo if arrays.any? {|a| a.kind_of?(Cumo::NArray) }
|
10
|
+
end
|
11
|
+
return Numo
|
12
|
+
end
|
13
|
+
module_function :get_array_module
|
14
|
+
|
15
|
+
# Returns true if the argument is either of +Numo::NArray+ or +Cumo::NArray+.
|
16
|
+
#
|
17
|
+
# @param [Object] obj
|
18
|
+
# @return [Boolean]
|
19
|
+
def array?(obj)
|
20
|
+
if CUDA.available?
|
21
|
+
return true if obj.kind_of?(Cumo::NArray)
|
22
|
+
end
|
23
|
+
return true if obj.kind_of?(Numo::NArray)
|
24
|
+
false
|
25
|
+
end
|
26
|
+
module_function :array?
|
27
|
+
end
|
data/lib/chainer/cuda.rb
CHANGED
@@ -1,18 +1,40 @@
|
|
1
|
+
begin
|
2
|
+
require 'cumo'
|
3
|
+
$chainer_cuda_available = true
|
4
|
+
rescue LoadError => e
|
5
|
+
$chainer_cuda_available = false
|
6
|
+
# A trick to make Cumo::NArray always exists
|
7
|
+
module Cumo
|
8
|
+
class NArray; end
|
9
|
+
class NMath; end
|
10
|
+
class Bit; end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
1
14
|
module Chainer
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
15
|
+
module CUDA
|
16
|
+
# Returns whether CUDA is available.
|
17
|
+
#
|
18
|
+
# @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
|
19
|
+
# @return [Boolean]
|
20
|
+
def available?(id = nil)
|
21
|
+
return false unless $chainer_cuda_available
|
22
|
+
if id
|
23
|
+
raise 'id must be non negative' if id < 0
|
24
|
+
@device_count ||= Cumo::CUDA::Runtime.cudaGetDeviceCount
|
25
|
+
return @device_count > id
|
26
|
+
end
|
27
|
+
true
|
28
|
+
end
|
29
|
+
module_function :available?
|
30
|
+
|
31
|
+
# Checks if CUDA is available.
|
32
|
+
#
|
33
|
+
# @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
|
34
|
+
# @raise [RuntimeError] if not available
|
35
|
+
def check_available(id = nil)
|
36
|
+
raise 'CUDA is not available' unless available?(id)
|
37
|
+
end
|
38
|
+
module_function :check_available
|
16
39
|
end
|
17
|
-
module_function :get_array_module
|
18
40
|
end
|
@@ -2,12 +2,13 @@ module Chainer
|
|
2
2
|
module Dataset
|
3
3
|
module Convert
|
4
4
|
def self.to_device(device, x)
|
5
|
-
# TODO:
|
5
|
+
# TODO(sonots): Implement after Cumo supports transferring between devices
|
6
6
|
x
|
7
7
|
end
|
8
8
|
|
9
9
|
def self.concat_examples(batch, device: nil, padding: nil)
|
10
10
|
raise "batch is empty" if batch.size == 0
|
11
|
+
device = device ? Chainer::Device.create(device) : Chainer::Device.default # takes care of int and nil
|
11
12
|
first_elem = batch[0]
|
12
13
|
|
13
14
|
if first_elem.kind_of?(Array)
|
@@ -17,28 +18,29 @@ module Chainer
|
|
17
18
|
end
|
18
19
|
|
19
20
|
first_elem.size.times do |i|
|
20
|
-
x =
|
21
|
+
x = _concat_arrays(batch.map { |b| b[i] }, padding[i], device)
|
21
22
|
result.push(to_device(device, x))
|
22
23
|
end
|
23
24
|
|
24
25
|
return result
|
25
26
|
else
|
26
|
-
return
|
27
|
+
return _concat_arrays(batch, padding, device)
|
27
28
|
end
|
28
29
|
end
|
29
30
|
|
30
|
-
def self.
|
31
|
-
|
31
|
+
def self._concat_arrays(arrays, padding, device)
|
32
|
+
xm = device.xm
|
33
|
+
unless arrays[0].kind_of?(xm::NArray)
|
32
34
|
# [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
|
33
|
-
arrays =
|
35
|
+
arrays = xm::NArray.cast(arrays)
|
34
36
|
if padding
|
35
|
-
return
|
37
|
+
return _concat_arrays_with_padding(arrays, padding, device)
|
36
38
|
end
|
37
39
|
return arrays
|
38
40
|
end
|
39
41
|
|
40
42
|
if padding
|
41
|
-
return
|
43
|
+
return _concat_arrays_with_padding(arrays, padding, device)
|
42
44
|
end
|
43
45
|
|
44
46
|
# [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
|
@@ -48,12 +50,14 @@ module Chainer
|
|
48
50
|
a[0].concatenate(*a[1..-1])
|
49
51
|
end
|
50
52
|
|
51
|
-
def self.
|
52
|
-
|
53
|
-
|
53
|
+
def self._concat_arrays_with_padding(arrays, padding, device)
|
54
|
+
xm = device.xm
|
55
|
+
if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
|
56
|
+
xm = Chainer.get_array_module(arrays[0])
|
57
|
+
shape = xm::Int32.cast(arrays[0].shape)
|
54
58
|
arrays[1..-1].each do |array|
|
55
|
-
if
|
56
|
-
shape =
|
59
|
+
if xm::Bit.[](shape != array.shape).any?
|
60
|
+
shape = xm::Int32.maximum(shape, array.shape)
|
57
61
|
end
|
58
62
|
end
|
59
63
|
else # Integer
|
@@ -61,15 +65,15 @@ module Chainer
|
|
61
65
|
end
|
62
66
|
|
63
67
|
shape = shape.insert(0, arrays.size).to_a
|
64
|
-
if arrays[0].
|
68
|
+
if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
|
65
69
|
result = arrays[0].class.new(shape).fill(padding)
|
66
70
|
else # Integer
|
67
|
-
result =
|
71
|
+
result = xm::Int32.new(shape).fill(padding)
|
68
72
|
end
|
69
73
|
|
70
74
|
arrays.size.times do |i|
|
71
75
|
src = arrays[i]
|
72
|
-
if
|
76
|
+
if Chainer.array?(src) and src.ndim > 0
|
73
77
|
result[i, 0...src.shape[0], 0...src.shape[1]] = src
|
74
78
|
else # Integer
|
75
79
|
result[i] = src
|