red-chainer 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 4bd5f287aafbc100878391131a0a184c9c6497ddb26e0c5124010166aa455def
|
4
|
+
data.tar.gz: 6625f1d862ba17db5867b208494419f1e7e7bc1ed5b5276bec785151140a47bd
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 659986f00b88051471ca54a88eafb4a6f9ebbdf8f27a9f51c8ff32a27f584d2a78015d47baf20e74c6e9605c09bda0777b4780c4ef79313d194a362e0d3ff013
|
7
|
+
data.tar.gz: a468488ee913244bcdbc058fe27c93beec897bd4476f1eec50ae2bc3071124701aaf37b9f15cd07848bbb30372e4796c04b24ead672c28e5fcb0dbb1721b7bd6
|
data/README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
## Name
|
4
4
|
|
5
|
-
Red
|
5
|
+
Red Chainer
|
6
6
|
|
7
7
|
## Description
|
8
8
|
Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
|
@@ -39,4 +39,4 @@ $ ruby examples/mnist.rb
|
|
39
39
|
|
40
40
|
## License
|
41
41
|
|
42
|
-
The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
42
|
+
The MIT license. See [LICENSE.txt](./LICENSE.txt) for details.
|
@@ -0,0 +1,84 @@
|
|
1
|
+
class Block < Chainer::Chain
|
2
|
+
def initialize(out_channels, ksize, pad: 1)
|
3
|
+
super()
|
4
|
+
init_scope do
|
5
|
+
@conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true)
|
6
|
+
@bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels)
|
7
|
+
end
|
8
|
+
end
|
9
|
+
|
10
|
+
def call(x)
|
11
|
+
h = @conv.(x)
|
12
|
+
h = @bn.(h)
|
13
|
+
Chainer::Functions::Activation::Relu.relu(h)
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
class VGG < Chainer::Chain
|
18
|
+
def initialize(class_labels: 10)
|
19
|
+
super()
|
20
|
+
init_scope do
|
21
|
+
@block1_1 = Block.new(64, 3)
|
22
|
+
@block1_2 = Block.new(64, 3)
|
23
|
+
@block2_1 = Block.new(128, 3)
|
24
|
+
@block2_2 = Block.new(128, 3)
|
25
|
+
@block3_1 = Block.new(256, 3)
|
26
|
+
@block3_2 = Block.new(256, 3)
|
27
|
+
@block3_3 = Block.new(256, 3)
|
28
|
+
@block4_1 = Block.new(512, 3)
|
29
|
+
@block4_2 = Block.new(512, 3)
|
30
|
+
@block4_3 = Block.new(512, 3)
|
31
|
+
@block5_1 = Block.new(512, 3)
|
32
|
+
@block5_2 = Block.new(512, 3)
|
33
|
+
@block5_3 = Block.new(512, 3)
|
34
|
+
@fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
|
35
|
+
@bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
|
36
|
+
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
def call(x)
|
41
|
+
# 64 channel blocks:
|
42
|
+
h = @block1_1.(x)
|
43
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3)
|
44
|
+
h = @block1_2.(h)
|
45
|
+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
|
46
|
+
|
47
|
+
# 128 channel blocks:
|
48
|
+
h = @block2_1.(h)
|
49
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
50
|
+
h = @block2_2.(h)
|
51
|
+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2)
|
52
|
+
|
53
|
+
# 256 channel blocks:
|
54
|
+
h = @block3_1.(h)
|
55
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
56
|
+
h = @block3_2.(h)
|
57
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
58
|
+
h = @block3_3.(h)
|
59
|
+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
|
60
|
+
|
61
|
+
# 512 channel blocks:
|
62
|
+
h = @block4_1.(h)
|
63
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
64
|
+
h = @block4_2.(h)
|
65
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
66
|
+
h = @block4_3.(h)
|
67
|
+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
|
68
|
+
|
69
|
+
# 512 channel blocks:
|
70
|
+
h = @block5_1.(h)
|
71
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
72
|
+
h = @block5_2.(h)
|
73
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
|
74
|
+
h = @block5_3.(h)
|
75
|
+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
|
76
|
+
|
77
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
|
78
|
+
h = @fc1.(h)
|
79
|
+
h = @bn_fc1.(h)
|
80
|
+
h = Chainer::Functions::Activation::Relu.relu(h)
|
81
|
+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
|
82
|
+
@fc2.(h)
|
83
|
+
end
|
84
|
+
end
|
@@ -0,0 +1,70 @@
|
|
1
|
+
require 'chainer'
|
2
|
+
require __dir__ + '/models/vgg'
|
3
|
+
require 'optparse'
|
4
|
+
|
5
|
+
args = {
|
6
|
+
dataset: 'cifar10',
|
7
|
+
frequency: -1,
|
8
|
+
batchsize: 64,
|
9
|
+
learnrate: 0.05,
|
10
|
+
epoch: 300,
|
11
|
+
out: 'result',
|
12
|
+
resume: nil
|
13
|
+
}
|
14
|
+
|
15
|
+
|
16
|
+
opt = OptionParser.new
|
17
|
+
opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v }
|
18
|
+
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
|
19
|
+
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
|
20
|
+
opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
|
21
|
+
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
|
22
|
+
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
|
23
|
+
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
|
24
|
+
opt.parse!(ARGV)
|
25
|
+
|
26
|
+
# Set up a neural network to train.
|
27
|
+
# Classifier reports softmax cross entropy loss and accuracy at every
|
28
|
+
# iteration, which will be used by the PrintReport extension below.
|
29
|
+
if args[:dataset] == 'cifar10'
|
30
|
+
puts 'Using CIFAR10 dataset.'
|
31
|
+
class_labels = 10
|
32
|
+
train, test = Chainer::Datasets::CIFAR.get_cifar10
|
33
|
+
elsif args[:dataset] == 'cifar100'
|
34
|
+
puts 'Using CIFAR100 dataset.'
|
35
|
+
class_labels = 100
|
36
|
+
train, test = Chainer::Datasets::CIFAR.get_cifar100
|
37
|
+
else
|
38
|
+
raise 'Invalid dataset choice.'
|
39
|
+
end
|
40
|
+
|
41
|
+
puts "setup..."
|
42
|
+
|
43
|
+
model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels))
|
44
|
+
|
45
|
+
optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
|
46
|
+
optimizer.setup(model)
|
47
|
+
|
48
|
+
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
|
49
|
+
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
|
50
|
+
|
51
|
+
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
|
52
|
+
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
|
53
|
+
|
54
|
+
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
|
55
|
+
|
56
|
+
trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
|
57
|
+
|
58
|
+
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
|
59
|
+
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])
|
60
|
+
|
61
|
+
trainer.extend(Chainer::Training::Extensions::LogReport.new)
|
62
|
+
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
|
63
|
+
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
|
64
|
+
|
65
|
+
if args[:resume]
|
66
|
+
Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
|
67
|
+
end
|
68
|
+
|
69
|
+
trainer.run
|
70
|
+
|
data/examples/iris.rb
ADDED
@@ -0,0 +1,103 @@
|
|
1
|
+
require 'numo/narray'
|
2
|
+
require 'chainer'
|
3
|
+
require "datasets"
|
4
|
+
|
5
|
+
class IrisChain < Chainer::Chain
|
6
|
+
L = Chainer::Links::Connection::Linear
|
7
|
+
F = Chainer::Functions
|
8
|
+
|
9
|
+
def initialize(n_units, n_out)
|
10
|
+
super()
|
11
|
+
init_scope do
|
12
|
+
@l1 = L.new(nil, out_size: n_units)
|
13
|
+
@l2 = L.new(nil, out_size: n_out)
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(x, y)
|
18
|
+
return F::Loss::MeanSquaredError.mean_squared_error(fwd(x), y)
|
19
|
+
end
|
20
|
+
|
21
|
+
def fwd(x)
|
22
|
+
h1 = F::Activation::Sigmoid.sigmoid(@l1.(x))
|
23
|
+
h2 = @l2.(h1)
|
24
|
+
return h2
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
model = IrisChain.new(6,3)
|
29
|
+
|
30
|
+
optimizer = Chainer::Optimizers::Adam.new
|
31
|
+
optimizer.setup(model)
|
32
|
+
|
33
|
+
iris = Datasets::Iris.new
|
34
|
+
x = iris.each.map {|r| r.each.to_a[0..3]}
|
35
|
+
|
36
|
+
# target
|
37
|
+
y_class = iris.each.map {|r| r.class}
|
38
|
+
|
39
|
+
# class index array
|
40
|
+
# ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
|
41
|
+
class_name = y_class.uniq
|
42
|
+
# y => [0, 0, 0, 0, ,,, 1, 1, ,,, ,2, 2]
|
43
|
+
y = y_class.map{|s|
|
44
|
+
class_name.index(s)
|
45
|
+
}
|
46
|
+
|
47
|
+
# y_onehot => One-hot [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0],,, [0.0, 1.0, 0.0], ,, [0.0, 0.0, 1.0]]
|
48
|
+
y_onehot = y_class.map{|s|
|
49
|
+
i = class_name.index(s)
|
50
|
+
a = Array.new(class_name.size, 0.0)
|
51
|
+
a[i] = 1.0
|
52
|
+
a
|
53
|
+
}
|
54
|
+
|
55
|
+
puts "Iris Datasets"
|
56
|
+
puts "No. [sepal_length, sepal_width, petal_length, petal_width] one-hot #=> class"
|
57
|
+
x.each_with_index{|r, i|
|
58
|
+
puts "#{'%3d' % i} : [#{r.join(', ')}] #{y_onehot[i]} #=> #{y_class[i]}(#{y[i]})"
|
59
|
+
}
|
60
|
+
# [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] => 50 data
|
61
|
+
# [7.0, 3.2, 4.7, 1.4, "Iris-versicolor"] => 50 data
|
62
|
+
# [6.3, 3.3, 6.0, 2.5, "Iris-virginica"] => 50 data
|
63
|
+
|
64
|
+
x = Numo::SFloat.cast(x)
|
65
|
+
y = Numo::SFloat.cast(y)
|
66
|
+
y_onehot = Numo::SFloat.cast(y_onehot)
|
67
|
+
|
68
|
+
x_train = x[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
69
|
+
y_train = y_onehot[(1..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
70
|
+
x_test = x[(0..-1).step(2), true] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
71
|
+
y_test = y[(0..-1).step(2)] #=> 75 data (Iris-setosa : 25, Iris-versicolor : 25, Iris-virginica : 25)
|
72
|
+
|
73
|
+
# Train
|
74
|
+
10000.times{|i|
|
75
|
+
x = Chainer::Variable.new(x_train)
|
76
|
+
y = Chainer::Variable.new(y_train)
|
77
|
+
model.cleargrads()
|
78
|
+
loss = model.(x, y)
|
79
|
+
loss.backward()
|
80
|
+
optimizer.update()
|
81
|
+
}
|
82
|
+
|
83
|
+
# Test
|
84
|
+
xt = Chainer::Variable.new(x_test)
|
85
|
+
yt = model.fwd(xt)
|
86
|
+
n_row, n_col = yt.data.shape
|
87
|
+
|
88
|
+
puts "Result : Correct Answer : Answer <= One-Hot"
|
89
|
+
ok = 0
|
90
|
+
n_row.times{|i|
|
91
|
+
ans = yt.data[i, true].max_index()
|
92
|
+
if ans == y_test[i]
|
93
|
+
ok += 1
|
94
|
+
printf("OK")
|
95
|
+
else
|
96
|
+
printf("--")
|
97
|
+
end
|
98
|
+
printf(" : #{y_test[i].to_i} :")
|
99
|
+
|
100
|
+
puts " #{ans.to_i} <= #{yt.data[i, 0..-1].to_a}"
|
101
|
+
}
|
102
|
+
puts "Row: #{n_row}, Column: #{n_col}"
|
103
|
+
puts "Accuracy rate : #{ok}/#{n_row} = #{ok.to_f / n_row}"
|
data/lib/chainer.rb
CHANGED
@@ -2,10 +2,12 @@ require "weakref"
|
|
2
2
|
|
3
3
|
require "chainer/version"
|
4
4
|
|
5
|
+
require 'chainer/cuda'
|
5
6
|
require 'chainer/configuration'
|
6
7
|
require 'chainer/function'
|
7
8
|
require 'chainer/optimizer'
|
8
9
|
require 'chainer/gradient_method'
|
10
|
+
require 'chainer/gradient_check'
|
9
11
|
require 'chainer/hyperparameter'
|
10
12
|
require 'chainer/dataset/iterator'
|
11
13
|
require 'chainer/dataset/convert'
|
@@ -15,11 +17,15 @@ require 'chainer/initializers/constant'
|
|
15
17
|
require 'chainer/initializers/normal'
|
16
18
|
require 'chainer/iterators/serial_iterator'
|
17
19
|
require 'chainer/link'
|
20
|
+
require 'chainer/links/connection/convolution_2d'
|
18
21
|
require 'chainer/links/connection/linear'
|
22
|
+
require 'chainer/links/normalization/batch_normalization'
|
19
23
|
require 'chainer/links/model/classifier'
|
20
24
|
require 'chainer/variable'
|
21
25
|
require 'chainer/variable_node'
|
26
|
+
require 'chainer/utils/conv'
|
22
27
|
require 'chainer/utils/initializer'
|
28
|
+
require 'chainer/utils/math'
|
23
29
|
require 'chainer/utils/variable'
|
24
30
|
require 'chainer/utils/array'
|
25
31
|
require 'chainer/functions/activation/leaky_relu'
|
@@ -29,10 +35,19 @@ require 'chainer/functions/activation/tanh'
|
|
29
35
|
require 'chainer/functions/activation/log_softmax'
|
30
36
|
require 'chainer/functions/evaluation/accuracy'
|
31
37
|
require 'chainer/functions/math/basic_math'
|
38
|
+
require 'chainer/functions/math/identity'
|
39
|
+
require 'chainer/functions/loss/mean_squared_error'
|
32
40
|
require 'chainer/functions/loss/softmax_cross_entropy'
|
41
|
+
require 'chainer/functions/connection/convolution_2d'
|
33
42
|
require 'chainer/functions/connection/linear'
|
43
|
+
require 'chainer/functions/noise/dropout'
|
44
|
+
require 'chainer/functions/normalization/batch_normalization'
|
45
|
+
require 'chainer/functions/pooling/pooling_2d'
|
46
|
+
require 'chainer/functions/pooling/max_pooling_2d'
|
47
|
+
require 'chainer/testing/array'
|
34
48
|
require 'chainer/training/extension'
|
35
49
|
require 'chainer/training/extensions/evaluator'
|
50
|
+
require 'chainer/training/extensions/exponential_shift'
|
36
51
|
require 'chainer/training/extensions/log_report'
|
37
52
|
require 'chainer/training/extensions/print_report'
|
38
53
|
require 'chainer/training/extensions/progress_bar'
|
@@ -44,8 +59,10 @@ require 'chainer/training/standard_updater'
|
|
44
59
|
require 'chainer/training/triggers/interval'
|
45
60
|
require 'chainer/parameter'
|
46
61
|
require 'chainer/optimizers/adam'
|
62
|
+
require 'chainer/optimizers/momentum_sgd'
|
47
63
|
require 'chainer/dataset/download'
|
48
64
|
require 'chainer/datasets/mnist'
|
65
|
+
require 'chainer/datasets/cifar'
|
49
66
|
require 'chainer/datasets/tuple_dataset'
|
50
67
|
require 'chainer/reporter'
|
51
68
|
require 'chainer/serializer'
|
data/lib/chainer/cuda.rb
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
module Chainer
|
2
|
+
# Gets an appropriate one from +Numo::NArray+ or +Cumo::NArray+.
|
3
|
+
#
|
4
|
+
# This is almost equivalent to +Chainer::get_array_module+. The differences
|
5
|
+
# are that this function can be used even if CUDA is not available and that
|
6
|
+
# it will return their data arrays' array module for
|
7
|
+
# +Chainer::Variable+ arguments.
|
8
|
+
#
|
9
|
+
# @param [Array<Chainer::Variable> or Array<Numo::NArray> or Array<Cumo::NArray>] args Values to determine whether Numo or Cumo should be used.
|
10
|
+
# @return [Numo::NArray] +Cumo::NArray+ or +Numo::NArray+ is returned based on the types of
|
11
|
+
# the arguments.
|
12
|
+
# @todo CUDA is not supported, yet.
|
13
|
+
#
|
14
|
+
def get_array_module(*args)
|
15
|
+
return Numo::NArray
|
16
|
+
end
|
17
|
+
module_function :get_array_module
|
18
|
+
end
|
@@ -29,30 +29,51 @@ module Chainer
|
|
29
29
|
|
30
30
|
def self.concat_arrays(arrays, padding)
|
31
31
|
unless arrays[0].kind_of?(Numo::NArray)
|
32
|
+
# [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
|
32
33
|
arrays = Numo::NArray.cast(arrays)
|
34
|
+
if padding
|
35
|
+
return concat_arrays_with_padding(arrays, padding)
|
36
|
+
end
|
37
|
+
return arrays
|
33
38
|
end
|
34
39
|
|
35
40
|
if padding
|
36
41
|
return concat_arrays_with_padding(arrays, padding)
|
37
42
|
end
|
38
43
|
|
39
|
-
Numo::
|
44
|
+
# [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
|
45
|
+
# => Numo::SFloat#shape=[2,2]
|
46
|
+
# [[1, 2], [3, 4]]
|
47
|
+
a = arrays.map{|arr| arr[:-, false]}
|
48
|
+
a[0].concatenate(*a[1..-1])
|
40
49
|
end
|
41
50
|
|
42
51
|
def self.concat_arrays_with_padding(arrays, padding)
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
52
|
+
if arrays[0].is_a? Numo::NArray
|
53
|
+
shape = Numo::Int32.cast(arrays[0].shape)
|
54
|
+
arrays[1..-1].each do |array|
|
55
|
+
if Numo::Bit.[](shape != array.shape).any?
|
56
|
+
shape = Numo::Int32.maximum(shape, array.shape)
|
57
|
+
end
|
47
58
|
end
|
59
|
+
else # Integer
|
60
|
+
shape = []
|
61
|
+
end
|
62
|
+
|
63
|
+
shape = shape.insert(0, arrays.size).to_a
|
64
|
+
if arrays[0].is_a? Numo::NArray
|
65
|
+
result = arrays[0].class.new(shape).fill(padding)
|
66
|
+
else # Integer
|
67
|
+
result = Numo::Int32.new(shape).fill(padding)
|
48
68
|
end
|
49
69
|
|
50
|
-
shape = [shape.insert(0, arrays.size)]
|
51
|
-
result = arrays[0].dtype.[](*shape).full(padding)
|
52
70
|
arrays.size.times do |i|
|
53
71
|
src = arrays[i]
|
54
|
-
|
55
|
-
|
72
|
+
if src.is_a? Numo::NArray
|
73
|
+
result[i, 0...src.shape[0], 0...src.shape[1]] = src
|
74
|
+
else # Integer
|
75
|
+
result[i] = src
|
76
|
+
end
|
56
77
|
end
|
57
78
|
|
58
79
|
result
|
@@ -0,0 +1,56 @@
|
|
1
|
+
require 'datasets'
|
2
|
+
|
3
|
+
module Chainer
|
4
|
+
module Datasets
|
5
|
+
module CIFAR
|
6
|
+
def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0)
|
7
|
+
get_cifar(10, with_label, ndim, scale)
|
8
|
+
end
|
9
|
+
|
10
|
+
def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0)
|
11
|
+
get_cifar(100, with_label, ndim, scale)
|
12
|
+
end
|
13
|
+
|
14
|
+
def self.get_cifar(n_classes, with_label, ndim, scale)
|
15
|
+
train_data = []
|
16
|
+
train_labels = []
|
17
|
+
::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record|
|
18
|
+
train_data << record.pixels
|
19
|
+
train_labels << (n_classes == 10 ? record.label : record.fine_label)
|
20
|
+
end
|
21
|
+
|
22
|
+
test_data = []
|
23
|
+
test_labels = []
|
24
|
+
::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record|
|
25
|
+
test_data << record.pixels
|
26
|
+
test_labels << (n_classes == 10 ? record.label : record.fine_label)
|
27
|
+
end
|
28
|
+
|
29
|
+
[
|
30
|
+
preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
|
31
|
+
preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
|
32
|
+
]
|
33
|
+
end
|
34
|
+
|
35
|
+
def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
|
36
|
+
if ndim == 1
|
37
|
+
images = images.reshape(images.shape[0], 3072)
|
38
|
+
elsif ndim == 3
|
39
|
+
images = images.reshape(images.shape[0], 3, 32, 32)
|
40
|
+
else
|
41
|
+
raise 'invalid ndim for CIFAR dataset'
|
42
|
+
end
|
43
|
+
images = images.cast_to(Numo::DFloat)
|
44
|
+
images *= scale / 255.0
|
45
|
+
|
46
|
+
if withlabel
|
47
|
+
labels = labels.cast_to(Numo::Int32)
|
48
|
+
TupleDataset.new(images, labels)
|
49
|
+
else
|
50
|
+
images
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|