red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -0,0 +1,159 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class Deconvolution2DFunction < Chainer::FunctionNode
|
5
|
+
attr_reader :sy, :sx, :ph, :pw, :cover_all
|
6
|
+
|
7
|
+
# Two dimensional deconvolution function.
|
8
|
+
#
|
9
|
+
# This is an implementation of two-dimensional deconvolution.
|
10
|
+
# In most of deep learning frameworks and papers,
|
11
|
+
# this function is called <b>transposed convolution</b>.
|
12
|
+
# But because of historical reasons (e.g. paper by Ziller Deconvolutional Networks) and backward compatibility,
|
13
|
+
# this function is called +deconvolution+ in Chainer.
|
14
|
+
#
|
15
|
+
# It takes three variables: input image +x+,
|
16
|
+
# the filter weight +W+, and the bias vector +b+.
|
17
|
+
#
|
18
|
+
# - $n$ is the batch size.
|
19
|
+
# - $c_I$ and $c_O$ are the number of the input and output channels, respectively.
|
20
|
+
# - $h_I$ and $w_I$ are the height and width of the input image, respectively.
|
21
|
+
# - $h_K$ and $w_K$ are the height and width of the filters, respectively.
|
22
|
+
# - $h_P$ and $w_P$ are the height and width of the spatial padding size, respectively.
|
23
|
+
#
|
24
|
+
# Let $(s_Y, s_X)$ be the stride of filter application.
|
25
|
+
# Then, the output size $(h_O, w_O)$ is estimated by the following equations:
|
26
|
+
#
|
27
|
+
# $
|
28
|
+
# h_O &= s_Y (h_I - 1) + h_K - 2h_P,
|
29
|
+
# w_O &= s_X (w_I - 1) + w_K - 2w_P.
|
30
|
+
# $
|
31
|
+
#
|
32
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable of shape $(n, c_I, h_I, w_I)$.
|
33
|
+
# @param [Chainer::Variable or Numo::NArray] w Weight variable of shape $(c_I, c_O, h_K, w_K)$.
|
34
|
+
# @param [Chainer::Variable or Numo::NArray] b Bias variable of length $c_O$ (optional).
|
35
|
+
# @param [integer or Array<integer>] stride Stride of filter applications. +stride=s+ and +stride=[s, s]+ are equivalent.
|
36
|
+
# @param [integer or Array<integer>] pad Spatial padding width for input arrays. +pad=p+ and +pad=[p, p]+ are equivalent.
|
37
|
+
# @param [integer or Arrat<integer>] outsize Expected output size of deconvolutional operation.
|
38
|
+
# It should be pair of height and width $(h_O, w_O)$.
|
39
|
+
# Default value is +nil+ and the outsize is estimated by input size, stride and pad.
|
40
|
+
# @return [Chainer::Variable] Output variable of shape $(n, c_O, h_O, w_O)$.
|
41
|
+
#
|
42
|
+
# Example
|
43
|
+
# > n = 10
|
44
|
+
# > c_i, c_o = 1, 3
|
45
|
+
# > h_i, w_i = 5, 10
|
46
|
+
# > h_k, w_k = 10, 10
|
47
|
+
# > h_p, w_p = 5, 5
|
48
|
+
# > x = Numo::DFloat.new(n, c_i, h_i, w_i).rand
|
49
|
+
# > x.shape
|
50
|
+
# => [10, 1, 5, 10]
|
51
|
+
# > w = Numo::DFloat.new(c_i, c_o, h_k, w_k).rand
|
52
|
+
# > w.shape
|
53
|
+
# => [1, 3, 10, 10]
|
54
|
+
# > b = Numo::DFloat.new(c_o).rand
|
55
|
+
# > b.shape
|
56
|
+
# => [3]
|
57
|
+
# > s_y, s_x = 5, 5
|
58
|
+
# > y = Chainer::Functions::Connection::Deconvolution2DFunction.deconvolution_2d(x, w, b: b, stride: [s_y, s_x], pad: [h_p, w_p])
|
59
|
+
# > y.shape
|
60
|
+
# => [10, 3, 20, 45]
|
61
|
+
# > h_o = s_y * (h_i - 1) + h_k - 2 * h_p
|
62
|
+
# > w_o = s_x * (w_i - 1) + w_k - 2 * w_p
|
63
|
+
# > y.shape == [n, c_o, h_o, w_o]
|
64
|
+
# => true
|
65
|
+
def self.deconvolution_2d(x, w, b: nil, stride: 1, pad: 0, outsize: nil)
|
66
|
+
func = Deconvolution2DFunction.new(stride: stride, pad: pad, outsize: outsize)
|
67
|
+
if b.nil?
|
68
|
+
args = x, w
|
69
|
+
else
|
70
|
+
args = x, w, b
|
71
|
+
end
|
72
|
+
func.apply(args).first
|
73
|
+
end
|
74
|
+
|
75
|
+
def initialize(stride: 1, pad: 0, outsize: nil)
|
76
|
+
@cover_all = nil
|
77
|
+
|
78
|
+
@sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
|
79
|
+
@ph, @pw = pad.is_a?(::Array) ? pad : [pad, pad]
|
80
|
+
@outh, @outw = outsize.nil? ? [nil, nil] : outsize
|
81
|
+
end
|
82
|
+
|
83
|
+
def forward(inputs)
|
84
|
+
retain_inputs([0, 1])
|
85
|
+
x, w = inputs[0...2]
|
86
|
+
b = inputs.size == 3 ? inputs[2] : nil
|
87
|
+
|
88
|
+
unless inputs.all? { |i| i.is_a?(Numo::NArray) }
|
89
|
+
if b.nil?
|
90
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}"
|
91
|
+
else
|
92
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}, b: #{b.class}"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
kh, kw = w.shape[2..-1]
|
97
|
+
_, _, x_h, x_w = x.shape
|
98
|
+
|
99
|
+
gcol = Chainer::Utils::Math.tensordot(w, x, [0, 1]).cast_to(x.class)
|
100
|
+
# - k, m, n: shape of out_channel
|
101
|
+
# - b: number of inputs
|
102
|
+
# - h, w: height and width of kernels
|
103
|
+
# k, m, n, b, h, w -> b, k, m, n, h, w
|
104
|
+
gcol = gcol.transpose(3, 0, 1, 2, 4, 5)
|
105
|
+
|
106
|
+
if @outh.nil?
|
107
|
+
@outh = Chainer::Utils::Conv.get_deconv_outsize(x_h, kh, @sy, @ph)
|
108
|
+
raise TypeError, 'Height in the output should be positive.' if @outh <= 0
|
109
|
+
end
|
110
|
+
if @outw.nil?
|
111
|
+
@outw = Chainer::Utils::Conv.get_deconv_outsize(x_w, kw, @sx, @pw)
|
112
|
+
raise TypeError, 'Width in the output should be positive.' if @outw <= 0
|
113
|
+
end
|
114
|
+
|
115
|
+
y = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, @outh, @outw)
|
116
|
+
if !b.nil?
|
117
|
+
y += b.reshape(1, b.size, 1, 1)
|
118
|
+
end
|
119
|
+
[y]
|
120
|
+
end
|
121
|
+
|
122
|
+
def backward(indexes, grad_outputs)
|
123
|
+
x, w = get_retained_inputs
|
124
|
+
gy = grad_outputs.first
|
125
|
+
|
126
|
+
ret = []
|
127
|
+
|
128
|
+
if indexes.include?(0)
|
129
|
+
set_cover_all(x, w) if @cover_all.nil?
|
130
|
+
gw = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(gy, w, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
|
131
|
+
ret << gw
|
132
|
+
end
|
133
|
+
|
134
|
+
if indexes.include?(1)
|
135
|
+
set_cover_all(x, w) if @cover_all.nil?
|
136
|
+
gw = Chainer::Functions::Connection::Convolution2DGradW.new(self).apply([gy, x]).first
|
137
|
+
ret << gw
|
138
|
+
end
|
139
|
+
|
140
|
+
if indexes.include?(2)
|
141
|
+
gb = Chainer::Functions::Math::Sum.sum(gy, axis: [0, 2, 3])
|
142
|
+
ret << gb
|
143
|
+
end
|
144
|
+
|
145
|
+
ret
|
146
|
+
end
|
147
|
+
|
148
|
+
private
|
149
|
+
|
150
|
+
def set_cover_all(x, w)
|
151
|
+
in_h, in_w = x.shape[2..-1]
|
152
|
+
kh, kw = w.shape[2..-1]
|
153
|
+
|
154
|
+
@cover_all = in_h != Chainer::Utils::Conv.get_conv_outsize(@outh, kh, @sy, @ph) || in_w != Chainer::Utils::Conv.get_conv_outsize(@outw, kw, @sx, @pw)
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
@@ -1,17 +1,23 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Connection
|
4
|
-
class LinearFunction < Chainer::
|
4
|
+
class LinearFunction < Chainer::FunctionNode
|
5
5
|
def self.linear(x, w, b=nil)
|
6
|
+
if x.ndim > 2
|
7
|
+
x = x.reshape(x.shape.first, -1)
|
8
|
+
end
|
9
|
+
|
6
10
|
if b.nil?
|
7
|
-
|
11
|
+
args = x, w
|
8
12
|
else
|
9
|
-
|
13
|
+
args = x, w, b
|
10
14
|
end
|
15
|
+
|
16
|
+
self.new.apply(args).first
|
11
17
|
end
|
12
18
|
|
13
19
|
def forward(inputs)
|
14
|
-
x =
|
20
|
+
x = inputs[0]
|
15
21
|
w = inputs[1]
|
16
22
|
|
17
23
|
y = x.dot(w.transpose).cast_to(x.class)
|
@@ -19,28 +25,29 @@ module Chainer
|
|
19
25
|
b = inputs[2]
|
20
26
|
y += b
|
21
27
|
end
|
22
|
-
return [y]
|
23
|
-
end
|
24
28
|
|
25
|
-
|
26
|
-
|
27
|
-
w = inputs[1]
|
28
|
-
gy = grad_outputs[0]
|
29
|
-
gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape)
|
30
|
-
gw = gy.transpose.dot(x).cast_to(w.class)
|
31
|
-
if inputs.size == 3
|
32
|
-
gb = gy.sum(0)
|
33
|
-
[gx, gw, gb]
|
34
|
-
else
|
35
|
-
[gx, gw]
|
36
|
-
end
|
29
|
+
retain_inputs([0, 1])
|
30
|
+
return [y]
|
37
31
|
end
|
38
32
|
|
39
|
-
|
33
|
+
def backward(indexes, grad_outputs)
|
34
|
+
x, w = get_retained_inputs
|
35
|
+
gy = grad_outputs.first
|
40
36
|
|
41
|
-
|
42
|
-
|
43
|
-
|
37
|
+
ret = []
|
38
|
+
if indexes.include?(0)
|
39
|
+
gx = LinearFunction.linear(gy, w.transpose)
|
40
|
+
ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
|
41
|
+
end
|
42
|
+
if indexes.include?(1)
|
43
|
+
gw = LinearFunction.linear(gy.transpose, x.transpose)
|
44
|
+
ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
|
45
|
+
end
|
46
|
+
if indexes.include?(2)
|
47
|
+
gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
|
48
|
+
ret << gb
|
49
|
+
end
|
50
|
+
ret
|
44
51
|
end
|
45
52
|
end
|
46
53
|
end
|
@@ -12,13 +12,13 @@ module Chainer
|
|
12
12
|
|
13
13
|
def forward(inputs)
|
14
14
|
y, t = inputs
|
15
|
+
xm = Chainer.get_array_module(*inputs)
|
15
16
|
if @ignore_label
|
16
17
|
mask = t.eq(@ignore_label)
|
17
18
|
ignore_cnt = mask.count
|
18
19
|
|
19
|
-
|
20
|
-
pred =
|
21
|
-
pred = y.class[*pred].reshape(*t.shape)
|
20
|
+
pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
|
21
|
+
pred = pred.reshape(*t.shape)
|
22
22
|
pred[mask] = @ignore_label
|
23
23
|
count = pred.eq(t).count - ignore_cnt
|
24
24
|
|
@@ -30,8 +30,8 @@ module Chainer
|
|
30
30
|
[y.class.cast(count.to_f / total)]
|
31
31
|
end
|
32
32
|
else
|
33
|
-
pred = y.max_index(axis: 1).
|
34
|
-
pred =
|
33
|
+
pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
|
34
|
+
pred = pred.reshape(*t.shape)
|
35
35
|
|
36
36
|
[y.class.cast(y.class[pred.eq(t)].mean)]
|
37
37
|
end
|
@@ -2,31 +2,40 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Loss
|
4
4
|
# Mean squared error (a.k.a. Euclidean loss) function.
|
5
|
-
class MeanSquaredError <
|
5
|
+
class MeanSquaredError < FunctionNode
|
6
6
|
# Mean squared error function.
|
7
7
|
#
|
8
8
|
# This function computes mean squared error between two variables. The mean
|
9
9
|
# is taken over the minibatch. Note that the error is not scaled by 1/2.
|
10
10
|
#
|
11
|
-
# @param [Chainer::Variable or Numo::NArray] x0 Input variable.
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x1 Input variable.
|
11
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x0 Input variable.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x1 Input variable.
|
13
13
|
# @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
|
14
14
|
#
|
15
15
|
def self.mean_squared_error(x0, x1)
|
16
|
-
self.new.(x0, x1)
|
16
|
+
self.new.apply([x0, x1]).first
|
17
17
|
end
|
18
18
|
|
19
|
-
def
|
20
|
-
|
21
|
-
|
22
|
-
diff = @diff.flatten.dup()
|
19
|
+
def forward(inputs)
|
20
|
+
retain_inputs([0, 1])
|
21
|
+
diff = (inputs[0] - inputs[1]).flatten.dup
|
23
22
|
[diff.class.cast(diff.dot(diff) / diff.size)]
|
24
23
|
end
|
25
24
|
|
26
|
-
def backward(
|
27
|
-
|
28
|
-
|
29
|
-
[
|
25
|
+
def backward(indexes, gy)
|
26
|
+
x0, x1 = get_retained_inputs
|
27
|
+
diff = x0 - x1
|
28
|
+
gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape)
|
29
|
+
gx0 = gy0 * diff * (2.0 / diff.size)
|
30
|
+
|
31
|
+
ret = []
|
32
|
+
if indexes.include?(0)
|
33
|
+
ret << gx0
|
34
|
+
end
|
35
|
+
if indexes.include?(1)
|
36
|
+
ret << -gx0
|
37
|
+
end
|
38
|
+
ret
|
30
39
|
end
|
31
40
|
end
|
32
41
|
end
|
@@ -2,68 +2,101 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Loss
|
4
4
|
class SoftmaxCrossEntropy < Function
|
5
|
-
def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
|
6
|
-
|
5
|
+
def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean', enable_double_backprop: false)
|
6
|
+
if enable_double_backprop
|
7
|
+
self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
|
8
|
+
else
|
9
|
+
self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
|
14
|
+
if t.is_a?(Chainer::Variable)
|
15
|
+
t = t.data
|
16
|
+
end
|
17
|
+
|
18
|
+
self.check_class_weight_option(class_weight)
|
19
|
+
self.check_reduce_option(reduce)
|
20
|
+
|
21
|
+
loss = -Activation::LogSoftmax.log_softmax(x)
|
22
|
+
|
23
|
+
if class_weight
|
24
|
+
shape = x.ndim.times.map { |d| d != 1 ? 1 : class_weight.shape[-1] }
|
25
|
+
class_weight = Chainer::Functions::Array::BroadcastTo.broadcast_to(class_weight.reshape(*shape), x.shape)
|
26
|
+
loss = loss * class_weight
|
27
|
+
end
|
28
|
+
|
29
|
+
dtype = x.is_a?(Chainer::Variable) ? x.dtype : x.class
|
30
|
+
in_use = t.ne(ignore_label).cast_to(dtype)
|
31
|
+
|
32
|
+
loss = Chainer::Functions::Array::Rollaxis.rollaxis(loss, 1, start: loss.ndim)
|
33
|
+
|
34
|
+
# TODO: loss = chainer.functions.reshape(loss, (-1, loss.shape[-1]))
|
35
|
+
shape = loss.shape
|
36
|
+
last_shape = shape.pop
|
37
|
+
loss = Chainer::Functions::Array::Reshape.reshape(loss, [shape.inject(:*), last_shape])
|
38
|
+
|
39
|
+
# Replace ignore_label value with one valid for F.select_item below.
|
40
|
+
t = t.clip(0, loss.shape[1] - 1)
|
41
|
+
|
42
|
+
loss = Chainer::Functions::Array::SelectItem.select_item(loss, t.flatten.dup)
|
43
|
+
loss = Chainer::Functions::Array::Reshape.reshape(loss, t.shape)
|
44
|
+
|
45
|
+
loss = loss * in_use
|
46
|
+
|
47
|
+
if reduce == "mean"
|
48
|
+
count = normalize ? in_use.sum : x.shape.first
|
49
|
+
count = [count, 1.0].max
|
50
|
+
loss = loss * (1.0 / count)
|
51
|
+
return Chainer::Functions::Math::Sum.sum(loss)
|
52
|
+
else
|
53
|
+
return loss
|
54
|
+
end
|
7
55
|
end
|
8
56
|
|
9
57
|
def initialize(normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
|
10
58
|
@normalize = normalize
|
11
59
|
@cache_score = cache_score
|
60
|
+
self.class.check_class_weight_option(class_weight)
|
12
61
|
@class_weight = class_weight
|
13
62
|
|
14
|
-
unless class_weight.nil?
|
15
|
-
if @class_weight.ndim != 1
|
16
|
-
raise ArgumentError, 'class_weight.ndim should be 1'
|
17
|
-
elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
|
18
|
-
raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
|
19
|
-
elsif @class_weight.kind_of?(Chainer::Variable)
|
20
|
-
raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
|
21
|
-
end
|
22
|
-
end
|
23
|
-
|
24
63
|
@ignore_label = ignore_label
|
25
|
-
unless ['mean', 'no'].include?(reduce)
|
26
|
-
raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
27
|
-
end
|
28
64
|
|
65
|
+
self.class.check_reduce_option(reduce)
|
29
66
|
@reduce = reduce
|
30
67
|
end
|
31
68
|
|
32
|
-
def
|
69
|
+
def forward(inputs)
|
70
|
+
xm = Chainer.get_array_module(*inputs)
|
33
71
|
x, t = inputs
|
34
72
|
log_y = Activation._log_softmax(x)
|
35
73
|
|
36
74
|
if @cache_score
|
37
|
-
@y =
|
75
|
+
@y = xm::NMath.exp(log_y)
|
38
76
|
end
|
39
77
|
if @class_weight
|
40
78
|
shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
|
41
|
-
log_y *= Chainer::
|
79
|
+
log_y *= Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
42
80
|
end
|
43
|
-
log_yd = Chainer::
|
81
|
+
log_yd = Chainer::Utils::Array.rollaxis(log_y, 1)
|
44
82
|
begin
|
45
83
|
log_yd = log_yd.reshape(log_yd.shape[0], true)
|
46
84
|
rescue ArgumentError
|
47
85
|
end
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
log_p = []
|
54
|
-
ravel_arr.each_with_index do |r, i|
|
55
|
-
log_p << log_yd[r, i]
|
86
|
+
|
87
|
+
log_p = log_yd[t.class.maximum(t.flatten, 0), t.class.new(t.size).seq].diagonal
|
88
|
+
if @ignore_label
|
89
|
+
t_valid= t.ne(@ignore_label)
|
90
|
+
log_p *= t_valid.flatten
|
56
91
|
end
|
57
|
-
log_p = log_yd.class.[](*log_p)
|
58
|
-
log_p[t.flatten.dup.eq(@ignore_label)] = 0
|
59
92
|
|
60
93
|
if @reduce == 'mean'
|
61
|
-
if @normalize
|
62
|
-
|
94
|
+
if @normalize and t_valid
|
95
|
+
@coeff = 1.0 / log_p.class.maximum(Chainer::Utils::Array.force_array(t_valid.count), 1)
|
63
96
|
else
|
64
97
|
count = x.shape[0]
|
98
|
+
@coeff = 1.0 / [count, 1].max
|
65
99
|
end
|
66
|
-
@coeff = 1.0 / [count, 1].max
|
67
100
|
y = log_p.sum(keepdims: true) * (-@coeff)
|
68
101
|
[y.class.cast(y[0])]
|
69
102
|
else
|
@@ -71,7 +104,8 @@ module Chainer
|
|
71
104
|
end
|
72
105
|
end
|
73
106
|
|
74
|
-
def
|
107
|
+
def backward(inputs, grad_outputs)
|
108
|
+
xm = Chainer.get_array_module(*(inputs + grad_outputs))
|
75
109
|
x, t = inputs
|
76
110
|
gloss = grad_outputs[0]
|
77
111
|
|
@@ -79,24 +113,24 @@ module Chainer
|
|
79
113
|
y = @y.dup
|
80
114
|
else
|
81
115
|
y = Activation._log_softmax(x)
|
82
|
-
y =
|
116
|
+
y = xm::NMath.exp(y)
|
83
117
|
end
|
84
118
|
|
85
119
|
if y.ndim == 2
|
86
120
|
gx = y
|
121
|
+
# TODO(sonots): Avoid to_a especially in Cumo to improve performance
|
87
122
|
t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
|
88
123
|
|
89
124
|
if @class_weight
|
90
125
|
shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
|
91
|
-
c = Chainer::
|
92
|
-
c = c
|
93
|
-
gx *= Chainer::
|
126
|
+
c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
127
|
+
c = c[t.class.new(t.shape[0]).seq, t.class.maximum(t, 0)].diagonal.dup
|
128
|
+
gx *= Chainer::Utils::Array.broadcast_to(c.expand_dims(1), gx.shape)
|
94
129
|
end
|
95
130
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
gx *= bit.reshape(t.shape[0], 1)
|
131
|
+
if @ignore_label
|
132
|
+
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1)
|
133
|
+
end
|
100
134
|
else
|
101
135
|
# in the case where y.ndim is higher than 2,
|
102
136
|
# we think that a current implementation is inefficient
|
@@ -104,18 +138,21 @@ module Chainer
|
|
104
138
|
|
105
139
|
n_unit = t.size / t.shape[0]
|
106
140
|
gx = y.reshape(y.shape[0], y.shape[1], true)
|
107
|
-
fst_index =
|
108
|
-
trd_index =
|
141
|
+
fst_index = xm::Int32.new(t.size).seq(0) / n_unit
|
142
|
+
trd_index = xm::Int32.new(t.size).seq(0) % n_unit
|
143
|
+
# TODO(sonots): Avoid to_a especially in Cumo to improve performance
|
109
144
|
fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
|
110
145
|
if @class_weight
|
111
146
|
shape = x.ndim.times.map{|d| d == 1 ? true : 1}
|
112
|
-
c = Chainer::
|
147
|
+
c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
113
148
|
c = c.reshape(*gx.shape)
|
114
|
-
c = c
|
149
|
+
c = c[fst_index, t.class.maximum(t.flatten.dup, 0), trd_index].diagonal.diagonal.dup
|
115
150
|
c = c.reshape(y.shape[0], 1, true)
|
116
|
-
gx *= Chainer::
|
151
|
+
gx *= Chainer::Utils::Array.broadcast_to(c, gx.shape)
|
152
|
+
end
|
153
|
+
if @ignore_label
|
154
|
+
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
|
117
155
|
end
|
118
|
-
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
|
119
156
|
gx = gx.reshape(*y.shape)
|
120
157
|
end
|
121
158
|
|
@@ -126,36 +163,26 @@ module Chainer
|
|
126
163
|
end
|
127
164
|
return [gx, nil]
|
128
165
|
end
|
129
|
-
end
|
130
166
|
|
131
|
-
|
132
|
-
|
133
|
-
axes.delete_at(axis)
|
134
|
-
axes.insert(start <= axes.size ? start : -1, axis)
|
135
|
-
y.transpose(*axes)
|
136
|
-
end
|
167
|
+
def self.check_class_weight_option(class_weight)
|
168
|
+
return if class_weight.nil?
|
137
169
|
|
138
|
-
|
139
|
-
|
140
|
-
|
170
|
+
xm = Chainer.get_array_module(@class_weight)
|
171
|
+
if class_weight.ndim != 1
|
172
|
+
raise ArgumentError, 'class_weight.ndim should be 1'
|
173
|
+
elsif (class_weight.class != xm::DFloat) and (class_weight.class != xm::SFloat)
|
174
|
+
raise ArgumentError, "The dtype of class_weight should be 'DFloat' or 'SFloat'"
|
175
|
+
elsif class_weight.kind_of?(Chainer::Variable)
|
176
|
+
raise ArgumentError, 'class_weight should be a NArray, not a chainer.Variable'
|
177
|
+
end
|
141
178
|
end
|
142
179
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
if array.shape[i] == 1
|
147
|
-
tile_shape << s
|
148
|
-
elsif array.shape[i] == s
|
149
|
-
tile_shape << 1
|
150
|
-
else
|
151
|
-
raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
|
180
|
+
def self.check_reduce_option(reduce)
|
181
|
+
unless ['mean', 'no'].include?(reduce)
|
182
|
+
raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
152
183
|
end
|
153
|
-
|
154
|
-
|
155
|
-
array.tile(*shape[0...-array.shape.size], *tile_shape)
|
184
|
+
end
|
156
185
|
end
|
157
|
-
|
158
|
-
module_function :rollaxis, :broadcast_to
|
159
186
|
end
|
160
187
|
end
|
161
188
|
end
|