red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33a95bf098a08c334e6a9d29ff791350e6ac0bd9de843054f3ed14f2a005b79b
4
- data.tar.gz: 6f4e53b84e93d01e26363b5d43dba73ec4fc515ecbebbd1574ac6a864f52fb83
3
+ metadata.gz: e7ed2df404bfc36275381f523c0439f2e73debcf36f3b5edb063e985502d7a70
4
+ data.tar.gz: 357c983134aae985808568113d3f4f82bacebf6d25e5cf7c4f9197b1825455dc
5
5
  SHA512:
6
- metadata.gz: 40054365541bb8956c4a8211fbd489e75d1557a9e5c23bcb49f2cc2b393f0d96574d18259bad8e147525a3b720a5329f904be18c7c4f0891f1d886241b72a65d
7
- data.tar.gz: fcb8e641d0efc1ffacc2014c2954d1c5a1e5f947fdcc18306f8946fb307c659e37bb3ccf08b9661d17e13184463b0282ae3594a6c252813d23a53fa0c6182a67
6
+ metadata.gz: 40eb83d14d6efd140a4cb9748f04f50cfa325c9831d8020890a20fe88fc1485547f4dcab48cdcadfda317b46b3f4a6bc936eb8204ae39a876e053878caa7359f
7
+ data.tar.gz: af4133b975c5b4b5ca6e2ce9fb05eddd2b1de5a8a30df9c776531a5acdcf5bc4d8322dc7d6875c49800587a4d98031d0eb62054dbd87ced964093c501da32c95
data/.gitignore CHANGED
@@ -1,13 +1,13 @@
1
1
  /.bundle/
2
- /.yardoc
2
+ /.yardoc/
3
3
  /Gemfile.lock
4
- /_yardoc/
5
4
  /coverage/
6
5
  /doc/
7
6
  /pkg/
8
7
  /spec/reports/
9
8
  /tmp/
10
9
  result
10
+ Gemfile.local
11
11
 
12
12
  # rspec failure tracking
13
13
  .rspec_status
data/.travis.yml CHANGED
@@ -1,8 +1,13 @@
1
+ notifications:
2
+ webhooks:
3
+ - https://webhook.commit-email.info/
1
4
  sudo: false
2
5
  language: ruby
3
6
  rvm:
4
- - 2.3.6
5
7
  - 2.4.3
6
8
  - 2.5.0
7
- before_install: gem install bundler -v 1.15.1
8
- script: ruby test/run_test.rb
9
+ - 2.6.0
10
+ before_install: gem install bundler
11
+ script:
12
+ - ruby test/run_test.rb
13
+ - yardoc --fail-on-warning
data/.yardopts ADDED
@@ -0,0 +1 @@
1
+ -p templates
data/Gemfile CHANGED
@@ -1,4 +1,9 @@
1
1
  source "https://rubygems.org"
2
2
 
3
- # Specify your gem's dependencies in red-chainer.gemspec
4
3
  gemspec
4
+
5
+ local_gemfile = File.join(File.dirname(__FILE__), "Gemfile.local")
6
+ if File.exist?(local_gemfile)
7
+ puts "Loading Gemfile.local ..." if $DEBUG # `ruby -d` or `bundle -v`
8
+ instance_eval File.read(local_gemfile)
9
+ end
data/README.md CHANGED
@@ -8,7 +8,7 @@ It ported python's [Chainer](https://github.com/chainer/chainer) with Ruby.
8
8
 
9
9
  ## Requirements
10
10
 
11
- * Ruby 2.3 or later
11
+ * Ruby 2.4 or later
12
12
 
13
13
  ## Installation
14
14
 
@@ -31,7 +31,10 @@ $ gem install red-chainer
31
31
  ```
32
32
 
33
33
  ## Usage
34
- mnist sample program is [here](./examples/mnist/mnist.rb)
34
+
35
+ ### Run MNIST example
36
+
37
+ MNIST sample program is [here](./examples/mnist/mnist.rb)
35
38
 
36
39
  ```bash
37
40
  # when install Gemfile
@@ -40,6 +43,34 @@ $ bundle exec ruby examples/mnist/mnist.rb
40
43
  $ ruby examples/mnist/mnist.rb
41
44
  ```
42
45
 
46
+ ### Run MNIST example with GPU
47
+
48
+ On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
49
+
50
+ Run the example with `--gpu` option whose value indicates GPU device ID such as:
51
+
52
+ ```
53
+ $ bundle exec ruby examples/mnist/mnist.rb --gpu 0
54
+ ```
55
+
56
+ ## Development
57
+
58
+ ### Run tests
59
+
60
+ ```
61
+ $ bundle exec ruby test/run_test.rb
62
+ ```
63
+
64
+ ### Run tests with Cumo
65
+
66
+ On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
67
+
68
+ Run tests with `RED_CHAINER_GPU` environment variable whose value indicates GPU device ID such as:
69
+
70
+ ```
71
+ $ bundle exec env RED_CHAINER_GPU=0 ruby test/run_test.rb
72
+ ```
73
+
43
74
  ## License
44
75
 
45
76
  The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
@@ -54,4 +85,4 @@ The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
54
85
  | [connection](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/connection) | 12 | 2 | Linear, Convolution2D |
55
86
  | [pooling](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/pooling) | 14 | 3 | Pooling2D, MaxPooling2D, AveragePooling2D |
56
87
  | [example](https://github.com/red-data-tools/red-chainer/tree/master/examples) | 31 | 3 | MNIST, Iris, CIFAR |
57
- | GPU | use cupy | ToDo | want to support [Cumo](https://github.com/sonots/cumo) |
88
+ | GPU | use CuPy | use [Cumo](https://github.com/sonots/cumo) ||
@@ -9,6 +9,7 @@ args = {
9
9
  batchsize: 64,
10
10
  learnrate: 0.05,
11
11
  epoch: 300,
12
+ gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
12
13
  out: 'result',
13
14
  resume: nil,
14
15
  model: 'vgg',
@@ -21,11 +22,21 @@ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default:
21
22
  opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
22
23
  opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
23
24
  opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
25
+ opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
24
26
  opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
25
27
  opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
26
28
  opt.on('-m', '--model VALUE', "Use model") { |v| args[:model] = v }
27
29
  opt.parse!(ARGV)
28
30
 
31
+ puts "GPU: #{args[:gpu]}"
32
+ puts "# unit: #{args[:unit]}"
33
+ puts "# Minibatch-size: #{args[:batchsize]}"
34
+ puts "# epoch: #{args[:epoch]}"
35
+ puts
36
+
37
+ device = Chainer::Device.create(args[:gpu])
38
+ Chainer::Device.change_default(device)
39
+
29
40
  # Set up a neural network to train.
30
41
  # Classifier reports softmax cross entropy loss and accuracy at every
31
42
  # iteration, which will be used by the PrintReport extension below.
@@ -57,10 +68,10 @@ optimizer.setup(model)
57
68
  train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
58
69
  test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
59
70
 
60
- updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
71
+ updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
61
72
  trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
62
73
 
63
- trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
74
+ trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: device))
64
75
 
65
76
  trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
66
77
 
@@ -25,6 +25,10 @@ class IrisChain < Chainer::Chain
25
25
  end
26
26
  end
27
27
 
28
+ device = Chainer::Device.create(Integer(ENV['RED_CHAINER_GPU'] || -1))
29
+ Chainer::Device.change_default(device)
30
+ xm = device.xm
31
+
28
32
  model = IrisChain.new(6,3)
29
33
 
30
34
  optimizer = Chainer::Optimizers::Adam.new
@@ -35,7 +39,7 @@ iris_table = iris.to_table
35
39
  x = iris_table.fetch_values(:sepal_length, :sepal_width, :petal_length, :petal_width).transpose
36
40
 
37
41
  # target
38
- y_class = iris_table[:class]
42
+ y_class = iris_table[:label]
39
43
 
40
44
  # class index array
41
45
  # ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
@@ -46,7 +50,7 @@ y = y_class.map{|s|
46
50
  }
47
51
 
48
52
  # y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
49
- y_onehot = Numo::SFloat.eye(class_name.size)[y,false]
53
+ y_onehot = xm::SFloat.eye(class_name.size)[y, false]
50
54
 
51
55
  puts "Iris Datasets"
52
56
  puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
@@ -57,9 +61,9 @@ x.each_with_index{|r, i|
57
61
  # [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
58
62
  # [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
59
63
 
60
- x = Numo::SFloat.cast(x)
61
- y = Numo::SFloat.cast(y)
62
- y_onehot = Numo::SFloat.cast(y_onehot)
64
+ x = xm::SFloat.cast(x)
65
+ y = xm::SFloat.cast(y)
66
+ y_onehot = xm::SFloat.cast(y_onehot)
63
67
 
64
68
  x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
65
69
  y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
@@ -27,6 +27,7 @@ args = {
27
27
  batchsize: 100,
28
28
  frequency: -1,
29
29
  epoch: 20,
30
+ gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
30
31
  resume: nil,
31
32
  unit: 1000,
32
33
  out: 'result'
@@ -35,25 +36,36 @@ args = {
35
36
  opt = OptionParser.new
36
37
  opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
37
38
  opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
39
+ opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
38
40
  opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
39
41
  opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
40
42
  opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
41
43
  opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
42
44
  opt.parse!(ARGV)
43
45
 
44
- model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10))
46
+ puts "GPU: #{args[:gpu]}"
47
+ puts "# unit: #{args[:unit]}"
48
+ puts "# Minibatch-size: #{args[:batchsize]}"
49
+ puts "# epoch: #{args[:epoch]}"
50
+ puts
51
+
52
+ device = Chainer::Device.create(args[:gpu])
53
+ Chainer::Device.change_default(device)
54
+
55
+ lossfun = -> (x, t) { Chainer::Functions::Loss::SoftmaxCrossEntropy.new(ignore_label: nil).(x, t) }
56
+ model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10), lossfun)
45
57
 
46
58
  optimizer = Chainer::Optimizers::Adam.new
47
59
  optimizer.setup(model)
48
- train, test = Chainer::Datasets::Mnist.get_mnist
60
+ train, test = Chainer::Datasets::MNIST.get_mnist
49
61
 
50
62
  train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
51
63
  test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
52
64
 
53
- updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
65
+ updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
54
66
  trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
55
67
 
56
- trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
68
+ trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: args[:gpu]))
57
69
 
58
70
  # Take a snapshot for each specified epoch
59
71
  frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
data/lib/chainer.rb CHANGED
@@ -3,8 +3,11 @@ require "weakref"
3
3
  require "chainer/version"
4
4
 
5
5
  require 'chainer/cuda'
6
+ require 'chainer/backend'
6
7
  require 'chainer/configuration'
8
+ require 'chainer/device'
7
9
  require 'chainer/function'
10
+ require 'chainer/function_node'
8
11
  require 'chainer/optimizer'
9
12
  require 'chainer/gradient_method'
10
13
  require 'chainer/gradient_check'
@@ -15,6 +18,7 @@ require 'chainer/initializer'
15
18
  require 'chainer/initializers/init'
16
19
  require 'chainer/initializers/constant'
17
20
  require 'chainer/initializers/normal'
21
+ require 'chainer/initializers/uniform'
18
22
  require 'chainer/iterators/serial_iterator'
19
23
  require 'chainer/link'
20
24
  require 'chainer/links/connection/convolution_2d'
@@ -30,15 +34,28 @@ require 'chainer/utils/variable'
30
34
  require 'chainer/utils/array'
31
35
  require 'chainer/functions/activation/leaky_relu'
32
36
  require 'chainer/functions/activation/relu'
37
+ require 'chainer/functions/activation/relu_grad2'
33
38
  require 'chainer/functions/activation/sigmoid'
39
+ require 'chainer/functions/activation/sigmoid_grad'
34
40
  require 'chainer/functions/activation/tanh'
35
41
  require 'chainer/functions/activation/log_softmax'
42
+ require 'chainer/functions/array/broadcast_to'
43
+ require 'chainer/functions/array/cast'
44
+ require 'chainer/functions/array/reshape'
45
+ require 'chainer/functions/array/rollaxis'
46
+ require 'chainer/functions/array/select_item'
47
+ require 'chainer/functions/array/squeeze'
48
+ require 'chainer/functions/array/transpose'
36
49
  require 'chainer/functions/evaluation/accuracy'
37
50
  require 'chainer/functions/math/basic_math'
38
51
  require 'chainer/functions/math/identity'
52
+ require 'chainer/functions/math/sum'
53
+ require 'chainer/functions/math/exp'
39
54
  require 'chainer/functions/loss/mean_squared_error'
40
55
  require 'chainer/functions/loss/softmax_cross_entropy'
41
56
  require 'chainer/functions/connection/convolution_2d'
57
+ require 'chainer/functions/connection/deconvolution_2d'
58
+ require 'chainer/functions/connection/convolution_2d_grad_w'
42
59
  require 'chainer/functions/connection/linear'
43
60
  require 'chainer/functions/noise/dropout'
44
61
  require 'chainer/functions/normalization/batch_normalization'
@@ -61,7 +78,6 @@ require 'chainer/training/triggers/interval'
61
78
  require 'chainer/parameter'
62
79
  require 'chainer/optimizers/adam'
63
80
  require 'chainer/optimizers/momentum_sgd'
64
- require 'chainer/dataset/download'
65
81
  require 'chainer/datasets/mnist'
66
82
  require 'chainer/datasets/cifar'
67
83
  require 'chainer/datasets/tuple_dataset'
@@ -0,0 +1,27 @@
1
+ module Chainer
2
+ # Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+ from given arrays.
3
+ #
4
+ # @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
5
+ # @return [Class] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of the arguments.
6
+ def get_array_module(*args)
7
+ arrays = args.map {|v| v.kind_of?(Chainer::Variable) ? v.data : v }
8
+ if CUDA.available?
9
+ return Cumo if arrays.any? {|a| a.kind_of?(Cumo::NArray) }
10
+ end
11
+ return Numo
12
+ end
13
+ module_function :get_array_module
14
+
15
+ # Returns true if the argument is either of +Numo::NArray+ or +Cumo::NArray+.
16
+ #
17
+ # @param [Object] obj
18
+ # @return [Boolean]
19
+ def array?(obj)
20
+ if CUDA.available?
21
+ return true if obj.kind_of?(Cumo::NArray)
22
+ end
23
+ return true if obj.kind_of?(Numo::NArray)
24
+ false
25
+ end
26
+ module_function :array?
27
+ end
data/lib/chainer/cuda.rb CHANGED
@@ -1,18 +1,40 @@
1
+ begin
2
+ require 'cumo'
3
+ $chainer_cuda_available = true
4
+ rescue LoadError => e
5
+ $chainer_cuda_available = false
6
+ # A trick to make Cumo::NArray always exists
7
+ module Cumo
8
+ class NArray; end
9
+ class NMath; end
10
+ class Bit; end
11
+ end
12
+ end
13
+
1
14
  module Chainer
2
- # Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+.
3
- #
4
- # This is almost equivalent to +Chainer::get_array_module+. The differences
5
- # are that this function can be used even if CUDA is not available and that
6
- # it will return their data arrays' array module for
7
- # +Chainer::Variable+ arguments.
8
- #
9
- # @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
10
- # @return [Numo::NArray] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of
11
- # the arguments.
12
- # @todo CUDA is not supported, yet.
13
- #
14
- def get_array_module(*args)
15
- return Numo::NArray
15
+ module CUDA
16
+ # Returns whether CUDA is available.
17
+ #
18
+ # @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
19
+ # @return [Boolean]
20
+ def available?(id = nil)
21
+ return false unless $chainer_cuda_available
22
+ if id
23
+ raise 'id must be non negative' if id < 0
24
+ @device_count ||= Cumo::CUDA::Runtime.cudaGetDeviceCount
25
+ return @device_count > id
26
+ end
27
+ true
28
+ end
29
+ module_function :available?
30
+
31
+ # Checks if CUDA is available.
32
+ #
33
+ # @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
34
+ # @raise [RuntimeError] if not available
35
+ def check_available(id = nil)
36
+ raise 'CUDA is not available' unless available?(id)
37
+ end
38
+ module_function :check_available
16
39
  end
17
- module_function :get_array_module
18
40
  end
@@ -2,12 +2,13 @@ module Chainer
2
2
  module Dataset
3
3
  module Convert
4
4
  def self.to_device(device, x)
5
- # TODO: support cuda
5
+ # TODO(sonots): Implement after Cumo supports transferring between devices
6
6
  x
7
7
  end
8
8
 
9
9
  def self.concat_examples(batch, device: nil, padding: nil)
10
10
  raise "batch is empty" if batch.size == 0
11
+ device = device ? Chainer::Device.create(device) : Chainer::Device.default # takes care of int and nil
11
12
  first_elem = batch[0]
12
13
 
13
14
  if first_elem.kind_of?(Array)
@@ -17,28 +18,29 @@ module Chainer
17
18
  end
18
19
 
19
20
  first_elem.size.times do |i|
20
- x = concat_arrays(batch.map { |b| b[i] }, padding[i])
21
+ x = _concat_arrays(batch.map { |b| b[i] }, padding[i], device)
21
22
  result.push(to_device(device, x))
22
23
  end
23
24
 
24
25
  return result
25
26
  else
26
- return to_device(device, concat_arrays(batch, padding))
27
+ return _concat_arrays(batch, padding, device)
27
28
  end
28
29
  end
29
30
 
30
- def self.concat_arrays(arrays, padding)
31
- unless arrays[0].kind_of?(Numo::NArray)
31
+ def self._concat_arrays(arrays, padding, device)
32
+ xm = device.xm
33
+ unless arrays[0].kind_of?(xm::NArray)
32
34
  # [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
33
- arrays = Numo::NArray.cast(arrays)
35
+ arrays = xm::NArray.cast(arrays)
34
36
  if padding
35
- return concat_arrays_with_padding(arrays, padding)
37
+ return _concat_arrays_with_padding(arrays, padding, device)
36
38
  end
37
39
  return arrays
38
40
  end
39
41
 
40
42
  if padding
41
- return concat_arrays_with_padding(arrays, padding)
43
+ return _concat_arrays_with_padding(arrays, padding, device)
42
44
  end
43
45
 
44
46
  # [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
@@ -48,12 +50,14 @@ module Chainer
48
50
  a[0].concatenate(*a[1..-1])
49
51
  end
50
52
 
51
- def self.concat_arrays_with_padding(arrays, padding)
52
- if arrays[0].is_a? Numo::NArray
53
- shape = Numo::Int32.cast(arrays[0].shape)
53
+ def self._concat_arrays_with_padding(arrays, padding, device)
54
+ xm = device.xm
55
+ if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
56
+ xm = Chainer.get_array_module(arrays[0])
57
+ shape = xm::Int32.cast(arrays[0].shape)
54
58
  arrays[1..-1].each do |array|
55
- if Numo::Bit.[](shape != array.shape).any?
56
- shape = Numo::Int32.maximum(shape, array.shape)
59
+ if xm::Bit.[](shape != array.shape).any?
60
+ shape = xm::Int32.maximum(shape, array.shape)
57
61
  end
58
62
  end
59
63
  else # Integer
@@ -61,15 +65,15 @@ module Chainer
61
65
  end
62
66
 
63
67
  shape = shape.insert(0, arrays.size).to_a
64
- if arrays[0].is_a? Numo::NArray
68
+ if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
65
69
  result = arrays[0].class.new(shape).fill(padding)
66
70
  else # Integer
67
- result = Numo::Int32.new(shape).fill(padding)
71
+ result = xm::Int32.new(shape).fill(padding)
68
72
  end
69
73
 
70
74
  arrays.size.times do |i|
71
75
  src = arrays[i]
72
- if src.is_a? Numo::NArray
76
+ if Chainer.array?(src) and src.ndim > 0
73
77
  result[i, 0...src.shape[0], 0...src.shape[1]] = src
74
78
  else # Integer
75
79
  result[i] = src