red-chainer 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
@@ -3,7 +3,7 @@ require 'zlib'
|
|
3
3
|
module Chainer
|
4
4
|
module Datasets
|
5
5
|
module Mnist
|
6
|
-
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::
|
6
|
+
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::SFloat, label_dtype: Numo::Int32)
|
7
7
|
train_raw = retrieve_mnist_training
|
8
8
|
train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
|
9
9
|
|
@@ -15,9 +15,9 @@ module Chainer
|
|
15
15
|
def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype)
|
16
16
|
images = raw[:x]
|
17
17
|
if ndim == 2
|
18
|
-
images = images.reshape(
|
18
|
+
images = images.reshape(true, 28, 28)
|
19
19
|
elsif ndim == 3
|
20
|
-
images = images.reshape(
|
20
|
+
images = images.reshape(true, 1, 28, 28)
|
21
21
|
elsif ndim != 1
|
22
22
|
raise "invalid ndim for MNIST dataset"
|
23
23
|
end
|
@@ -16,7 +16,9 @@ module Chainer
|
|
16
16
|
end
|
17
17
|
|
18
18
|
def [](index)
|
19
|
-
batches = @datasets.map
|
19
|
+
batches = @datasets.map do |dataset|
|
20
|
+
dataset.ndim > 1 ? dataset[index, false] : dataset[index]
|
21
|
+
end
|
20
22
|
if index.kind_of?(Enumerable)
|
21
23
|
length = batches[0].shape[0]
|
22
24
|
length.times.map {|i| batches.map { |m| m[i] } }
|
data/lib/chainer/function.rb
CHANGED
@@ -13,19 +13,19 @@ module Chainer
|
|
13
13
|
#
|
14
14
|
# where $a$ is a configurable slope value.
|
15
15
|
#
|
16
|
-
# @param [Chainer::Variable or Numo::
|
16
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
17
17
|
# @param [float] slope Slope value $a$.
|
18
18
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
19
19
|
# @example
|
20
|
-
# > x = Numo::
|
20
|
+
# > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
|
21
21
|
# > x
|
22
|
-
# => Numo::
|
22
|
+
# => Numo::SFloat#shape=[3,2]
|
23
23
|
# [[-1, 0],
|
24
24
|
# [2, -3],
|
25
25
|
# [-2, 1]]
|
26
26
|
# > F = Chainer::Functions::Activation::LeakyReLU
|
27
27
|
# > F.leaky_relu(x, slope:0.2).data
|
28
|
-
# => Numo::
|
28
|
+
# => Numo::SFloat#shape=[3,2]
|
29
29
|
# [[-0.2, 0],
|
30
30
|
# [2, -0.6],
|
31
31
|
# [-0.4, 1]]
|
@@ -36,19 +36,19 @@ module Chainer
|
|
36
36
|
# because +softmax(x)+ may returns +0+.
|
37
37
|
# +log_softmax+ method is more stable.
|
38
38
|
#
|
39
|
-
# @param [Chainer::Variable or Numo::
|
39
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
40
|
# @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
|
41
41
|
#
|
42
42
|
# @see Chainer::Functions::Softmax
|
43
43
|
#
|
44
44
|
# @example
|
45
|
-
# > x = Numo::
|
46
|
-
# => Numo::
|
45
|
+
# > x = Numo::SFloat[[0, 1, 2], [0, 2, 4]]
|
46
|
+
# => Numo::SFloat#shape=[2,3]
|
47
47
|
# [[0, 1, 2],
|
48
48
|
# [0, 2, 4]]
|
49
49
|
# > F = Chainer::Functions::Activation::LogSoftmax
|
50
50
|
# > F.log_softmax(x).data
|
51
|
-
# => Numo::
|
51
|
+
# => Numo::SFloat#shape=[2,3]
|
52
52
|
# [[-2.40761, -1.40761, -0.407606],
|
53
53
|
# [-4.14293, -2.14293, -0.142932]]
|
54
54
|
# @example (T.B.I : F.log, F.softmax)
|
@@ -9,10 +9,10 @@ module Chainer
|
|
9
9
|
# f(x)=\\max(0, x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
|
-
# > x = Numo::
|
15
|
+
# > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
|
16
16
|
# > (x < 0).any?
|
17
17
|
# => true
|
18
18
|
# > F = Chainer::Functions::Activation::Relu
|
@@ -29,8 +29,7 @@ module Chainer
|
|
29
29
|
def forward_cpu(x)
|
30
30
|
retain_inputs([])
|
31
31
|
retain_outputs([0])
|
32
|
-
x[0]
|
33
|
-
[Utils::Array.force_array(x[0])]
|
32
|
+
[Utils::Array.force_array(x[0].class.maximum(x[0], 0))]
|
34
33
|
end
|
35
34
|
|
36
35
|
def backward_cpu(x, gy)
|
@@ -9,15 +9,15 @@ module Chainer
|
|
9
9
|
# f(x)=(1 + \\exp(-x))^ { -1 }.
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example It maps the input values into the range of $`[0, 1]`$.
|
15
|
-
# > x = Numo::
|
16
|
-
# => Numo::
|
15
|
+
# > x = Numo::SFloat.new(3).seq(-2, 2)
|
16
|
+
# => Numo::SFloat#shape=[3]
|
17
17
|
# [-2, 0, 2]
|
18
18
|
# > F = Chainer::Functions::Activation::Sigmoid
|
19
19
|
# > F.sigmoid(x).data
|
20
|
-
# => Numo::
|
20
|
+
# => Numo::SFloat#shape=[3]
|
21
21
|
# [0.119203, 0.5, 0.880797]
|
22
22
|
#
|
23
23
|
def self.sigmoid(x)
|
@@ -9,15 +9,15 @@ module Chainer
|
|
9
9
|
# f(x)=\\tanh(x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
|
-
# > x = Numo::
|
16
|
-
# => Numo::
|
15
|
+
# > x = Numo::SFloat.new(3).seq(-1, 2)
|
16
|
+
# => Numo::SFloat#shape=[3]
|
17
17
|
# [-1, 1, 3]
|
18
18
|
# > F = Chainer::Functions::Activation::Tanh
|
19
19
|
# > F.tanh(x).data
|
20
|
-
# => Numo::
|
20
|
+
# => Numo::SFloat#shape=[3]
|
21
21
|
# [-0.761594, 0.761594, 0.995055]
|
22
22
|
#
|
23
23
|
def self.tanh(x)
|
@@ -33,7 +33,7 @@ module Chainer
|
|
33
33
|
|
34
34
|
def backward_cpu(x, gy)
|
35
35
|
y = @output_data[0]
|
36
|
-
one = y.
|
36
|
+
one = y.class.cast(1)
|
37
37
|
[Utils::Array.force_array(gy[0] * (one - y * y))]
|
38
38
|
end
|
39
39
|
end
|
@@ -0,0 +1,92 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class Convolution2DFunction < Chainer::Function
|
5
|
+
# Two-dimensional convolution function.
|
6
|
+
# This is an implementation of two-dimensional convolution in ConvNets.
|
7
|
+
# It takes three variables: the input image `x`, the filter weight `w`, and the bias vector `b`.
|
8
|
+
#
|
9
|
+
# a notation for dimensionalities.
|
10
|
+
#
|
11
|
+
# - :math:`n` is the batch size.
|
12
|
+
# - :math:`c_I` and :math:`c_O` are the number of the input and output channels, respectively.
|
13
|
+
# - :math:`h_I` and :math:`w_I` are the height and width of the input image, respectively.
|
14
|
+
# - :math:`h_K` and :math:`w_K` are the height and width of the filters, respectively.
|
15
|
+
# - :math:`h_P` and :math:`w_P` are the height and width of the spatial padding size, respectively.
|
16
|
+
#
|
17
|
+
# Then the `Convolution2D` function computes correlations between filters and patches of size :math:`(h_K, w_K)` in `x`.
|
18
|
+
# Patches are extracted at positions shifted by multiples of `stride` from the first position `(-h_P, -w_P)` for each spatial axis.
|
19
|
+
# The right-most (or bottom-most) patches do not run over the padded spatial size.
|
20
|
+
# Let :math:`(s_Y, s_X)` be the stride of filter application.
|
21
|
+
# Then, the output size :math:`(h_O, w_O)` is determined by the following equations:
|
22
|
+
#
|
23
|
+
# math:
|
24
|
+
# h_O &= (h_I + 2h_P - h_K) / s_Y + 1,\\\\
|
25
|
+
# w_O &= (w_I + 2w_P - w_K) / s_X + 1.
|
26
|
+
# If `cover_all` option is `true`, the filter will cover the all spatial locations.
|
27
|
+
# So, if the last stride of filter does not cover the end of spatial locations,
|
28
|
+
# an addtional stride will be applied to the end part of spatial locations.
|
29
|
+
# In this case, the output size :math:`(h_O, w_O)` is determined by the following equations:
|
30
|
+
#
|
31
|
+
# math:
|
32
|
+
# h_O &= (h_I + 2h_P - h_K + s_Y - 1) / s_Y + 1,\\\\
|
33
|
+
# w_O &= (w_I + 2w_P - w_K + s_X - 1) / s_X + 1.
|
34
|
+
# If the bias vector is given, then it is added to all spatial locations of the output of convolution.
|
35
|
+
#
|
36
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
|
37
|
+
# @param [Chainer::Variable or Numo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
|
38
|
+
# @param [Chainer::Variable or Numo::NArray] b Bias variable of length :math:`c_O`
|
39
|
+
# @param [Int or 2-D Array] stride Stride of filter applications. `stride=s` and `stride=(s, s)` are equivalent.
|
40
|
+
# @param [Int or 2-D Array] pad Spatial padding width for input arrays.
|
41
|
+
# @param [Boolean] cover_all If `true`, all spatial locations are convoluted into some output pixels.
|
42
|
+
# @return [Chainer::Variable] Output variable of shape :math:`(n, c_O, h_O, w_O)`.
|
43
|
+
def self.convolution_2d(x, w, b: nil, stride: 1, pad: 0, cover_all: false)
|
44
|
+
func = self.new(stride: stride, pad: pad, cover_all: cover_all)
|
45
|
+
if b.nil?
|
46
|
+
func.(x, w)
|
47
|
+
else
|
48
|
+
func.(x, w, b)
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
def initialize(stride: 1, pad: 0, cover_all: false)
|
53
|
+
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
54
|
+
@ph, @pw = pad.is_a?(Array) ? pad : [pad, pad]
|
55
|
+
@cover_all = cover_all
|
56
|
+
end
|
57
|
+
|
58
|
+
def forward_cpu(inputs)
|
59
|
+
x = inputs[0]
|
60
|
+
w = inputs[1]
|
61
|
+
b = inputs.size == 3 ? inputs[2] : nil
|
62
|
+
|
63
|
+
kh, kw = w.shape[2], w.shape[3]
|
64
|
+
|
65
|
+
@col = Chainer::Utils::Conv.im2col_cpu(x, kh, kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
|
66
|
+
y = Chainer::Utils::Math.tensordot(@col, w, [[1, 2, 3], [1, 2, 3]])
|
67
|
+
y += b if b
|
68
|
+
|
69
|
+
[y.transpose(0, 3, 1, 2)]
|
70
|
+
end
|
71
|
+
|
72
|
+
def backward_cpu(inputs, grad_outputs)
|
73
|
+
x, w, b = inputs[0], inputs[1], inputs[2]
|
74
|
+
gy = grad_outputs[0]
|
75
|
+
height, width = x.shape[2..-1]
|
76
|
+
|
77
|
+
gw = Chainer::Utils::Math.tensordot(gy, @col, [[0, 2, 3], [0, 4, 5]])
|
78
|
+
gcol = Chainer::Utils::Math.tensordot(w, gy, [0, 1])
|
79
|
+
gcol = gcol.transpose(3, 0, 1, 2)
|
80
|
+
gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, height, width)
|
81
|
+
|
82
|
+
if b.nil?
|
83
|
+
[gx, gw]
|
84
|
+
else
|
85
|
+
gb = gy.sum(axis: [0, 2, 3])
|
86
|
+
[gx, gw, gb]
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Loss
|
4
|
+
# Mean squared error (a.k.a. Euclidean loss) function.
|
5
|
+
class MeanSquaredError < Function
|
6
|
+
# Mean squared error function.
|
7
|
+
#
|
8
|
+
# This function computes mean squared error between two variables. The mean
|
9
|
+
# is taken over the minibatch. Note that the error is not scaled by 1/2.
|
10
|
+
#
|
11
|
+
# @param [Chainer::Variable or Numo::NArray] x0 Input variable.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x1 Input variable.
|
13
|
+
# @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
|
14
|
+
#
|
15
|
+
def self.mean_squared_error(x0, x1)
|
16
|
+
self.new.(x0, x1)
|
17
|
+
end
|
18
|
+
|
19
|
+
def forward_cpu(inputs)
|
20
|
+
x0, x1 = inputs
|
21
|
+
@diff = x0 - x1
|
22
|
+
diff = @diff.flatten.dup()
|
23
|
+
[diff.class.cast(diff.dot(diff) / diff.size)]
|
24
|
+
end
|
25
|
+
|
26
|
+
def backward(inputs, gy)
|
27
|
+
coeff = gy[0] * gy[0].class.cast(2.0 / @diff.size)
|
28
|
+
gx0 = coeff * @diff
|
29
|
+
[gx0, -(gx0)]
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -13,17 +13,17 @@ module Chainer
|
|
13
13
|
|
14
14
|
unless class_weight.nil?
|
15
15
|
if @class_weight.ndim != 1
|
16
|
-
raise ArgumentError 'class_weight.ndim should be 1'
|
17
|
-
elsif @class_weight.
|
18
|
-
raise ArgumentError
|
16
|
+
raise ArgumentError, 'class_weight.ndim should be 1'
|
17
|
+
elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
|
18
|
+
raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
|
19
19
|
elsif @class_weight.kind_of?(Chainer::Variable)
|
20
|
-
raise ArgumentError 'class_weight should be a Numo::NArray, not a chainer.Variable'
|
20
|
+
raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
|
21
21
|
end
|
22
22
|
end
|
23
23
|
|
24
24
|
@ignore_label = ignore_label
|
25
25
|
unless ['mean', 'no'].include?(reduce)
|
26
|
-
raise ArgumentError "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
26
|
+
raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
27
27
|
end
|
28
28
|
|
29
29
|
@reduce = reduce
|
@@ -37,40 +37,37 @@ module Chainer
|
|
37
37
|
@y = Numo::NMath.exp(log_y)
|
38
38
|
end
|
39
39
|
if @class_weight
|
40
|
-
shape = x.ndim.times.map { |e| e == 1 ?
|
41
|
-
log_y
|
40
|
+
shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
|
41
|
+
log_y *= Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
42
42
|
end
|
43
|
-
log_yd = rollaxis(log_y, 1)
|
43
|
+
log_yd = Chainer::Functions::Loss.rollaxis(log_y, 1)
|
44
44
|
begin
|
45
|
-
log_yd = log_yd.reshape(log_yd.
|
45
|
+
log_yd = log_yd.reshape(log_yd.shape[0], true)
|
46
46
|
rescue ArgumentError
|
47
47
|
end
|
48
|
-
|
49
48
|
ravel_arr = t.dup.flatten.dup
|
50
49
|
ravel_arr[ravel_arr<0] = 0
|
51
50
|
arange_arr = t.class.new(t.size).seq
|
52
51
|
|
53
52
|
# https://github.com/chainer/chainer/blob/v2.0.2/chainer/functions/loss/softmax_cross_entropy.py#L79
|
54
53
|
log_p = []
|
55
|
-
|
56
|
-
log_p << log_yd[
|
54
|
+
ravel_arr.each_with_index do |r, i|
|
55
|
+
log_p << log_yd[r, i]
|
57
56
|
end
|
58
|
-
log_p =
|
59
|
-
|
60
|
-
log_p[log_p.eq(@ignore_label)] = 0
|
57
|
+
log_p = log_yd.class.[](*log_p)
|
58
|
+
log_p[t.flatten.dup.eq(@ignore_label)] = 0
|
61
59
|
|
62
60
|
if @reduce == 'mean'
|
63
61
|
if @normalize
|
64
62
|
count = t.ne(@ignore_label).count
|
65
63
|
else
|
66
|
-
count = x.
|
64
|
+
count = x.shape[0]
|
67
65
|
end
|
68
66
|
@coeff = 1.0 / [count, 1].max
|
69
|
-
|
70
67
|
y = log_p.sum(keepdims: true) * (-@coeff)
|
71
|
-
[y.
|
68
|
+
[y.class.cast(y[0])]
|
72
69
|
else
|
73
|
-
[-log_p.reshape(t.shape)]
|
70
|
+
[-log_p.reshape(*t.shape)]
|
74
71
|
end
|
75
72
|
end
|
76
73
|
|
@@ -87,48 +84,78 @@ module Chainer
|
|
87
84
|
|
88
85
|
if y.ndim == 2
|
89
86
|
gx = y
|
90
|
-
t[t
|
91
|
-
t.each_with_index do |v, idx|
|
92
|
-
gx[(idx * 10)...(idx * 10 + 10)][v] -= 1
|
93
|
-
end
|
87
|
+
t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
|
94
88
|
|
95
89
|
if @class_weight
|
96
|
-
shape = x.ndim.times.map { |d| d == 1 ?
|
97
|
-
c = broadcast_to(@class_weight.reshape(shape), x.shape)
|
98
|
-
c = c
|
99
|
-
gx *= broadcast_to(
|
90
|
+
shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
|
91
|
+
c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
92
|
+
c = c.class.cast(t.class.new(t.shape[0]).seq.to_a.zip(t.class.maximum(t, 0).to_a).map{|v| c[*v]})
|
93
|
+
gx *= Chainer::Functions::Loss.broadcast_to(c.expand_dims(1), gx.shape)
|
100
94
|
end
|
101
95
|
|
102
96
|
bit = t.flatten.dup
|
103
97
|
bit[t.ne(@ignore_label)] = 1
|
104
98
|
bit[bit.ne(1)] = 0
|
105
|
-
gx *= bit.reshape(t.
|
99
|
+
gx *= bit.reshape(t.shape[0], 1)
|
106
100
|
else
|
107
|
-
|
101
|
+
# in the case where y.ndim is higher than 2,
|
102
|
+
# we think that a current implementation is inefficient
|
103
|
+
# because it yields two provisional arrays for indexing.
|
104
|
+
|
105
|
+
n_unit = t.size / t.shape[0]
|
106
|
+
gx = y.reshape(y.shape[0], y.shape[1], true)
|
107
|
+
fst_index = Numo::Int32.new(t.size).seq(0) / n_unit
|
108
|
+
trd_index = Numo::Int32.new(t.size).seq(0) % n_unit
|
109
|
+
fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
|
110
|
+
if @class_weight
|
111
|
+
shape = x.ndim.times.map{|d| d == 1 ? true : 1}
|
112
|
+
c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
113
|
+
c = c.reshape(*gx.shape)
|
114
|
+
c = c.class.cast(fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).map{|v| c[*v]})
|
115
|
+
c = c.reshape(y.shape[0], 1, true)
|
116
|
+
gx *= Chainer::Functions::Loss.broadcast_to(c, gx.shape)
|
117
|
+
end
|
118
|
+
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
|
119
|
+
gx = gx.reshape(*y.shape)
|
108
120
|
end
|
109
121
|
|
110
122
|
if @reduce == 'mean'
|
111
123
|
gx *= gloss * @coeff
|
112
124
|
else
|
113
|
-
|
125
|
+
gx *= gloss[true,:- , false]
|
114
126
|
end
|
115
127
|
return [gx, nil]
|
116
128
|
end
|
129
|
+
end
|
117
130
|
|
131
|
+
def rollaxis(y, axis, start: 0)
|
132
|
+
axes = (0...y.ndim).to_a
|
133
|
+
axes.delete_at(axis)
|
134
|
+
axes.insert(start <= axes.size ? start : -1, axis)
|
135
|
+
y.transpose(*axes)
|
136
|
+
end
|
118
137
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
array.class.tile(array, shape[0]).reshape(*shape)
|
138
|
+
def broadcast_to(array, shape)
|
139
|
+
if array.shape.size > shape.size
|
140
|
+
raise TypeError, "Shape of data mismatch\n array.shape.size(#{array.shape.size}) > shape.size(#{shape.size})"
|
123
141
|
end
|
124
142
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
143
|
+
tile_shape = []
|
144
|
+
shape_check = shape[-array.shape.size..-1]
|
145
|
+
shape_check.each_with_index{|s, i|
|
146
|
+
if array.shape[i] == 1
|
147
|
+
tile_shape << s
|
148
|
+
elsif array.shape[i] == s
|
149
|
+
tile_shape << 1
|
150
|
+
else
|
151
|
+
raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
|
152
|
+
end
|
153
|
+
}
|
154
|
+
|
155
|
+
array.tile(*shape[0...-array.shape.size], *tile_shape)
|
131
156
|
end
|
157
|
+
|
158
|
+
module_function :rollaxis, :broadcast_to
|
132
159
|
end
|
133
160
|
end
|
134
161
|
end
|