red-chainer 0.1.1 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
- SHA1:
3
- metadata.gz: f236d4d7bf68a410d804ef19c6dc10173eb167d4
4
- data.tar.gz: 71f41bc080aa3a1f11018a0c15c6ee3c27e02f05
2
+ SHA256:
3
+ metadata.gz: 01f7ae8c937f4fb6384d9d6a0bf975d236da004d9165a734f9f5ab78c48d1197
4
+ data.tar.gz: 218ae1d0ebe3f90db0861d98d88a0be2a52e160f8001952d223df64ee8d22838
5
5
  SHA512:
6
- metadata.gz: ed00166f602c037f22215a88e58ac4eaf40b155f0463b5c72737a343468b5246530e962f73d34ec85563196ed275077fffd3acc37da3e6afbd25e56007331199
7
- data.tar.gz: 1567a64df5951c04e0140082afa3530417424f468fa983ae07bf83c7beac9fdc13a6332b23fb6633c37d2d487d7aa6fe15a3a5647ed652b79cedea51fc89a782
6
+ metadata.gz: aba81905dd93397b368b64fdd517b3b99487cc47dfb2ca7477cadf5e8357038be3f920f950a4968394040bd34e04e6537f24965a52e93f7d4855f174ff305d60
7
+ data.tar.gz: 96752ee9d7b151eb63e36ddda56d82910a6eee07109376de47cd977a77d1bb2bf0ecf78353835a049b795b35441197ede4cfb1dbff78ad4bd4a220da1c6f5009
data/.gitignore CHANGED
@@ -7,6 +7,7 @@
7
7
  /pkg/
8
8
  /spec/reports/
9
9
  /tmp/
10
+ result
10
11
 
11
12
  # rspec failure tracking
12
13
  .rspec_status
data/.travis.yml CHANGED
@@ -1,5 +1,7 @@
1
1
  sudo: false
2
2
  language: ruby
3
3
  rvm:
4
- - 2.4.1
4
+ - 2.3.5
5
+ - 2.4.2
5
6
  before_install: gem install bundler -v 1.15.1
7
+ script: ruby test/run_test.rb
data/README.md CHANGED
@@ -5,39 +5,35 @@
5
5
  Red Cahiner
6
6
 
7
7
  ## Description
8
-
9
8
  Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
10
9
 
11
- TODO: Delete this and the text above, and describe your gem
12
-
13
10
  ## Installation
14
11
 
15
12
  Add this line to your application's Gemfile:
16
13
 
17
- ```ruby
18
- gem 'red-chainer', github: 'red-data-tools/red-chainer'
14
+ ```bash
15
+ gem 'red-chainer'
19
16
  ```
20
17
 
21
18
  And then execute:
22
19
 
23
- ```ruby
20
+ ```bash
24
21
  $ bundle
25
22
  ```
26
23
 
27
24
  Or install it yourself as:
28
25
 
29
- ```ruby
30
- gem install specific_install
31
- gem specific_install -l 'https://github.com/red-data-tools/red-chainer'
26
+ ```bash
27
+ $ gem install red-chainer
32
28
  ```
33
29
 
34
30
  ## Usage
35
31
  mnist sample program is [here](./examples/mnist.rb)
36
32
 
37
- ```ruby
38
- # install Gemfile
33
+ ```bash
34
+ # when install Gemfile
39
35
  $ bundle exec ruby examples/mnist.rb
40
- # install yourself
36
+ # when install yourself
41
37
  $ ruby examples/mnist.rb
42
38
  ```
43
39
 
data/examples/mnist.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  require 'chainer'
2
2
  require 'fileutils'
3
+ require 'optparse'
3
4
  require 'tmpdir'
4
5
 
5
6
  class MLP < Chainer::Chain
@@ -22,21 +23,48 @@ class MLP < Chainer::Chain
22
23
  end
23
24
  end
24
25
 
25
- model = Chainer::Links::Model::Classifier.new(MLP.new(1000, 10))
26
+ args = {
27
+ batchsize: 100,
28
+ frequency: -1,
29
+ epoch: 20,
30
+ resume: nil,
31
+ unit: 1000,
32
+ out: 'result'
33
+ }
34
+
35
+ opt = OptionParser.new
36
+ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
37
+ opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
38
+ opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
39
+ opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
40
+ opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
41
+ opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
42
+ opt.parse!(ARGV)
43
+
44
+ model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10))
26
45
 
27
46
  optimizer = Chainer::Optimizers::Adam.new
28
47
  optimizer.setup(model)
29
48
  train, test = Chainer::Datasets::Mnist.get_mnist
30
49
 
31
- train_iter = Chainer::Iterators::SerialIterator.new(train, 100)
32
- test_iter = Chainer::Iterators::SerialIterator.new(test, 100, repeat: false, shuffle: false)
50
+ train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
51
+ test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
33
52
 
34
53
  updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
35
- trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [20, 'epoch'], out: 'result')
54
+ trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
36
55
 
37
56
  trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
57
+
58
+ # Take a snapshot for each specified epoch
59
+ frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
60
+ trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'], priority: -100)
61
+
38
62
  trainer.extend(Chainer::Training::Extensions::LogReport.new)
39
63
  trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
40
64
  trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
41
65
 
66
+ if args[:resume]
67
+ Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
68
+ end
69
+
42
70
  trainer.run
@@ -0,0 +1,64 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ # Leaky rectifier unit.
5
+ class LeakyReLU < Function
6
+ # Leaky Rectified Linear Unit function.
7
+ #
8
+ # This function is expressed as
9
+ #
10
+ # $$
11
+ # f(x)=\\max(x, ax),
12
+ # $$
13
+ #
14
+ # where $a$ is a configurable slope value.
15
+ #
16
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
17
+ # @param [float] slope Slope value $a$.
18
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
19
+ # @example
20
+ # > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
21
+ # > x
22
+ # => Numo::DFloat#shape=[3,2]
23
+ # [[-1, 0],
24
+ # [2, -3],
25
+ # [-2, 1]]
26
+ # > F = Chainer::Functions::Activation::LeakyReLU
27
+ # > F.leaky_relu(x, slope:0.2).data
28
+ # => Numo::DFloat#shape=[3,2]
29
+ # [[-0.2, 0],
30
+ # [2, -0.6],
31
+ # [-0.4, 1]]
32
+ #
33
+ def self.leaky_relu(x, slope: 0.2)
34
+ self.new(slope: slope).(x)
35
+ end
36
+
37
+ def initialize(slope:0.2)
38
+ @slope = slope
39
+ end
40
+
41
+ def forward_cpu(x)
42
+ y = x[0].dup()
43
+ y[x[0] < 0] *= @slope
44
+ if @slope >= 0
45
+ retain_inputs([])
46
+ retain_outputs([0])
47
+ end
48
+ [y]
49
+ end
50
+
51
+ def backward_cpu(x, gy)
52
+ gx = gy[0].dup()
53
+ if @slope >= 0
54
+ y = @output_data
55
+ gx[y[0] < 0] *= @slope
56
+ else
57
+ gx[x[0] < 0] *= @slope
58
+ end
59
+ [gx]
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
@@ -10,26 +10,68 @@ module Chainer
10
10
  m + s
11
11
  end
12
12
 
13
- def self.log_softmax(x)
13
+ def self._log_softmax(x)
14
14
  log_z = logsumexp(x)
15
15
  x - log_z
16
16
  end
17
17
 
18
+ # Log-softmax activation function.
18
19
  class LogSoftmax < Function
19
- def self.relu(x)
20
+ # Channel-wise log-softmax function.
21
+ #
22
+ # This function computes its logarithm of softmax along the second axis.
23
+ # Let $c = (c_1, c_2, \\dots, c_D)$ be the slice of +x+ along with
24
+ # the second axis. For each slice $c$, it computes the logarithm of
25
+ # the function $f(\c)$ defined as
26
+ #
27
+ # $$
28
+ # f(\c) = { \\exp(\c) \\over \\sum_{ d } \\exp(c_d) }.
29
+ # $$
30
+ #
31
+ # This method is theoretically equivalent to +log(softmax(x))+ but is more
32
+ # stable.
33
+ #
34
+ # @note
35
+ # +log(softmax(x))+ may cause underflow when +x+ is too small,
36
+ # because +softmax(x)+ may returns +0+.
37
+ # +log_softmax+ method is more stable.
38
+ #
39
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
40
+ # @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
41
+ #
42
+ # @see Chainer::Functions::Softmax
43
+ #
44
+ # @example
45
+ # > x = Numo::DFloat[[0, 1, 2], [0, 2, 4]]
46
+ # => Numo::DFloat#shape=[2,3]
47
+ # [[0, 1, 2],
48
+ # [0, 2, 4]]
49
+ # > F = Chainer::Functions::Activation::LogSoftmax
50
+ # > F.log_softmax(x).data
51
+ # => Numo::DFloat#shape=[2,3]
52
+ # [[-2.40761, -1.40761, -0.407606],
53
+ # [-4.14293, -2.14293, -0.142932]]
54
+ # @example (T.B.I : F.log, F.softmax)
55
+ # > F.log_softmax(x).data.nearly_eq(F.log(F.softmax(x)).data).all?)
56
+ # => true
57
+ #
58
+ def self.log_softmax(x)
20
59
  self.new.(x)
21
60
  end
22
61
 
23
- def forward_cpu(x)
62
+ def forward(xs)
63
+ y = Chainer::Functions::Activation._log_softmax(xs[0])
64
+ @x_shape = xs[0].shape
65
+ @x_dtype = xs[0].class
24
66
  retain_inputs([])
25
67
  retain_outputs([0])
26
- x[0][x[0]<=0] = 0
27
- [Utils::Array.force_array(x[0])]
68
+ [y]
28
69
  end
29
70
 
30
- def backward_cpu(x, gy)
31
- y = output_data[0]
32
- [Utils::Array.force_array(gy[0] * (y > 0))]
71
+ def backward(x, gy)
72
+ y = @output_data[0]
73
+ gx = gy[0] - Numo::NMath.exp(y) * gy[0].sum(axis: 1, keepdims: true)
74
+ [gx]
33
75
  end
34
76
  end
35
77
  end
@@ -1,7 +1,27 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Activation
4
+ # Rectified Linear Unit.
4
5
  class Relu < Function
6
+ # Rectified Linear Unit function.
7
+ #
8
+ # $$
9
+ # f(x)=\\max(0, x).
10
+ # $$
11
+ #
12
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
+ # @example
15
+ # > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
16
+ # > (x < 0).any?
17
+ # => true
18
+ # > F = Chainer::Functions::Activation::Relu
19
+ # > y = F.relu(x)
20
+ # > (y.data < 0).any?
21
+ # => false
22
+ # > y.shape
23
+ # => [3, 2]
24
+ #
5
25
  def self.relu(x)
6
26
  self.new.(x)
7
27
  end
@@ -14,7 +34,7 @@ module Chainer
14
34
  end
15
35
 
16
36
  def backward_cpu(x, gy)
17
- y = output_data[0]
37
+ y = @output_data[0]
18
38
  [Utils::Array.force_array(gy[0] * (y > 0))]
19
39
  end
20
40
  end
@@ -0,0 +1,43 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ # Logistic sigmoid function.
5
+ class Sigmoid < Function
6
+ # Element-wise sigmoid logistic function.
7
+ #
8
+ # $$
9
+ # f(x)=(1 + \\exp(-x))^ { -1 }.
10
+ # $$
11
+ #
12
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
+ # @example It maps the input values into the range of $`[0, 1]`$.
15
+ # > x = Numo::DFloat.new(3).seq(-2, 2)
16
+ # => Numo::DFloat#shape=[3]
17
+ # [-2, 0, 2]
18
+ # > F = Chainer::Functions::Activation::Sigmoid
19
+ # > F.sigmoid(x).data
20
+ # => Numo::DFloat#shape=[3]
21
+ # [0.119203, 0.5, 0.880797]
22
+ #
23
+ def self.sigmoid(x)
24
+ self.new.(x)
25
+ end
26
+
27
+ def forward_cpu(x)
28
+ half = 0.5
29
+ y = Utils::Array.force_array((Numo::NMath.tanh(x[0] * half) * half)+ half)
30
+ retain_inputs([])
31
+ retain_outputs([0])
32
+ return [y]
33
+ end
34
+
35
+ def backward_cpu(x, gy)
36
+ one = 1
37
+ y = @output_data[0]
38
+ [Utils::Array.force_array((gy[0] * y) * (one - y))]
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,42 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ # Hyperbolic tangent function.
5
+ class Tanh < Function
6
+ # Elementwise hyperbolic tangent function.
7
+ #
8
+ # $$
9
+ # f(x)=\\tanh(x).
10
+ # $$
11
+ #
12
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
+ # @example
15
+ # > x = Numo::DFloat.new(3).seq(-1, 2)
16
+ # => Numo::DFloat#shape=[3]
17
+ # [-1, 1, 3]
18
+ # > F = Chainer::Functions::Activation::Tanh
19
+ # > F.tanh(x).data
20
+ # => Numo::DFloat#shape=[3]
21
+ # [-0.761594, 0.761594, 0.995055]
22
+ #
23
+ def self.tanh(x)
24
+ self.new.(x)
25
+ end
26
+
27
+ def forward_cpu(x)
28
+ y = Utils::Array.force_array(Numo::NMath.tanh(x[0]))
29
+ retain_inputs([])
30
+ retain_outputs([0])
31
+ return [y]
32
+ end
33
+
34
+ def backward_cpu(x, gy)
35
+ y = @output_data[0]
36
+ one = y.dtype.type(1)
37
+ [Utils::Array.force_array(gy[0] * (one - y * y))]
38
+ end
39
+ end
40
+ end
41
+ end
42
+ end
@@ -31,7 +31,7 @@ module Chainer
31
31
 
32
32
  def forward_cpu(inputs)
33
33
  x, t = inputs
34
- log_y = Activation.log_softmax(x)
34
+ log_y = Activation._log_softmax(x)
35
35
 
36
36
  if @cache_score
37
37
  @y = Numo::NMath.exp(log_y)
@@ -81,7 +81,7 @@ module Chainer
81
81
  if self.instance_variable_defined?(:'@y')
82
82
  y = @y.dup
83
83
  else
84
- y = Activation.log_softmax(x)
84
+ y = Activation._log_softmax(x)
85
85
  y = Numo::NMath.exp(y)
86
86
  end
87
87
 
@@ -15,7 +15,7 @@ module Chainer
15
15
  def next
16
16
  raise StopIteration if !@repeat && @epoch > 0
17
17
 
18
- @previous_epoch_detail = @epoch_detail
18
+ @previous_epoch_detail = epoch_detail
19
19
 
20
20
  i = @current_position
21
21
  i_end = i + @batch_size
@@ -55,6 +55,31 @@ module Chainer
55
55
  @epoch + @current_position.to_f / @dataset.size
56
56
  end
57
57
 
58
+ def serialize(serializer)
59
+ @current_position = serializer.('current_position', @current_position)
60
+ @epoch = serializer.('epoch', @epoch)
61
+ @is_new_epoch = serializer.('is_new_epoch', @is_new_epoch)
62
+ unless @order.nil?
63
+ begin
64
+ serializer.('order', @order)
65
+ rescue KeyError
66
+ serializer('_order', @order)
67
+ end
68
+ end
69
+
70
+ begin
71
+ @previous_epoch_detail = serializer.( 'previous_epoch_detail', @previous_epoch_detail)
72
+ rescue KeyError
73
+ # guess previous_epoch_detail for older version
74
+ @previous_epoch_detail = @epoch + (@current_position - @batch_size) / @dataset.size
75
+ if epoch_detail > 0
76
+ @previous_epoch_detail = [@previous_epoch_detail, 0.0].max
77
+ else
78
+ @previous_epoch_detail = -1.0
79
+ end
80
+ end
81
+ end
82
+
58
83
  def reset
59
84
  if @shuffle
60
85
  order = @dataset.size.times.map(&:to_i).shuffle
data/lib/chainer/link.rb CHANGED
@@ -59,6 +59,27 @@ module Chainer
59
59
  def namedlinks(skipself: false)
60
60
  yield('/', self) unless skipself
61
61
  end
62
+
63
+ def serialize(serializer)
64
+ d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
65
+ @params.each do |name|
66
+ param = d[name]
67
+ data = serializer.(name.to_s, param.data)
68
+ if param.data.nil? && !data.nil?
69
+ # Initialize the parameter here
70
+ param.init(data.shape)
71
+ if param.data.is_a?(Numo::NArray)
72
+ param.data.store(data)
73
+ else
74
+ param.data.set(Numo::NArray.cast(data))
75
+ end
76
+ end
77
+ end
78
+
79
+ @persistent.each do |name|
80
+ d[name] = serializer.(name.to_s, d[name])
81
+ end
82
+ end
62
83
  end
63
84
 
64
85
  class Chain < Link
@@ -114,5 +135,13 @@ module Chainer
114
135
  end
115
136
  end
116
137
  end
138
+
139
+ def serialize(serializer)
140
+ super(serializer)
141
+ d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
142
+ @children.each do |name|
143
+ d[name].serialize(serializer[name.to_s])
144
+ end
145
+ end
117
146
  end
118
147
  end
@@ -19,6 +19,17 @@ module Chainer
19
19
  hook(self)
20
20
  end
21
21
  end
22
+
23
+ def serialize(serializer)
24
+ @t = serializer.('t', @t)
25
+ @epoch = serializer.('epoch', @epoch)
26
+
27
+ @target.namedparams() do |(name, param)|
28
+ if param.respond_to?(:update_rule)
29
+ param.update_rule.serialize(serializer[name.to_s])
30
+ end
31
+ end
32
+ end
22
33
  end
23
34
 
24
35
  class UpdateRule
@@ -56,6 +67,33 @@ module Chainer
56
67
  raise NotImplementedError
57
68
  end
58
69
 
70
+
71
+ # Serializes the update rule state.
72
+ # Be careful that this method only saves/loads the state of the update rule.
73
+ # The parameters of the target link is not saved/loaded by this
74
+ # method, and so you need to serialize the target link separately if you
75
+ # want to fully recover the training state including parameters.
76
+ #
77
+ # @param [Chainer::AbstractSerializer] serializer: Serializer object.
78
+ def serialize(serializer)
79
+ if @state.nil?
80
+ if serializer.is_a?(Chainer::Deserializer)
81
+ # try to initialize the state to retrieve state entries
82
+ @state = {}
83
+ self_copy = self.dup
84
+ arr = Numo::DFloat.new(1)
85
+ self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
86
+ @state.keys.each do |key|
87
+ @state[key] = serializer.(key.to_s, nil)
88
+ end
89
+ end
90
+ else
91
+ @state.each do |key, val|
92
+ @state[key] = serializer.(key.to_s, val)
93
+ end
94
+ end
95
+ end
96
+
59
97
  private
60
98
 
61
99
  def prepare(param)
@@ -0,0 +1,50 @@
1
+ module Chainer
2
+ # Abstract base class of all serializers and deserializers.
3
+ class AbstractSerializer
4
+ # Gets a child serializer.
5
+ # This operator creates a child serializer represented by the given key.
6
+ #
7
+ # @param [string] key: Name of the child serializer.
8
+ def [](key)
9
+ raise NotImplementedError
10
+ end
11
+
12
+ # Serializes or deserializes a value by given name.
13
+ # This operator saves or loads a value by given name.
14
+ # If this is a serializer, then the value is simply saved at the key.
15
+ # Note that some type information might be missed depending on the
16
+ # implementation (and the target file format).
17
+ # If this is a deserializer, then the value is loaded by the key.
18
+ # The deserialization differently works on scalars and arrays.
19
+ # For scalars, the ``value`` argument is used just for determining the type of
20
+ # restored value to be converted, and the converted value is returned.
21
+ # For arrays, the restored elements are directly copied into the
22
+ # ``value`` argument. String values are treated like scalars.
23
+ #
24
+ # @param [string] key: Name of the serialization entry.
25
+ # @param [any] value: Object to be (de)serialized.
26
+ # ``None`` is only supported by deserializers.
27
+ # @return Serialized or deserialized value.
28
+ def call(key, value)
29
+ raise NotImplementedError
30
+ end
31
+ end
32
+
33
+ # Base class of all serializers.
34
+ class Serializer < AbstractSerializer
35
+ # Saves an object by this serializer.
36
+ # This is equivalent to ``obj.serialize(self)``.
37
+ #
38
+ # @param [any] obj: Target object to be serialized.
39
+ def save(obj)
40
+ obj.serialize(self)
41
+ end
42
+ end
43
+
44
+ # Base class of all deserializers.
45
+ class Deserializer < AbstractSerializer
46
+ def load(obj)
47
+ obj.serialize(self)
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,83 @@
1
+ module Chainer
2
+ module Serializers
3
+ class MarshalSerializer < Chainer::Serializer
4
+ attr_accessor :target, :path
5
+
6
+ def self.save_file(filename, obj)
7
+ s = self.new
8
+ s.save(obj)
9
+ Marshal.dump(s.target, filename)
10
+ end
11
+
12
+ def initialize(target: nil, path: "")
13
+ @target = target.nil? ? {} : target
14
+ @path = path
15
+ end
16
+
17
+ def [](key)
18
+ self.class.new(target: @target, path: File.join(@path, key, '/'))
19
+ end
20
+
21
+ def call(key, value)
22
+ ret = value
23
+ if value.is_a?(TrueClass)
24
+ arr = Numo::Bit[1]
25
+ elsif value.is_a?(FalseClass)
26
+ arr = Numo::Bit[0]
27
+ elsif value.instance_of?(String)
28
+ arr = value
29
+ else
30
+ arr = Numo::NArray.cast(value)
31
+ end
32
+ @target[File.join(@path, key)] = arr
33
+ ret
34
+ end
35
+ end
36
+
37
+ class MarshalDeserializer < Chainer::Deserializer
38
+ # Loads an object from the file in Marshal format.
39
+ # This is a short-cut function to load from an Marshal file that contains only one object.
40
+ #
41
+ # @param [string ]filename: Name of the file to be loaded.
42
+ # @param [object] obj: Object to be deserialized. It must support serialization protocol.
43
+ def self.load_file(filename, obj)
44
+ File.open(filename) do |f|
45
+ d = self.new(Marshal.load(f))
46
+ d.load(obj)
47
+ end
48
+ end
49
+
50
+ def initialize(marshalData, path: '', strict: true)
51
+ @marshal_data = marshalData
52
+ @path = path
53
+ @strict = strict
54
+ end
55
+
56
+ def [](key)
57
+ self.class.new(@marshal_data, path: File.join(@path, key, '/'), strict: @strict)
58
+ end
59
+
60
+ def call(key, value)
61
+ key = File.join(@path, key)
62
+ if !@strict && !@marshal_data.keys.include?(key)
63
+ return value
64
+ end
65
+
66
+ dataset = @marshal_data[key]
67
+ if value.nil?
68
+ return dataset
69
+ elsif value.instance_of?(String)
70
+ return dataset
71
+ elsif value.is_a?(Numo::NArray)
72
+ value.store(dataset)
73
+ return value
74
+ elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)
75
+ return dataset[0] == 1
76
+ else
77
+ return dataset[0]
78
+ end
79
+ end
80
+ end
81
+ end
82
+ end
83
+
@@ -61,6 +61,21 @@ module Chainer
61
61
  init_summary
62
62
  end
63
63
 
64
+ def serialize(serializer)
65
+ if @trigger.respond_to?(:serialize)
66
+ @trigger.serialize(serializer['_trigger'])
67
+ end
68
+ # Note that this serialization may lose some information of small
69
+ # numerical differences.
70
+ if serializer.is_a?(Chainer::Serializer)
71
+ log = JSON.generate(@log)
72
+ serializer.('_log', log)
73
+ else
74
+ log = serializer.('_log', '')
75
+ @log = JSON.parse(log)
76
+ end
77
+ end
78
+
64
79
  private
65
80
 
66
81
  def init_summary
@@ -44,6 +44,12 @@ module Chainer
44
44
  end
45
45
  end
46
46
 
47
+ def serialize(serializer)
48
+ if @log_report.is_a?(Chainer::Training::Extensions::LogReport)
49
+ @log_report.serialize(serializer['_log_report'])
50
+ end
51
+ end
52
+
47
53
  private
48
54
 
49
55
  def print(observation)
@@ -0,0 +1,33 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ class Snapshot < Extension
5
+ attr_accessor :save_class, :filename_proc, :target
6
+
7
+ def self.snapshot_object(target:, save_class:, &block)
8
+ self.new(save_class: save_class, filename_proc: block, target: target)
9
+ end
10
+
11
+ def self.snapshot(save_class: nil, &block)
12
+ self.new(save_class: save_class, filename_proc: block)
13
+ end
14
+
15
+ def initialize(save_class: nil, filename_proc: nil, target: nil)
16
+ @save_class = save_class || Chainer::Serializers::MarshalSerializer
17
+ @filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
18
+ @target = target
19
+ end
20
+
21
+ def call(trainer)
22
+ target = @target || trainer
23
+ filename = filename_proc.call(trainer)
24
+ prefix = "tmp#{filename}"
25
+ temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
26
+ save_class.save_file(temp_file, trainer)
27
+ FileUtils.move(temp_file.path, File.join(trainer.out, filename))
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+
@@ -58,6 +58,18 @@ module Chainer
58
58
  iterator.finalize
59
59
  end
60
60
  end
61
+
62
+ def serialize(serializer)
63
+ @iterators.each do |name, iterator|
64
+ iterator.serialize(serializer["iterator:#{name}"])
65
+ end
66
+ @optimizers.each do |name, optimizer|
67
+ optimizer.serialize(serializer["optimizer:#{name}"])
68
+ optimizer.target.serialize(serializer["model:#{name}"])
69
+ end
70
+
71
+ @iteration = serializer.('iteration', @iteration)
72
+ end
61
73
  end
62
74
  end
63
75
  end
@@ -44,7 +44,7 @@ module Chainer
44
44
  return @final_elapsed_time if @done
45
45
  raise "training has not been started yet" if @start_at.nil?
46
46
 
47
- Time.now.to_f - @start_at - @snapshot_elapsed_time.to_f
47
+ Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
48
48
  end
49
49
 
50
50
  def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
@@ -97,7 +97,7 @@ module Chainer
97
97
  raise 'cannot run training loop multiple times' if @done
98
98
  FileUtils.mkdir_p(@out)
99
99
 
100
- extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
100
+ extensions = @extensions.sort_by { |(_, e)| -e.priority }.map { |(name, extension)| [name, extension] }
101
101
 
102
102
  @start_at = Time.now.to_f
103
103
 
@@ -115,7 +115,7 @@ module Chainer
115
115
  @observation = {}
116
116
  reporter.scope(@observation) do
117
117
  update.call
118
- extensions.each do |(_, entry)|
118
+ extensions.each do |(name, entry)|
119
119
  entry.extension.(self) if entry.trigger.(self)
120
120
  end
121
121
  end
@@ -131,6 +131,29 @@ module Chainer
131
131
  @final_elapsed_time = @elapsed_time
132
132
  @done = true
133
133
  end
134
+
135
+ def serialize(serializer)
136
+ updater.serialize(serializer['updater'])
137
+ if @stop_trigger.respond_to?(:serialize)
138
+ @stop_trigger.serialize(serializer['stop_trigger'])
139
+ end
140
+
141
+ s = serializer['extensions']
142
+ t = serializer['extension_triggers']
143
+ @extensions.each do |name, entry|
144
+ if entry.extension.respond_to?(:serialize)
145
+ entry.extension.serialize(s[name])
146
+ end
147
+ if entry.trigger.respond_to?(:serialize)
148
+ entry.trigger.serialize(t[name])
149
+ end
150
+ end
151
+ if serializer.is_a?(Chainer::Serializer)
152
+ serializer.('_snapshot_elapsed_time', elapsed_time)
153
+ else
154
+ @snapshot_elapsed_time = serializer.('_snapshot_elapsed_time', 0.0)
155
+ end
156
+ end
134
157
  end
135
158
  end
136
159
  end
@@ -8,18 +8,43 @@ module Chainer
8
8
  @period = period
9
9
  @unit = unit
10
10
  @count = 0
11
+
12
+ @previous_iteration = 0
13
+ @previous_epoch_detail = 0.0
11
14
  end
12
15
 
13
16
  def call(trainer)
14
17
  updater = trainer.updater
15
18
  if @unit == 'epoch'
16
- prev = @count
17
- @count = updater.epoch_detail.div(@period)
18
- prev != @count
19
+ epoch_detail = updater.epoch_detail
20
+ previous_epoch_detail = @previous_epoch_detail
21
+
22
+ if previous_epoch_detail < 0
23
+ previous_epoch_detail = updater.previous_epoch_detail
24
+ end
25
+
26
+ @count = epoch_detail.div(@period).floor
27
+
28
+ fire = previous_epoch_detail.div(@period).floor != epoch_detail.div(@period).floor
19
29
  else
20
30
  iteration = updater.iteration
21
- iteration > 0 && iteration % @period == 0
31
+ previous_iteration = @previous_iteration
32
+ if previous_iteration < 0
33
+ previous_iteration = iteration - 1
34
+ end
35
+ fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
22
36
  end
37
+
38
+ # save current values
39
+ @previous_iteration = updater.iteration
40
+ @previous_epoch_detail = updater.epoch_detail
41
+
42
+ fire
43
+ end
44
+
45
+ def serialize(serializer)
46
+ @previous_iteration = serializer.('previous_iteration', @previous_iteration)
47
+ @previous_epoch_detail = serializer.('previous_epoch_detail', @previous_epoch_detail)
23
48
  end
24
49
  end
25
50
  end
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.1.1"
2
+ VERSION = "0.2.0"
3
3
  end
4
4
 
data/lib/chainer.rb CHANGED
@@ -22,7 +22,10 @@ require 'chainer/variable_node'
22
22
  require 'chainer/utils/initializer'
23
23
  require 'chainer/utils/variable'
24
24
  require 'chainer/utils/array'
25
+ require 'chainer/functions/activation/leaky_relu'
25
26
  require 'chainer/functions/activation/relu'
27
+ require 'chainer/functions/activation/sigmoid'
28
+ require 'chainer/functions/activation/tanh'
26
29
  require 'chainer/functions/activation/log_softmax'
27
30
  require 'chainer/functions/evaluation/accuracy'
28
31
  require 'chainer/functions/math/basic_math'
@@ -33,6 +36,7 @@ require 'chainer/training/extensions/evaluator'
33
36
  require 'chainer/training/extensions/log_report'
34
37
  require 'chainer/training/extensions/print_report'
35
38
  require 'chainer/training/extensions/progress_bar'
39
+ require 'chainer/training/extensions/snapshot'
36
40
  require 'chainer/training/trainer'
37
41
  require 'chainer/training/updater'
38
42
  require 'chainer/training/util'
@@ -44,6 +48,8 @@ require 'chainer/dataset/download'
44
48
  require 'chainer/datasets/mnist'
45
49
  require 'chainer/datasets/tuple_dataset'
46
50
  require 'chainer/reporter'
51
+ require 'chainer/serializer'
52
+ require 'chainer/serializers/marshal'
47
53
 
48
54
  require 'numo/narray'
49
55
 
data/red-chainer.gemspec CHANGED
@@ -19,7 +19,7 @@ Gem::Specification.new do |spec|
19
19
  spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
20
20
  spec.require_paths = ["lib"]
21
21
 
22
- spec.add_runtime_dependency "numo-narray", ">= 0.9.0.8"
22
+ spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
23
23
 
24
24
  spec.add_development_dependency "bundler", "~> 1.15"
25
25
  spec.add_development_dependency "rake", "~> 10.0"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2017-11-18 00:00:00.000000000 Z
11
+ date: 2018-02-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.9.0.8
19
+ version: 0.9.1.1
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.9.0.8
26
+ version: 0.9.1.1
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: bundler
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -92,8 +92,11 @@ files:
92
92
  - lib/chainer/datasets/mnist.rb
93
93
  - lib/chainer/datasets/tuple_dataset.rb
94
94
  - lib/chainer/function.rb
95
+ - lib/chainer/functions/activation/leaky_relu.rb
95
96
  - lib/chainer/functions/activation/log_softmax.rb
96
97
  - lib/chainer/functions/activation/relu.rb
98
+ - lib/chainer/functions/activation/sigmoid.rb
99
+ - lib/chainer/functions/activation/tanh.rb
97
100
  - lib/chainer/functions/connection/linear.rb
98
101
  - lib/chainer/functions/evaluation/accuracy.rb
99
102
  - lib/chainer/functions/loss/softmax_cross_entropy.rb
@@ -112,11 +115,14 @@ files:
112
115
  - lib/chainer/optimizers/adam.rb
113
116
  - lib/chainer/parameter.rb
114
117
  - lib/chainer/reporter.rb
118
+ - lib/chainer/serializer.rb
119
+ - lib/chainer/serializers/marshal.rb
115
120
  - lib/chainer/training/extension.rb
116
121
  - lib/chainer/training/extensions/evaluator.rb
117
122
  - lib/chainer/training/extensions/log_report.rb
118
123
  - lib/chainer/training/extensions/print_report.rb
119
124
  - lib/chainer/training/extensions/progress_bar.rb
125
+ - lib/chainer/training/extensions/snapshot.rb
120
126
  - lib/chainer/training/standard_updater.rb
121
127
  - lib/chainer/training/trainer.rb
122
128
  - lib/chainer/training/triggers/interval.rb
@@ -149,7 +155,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
149
155
  version: '0'
150
156
  requirements: []
151
157
  rubyforge_project:
152
- rubygems_version: 2.6.13
158
+ rubygems_version: 2.7.3
153
159
  signing_key:
154
160
  specification_version: 4
155
161
  summary: A flexible framework for neural network for Ruby