red-chainer 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (52) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/examples/cifar/models/vgg.rb +84 -0
  4. data/examples/cifar/train_cifar.rb +70 -0
  5. data/examples/iris.rb +103 -0
  6. data/lib/chainer.rb +17 -0
  7. data/lib/chainer/configuration.rb +2 -1
  8. data/lib/chainer/cuda.rb +18 -0
  9. data/lib/chainer/dataset/convert.rb +30 -9
  10. data/lib/chainer/datasets/cifar.rb +56 -0
  11. data/lib/chainer/datasets/mnist.rb +3 -3
  12. data/lib/chainer/datasets/tuple_dataset.rb +3 -1
  13. data/lib/chainer/function.rb +1 -0
  14. data/lib/chainer/functions/activation/leaky_relu.rb +4 -4
  15. data/lib/chainer/functions/activation/log_softmax.rb +4 -4
  16. data/lib/chainer/functions/activation/relu.rb +3 -4
  17. data/lib/chainer/functions/activation/sigmoid.rb +4 -4
  18. data/lib/chainer/functions/activation/tanh.rb +5 -5
  19. data/lib/chainer/functions/connection/convolution_2d.rb +92 -0
  20. data/lib/chainer/functions/connection/linear.rb +1 -1
  21. data/lib/chainer/functions/loss/mean_squared_error.rb +34 -0
  22. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +67 -40
  23. data/lib/chainer/functions/math/identity.rb +26 -0
  24. data/lib/chainer/functions/noise/dropout.rb +45 -0
  25. data/lib/chainer/functions/normalization/batch_normalization.rb +136 -0
  26. data/lib/chainer/functions/pooling/max_pooling_2d.rb +57 -0
  27. data/lib/chainer/functions/pooling/pooling_2d.rb +20 -0
  28. data/lib/chainer/gradient_check.rb +240 -0
  29. data/lib/chainer/initializer.rb +2 -0
  30. data/lib/chainer/initializers/constant.rb +1 -1
  31. data/lib/chainer/initializers/init.rb +5 -1
  32. data/lib/chainer/initializers/normal.rb +1 -1
  33. data/lib/chainer/iterators/serial_iterator.rb +1 -1
  34. data/lib/chainer/link.rb +11 -0
  35. data/lib/chainer/links/connection/convolution_2d.rb +98 -0
  36. data/lib/chainer/links/normalization/batch_normalization.rb +106 -0
  37. data/lib/chainer/optimizer.rb +40 -1
  38. data/lib/chainer/optimizers/momentum_sgd.rb +49 -0
  39. data/lib/chainer/parameter.rb +1 -1
  40. data/lib/chainer/serializers/marshal.rb +7 -3
  41. data/lib/chainer/testing/array.rb +32 -0
  42. data/lib/chainer/training/extensions/exponential_shift.rb +78 -0
  43. data/lib/chainer/training/extensions/snapshot.rb +1 -1
  44. data/lib/chainer/training/standard_updater.rb +4 -0
  45. data/lib/chainer/training/trainer.rb +1 -1
  46. data/lib/chainer/utils/array.rb +13 -2
  47. data/lib/chainer/utils/conv.rb +59 -0
  48. data/lib/chainer/utils/math.rb +72 -0
  49. data/lib/chainer/utils/variable.rb +7 -3
  50. data/lib/chainer/version.rb +1 -1
  51. data/red-chainer.gemspec +1 -0
  52. metadata +37 -3
@@ -3,7 +3,7 @@ require 'zlib'
3
3
  module Chainer
4
4
  module Datasets
5
5
  module Mnist
6
- def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::DFloat, label_dtype: Numo::Int32)
6
+ def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::SFloat, label_dtype: Numo::Int32)
7
7
  train_raw = retrieve_mnist_training
8
8
  train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
9
9
 
@@ -15,9 +15,9 @@ module Chainer
15
15
  def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype)
16
16
  images = raw[:x]
17
17
  if ndim == 2
18
- images = images.reshape(-1, 28, 28)
18
+ images = images.reshape(true, 28, 28)
19
19
  elsif ndim == 3
20
- images = images.reshape(-1, 1, 28, 28)
20
+ images = images.reshape(true, 1, 28, 28)
21
21
  elsif ndim != 1
22
22
  raise "invalid ndim for MNIST dataset"
23
23
  end
@@ -16,7 +16,9 @@ module Chainer
16
16
  end
17
17
 
18
18
  def [](index)
19
- batches = @datasets.map { |dataset| dataset.ndim > 1 ? dataset[index, 0...dataset.shape[1]] : dataset[index] }
19
+ batches = @datasets.map do |dataset|
20
+ dataset.ndim > 1 ? dataset[index, false] : dataset[index]
21
+ end
20
22
  if index.kind_of?(Enumerable)
21
23
  length = batches[0].shape[0]
22
24
  length.times.map {|i| batches.map { |m| m[i] } }
@@ -23,6 +23,7 @@ module Chainer
23
23
  @input_indexes_to_retain = nil
24
24
  @output_indexes_to_retain = nil
25
25
  outputs = forward(in_data)
26
+ raise if !outputs.is_a? Array
26
27
 
27
28
  ret = outputs.map do |y|
28
29
  Variable.new(y, requires_grad: requires_grad)
@@ -13,19 +13,19 @@ module Chainer
13
13
  #
14
14
  # where $a$ is a configurable slope value.
15
15
  #
16
- # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
16
+ # @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
17
17
  # @param [float] slope Slope value $a$.
18
18
  # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
19
19
  # @example
20
- # > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
20
+ # > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
21
21
  # > x
22
- # => Numo::DFloat#shape=[3,2]
22
+ # => Numo::SFloat#shape=[3,2]
23
23
  # [[-1, 0],
24
24
  # [2, -3],
25
25
  # [-2, 1]]
26
26
  # > F = Chainer::Functions::Activation::LeakyReLU
27
27
  # > F.leaky_relu(x, slope:0.2).data
28
- # => Numo::DFloat#shape=[3,2]
28
+ # => Numo::SFloat#shape=[3,2]
29
29
  # [[-0.2, 0],
30
30
  # [2, -0.6],
31
31
  # [-0.4, 1]]
@@ -36,19 +36,19 @@ module Chainer
36
36
  # because +softmax(x)+ may returns +0+.
37
37
  # +log_softmax+ method is more stable.
38
38
  #
39
- # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
39
+ # @param [Chainer::Variable or Numo::NArray] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
40
40
  # @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
41
41
  #
42
42
  # @see Chainer::Functions::Softmax
43
43
  #
44
44
  # @example
45
- # > x = Numo::DFloat[[0, 1, 2], [0, 2, 4]]
46
- # => Numo::DFloat#shape=[2,3]
45
+ # > x = Numo::SFloat[[0, 1, 2], [0, 2, 4]]
46
+ # => Numo::SFloat#shape=[2,3]
47
47
  # [[0, 1, 2],
48
48
  # [0, 2, 4]]
49
49
  # > F = Chainer::Functions::Activation::LogSoftmax
50
50
  # > F.log_softmax(x).data
51
- # => Numo::DFloat#shape=[2,3]
51
+ # => Numo::SFloat#shape=[2,3]
52
52
  # [[-2.40761, -1.40761, -0.407606],
53
53
  # [-4.14293, -2.14293, -0.142932]]
54
54
  # @example (T.B.I : F.log, F.softmax)
@@ -9,10 +9,10 @@ module Chainer
9
9
  # f(x)=\\max(0, x).
10
10
  # $$
11
11
  #
12
- # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
12
+ # @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
13
  # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
14
  # @example
15
- # > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
15
+ # > x = Numo::SFloat[[-1, 0], [2, -3], [-2, 1]]
16
16
  # > (x < 0).any?
17
17
  # => true
18
18
  # > F = Chainer::Functions::Activation::Relu
@@ -29,8 +29,7 @@ module Chainer
29
29
  def forward_cpu(x)
30
30
  retain_inputs([])
31
31
  retain_outputs([0])
32
- x[0][x[0]<=0] = 0
33
- [Utils::Array.force_array(x[0])]
32
+ [Utils::Array.force_array(x[0].class.maximum(x[0], 0))]
34
33
  end
35
34
 
36
35
  def backward_cpu(x, gy)
@@ -9,15 +9,15 @@ module Chainer
9
9
  # f(x)=(1 + \\exp(-x))^ { -1 }.
10
10
  # $$
11
11
  #
12
- # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
12
+ # @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
13
  # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
14
  # @example It maps the input values into the range of $`[0, 1]`$.
15
- # > x = Numo::DFloat.new(3).seq(-2, 2)
16
- # => Numo::DFloat#shape=[3]
15
+ # > x = Numo::SFloat.new(3).seq(-2, 2)
16
+ # => Numo::SFloat#shape=[3]
17
17
  # [-2, 0, 2]
18
18
  # > F = Chainer::Functions::Activation::Sigmoid
19
19
  # > F.sigmoid(x).data
20
- # => Numo::DFloat#shape=[3]
20
+ # => Numo::SFloat#shape=[3]
21
21
  # [0.119203, 0.5, 0.880797]
22
22
  #
23
23
  def self.sigmoid(x)
@@ -9,15 +9,15 @@ module Chainer
9
9
  # f(x)=\\tanh(x).
10
10
  # $$
11
11
  #
12
- # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
12
+ # @param [Chainer::Variable or Numo::NArray] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
13
  # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
14
  # @example
15
- # > x = Numo::DFloat.new(3).seq(-1, 2)
16
- # => Numo::DFloat#shape=[3]
15
+ # > x = Numo::SFloat.new(3).seq(-1, 2)
16
+ # => Numo::SFloat#shape=[3]
17
17
  # [-1, 1, 3]
18
18
  # > F = Chainer::Functions::Activation::Tanh
19
19
  # > F.tanh(x).data
20
- # => Numo::DFloat#shape=[3]
20
+ # => Numo::SFloat#shape=[3]
21
21
  # [-0.761594, 0.761594, 0.995055]
22
22
  #
23
23
  def self.tanh(x)
@@ -33,7 +33,7 @@ module Chainer
33
33
 
34
34
  def backward_cpu(x, gy)
35
35
  y = @output_data[0]
36
- one = y.dtype.type(1)
36
+ one = y.class.cast(1)
37
37
  [Utils::Array.force_array(gy[0] * (one - y * y))]
38
38
  end
39
39
  end
@@ -0,0 +1,92 @@
1
+ module Chainer
2
+ module Functions
3
+ module Connection
4
+ class Convolution2DFunction < Chainer::Function
5
+ # Two-dimensional convolution function.
6
+ # This is an implementation of two-dimensional convolution in ConvNets.
7
+ # It takes three variables: the input image `x`, the filter weight `w`, and the bias vector `b`.
8
+ #
9
+ # a notation for dimensionalities.
10
+ #
11
+ # - :math:`n` is the batch size.
12
+ # - :math:`c_I` and :math:`c_O` are the number of the input and output channels, respectively.
13
+ # - :math:`h_I` and :math:`w_I` are the height and width of the input image, respectively.
14
+ # - :math:`h_K` and :math:`w_K` are the height and width of the filters, respectively.
15
+ # - :math:`h_P` and :math:`w_P` are the height and width of the spatial padding size, respectively.
16
+ #
17
+ # Then the `Convolution2D` function computes correlations between filters and patches of size :math:`(h_K, w_K)` in `x`.
18
+ # Patches are extracted at positions shifted by multiples of `stride` from the first position `(-h_P, -w_P)` for each spatial axis.
19
+ # The right-most (or bottom-most) patches do not run over the padded spatial size.
20
+ # Let :math:`(s_Y, s_X)` be the stride of filter application.
21
+ # Then, the output size :math:`(h_O, w_O)` is determined by the following equations:
22
+ #
23
+ # math:
24
+ # h_O &= (h_I + 2h_P - h_K) / s_Y + 1,\\\\
25
+ # w_O &= (w_I + 2w_P - w_K) / s_X + 1.
26
+ # If `cover_all` option is `true`, the filter will cover the all spatial locations.
27
+ # So, if the last stride of filter does not cover the end of spatial locations,
28
+ # an addtional stride will be applied to the end part of spatial locations.
29
+ # In this case, the output size :math:`(h_O, w_O)` is determined by the following equations:
30
+ #
31
+ # math:
32
+ # h_O &= (h_I + 2h_P - h_K + s_Y - 1) / s_Y + 1,\\\\
33
+ # w_O &= (w_I + 2w_P - w_K + s_X - 1) / s_X + 1.
34
+ # If the bias vector is given, then it is added to all spatial locations of the output of convolution.
35
+ #
36
+ # @param [Chainer::Variable or Numo::NArray] x Input variable of shape :math:`(n, c_I, h_I, w_I)`.
37
+ # @param [Chainer::Variable or Numo::NArray] w Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`.
38
+ # @param [Chainer::Variable or Numo::NArray] b Bias variable of length :math:`c_O`
39
+ # @param [Int or 2-D Array] stride Stride of filter applications. `stride=s` and `stride=(s, s)` are equivalent.
40
+ # @param [Int or 2-D Array] pad Spatial padding width for input arrays.
41
+ # @param [Boolean] cover_all If `true`, all spatial locations are convoluted into some output pixels.
42
+ # @return [Chainer::Variable] Output variable of shape :math:`(n, c_O, h_O, w_O)`.
43
+ def self.convolution_2d(x, w, b: nil, stride: 1, pad: 0, cover_all: false)
44
+ func = self.new(stride: stride, pad: pad, cover_all: cover_all)
45
+ if b.nil?
46
+ func.(x, w)
47
+ else
48
+ func.(x, w, b)
49
+ end
50
+ end
51
+
52
+ def initialize(stride: 1, pad: 0, cover_all: false)
53
+ @sy, @sx = stride.is_a?(Array) ? stride : [stride, stride]
54
+ @ph, @pw = pad.is_a?(Array) ? pad : [pad, pad]
55
+ @cover_all = cover_all
56
+ end
57
+
58
+ def forward_cpu(inputs)
59
+ x = inputs[0]
60
+ w = inputs[1]
61
+ b = inputs.size == 3 ? inputs[2] : nil
62
+
63
+ kh, kw = w.shape[2], w.shape[3]
64
+
65
+ @col = Chainer::Utils::Conv.im2col_cpu(x, kh, kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)
66
+ y = Chainer::Utils::Math.tensordot(@col, w, [[1, 2, 3], [1, 2, 3]])
67
+ y += b if b
68
+
69
+ [y.transpose(0, 3, 1, 2)]
70
+ end
71
+
72
+ def backward_cpu(inputs, grad_outputs)
73
+ x, w, b = inputs[0], inputs[1], inputs[2]
74
+ gy = grad_outputs[0]
75
+ height, width = x.shape[2..-1]
76
+
77
+ gw = Chainer::Utils::Math.tensordot(gy, @col, [[0, 2, 3], [0, 4, 5]])
78
+ gcol = Chainer::Utils::Math.tensordot(w, gy, [0, 1])
79
+ gcol = gcol.transpose(3, 0, 1, 2)
80
+ gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, height, width)
81
+
82
+ if b.nil?
83
+ [gx, gw]
84
+ else
85
+ gb = gy.sum(axis: [0, 2, 3])
86
+ [gx, gw, gb]
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
92
+ end
@@ -40,7 +40,7 @@ module Chainer
40
40
 
41
41
  def as_mat(x)
42
42
  return x if x.ndim == 2
43
- x.reshape(x.size, -1)
43
+ x.reshape(x.shape.first, true)
44
44
  end
45
45
  end
46
46
  end
@@ -0,0 +1,34 @@
1
+ module Chainer
2
+ module Functions
3
+ module Loss
4
+ # Mean squared error (a.k.a. Euclidean loss) function.
5
+ class MeanSquaredError < Function
6
+ # Mean squared error function.
7
+ #
8
+ # This function computes mean squared error between two variables. The mean
9
+ # is taken over the minibatch. Note that the error is not scaled by 1/2.
10
+ #
11
+ # @param [Chainer::Variable or Numo::NArray] x0 Input variable.
12
+ # @param [Chainer::Variable or Numo::NArray] x1 Input variable.
13
+ # @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
14
+ #
15
+ def self.mean_squared_error(x0, x1)
16
+ self.new.(x0, x1)
17
+ end
18
+
19
+ def forward_cpu(inputs)
20
+ x0, x1 = inputs
21
+ @diff = x0 - x1
22
+ diff = @diff.flatten.dup()
23
+ [diff.class.cast(diff.dot(diff) / diff.size)]
24
+ end
25
+
26
+ def backward(inputs, gy)
27
+ coeff = gy[0] * gy[0].class.cast(2.0 / @diff.size)
28
+ gx0 = coeff * @diff
29
+ [gx0, -(gx0)]
30
+ end
31
+ end
32
+ end
33
+ end
34
+ end
@@ -13,17 +13,17 @@ module Chainer
13
13
 
14
14
  unless class_weight.nil?
15
15
  if @class_weight.ndim != 1
16
- raise ArgumentError 'class_weight.ndim should be 1'
17
- elsif @class_weight.dtype != Numo::DFloat
18
- raise ArgumentError 'The dtype of class_weight should be \'Numo::DFloat\''
16
+ raise ArgumentError, 'class_weight.ndim should be 1'
17
+ elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
18
+ raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
19
19
  elsif @class_weight.kind_of?(Chainer::Variable)
20
- raise ArgumentError 'class_weight should be a Numo::NArray, not a chainer.Variable'
20
+ raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
21
21
  end
22
22
  end
23
23
 
24
24
  @ignore_label = ignore_label
25
25
  unless ['mean', 'no'].include?(reduce)
26
- raise ArgumentError "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
26
+ raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
27
27
  end
28
28
 
29
29
  @reduce = reduce
@@ -37,40 +37,37 @@ module Chainer
37
37
  @y = Numo::NMath.exp(log_y)
38
38
  end
39
39
  if @class_weight
40
- shape = x.ndim.times.map { |e| e == 1 ? -1 : 1 }
41
- log_y += broadcast_to(@class_weight.reshape(*shape), x.shape)
40
+ shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
41
+ log_y *= Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
42
42
  end
43
- log_yd = rollaxis(log_y, 1)
43
+ log_yd = Chainer::Functions::Loss.rollaxis(log_y, 1)
44
44
  begin
45
- log_yd = log_yd.reshape(log_yd.size, -1)
45
+ log_yd = log_yd.reshape(log_yd.shape[0], true)
46
46
  rescue ArgumentError
47
47
  end
48
-
49
48
  ravel_arr = t.dup.flatten.dup
50
49
  ravel_arr[ravel_arr<0] = 0
51
50
  arange_arr = t.class.new(t.size).seq
52
51
 
53
52
  # https://github.com/chainer/chainer/blob/v2.0.2/chainer/functions/loss/softmax_cross_entropy.py#L79
54
53
  log_p = []
55
- arange_arr.each do |col_idx|
56
- log_p << log_yd[ravel_arr, col_idx][col_idx]
54
+ ravel_arr.each_with_index do |r, i|
55
+ log_p << log_yd[r, i]
57
56
  end
58
- log_p = Numo::NArray.[](*log_p)
59
-
60
- log_p[log_p.eq(@ignore_label)] = 0
57
+ log_p = log_yd.class.[](*log_p)
58
+ log_p[t.flatten.dup.eq(@ignore_label)] = 0
61
59
 
62
60
  if @reduce == 'mean'
63
61
  if @normalize
64
62
  count = t.ne(@ignore_label).count
65
63
  else
66
- count = x.size
64
+ count = x.shape[0]
67
65
  end
68
66
  @coeff = 1.0 / [count, 1].max
69
-
70
67
  y = log_p.sum(keepdims: true) * (-@coeff)
71
- [y.reshape(())]
68
+ [y.class.cast(y[0])]
72
69
  else
73
- [-log_p.reshape(t.shape)]
70
+ [-log_p.reshape(*t.shape)]
74
71
  end
75
72
  end
76
73
 
@@ -87,48 +84,78 @@ module Chainer
87
84
 
88
85
  if y.ndim == 2
89
86
  gx = y
90
- t[t<0] = 0
91
- t.each_with_index do |v, idx|
92
- gx[(idx * 10)...(idx * 10 + 10)][v] -= 1
93
- end
87
+ t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
94
88
 
95
89
  if @class_weight
96
- shape = x.ndim.times.map { |d| d == 1 ? -1 : 1 }
97
- c = broadcast_to(@class_weight.reshape(shape), x.shape)
98
- c = c[Numo::DFloat.new(t.size).seq, t]
99
- gx *= broadcast_to(t.expand_dims(1), gx.shape)
90
+ shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
91
+ c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
92
+ c = c.class.cast(t.class.new(t.shape[0]).seq.to_a.zip(t.class.maximum(t, 0).to_a).map{|v| c[*v]})
93
+ gx *= Chainer::Functions::Loss.broadcast_to(c.expand_dims(1), gx.shape)
100
94
  end
101
95
 
102
96
  bit = t.flatten.dup
103
97
  bit[t.ne(@ignore_label)] = 1
104
98
  bit[bit.ne(1)] = 0
105
- gx *= bit.reshape(t.size, 1)
99
+ gx *= bit.reshape(t.shape[0], 1)
106
100
  else
107
- raise 'TODO: ndim > 2 backward'
101
+ # in the case where y.ndim is higher than 2,
102
+ # we think that a current implementation is inefficient
103
+ # because it yields two provisional arrays for indexing.
104
+
105
+ n_unit = t.size / t.shape[0]
106
+ gx = y.reshape(y.shape[0], y.shape[1], true)
107
+ fst_index = Numo::Int32.new(t.size).seq(0) / n_unit
108
+ trd_index = Numo::Int32.new(t.size).seq(0) % n_unit
109
+ fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
110
+ if @class_weight
111
+ shape = x.ndim.times.map{|d| d == 1 ? true : 1}
112
+ c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
113
+ c = c.reshape(*gx.shape)
114
+ c = c.class.cast(fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).map{|v| c[*v]})
115
+ c = c.reshape(y.shape[0], 1, true)
116
+ gx *= Chainer::Functions::Loss.broadcast_to(c, gx.shape)
117
+ end
118
+ gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
119
+ gx = gx.reshape(*y.shape)
108
120
  end
109
121
 
110
122
  if @reduce == 'mean'
111
123
  gx *= gloss * @coeff
112
124
  else
113
- raise 'TODO: reduce'
125
+ gx *= gloss[true,:- , false]
114
126
  end
115
127
  return [gx, nil]
116
128
  end
129
+ end
117
130
 
131
+ def rollaxis(y, axis, start: 0)
132
+ axes = (0...y.ndim).to_a
133
+ axes.delete_at(axis)
134
+ axes.insert(start <= axes.size ? start : -1, axis)
135
+ y.transpose(*axes)
136
+ end
118
137
 
119
- private
120
-
121
- def broadcast_to(array, shape)
122
- array.class.tile(array, shape[0]).reshape(*shape)
138
+ def broadcast_to(array, shape)
139
+ if array.shape.size > shape.size
140
+ raise TypeError, "Shape of data mismatch\n array.shape.size(#{array.shape.size}) > shape.size(#{shape.size})"
123
141
  end
124
142
 
125
- def rollaxis(y, axis, start: 0)
126
- axes = (0...y.ndim).to_a
127
- axes.delete_at(axis)
128
- axes.insert(start, axis)
129
- y.transpose(*axes)
130
- end
143
+ tile_shape = []
144
+ shape_check = shape[-array.shape.size..-1]
145
+ shape_check.each_with_index{|s, i|
146
+ if array.shape[i] == 1
147
+ tile_shape << s
148
+ elsif array.shape[i] == s
149
+ tile_shape << 1
150
+ else
151
+ raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
152
+ end
153
+ }
154
+
155
+ array.tile(*shape[0...-array.shape.size], *tile_shape)
131
156
  end
157
+
158
+ module_function :rollaxis, :broadcast_to
132
159
  end
133
160
  end
134
161
  end