red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: e7ed2df404bfc36275381f523c0439f2e73debcf36f3b5edb063e985502d7a70
|
4
|
+
data.tar.gz: 357c983134aae985808568113d3f4f82bacebf6d25e5cf7c4f9197b1825455dc
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 40eb83d14d6efd140a4cb9748f04f50cfa325c9831d8020890a20fe88fc1485547f4dcab48cdcadfda317b46b3f4a6bc936eb8204ae39a876e053878caa7359f
|
7
|
+
data.tar.gz: af4133b975c5b4b5ca6e2ce9fb05eddd2b1de5a8a30df9c776531a5acdcf5bc4d8322dc7d6875c49800587a4d98031d0eb62054dbd87ced964093c501da32c95
|
data/.gitignore
CHANGED
data/.travis.yml
CHANGED
@@ -1,8 +1,13 @@
|
|
1
|
+
notifications:
|
2
|
+
webhooks:
|
3
|
+
- https://webhook.commit-email.info/
|
1
4
|
sudo: false
|
2
5
|
language: ruby
|
3
6
|
rvm:
|
4
|
-
- 2.3.6
|
5
7
|
- 2.4.3
|
6
8
|
- 2.5.0
|
7
|
-
|
8
|
-
|
9
|
+
- 2.6.0
|
10
|
+
before_install: gem install bundler
|
11
|
+
script:
|
12
|
+
- ruby test/run_test.rb
|
13
|
+
- yardoc --fail-on-warning
|
data/.yardopts
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
-p templates
|
data/Gemfile
CHANGED
@@ -1,4 +1,9 @@
|
|
1
1
|
source "https://rubygems.org"
|
2
2
|
|
3
|
-
# Specify your gem's dependencies in red-chainer.gemspec
|
4
3
|
gemspec
|
4
|
+
|
5
|
+
local_gemfile = File.join(File.dirname(__FILE__), "Gemfile.local")
|
6
|
+
if File.exist?(local_gemfile)
|
7
|
+
puts "Loading Gemfile.local ..." if $DEBUG # `ruby -d` or `bundle -v`
|
8
|
+
instance_eval File.read(local_gemfile)
|
9
|
+
end
|
data/README.md
CHANGED
@@ -8,7 +8,7 @@ It ported python's [Chainer](https://github.com/chainer/chainer) with Ruby.
|
|
8
8
|
|
9
9
|
## Requirements
|
10
10
|
|
11
|
-
* Ruby 2.
|
11
|
+
* Ruby 2.4 or later
|
12
12
|
|
13
13
|
## Installation
|
14
14
|
|
@@ -31,7 +31,10 @@ $ gem install red-chainer
|
|
31
31
|
```
|
32
32
|
|
33
33
|
## Usage
|
34
|
-
|
34
|
+
|
35
|
+
### Run MNIST example
|
36
|
+
|
37
|
+
MNIST sample program is [here](./examples/mnist/mnist.rb)
|
35
38
|
|
36
39
|
```bash
|
37
40
|
# when install Gemfile
|
@@ -40,6 +43,34 @@ $ bundle exec ruby examples/mnist/mnist.rb
|
|
40
43
|
$ ruby examples/mnist/mnist.rb
|
41
44
|
```
|
42
45
|
|
46
|
+
### Run MNIST example with GPU
|
47
|
+
|
48
|
+
On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
|
49
|
+
|
50
|
+
Run the example with `--gpu` option whose value indicates GPU device ID such as:
|
51
|
+
|
52
|
+
```
|
53
|
+
$ bundle exec ruby examples/mnist/mnist.rb --gpu 0
|
54
|
+
```
|
55
|
+
|
56
|
+
## Development
|
57
|
+
|
58
|
+
### Run tests
|
59
|
+
|
60
|
+
```
|
61
|
+
$ bundle exec ruby test/run_test.rb
|
62
|
+
```
|
63
|
+
|
64
|
+
### Run tests with Cumo
|
65
|
+
|
66
|
+
On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
|
67
|
+
|
68
|
+
Run tests with `RED_CHAINER_GPU` environment variable whose value indicates GPU device ID such as:
|
69
|
+
|
70
|
+
```
|
71
|
+
$ bundle exec env RED_CHAINER_GPU=0 ruby test/run_test.rb
|
72
|
+
```
|
73
|
+
|
43
74
|
## License
|
44
75
|
|
45
76
|
The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
@@ -54,4 +85,4 @@ The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
|
54
85
|
| [connection](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/connection) | 12 | 2 | Linear, Convolution2D |
|
55
86
|
| [pooling](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/pooling) | 14 | 3 | Pooling2D, MaxPooling2D, AveragePooling2D |
|
56
87
|
| [example](https://github.com/red-data-tools/red-chainer/tree/master/examples) | 31 | 3 | MNIST, Iris, CIFAR |
|
57
|
-
| GPU | use
|
88
|
+
| GPU | use CuPy | use [Cumo](https://github.com/sonots/cumo) ||
|
@@ -9,6 +9,7 @@ args = {
|
|
9
9
|
batchsize: 64,
|
10
10
|
learnrate: 0.05,
|
11
11
|
epoch: 300,
|
12
|
+
gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
|
12
13
|
out: 'result',
|
13
14
|
resume: nil,
|
14
15
|
model: 'vgg',
|
@@ -21,11 +22,21 @@ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default:
|
|
21
22
|
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
22
23
|
opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
|
23
24
|
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
25
|
+
opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
|
24
26
|
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
25
27
|
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
26
28
|
opt.on('-m', '--model VALUE', "Use model") { |v| args[:model] = v }
|
27
29
|
opt.parse!(ARGV)
|
28
30
|
|
31
|
+
puts "GPU: #{args[:gpu]}"
|
32
|
+
puts "# unit: #{args[:unit]}"
|
33
|
+
puts "# Minibatch-size: #{args[:batchsize]}"
|
34
|
+
puts "# epoch: #{args[:epoch]}"
|
35
|
+
puts
|
36
|
+
|
37
|
+
device = Chainer::Device.create(args[:gpu])
|
38
|
+
Chainer::Device.change_default(device)
|
39
|
+
|
29
40
|
# Set up a neural network to train.
|
30
41
|
# Classifier reports softmax cross entropy loss and accuracy at every
|
31
42
|
# iteration, which will be used by the PrintReport extension below.
|
@@ -57,10 +68,10 @@ optimizer.setup(model)
|
|
57
68
|
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
58
69
|
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
59
70
|
|
60
|
-
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device:
|
71
|
+
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
|
61
72
|
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
62
73
|
|
63
|
-
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device:
|
74
|
+
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: device))
|
64
75
|
|
65
76
|
trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
|
66
77
|
|
data/examples/iris/iris.rb
CHANGED
@@ -25,6 +25,10 @@ class IrisChain < Chainer::Chain
|
|
25
25
|
end
|
26
26
|
end
|
27
27
|
|
28
|
+
device = Chainer::Device.create(Integer(ENV['RED_CHAINER_GPU'] || -1))
|
29
|
+
Chainer::Device.change_default(device)
|
30
|
+
xm = device.xm
|
31
|
+
|
28
32
|
model = IrisChain.new(6,3)
|
29
33
|
|
30
34
|
optimizer = Chainer::Optimizers::Adam.new
|
@@ -35,7 +39,7 @@ iris_table = iris.to_table
|
|
35
39
|
x = iris_table.fetch_values(:sepal_length, :sepal_width, :petal_length, :petal_width).transpose
|
36
40
|
|
37
41
|
# target
|
38
|
-
y_class = iris_table[:
|
42
|
+
y_class = iris_table[:label]
|
39
43
|
|
40
44
|
# class index array
|
41
45
|
# ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
|
@@ -46,7 +50,7 @@ y = y_class.map{|s|
|
|
46
50
|
}
|
47
51
|
|
48
52
|
# y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
|
49
|
-
y_onehot =
|
53
|
+
y_onehot = xm::SFloat.eye(class_name.size)[y, false]
|
50
54
|
|
51
55
|
puts "Iris Datasets"
|
52
56
|
puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
|
@@ -57,9 +61,9 @@ x.each_with_index{|r, i|
|
|
57
61
|
# [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
|
58
62
|
# [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
|
59
63
|
|
60
|
-
x =
|
61
|
-
y =
|
62
|
-
y_onehot =
|
64
|
+
x = xm::SFloat.cast(x)
|
65
|
+
y = xm::SFloat.cast(y)
|
66
|
+
y_onehot = xm::SFloat.cast(y_onehot)
|
63
67
|
|
64
68
|
x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
65
69
|
y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
data/examples/mnist/mnist.rb
CHANGED
@@ -27,6 +27,7 @@ args = {
|
|
27
27
|
batchsize: 100,
|
28
28
|
frequency: -1,
|
29
29
|
epoch: 20,
|
30
|
+
gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
|
30
31
|
resume: nil,
|
31
32
|
unit: 1000,
|
32
33
|
out: 'result'
|
@@ -35,25 +36,36 @@ args = {
|
|
35
36
|
opt = OptionParser.new
|
36
37
|
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
|
37
38
|
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
39
|
+
opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
|
38
40
|
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
39
41
|
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
40
42
|
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
41
43
|
opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
|
42
44
|
opt.parse!(ARGV)
|
43
45
|
|
44
|
-
|
46
|
+
puts "GPU: #{args[:gpu]}"
|
47
|
+
puts "# unit: #{args[:unit]}"
|
48
|
+
puts "# Minibatch-size: #{args[:batchsize]}"
|
49
|
+
puts "# epoch: #{args[:epoch]}"
|
50
|
+
puts
|
51
|
+
|
52
|
+
device = Chainer::Device.create(args[:gpu])
|
53
|
+
Chainer::Device.change_default(device)
|
54
|
+
|
55
|
+
lossfun = -> (x, t) { Chainer::Functions::Loss::SoftmaxCrossEntropy.new(ignore_label: nil).(x, t) }
|
56
|
+
model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10), lossfun)
|
45
57
|
|
46
58
|
optimizer = Chainer::Optimizers::Adam.new
|
47
59
|
optimizer.setup(model)
|
48
|
-
train, test = Chainer::Datasets::
|
60
|
+
train, test = Chainer::Datasets::MNIST.get_mnist
|
49
61
|
|
50
62
|
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
51
63
|
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
52
64
|
|
53
|
-
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device:
|
65
|
+
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
|
54
66
|
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
55
67
|
|
56
|
-
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device:
|
68
|
+
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: args[:gpu]))
|
57
69
|
|
58
70
|
# Take a snapshot for each specified epoch
|
59
71
|
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
|
data/lib/chainer.rb
CHANGED
@@ -3,8 +3,11 @@ require "weakref"
|
|
3
3
|
require "chainer/version"
|
4
4
|
|
5
5
|
require 'chainer/cuda'
|
6
|
+
require 'chainer/backend'
|
6
7
|
require 'chainer/configuration'
|
8
|
+
require 'chainer/device'
|
7
9
|
require 'chainer/function'
|
10
|
+
require 'chainer/function_node'
|
8
11
|
require 'chainer/optimizer'
|
9
12
|
require 'chainer/gradient_method'
|
10
13
|
require 'chainer/gradient_check'
|
@@ -15,6 +18,7 @@ require 'chainer/initializer'
|
|
15
18
|
require 'chainer/initializers/init'
|
16
19
|
require 'chainer/initializers/constant'
|
17
20
|
require 'chainer/initializers/normal'
|
21
|
+
require 'chainer/initializers/uniform'
|
18
22
|
require 'chainer/iterators/serial_iterator'
|
19
23
|
require 'chainer/link'
|
20
24
|
require 'chainer/links/connection/convolution_2d'
|
@@ -30,15 +34,28 @@ require 'chainer/utils/variable'
|
|
30
34
|
require 'chainer/utils/array'
|
31
35
|
require 'chainer/functions/activation/leaky_relu'
|
32
36
|
require 'chainer/functions/activation/relu'
|
37
|
+
require 'chainer/functions/activation/relu_grad2'
|
33
38
|
require 'chainer/functions/activation/sigmoid'
|
39
|
+
require 'chainer/functions/activation/sigmoid_grad'
|
34
40
|
require 'chainer/functions/activation/tanh'
|
35
41
|
require 'chainer/functions/activation/log_softmax'
|
42
|
+
require 'chainer/functions/array/broadcast_to'
|
43
|
+
require 'chainer/functions/array/cast'
|
44
|
+
require 'chainer/functions/array/reshape'
|
45
|
+
require 'chainer/functions/array/rollaxis'
|
46
|
+
require 'chainer/functions/array/select_item'
|
47
|
+
require 'chainer/functions/array/squeeze'
|
48
|
+
require 'chainer/functions/array/transpose'
|
36
49
|
require 'chainer/functions/evaluation/accuracy'
|
37
50
|
require 'chainer/functions/math/basic_math'
|
38
51
|
require 'chainer/functions/math/identity'
|
52
|
+
require 'chainer/functions/math/sum'
|
53
|
+
require 'chainer/functions/math/exp'
|
39
54
|
require 'chainer/functions/loss/mean_squared_error'
|
40
55
|
require 'chainer/functions/loss/softmax_cross_entropy'
|
41
56
|
require 'chainer/functions/connection/convolution_2d'
|
57
|
+
require 'chainer/functions/connection/deconvolution_2d'
|
58
|
+
require 'chainer/functions/connection/convolution_2d_grad_w'
|
42
59
|
require 'chainer/functions/connection/linear'
|
43
60
|
require 'chainer/functions/noise/dropout'
|
44
61
|
require 'chainer/functions/normalization/batch_normalization'
|
@@ -61,7 +78,6 @@ require 'chainer/training/triggers/interval'
|
|
61
78
|
require 'chainer/parameter'
|
62
79
|
require 'chainer/optimizers/adam'
|
63
80
|
require 'chainer/optimizers/momentum_sgd'
|
64
|
-
require 'chainer/dataset/download'
|
65
81
|
require 'chainer/datasets/mnist'
|
66
82
|
require 'chainer/datasets/cifar'
|
67
83
|
require 'chainer/datasets/tuple_dataset'
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Chainer
|
2
|
+
# Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+ from given arrays.
|
3
|
+
#
|
4
|
+
# @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
|
5
|
+
# @return [Class] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of the arguments.
|
6
|
+
def get_array_module(*args)
|
7
|
+
arrays = args.map {|v| v.kind_of?(Chainer::Variable) ? v.data : v }
|
8
|
+
if CUDA.available?
|
9
|
+
return Cumo if arrays.any? {|a| a.kind_of?(Cumo::NArray) }
|
10
|
+
end
|
11
|
+
return Numo
|
12
|
+
end
|
13
|
+
module_function :get_array_module
|
14
|
+
|
15
|
+
# Returns true if the argument is either of +Numo::NArray+ or +Cumo::NArray+.
|
16
|
+
#
|
17
|
+
# @param [Object] obj
|
18
|
+
# @return [Boolean]
|
19
|
+
def array?(obj)
|
20
|
+
if CUDA.available?
|
21
|
+
return true if obj.kind_of?(Cumo::NArray)
|
22
|
+
end
|
23
|
+
return true if obj.kind_of?(Numo::NArray)
|
24
|
+
false
|
25
|
+
end
|
26
|
+
module_function :array?
|
27
|
+
end
|
data/lib/chainer/cuda.rb
CHANGED
@@ -1,18 +1,40 @@
|
|
1
|
+
begin
|
2
|
+
require 'cumo'
|
3
|
+
$chainer_cuda_available = true
|
4
|
+
rescue LoadError => e
|
5
|
+
$chainer_cuda_available = false
|
6
|
+
# A trick to make Cumo::NArray always exists
|
7
|
+
module Cumo
|
8
|
+
class NArray; end
|
9
|
+
class NMath; end
|
10
|
+
class Bit; end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
1
14
|
module Chainer
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
15
|
+
module CUDA
|
16
|
+
# Returns whether CUDA is available.
|
17
|
+
#
|
18
|
+
# @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
|
19
|
+
# @return [Boolean]
|
20
|
+
def available?(id = nil)
|
21
|
+
return false unless $chainer_cuda_available
|
22
|
+
if id
|
23
|
+
raise 'id must be non negative' if id < 0
|
24
|
+
@device_count ||= Cumo::CUDA::Runtime.cudaGetDeviceCount
|
25
|
+
return @device_count > id
|
26
|
+
end
|
27
|
+
true
|
28
|
+
end
|
29
|
+
module_function :available?
|
30
|
+
|
31
|
+
# Checks if CUDA is available.
|
32
|
+
#
|
33
|
+
# @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
|
34
|
+
# @raise [RuntimeError] if not available
|
35
|
+
def check_available(id = nil)
|
36
|
+
raise 'CUDA is not available' unless available?(id)
|
37
|
+
end
|
38
|
+
module_function :check_available
|
16
39
|
end
|
17
|
-
module_function :get_array_module
|
18
40
|
end
|
@@ -2,12 +2,13 @@ module Chainer
|
|
2
2
|
module Dataset
|
3
3
|
module Convert
|
4
4
|
def self.to_device(device, x)
|
5
|
-
# TODO:
|
5
|
+
# TODO(sonots): Implement after Cumo supports transferring between devices
|
6
6
|
x
|
7
7
|
end
|
8
8
|
|
9
9
|
def self.concat_examples(batch, device: nil, padding: nil)
|
10
10
|
raise "batch is empty" if batch.size == 0
|
11
|
+
device = device ? Chainer::Device.create(device) : Chainer::Device.default # takes care of int and nil
|
11
12
|
first_elem = batch[0]
|
12
13
|
|
13
14
|
if first_elem.kind_of?(Array)
|
@@ -17,28 +18,29 @@ module Chainer
|
|
17
18
|
end
|
18
19
|
|
19
20
|
first_elem.size.times do |i|
|
20
|
-
x =
|
21
|
+
x = _concat_arrays(batch.map { |b| b[i] }, padding[i], device)
|
21
22
|
result.push(to_device(device, x))
|
22
23
|
end
|
23
24
|
|
24
25
|
return result
|
25
26
|
else
|
26
|
-
return
|
27
|
+
return _concat_arrays(batch, padding, device)
|
27
28
|
end
|
28
29
|
end
|
29
30
|
|
30
|
-
def self.
|
31
|
-
|
31
|
+
def self._concat_arrays(arrays, padding, device)
|
32
|
+
xm = device.xm
|
33
|
+
unless arrays[0].kind_of?(xm::NArray)
|
32
34
|
# [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
|
33
|
-
arrays =
|
35
|
+
arrays = xm::NArray.cast(arrays)
|
34
36
|
if padding
|
35
|
-
return
|
37
|
+
return _concat_arrays_with_padding(arrays, padding, device)
|
36
38
|
end
|
37
39
|
return arrays
|
38
40
|
end
|
39
41
|
|
40
42
|
if padding
|
41
|
-
return
|
43
|
+
return _concat_arrays_with_padding(arrays, padding, device)
|
42
44
|
end
|
43
45
|
|
44
46
|
# [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
|
@@ -48,12 +50,14 @@ module Chainer
|
|
48
50
|
a[0].concatenate(*a[1..-1])
|
49
51
|
end
|
50
52
|
|
51
|
-
def self.
|
52
|
-
|
53
|
-
|
53
|
+
def self._concat_arrays_with_padding(arrays, padding, device)
|
54
|
+
xm = device.xm
|
55
|
+
if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
|
56
|
+
xm = Chainer.get_array_module(arrays[0])
|
57
|
+
shape = xm::Int32.cast(arrays[0].shape)
|
54
58
|
arrays[1..-1].each do |array|
|
55
|
-
if
|
56
|
-
shape =
|
59
|
+
if xm::Bit.[](shape != array.shape).any?
|
60
|
+
shape = xm::Int32.maximum(shape, array.shape)
|
57
61
|
end
|
58
62
|
end
|
59
63
|
else # Integer
|
@@ -61,15 +65,15 @@ module Chainer
|
|
61
65
|
end
|
62
66
|
|
63
67
|
shape = shape.insert(0, arrays.size).to_a
|
64
|
-
if arrays[0].
|
68
|
+
if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
|
65
69
|
result = arrays[0].class.new(shape).fill(padding)
|
66
70
|
else # Integer
|
67
|
-
result =
|
71
|
+
result = xm::Int32.new(shape).fill(padding)
|
68
72
|
end
|
69
73
|
|
70
74
|
arrays.size.times do |i|
|
71
75
|
src = arrays[i]
|
72
|
-
if
|
76
|
+
if Chainer.array?(src) and src.ndim > 0
|
73
77
|
result[i, 0...src.shape[0], 0...src.shape[1]] = src
|
74
78
|
else # Integer
|
75
79
|
result[i] = src
|