red-chainer 0.2.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/examples/cifar/models/vgg.rb +84 -0
- data/examples/cifar/train_cifar.rb +70 -0
- data/examples/iris.rb +103 -0
- data/lib/chainer.rb +17 -0
- data/lib/chainer/configuration.rb +2 -1
- data/lib/chainer/cuda.rb +18 -0
- data/lib/chainer/dataset/convert.rb +30 -9
- data/lib/chainer/datasets/cifar.rb +56 -0
- data/lib/chainer/datasets/mnist.rb +3 -3
- data/lib/chainer/datasets/tuple_dataset.rb +3 -1
- data/lib/chainer/function.rb +1 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
- data/lib/chainer/functions/activation/log_softmax.rb +4 -4
- data/lib/chainer/functions/activation/relu.rb +3 -4
- data/lib/chainer/functions/activation/sigmoid.rb +4 -4
- data/lib/chainer/functions/activation/tanh.rb +5 -5
- data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
- data/lib/chainer/functions/connection/linear.rb +1 -1
- data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
- data/lib/chainer/functions/math/identity.rb +26 -0
- data/lib/chainer/functions/noise/dropout.rb +45 -0
- data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
- data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
- data/lib/chainer/gradient_check.rb +240 -0
- data/lib/chainer/initializer.rb +2 -0
- data/lib/chainer/initializers/constant.rb +1 -1
- data/lib/chainer/initializers/init.rb +5 -1
- data/lib/chainer/initializers/normal.rb +1 -1
- data/lib/chainer/iterators/serial_iterator.rb +1 -1
- data/lib/chainer/link.rb +11 -0
- data/lib/chainer/links/connection/convolution_2d.rb +98 -0
- data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
- data/lib/chainer/optimizer.rb +40 -1
- data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
- data/lib/chainer/parameter.rb +1 -1
- data/lib/chainer/serializers/marshal.rb +7 -3
- data/lib/chainer/testing/array.rb +32 -0
- data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
- data/lib/chainer/training/extensions/snapshot.rb +1 -1
- data/lib/chainer/training/standard_updater.rb +4 -0
- data/lib/chainer/training/trainer.rb +1 -1
- data/lib/chainer/utils/array.rb +13 -2
- data/lib/chainer/utils/conv.rb +59 -0
- data/lib/chainer/utils/math.rb +72 -0
- data/lib/chainer/utils/variable.rb +7 -3
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +1 -0
- metadata +37 -3
@@ -3,7 +3,7 @@ require 'zlib'
|
|
3
3
|
module Chainer
|
4
4
|
module Datasets
|
5
5
|
module Mnist
|
6
|
-
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::
|
6
|
+
def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::SFloat, label_dtype: Numo::Int32)
|
7
7
|
train_raw = retrieve_mnist_training
|
8
8
|
train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
|
9
9
|
|
@@ -15,9 +15,9 @@ module Chainer
|
|
15
15
|
def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype)
|
16
16
|
images = raw[:x]
|
17
17
|
if ndim == 2
|
18
|
-
images = images.reshape(
|
18
|
+
images = images.reshape(true, 28, 28)
|
19
19
|
elsif ndim == 3
|
20
|
-
images = images.reshape(
|
20
|
+
images = images.reshape(true, 1, 28, 28)
|
21
21
|
elsif ndim != 1
|
22
22
|
raise "invalid ndim for MNIST dataset"
|
23
23
|
end
|
@@ -16,7 +16,9 @@ module Chainer
|
|
16
16
|
end
|
17
17
|
|
18
18
|
def [](index)
|
19
|
-
batches = @datasets.map
|
19
|
+
batches = @datasets.map do |dataset|
|
20
|
+
dataset.ndim > 1 ? dataset[index, false] : dataset[index]
|
21
|
+
end
|
20
22
|
if index.kind_of?(Enumerable)
|
21
23
|
length = batches[0].shape[0]
|
22
24
|
length.times.map {|i| batches.map { |m| m[i] } }
|
data/lib/chainer/function.rb
CHANGED
@@ -13,19 +13,19 @@ module Chainer
|
|
13
13
|
#
|
14
14
|
# where $a$ is a configurable slope value.
|
15
15
|
#
|
16
|
-
# @param [Chainer::Variable or Numo::
|
16
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
17
17
|
# @param [float] slope Slope value $a$.
|
18
18
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
19
19
|
# @example
|
20
|
-
# > x = Numo::
|
20
|
+
# > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
|
21
21
|
# > x
|
22
|
-
# => Numo::
|
22
|
+
# => Numo::SFloat#shape=[3,2]
|
23
23
|
# [[-1, 0],
|
24
24
|
# [2, -3],
|
25
25
|
# [-2, 1]]
|
26
26
|
# > F = Chainer::Functions::Activation::LeakyReLU
|
27
27
|
# > F.leaky_relu(x, slope:0.2).data
|
28
|
-
# => Numo::
|
28
|
+
# => Numo::SFloat#shape=[3,2]
|
29
29
|
# [[-0.2, 0],
|
30
30
|
# [2, -0.6],
|
31
31
|
# [-0.4, 1]]
|
@@ -36,19 +36,19 @@ module Chainer
|
|
36
36
|
# because +softmax(x)+ may returns +0+.
|
37
37
|
# +log_softmax+ method is more stable.
|
38
38
|
#
|
39
|
-
# @param [Chainer::Variable or Numo::
|
39
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
|
40
40
|
# @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
|
41
41
|
#
|
42
42
|
# @see Chainer::Functions::Softmax
|
43
43
|
#
|
44
44
|
# @example
|
45
|
-
# > x = Numo::
|
46
|
-
# => Numo::
|
45
|
+
# > x = Numo::SFloat[[0, 1, 2], [0, 2, 4]]
|
46
|
+
# => Numo::SFloat#shape=[2,3]
|
47
47
|
# [[0, 1, 2],
|
48
48
|
# [0, 2, 4]]
|
49
49
|
# > F = Chainer::Functions::Activation::LogSoftmax
|
50
50
|
# > F.log_softmax(x).data
|
51
|
-
# => Numo::
|
51
|
+
# => Numo::SFloat#shape=[2,3]
|
52
52
|
# [[-2.40761, -1.40761, -0.407606],
|
53
53
|
# [-4.14293, -2.14293, -0.142932]]
|
54
54
|
# @example (T.B.I : F.log, F.softmax)
|
@@ -9,10 +9,10 @@ module Chainer
|
|
9
9
|
# f(x)=\\max(0, x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
|
-
# > x = Numo::
|
15
|
+
# > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
|
16
16
|
# > (x < 0).any?
|
17
17
|
# => true
|
18
18
|
# > F = Chainer::Functions::Activation::Relu
|
@@ -29,8 +29,7 @@ module Chainer
|
|
29
29
|
def forward_cpu(x)
|
30
30
|
retain_inputs([])
|
31
31
|
retain_outputs([0])
|
32
|
-
x[0]
|
33
|
-
[Utils::Array.force_array(x[0])]
|
32
|
+
[Utils::Array.force_array(x[0].class.maximum(x[0], 0))]
|
34
33
|
end
|
35
34
|
|
36
35
|
def backward_cpu(x, gy)
|
@@ -9,15 +9,15 @@ module Chainer
|
|
9
9
|
# f(x)=(1 + \\exp(-x))^ { -1 }.
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example It maps the input values into the range of $`[0, 1]`$.
|
15
|
-
# > x = Numo::
|
16
|
-
# => Numo::
|
15
|
+
# > x = Numo::SFloat.new(3).seq(-2, 2)
|
16
|
+
# => Numo::SFloat#shape=[3]
|
17
17
|
# [-2, 0, 2]
|
18
18
|
# > F = Chainer::Functions::Activation::Sigmoid
|
19
19
|
# > F.sigmoid(x).data
|
20
|
-
# => Numo::
|
20
|
+
# => Numo::SFloat#shape=[3]
|
21
21
|
# [0.119203, 0.5, 0.880797]
|
22
22
|
#
|
23
23
|
def self.sigmoid(x)
|
@@ -9,15 +9,15 @@ module Chainer
|
|
9
9
|
# f(x)=\\tanh(x).
|
10
10
|
# $$
|
11
11
|
#
|
12
|
-
# @param [Chainer::Variable or Numo::
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
13
13
|
# @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
|
14
14
|
# @example
|
15
|
-
# > x = Numo::
|
16
|
-
# => Numo::
|
15
|
+
# > x = Numo::SFloat.new(3).seq(-1, 2)
|
16
|
+
# => Numo::SFloat#shape=[3]
|
17
17
|
# [-1, 1, 3]
|
18
18
|
# > F = Chainer::Functions::Activation::Tanh
|
19
19
|
# > F.tanh(x).data
|
20
|
-
# => Numo::
|
20
|
+
# => Numo::SFloat#shape=[3]
|
21
21
|
# [-0.761594, 0.761594, 0.995055]
|
22
22
|
#
|
23
23
|
def self.tanh(x)
|
@@ -33,7 +33,7 @@ module Chainer
|
|
33
33
|
|
34
34
|
def backward_cpu(x, gy)
|
35
35
|
y = @output_data[0]
|
36
|
-
one = y.
|
36
|
+
one = y.class.cast(1)
|
37
37
|
[Utils::Array.force_array(gy[0] * (one - y * y))]
|
38
38
|
end
|
39
39
|
end
|
@@ -0,0 +1,92 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class Convolution2DFunction < Chainer::Function
|
5
|
+
# Two-dimensional convolution function.
|
6
|
+
# This is an implementation of two-dimensional convolution in ConvNets.
|
7
|
+
# It takes three variables: the input image `x`, the filter weight `w`, and the bias vector `b`.
|
8
|
+
#
|
9
|
+
# a notation for dimensionalities.
|
10
|
+
#
|
11
|
+
# - :math:`n` is the batch size.
|
12
|
+
# - :math:`c_I` and :math:`c_O` are the number of the input and output channels, respectively.
|
13
|
+
# - :math:`h_I` and :math:`w_I` are the height and width of the input image, respectively.
|
14
|
+
# - :math:`h_K` and :math:`w_K` are the height and width of the filters, respectively.
|
15
|
+
# - :math:`h_P` and :math:`w_P` are the height and width of the spatial padding size, respectively.
|
16
|
+
#
|
17
|
+
# Then the `Convolution2D` function computes correlations between filters and patches of size :math:`(h_K, w_K)` in `x`.
|
18
|
+
# Patches are extracted at positions shifted by multiples of `stride` from the first position `(-h_P, -w_P)` for each spatial axis.
|
19
|
+
# The right-most (or bottom-most) patches do not run over the padded spatial size.
|
20
|
+
# Let :math:`(s_Y, s_X)` be the stride of filter application.
|
21
|
+
# Then, the output size :math:`(h_O, w_O)` is determined by the following equations:
|
22
|
+
#
|
23
|
+
# math:
|
24
|
+
# h_O &= (h_I + 2h_P - h_K) / s_Y + 1,\\\\
|
25
|
+
# w_O &= (w_I + 2w_P - w_K) / s_X + 1.
|
26
|
+
# If `cover_all` option is `true`, the filter will cover the all spatial locations.
|
27
|
+
# So, if the last stride of filter does not cover the end of spatial locations,
|
28
|
+
# an addtional stride will be applied to the end part of spatial locations.
|
29
|
+
# In this case, the output size :math:`(h_O, w_O)` is determined by the following equations:
|
30
|
+
#
|
31
|
+
# math:
|
32
|
+
# h_O &= (h_I + 2h_P - h_K + s_Y - 1) / s_Y + 1,\\\\
|
33
|
+
# w_O &= (w_I + 2w_P - w_K + s_X - 1) / s_X + 1.
|
34
|
+
# If the bias vector is given, then it is added to all spatial locations of the output of convolution.
|
35
|
+
#
|
36
|
+
# @param [Chainer::Variable or Numo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
|
37
|
+
# @param [Chainer::Variable or Numo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
|
38
|
+
# @param [Chainer::Variable or Numo::NArray] b Bias variable of length :math:`c_O`
|
39
|
+
# @param [Int or 2-D Array] stride Stride of filter applications. `stride=s` and `stride=(s, s)` are equivalent.
|
40
|
+
# @param [Int or 2-D Array] pad Spatial padding width for input arrays.
|
41
|
+
# @param [Boolean] cover_all If `true`, all spatial locations are convoluted into some output pixels.
|
42
|
+
# @return [Chainer::Variable] Output variable of shape :math:`(n, c_O, h_O, w_O)`.
|
43
|
+
def self.convolution_2d(x, w, b: nil, stride: 1, pad: 0, cover_all: false)
|
44
|
+
func = self.new(stride: stride, pad: pad, cover_all: cover_all)
|
45
|
+
if b.nil?
|
46
|
+
func.(x, w)
|
47
|
+
else
|
48
|
+
func.(x, w, b)
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
def initialize(stride: 1, pad: 0, cover_all: false)
|
53
|
+
@sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
|
54
|
+
@ph, @pw = pad.is_a?(Array) ? pad : [pad, pad]
|
55
|
+
@cover_all = cover_all
|
56
|
+
end
|
57
|
+
|
58
|
+
def forward_cpu(inputs)
|
59
|
+
x = inputs[0]
|
60
|
+
w = inputs[1]
|
61
|
+
b = inputs.size == 3 ? inputs[2] : nil
|
62
|
+
|
63
|
+
kh, kw = w.shape[2], w.shape[3]
|
64
|
+
|
65
|
+
@col = Chainer::Utils::Conv.im2col_cpu(x, kh, kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
|
66
|
+
y = Chainer::Utils::Math.tensordot(@col, w, [[1, 2, 3], [1, 2, 3]])
|
67
|
+
y += b if b
|
68
|
+
|
69
|
+
[y.transpose(0, 3, 1, 2)]
|
70
|
+
end
|
71
|
+
|
72
|
+
def backward_cpu(inputs, grad_outputs)
|
73
|
+
x, w, b = inputs[0], inputs[1], inputs[2]
|
74
|
+
gy = grad_outputs[0]
|
75
|
+
height, width = x.shape[2..-1]
|
76
|
+
|
77
|
+
gw = Chainer::Utils::Math.tensordot(gy, @col, [[0, 2, 3], [0, 4, 5]])
|
78
|
+
gcol = Chainer::Utils::Math.tensordot(w, gy, [0, 1])
|
79
|
+
gcol = gcol.transpose(3, 0, 1, 2)
|
80
|
+
gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, height, width)
|
81
|
+
|
82
|
+
if b.nil?
|
83
|
+
[gx, gw]
|
84
|
+
else
|
85
|
+
gb = gy.sum(axis: [0, 2, 3])
|
86
|
+
[gx, gw, gb]
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Loss
|
4
|
+
# Mean squared error (a.k.a. Euclidean loss) function.
|
5
|
+
class MeanSquaredError < Function
|
6
|
+
# Mean squared error function.
|
7
|
+
#
|
8
|
+
# This function computes mean squared error between two variables. The mean
|
9
|
+
# is taken over the minibatch. Note that the error is not scaled by 1/2.
|
10
|
+
#
|
11
|
+
# @param [Chainer::Variable or Numo::NArray] x0 Input variable.
|
12
|
+
# @param [Chainer::Variable or Numo::NArray] x1 Input variable.
|
13
|
+
# @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
|
14
|
+
#
|
15
|
+
def self.mean_squared_error(x0, x1)
|
16
|
+
self.new.(x0, x1)
|
17
|
+
end
|
18
|
+
|
19
|
+
def forward_cpu(inputs)
|
20
|
+
x0, x1 = inputs
|
21
|
+
@diff = x0 - x1
|
22
|
+
diff = @diff.flatten.dup()
|
23
|
+
[diff.class.cast(diff.dot(diff) / diff.size)]
|
24
|
+
end
|
25
|
+
|
26
|
+
def backward(inputs, gy)
|
27
|
+
coeff = gy[0] * gy[0].class.cast(2.0 / @diff.size)
|
28
|
+
gx0 = coeff * @diff
|
29
|
+
[gx0, -(gx0)]
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -13,17 +13,17 @@ module Chainer
|
|
13
13
|
|
14
14
|
unless class_weight.nil?
|
15
15
|
if @class_weight.ndim != 1
|
16
|
-
raise ArgumentError 'class_weight.ndim should be 1'
|
17
|
-
elsif @class_weight.
|
18
|
-
raise ArgumentError
|
16
|
+
raise ArgumentError, 'class_weight.ndim should be 1'
|
17
|
+
elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
|
18
|
+
raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
|
19
19
|
elsif @class_weight.kind_of?(Chainer::Variable)
|
20
|
-
raise ArgumentError 'class_weight should be a Numo::NArray, not a chainer.Variable'
|
20
|
+
raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
|
21
21
|
end
|
22
22
|
end
|
23
23
|
|
24
24
|
@ignore_label = ignore_label
|
25
25
|
unless ['mean', 'no'].include?(reduce)
|
26
|
-
raise ArgumentError "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
26
|
+
raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
|
27
27
|
end
|
28
28
|
|
29
29
|
@reduce = reduce
|
@@ -37,40 +37,37 @@ module Chainer
|
|
37
37
|
@y = Numo::NMath.exp(log_y)
|
38
38
|
end
|
39
39
|
if @class_weight
|
40
|
-
shape = x.ndim.times.map { |e| e == 1 ?
|
41
|
-
log_y
|
40
|
+
shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
|
41
|
+
log_y *= Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
42
42
|
end
|
43
|
-
log_yd = rollaxis(log_y, 1)
|
43
|
+
log_yd = Chainer::Functions::Loss.rollaxis(log_y, 1)
|
44
44
|
begin
|
45
|
-
log_yd = log_yd.reshape(log_yd.
|
45
|
+
log_yd = log_yd.reshape(log_yd.shape[0], true)
|
46
46
|
rescue ArgumentError
|
47
47
|
end
|
48
|
-
|
49
48
|
ravel_arr = t.dup.flatten.dup
|
50
49
|
ravel_arr[ravel_arr<0] = 0
|
51
50
|
arange_arr = t.class.new(t.size).seq
|
52
51
|
|
53
52
|
# https://github.com/chainer/chainer/blob/v2.0.2/chainer/functions/loss/softmax_cross_entropy.py#L79
|
54
53
|
log_p = []
|
55
|
-
|
56
|
-
log_p << log_yd[
|
54
|
+
ravel_arr.each_with_index do |r, i|
|
55
|
+
log_p << log_yd[r, i]
|
57
56
|
end
|
58
|
-
log_p =
|
59
|
-
|
60
|
-
log_p[log_p.eq(@ignore_label)] = 0
|
57
|
+
log_p = log_yd.class.[](*log_p)
|
58
|
+
log_p[t.flatten.dup.eq(@ignore_label)] = 0
|
61
59
|
|
62
60
|
if @reduce == 'mean'
|
63
61
|
if @normalize
|
64
62
|
count = t.ne(@ignore_label).count
|
65
63
|
else
|
66
|
-
count = x.
|
64
|
+
count = x.shape[0]
|
67
65
|
end
|
68
66
|
@coeff = 1.0 / [count, 1].max
|
69
|
-
|
70
67
|
y = log_p.sum(keepdims: true) * (-@coeff)
|
71
|
-
[y.
|
68
|
+
[y.class.cast(y[0])]
|
72
69
|
else
|
73
|
-
[-log_p.reshape(t.shape)]
|
70
|
+
[-log_p.reshape(*t.shape)]
|
74
71
|
end
|
75
72
|
end
|
76
73
|
|
@@ -87,48 +84,78 @@ module Chainer
|
|
87
84
|
|
88
85
|
if y.ndim == 2
|
89
86
|
gx = y
|
90
|
-
t[t
|
91
|
-
t.each_with_index do |v, idx|
|
92
|
-
gx[(idx * 10)...(idx * 10 + 10)][v] -= 1
|
93
|
-
end
|
87
|
+
t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
|
94
88
|
|
95
89
|
if @class_weight
|
96
|
-
shape = x.ndim.times.map { |d| d == 1 ?
|
97
|
-
c = broadcast_to(@class_weight.reshape(shape), x.shape)
|
98
|
-
c = c
|
99
|
-
gx *= broadcast_to(
|
90
|
+
shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
|
91
|
+
c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
92
|
+
c = c.class.cast(t.class.new(t.shape[0]).seq.to_a.zip(t.class.maximum(t, 0).to_a).map{|v| c[*v]})
|
93
|
+
gx *= Chainer::Functions::Loss.broadcast_to(c.expand_dims(1), gx.shape)
|
100
94
|
end
|
101
95
|
|
102
96
|
bit = t.flatten.dup
|
103
97
|
bit[t.ne(@ignore_label)] = 1
|
104
98
|
bit[bit.ne(1)] = 0
|
105
|
-
gx *= bit.reshape(t.
|
99
|
+
gx *= bit.reshape(t.shape[0], 1)
|
106
100
|
else
|
107
|
-
|
101
|
+
# in the case where y.ndim is higher than 2,
|
102
|
+
# we think that a current implementation is inefficient
|
103
|
+
# because it yields two provisional arrays for indexing.
|
104
|
+
|
105
|
+
n_unit = t.size / t.shape[0]
|
106
|
+
gx = y.reshape(y.shape[0], y.shape[1], true)
|
107
|
+
fst_index = Numo::Int32.new(t.size).seq(0) / n_unit
|
108
|
+
trd_index = Numo::Int32.new(t.size).seq(0) % n_unit
|
109
|
+
fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
|
110
|
+
if @class_weight
|
111
|
+
shape = x.ndim.times.map{|d| d == 1 ? true : 1}
|
112
|
+
c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
|
113
|
+
c = c.reshape(*gx.shape)
|
114
|
+
c = c.class.cast(fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).map{|v| c[*v]})
|
115
|
+
c = c.reshape(y.shape[0], 1, true)
|
116
|
+
gx *= Chainer::Functions::Loss.broadcast_to(c, gx.shape)
|
117
|
+
end
|
118
|
+
gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
|
119
|
+
gx = gx.reshape(*y.shape)
|
108
120
|
end
|
109
121
|
|
110
122
|
if @reduce == 'mean'
|
111
123
|
gx *= gloss * @coeff
|
112
124
|
else
|
113
|
-
|
125
|
+
gx *= gloss[true,:- , false]
|
114
126
|
end
|
115
127
|
return [gx, nil]
|
116
128
|
end
|
129
|
+
end
|
117
130
|
|
131
|
+
def rollaxis(y, axis, start: 0)
|
132
|
+
axes = (0...y.ndim).to_a
|
133
|
+
axes.delete_at(axis)
|
134
|
+
axes.insert(start <= axes.size ? start : -1, axis)
|
135
|
+
y.transpose(*axes)
|
136
|
+
end
|
118
137
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
array.class.tile(array, shape[0]).reshape(*shape)
|
138
|
+
def broadcast_to(array, shape)
|
139
|
+
if array.shape.size > shape.size
|
140
|
+
raise TypeError, "Shape of data mismatch\n array.shape.size(#{array.shape.size}) > shape.size(#{shape.size})"
|
123
141
|
end
|
124
142
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
143
|
+
tile_shape = []
|
144
|
+
shape_check = shape[-array.shape.size..-1]
|
145
|
+
shape_check.each_with_index{|s, i|
|
146
|
+
if array.shape[i] == 1
|
147
|
+
tile_shape << s
|
148
|
+
elsif array.shape[i] == s
|
149
|
+
tile_shape << 1
|
150
|
+
else
|
151
|
+
raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
|
152
|
+
end
|
153
|
+
}
|
154
|
+
|
155
|
+
array.tile(*shape[0...-array.shape.size], *tile_shape)
|
131
156
|
end
|
157
|
+
|
158
|
+
module_function :rollaxis, :broadcast_to
|
132
159
|
end
|
133
160
|
end
|
134
161
|
end
|