red-chainer 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +12 -0
  3. data/.rspec +2 -0
  4. data/.travis.yml +5 -0
  5. data/CODE_OF_CONDUCT.md +74 -0
  6. data/Gemfile +4 -0
  7. data/LICENSE.txt +23 -0
  8. data/README.md +60 -0
  9. data/Rakefile +8 -0
  10. data/bin/console +14 -0
  11. data/bin/setup +8 -0
  12. data/examples/mnist.rb +42 -0
  13. data/lib/chainer.rb +59 -0
  14. data/lib/chainer/configuration.rb +10 -0
  15. data/lib/chainer/dataset/convert.rb +62 -0
  16. data/lib/chainer/dataset/download.rb +56 -0
  17. data/lib/chainer/dataset/iterator.rb +15 -0
  18. data/lib/chainer/datasets/mnist.rb +89 -0
  19. data/lib/chainer/datasets/tuple_dataset.rb +33 -0
  20. data/lib/chainer/function.rb +80 -0
  21. data/lib/chainer/functions/activation/log_softmax.rb +37 -0
  22. data/lib/chainer/functions/activation/relu.rb +23 -0
  23. data/lib/chainer/functions/connection/linear.rb +48 -0
  24. data/lib/chainer/functions/evaluation/accuracy.rb +42 -0
  25. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +134 -0
  26. data/lib/chainer/functions/math/basic_math.rb +119 -0
  27. data/lib/chainer/gradient_method.rb +63 -0
  28. data/lib/chainer/hyperparameter.rb +23 -0
  29. data/lib/chainer/initializer.rb +12 -0
  30. data/lib/chainer/initializers/constant.rb +18 -0
  31. data/lib/chainer/initializers/init.rb +24 -0
  32. data/lib/chainer/initializers/normal.rb +28 -0
  33. data/lib/chainer/iterators/serial_iterator.rb +74 -0
  34. data/lib/chainer/link.rb +118 -0
  35. data/lib/chainer/links/connection/linear.rb +43 -0
  36. data/lib/chainer/links/model/classifier.rb +39 -0
  37. data/lib/chainer/optimizer.rb +69 -0
  38. data/lib/chainer/optimizers/adam.rb +62 -0
  39. data/lib/chainer/parameter.rb +53 -0
  40. data/lib/chainer/reporter.rb +130 -0
  41. data/lib/chainer/training/extension.rb +25 -0
  42. data/lib/chainer/training/extensions/evaluator.rb +26 -0
  43. data/lib/chainer/training/extensions/log_report.rb +72 -0
  44. data/lib/chainer/training/extensions/print_report.rb +62 -0
  45. data/lib/chainer/training/extensions/progress_bar.rb +89 -0
  46. data/lib/chainer/training/standard_updater.rb +63 -0
  47. data/lib/chainer/training/trainer.rb +136 -0
  48. data/lib/chainer/training/triggers/interval.rb +27 -0
  49. data/lib/chainer/training/updater.rb +33 -0
  50. data/lib/chainer/training/util.rb +13 -0
  51. data/lib/chainer/utils/array.rb +10 -0
  52. data/lib/chainer/utils/initializer.rb +14 -0
  53. data/lib/chainer/utils/variable.rb +20 -0
  54. data/lib/chainer/variable.rb +204 -0
  55. data/lib/chainer/variable_node.rb +71 -0
  56. data/lib/chainer/version.rb +4 -0
  57. data/red-chainer.gemspec +27 -0
  58. metadata +156 -0
@@ -0,0 +1,15 @@
1
+ module Chainer
2
+ module Dataset
3
+ class Iterator
4
+ def next
5
+ raise NotImplementedError
6
+ end
7
+
8
+ def finalize
9
+ end
10
+
11
+ def serialize(serializer)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,89 @@
1
+ require 'zlib'
2
+
3
+ module Chainer
4
+ module Datasets
5
+ module Mnist
6
+ def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: Numo::DFloat, label_dtype: Numo::Int32)
7
+ train_raw = retrieve_mnist_training
8
+ train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)
9
+
10
+ test_raw = retrieve_mnist_test
11
+ test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype)
12
+ [train, test]
13
+ end
14
+
15
+ def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype)
16
+ images = raw[:x]
17
+ if ndim == 2
18
+ images = images.reshape(-1, 28, 28)
19
+ elsif ndim == 3
20
+ images = images.reshape(-1, 1, 28, 28)
21
+ elsif ndim != 1
22
+ raise "invalid ndim for MNIST dataset"
23
+ end
24
+
25
+ images = images.cast_to(image_dtype)
26
+ images *= scale / 255.0
27
+
28
+ if withlabel
29
+ labels = raw[:y].cast_to(label_dtype)
30
+ TupleDataset.new(images, labels)
31
+ else
32
+ images
33
+ end
34
+ end
35
+
36
+ def self.retrieve_mnist_training
37
+ urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
38
+ 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz']
39
+ retrieve_mnist('train.npz', urls)
40
+ end
41
+
42
+ def self.retrieve_mnist_test
43
+ urls = ['http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
44
+ 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz']
45
+ retrieve_mnist('test.npz', urls)
46
+ end
47
+
48
+ def self.retrieve_mnist(name, urls)
49
+ root = Chainer::Dataset::Download.get_dataset_directory('pfnet/chainer/mnist')
50
+ path = File.expand_path(name, root)
51
+ Chainer::Dataset::Download.cache_or_load_file(path) do
52
+ make_npz(path, urls)
53
+ end
54
+ end
55
+
56
+ def self.make_npz(path, urls)
57
+ x_url, y_url = urls
58
+ x_path = Chainer::Dataset::Download.cached_download(x_url)
59
+ y_path = Chainer::Dataset::Download.cached_download(y_url)
60
+
61
+ x = nil
62
+ y = nil
63
+
64
+ Zlib::GzipReader.open(x_path) do |fx|
65
+ Zlib::GzipReader.open(y_path) do |fy|
66
+ fx.read(4)
67
+ fy.read(4)
68
+
69
+ n = fx.read(4).unpack('i>')[0]
70
+ fy.read(4)
71
+ fx.read(8)
72
+
73
+ x = Numo::UInt8.new(n, 784).rand(n)
74
+ y = Numo::UInt8.new(n).rand(n)
75
+
76
+ n.times do |i|
77
+ y[i] = fy.read(1).ord
78
+ 784.times do |j|
79
+ x[i, j] = fx.read(1).ord
80
+ end
81
+ end
82
+ end
83
+ end
84
+
85
+ { x: x, y: y}
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,33 @@
1
+ module Chainer
2
+ module Datasets
3
+ class TupleDataset
4
+ def initialize(*datasets)
5
+ if datasets.empty?
6
+ raise "no datasets are given"
7
+ end
8
+ length = datasets[0].shape[0]
9
+
10
+ datasets.each_with_index do |dataset, idx|
11
+ raise "dataset of the index #{idx} has a wrong length" unless dataset.shape[0] == length
12
+ end
13
+
14
+ @datasets = datasets
15
+ @length = length
16
+ end
17
+
18
+ def [](index)
19
+ batches = @datasets.map { |dataset| dataset.ndim > 1 ? dataset[index, 0...dataset.shape[1]] : dataset[index] }
20
+ if index.kind_of?(Enumerable)
21
+ length = batches[0].shape[0]
22
+ length.times.map {|i| batches.map { |m| m[i] } }
23
+ else
24
+ batches
25
+ end
26
+ end
27
+
28
+ def size
29
+ @length
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,80 @@
1
+ module Chainer
2
+ class Function
3
+
4
+ attr_reader :rank, :inputs, :outputs, :retain_after_backward
5
+ attr_accessor :output_data
6
+
7
+ def initialize
8
+ @rank = 0
9
+ end
10
+
11
+ def call(*inputs)
12
+ inputs = inputs.map do |x|
13
+ if x.kind_of?(Chainer::Variable)
14
+ x
15
+ else
16
+ Variable.new(x, requires_grad: false)
17
+ end
18
+ end
19
+
20
+ in_data = inputs.map(&:data)
21
+ requires_grad = inputs.any?(&:requires_grad)
22
+
23
+ @input_indexes_to_retain = nil
24
+ @output_indexes_to_retain = nil
25
+ outputs = forward(in_data)
26
+
27
+ ret = outputs.map do |y|
28
+ Variable.new(y, requires_grad: requires_grad)
29
+ end
30
+
31
+ if Chainer.configuration.enable_backprop
32
+ @rank = inputs.map(&:rank).max || 0
33
+
34
+ ret.each { |y| y.creator = self }
35
+
36
+ @inputs = inputs.map(&:node)
37
+ @outputs = ret.map { |y| WeakRef.new(y.node) }
38
+
39
+ @input_indexes_to_retain = 0...inputs.size if @input_indexes_to_retain.nil?
40
+ @input_indexes_to_retain.each do |index|
41
+ inputs[index].retain_data()
42
+ end
43
+ remove_instance_variable(:@input_indexes_to_retain)
44
+
45
+ unless @output_indexes_to_retain.nil?
46
+ @output_indexes_to_retain.each do |index|
47
+ ret[index].retain_data()
48
+ end
49
+ remove_instance_variable(:@output_indexes_to_retain)
50
+ end
51
+ end
52
+
53
+ ret.size == 1 ? ret[0] : ret
54
+ end
55
+
56
+ def forward(inputs)
57
+ # TODO: GPU branch processing
58
+ forward_cpu(inputs)
59
+ end
60
+
61
+ def forward_cpu(inputs)
62
+ raise NotImplementedError
63
+ end
64
+
65
+ def backward(inputs, grad_outputs)
66
+ backward_cpu(inputs, grad_outputs)
67
+ end
68
+
69
+ def retain_inputs(indexes)
70
+ @input_indexes_to_retain = indexes
71
+ end
72
+
73
+ def retain_outputs(indexes, retain_after_backward: false)
74
+ @output_indexes_to_retain = indexes
75
+ if retain_after_backward
76
+ @retain_after_backward = retain_after_backward
77
+ end
78
+ end
79
+ end
80
+ end
@@ -0,0 +1,37 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ def self.logsumexp(x)
5
+ m = x.max(axis: 1, keepdims: true)
6
+ y = x - m
7
+ y = Numo::NMath.exp(y)
8
+ s = y.sum(axis: 1, keepdims: true)
9
+ s = Numo::NMath.log(s)
10
+ m + s
11
+ end
12
+
13
+ def self.log_softmax(x)
14
+ log_z = logsumexp(x)
15
+ x - log_z
16
+ end
17
+
18
+ class LogSoftmax < Function
19
+ def self.relu(x)
20
+ self.new.(x)
21
+ end
22
+
23
+ def forward_cpu(x)
24
+ retain_inputs([])
25
+ retain_outputs([0])
26
+ x[0][x[0]<=0] = 0
27
+ [Utils::Array.force_array(x[0])]
28
+ end
29
+
30
+ def backward_cpu(x, gy)
31
+ y = output_data[0]
32
+ [Utils::Array.force_array(gy[0] * (y > 0))]
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,23 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ class Relu < Function
5
+ def self.relu(x)
6
+ self.new.(x)
7
+ end
8
+
9
+ def forward_cpu(x)
10
+ retain_inputs([])
11
+ retain_outputs([0])
12
+ x[0][x[0]<=0] = 0
13
+ [Utils::Array.force_array(x[0])]
14
+ end
15
+
16
+ def backward_cpu(x, gy)
17
+ y = output_data[0]
18
+ [Utils::Array.force_array(gy[0] * (y > 0))]
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,48 @@
1
+ module Chainer
2
+ module Functions
3
+ module Connection
4
+ class LinearFunction < Chainer::Function
5
+ def self.linear(x, w, b=nil)
6
+ if b.nil?
7
+ self.new.(x, w)
8
+ else
9
+ self.new.(x, w, b)
10
+ end
11
+ end
12
+
13
+ def forward(inputs)
14
+ x = as_mat(inputs[0])
15
+ w = inputs[1]
16
+
17
+ y = x.dot(w.transpose).cast_to(x.class)
18
+ if inputs.size == 3
19
+ b = inputs[2]
20
+ y += b
21
+ end
22
+ return [y]
23
+ end
24
+
25
+ def backward(inputs, grad_outputs)
26
+ x = as_mat(inputs[0])
27
+ w = inputs[1]
28
+ gy = grad_outputs[0]
29
+ gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape)
30
+ gw = gy.transpose.dot(x).cast_to(w.class)
31
+ if inputs.size == 3
32
+ gb = gy.sum(0)
33
+ [gx, gw, gb]
34
+ else
35
+ [gx, gw]
36
+ end
37
+ end
38
+
39
+ private
40
+
41
+ def as_mat(x)
42
+ return x if x.ndim == 2
43
+ x.reshape(x.size, -1)
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,42 @@
1
+ module Chainer
2
+ module Functions
3
+ module Evaluation
4
+ class Accuracy < Function
5
+ def self.accuracy(y, t, ignore_label: nil)
6
+ self.new(ignore_label: ignore_label).(y, t)
7
+ end
8
+
9
+ def initialize(ignore_label: nil)
10
+ @ignore_label = ignore_label
11
+ end
12
+
13
+ def forward(inputs)
14
+ y, t = inputs
15
+ if @ignore_label
16
+ mask = t.eq(@ignore_label)
17
+ ignore_cnt = mask.count
18
+
19
+ # this work
20
+ pred = y.max_index(axis: 1).to_a.map.with_index { |val, idx| val - y.shape[1] * idx}
21
+ pred = y.class[*pred].reshape(*t.shape)
22
+ pred[mask] = @ignore_label
23
+ count = pred.eq(t).count - ignore_cnt
24
+
25
+ total = t.size - ignore_cnt
26
+
27
+ if total == 0
28
+ [y.class.cast(0.0)]
29
+ else
30
+ [y.class.cast(count.to_f / total)]
31
+ end
32
+ else
33
+ pred = y.max_index(axis: 1).to_a.map.with_index { |val, idx| val - y.shape[1] * idx}
34
+ pred = y.class[*pred].reshape(*t.shape)
35
+
36
+ [y.class.cast(y.class[pred.eq(t)].mean)]
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,134 @@
1
+ module Chainer
2
+ module Functions
3
+ module Loss
4
+ class SoftmaxCrossEntropy < Function
5
+ def self.softmax_cross_entropy(x, t, normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
6
+ self.new(normalize: normalize, cache_score: cache_score, class_weight: class_weight, ignore_label: ignore_label, reduce: reduce).(x, t)
7
+ end
8
+
9
+ def initialize(normalize: true, cache_score: true, class_weight: nil, ignore_label: -1, reduce: 'mean')
10
+ @normalize = normalize
11
+ @cache_score = cache_score
12
+ @class_weight = class_weight
13
+
14
+ unless class_weight.nil?
15
+ if @class_weight.ndim != 1
16
+ raise ArgumentError 'class_weight.ndim should be 1'
17
+ elsif @class_weight.dtype != Numo::DFloat
18
+ raise ArgumentError 'The dtype of class_weight should be \'Numo::DFloat\''
19
+ elsif @class_weight.kind_of?(Chainer::Variable)
20
+ raise ArgumentError 'class_weight should be a Numo::NArray, not a chainer.Variable'
21
+ end
22
+ end
23
+
24
+ @ignore_label = ignore_label
25
+ unless ['mean', 'no'].include?(reduce)
26
+ raise ArgumentError "only 'mean' and 'no' are valid for 'reduce', but #{reduce} is given"
27
+ end
28
+
29
+ @reduce = reduce
30
+ end
31
+
32
+ def forward_cpu(inputs)
33
+ x, t = inputs
34
+ log_y = Activation.log_softmax(x)
35
+
36
+ if @cache_score
37
+ @y = Numo::NMath.exp(log_y)
38
+ end
39
+ if @class_weight
40
+ shape = x.ndim.times.map { |e| e == 1 ? -1 : 1 }
41
+ log_y += broadcast_to(@class_weight.reshape(*shape), x.shape)
42
+ end
43
+ log_yd = rollaxis(log_y, 1)
44
+ begin
45
+ log_yd = log_yd.reshape(log_yd.size, -1)
46
+ rescue ArgumentError
47
+ end
48
+
49
+ ravel_arr = t.dup.flatten.dup
50
+ ravel_arr[ravel_arr<0] = 0
51
+ arange_arr = t.class.new(t.size).seq
52
+
53
+ # https://github.com/chainer/chainer/blob/v2.0.2/chainer/functions/loss/softmax_cross_entropy.py#L79
54
+ log_p = []
55
+ arange_arr.each do |col_idx|
56
+ log_p << log_yd[ravel_arr, col_idx][col_idx]
57
+ end
58
+ log_p = Numo::NArray.[](*log_p)
59
+
60
+ log_p[log_p.eq(@ignore_label)] = 0
61
+
62
+ if @reduce == 'mean'
63
+ if @normalize
64
+ count = t.ne(@ignore_label).count
65
+ else
66
+ count = x.size
67
+ end
68
+ @coeff = 1.0 / [count, 1].max
69
+
70
+ y = log_p.sum(keepdims: true) * (-@coeff)
71
+ [y.reshape(())]
72
+ else
73
+ [-log_p.reshape(t.shape)]
74
+ end
75
+ end
76
+
77
+ def backward_cpu(inputs, grad_outputs)
78
+ x, t = inputs
79
+ gloss = grad_outputs[0]
80
+
81
+ if self.instance_variable_defined?(:'@y')
82
+ y = @y.dup
83
+ else
84
+ y = Activation.log_softmax(x)
85
+ y = Numo::NMath.exp(y)
86
+ end
87
+
88
+ if y.ndim == 2
89
+ gx = y
90
+ t[t<0] = 0
91
+ t.each_with_index do |v, idx|
92
+ gx[(idx * 10)...(idx * 10 + 10)][v] -= 1
93
+ end
94
+
95
+ if @class_weight
96
+ shape = x.ndim.times.map { |d| d == 1 ? -1 : 1 }
97
+ c = broadcast_to(@class_weight.reshape(shape), x.shape)
98
+ c = c[Numo::DFloat.new(t.size).seq, t]
99
+ gx *= broadcast_to(t.expand_dims(1), gx.shape)
100
+ end
101
+
102
+ bit = t.flatten.dup
103
+ bit[t.ne(@ignore_label)] = 1
104
+ bit[bit.ne(1)] = 0
105
+ gx *= bit.reshape(t.size, 1)
106
+ else
107
+ raise 'TODO: ndim > 2 backward'
108
+ end
109
+
110
+ if @reduce == 'mean'
111
+ gx *= gloss * @coeff
112
+ else
113
+ raise 'TODO: reduce'
114
+ end
115
+ return [gx, nil]
116
+ end
117
+
118
+
119
+ private
120
+
121
+ def broadcast_to(array, shape)
122
+ array.class.tile(array, shape[0]).reshape(*shape)
123
+ end
124
+
125
+ def rollaxis(y, axis, start: 0)
126
+ axes = (0...y.ndim).to_a
127
+ axes.delete_at(axis)
128
+ axes.insert(start, axis)
129
+ y.transpose(*axes)
130
+ end
131
+ end
132
+ end
133
+ end
134
+ end