red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -0,0 +1,159 @@
1
+ module Chainer
2
+ module Functions
3
+ module Connection
4
+ class Deconvolution2DFunction < Chainer::FunctionNode
5
+ attr_reader :sy, :sx, :ph, :pw, :cover_all
6
+
7
+ # Two dimensional deconvolution function.
8
+ #
9
+ # This is an implementation of two-dimensional deconvolution.
10
+ # In most of deep learning frameworks and papers,
11
+ # this function is called <b>transposed convolution</b>.
12
+ # But because of historical reasons (e.g. paper by Ziller Deconvolutional Networks) and backward compatibility,
13
+ # this function is called +deconvolution+ in Chainer.
14
+ #
15
+ # It takes three variables: input image +x+,
16
+ # the filter weight +W+, and the bias vector +b+.
17
+ #
18
+ # - $n$ is the batch size.
19
+ # - $c_I$ and $c_O$ are the number of the input and output channels, respectively.
20
+ # - $h_I$ and $w_I$ are the height and width of the input image, respectively.
21
+ # - $h_K$ and $w_K$ are the height and width of the filters, respectively.
22
+ # - $h_P$ and $w_P$ are the height and width of the spatial padding size, respectively.
23
+ #
24
+ # Let $(s_Y, s_X)$ be the stride of filter application.
25
+ # Then, the output size $(h_O, w_O)$ is estimated by the following equations:
26
+ #
27
+ # $
28
+ # h_O &= s_Y (h_I - 1) + h_K - 2h_P,
29
+ # w_O &= s_X (w_I - 1) + w_K - 2w_P.
30
+ # $
31
+ #
32
+ # @param [Chainer::Variable or Numo::NArray] x Input variable of shape $(n, c_I, h_I, w_I)$.
33
+ # @param [Chainer::Variable or Numo::NArray] w Weight variable of shape $(c_I, c_O, h_K, w_K)$.
34
+ # @param [Chainer::Variable or Numo::NArray] b Bias variable of length $c_O$ (optional).
35
+ # @param [integer or Array<integer>] stride Stride of filter applications. +stride=s+ and +stride=[s, s]+ are equivalent.
36
+ # @param [integer or Array<integer>] pad Spatial padding width for input arrays. +pad=p+ and +pad=[p, p]+ are equivalent.
37
+ # @param [integer or Arrat<integer>] outsize Expected output size of deconvolutional operation.
38
+ # It should be pair of height and width $(h_O, w_O)$.
39
+ # Default value is +nil+ and the outsize is estimated by input size, stride and pad.
40
+ # @return [Chainer::Variable] Output variable of shape $(n, c_O, h_O, w_O)$.
41
+ #
42
+ # Example
43
+ # > n = 10
44
+ # > c_i, c_o = 1, 3
45
+ # > h_i, w_i = 5, 10
46
+ # > h_k, w_k = 10, 10
47
+ # > h_p, w_p = 5, 5
48
+ # > x = Numo::DFloat.new(n, c_i, h_i, w_i).rand
49
+ # > x.shape
50
+ # => [10, 1, 5, 10]
51
+ # > w = Numo::DFloat.new(c_i, c_o, h_k, w_k).rand
52
+ # > w.shape
53
+ # => [1, 3, 10, 10]
54
+ # > b = Numo::DFloat.new(c_o).rand
55
+ # > b.shape
56
+ # => [3]
57
+ # > s_y, s_x = 5, 5
58
+ # > y = Chainer::Functions::Connection::Deconvolution2DFunction.deconvolution_2d(x, w, b: b, stride: [s_y, s_x], pad: [h_p, w_p])
59
+ # > y.shape
60
+ # => [10, 3, 20, 45]
61
+ # > h_o = s_y * (h_i - 1) + h_k - 2 * h_p
62
+ # > w_o = s_x * (w_i - 1) + w_k - 2 * w_p
63
+ # > y.shape == [n, c_o, h_o, w_o]
64
+ # => true
65
+ def self.deconvolution_2d(x, w, b: nil, stride: 1, pad: 0, outsize: nil)
66
+ func = Deconvolution2DFunction.new(stride: stride, pad: pad, outsize: outsize)
67
+ if b.nil?
68
+ args = x, w
69
+ else
70
+ args = x, w, b
71
+ end
72
+ func.apply(args).first
73
+ end
74
+
75
+ def initialize(stride: 1, pad: 0, outsize: nil)
76
+ @cover_all = nil
77
+
78
+ @sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
79
+ @ph, @pw = pad.is_a?(::Array) ? pad : [pad, pad]
80
+ @outh, @outw = outsize.nil? ? [nil, nil] : outsize
81
+ end
82
+
83
+ def forward(inputs)
84
+ retain_inputs([0, 1])
85
+ x, w = inputs[0...2]
86
+ b = inputs.size == 3 ? inputs[2] : nil
87
+
88
+ unless inputs.all? { |i| i.is_a?(Numo::NArray) }
89
+ if b.nil?
90
+ raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}"
91
+ else
92
+ raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}, b: #{b.class}"
93
+ end
94
+ end
95
+
96
+ kh, kw = w.shape[2..-1]
97
+ _, _, x_h, x_w = x.shape
98
+
99
+ gcol = Chainer::Utils::Math.tensordot(w, x, [0, 1]).cast_to(x.class)
100
+ # - k, m, n: shape of out_channel
101
+ # - b: number of inputs
102
+ # - h, w: height and width of kernels
103
+ # k, m, n, b, h, w -> b, k, m, n, h, w
104
+ gcol = gcol.transpose(3, 0, 1, 2, 4, 5)
105
+
106
+ if @outh.nil?
107
+ @outh = Chainer::Utils::Conv.get_deconv_outsize(x_h, kh, @sy, @ph)
108
+ raise TypeError, 'Height in the output should be positive.' if @outh <= 0
109
+ end
110
+ if @outw.nil?
111
+ @outw = Chainer::Utils::Conv.get_deconv_outsize(x_w, kw, @sx, @pw)
112
+ raise TypeError, 'Width in the output should be positive.' if @outw <= 0
113
+ end
114
+
115
+ y = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, @outh, @outw)
116
+ if !b.nil?
117
+ y += b.reshape(1, b.size, 1, 1)
118
+ end
119
+ [y]
120
+ end
121
+
122
+ def backward(indexes, grad_outputs)
123
+ x, w = get_retained_inputs
124
+ gy = grad_outputs.first
125
+
126
+ ret = []
127
+
128
+ if indexes.include?(0)
129
+ set_cover_all(x, w) if @cover_all.nil?
130
+ gw = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(gy, w, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
131
+ ret << gw
132
+ end
133
+
134
+ if indexes.include?(1)
135
+ set_cover_all(x, w) if @cover_all.nil?
136
+ gw = Chainer::Functions::Connection::Convolution2DGradW.new(self).apply([gy, x]).first
137
+ ret << gw
138
+ end
139
+
140
+ if indexes.include?(2)
141
+ gb = Chainer::Functions::Math::Sum.sum(gy, axis: [0, 2, 3])
142
+ ret << gb
143
+ end
144
+
145
+ ret
146
+ end
147
+
148
+ private
149
+
150
+ def set_cover_all(x, w)
151
+ in_h, in_w = x.shape[2..-1]
152
+ kh, kw = w.shape[2..-1]
153
+
154
+ @cover_all = in_h != Chainer::Utils::Conv.get_conv_outsize(@outh, kh, @sy, @ph) || in_w != Chainer::Utils::Conv.get_conv_outsize(@outw, kw, @sx, @pw)
155
+ end
156
+ end
157
+ end
158
+ end
159
+ end
@@ -1,17 +1,23 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Connection
4
- class LinearFunction < Chainer::Function
4
+ class LinearFunction < Chainer::FunctionNode
5
5
  def self.linear(x, w, b=nil)
6
+ if x.ndim > 2
7
+ x = x.reshape(x.shape.first, -1)
8
+ end
9
+
6
10
  if b.nil?
7
- self.new.(x, w)
11
+ args = x, w
8
12
  else
9
- self.new.(x, w, b)
13
+ args = x, w, b
10
14
  end
15
+
16
+ self.new.apply(args).first
11
17
  end
12
18
 
13
19
  def forward(inputs)
14
- x = as_mat(inputs[0])
20
+ x = inputs[0]
15
21
  w = inputs[1]
16
22
 
17
23
  y = x.dot(w.transpose).cast_to(x.class)
@@ -19,28 +25,29 @@ module Chainer
19
25
  b = inputs[2]
20
26
  y += b
21
27
  end
22
- return [y]
23
- end
24
28
 
25
- def backward(inputs, grad_outputs)
26
- x = as_mat(inputs[0])
27
- w = inputs[1]
28
- gy = grad_outputs[0]
29
- gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape)
30
- gw = gy.transpose.dot(x).cast_to(w.class)
31
- if inputs.size == 3
32
- gb = gy.sum(0)
33
- [gx, gw, gb]
34
- else
35
- [gx, gw]
36
- end
29
+ retain_inputs([0, 1])
30
+ return [y]
37
31
  end
38
32
 
39
- private
33
+ def backward(indexes, grad_outputs)
34
+ x, w = get_retained_inputs
35
+ gy = grad_outputs.first
40
36
 
41
- def as_mat(x)
42
- return x if x.ndim == 2
43
- x.reshape(x.shape.first, true)
37
+ ret = []
38
+ if indexes.include?(0)
39
+ gx = LinearFunction.linear(gy, w.transpose)
40
+ ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
41
+ end
42
+ if indexes.include?(1)
43
+ gw = LinearFunction.linear(gy.transpose, x.transpose)
44
+ ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
45
+ end
46
+ if indexes.include?(2)
47
+ gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
48
+ ret << gb
49
+ end
50
+ ret
44
51
  end
45
52
  end
46
53
  end
@@ -12,13 +12,13 @@ module Chainer
12
12
 
13
13
  def forward(inputs)
14
14
  y, t = inputs
15
+ xm = Chainer.get_array_module(*inputs)
15
16
  if @ignore_label
16
17
  mask = t.eq(@ignore_label)
17
18
  ignore_cnt = mask.count
18
19
 
19
- # this work
20
- pred = y.max_index(axis: 1).to_a.map.with_index { |val, idx| val - y.shape[1] * idx}
21
- pred = y.class[*pred].reshape(*t.shape)
20
+ pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
21
+ pred = pred.reshape(*t.shape)
22
22
  pred[mask] = @ignore_label
23
23
  count = pred.eq(t).count - ignore_cnt
24
24
 
@@ -30,8 +30,8 @@ module Chainer
30
30
  [y.class.cast(count.to_f / total)]
31
31
  end
32
32
  else
33
- pred = y.max_index(axis: 1).to_a.map.with_index { |val, idx| val - y.shape[1] * idx}
34
- pred = y.class[*pred].reshape(*t.shape)
33
+ pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
34
+ pred = pred.reshape(*t.shape)
35
35
 
36
36
  [y.class.cast(y.class[pred.eq(t)].mean)]
37
37
  end
@@ -2,31 +2,40 @@ module Chainer
2
2
  module Functions
3
3
  module Loss
4
4
  # Mean squared error (a.k.a. Euclidean loss) function.
5
- class MeanSquaredError < Function
5
+ class MeanSquaredError < FunctionNode
6
6
  # Mean squared error function.
7
7
  #
8
8
  # This function computes mean squared error between two variables. The mean
9
9
  # is taken over the minibatch. Note that the error is not scaled by 1/2.
10
10
  #
11
- # @param [Chainer::Variable or Numo::NArray] x0 Input variable.
12
- # @param [Chainer::Variable or Numo::NArray] x1 Input variable.
11
+ # @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x0 Input variable.
12
+ # @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x1 Input variable.
13
13
  # @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
14
14
  #
15
15
  def self.mean_squared_error(x0, x1)
16
- self.new.(x0, x1)
16
+ self.new.apply([x0, x1]).first
17
17
  end
18
18
 
19
- def forward_cpu(inputs)
20
- x0, x1 = inputs
21
- @diff = x0 - x1
22
- diff = @diff.flatten.dup()
19
+ def forward(inputs)
20
+ retain_inputs([0, 1])
21
+ diff = (inputs[0] - inputs[1]).flatten.dup
23
22
  [diff.class.cast(diff.dot(diff) / diff.size)]
24
23
  end
25
24
 
26
- def backward(inputs, gy)
27
- coeff = gy[0] * gy[0].class.cast(2.0 / @diff.size)
28
- gx0 = coeff * @diff
29
- [gx0, -(gx0)]
25
+ def backward(indexes, gy)
26
+ x0, x1 = get_retained_inputs
27
+ diff = x0 - x1
28
+ gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape)
29
+ gx0 = gy0 * diff * (2.0 / diff.size)
30
+
31
+ ret = []
32
+ if indexes.include?(0)
33
+ ret << gx0
34
+ end
35
+ if indexes.include?(1)
36
+ ret << -gx0
37
+ end
38
+ ret
30
39
  end
31
40
  end
32
41
  end
@@ -2,68 +2,101 @@ module Chainer
2
2
  module Functions
3
3
  module Loss
4
4
  class SoftmaxCrossEntropy < Function
5
- def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
6
- self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
5
+ def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean', enable_double_backprop: false)
6
+ if enable_double_backprop
7
+ self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
8
+ else
9
+ self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
10
+ end
11
+ end
12
+
13
+ def self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
14
+ if t.is_a?(Chainer::Variable)
15
+ t = t.data
16
+ end
17
+
18
+ self.check_class_weight_option(class_weight)
19
+ self.check_reduce_option(reduce)
20
+
21
+ loss = -Activation::LogSoftmax.log_softmax(x)
22
+
23
+ if class_weight
24
+ shape = x.ndim.times.map { |d| d != 1 ? 1 : class_weight.shape[-1] }
25
+ class_weight = Chainer::Functions::Array::BroadcastTo.broadcast_to(class_weight.reshape(*shape), x.shape)
26
+ loss = loss * class_weight
27
+ end
28
+
29
+ dtype = x.is_a?(Chainer::Variable) ? x.dtype : x.class
30
+ in_use = t.ne(ignore_label).cast_to(dtype)
31
+
32
+ loss = Chainer::Functions::Array::Rollaxis.rollaxis(loss, 1, start: loss.ndim)
33
+
34
+ # TODO: loss = chainer.functions.reshape(loss, (-1, loss.shape[-1]))
35
+ shape = loss.shape
36
+ last_shape = shape.pop
37
+ loss = Chainer::Functions::Array::Reshape.reshape(loss, [shape.inject(:*), last_shape])
38
+
39
+ # Replace ignore_label value with one valid for F.select_item below.
40
+ t = t.clip(0, loss.shape[1] - 1)
41
+
42
+ loss = Chainer::Functions::Array::SelectItem.select_item(loss, t.flatten.dup)
43
+ loss = Chainer::Functions::Array::Reshape.reshape(loss, t.shape)
44
+
45
+ loss = loss * in_use
46
+
47
+ if reduce == "mean"
48
+ count = normalize ? in_use.sum : x.shape.first
49
+ count = [count, 1.0].max
50
+ loss = loss * (1.0 / count)
51
+ return Chainer::Functions::Math::Sum.sum(loss)
52
+ else
53
+ return loss
54
+ end
7
55
  end
8
56
 
9
57
  def initialize(normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
10
58
  @normalize = normalize
11
59
  @cache_score = cache_score
60
+ self.class.check_class_weight_option(class_weight)
12
61
  @class_weight = class_weight
13
62
 
14
- unless class_weight.nil?
15
- if @class_weight.ndim != 1
16
- raise ArgumentError, 'class_weight.ndim should be 1'
17
- elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
18
- raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
19
- elsif @class_weight.kind_of?(Chainer::Variable)
20
- raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
21
- end
22
- end
23
-
24
63
  @ignore_label = ignore_label
25
- unless ['mean', 'no'].include?(reduce)
26
- raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
27
- end
28
64
 
65
+ self.class.check_reduce_option(reduce)
29
66
  @reduce = reduce
30
67
  end
31
68
 
32
- def forward_cpu(inputs)
69
+ def forward(inputs)
70
+ xm = Chainer.get_array_module(*inputs)
33
71
  x, t = inputs
34
72
  log_y = Activation._log_softmax(x)
35
73
 
36
74
  if @cache_score
37
- @y = Numo::NMath.exp(log_y)
75
+ @y = xm::NMath.exp(log_y)
38
76
  end
39
77
  if @class_weight
40
78
  shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
41
- log_y *= Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
79
+ log_y *= Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
42
80
  end
43
- log_yd = Chainer::Functions::Loss.rollaxis(log_y, 1)
81
+ log_yd = Chainer::Utils::Array.rollaxis(log_y, 1)
44
82
  begin
45
83
  log_yd = log_yd.reshape(log_yd.shape[0], true)
46
84
  rescue ArgumentError
47
85
  end
48
- ravel_arr = t.dup.flatten.dup
49
- ravel_arr[ravel_arr<0] = 0
50
- arange_arr = t.class.new(t.size).seq
51
-
52
- # https://github.com/chainer/chainer/blob/v2.0.2/chainer/functions/loss/softmax_cross_entropy.py#L79
53
- log_p = []
54
- ravel_arr.each_with_index do |r, i|
55
- log_p << log_yd[r, i]
86
+
87
+ log_p = log_yd[t.class.maximum(t.flatten, 0), t.class.new(t.size).seq].diagonal
88
+ if @ignore_label
89
+ t_valid= t.ne(@ignore_label)
90
+ log_p *= t_valid.flatten
56
91
  end
57
- log_p = log_yd.class.[](*log_p)
58
- log_p[t.flatten.dup.eq(@ignore_label)] = 0
59
92
 
60
93
  if @reduce == 'mean'
61
- if @normalize
62
- count = t.ne(@ignore_label).count
94
+ if @normalize and t_valid
95
+ @coeff = 1.0 / log_p.class.maximum(Chainer::Utils::Array.force_array(t_valid.count), 1)
63
96
  else
64
97
  count = x.shape[0]
98
+ @coeff = 1.0 / [count, 1].max
65
99
  end
66
- @coeff = 1.0 / [count, 1].max
67
100
  y = log_p.sum(keepdims: true) * (-@coeff)
68
101
  [y.class.cast(y[0])]
69
102
  else
@@ -71,7 +104,8 @@ module Chainer
71
104
  end
72
105
  end
73
106
 
74
- def backward_cpu(inputs, grad_outputs)
107
+ def backward(inputs, grad_outputs)
108
+ xm = Chainer.get_array_module(*(inputs + grad_outputs))
75
109
  x, t = inputs
76
110
  gloss = grad_outputs[0]
77
111
 
@@ -79,24 +113,24 @@ module Chainer
79
113
  y = @y.dup
80
114
  else
81
115
  y = Activation._log_softmax(x)
82
- y = Numo::NMath.exp(y)
116
+ y = xm::NMath.exp(y)
83
117
  end
84
118
 
85
119
  if y.ndim == 2
86
120
  gx = y
121
+ # TODO(sonots): Avoid to_a especially in Cumo to improve performance
87
122
  t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
88
123
 
89
124
  if @class_weight
90
125
  shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
91
- c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
92
- c = c.class.cast(t.class.new(t.shape[0]).seq.to_a.zip(t.class.maximum(t, 0).to_a).map{|v| c[*v]})
93
- gx *= Chainer::Functions::Loss.broadcast_to(c.expand_dims(1), gx.shape)
126
+ c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
127
+ c = c[t.class.new(t.shape[0]).seq, t.class.maximum(t, 0)].diagonal.dup
128
+ gx *= Chainer::Utils::Array.broadcast_to(c.expand_dims(1), gx.shape)
94
129
  end
95
130
 
96
- bit = t.flatten.dup
97
- bit[t.ne(@ignore_label)] = 1
98
- bit[bit.ne(1)] = 0
99
- gx *= bit.reshape(t.shape[0], 1)
131
+ if @ignore_label
132
+ gx *= (t.ne @ignore_label).reshape(t.shape[0], 1)
133
+ end
100
134
  else
101
135
  # in the case where y.ndim is higher than 2,
102
136
  # we think that a current implementation is inefficient
@@ -104,18 +138,21 @@ module Chainer
104
138
 
105
139
  n_unit = t.size / t.shape[0]
106
140
  gx = y.reshape(y.shape[0], y.shape[1], true)
107
- fst_index = Numo::Int32.new(t.size).seq(0) / n_unit
108
- trd_index = Numo::Int32.new(t.size).seq(0) % n_unit
141
+ fst_index = xm::Int32.new(t.size).seq(0) / n_unit
142
+ trd_index = xm::Int32.new(t.size).seq(0) % n_unit
143
+ # TODO(sonots): Avoid to_a especially in Cumo to improve performance
109
144
  fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
110
145
  if @class_weight
111
146
  shape = x.ndim.times.map{|d| d == 1 ? true : 1}
112
- c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
147
+ c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
113
148
  c = c.reshape(*gx.shape)
114
- c = c.class.cast(fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).map{|v| c[*v]})
149
+ c = c[fst_index, t.class.maximum(t.flatten.dup, 0), trd_index].diagonal.diagonal.dup
115
150
  c = c.reshape(y.shape[0], 1, true)
116
- gx *= Chainer::Functions::Loss.broadcast_to(c, gx.shape)
151
+ gx *= Chainer::Utils::Array.broadcast_to(c, gx.shape)
152
+ end
153
+ if @ignore_label
154
+ gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
117
155
  end
118
- gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
119
156
  gx = gx.reshape(*y.shape)
120
157
  end
121
158
 
@@ -126,36 +163,26 @@ module Chainer
126
163
  end
127
164
  return [gx, nil]
128
165
  end
129
- end
130
166
 
131
- def rollaxis(y, axis, start: 0)
132
- axes = (0...y.ndim).to_a
133
- axes.delete_at(axis)
134
- axes.insert(start <= axes.size ? start : -1, axis)
135
- y.transpose(*axes)
136
- end
167
+ def self.check_class_weight_option(class_weight)
168
+ return if class_weight.nil?
137
169
 
138
- def broadcast_to(array, shape)
139
- if array.shape.size > shape.size
140
- raise TypeError, "Shape of data mismatch\n array.shape.size(#{array.shape.size}) > shape.size(#{shape.size})"
170
+ xm = Chainer.get_array_module(@class_weight)
171
+ if class_weight.ndim != 1
172
+ raise ArgumentError, 'class_weight.ndim should be 1'
173
+ elsif (class_weight.class != xm::DFloat) and (class_weight.class != xm::SFloat)
174
+ raise ArgumentError, "The dtype of class_weight should be 'DFloat' or 'SFloat'"
175
+ elsif class_weight.kind_of?(Chainer::Variable)
176
+ raise ArgumentError, 'class_weight should be a NArray, not a chainer.Variable'
177
+ end
141
178
  end
142
179
 
143
- tile_shape = []
144
- shape_check = shape[-array.shape.size..-1]
145
- shape_check.each_with_index{|s, i|
146
- if array.shape[i] == 1
147
- tile_shape << s
148
- elsif array.shape[i] == s
149
- tile_shape << 1
150
- else
151
- raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
180
+ def self.check_reduce_option(reduce)
181
+ unless ['mean', 'no'].include?(reduce)
182
+ raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
152
183
  end
153
- }
154
-
155
- array.tile(*shape[0...-array.shape.size], *tile_shape)
184
+ end
156
185
  end
157
-
158
- module_function :rollaxis, :broadcast_to
159
186
  end
160
187
  end
161
188
  end