red-chainer 0.2.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/examples/cifar/models/vgg.rb +84 -0
  4. data/examples/cifar/train_cifar.rb +70 -0
  5. data/examples/iris.rb +103 -0
  6. data/lib/chainer.rb +17 -0
  7. data/lib/chainer/configuration.rb +2 -1
  8. data/lib/chainer/cuda.rb +18 -0
  9. data/lib/chainer/dataset/convert.rb +30 -9
  10. data/lib/chainer/datasets/cifar.rb +56 -0
  11. data/lib/chainer/datasets/mnist.rb +3 -3
  12. data/lib/chainer/datasets/tuple_dataset.rb +3 -1
  13. data/lib/chainer/function.rb +1 -0
  14. data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
  15. data/lib/chainer/functions/activation/log_softmax.rb +4 -4
  16. data/lib/chainer/functions/activation/relu.rb +3 -4
  17. data/lib/chainer/functions/activation/sigmoid.rb +4 -4
  18. data/lib/chainer/functions/activation/tanh.rb +5 -5
  19. data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
  20. data/lib/chainer/functions/connection/linear.rb +1 -1
  21. data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
  22. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
  23. data/lib/chainer/functions/math/identity.rb +26 -0
  24. data/lib/chainer/functions/noise/dropout.rb +45 -0
  25. data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
  26. data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
  27. data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
  28. data/lib/chainer/gradient_check.rb +240 -0
  29. data/lib/chainer/initializer.rb +2 -0
  30. data/lib/chainer/initializers/constant.rb +1 -1
  31. data/lib/chainer/initializers/init.rb +5 -1
  32. data/lib/chainer/initializers/normal.rb +1 -1
  33. data/lib/chainer/iterators/serial_iterator.rb +1 -1
  34. data/lib/chainer/link.rb +11 -0
  35. data/lib/chainer/links/connection/convolution_2d.rb +98 -0
  36. data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
  37. data/lib/chainer/optimizer.rb +40 -1
  38. data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
  39. data/lib/chainer/parameter.rb +1 -1
  40. data/lib/chainer/serializers/marshal.rb +7 -3
  41. data/lib/chainer/testing/array.rb +32 -0
  42. data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
  43. data/lib/chainer/training/extensions/snapshot.rb +1 -1
  44. data/lib/chainer/training/standard_updater.rb +4 -0
  45. data/lib/chainer/training/trainer.rb +1 -1
  46. data/lib/chainer/utils/array.rb +13 -2
  47. data/lib/chainer/utils/conv.rb +59 -0
  48. data/lib/chainer/utils/math.rb +72 -0
  49. data/lib/chainer/utils/variable.rb +7 -3
  50. data/lib/chainer/version.rb +1 -1
  51. data/red-chainer.gemspec +1 -0
  52. metadata +37 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2fae8ab4c6d546356b7171ba9823e1537f02ad5e08358d97b29d907e67e11b20
4
- data.tar.gz: 87b807d2cdf48b060100ade6c0f294e249db8f8a6a131754c4a108758867a1e5
3
+ metadata.gz: 4bd5f287aafbc100878391131a0a184c9c6497ddb26e0c5124010166aa455def
4
+ data.tar.gz: 6625f1d862ba17db5867b208494419f1e7e7bc1ed5b5276bec785151140a47bd
5
5
  SHA512:
6
- metadata.gz: b533669c998ef246222a001658193e489c5500a4d9276c53e43c332543342bb258c2a8cf8958b47cd5201c52c8f6c3d1378af33c20c0f3644d0d155addeafc04
7
- data.tar.gz: 16c2e65110c8a83e42d85e482a35209463c36766d2eb6d6eb76045f44d55472f5f6c94b8050221f2273f5e8d2ba644648fb3c10b2a0b9cf4d03b86495470d7a0
6
+ metadata.gz: 659986f00b88051471ca54a88eafb4a6f9ebbdf8f27a9f51c8ff32a27f584d2a78015d47baf20e74c6e9605c09bda0777b4780c4ef79313d194a362e0d3ff013
7
+ data.tar.gz: a468488ee913244bcdbc058fe27c93beec897bd4476f1eec50ae2bc3071124701aaf37b9f15cd07848bbb30372e4796c04b24ead672c28e5fcb0dbb1721b7bd6
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  ## Name
4
4
 
5
- Red Cahiner
5
+ Red Chainer
6
6
 
7
7
  ## Description
8
8
  Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
@@ -39,4 +39,4 @@ $ ruby examples/mnist.rb
39
39
 
40
40
  ## License
41
41
 
42
- The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
42
+ The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
@@ -0,0 +1,84 @@
1
+ class Block < Chainer::Chain
2
+ def initialize(out_channels, ksize, pad: 1)
3
+ super()
4
+ init_scope do
5
+ @conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true)
6
+ @bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels)
7
+ end
8
+ end
9
+
10
+ def call(x)
11
+ h = @conv.(x)
12
+ h = @bn.(h)
13
+ Chainer::Functions::Activation::Relu.relu(h)
14
+ end
15
+ end
16
+
17
+ class VGG < Chainer::Chain
18
+ def initialize(class_labels: 10)
19
+ super()
20
+ init_scope do
21
+ @block1_1 = Block.new(64, 3)
22
+ @block1_2 = Block.new(64, 3)
23
+ @block2_1 = Block.new(128, 3)
24
+ @block2_2 = Block.new(128, 3)
25
+ @block3_1 = Block.new(256, 3)
26
+ @block3_2 = Block.new(256, 3)
27
+ @block3_3 = Block.new(256, 3)
28
+ @block4_1 = Block.new(512, 3)
29
+ @block4_2 = Block.new(512, 3)
30
+ @block4_3 = Block.new(512, 3)
31
+ @block5_1 = Block.new(512, 3)
32
+ @block5_2 = Block.new(512, 3)
33
+ @block5_3 = Block.new(512, 3)
34
+ @fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
35
+ @bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
36
+ @fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true)
37
+ end
38
+ end
39
+
40
+ def call(x)
41
+ # 64 channel blocks:
42
+ h = @block1_1.(x)
43
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3)
44
+ h = @block1_2.(h)
45
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
46
+
47
+ # 128 channel blocks:
48
+ h = @block2_1.(h)
49
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
50
+ h = @block2_2.(h)
51
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2)
52
+
53
+ # 256 channel blocks:
54
+ h = @block3_1.(h)
55
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
56
+ h = @block3_2.(h)
57
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
58
+ h = @block3_3.(h)
59
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
60
+
61
+ # 512 channel blocks:
62
+ h = @block4_1.(h)
63
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
64
+ h = @block4_2.(h)
65
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
66
+ h = @block4_3.(h)
67
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
68
+
69
+ # 512 channel blocks:
70
+ h = @block5_1.(h)
71
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
72
+ h = @block5_2.(h)
73
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
74
+ h = @block5_3.(h)
75
+ h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
76
+
77
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
78
+ h = @fc1.(h)
79
+ h = @bn_fc1.(h)
80
+ h = Chainer::Functions::Activation::Relu.relu(h)
81
+ h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
82
+ @fc2.(h)
83
+ end
84
+ end
@@ -0,0 +1,70 @@
1
+ require 'chainer'
2
+ require __dir__ + '/models/vgg'
3
+ require 'optparse'
4
+
5
+ args = {
6
+ dataset: 'cifar10',
7
+ frequency: -1,
8
+ batchsize: 64,
9
+ learnrate: 0.05,
10
+ epoch: 300,
11
+ out: 'result',
12
+ resume: nil
13
+ }
14
+
15
+
16
+ opt = OptionParser.new
17
+ opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v }
18
+ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
19
+ opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
20
+ opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
21
+ opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
22
+ opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
23
+ opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
24
+ opt.parse!(ARGV)
25
+
26
+ # Set up a neural network to train.
27
+ # Classifier reports softmax cross entropy loss and accuracy at every
28
+ # iteration, which will be used by the PrintReport extension below.
29
+ if args[:dataset] == 'cifar10'
30
+ puts 'Using CIFAR10 dataset.'
31
+ class_labels = 10
32
+ train, test = Chainer::Datasets::CIFAR.get_cifar10
33
+ elsif args[:dataset] == 'cifar100'
34
+ puts 'Using CIFAR100 dataset.'
35
+ class_labels = 100
36
+ train, test = Chainer::Datasets::CIFAR.get_cifar100
37
+ else
38
+ raise 'Invalid dataset choice.'
39
+ end
40
+
41
+ puts "setup..."
42
+
43
+ model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels))
44
+
45
+ optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
46
+ optimizer.setup(model)
47
+
48
+ train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
49
+ test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
50
+
51
+ updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
52
+ trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
53
+
54
+ trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
55
+
56
+ trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
57
+
58
+ frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
59
+ trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])
60
+
61
+ trainer.extend(Chainer::Training::Extensions::LogReport.new)
62
+ trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
63
+ trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
64
+
65
+ if args[:resume]
66
+ Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
67
+ end
68
+
69
+ trainer.run
70
+
data/examples/iris.rb ADDED
@@ -0,0 +1,103 @@
1
+ require 'numo/narray'
2
+ require 'chainer'
3
+ require "datasets"
4
+
5
+ class IrisChain < Chainer::Chain
6
+ L = Chainer::Links::Connection::Linear
7
+ F = Chainer::Functions
8
+
9
+ def initialize(n_units, n_out)
10
+ super()
11
+ init_scope do
12
+ @l1 = L.new(nil, out_size: n_units)
13
+ @l2 = L.new(nil, out_size: n_out)
14
+ end
15
+ end
16
+
17
+ def call(x, y)
18
+ return F::Loss::MeanSquaredError.mean_squared_error(fwd(x), y)
19
+ end
20
+
21
+ def fwd(x)
22
+ h1 = F::Activation::Sigmoid.sigmoid(@l1.(x))
23
+ h2 = @l2.(h1)
24
+ return h2
25
+ end
26
+ end
27
+
28
+ model = IrisChain.new(6,3)
29
+
30
+ optimizer = Chainer::Optimizers::Adam.new
31
+ optimizer.setup(model)
32
+
33
+ iris = Datasets::Iris.new
34
+ x = iris.each.map {|r| r.each.to_a[0..3]}
35
+
36
+ # target
37
+ y_class = iris.each.map {|r| r.class}
38
+
39
+ # class index array
40
+ # ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
41
+ class_name = y_class.uniq
42
+ # y => [0, 0, 0, 0, ,,, 1, 1, ,,, ,2, 2]
43
+ y = y_class.map{|s|
44
+ class_name.index(s)
45
+ }
46
+
47
+ # y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
48
+ y_onehot = y_class.map{|s|
49
+ i = class_name.index(s)
50
+ a = Array.new(class_name.size, 0.0)
51
+ a[i] = 1.0
52
+ a
53
+ }
54
+
55
+ puts "Iris Datasets"
56
+ puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
57
+ x.each_with_index{|r, i|
58
+ puts "#{'%3d' % i} : [#{r.join(', ')}] #{y_onehot[i]} #=> #{y_class[i]}(#{y[i]})"
59
+ }
60
+ # [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] => 50 data
61
+ # [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
62
+ # [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
63
+
64
+ x = Numo::SFloat.cast(x)
65
+ y = Numo::SFloat.cast(y)
66
+ y_onehot = Numo::SFloat.cast(y_onehot)
67
+
68
+ x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
69
+ y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
70
+ x_test = x[(0..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
71
+ y_test = y[(0..-1).step(2)] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
72
+
73
+ # Train
74
+ 10000.times{|i|
75
+ x = Chainer::Variable.new(x_train)
76
+ y = Chainer::Variable.new(y_train)
77
+ model.cleargrads()
78
+ loss = model.(x, y)
79
+ loss.backward()
80
+ optimizer.update()
81
+ }
82
+
83
+ # Test
84
+ xt = Chainer::Variable.new(x_test)
85
+ yt = model.fwd(xt)
86
+ n_row, n_col = yt.data.shape
87
+
88
+ puts "Result : Correct Answer : Answer <= One-Hot"
89
+ ok = 0
90
+ n_row.times{|i|
91
+ ans = yt.data[i, true].max_index()
92
+ if ans == y_test[i]
93
+ ok += 1
94
+ printf("OK")
95
+ else
96
+ printf("--")
97
+ end
98
+ printf(" : #{y_test[i].to_i} :")
99
+
100
+ puts " #{ans.to_i} <= #{yt.data[i, 0..-1].to_a}"
101
+ }
102
+ puts "Row: #{n_row}, Column: #{n_col}"
103
+ puts "Accuracy rate : #{ok}/#{n_row} = #{ok.to_f / n_row}"
data/lib/chainer.rb CHANGED
@@ -2,10 +2,12 @@ require "weakref"
2
2
 
3
3
  require "chainer/version"
4
4
 
5
+ require 'chainer/cuda'
5
6
  require 'chainer/configuration'
6
7
  require 'chainer/function'
7
8
  require 'chainer/optimizer'
8
9
  require 'chainer/gradient_method'
10
+ require 'chainer/gradient_check'
9
11
  require 'chainer/hyperparameter'
10
12
  require 'chainer/dataset/iterator'
11
13
  require 'chainer/dataset/convert'
@@ -15,11 +17,15 @@ require 'chainer/initializers/constant'
15
17
  require 'chainer/initializers/normal'
16
18
  require 'chainer/iterators/serial_iterator'
17
19
  require 'chainer/link'
20
+ require 'chainer/links/connection/convolution_2d'
18
21
  require 'chainer/links/connection/linear'
22
+ require 'chainer/links/normalization/batch_normalization'
19
23
  require 'chainer/links/model/classifier'
20
24
  require 'chainer/variable'
21
25
  require 'chainer/variable_node'
26
+ require 'chainer/utils/conv'
22
27
  require 'chainer/utils/initializer'
28
+ require 'chainer/utils/math'
23
29
  require 'chainer/utils/variable'
24
30
  require 'chainer/utils/array'
25
31
  require 'chainer/functions/activation/leaky_relu'
@@ -29,10 +35,19 @@ require 'chainer/functions/activation/tanh'
29
35
  require 'chainer/functions/activation/log_softmax'
30
36
  require 'chainer/functions/evaluation/accuracy'
31
37
  require 'chainer/functions/math/basic_math'
38
+ require 'chainer/functions/math/identity'
39
+ require 'chainer/functions/loss/mean_squared_error'
32
40
  require 'chainer/functions/loss/softmax_cross_entropy'
41
+ require 'chainer/functions/connection/convolution_2d'
33
42
  require 'chainer/functions/connection/linear'
43
+ require 'chainer/functions/noise/dropout'
44
+ require 'chainer/functions/normalization/batch_normalization'
45
+ require 'chainer/functions/pooling/pooling_2d'
46
+ require 'chainer/functions/pooling/max_pooling_2d'
47
+ require 'chainer/testing/array'
34
48
  require 'chainer/training/extension'
35
49
  require 'chainer/training/extensions/evaluator'
50
+ require 'chainer/training/extensions/exponential_shift'
36
51
  require 'chainer/training/extensions/log_report'
37
52
  require 'chainer/training/extensions/print_report'
38
53
  require 'chainer/training/extensions/progress_bar'
@@ -44,8 +59,10 @@ require 'chainer/training/standard_updater'
44
59
  require 'chainer/training/triggers/interval'
45
60
  require 'chainer/parameter'
46
61
  require 'chainer/optimizers/adam'
62
+ require 'chainer/optimizers/momentum_sgd'
47
63
  require 'chainer/dataset/download'
48
64
  require 'chainer/datasets/mnist'
65
+ require 'chainer/datasets/cifar'
49
66
  require 'chainer/datasets/tuple_dataset'
50
67
  require 'chainer/reporter'
51
68
  require 'chainer/serializer'
@@ -1,9 +1,10 @@
1
1
  module Chainer
2
2
  class Configuration
3
- attr_accessor :enable_backprop
3
+ attr_accessor :enable_backprop, :train
4
4
 
5
5
  def initialize
6
6
  @enable_backprop = true
7
+ @train = true
7
8
  end
8
9
  end
9
10
  end
@@ -0,0 +1,18 @@
1
+ module Chainer
2
+ # Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+.
3
+ #
4
+ # This is almost equivalent to +Chainer::get_array_module+. The differences
5
+ # are that this function can be used even if CUDA is not available and that
6
+ # it will return their data arrays' array module for
7
+ # +Chainer::Variable+ arguments.
8
+ #
9
+ # @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
10
+ # @return [Numo::NArray] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of
11
+ # the arguments.
12
+ # @todo CUDA is not supported, yet.
13
+ #
14
+ def get_array_module(*args)
15
+ return Numo::NArray
16
+ end
17
+ module_function :get_array_module
18
+ end
@@ -29,30 +29,51 @@ module Chainer
29
29
 
30
30
  def self.concat_arrays(arrays, padding)
31
31
  unless arrays[0].kind_of?(Numo::NArray)
32
+ # [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
32
33
  arrays = Numo::NArray.cast(arrays)
34
+ if padding
35
+ return concat_arrays_with_padding(arrays, padding)
36
+ end
37
+ return arrays
33
38
  end
34
39
 
35
40
  if padding
36
41
  return concat_arrays_with_padding(arrays, padding)
37
42
  end
38
43
 
39
- Numo::NArray.[](*arrays.to_a.map { |arr| arr.kind_of?(Numeric) ? arr : Numo::NArray.[](*arr) })
44
+ # [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
45
+ # => Numo::SFloat#shape=[2,2]
46
+ # [[1, 2], [3, 4]]
47
+ a = arrays.map{|arr| arr[:-, false]}
48
+ a[0].concatenate(*a[1..-1])
40
49
  end
41
50
 
42
51
  def self.concat_arrays_with_padding(arrays, padding)
43
- shape = Numo::Int32.[](arrays[0].shape)
44
- arrays[1...arrays.len].each do |array|
45
- if Numo::Bit.[](shape != array.shape).any?
46
- # TODO: numpy maximum
52
+ if arrays[0].is_a? Numo::NArray
53
+ shape = Numo::Int32.cast(arrays[0].shape)
54
+ arrays[1..-1].each do |array|
55
+ if Numo::Bit.[](shape != array.shape).any?
56
+ shape = Numo::Int32.maximum(shape, array.shape)
57
+ end
47
58
  end
59
+ else # Integer
60
+ shape = []
61
+ end
62
+
63
+ shape = shape.insert(0, arrays.size).to_a
64
+ if arrays[0].is_a? Numo::NArray
65
+ result = arrays[0].class.new(shape).fill(padding)
66
+ else # Integer
67
+ result = Numo::Int32.new(shape).fill(padding)
48
68
  end
49
69
 
50
- shape = [shape.insert(0, arrays.size)]
51
- result = arrays[0].dtype.[](*shape).full(padding)
52
70
  arrays.size.times do |i|
53
71
  src = arrays[i]
54
- slices = src.shape.map { |s| [s] }
55
- result[[i] + slices] = src
72
+ if src.is_a? Numo::NArray
73
+ result[i, 0...src.shape[0], 0...src.shape[1]] = src
74
+ else # Integer
75
+ result[i] = src
76
+ end
56
77
  end
57
78
 
58
79
  result
@@ -0,0 +1,56 @@
1
+ require 'datasets'
2
+
3
+ module Chainer
4
+ module Datasets
5
+ module CIFAR
6
+ def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0)
7
+ get_cifar(10, with_label, ndim, scale)
8
+ end
9
+
10
+ def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0)
11
+ get_cifar(100, with_label, ndim, scale)
12
+ end
13
+
14
+ def self.get_cifar(n_classes, with_label, ndim, scale)
15
+ train_data = []
16
+ train_labels = []
17
+ ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record|
18
+ train_data << record.pixels
19
+ train_labels << (n_classes == 10 ? record.label : record.fine_label)
20
+ end
21
+
22
+ test_data = []
23
+ test_labels = []
24
+ ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record|
25
+ test_data << record.pixels
26
+ test_labels << (n_classes == 10 ? record.label : record.fine_label)
27
+ end
28
+
29
+ [
30
+ preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
31
+ preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
32
+ ]
33
+ end
34
+
35
+ def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
36
+ if ndim == 1
37
+ images = images.reshape(images.shape[0], 3072)
38
+ elsif ndim == 3
39
+ images = images.reshape(images.shape[0], 3, 32, 32)
40
+ else
41
+ raise 'invalid ndim for CIFAR dataset'
42
+ end
43
+ images = images.cast_to(Numo::DFloat)
44
+ images *= scale / 255.0
45
+
46
+ if withlabel
47
+ labels = labels.cast_to(Numo::Int32)
48
+ TupleDataset.new(images, labels)
49
+ else
50
+ images
51
+ end
52
+ end
53
+ end
54
+ end
55
+ end
56
+