red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33a95bf098a08c334e6a9d29ff791350e6ac0bd9de843054f3ed14f2a005b79b
4
- data.tar.gz: 6f4e53b84e93d01e26363b5d43dba73ec4fc515ecbebbd1574ac6a864f52fb83
3
+ metadata.gz: e7ed2df404bfc36275381f523c0439f2e73debcf36f3b5edb063e985502d7a70
4
+ data.tar.gz: 357c983134aae985808568113d3f4f82bacebf6d25e5cf7c4f9197b1825455dc
5
5
  SHA512:
6
- metadata.gz: 40054365541bb8956c4a8211fbd489e75d1557a9e5c23bcb49f2cc2b393f0d96574d18259bad8e147525a3b720a5329f904be18c7c4f0891f1d886241b72a65d
7
- data.tar.gz: fcb8e641d0efc1ffacc2014c2954d1c5a1e5f947fdcc18306f8946fb307c659e37bb3ccf08b9661d17e13184463b0282ae3594a6c252813d23a53fa0c6182a67
6
+ metadata.gz: 40eb83d14d6efd140a4cb9748f04f50cfa325c9831d8020890a20fe88fc1485547f4dcab48cdcadfda317b46b3f4a6bc936eb8204ae39a876e053878caa7359f
7
+ data.tar.gz: af4133b975c5b4b5ca6e2ce9fb05eddd2b1de5a8a30df9c776531a5acdcf5bc4d8322dc7d6875c49800587a4d98031d0eb62054dbd87ced964093c501da32c95
data/.gitignore CHANGED
@@ -1,13 +1,13 @@
1
1
  /.bundle/
2
- /.yardoc
2
+ /.yardoc/
3
3
  /Gemfile.lock
4
- /_yardoc/
5
4
  /coverage/
6
5
  /doc/
7
6
  /pkg/
8
7
  /spec/reports/
9
8
  /tmp/
10
9
  result
10
+ Gemfile.local
11
11
 
12
12
  # rspec failure tracking
13
13
  .rspec_status
data/.travis.yml CHANGED
@@ -1,8 +1,13 @@
1
+ notifications:
2
+ webhooks:
3
+ - https://webhook.commit-email.info/
1
4
  sudo: false
2
5
  language: ruby
3
6
  rvm:
4
- - 2.3.6
5
7
  - 2.4.3
6
8
  - 2.5.0
7
- before_install: gem install bundler -v 1.15.1
8
- script: ruby test/run_test.rb
9
+ - 2.6.0
10
+ before_install: gem install bundler
11
+ script:
12
+ - ruby test/run_test.rb
13
+ - yardoc --fail-on-warning
data/.yardopts ADDED
@@ -0,0 +1 @@
1
+ -p templates
data/Gemfile CHANGED
@@ -1,4 +1,9 @@
1
1
  source "https://rubygems.org"
2
2
 
3
- # Specify your gem's dependencies in red-chainer.gemspec
4
3
  gemspec
4
+
5
+ local_gemfile = File.join(File.dirname(__FILE__), "Gemfile.local")
6
+ if File.exist?(local_gemfile)
7
+ puts "Loading Gemfile.local ..." if $DEBUG # `ruby -d` or `bundle -v`
8
+ instance_eval File.read(local_gemfile)
9
+ end
data/README.md CHANGED
@@ -8,7 +8,7 @@ It ported python's [Chainer](https://github.com/chainer/chainer) with Ruby.
8
8
 
9
9
  ## Requirements
10
10
 
11
- * Ruby 2.3 or later
11
+ * Ruby 2.4 or later
12
12
 
13
13
  ## Installation
14
14
 
@@ -31,7 +31,10 @@ $ gem install red-chainer
31
31
  ```
32
32
 
33
33
  ## Usage
34
- mnist sample program is [here](./examples/mnist/mnist.rb)
34
+
35
+ ### Run MNIST example
36
+
37
+ MNIST sample program is [here](./examples/mnist/mnist.rb)
35
38
 
36
39
  ```bash
37
40
  # when install Gemfile
@@ -40,6 +43,34 @@ $ bundle exec ruby examples/mnist/mnist.rb
40
43
  $ ruby examples/mnist/mnist.rb
41
44
  ```
42
45
 
46
+ ### Run MNIST example with GPU
47
+
48
+ On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
49
+
50
+ Run the example with `--gpu` option whose value indicates GPU device ID such as:
51
+
52
+ ```
53
+ $ bundle exec ruby examples/mnist/mnist.rb --gpu 0
54
+ ```
55
+
56
+ ## Development
57
+
58
+ ### Run tests
59
+
60
+ ```
61
+ $ bundle exec ruby test/run_test.rb
62
+ ```
63
+
64
+ ### Run tests with Cumo
65
+
66
+ On GPU machine, add `gem 'cumo'` on Gemfile and do `bundle install`.
67
+
68
+ Run tests with `RED_CHAINER_GPU` environment variable whose value indicates GPU device ID such as:
69
+
70
+ ```
71
+ $ bundle exec env RED_CHAINER_GPU=0 ruby test/run_test.rb
72
+ ```
73
+
43
74
  ## License
44
75
 
45
76
  The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
@@ -54,4 +85,4 @@ The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
54
85
  | [connection](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/connection) | 12 | 2 | Linear, Convolution2D |
55
86
  | [pooling](https://github.com/red-data-tools/red-chainer/tree/master/lib/chainer/functions/pooling) | 14 | 3 | Pooling2D, MaxPooling2D, AveragePooling2D |
56
87
  | [example](https://github.com/red-data-tools/red-chainer/tree/master/examples) | 31 | 3 | MNIST, Iris, CIFAR |
57
- | GPU | use cupy | ToDo | want to support [Cumo](https://github.com/sonots/cumo) |
88
+ | GPU | use CuPy | use [Cumo](https://github.com/sonots/cumo) ||
@@ -9,6 +9,7 @@ args = {
9
9
  batchsize: 64,
10
10
  learnrate: 0.05,
11
11
  epoch: 300,
12
+ gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
12
13
  out: 'result',
13
14
  resume: nil,
14
15
  model: 'vgg',
@@ -21,11 +22,21 @@ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default:
21
22
  opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
22
23
  opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
23
24
  opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
25
+ opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
24
26
  opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
25
27
  opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
26
28
  opt.on('-m', '--model VALUE', "Use model") { |v| args[:model] = v }
27
29
  opt.parse!(ARGV)
28
30
 
31
+ puts "GPU: #{args[:gpu]}"
32
+ puts "# unit: #{args[:unit]}"
33
+ puts "# Minibatch-size: #{args[:batchsize]}"
34
+ puts "# epoch: #{args[:epoch]}"
35
+ puts
36
+
37
+ device = Chainer::Device.create(args[:gpu])
38
+ Chainer::Device.change_default(device)
39
+
29
40
  # Set up a neural network to train.
30
41
  # Classifier reports softmax cross entropy loss and accuracy at every
31
42
  # iteration, which will be used by the PrintReport extension below.
@@ -57,10 +68,10 @@ optimizer.setup(model)
57
68
  train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
58
69
  test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
59
70
 
60
- updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
71
+ updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
61
72
  trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
62
73
 
63
- trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
74
+ trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: device))
64
75
 
65
76
  trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
66
77
 
@@ -25,6 +25,10 @@ class IrisChain < Chainer::Chain
25
25
  end
26
26
  end
27
27
 
28
+ device = Chainer::Device.create(Integer(ENV['RED_CHAINER_GPU'] || -1))
29
+ Chainer::Device.change_default(device)
30
+ xm = device.xm
31
+
28
32
  model = IrisChain.new(6,3)
29
33
 
30
34
  optimizer = Chainer::Optimizers::Adam.new
@@ -35,7 +39,7 @@ iris_table = iris.to_table
35
39
  x = iris_table.fetch_values(:sepal_length, :sepal_width, :petal_length, :petal_width).transpose
36
40
 
37
41
  # target
38
- y_class = iris_table[:class]
42
+ y_class = iris_table[:label]
39
43
 
40
44
  # class index array
41
45
  # ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
@@ -46,7 +50,7 @@ y = y_class.map{|s|
46
50
  }
47
51
 
48
52
  # y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
49
- y_onehot = Numo::SFloat.eye(class_name.size)[y,false]
53
+ y_onehot = xm::SFloat.eye(class_name.size)[y, false]
50
54
 
51
55
  puts "Iris Datasets"
52
56
  puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
@@ -57,9 +61,9 @@ x.each_with_index{|r, i|
57
61
  # [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
58
62
  # [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
59
63
 
60
- x = Numo::SFloat.cast(x)
61
- y = Numo::SFloat.cast(y)
62
- y_onehot = Numo::SFloat.cast(y_onehot)
64
+ x = xm::SFloat.cast(x)
65
+ y = xm::SFloat.cast(y)
66
+ y_onehot = xm::SFloat.cast(y_onehot)
63
67
 
64
68
  x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
65
69
  y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
@@ -27,6 +27,7 @@ args = {
27
27
  batchsize: 100,
28
28
  frequency: -1,
29
29
  epoch: 20,
30
+ gpu: Integer(ENV['RED_CHAINER_GPU'] || -1),
30
31
  resume: nil,
31
32
  unit: 1000,
32
33
  out: 'result'
@@ -35,25 +36,36 @@ args = {
35
36
  opt = OptionParser.new
36
37
  opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
37
38
  opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
39
+ opt.on('-g', '--gpu VALUE', "GPU ID (negative value indicates CPU) (default: #{args[:gpu]})") { |v| args[:gpu] = v.to_i }
38
40
  opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
39
41
  opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
40
42
  opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
41
43
  opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
42
44
  opt.parse!(ARGV)
43
45
 
44
- model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10))
46
+ puts "GPU: #{args[:gpu]}"
47
+ puts "# unit: #{args[:unit]}"
48
+ puts "# Minibatch-size: #{args[:batchsize]}"
49
+ puts "# epoch: #{args[:epoch]}"
50
+ puts
51
+
52
+ device = Chainer::Device.create(args[:gpu])
53
+ Chainer::Device.change_default(device)
54
+
55
+ lossfun = -> (x, t) { Chainer::Functions::Loss::SoftmaxCrossEntropy.new(ignore_label: nil).(x, t) }
56
+ model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10), lossfun)
45
57
 
46
58
  optimizer = Chainer::Optimizers::Adam.new
47
59
  optimizer.setup(model)
48
- train, test = Chainer::Datasets::Mnist.get_mnist
60
+ train, test = Chainer::Datasets::MNIST.get_mnist
49
61
 
50
62
  train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
51
63
  test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
52
64
 
53
- updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
65
+ updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: device)
54
66
  trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
55
67
 
56
- trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
68
+ trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: args[:gpu]))
57
69
 
58
70
  # Take a snapshot for each specified epoch
59
71
  frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
data/lib/chainer.rb CHANGED
@@ -3,8 +3,11 @@ require "weakref"
3
3
  require "chainer/version"
4
4
 
5
5
  require 'chainer/cuda'
6
+ require 'chainer/backend'
6
7
  require 'chainer/configuration'
8
+ require 'chainer/device'
7
9
  require 'chainer/function'
10
+ require 'chainer/function_node'
8
11
  require 'chainer/optimizer'
9
12
  require 'chainer/gradient_method'
10
13
  require 'chainer/gradient_check'
@@ -15,6 +18,7 @@ require 'chainer/initializer'
15
18
  require 'chainer/initializers/init'
16
19
  require 'chainer/initializers/constant'
17
20
  require 'chainer/initializers/normal'
21
+ require 'chainer/initializers/uniform'
18
22
  require 'chainer/iterators/serial_iterator'
19
23
  require 'chainer/link'
20
24
  require 'chainer/links/connection/convolution_2d'
@@ -30,15 +34,28 @@ require 'chainer/utils/variable'
30
34
  require 'chainer/utils/array'
31
35
  require 'chainer/functions/activation/leaky_relu'
32
36
  require 'chainer/functions/activation/relu'
37
+ require 'chainer/functions/activation/relu_grad2'
33
38
  require 'chainer/functions/activation/sigmoid'
39
+ require 'chainer/functions/activation/sigmoid_grad'
34
40
  require 'chainer/functions/activation/tanh'
35
41
  require 'chainer/functions/activation/log_softmax'
42
+ require 'chainer/functions/array/broadcast_to'
43
+ require 'chainer/functions/array/cast'
44
+ require 'chainer/functions/array/reshape'
45
+ require 'chainer/functions/array/rollaxis'
46
+ require 'chainer/functions/array/select_item'
47
+ require 'chainer/functions/array/squeeze'
48
+ require 'chainer/functions/array/transpose'
36
49
  require 'chainer/functions/evaluation/accuracy'
37
50
  require 'chainer/functions/math/basic_math'
38
51
  require 'chainer/functions/math/identity'
52
+ require 'chainer/functions/math/sum'
53
+ require 'chainer/functions/math/exp'
39
54
  require 'chainer/functions/loss/mean_squared_error'
40
55
  require 'chainer/functions/loss/softmax_cross_entropy'
41
56
  require 'chainer/functions/connection/convolution_2d'
57
+ require 'chainer/functions/connection/deconvolution_2d'
58
+ require 'chainer/functions/connection/convolution_2d_grad_w'
42
59
  require 'chainer/functions/connection/linear'
43
60
  require 'chainer/functions/noise/dropout'
44
61
  require 'chainer/functions/normalization/batch_normalization'
@@ -61,7 +78,6 @@ require 'chainer/training/triggers/interval'
61
78
  require 'chainer/parameter'
62
79
  require 'chainer/optimizers/adam'
63
80
  require 'chainer/optimizers/momentum_sgd'
64
- require 'chainer/dataset/download'
65
81
  require 'chainer/datasets/mnist'
66
82
  require 'chainer/datasets/cifar'
67
83
  require 'chainer/datasets/tuple_dataset'
@@ -0,0 +1,27 @@
1
+ module Chainer
2
+ # Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+ from given arrays.
3
+ #
4
+ # @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
5
+ # @return [Class] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of the arguments.
6
+ def get_array_module(*args)
7
+ arrays = args.map {|v| v.kind_of?(Chainer::Variable) ? v.data : v }
8
+ if CUDA.available?
9
+ return Cumo if arrays.any? {|a| a.kind_of?(Cumo::NArray) }
10
+ end
11
+ return Numo
12
+ end
13
+ module_function :get_array_module
14
+
15
+ # Returns true if the argument is either of +Numo::NArray+ or +Cumo::NArray+.
16
+ #
17
+ # @param [Object] obj
18
+ # @return [Boolean]
19
+ def array?(obj)
20
+ if CUDA.available?
21
+ return true if obj.kind_of?(Cumo::NArray)
22
+ end
23
+ return true if obj.kind_of?(Numo::NArray)
24
+ false
25
+ end
26
+ module_function :array?
27
+ end
data/lib/chainer/cuda.rb CHANGED
@@ -1,18 +1,40 @@
1
+ begin
2
+ require 'cumo'
3
+ $chainer_cuda_available = true
4
+ rescue LoadError => e
5
+ $chainer_cuda_available = false
6
+ # A trick to make Cumo::NArray always exists
7
+ module Cumo
8
+ class NArray; end
9
+ class NMath; end
10
+ class Bit; end
11
+ end
12
+ end
13
+
1
14
  module Chainer
2
- # Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+.
3
- #
4
- # This is almost equivalent to +Chainer::get_array_module+. The differences
5
- # are that this function can be used even if CUDA is not available and that
6
- # it will return their data arrays' array module for
7
- # +Chainer::Variable+ arguments.
8
- #
9
- # @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
10
- # @return [Numo::NArray] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of
11
- # the arguments.
12
- # @todo CUDA is not supported, yet.
13
- #
14
- def get_array_module(*args)
15
- return Numo::NArray
15
+ module CUDA
16
+ # Returns whether CUDA is available.
17
+ #
18
+ # @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
19
+ # @return [Boolean]
20
+ def available?(id = nil)
21
+ return false unless $chainer_cuda_available
22
+ if id
23
+ raise 'id must be non negative' if id < 0
24
+ @device_count ||= Cumo::CUDA::Runtime.cudaGetDeviceCount
25
+ return @device_count > id
26
+ end
27
+ true
28
+ end
29
+ module_function :available?
30
+
31
+ # Checks if CUDA is available.
32
+ #
33
+ # @param [Integer or nil] id If a non negative integer is given, check availability of GPU ID.
34
+ # @raise [RuntimeError] if not available
35
+ def check_available(id = nil)
36
+ raise 'CUDA is not available' unless available?(id)
37
+ end
38
+ module_function :check_available
16
39
  end
17
- module_function :get_array_module
18
40
  end
@@ -2,12 +2,13 @@ module Chainer
2
2
  module Dataset
3
3
  module Convert
4
4
  def self.to_device(device, x)
5
- # TODO: support cuda
5
+ # TODO(sonots): Implement after Cumo supports transferring between devices
6
6
  x
7
7
  end
8
8
 
9
9
  def self.concat_examples(batch, device: nil, padding: nil)
10
10
  raise "batch is empty" if batch.size == 0
11
+ device = device ? Chainer::Device.create(device) : Chainer::Device.default # takes care of int and nil
11
12
  first_elem = batch[0]
12
13
 
13
14
  if first_elem.kind_of?(Array)
@@ -17,28 +18,29 @@ module Chainer
17
18
  end
18
19
 
19
20
  first_elem.size.times do |i|
20
- x = concat_arrays(batch.map { |b| b[i] }, padding[i])
21
+ x = _concat_arrays(batch.map { |b| b[i] }, padding[i], device)
21
22
  result.push(to_device(device, x))
22
23
  end
23
24
 
24
25
  return result
25
26
  else
26
- return to_device(device, concat_arrays(batch, padding))
27
+ return _concat_arrays(batch, padding, device)
27
28
  end
28
29
  end
29
30
 
30
- def self.concat_arrays(arrays, padding)
31
- unless arrays[0].kind_of?(Numo::NArray)
31
+ def self._concat_arrays(arrays, padding, device)
32
+ xm = device.xm
33
+ unless arrays[0].kind_of?(xm::NArray)
32
34
  # [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
33
- arrays = Numo::NArray.cast(arrays)
35
+ arrays = xm::NArray.cast(arrays)
34
36
  if padding
35
- return concat_arrays_with_padding(arrays, padding)
37
+ return _concat_arrays_with_padding(arrays, padding, device)
36
38
  end
37
39
  return arrays
38
40
  end
39
41
 
40
42
  if padding
41
- return concat_arrays_with_padding(arrays, padding)
43
+ return _concat_arrays_with_padding(arrays, padding, device)
42
44
  end
43
45
 
44
46
  # [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
@@ -48,12 +50,14 @@ module Chainer
48
50
  a[0].concatenate(*a[1..-1])
49
51
  end
50
52
 
51
- def self.concat_arrays_with_padding(arrays, padding)
52
- if arrays[0].is_a? Numo::NArray
53
- shape = Numo::Int32.cast(arrays[0].shape)
53
+ def self._concat_arrays_with_padding(arrays, padding, device)
54
+ xm = device.xm
55
+ if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
56
+ xm = Chainer.get_array_module(arrays[0])
57
+ shape = xm::Int32.cast(arrays[0].shape)
54
58
  arrays[1..-1].each do |array|
55
- if Numo::Bit.[](shape != array.shape).any?
56
- shape = Numo::Int32.maximum(shape, array.shape)
59
+ if xm::Bit.[](shape != array.shape).any?
60
+ shape = xm::Int32.maximum(shape, array.shape)
57
61
  end
58
62
  end
59
63
  else # Integer
@@ -61,15 +65,15 @@ module Chainer
61
65
  end
62
66
 
63
67
  shape = shape.insert(0, arrays.size).to_a
64
- if arrays[0].is_a? Numo::NArray
68
+ if Chainer.array?(arrays[0]) and arrays[0].ndim > 0
65
69
  result = arrays[0].class.new(shape).fill(padding)
66
70
  else # Integer
67
- result = Numo::Int32.new(shape).fill(padding)
71
+ result = xm::Int32.new(shape).fill(padding)
68
72
  end
69
73
 
70
74
  arrays.size.times do |i|
71
75
  src = arrays[i]
72
- if src.is_a? Numo::NArray
76
+ if Chainer.array?(src) and src.ndim > 0
73
77
  result[i, 0...src.shape[0], 0...src.shape[1]] = src
74
78
  else # Integer
75
79
  result[i] = src