red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -0,0 +1,159 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class Deconvolution2DFunction < Chainer::FunctionNode
|
5
|
+
attr_reader :sy, :sx, :ph, :pw, :cover_all
|
6
|
+
|
7
|
+
# Two dimensional deconvolution function.
|
8
|
+
#
|
9
|
+
# This is an implementation of two-dimensional deconvolution.
|
10
|
+
# In most of deep learning frameworks and papers,
|
11
|
+
# this function is called <b>transposed convolution</b>.
|
12
|
+
# But because of historical reasons (e.g. paper by Ziller Deconvolutional Networks) and backward compatibility,
|
13
|
+
# this function is called +deconvolution+ in Chainer.
|
14
|
+
#
|
15
|
+
# It takes three variables: input image +x+,
|
16
|
+
# the filter weight +W+, and the bias vector +b+.
|
17
|
+
#
|
18
|
+
# - $n$ is the batch size.
|
19
|
+
# - $c_I$ and $c_O$ are the number of the input and output channels, respectively.
|
20
|
+
# - $h_I$ and $w_I$ are the height and width of the input image, respectively.
|
21
|
+
# - $h_K$ and $w_K$ are the height and width of the filters, respectively.
|
22
|
+
# - $h_P$ and $w_P$ are the height and width of the spatial padding size, respectively.
|
23
|
+
#
|
24
|
+
# Let $(s_Y, s_X)$ be the stride of filter application.
|
25
|
+
# Then, the output size $(h_O, w_O)$ is estimated by the following equations:
|
26
|
+
#
|
27
|
+
# $
|
28
|
+
# h_O &= s_Y (h_I - 1) + h_K - 2h_P,
|
29
|
+
# w_O &= s_X (w_I - 1) + w_K - 2w_P.
|
30
|
+
# $
|
31
|
+
#
|
32
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable of shape $(n, c_I, h_I, w_I)$.
|
33
|
+
# @param [Chainer::Variable or Numo::NArray] w Weight variable of shape $(c_I, c_O, h_K, w_K)$.
|
34
|
+
# @param [Chainer::Variable or Numo::NArray] b Bias variable of length $c_O$ (optional).
|
35
|
+
# @param [integer or Array<integer>] stride Stride of filter applications. +stride=s+ and +stride=[s, s]+ are equivalent.
|
36
|
+
# @param [integer or Array<integer>] pad Spatial padding width for input arrays. +pad=p+ and +pad=[p, p]+ are equivalent.
|
37
|
+
# @param [integer or Arrat<integer>] outsize Expected output size of deconvolutional operation.
|
38
|
+
# It should be pair of height and width $(h_O, w_O)$.
|
39
|
+
# Default value is +nil+ and the outsize is estimated by input size, stride and pad.
|
40
|
+
# @return [Chainer::Variable] Output variable of shape $(n, c_O, h_O, w_O)$.
|
41
|
+
#
|
42
|
+
# Example
|
43
|
+
# > n = 10
|
44
|
+
# > c_i, c_o = 1, 3
|
45
|
+
# > h_i, w_i = 5, 10
|
46
|
+
# > h_k, w_k = 10, 10
|
47
|
+
# > h_p, w_p = 5, 5
|
48
|
+
# > x = Numo::DFloat.new(n, c_i, h_i, w_i).rand
|
49
|
+
# > x.shape
|
50
|
+
# => [10, 1, 5, 10]
|
51
|
+
# > w = Numo::DFloat.new(c_i, c_o, h_k, w_k).rand
|
52
|
+
# > w.shape
|
53
|
+
# => [1, 3, 10, 10]
|
54
|
+
# > b = Numo::DFloat.new(c_o).rand
|
55
|
+
# > b.shape
|
56
|
+
# => [3]
|
57
|
+
# > s_y, s_x = 5, 5
|
58
|
+
# > y = Chainer::Functions::Connection::Deconvolution2DFunction.deconvolution_2d(x, w, b: b, stride: [s_y, s_x], pad: [h_p, w_p])
|
59
|
+
# > y.shape
|
60
|
+
# => [10, 3, 20, 45]
|
61
|
+
# > h_o = s_y * (h_i - 1) + h_k - 2 * h_p
|
62
|
+
# > w_o = s_x * (w_i - 1) + w_k - 2 * w_p
|
63
|
+
# > y.shape == [n, c_o, h_o, w_o]
|
64
|
+
# => true
|
65
|
+
def self.deconvolution_2d(x, w, b: nil, stride: 1, pad: 0, outsize: nil)
|
66
|
+
func = Deconvolution2DFunction.new(stride: stride, pad: pad, outsize: outsize)
|
67
|
+
if b.nil?
|
68
|
+
args = x, w
|
69
|
+
else
|
70
|
+
args = x, w, b
|
71
|
+
end
|
72
|
+
func.apply(args).first
|
73
|
+
end
|
74
|
+
|
75
|
+
def initialize(stride: 1, pad: 0, outsize: nil)
|
76
|
+
@cover_all = nil
|
77
|
+
|
78
|
+
@sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
|
79
|
+
@ph, @pw = pad.is_a?(::Array) ? pad : [pad, pad]
|
80
|
+
@outh, @outw = outsize.nil? ? [nil, nil] : outsize
|
81
|
+
end
|
82
|
+
|
83
|
+
def forward(inputs)
|
84
|
+
retain_inputs([0, 1])
|
85
|
+
x, w = inputs[0...2]
|
86
|
+
b = inputs.size == 3 ? inputs[2] : nil
|
87
|
+
|
88
|
+
unless inputs.all? { |i| i.is_a?(Numo::NArray) }
|
89
|
+
if b.nil?
|
90
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}"
|
91
|
+
else
|
92
|
+
raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}, b: #{b.class}"
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
kh, kw = w.shape[2..-1]
|
97
|
+
_, _, x_h, x_w = x.shape
|
98
|
+
|
99
|
+
gcol = Chainer::Utils::Math.tensordot(w, x, [0, 1]).cast_to(x.class)
|
100
|
+
# - k, m, n: shape of out_channel
|
101
|
+
# - b: number of inputs
|
102
|
+
# - h, w: height and width of kernels
|
103
|
+
# k, m, n, b, h, w -> b, k, m, n, h, w
|
104
|
+
gcol = gcol.transpose(3, 0, 1, 2, 4, 5)
|
105
|
+
|
106
|
+
if @outh.nil?
|
107
|
+
@outh = Chainer::Utils::Conv.get_deconv_outsize(x_h, kh, @sy, @ph)
|
108
|
+
raise TypeError, 'Height in the output should be positive.' if @outh <= 0
|
109
|
+
end
|
110
|
+
if @outw.nil?
|
111
|
+
@outw = Chainer::Utils::Conv.get_deconv_outsize(x_w, kw, @sx, @pw)
|
112
|
+
raise TypeError, 'Width in the output should be positive.' if @outw <= 0
|
113
|
+
end
|
114
|
+
|
115
|
+
y = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, @outh, @outw)
|
116
|
+
if !b.nil?
|
117
|
+
y += b.reshape(1, b.size, 1, 1)
|
118
|
+
end
|
119
|
+
[y]
|
120
|
+
end
|
121
|
+
|
122
|
+
def backward(indexes, grad_outputs)
|
123
|
+
x, w = get_retained_inputs
|
124
|
+
gy = grad_outputs.first
|
125
|
+
|
126
|
+
ret = []
|
127
|
+
|
128
|
+
if indexes.include?(0)
|
129
|
+
set_cover_all(x, w) if @cover_all.nil?
|
130
|
+
gw = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(gy, w, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
|
131
|
+
ret << gw
|
132
|
+
end
|
133
|
+
|
134
|
+
if indexes.include?(1)
|
135
|
+
set_cover_all(x, w) if @cover_all.nil?
|
136
|
+
gw = Chainer::Functions::Connection::Convolution2DGradW.new(self).apply([gy, x]).first
|
137
|
+
ret << gw
|
138
|
+
end
|
139
|
+
|
140
|
+
if indexes.include?(2)
|
141
|
+
gb = Chainer::Functions::Math::Sum.sum(gy, axis: [0, 2, 3])
|
142
|
+
ret << gb
|
143
|
+
end
|
144
|
+
|
145
|
+
ret
|
146
|
+
end
|
147
|
+
|
148
|
+
private
|
149
|
+
|
150
|
+
def set_cover_all(x, w)
|
151
|
+
in_h, in_w = x.shape[2..-1]
|
152
|
+
kh, kw = w.shape[2..-1]
|
153
|
+
|
154
|
+
@cover_all = in_h != Chainer::Utils::Conv.get_conv_outsize(@outh, kh, @sy, @ph) || in_w != Chainer::Utils::Conv.get_conv_outsize(@outw, kw, @sx, @pw)
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
@@ -1,17 +1,23 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Functions
|
3
3
|
module Connection
|
4
|
-
class LinearFunction < Chainer::
|
4
|
+
class LinearFunction < Chainer::FunctionNode
|
5
5
|
def self.linear(x, w, b=nil)
|
6
|
+
if x.ndim > 2
|
7
|
+
x = x.reshape(x.shape.first, -1)
|
8
|
+
end
|
9
|
+
|
6
10
|
if b.nil?
|
7
|
-
|
11
|
+
args = x, w
|
8
12
|
else
|
9
|
-
|
13
|
+
args = x, w, b
|
10
14
|
end
|
15
|
+
|
16
|
+
self.new.apply(args).first
|
11
17
|
end
|
12
18
|
|
13
19
|
def forward(inputs)
|
14
|
-
x =
|
20
|
+
x = inputs[0]
|
15
21
|
w = inputs[1]
|
16
22
|
|
17
23
|
y = x.dot(w.transpose).cast_to(x.class)
|
@@ -19,28 +25,29 @@ module Chainer
|
|
19
25
|
b = inputs[2]
|
20
26
|
y += b
|
21
27
|
end
|
22
|
-
return [y]
|
23
|
-
end
|
24
28
|
|
25
|
-
|
26
|
-
|
27
|
-
w = inputs[1]
|
28
|
-
gy = grad_outputs[0]
|
29
|
-
gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape)
|
30
|
-
gw = gy.transpose.dot(x).cast_to(w.class)
|
31
|
-
if inputs.size == 3
|
32
|
-
gb = gy.sum(0)
|
33
|
-
[gx, gw, gb]
|
34
|
-
else
|
35
|
-
[gx, gw]
|
36
|
-
end
|
29
|
+
retain_inputs([0, 1])
|
30
|
+
return [y]
|
37
31
|
end
|
38
32
|
|
39
|
-
|
33
|
+
def backward(indexes, grad_outputs)
|
34
|
+
x, w = get_retained_inputs
|
35
|
+
gy = grad_outputs.first
|
40
36
|
|
41
|
-
|
42
|
-
|
43
|
-
|
37
|
+
ret = []
|
38
|
+
if indexes.include?(0)
|
39
|
+
gx = LinearFunction.linear(gy, w.transpose)
|
40
|
+
ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
|
41
|
+
end
|
42
|
+
if indexes.include?(1)
|
43
|
+
gw = LinearFunction.linear(gy.transpose, x.transpose)
|
44
|
+
ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
|
45
|
+
end
|
46
|
+
if indexes.include?(2)
|
47
|
+
gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
|
48
|
+
ret << gb
|
49
|
+
end
|
50
|
+
ret
|
44
51
|
end
|
45
52
|
end
|
46
53
|
end
|
@@ -12,13 +12,13 @@ module Chainer
|
|
12
12
|
|
13
13
|
def forward(inputs)
|
14
14
|
y, t = inputs
|
15
|
+
xm = Chainer.get_array_module(*inputs)
|
15
16
|
if @ignore_label
|
16
17
|
mask = t.eq(@ignore_label)
|
17
18
|
ignore_cnt = mask.count
|
18
19
|
|
19
|
-
|
20
|
-
pred =
|
21
|
-
pred = y.class[*pred].reshape(*t.shape)
|
20
|
+
pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
|
21
|
+
pred = pred.reshape(*t.shape)
|
22
22
|
pred[mask] = @ignore_label
|
23
23
|
count = pred.eq(t).count - ignore_cnt
|
24
24
|
|
@@ -30,8 +30,8 @@ module Chainer
|
|
30
30
|
[y.class.cast(count.to_f / total)]
|
31
31
|
end
|
32
32
|
else
|
33
|
-
pred = y.max_index(axis: 1).
|
34
|
-
pred =
|
33
|
+
pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
|
34
|
+
pred = pred.reshape(*t.shape)
|
35
35
|
|
36
36
|
[y.class.cast(y.class[pred.eq(t)].mean)]
|
37
37
|
end
|
@@ -2,31 +2,40 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Loss
|
4
4
|
# Mean squared error (a.k.a. Euclidean loss) function.
|
5
|
-
class MeanSquaredError <
|
5
|
+
class MeanSquaredError < FunctionNode
|
6
6
|
# Mean squared error function.
|
7
7
|
#
|
8
8
|
# This function computes mean squared error between two variables. The mean
|
9
9
|
# is taken over the minibatch. Note that the error is not scaled by 1/2.
|
10
10
|
#
|
11
|
-
# @param [Chainer::Variable or Numo::NArray] x0 Input variable.
|
12
|
-
# @param [Chainer::Variable or Numo::NArray] x1 Input variable.
|
11
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x0 Input variable.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x1 Input variable.
|
13
13
|
# @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
|
14
14
|
#
|
15
15
|
def self.mean_squared_error(x0, x1)
|
16
|
-
self.new.(x0, x1)
|
16
|
+
self.new.apply([x0, x1]).first
|
17
17
|
end
|
18
18
|
|
19
|
-
def
|
20
|
-
|
21
|
-
|
22
|
-
diff = @diff.flatten.dup()
|
19
|
+
def forward(inputs)
|
20
|
+
retain_inputs([0, 1])
|
21
|
+
diff = (inputs[0] - inputs[1]).flatten.dup
|
23
22
|
[diff.class.cast(diff.dot(diff) / diff.size)]
|
24
23
|
end
|
25
24
|
|
26
|
-
def backward(
|
27
|
-
|
28
|
-
|
29
|
-
[
|
25
|
+
def backward(indexes, gy)
|
26
|
+
x0, x1 = get_retained_inputs
|
27
|
+
diff = x0 - x1
|
28
|
+
gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape)
|
29
|
+
gx0 = gy0 * diff * (2.0 / diff.size)
|
30
|
+
|
31
|
+
ret = []
|
32
|
+
if indexes.include?(0)
|
33
|
+
ret << gx0
|
34
|
+
end
|
35
|
+
if indexes.include?(1)
|
36
|
+
ret << -gx0
|
37
|
+
end
|
38
|
+
ret
|
30
39
|
end
|
31
40
|
end
|
32
41
|
end
|
@@ -2,68 +2,101 @@ module Chainer
|
|
2
2
|
module Functions
|
3
3
|
module Loss
|
4
4
|
class SoftmaxCrossEntropy < Function
|
5
|
-
def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
|
6
|
-
|
5
|
+
def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean', enable_double_backprop: false)
|
6
|
+
if enable_double_backprop
|
7
|
+
self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
|
8
|
+
else
|
9
|
+
self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
|
14
|
+
if t.is_a?(Chainer::Variable)
|
15
|
+
t = t.data
|
16
|
+
end
|
17
|
+
|
18
|
+
self.check_class_weight_option(class_weight)
|
19
|
+
self.check_reduce_option(reduce)
|
20
|
+
|
21
|
+
loss = -Activation::LogSoftmax.log_softmax(x)
|
22
|
+
|
23
|
+
if class_weight
|
24
|
+
shape = x.ndim.times.map { |d| d != 1 ? 1 : class_weight.shape[-1] }
|
25
|
+
class_weight = Chainer::Functions::Array::BroadcastTo.broadcast_to(class_weight.reshape(*shape), x.shape)
|
26
|
+
loss = loss * class_weight
|
27
|
+
end
|
28
|
+
|
29
|
+
dtype = x.is_a?(Chainer::Variable) ? x.dtype : x.class
|
30
|
+
in_use = t.ne(ignore_label).cast_to(dtype)
|
31
|
+
|
32
|
+
loss = Chainer::Functions::Array::Rollaxis.rollaxis(loss, 1, start: loss.ndim)
|
33
|
+
|
34
|
+
# TODO: loss = chainer.functions.reshape(loss, (-1, loss.shape[-1]))
|
35
|
+
shape = loss.shape
|
36
|
+
last_shape = shape.pop
|
37
|
+
loss = Chainer::Functions::Array::Reshape.reshape(loss, [shape.inject(:*), last_shape])
|
38
|
+
|
39
|
+
# Replace ignore_label value with one valid for F.select_item below.
|
40
|
+
t = t.clip(0, loss.shape[1] - 1)
|
41
|
+
|
42
|
+
loss = Chainer::Functions::Array::SelectItem.select_item(loss, t.flatten.dup)
|
43
|
+
loss = Chainer::Functions::Array::Reshape.reshape(loss, t.shape)
|
44
|
+
|
45
|
+
loss = loss * in_use
|
46
|
+
|
47
|
+
if reduce == "mean"
|
48
|
+
count = normalize ? in_use.sum : x.shape.first
|
49
|
+
count = [count, 1.0].max
|
50
|
+
loss = loss * (1.0 / count)
|
51
|
+
return Chainer::Functions::Math::Sum.sum(loss)
|
52
|
+
else
|
53
|
+
return loss
|
54
|
+
end
|
7
55
|
end
|
8
56
|
|
9
57
|
def initialize(normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
|
10
58
|
@normalize = normalize
|
11
59
|
@cache_score = cache_score
|
60
|
+
self.class.check_class_weight_option(class_weight)
|
12
61
|
@class_weight = class_weight
|
13
62
|
|
14
|
-
unless class_weight.nil?
|
15
|
-
if @class_weight.ndim != 1
|
16
|
-
raise ArgumentError, 'class_weight.ndim should be 1'
|
17
|
-
elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
|
18
|
-
raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
|
19
|
-
elsif @class_weight.kind_of?(Chainer::Variable)
|
20
|
-
raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
|
21
|
-
end
|
22
|
-
end
|
23
|
-
|
24
63
|
@ignore_label = ignore_label
|
25
|
-
unless ['mean', 'no'].include?(reduce)
|
26
|
-
raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
27
|
-
end
|
28
64
|
|
65
|
+
self.class.check_reduce_option(reduce)
|
29
66
|
@reduce = reduce
|
30
67
|
end
|
31
68
|
|
32
|
-
def
|
69
|
+
def forward(inputs)
|
70
|
+
xm = Chainer.get_array_module(*inputs)
|
33
71
|
x, t = inputs
|
34
72
|
log_y = Activation._log_softmax(x)
|
35
73
|
|
36
74
|
if @cache_score
|
37
|
-
@y =
|
75
|
+
@y = xm::NMath.exp(log_y)
|
38
76
|
end
|
39
77
|
if @class_weight
|
40
78
|
shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
|
41
|
-
log_y *= Chainer::
|
79
|
+
log_y *= Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
42
80
|
end
|
43
|
-
log_yd = Chainer::
|
81
|
+
log_yd = Chainer::Utils::Array.rollaxis(log_y, 1)
|
44
82
|
begin
|
45
83
|
log_yd = log_yd.reshape(log_yd.shape[0], true)
|
46
84
|
rescue ArgumentError
|
47
85
|
end
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
log_p = []
|
54
|
-
ravel_arr.each_with_index do |r, i|
|
55
|
-
log_p << log_yd[r, i]
|
86
|
+
|
87
|
+
log_p = log_yd[t.class.maximum(t.flatten, 0), t.class.new(t.size).seq].diagonal
|
88
|
+
if @ignore_label
|
89
|
+
t_valid= t.ne(@ignore_label)
|
90
|
+
log_p *= t_valid.flatten
|
56
91
|
end
|
57
|
-
log_p = log_yd.class.[](*log_p)
|
58
|
-
log_p[t.flatten.dup.eq(@ignore_label)] = 0
|
59
92
|
|
60
93
|
if @reduce == 'mean'
|
61
|
-
if @normalize
|
62
|
-
|
94
|
+
if @normalize and t_valid
|
95
|
+
@coeff = 1.0 / log_p.class.maximum(Chainer::Utils::Array.force_array(t_valid.count), 1)
|
63
96
|
else
|
64
97
|
count = x.shape[0]
|
98
|
+
@coeff = 1.0 / [count, 1].max
|
65
99
|
end
|
66
|
-
@coeff = 1.0 / [count, 1].max
|
67
100
|
y = log_p.sum(keepdims: true) * (-@coeff)
|
68
101
|
[y.class.cast(y[0])]
|
69
102
|
else
|
@@ -71,7 +104,8 @@ module Chainer
|
|
71
104
|
end
|
72
105
|
end
|
73
106
|
|
74
|
-
def
|
107
|
+
def backward(inputs, grad_outputs)
|
108
|
+
xm = Chainer.get_array_module(*(inputs + grad_outputs))
|
75
109
|
x, t = inputs
|
76
110
|
gloss = grad_outputs[0]
|
77
111
|
|
@@ -79,24 +113,24 @@ module Chainer
|
|
79
113
|
y = @y.dup
|
80
114
|
else
|
81
115
|
y = Activation._log_softmax(x)
|
82
|
-
y =
|
116
|
+
y = xm::NMath.exp(y)
|
83
117
|
end
|
84
118
|
|
85
119
|
if y.ndim == 2
|
86
120
|
gx = y
|
121
|
+
# TODO(sonots): Avoid to_a especially in Cumo to improve performance
|
87
122
|
t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
|
88
123
|
|
89
124
|
if @class_weight
|
90
125
|
shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
|
91
|
-
c = Chainer::
|
92
|
-
c = c
|
93
|
-
gx *= Chainer::
|
126
|
+
c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
127
|
+
c = c[t.class.new(t.shape[0]).seq, t.class.maximum(t, 0)].diagonal.dup
|
128
|
+
gx *= Chainer::Utils::Array.broadcast_to(c.expand_dims(1), gx.shape)
|
94
129
|
end
|
95
130
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
gx *= bit.reshape(t.shape[0], 1)
|
131
|
+
if @ignore_label
|
132
|
+
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1)
|
133
|
+
end
|
100
134
|
else
|
101
135
|
# in the case where y.ndim is higher than 2,
|
102
136
|
# we think that a current implementation is inefficient
|
@@ -104,18 +138,21 @@ module Chainer
|
|
104
138
|
|
105
139
|
n_unit = t.size / t.shape[0]
|
106
140
|
gx = y.reshape(y.shape[0], y.shape[1], true)
|
107
|
-
fst_index =
|
108
|
-
trd_index =
|
141
|
+
fst_index = xm::Int32.new(t.size).seq(0) / n_unit
|
142
|
+
trd_index = xm::Int32.new(t.size).seq(0) % n_unit
|
143
|
+
# TODO(sonots): Avoid to_a especially in Cumo to improve performance
|
109
144
|
fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
|
110
145
|
if @class_weight
|
111
146
|
shape = x.ndim.times.map{|d| d == 1 ? true : 1}
|
112
|
-
c = Chainer::
|
147
|
+
c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
113
148
|
c = c.reshape(*gx.shape)
|
114
|
-
c = c
|
149
|
+
c = c[fst_index, t.class.maximum(t.flatten.dup, 0), trd_index].diagonal.diagonal.dup
|
115
150
|
c = c.reshape(y.shape[0], 1, true)
|
116
|
-
gx *= Chainer::
|
151
|
+
gx *= Chainer::Utils::Array.broadcast_to(c, gx.shape)
|
152
|
+
end
|
153
|
+
if @ignore_label
|
154
|
+
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
|
117
155
|
end
|
118
|
-
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
|
119
156
|
gx = gx.reshape(*y.shape)
|
120
157
|
end
|
121
158
|
|
@@ -126,36 +163,26 @@ module Chainer
|
|
126
163
|
end
|
127
164
|
return [gx, nil]
|
128
165
|
end
|
129
|
-
end
|
130
166
|
|
131
|
-
|
132
|
-
|
133
|
-
axes.delete_at(axis)
|
134
|
-
axes.insert(start <= axes.size ? start : -1, axis)
|
135
|
-
y.transpose(*axes)
|
136
|
-
end
|
167
|
+
def self.check_class_weight_option(class_weight)
|
168
|
+
return if class_weight.nil?
|
137
169
|
|
138
|
-
|
139
|
-
|
140
|
-
|
170
|
+
xm = Chainer.get_array_module(@class_weight)
|
171
|
+
if class_weight.ndim != 1
|
172
|
+
raise ArgumentError, 'class_weight.ndim should be 1'
|
173
|
+
elsif (class_weight.class != xm::DFloat) and (class_weight.class != xm::SFloat)
|
174
|
+
raise ArgumentError, "The dtype of class_weight should be 'DFloat' or 'SFloat'"
|
175
|
+
elsif class_weight.kind_of?(Chainer::Variable)
|
176
|
+
raise ArgumentError, 'class_weight should be a NArray, not a chainer.Variable'
|
177
|
+
end
|
141
178
|
end
|
142
179
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
if array.shape[i] == 1
|
147
|
-
tile_shape << s
|
148
|
-
elsif array.shape[i] == s
|
149
|
-
tile_shape << 1
|
150
|
-
else
|
151
|
-
raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
|
180
|
+
def self.check_reduce_option(reduce)
|
181
|
+
unless ['mean', 'no'].include?(reduce)
|
182
|
+
raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
152
183
|
end
|
153
|
-
|
154
|
-
|
155
|
-
array.tile(*shape[0...-array.shape.size], *tile_shape)
|
184
|
+
end
|
156
185
|
end
|
157
|
-
|
158
|
-
module_function :rollaxis, :broadcast_to
|
159
186
|
end
|
160
187
|
end
|
161
188
|
end
|