red-chainer 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/examples/cifar/models/vgg.rb +84 -0
  4. data/examples/cifar/train_cifar.rb +70 -0
  5. data/examples/iris.rb +103 -0
  6. data/lib/chainer.rb +17 -0
  7. data/lib/chainer/configuration.rb +2 -1
  8. data/lib/chainer/cuda.rb +18 -0
  9. data/lib/chainer/dataset/convert.rb +30 -9
  10. data/lib/chainer/datasets/cifar.rb +56 -0
  11. data/lib/chainer/datasets/mnist.rb +3 -3
  12. data/lib/chainer/datasets/tuple_dataset.rb +3 -1
  13. data/lib/chainer/function.rb +1 -0
  14. data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
  15. data/lib/chainer/functions/activation/log_softmax.rb +4 -4
  16. data/lib/chainer/functions/activation/relu.rb +3 -4
  17. data/lib/chainer/functions/activation/sigmoid.rb +4 -4
  18. data/lib/chainer/functions/activation/tanh.rb +5 -5
  19. data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
  20. data/lib/chainer/functions/connection/linear.rb +1 -1
  21. data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
  22. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
  23. data/lib/chainer/functions/math/identity.rb +26 -0
  24. data/lib/chainer/functions/noise/dropout.rb +45 -0
  25. data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
  26. data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
  27. data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
  28. data/lib/chainer/gradient_check.rb +240 -0
  29. data/lib/chainer/initializer.rb +2 -0
  30. data/lib/chainer/initializers/constant.rb +1 -1
  31. data/lib/chainer/initializers/init.rb +5 -1
  32. data/lib/chainer/initializers/normal.rb +1 -1
  33. data/lib/chainer/iterators/serial_iterator.rb +1 -1
  34. data/lib/chainer/link.rb +11 -0
  35. data/lib/chainer/links/connection/convolution_2d.rb +98 -0
  36. data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
  37. data/lib/chainer/optimizer.rb +40 -1
  38. data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
  39. data/lib/chainer/parameter.rb +1 -1
  40. data/lib/chainer/serializers/marshal.rb +7 -3
  41. data/lib/chainer/testing/array.rb +32 -0
  42. data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
  43. data/lib/chainer/training/extensions/snapshot.rb +1 -1
  44. data/lib/chainer/training/standard_updater.rb +4 -0
  45. data/lib/chainer/training/trainer.rb +1 -1
  46. data/lib/chainer/utils/array.rb +13 -2
  47. data/lib/chainer/utils/conv.rb +59 -0
  48. data/lib/chainer/utils/math.rb +72 -0
  49. data/lib/chainer/utils/variable.rb +7 -3
  50. data/lib/chainer/version.rb +1 -1
  51. data/red-chainer.gemspec +1 -0
  52. metadata +37 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2fae8ab4c6d546356b7171ba9823e1537f02ad5e08358d97b29d907e67e11b20
4
- data.tar.gz: 87b807d2cdf48b060100ade6c0f294e249db8f8a6a131754c4a108758867a1e5
3
+ metadata.gz: 4bd5f287aafbc100878391131a0a184c9c6497ddb26e0c5124010166aa455def
4
+ data.tar.gz: 6625f1d862ba17db5867b208494419f1e7e7bc1ed5b5276bec785151140a47bd
5
5
  SHA512:
6
- metadata.gz: b533669c998ef246222a001658193e489c5500a4d9276c53e43c332543342bb258c2a8cf8958b47cd5201c52c8f6c3d1378af33c20c0f3644d0d155addeafc04
7
- data.tar.gz: 16c2e65110c8a83e42d85e482a35209463c36766d2eb6d6eb76045f44d55472f5f6c94b8050221f2273f5e8d2ba644648fb3c10b2a0b9cf4d03b86495470d7a0
6
+ metadata.gz: 659986f00b88051471ca54a88eafb4a6f9ebbdf8f27a9f51c8ff32a27f584d2a78015d47baf20e74c6e9605c09bda0777b4780c4ef79313d194a362e0d3ff013
7
+ data.tar.gz: a468488ee913244bcdbc058fe27c93beec897bd4476f1eec50ae2bc3071124701aaf37b9f15cd07848bbb30372e4796c04b24ead672c28e5fcb0dbb1721b7bd6
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  ## Name
4
4
 
5
- Red Cahiner
5
+ Red Chainer
6
6
 
7
7
  ## Description
8
8
  Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
@@ -39,4 +39,4 @@ $ ruby examples/mnist.rb
39
39
 
40
40
  ## License
41
41
 
42
- The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
42
+ The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
@@ -0,0 +1,84 @@
1
+ class Block < Chainer::Chain
2
+ def initialize(out_channels, ksize, pad: 1)
3
+ super()
4
+ init_scope do
5
+ @conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true)
6
+ @bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels)
7
+ end
8
+ end
9
+
10
+ def call(x)
11
+ h = @conv.(x)
12
+ h = @bn.(h)
13
+ Chainer::Functions::Activation::Relu.relu(h)
14
+ end
15
+ end
16
+
17
+ class VGG < Chainer::Chain
18
+ def initialize(class_labels: 10)
19
+ super()
20
+ init_scope do
21
+ @block1_1 = Block.new(64, 3)
22
+ @block1_2 = Block.new(64, 3)
23
+ @block2_1 = Block.new(128, 3)
24
+ @block2_2 = Block.new(128, 3)
25
+ @block3_1 = Block.new(256, 3)
26
+ @block3_2 = Block.new(256, 3)
27
+ @block3_3 = Block.new(256, 3)
28
+ @block4_1 = Block.new(512, 3)
29
+ @block4_2 = Block.new(512, 3)
30
+ @block4_3 = Block.new(512, 3)
31
+ @block5_1 = Block.new(512, 3)
32
+ @block5_2 = Block.new(512, 3)
33
+ @block5_3 = Block.new(512, 3)
34
+ @fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
35
+ @bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
36
+ @fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true)
37
+ end
38
+ end
39
+
40
+ def call(x)
41
+ # 64 channel blocks:
42
+ h = @block1_1.(x)
43
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3)
44
+ h = @block1_2.(h)
45
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
46
+
47
+ # 128 channel blocks:
48
+ h = @block2_1.(h)
49
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
50
+ h = @block2_2.(h)
51
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2)
52
+
53
+ # 256 channel blocks:
54
+ h = @block3_1.(h)
55
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
56
+ h = @block3_2.(h)
57
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
58
+ h = @block3_3.(h)
59
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
60
+
61
+ # 512 channel blocks:
62
+ h = @block4_1.(h)
63
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
64
+ h = @block4_2.(h)
65
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
66
+ h = @block4_3.(h)
67
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
68
+
69
+ # 512 channel blocks:
70
+ h = @block5_1.(h)
71
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
72
+ h = @block5_2.(h)
73
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
74
+ h = @block5_3.(h)
75
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
76
+
77
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
78
+ h = @fc1.(h)
79
+ h = @bn_fc1.(h)
80
+ h = Chainer::Functions::Activation::Relu.relu(h)
81
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
82
+ @fc2.(h)
83
+ end
84
+ end
@@ -0,0 +1,70 @@
1
+ require 'chainer'
2
+ require __dir__ + '/models/vgg'
3
+ require 'optparse'
4
+
5
+ args = {
6
+ dataset: 'cifar10',
7
+ frequency: -1,
8
+ batchsize: 64,
9
+ learnrate: 0.05,
10
+ epoch: 300,
11
+ out: 'result',
12
+ resume: nil
13
+ }
14
+
15
+
16
+ opt = OptionParser.new
17
+ opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v }
18
+ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
19
+ opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
20
+ opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
21
+ opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
22
+ opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
23
+ opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
24
+ opt.parse!(ARGV)
25
+
26
+ # Set up a neural network to train.
27
+ # Classifier reports softmax cross entropy loss and accuracy at every
28
+ # iteration, which will be used by the PrintReport extension below.
29
+ if args[:dataset] == 'cifar10'
30
+ puts 'Using CIFAR10 dataset.'
31
+ class_labels = 10
32
+ train, test = Chainer::Datasets::CIFAR.get_cifar10
33
+ elsif args[:dataset] == 'cifar100'
34
+ puts 'Using CIFAR100 dataset.'
35
+ class_labels = 100
36
+ train, test = Chainer::Datasets::CIFAR.get_cifar100
37
+ else
38
+ raise 'Invalid dataset choice.'
39
+ end
40
+
41
+ puts "setup..."
42
+
43
+ model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels))
44
+
45
+ optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
46
+ optimizer.setup(model)
47
+
48
+ train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
49
+ test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
50
+
51
+ updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
52
+ trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
53
+
54
+ trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
55
+
56
+ trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
57
+
58
+ frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
59
+ trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])
60
+
61
+ trainer.extend(Chainer::Training::Extensions::LogReport.new)
62
+ trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
63
+ trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
64
+
65
+ if args[:resume]
66
+ Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
67
+ end
68
+
69
+ trainer.run
70
+
data/examples/iris.rb ADDED
@@ -0,0 +1,103 @@
1
+ require 'numo/narray'
2
+ require 'chainer'
3
+ require "datasets"
4
+
5
+ class IrisChain < Chainer::Chain
6
+ L = Chainer::Links::Connection::Linear
7
+ F = Chainer::Functions
8
+
9
+ def initialize(n_units, n_out)
10
+ super()
11
+ init_scope do
12
+ @l1 = L.new(nil, out_size: n_units)
13
+ @l2 = L.new(nil, out_size: n_out)
14
+ end
15
+ end
16
+
17
+ def call(x, y)
18
+ return F::Loss::MeanSquaredError.mean_squared_error(fwd(x), y)
19
+ end
20
+
21
+ def fwd(x)
22
+ h1 = F::Activation::Sigmoid.sigmoid(@l1.(x))
23
+ h2 = @l2.(h1)
24
+ return h2
25
+ end
26
+ end
27
+
28
+ model = IrisChain.new(6,3)
29
+
30
+ optimizer = Chainer::Optimizers::Adam.new
31
+ optimizer.setup(model)
32
+
33
+ iris = Datasets::Iris.new
34
+ x = iris.each.map {|r| r.each.to_a[0..3]}
35
+
36
+ # target
37
+ y_class = iris.each.map {|r| r.class}
38
+
39
+ # class index array
40
+ # ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
41
+ class_name = y_class.uniq
42
+ # y => [0, 0, 0, 0, ,,, 1, 1, ,,, ,2, 2]
43
+ y = y_class.map{|s|
44
+ class_name.index(s)
45
+ }
46
+
47
+ # y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
48
+ y_onehot = y_class.map{|s|
49
+ i = class_name.index(s)
50
+ a = Array.new(class_name.size, 0.0)
51
+ a[i] = 1.0
52
+ a
53
+ }
54
+
55
+ puts "Iris Datasets"
56
+ puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
57
+ x.each_with_index{|r, i|
58
+ puts "#{'%3d' % i} : [#{r.join(', ')}] #{y_onehot[i]} #=> #{y_class[i]}(#{y[i]})"
59
+ }
60
+ # [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] => 50 data
61
+ # [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
62
+ # [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
63
+
64
+ x = Numo::SFloat.cast(x)
65
+ y = Numo::SFloat.cast(y)
66
+ y_onehot = Numo::SFloat.cast(y_onehot)
67
+
68
+ x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
69
+ y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
70
+ x_test = x[(0..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
71
+ y_test = y[(0..-1).step(2)] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
72
+
73
+ # Train
74
+ 10000.times{|i|
75
+ x = Chainer::Variable.new(x_train)
76
+ y = Chainer::Variable.new(y_train)
77
+ model.cleargrads()
78
+ loss = model.(x, y)
79
+ loss.backward()
80
+ optimizer.update()
81
+ }
82
+
83
+ # Test
84
+ xt = Chainer::Variable.new(x_test)
85
+ yt = model.fwd(xt)
86
+ n_row, n_col = yt.data.shape
87
+
88
+ puts "Result : Correct Answer : Answer <= One-Hot"
89
+ ok = 0
90
+ n_row.times{|i|
91
+ ans = yt.data[i, true].max_index()
92
+ if ans == y_test[i]
93
+ ok += 1
94
+ printf("OK")
95
+ else
96
+ printf("--")
97
+ end
98
+ printf(" : #{y_test[i].to_i} :")
99
+
100
+ puts " #{ans.to_i} <= #{yt.data[i, 0..-1].to_a}"
101
+ }
102
+ puts "Row: #{n_row}, Column: #{n_col}"
103
+ puts "Accuracy rate : #{ok}/#{n_row} = #{ok.to_f / n_row}"
data/lib/chainer.rb CHANGED
@@ -2,10 +2,12 @@ require "weakref"
2
2
 
3
3
  require "chainer/version"
4
4
 
5
+ require 'chainer/cuda'
5
6
  require 'chainer/configuration'
6
7
  require 'chainer/function'
7
8
  require 'chainer/optimizer'
8
9
  require 'chainer/gradient_method'
10
+ require 'chainer/gradient_check'
9
11
  require 'chainer/hyperparameter'
10
12
  require 'chainer/dataset/iterator'
11
13
  require 'chainer/dataset/convert'
@@ -15,11 +17,15 @@ require 'chainer/initializers/constant'
15
17
  require 'chainer/initializers/normal'
16
18
  require 'chainer/iterators/serial_iterator'
17
19
  require 'chainer/link'
20
+ require 'chainer/links/connection/convolution_2d'
18
21
  require 'chainer/links/connection/linear'
22
+ require 'chainer/links/normalization/batch_normalization'
19
23
  require 'chainer/links/model/classifier'
20
24
  require 'chainer/variable'
21
25
  require 'chainer/variable_node'
26
+ require 'chainer/utils/conv'
22
27
  require 'chainer/utils/initializer'
28
+ require 'chainer/utils/math'
23
29
  require 'chainer/utils/variable'
24
30
  require 'chainer/utils/array'
25
31
  require 'chainer/functions/activation/leaky_relu'
@@ -29,10 +35,19 @@ require 'chainer/functions/activation/tanh'
29
35
  require 'chainer/functions/activation/log_softmax'
30
36
  require 'chainer/functions/evaluation/accuracy'
31
37
  require 'chainer/functions/math/basic_math'
38
+ require 'chainer/functions/math/identity'
39
+ require 'chainer/functions/loss/mean_squared_error'
32
40
  require 'chainer/functions/loss/softmax_cross_entropy'
41
+ require 'chainer/functions/connection/convolution_2d'
33
42
  require 'chainer/functions/connection/linear'
43
+ require 'chainer/functions/noise/dropout'
44
+ require 'chainer/functions/normalization/batch_normalization'
45
+ require 'chainer/functions/pooling/pooling_2d'
46
+ require 'chainer/functions/pooling/max_pooling_2d'
47
+ require 'chainer/testing/array'
34
48
  require 'chainer/training/extension'
35
49
  require 'chainer/training/extensions/evaluator'
50
+ require 'chainer/training/extensions/exponential_shift'
36
51
  require 'chainer/training/extensions/log_report'
37
52
  require 'chainer/training/extensions/print_report'
38
53
  require 'chainer/training/extensions/progress_bar'
@@ -44,8 +59,10 @@ require 'chainer/training/standard_updater'
44
59
  require 'chainer/training/triggers/interval'
45
60
  require 'chainer/parameter'
46
61
  require 'chainer/optimizers/adam'
62
+ require 'chainer/optimizers/momentum_sgd'
47
63
  require 'chainer/dataset/download'
48
64
  require 'chainer/datasets/mnist'
65
+ require 'chainer/datasets/cifar'
49
66
  require 'chainer/datasets/tuple_dataset'
50
67
  require 'chainer/reporter'
51
68
  require 'chainer/serializer'
@@ -1,9 +1,10 @@
1
1
  module Chainer
2
2
  class Configuration
3
- attr_accessor :enable_backprop
3
+ attr_accessor :enable_backprop, :train
4
4
 
5
5
  def initialize
6
6
  @enable_backprop = true
7
+ @train = true
7
8
  end
8
9
  end
9
10
  end
@@ -0,0 +1,18 @@
1
+ module Chainer
2
+ # Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+.
3
+ #
4
+ # This is almost equivalent to +Chainer::get_array_module+. The differences
5
+ # are that this function can be used even if CUDA is not available and that
6
+ # it will return their data arrays' array module for
7
+ # +Chainer::Variable+ arguments.
8
+ #
9
+ # @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
10
+ # @return [Numo::NArray] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of
11
+ # the arguments.
12
+ # @todo CUDA is not supported, yet.
13
+ #
14
+ def get_array_module(*args)
15
+ return Numo::NArray
16
+ end
17
+ module_function :get_array_module
18
+ end
@@ -29,30 +29,51 @@ module Chainer
29
29
 
30
30
  def self.concat_arrays(arrays, padding)
31
31
  unless arrays[0].kind_of?(Numo::NArray)
32
+ # [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
32
33
  arrays = Numo::NArray.cast(arrays)
34
+ if padding
35
+ return concat_arrays_with_padding(arrays, padding)
36
+ end
37
+ return arrays
33
38
  end
34
39
 
35
40
  if padding
36
41
  return concat_arrays_with_padding(arrays, padding)
37
42
  end
38
43
 
39
- Numo::NArray.[](*arrays.to_a.map { |arr| arr.kind_of?(Numeric) ? arr : Numo::NArray.[](*arr) })
44
+ # [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
45
+ # => Numo::SFloat#shape=[2,2]
46
+ # [[1, 2], [3, 4]]
47
+ a = arrays.map{|arr| arr[:-, false]}
48
+ a[0].concatenate(*a[1..-1])
40
49
  end
41
50
 
42
51
  def self.concat_arrays_with_padding(arrays, padding)
43
- shape = Numo::Int32.[](arrays[0].shape)
44
- arrays[1...arrays.len].each do |array|
45
- if Numo::Bit.[](shape != array.shape).any?
46
- # TODO: numpy maximum
52
+ if arrays[0].is_a? Numo::NArray
53
+ shape = Numo::Int32.cast(arrays[0].shape)
54
+ arrays[1..-1].each do |array|
55
+ if Numo::Bit.[](shape != array.shape).any?
56
+ shape = Numo::Int32.maximum(shape, array.shape)
57
+ end
47
58
  end
59
+ else # Integer
60
+ shape = []
61
+ end
62
+
63
+ shape = shape.insert(0, arrays.size).to_a
64
+ if arrays[0].is_a? Numo::NArray
65
+ result = arrays[0].class.new(shape).fill(padding)
66
+ else # Integer
67
+ result = Numo::Int32.new(shape).fill(padding)
48
68
  end
49
69
 
50
- shape = [shape.insert(0, arrays.size)]
51
- result = arrays[0].dtype.[](*shape).full(padding)
52
70
  arrays.size.times do |i|
53
71
  src = arrays[i]
54
- slices = src.shape.map { |s| [s] }
55
- result[[i] + slices] = src
72
+ if src.is_a? Numo::NArray
73
+ result[i, 0...src.shape[0], 0...src.shape[1]] = src
74
+ else # Integer
75
+ result[i] = src
76
+ end
56
77
  end
57
78
 
58
79
  result
@@ -0,0 +1,56 @@
1
+ require 'datasets'
2
+
3
+ module Chainer
4
+ module Datasets
5
+ module CIFAR
6
+ def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0)
7
+ get_cifar(10, with_label, ndim, scale)
8
+ end
9
+
10
+ def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0)
11
+ get_cifar(100, with_label, ndim, scale)
12
+ end
13
+
14
+ def self.get_cifar(n_classes, with_label, ndim, scale)
15
+ train_data = []
16
+ train_labels = []
17
+ ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record|
18
+ train_data << record.pixels
19
+ train_labels << (n_classes == 10 ? record.label : record.fine_label)
20
+ end
21
+
22
+ test_data = []
23
+ test_labels = []
24
+ ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record|
25
+ test_data << record.pixels
26
+ test_labels << (n_classes == 10 ? record.label : record.fine_label)
27
+ end
28
+
29
+ [
30
+ preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
31
+ preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
32
+ ]
33
+ end
34
+
35
+ def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
36
+ if ndim == 1
37
+ images = images.reshape(images.shape[0], 3072)
38
+ elsif ndim == 3
39
+ images = images.reshape(images.shape[0], 3, 32, 32)
40
+ else
41
+ raise 'invalid ndim for CIFAR dataset'
42
+ end
43
+ images = images.cast_to(Numo::DFloat)
44
+ images *= scale / 255.0
45
+
46
+ if withlabel
47
+ labels = labels.cast_to(Numo::Int32)
48
+ TupleDataset.new(images, labels)
49
+ else
50
+ images
51
+ end
52
+ end
53
+ end
54
+ end
55
+ end
56
+