red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -0,0 +1,159 @@
1
+ module Chainer
2
+ module Functions
3
+ module Connection
4
+ class Deconvolution2DFunction < Chainer::FunctionNode
5
+ attr_reader :sy, :sx, :ph, :pw, :cover_all
6
+
7
+ # Two dimensional deconvolution function.
8
+ #
9
+ # This is an implementation of two-dimensional deconvolution.
10
+ # In most of deep learning frameworks and papers,
11
+ # this function is called <b>transposed convolution</b>.
12
+ # But because of historical reasons (e.g. paper by Ziller Deconvolutional Networks) and backward compatibility,
13
+ # this function is called +deconvolution+ in Chainer.
14
+ #
15
+ # It takes three variables: input image +x+,
16
+ # the filter weight +W+, and the bias vector +b+.
17
+ #
18
+ # - $n$ is the batch size.
19
+ # - $c_I$ and $c_O$ are the number of the input and output channels, respectively.
20
+ # - $h_I$ and $w_I$ are the height and width of the input image, respectively.
21
+ # - $h_K$ and $w_K$ are the height and width of the filters, respectively.
22
+ # - $h_P$ and $w_P$ are the height and width of the spatial padding size, respectively.
23
+ #
24
+ # Let $(s_Y, s_X)$ be the stride of filter application.
25
+ # Then, the output size $(h_O, w_O)$ is estimated by the following equations:
26
+ #
27
+ # $
28
+ # h_O &= s_Y (h_I - 1) + h_K - 2h_P,
29
+ # w_O &= s_X (w_I - 1) + w_K - 2w_P.
30
+ # $
31
+ #
32
+ # @param [Chainer::Variable or Numo::NArray] x Input variable of shape $(n, c_I, h_I, w_I)$.
33
+ # @param [Chainer::Variable or Numo::NArray] w Weight variable of shape $(c_I, c_O, h_K, w_K)$.
34
+ # @param [Chainer::Variable or Numo::NArray] b Bias variable of length $c_O$ (optional).
35
+ # @param [integer or Array<integer>] stride Stride of filter applications. +stride=s+ and +stride=[s, s]+ are equivalent.
36
+ # @param [integer or Array<integer>] pad Spatial padding width for input arrays. +pad=p+ and +pad=[p, p]+ are equivalent.
37
+ # @param [integer or Arrat<integer>] outsize Expected output size of deconvolutional operation.
38
+ # It should be pair of height and width $(h_O, w_O)$.
39
+ # Default value is +nil+ and the outsize is estimated by input size, stride and pad.
40
+ # @return [Chainer::Variable] Output variable of shape $(n, c_O, h_O, w_O)$.
41
+ #
42
+ # Example
43
+ # > n = 10
44
+ # > c_i, c_o = 1, 3
45
+ # > h_i, w_i = 5, 10
46
+ # > h_k, w_k = 10, 10
47
+ # > h_p, w_p = 5, 5
48
+ # > x = Numo::DFloat.new(n, c_i, h_i, w_i).rand
49
+ # > x.shape
50
+ # => [10, 1, 5, 10]
51
+ # > w = Numo::DFloat.new(c_i, c_o, h_k, w_k).rand
52
+ # > w.shape
53
+ # => [1, 3, 10, 10]
54
+ # > b = Numo::DFloat.new(c_o).rand
55
+ # > b.shape
56
+ # => [3]
57
+ # > s_y, s_x = 5, 5
58
+ # > y = Chainer::Functions::Connection::Deconvolution2DFunction.deconvolution_2d(x, w, b: b, stride: [s_y, s_x], pad: [h_p, w_p])
59
+ # > y.shape
60
+ # => [10, 3, 20, 45]
61
+ # > h_o = s_y * (h_i - 1) + h_k - 2 * h_p
62
+ # > w_o = s_x * (w_i - 1) + w_k - 2 * w_p
63
+ # > y.shape == [n, c_o, h_o, w_o]
64
+ # => true
65
+ def self.deconvolution_2d(x, w, b: nil, stride: 1, pad: 0, outsize: nil)
66
+ func = Deconvolution2DFunction.new(stride: stride, pad: pad, outsize: outsize)
67
+ if b.nil?
68
+ args = x, w
69
+ else
70
+ args = x, w, b
71
+ end
72
+ func.apply(args).first
73
+ end
74
+
75
+ def initialize(stride: 1, pad: 0, outsize: nil)
76
+ @cover_all = nil
77
+
78
+ @sy, @sx = stride.is_a?(::Array) ? stride : [stride, stride]
79
+ @ph, @pw = pad.is_a?(::Array) ? pad : [pad, pad]
80
+ @outh, @outw = outsize.nil? ? [nil, nil] : outsize
81
+ end
82
+
83
+ def forward(inputs)
84
+ retain_inputs([0, 1])
85
+ x, w = inputs[0...2]
86
+ b = inputs.size == 3 ? inputs[2] : nil
87
+
88
+ unless inputs.all? { |i| i.is_a?(Numo::NArray) }
89
+ if b.nil?
90
+ raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}"
91
+ else
92
+ raise TypeError, "Numo::NArray must not be used together w: #{w.class}, x: #{x.class}, b: #{b.class}"
93
+ end
94
+ end
95
+
96
+ kh, kw = w.shape[2..-1]
97
+ _, _, x_h, x_w = x.shape
98
+
99
+ gcol = Chainer::Utils::Math.tensordot(w, x, [0, 1]).cast_to(x.class)
100
+ # - k, m, n: shape of out_channel
101
+ # - b: number of inputs
102
+ # - h, w: height and width of kernels
103
+ # k, m, n, b, h, w -> b, k, m, n, h, w
104
+ gcol = gcol.transpose(3, 0, 1, 2, 4, 5)
105
+
106
+ if @outh.nil?
107
+ @outh = Chainer::Utils::Conv.get_deconv_outsize(x_h, kh, @sy, @ph)
108
+ raise TypeError, 'Height in the output should be positive.' if @outh <= 0
109
+ end
110
+ if @outw.nil?
111
+ @outw = Chainer::Utils::Conv.get_deconv_outsize(x_w, kw, @sx, @pw)
112
+ raise TypeError, 'Width in the output should be positive.' if @outw <= 0
113
+ end
114
+
115
+ y = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, @outh, @outw)
116
+ if !b.nil?
117
+ y += b.reshape(1, b.size, 1, 1)
118
+ end
119
+ [y]
120
+ end
121
+
122
+ def backward(indexes, grad_outputs)
123
+ x, w = get_retained_inputs
124
+ gy = grad_outputs.first
125
+
126
+ ret = []
127
+
128
+ if indexes.include?(0)
129
+ set_cover_all(x, w) if @cover_all.nil?
130
+ gw = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(gy, w, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
131
+ ret << gw
132
+ end
133
+
134
+ if indexes.include?(1)
135
+ set_cover_all(x, w) if @cover_all.nil?
136
+ gw = Chainer::Functions::Connection::Convolution2DGradW.new(self).apply([gy, x]).first
137
+ ret << gw
138
+ end
139
+
140
+ if indexes.include?(2)
141
+ gb = Chainer::Functions::Math::Sum.sum(gy, axis: [0, 2, 3])
142
+ ret << gb
143
+ end
144
+
145
+ ret
146
+ end
147
+
148
+ private
149
+
150
+ def set_cover_all(x, w)
151
+ in_h, in_w = x.shape[2..-1]
152
+ kh, kw = w.shape[2..-1]
153
+
154
+ @cover_all = in_h != Chainer::Utils::Conv.get_conv_outsize(@outh, kh, @sy, @ph) || in_w != Chainer::Utils::Conv.get_conv_outsize(@outw, kw, @sx, @pw)
155
+ end
156
+ end
157
+ end
158
+ end
159
+ end
@@ -1,17 +1,23 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Connection
4
- class LinearFunction < Chainer::Function
4
+ class LinearFunction < Chainer::FunctionNode
5
5
  def self.linear(x, w, b=nil)
6
+ if x.ndim > 2
7
+ x = x.reshape(x.shape.first, -1)
8
+ end
9
+
6
10
  if b.nil?
7
- self.new.(x, w)
11
+ args = x, w
8
12
  else
9
- self.new.(x, w, b)
13
+ args = x, w, b
10
14
  end
15
+
16
+ self.new.apply(args).first
11
17
  end
12
18
 
13
19
  def forward(inputs)
14
- x = as_mat(inputs[0])
20
+ x = inputs[0]
15
21
  w = inputs[1]
16
22
 
17
23
  y = x.dot(w.transpose).cast_to(x.class)
@@ -19,28 +25,29 @@ module Chainer
19
25
  b = inputs[2]
20
26
  y += b
21
27
  end
22
- return [y]
23
- end
24
28
 
25
- def backward(inputs, grad_outputs)
26
- x = as_mat(inputs[0])
27
- w = inputs[1]
28
- gy = grad_outputs[0]
29
- gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape)
30
- gw = gy.transpose.dot(x).cast_to(w.class)
31
- if inputs.size == 3
32
- gb = gy.sum(0)
33
- [gx, gw, gb]
34
- else
35
- [gx, gw]
36
- end
29
+ retain_inputs([0, 1])
30
+ return [y]
37
31
  end
38
32
 
39
- private
33
+ def backward(indexes, grad_outputs)
34
+ x, w = get_retained_inputs
35
+ gy = grad_outputs.first
40
36
 
41
- def as_mat(x)
42
- return x if x.ndim == 2
43
- x.reshape(x.shape.first, true)
37
+ ret = []
38
+ if indexes.include?(0)
39
+ gx = LinearFunction.linear(gy, w.transpose)
40
+ ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
41
+ end
42
+ if indexes.include?(1)
43
+ gw = LinearFunction.linear(gy.transpose, x.transpose)
44
+ ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
45
+ end
46
+ if indexes.include?(2)
47
+ gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
48
+ ret << gb
49
+ end
50
+ ret
44
51
  end
45
52
  end
46
53
  end
@@ -12,13 +12,13 @@ module Chainer
12
12
 
13
13
  def forward(inputs)
14
14
  y, t = inputs
15
+ xm = Chainer.get_array_module(*inputs)
15
16
  if @ignore_label
16
17
  mask = t.eq(@ignore_label)
17
18
  ignore_cnt = mask.count
18
19
 
19
- # this work
20
- pred = y.max_index(axis: 1).to_a.map.with_index { |val, idx| val - y.shape[1] * idx}
21
- pred = y.class[*pred].reshape(*t.shape)
20
+ pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
21
+ pred = pred.reshape(*t.shape)
22
22
  pred[mask] = @ignore_label
23
23
  count = pred.eq(t).count - ignore_cnt
24
24
 
@@ -30,8 +30,8 @@ module Chainer
30
30
  [y.class.cast(count.to_f / total)]
31
31
  end
32
32
  else
33
- pred = y.max_index(axis: 1).to_a.map.with_index { |val, idx| val - y.shape[1] * idx}
34
- pred = y.class[*pred].reshape(*t.shape)
33
+ pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
34
+ pred = pred.reshape(*t.shape)
35
35
 
36
36
  [y.class.cast(y.class[pred.eq(t)].mean)]
37
37
  end
@@ -2,31 +2,40 @@ module Chainer
2
2
  module Functions
3
3
  module Loss
4
4
  # Mean squared error (a.k.a. Euclidean loss) function.
5
- class MeanSquaredError < Function
5
+ class MeanSquaredError < FunctionNode
6
6
  # Mean squared error function.
7
7
  #
8
8
  # This function computes mean squared error between two variables. The mean
9
9
  # is taken over the minibatch. Note that the error is not scaled by 1/2.
10
10
  #
11
- # @param [Chainer::Variable or Numo::NArray] x0 Input variable.
12
- # @param [Chainer::Variable or Numo::NArray] x1 Input variable.
11
+ # @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x0 Input variable.
12
+ # @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x1 Input variable.
13
13
  # @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
14
14
  #
15
15
  def self.mean_squared_error(x0, x1)
16
- self.new.(x0, x1)
16
+ self.new.apply([x0, x1]).first
17
17
  end
18
18
 
19
- def forward_cpu(inputs)
20
- x0, x1 = inputs
21
- @diff = x0 - x1
22
- diff = @diff.flatten.dup()
19
+ def forward(inputs)
20
+ retain_inputs([0, 1])
21
+ diff = (inputs[0] - inputs[1]).flatten.dup
23
22
  [diff.class.cast(diff.dot(diff) / diff.size)]
24
23
  end
25
24
 
26
- def backward(inputs, gy)
27
- coeff = gy[0] * gy[0].class.cast(2.0 / @diff.size)
28
- gx0 = coeff * @diff
29
- [gx0, -(gx0)]
25
+ def backward(indexes, gy)
26
+ x0, x1 = get_retained_inputs
27
+ diff = x0 - x1
28
+ gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape)
29
+ gx0 = gy0 * diff * (2.0 / diff.size)
30
+
31
+ ret = []
32
+ if indexes.include?(0)
33
+ ret << gx0
34
+ end
35
+ if indexes.include?(1)
36
+ ret << -gx0
37
+ end
38
+ ret
30
39
  end
31
40
  end
32
41
  end
@@ -2,68 +2,101 @@ module Chainer
2
2
  module Functions
3
3
  module Loss
4
4
  class SoftmaxCrossEntropy < Function
5
- def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
6
- self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
5
+ def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean', enable_double_backprop: false)
6
+ if enable_double_backprop
7
+ self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
8
+ else
9
+ self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
10
+ end
11
+ end
12
+
13
+ def self.double_backward_softmax_cross_entropy(x, t, normalize, class_weight, ignore_label, reduce)
14
+ if t.is_a?(Chainer::Variable)
15
+ t = t.data
16
+ end
17
+
18
+ self.check_class_weight_option(class_weight)
19
+ self.check_reduce_option(reduce)
20
+
21
+ loss = -Activation::LogSoftmax.log_softmax(x)
22
+
23
+ if class_weight
24
+ shape = x.ndim.times.map { |d| d != 1 ? 1 : class_weight.shape[-1] }
25
+ class_weight = Chainer::Functions::Array::BroadcastTo.broadcast_to(class_weight.reshape(*shape), x.shape)
26
+ loss = loss * class_weight
27
+ end
28
+
29
+ dtype = x.is_a?(Chainer::Variable) ? x.dtype : x.class
30
+ in_use = t.ne(ignore_label).cast_to(dtype)
31
+
32
+ loss = Chainer::Functions::Array::Rollaxis.rollaxis(loss, 1, start: loss.ndim)
33
+
34
+ # TODO: loss = chainer.functions.reshape(loss, (-1, loss.shape[-1]))
35
+ shape = loss.shape
36
+ last_shape = shape.pop
37
+ loss = Chainer::Functions::Array::Reshape.reshape(loss, [shape.inject(:*), last_shape])
38
+
39
+ # Replace ignore_label value with one valid for F.select_item below.
40
+ t = t.clip(0, loss.shape[1] - 1)
41
+
42
+ loss = Chainer::Functions::Array::SelectItem.select_item(loss, t.flatten.dup)
43
+ loss = Chainer::Functions::Array::Reshape.reshape(loss, t.shape)
44
+
45
+ loss = loss * in_use
46
+
47
+ if reduce == "mean"
48
+ count = normalize ? in_use.sum : x.shape.first
49
+ count = [count, 1.0].max
50
+ loss = loss * (1.0 / count)
51
+ return Chainer::Functions::Math::Sum.sum(loss)
52
+ else
53
+ return loss
54
+ end
7
55
  end
8
56
 
9
57
  def initialize(normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
10
58
  @normalize = normalize
11
59
  @cache_score = cache_score
60
+ self.class.check_class_weight_option(class_weight)
12
61
  @class_weight = class_weight
13
62
 
14
- unless class_weight.nil?
15
- if @class_weight.ndim != 1
16
- raise ArgumentError, 'class_weight.ndim should be 1'
17
- elsif (@class_weight.class != Numo::DFloat) and (@class_weight.class != Numo::SFloat)
18
- raise ArgumentError, "The dtype of class_weight should be 'Numo::DFloat' or 'Numo::SFloat'"
19
- elsif @class_weight.kind_of?(Chainer::Variable)
20
- raise ArgumentError, 'class_weight should be a Numo::NArray, not a chainer.Variable'
21
- end
22
- end
23
-
24
63
  @ignore_label = ignore_label
25
- unless ['mean', 'no'].include?(reduce)
26
- raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
27
- end
28
64
 
65
+ self.class.check_reduce_option(reduce)
29
66
  @reduce = reduce
30
67
  end
31
68
 
32
- def forward_cpu(inputs)
69
+ def forward(inputs)
70
+ xm = Chainer.get_array_module(*inputs)
33
71
  x, t = inputs
34
72
  log_y = Activation._log_softmax(x)
35
73
 
36
74
  if @cache_score
37
- @y = Numo::NMath.exp(log_y)
75
+ @y = xm::NMath.exp(log_y)
38
76
  end
39
77
  if @class_weight
40
78
  shape = x.ndim.times.map { |e| e == 1 ? true : 1 }
41
- log_y *= Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
79
+ log_y *= Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
42
80
  end
43
- log_yd = Chainer::Functions::Loss.rollaxis(log_y, 1)
81
+ log_yd = Chainer::Utils::Array.rollaxis(log_y, 1)
44
82
  begin
45
83
  log_yd = log_yd.reshape(log_yd.shape[0], true)
46
84
  rescue ArgumentError
47
85
  end
48
- ravel_arr = t.dup.flatten.dup
49
- ravel_arr[ravel_arr<0] = 0
50
- arange_arr = t.class.new(t.size).seq
51
-
52
- # https://github.com/chainer/chainer/blob/v2.0.2/chainer/functions/loss/softmax_cross_entropy.py#L79
53
- log_p = []
54
- ravel_arr.each_with_index do |r, i|
55
- log_p << log_yd[r, i]
86
+
87
+ log_p = log_yd[t.class.maximum(t.flatten, 0), t.class.new(t.size).seq].diagonal
88
+ if @ignore_label
89
+ t_valid= t.ne(@ignore_label)
90
+ log_p *= t_valid.flatten
56
91
  end
57
- log_p = log_yd.class.[](*log_p)
58
- log_p[t.flatten.dup.eq(@ignore_label)] = 0
59
92
 
60
93
  if @reduce == 'mean'
61
- if @normalize
62
- count = t.ne(@ignore_label).count
94
+ if @normalize and t_valid
95
+ @coeff = 1.0 / log_p.class.maximum(Chainer::Utils::Array.force_array(t_valid.count), 1)
63
96
  else
64
97
  count = x.shape[0]
98
+ @coeff = 1.0 / [count, 1].max
65
99
  end
66
- @coeff = 1.0 / [count, 1].max
67
100
  y = log_p.sum(keepdims: true) * (-@coeff)
68
101
  [y.class.cast(y[0])]
69
102
  else
@@ -71,7 +104,8 @@ module Chainer
71
104
  end
72
105
  end
73
106
 
74
- def backward_cpu(inputs, grad_outputs)
107
+ def backward(inputs, grad_outputs)
108
+ xm = Chainer.get_array_module(*(inputs + grad_outputs))
75
109
  x, t = inputs
76
110
  gloss = grad_outputs[0]
77
111
 
@@ -79,24 +113,24 @@ module Chainer
79
113
  y = @y.dup
80
114
  else
81
115
  y = Activation._log_softmax(x)
82
- y = Numo::NMath.exp(y)
116
+ y = xm::NMath.exp(y)
83
117
  end
84
118
 
85
119
  if y.ndim == 2
86
120
  gx = y
121
+ # TODO(sonots): Avoid to_a especially in Cumo to improve performance
87
122
  t.class.new(t.shape[0]).seq(0).to_a.zip(t.class.maximum(t, 0).to_a).each{|v| gx[*v] -= 1}
88
123
 
89
124
  if @class_weight
90
125
  shape = x.ndim.times.map { |d| d == 1 ? true : 1 }
91
- c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
92
- c = c.class.cast(t.class.new(t.shape[0]).seq.to_a.zip(t.class.maximum(t, 0).to_a).map{|v| c[*v]})
93
- gx *= Chainer::Functions::Loss.broadcast_to(c.expand_dims(1), gx.shape)
126
+ c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
127
+ c = c[t.class.new(t.shape[0]).seq, t.class.maximum(t, 0)].diagonal.dup
128
+ gx *= Chainer::Utils::Array.broadcast_to(c.expand_dims(1), gx.shape)
94
129
  end
95
130
 
96
- bit = t.flatten.dup
97
- bit[t.ne(@ignore_label)] = 1
98
- bit[bit.ne(1)] = 0
99
- gx *= bit.reshape(t.shape[0], 1)
131
+ if @ignore_label
132
+ gx *= (t.ne @ignore_label).reshape(t.shape[0], 1)
133
+ end
100
134
  else
101
135
  # in the case where y.ndim is higher than 2,
102
136
  # we think that a current implementation is inefficient
@@ -104,18 +138,21 @@ module Chainer
104
138
 
105
139
  n_unit = t.size / t.shape[0]
106
140
  gx = y.reshape(y.shape[0], y.shape[1], true)
107
- fst_index = Numo::Int32.new(t.size).seq(0) / n_unit
108
- trd_index = Numo::Int32.new(t.size).seq(0) % n_unit
141
+ fst_index = xm::Int32.new(t.size).seq(0) / n_unit
142
+ trd_index = xm::Int32.new(t.size).seq(0) % n_unit
143
+ # TODO(sonots): Avoid to_a especially in Cumo to improve performance
109
144
  fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).each{|v| gx[*v] -= 1}
110
145
  if @class_weight
111
146
  shape = x.ndim.times.map{|d| d == 1 ? true : 1}
112
- c = Chainer::Functions::Loss.broadcast_to(@class_weight.reshape(*shape), x.shape)
147
+ c = Chainer::Utils::Array.broadcast_to(@class_weight.reshape(*shape), x.shape)
113
148
  c = c.reshape(*gx.shape)
114
- c = c.class.cast(fst_index.to_a.zip(t.class.maximum(t.flatten.dup, 0).to_a, trd_index.to_a).map{|v| c[*v]})
149
+ c = c[fst_index, t.class.maximum(t.flatten.dup, 0), trd_index].diagonal.diagonal.dup
115
150
  c = c.reshape(y.shape[0], 1, true)
116
- gx *= Chainer::Functions::Loss.broadcast_to(c, gx.shape)
151
+ gx *= Chainer::Utils::Array.broadcast_to(c, gx.shape)
152
+ end
153
+ if @ignore_label
154
+ gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
117
155
  end
118
- gx *= (t.ne @ignore_label).reshape(t.shape[0], 1, true)
119
156
  gx = gx.reshape(*y.shape)
120
157
  end
121
158
 
@@ -126,36 +163,26 @@ module Chainer
126
163
  end
127
164
  return [gx, nil]
128
165
  end
129
- end
130
166
 
131
- def rollaxis(y, axis, start: 0)
132
- axes = (0...y.ndim).to_a
133
- axes.delete_at(axis)
134
- axes.insert(start <= axes.size ? start : -1, axis)
135
- y.transpose(*axes)
136
- end
167
+ def self.check_class_weight_option(class_weight)
168
+ return if class_weight.nil?
137
169
 
138
- def broadcast_to(array, shape)
139
- if array.shape.size > shape.size
140
- raise TypeError, "Shape of data mismatch\n array.shape.size(#{array.shape.size}) > shape.size(#{shape.size})"
170
+ xm = Chainer.get_array_module(@class_weight)
171
+ if class_weight.ndim != 1
172
+ raise ArgumentError, 'class_weight.ndim should be 1'
173
+ elsif (class_weight.class != xm::DFloat) and (class_weight.class != xm::SFloat)
174
+ raise ArgumentError, "The dtype of class_weight should be 'DFloat' or 'SFloat'"
175
+ elsif class_weight.kind_of?(Chainer::Variable)
176
+ raise ArgumentError, 'class_weight should be a NArray, not a chainer.Variable'
177
+ end
141
178
  end
142
179
 
143
- tile_shape = []
144
- shape_check = shape[-array.shape.size..-1]
145
- shape_check.each_with_index{|s, i|
146
- if array.shape[i] == 1
147
- tile_shape << s
148
- elsif array.shape[i] == s
149
- tile_shape << 1
150
- else
151
- raise TypeError, "Shape of data mismatch\n#{array.shape} != #{shape}"
180
+ def self.check_reduce_option(reduce)
181
+ unless ['mean', 'no'].include?(reduce)
182
+ raise ArgumentError, "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
152
183
  end
153
- }
154
-
155
- array.tile(*shape[0...-array.shape.size], *tile_shape)
184
+ end
156
185
  end
157
-
158
- module_function :rollaxis, :broadcast_to
159
186
  end
160
187
  end
161
188
  end