red-chainer 0.1.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
- SHA1:
3
- metadata.gz: f236d4d7bf68a410d804ef19c6dc10173eb167d4
4
- data.tar.gz: 71f41bc080aa3a1f11018a0c15c6ee3c27e02f05
2
+ SHA256:
3
+ metadata.gz: 01f7ae8c937f4fb6384d9d6a0bf975d236da004d9165a734f9f5ab78c48d1197
4
+ data.tar.gz: 218ae1d0ebe3f90db0861d98d88a0be2a52e160f8001952d223df64ee8d22838
5
5
  SHA512:
6
- metadata.gz: ed00166f602c037f22215a88e58ac4eaf40b155f0463b5c72737a343468b5246530e962f73d34ec85563196ed275077fffd3acc37da3e6afbd25e56007331199
7
- data.tar.gz: 1567a64df5951c04e0140082afa3530417424f468fa983ae07bf83c7beac9fdc13a6332b23fb6633c37d2d487d7aa6fe15a3a5647ed652b79cedea51fc89a782
6
+ metadata.gz: aba81905dd93397b368b64fdd517b3b99487cc47dfb2ca7477cadf5e8357038be3f920f950a4968394040bd34e04e6537f24965a52e93f7d4855f174ff305d60
7
+ data.tar.gz: 96752ee9d7b151eb63e36ddda56d82910a6eee07109376de47cd977a77d1bb2bf0ecf78353835a049b795b35441197ede4cfb1dbff78ad4bd4a220da1c6f5009
data/.gitignore CHANGED
@@ -7,6 +7,7 @@
7
7
  /pkg/
8
8
  /spec/reports/
9
9
  /tmp/
10
+ result
10
11
 
11
12
  # rspec failure tracking
12
13
  .rspec_status
data/.travis.yml CHANGED
@@ -1,5 +1,7 @@
1
1
  sudo: false
2
2
  language: ruby
3
3
  rvm:
4
- - 2.4.1
4
+ - 2.3.5
5
+ - 2.4.2
5
6
  before_install: gem install bundler -v 1.15.1
7
+ script: ruby test/run_test.rb
data/README.md CHANGED
@@ -5,39 +5,35 @@
5
5
  Red Cahiner
6
6
 
7
7
  ## Description
8
-
9
8
  Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/chainer`. To experiment with that code, run `bin/console` for an interactive prompt.
10
9
 
11
- TODO: Delete this and the text above, and describe your gem
12
-
13
10
  ## Installation
14
11
 
15
12
  Add this line to your application's Gemfile:
16
13
 
17
- ```ruby
18
- gem 'red-chainer', github: 'red-data-tools/red-chainer'
14
+ ```bash
15
+ gem 'red-chainer'
19
16
  ```
20
17
 
21
18
  And then execute:
22
19
 
23
- ```ruby
20
+ ```bash
24
21
  $ bundle
25
22
  ```
26
23
 
27
24
  Or install it yourself as:
28
25
 
29
- ```ruby
30
- gem install specific_install
31
- gem specific_install -l 'https://github.com/red-data-tools/red-chainer'
26
+ ```bash
27
+ $ gem install red-chainer
32
28
  ```
33
29
 
34
30
  ## Usage
35
31
  mnist sample program is [here](./examples/mnist.rb)
36
32
 
37
- ```ruby
38
- # install Gemfile
33
+ ```bash
34
+ # when install Gemfile
39
35
  $ bundle exec ruby examples/mnist.rb
40
- # install yourself
36
+ # when install yourself
41
37
  $ ruby examples/mnist.rb
42
38
  ```
43
39
 
data/examples/mnist.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  require 'chainer'
2
2
  require 'fileutils'
3
+ require 'optparse'
3
4
  require 'tmpdir'
4
5
 
5
6
  class MLP < Chainer::Chain
@@ -22,21 +23,48 @@ class MLP < Chainer::Chain
22
23
  end
23
24
  end
24
25
 
25
- model = Chainer::Links::Model::Classifier.new(MLP.new(1000, 10))
26
+ args = {
27
+ batchsize: 100,
28
+ frequency: -1,
29
+ epoch: 20,
30
+ resume: nil,
31
+ unit: 1000,
32
+ out: 'result'
33
+ }
34
+
35
+ opt = OptionParser.new
36
+ opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
37
+ opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
38
+ opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
39
+ opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
40
+ opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
41
+ opt.on('-u', '--unit VALUE', "Number of units (default: #{args[:unit]})") { |v| args[:unit] = v.to_i }
42
+ opt.parse!(ARGV)
43
+
44
+ model = Chainer::Links::Model::Classifier.new(MLP.new(args[:unit], 10))
26
45
 
27
46
  optimizer = Chainer::Optimizers::Adam.new
28
47
  optimizer.setup(model)
29
48
  train, test = Chainer::Datasets::Mnist.get_mnist
30
49
 
31
- train_iter = Chainer::Iterators::SerialIterator.new(train, 100)
32
- test_iter = Chainer::Iterators::SerialIterator.new(test, 100, repeat: false, shuffle: false)
50
+ train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
51
+ test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
33
52
 
34
53
  updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
35
- trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [20, 'epoch'], out: 'result')
54
+ trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
36
55
 
37
56
  trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
57
+
58
+ # Take a snapshot for each specified epoch
59
+ frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
60
+ trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'], priority: -100)
61
+
38
62
  trainer.extend(Chainer::Training::Extensions::LogReport.new)
39
63
  trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
40
64
  trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
41
65
 
66
+ if args[:resume]
67
+ Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
68
+ end
69
+
42
70
  trainer.run
@@ -0,0 +1,64 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ # Leaky rectifier unit.
5
+ class LeakyReLU < Function
6
+ # Leaky Rectified Linear Unit function.
7
+ #
8
+ # This function is expressed as
9
+ #
10
+ # $$
11
+ # f(x)=\\max(x, ax),
12
+ # $$
13
+ #
14
+ # where $a$ is a configurable slope value.
15
+ #
16
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
17
+ # @param [float] slope Slope value $a$.
18
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
19
+ # @example
20
+ # > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
21
+ # > x
22
+ # => Numo::DFloat#shape=[3,2]
23
+ # [[-1, 0],
24
+ # [2, -3],
25
+ # [-2, 1]]
26
+ # > F = Chainer::Functions::Activation::LeakyReLU
27
+ # > F.leaky_relu(x, slope:0.2).data
28
+ # => Numo::DFloat#shape=[3,2]
29
+ # [[-0.2, 0],
30
+ # [2, -0.6],
31
+ # [-0.4, 1]]
32
+ #
33
+ def self.leaky_relu(x, slope: 0.2)
34
+ self.new(slope: slope).(x)
35
+ end
36
+
37
+ def initialize(slope:0.2)
38
+ @slope = slope
39
+ end
40
+
41
+ def forward_cpu(x)
42
+ y = x[0].dup()
43
+ y[x[0] < 0] *= @slope
44
+ if @slope >= 0
45
+ retain_inputs([])
46
+ retain_outputs([0])
47
+ end
48
+ [y]
49
+ end
50
+
51
+ def backward_cpu(x, gy)
52
+ gx = gy[0].dup()
53
+ if @slope >= 0
54
+ y = @output_data
55
+ gx[y[0] < 0] *= @slope
56
+ else
57
+ gx[x[0] < 0] *= @slope
58
+ end
59
+ [gx]
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
@@ -10,26 +10,68 @@ module Chainer
10
10
  m + s
11
11
  end
12
12
 
13
- def self.log_softmax(x)
13
+ def self._log_softmax(x)
14
14
  log_z = logsumexp(x)
15
15
  x - log_z
16
16
  end
17
17
 
18
+ # Log-softmax activation function.
18
19
  class LogSoftmax < Function
19
- def self.relu(x)
20
+ # Channel-wise log-softmax function.
21
+ #
22
+ # This function computes its logarithm of softmax along the second axis.
23
+ # Let $c = (c_1, c_2, \\dots, c_D)$ be the slice of +x+ along with
24
+ # the second axis. For each slice $c$, it computes the logarithm of
25
+ # the function $f(\c)$ defined as
26
+ #
27
+ # $$
28
+ # f(\c) = { \\exp(\c) \\over \\sum_{ d } \\exp(c_d) }.
29
+ # $$
30
+ #
31
+ # This method is theoretically equivalent to +log(softmax(x))+ but is more
32
+ # stable.
33
+ #
34
+ # @note
35
+ # +log(softmax(x))+ may cause underflow when +x+ is too small,
36
+ # because +softmax(x)+ may returns +0+.
37
+ # +log_softmax+ method is more stable.
38
+ #
39
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $n$-dimensional ($n \\geq 2$) float array.
40
+ # @return [Chainer::Variable] Output variable. A $n$-dimensional ($n \\geq 2$) float array, which is the same shape with x.
41
+ #
42
+ # @see Chainer::Functions::Softmax
43
+ #
44
+ # @example
45
+ # > x = Numo::DFloat[[0, 1, 2], [0, 2, 4]]
46
+ # => Numo::DFloat#shape=[2,3]
47
+ # [[0, 1, 2],
48
+ # [0, 2, 4]]
49
+ # > F = Chainer::Functions::Activation::LogSoftmax
50
+ # > F.log_softmax(x).data
51
+ # => Numo::DFloat#shape=[2,3]
52
+ # [[-2.40761, -1.40761, -0.407606],
53
+ # [-4.14293, -2.14293, -0.142932]]
54
+ # @example (T.B.I : F.log, F.softmax)
55
+ # > F.log_softmax(x).data.nearly_eq(F.log(F.softmax(x)).data).all?)
56
+ # => true
57
+ #
58
+ def self.log_softmax(x)
20
59
  self.new.(x)
21
60
  end
22
61
 
23
- def forward_cpu(x)
62
+ def forward(xs)
63
+ y = Chainer::Functions::Activation._log_softmax(xs[0])
64
+ @x_shape = xs[0].shape
65
+ @x_dtype = xs[0].class
24
66
  retain_inputs([])
25
67
  retain_outputs([0])
26
- x[0][x[0]<=0] = 0
27
- [Utils::Array.force_array(x[0])]
68
+ [y]
28
69
  end
29
70
 
30
- def backward_cpu(x, gy)
31
- y = output_data[0]
32
- [Utils::Array.force_array(gy[0] * (y > 0))]
71
+ def backward(x, gy)
72
+ y = @output_data[0]
73
+ gx = gy[0] - Numo::NMath.exp(y) * gy[0].sum(axis: 1, keepdims: true)
74
+ [gx]
33
75
  end
34
76
  end
35
77
  end
@@ -1,7 +1,27 @@
1
1
  module Chainer
2
2
  module Functions
3
3
  module Activation
4
+ # Rectified Linear Unit.
4
5
  class Relu < Function
6
+ # Rectified Linear Unit function.
7
+ #
8
+ # $$
9
+ # f(x)=\\max(0, x).
10
+ # $$
11
+ #
12
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
+ # @example
15
+ # > x = Numo::DFloat[[-1, 0], [2, -3], [-2, 1]]
16
+ # > (x < 0).any?
17
+ # => true
18
+ # > F = Chainer::Functions::Activation::Relu
19
+ # > y = F.relu(x)
20
+ # > (y.data < 0).any?
21
+ # => false
22
+ # > y.shape
23
+ # => [3, 2]
24
+ #
5
25
  def self.relu(x)
6
26
  self.new.(x)
7
27
  end
@@ -14,7 +34,7 @@ module Chainer
14
34
  end
15
35
 
16
36
  def backward_cpu(x, gy)
17
- y = output_data[0]
37
+ y = @output_data[0]
18
38
  [Utils::Array.force_array(gy[0] * (y > 0))]
19
39
  end
20
40
  end
@@ -0,0 +1,43 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ # Logistic sigmoid function.
5
+ class Sigmoid < Function
6
+ # Element-wise sigmoid logistic function.
7
+ #
8
+ # $$
9
+ # f(x)=(1 + \\exp(-x))^ { -1 }.
10
+ # $$
11
+ #
12
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
+ # @example It maps the input values into the range of $`[0, 1]`$.
15
+ # > x = Numo::DFloat.new(3).seq(-2, 2)
16
+ # => Numo::DFloat#shape=[3]
17
+ # [-2, 0, 2]
18
+ # > F = Chainer::Functions::Activation::Sigmoid
19
+ # > F.sigmoid(x).data
20
+ # => Numo::DFloat#shape=[3]
21
+ # [0.119203, 0.5, 0.880797]
22
+ #
23
+ def self.sigmoid(x)
24
+ self.new.(x)
25
+ end
26
+
27
+ def forward_cpu(x)
28
+ half = 0.5
29
+ y = Utils::Array.force_array((Numo::NMath.tanh(x[0] * half) * half)+ half)
30
+ retain_inputs([])
31
+ retain_outputs([0])
32
+ return [y]
33
+ end
34
+
35
+ def backward_cpu(x, gy)
36
+ one = 1
37
+ y = @output_data[0]
38
+ [Utils::Array.force_array((gy[0] * y) * (one - y))]
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,42 @@
1
+ module Chainer
2
+ module Functions
3
+ module Activation
4
+ # Hyperbolic tangent function.
5
+ class Tanh < Function
6
+ # Elementwise hyperbolic tangent function.
7
+ #
8
+ # $$
9
+ # f(x)=\\tanh(x).
10
+ # $$
11
+ #
12
+ # @param [Chainer::Variable or Numo::DFloat] x Input variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
13
+ # @return [Chainer::Variable] Output variable. A $(s_1, s_2, ..., s_N)$-shaped float array.
14
+ # @example
15
+ # > x = Numo::DFloat.new(3).seq(-1, 2)
16
+ # => Numo::DFloat#shape=[3]
17
+ # [-1, 1, 3]
18
+ # > F = Chainer::Functions::Activation::Tanh
19
+ # > F.tanh(x).data
20
+ # => Numo::DFloat#shape=[3]
21
+ # [-0.761594, 0.761594, 0.995055]
22
+ #
23
+ def self.tanh(x)
24
+ self.new.(x)
25
+ end
26
+
27
+ def forward_cpu(x)
28
+ y = Utils::Array.force_array(Numo::NMath.tanh(x[0]))
29
+ retain_inputs([])
30
+ retain_outputs([0])
31
+ return [y]
32
+ end
33
+
34
+ def backward_cpu(x, gy)
35
+ y = @output_data[0]
36
+ one = y.dtype.type(1)
37
+ [Utils::Array.force_array(gy[0] * (one - y * y))]
38
+ end
39
+ end
40
+ end
41
+ end
42
+ end
@@ -31,7 +31,7 @@ module Chainer
31
31
 
32
32
  def forward_cpu(inputs)
33
33
  x, t = inputs
34
- log_y = Activation.log_softmax(x)
34
+ log_y = Activation._log_softmax(x)
35
35
 
36
36
  if @cache_score
37
37
  @y = Numo::NMath.exp(log_y)
@@ -81,7 +81,7 @@ module Chainer
81
81
  if self.instance_variable_defined?(:'@y')
82
82
  y = @y.dup
83
83
  else
84
- y = Activation.log_softmax(x)
84
+ y = Activation._log_softmax(x)
85
85
  y = Numo::NMath.exp(y)
86
86
  end
87
87
 
@@ -15,7 +15,7 @@ module Chainer
15
15
  def next
16
16
  raise StopIteration if !@repeat && @epoch > 0
17
17
 
18
- @previous_epoch_detail = @epoch_detail
18
+ @previous_epoch_detail = epoch_detail
19
19
 
20
20
  i = @current_position
21
21
  i_end = i + @batch_size
@@ -55,6 +55,31 @@ module Chainer
55
55
  @epoch + @current_position.to_f / @dataset.size
56
56
  end
57
57
 
58
+ def serialize(serializer)
59
+ @current_position = serializer.('current_position', @current_position)
60
+ @epoch = serializer.('epoch', @epoch)
61
+ @is_new_epoch = serializer.('is_new_epoch', @is_new_epoch)
62
+ unless @order.nil?
63
+ begin
64
+ serializer.('order', @order)
65
+ rescue KeyError
66
+ serializer('_order', @order)
67
+ end
68
+ end
69
+
70
+ begin
71
+ @previous_epoch_detail = serializer.( 'previous_epoch_detail', @previous_epoch_detail)
72
+ rescue KeyError
73
+ # guess previous_epoch_detail for older version
74
+ @previous_epoch_detail = @epoch + (@current_position - @batch_size) / @dataset.size
75
+ if epoch_detail > 0
76
+ @previous_epoch_detail = [@previous_epoch_detail, 0.0].max
77
+ else
78
+ @previous_epoch_detail = -1.0
79
+ end
80
+ end
81
+ end
82
+
58
83
  def reset
59
84
  if @shuffle
60
85
  order = @dataset.size.times.map(&:to_i).shuffle
data/lib/chainer/link.rb CHANGED
@@ -59,6 +59,27 @@ module Chainer
59
59
  def namedlinks(skipself: false)
60
60
  yield('/', self) unless skipself
61
61
  end
62
+
63
+ def serialize(serializer)
64
+ d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
65
+ @params.each do |name|
66
+ param = d[name]
67
+ data = serializer.(name.to_s, param.data)
68
+ if param.data.nil? && !data.nil?
69
+ # Initialize the parameter here
70
+ param.init(data.shape)
71
+ if param.data.is_a?(Numo::NArray)
72
+ param.data.store(data)
73
+ else
74
+ param.data.set(Numo::NArray.cast(data))
75
+ end
76
+ end
77
+ end
78
+
79
+ @persistent.each do |name|
80
+ d[name] = serializer.(name.to_s, d[name])
81
+ end
82
+ end
62
83
  end
63
84
 
64
85
  class Chain < Link
@@ -114,5 +135,13 @@ module Chainer
114
135
  end
115
136
  end
116
137
  end
138
+
139
+ def serialize(serializer)
140
+ super(serializer)
141
+ d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
142
+ @children.each do |name|
143
+ d[name].serialize(serializer[name.to_s])
144
+ end
145
+ end
117
146
  end
118
147
  end
@@ -19,6 +19,17 @@ module Chainer
19
19
  hook(self)
20
20
  end
21
21
  end
22
+
23
+ def serialize(serializer)
24
+ @t = serializer.('t', @t)
25
+ @epoch = serializer.('epoch', @epoch)
26
+
27
+ @target.namedparams() do |(name, param)|
28
+ if param.respond_to?(:update_rule)
29
+ param.update_rule.serialize(serializer[name.to_s])
30
+ end
31
+ end
32
+ end
22
33
  end
23
34
 
24
35
  class UpdateRule
@@ -56,6 +67,33 @@ module Chainer
56
67
  raise NotImplementedError
57
68
  end
58
69
 
70
+
71
+ # Serializes the update rule state.
72
+ # Be careful that this method only saves/loads the state of the update rule.
73
+ # The parameters of the target link is not saved/loaded by this
74
+ # method, and so you need to serialize the target link separately if you
75
+ # want to fully recover the training state including parameters.
76
+ #
77
+ # @param [Chainer::AbstractSerializer] serializer: Serializer object.
78
+ def serialize(serializer)
79
+ if @state.nil?
80
+ if serializer.is_a?(Chainer::Deserializer)
81
+ # try to initialize the state to retrieve state entries
82
+ @state = {}
83
+ self_copy = self.dup
84
+ arr = Numo::DFloat.new(1)
85
+ self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
86
+ @state.keys.each do |key|
87
+ @state[key] = serializer.(key.to_s, nil)
88
+ end
89
+ end
90
+ else
91
+ @state.each do |key, val|
92
+ @state[key] = serializer.(key.to_s, val)
93
+ end
94
+ end
95
+ end
96
+
59
97
  private
60
98
 
61
99
  def prepare(param)
@@ -0,0 +1,50 @@
1
+ module Chainer
2
+ # Abstract base class of all serializers and deserializers.
3
+ class AbstractSerializer
4
+ # Gets a child serializer.
5
+ # This operator creates a child serializer represented by the given key.
6
+ #
7
+ # @param [string] key: Name of the child serializer.
8
+ def [](key)
9
+ raise NotImplementedError
10
+ end
11
+
12
+ # Serializes or deserializes a value by given name.
13
+ # This operator saves or loads a value by given name.
14
+ # If this is a serializer, then the value is simply saved at the key.
15
+ # Note that some type information might be missed depending on the
16
+ # implementation (and the target file format).
17
+ # If this is a deserializer, then the value is loaded by the key.
18
+ # The deserialization differently works on scalars and arrays.
19
+ # For scalars, the ``value`` argument is used just for determining the type of
20
+ # restored value to be converted, and the converted value is returned.
21
+ # For arrays, the restored elements are directly copied into the
22
+ # ``value`` argument. String values are treated like scalars.
23
+ #
24
+ # @param [string] key: Name of the serialization entry.
25
+ # @param [any] value: Object to be (de)serialized.
26
+ # ``None`` is only supported by deserializers.
27
+ # @return Serialized or deserialized value.
28
+ def call(key, value)
29
+ raise NotImplementedError
30
+ end
31
+ end
32
+
33
+ # Base class of all serializers.
34
+ class Serializer < AbstractSerializer
35
+ # Saves an object by this serializer.
36
+ # This is equivalent to ``obj.serialize(self)``.
37
+ #
38
+ # @param [any] obj: Target object to be serialized.
39
+ def save(obj)
40
+ obj.serialize(self)
41
+ end
42
+ end
43
+
44
+ # Base class of all deserializers.
45
+ class Deserializer < AbstractSerializer
46
+ def load(obj)
47
+ obj.serialize(self)
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,83 @@
1
+ module Chainer
2
+ module Serializers
3
+ class MarshalSerializer < Chainer::Serializer
4
+ attr_accessor :target, :path
5
+
6
+ def self.save_file(filename, obj)
7
+ s = self.new
8
+ s.save(obj)
9
+ Marshal.dump(s.target, filename)
10
+ end
11
+
12
+ def initialize(target: nil, path: "")
13
+ @target = target.nil? ? {} : target
14
+ @path = path
15
+ end
16
+
17
+ def [](key)
18
+ self.class.new(target: @target, path: File.join(@path, key, '/'))
19
+ end
20
+
21
+ def call(key, value)
22
+ ret = value
23
+ if value.is_a?(TrueClass)
24
+ arr = Numo::Bit[1]
25
+ elsif value.is_a?(FalseClass)
26
+ arr = Numo::Bit[0]
27
+ elsif value.instance_of?(String)
28
+ arr = value
29
+ else
30
+ arr = Numo::NArray.cast(value)
31
+ end
32
+ @target[File.join(@path, key)] = arr
33
+ ret
34
+ end
35
+ end
36
+
37
+ class MarshalDeserializer < Chainer::Deserializer
38
+ # Loads an object from the file in Marshal format.
39
+ # This is a short-cut function to load from an Marshal file that contains only one object.
40
+ #
41
+ # @param [string ]filename: Name of the file to be loaded.
42
+ # @param [object] obj: Object to be deserialized. It must support serialization protocol.
43
+ def self.load_file(filename, obj)
44
+ File.open(filename) do |f|
45
+ d = self.new(Marshal.load(f))
46
+ d.load(obj)
47
+ end
48
+ end
49
+
50
+ def initialize(marshalData, path: '', strict: true)
51
+ @marshal_data = marshalData
52
+ @path = path
53
+ @strict = strict
54
+ end
55
+
56
+ def [](key)
57
+ self.class.new(@marshal_data, path: File.join(@path, key, '/'), strict: @strict)
58
+ end
59
+
60
+ def call(key, value)
61
+ key = File.join(@path, key)
62
+ if !@strict && !@marshal_data.keys.include?(key)
63
+ return value
64
+ end
65
+
66
+ dataset = @marshal_data[key]
67
+ if value.nil?
68
+ return dataset
69
+ elsif value.instance_of?(String)
70
+ return dataset
71
+ elsif value.is_a?(Numo::NArray)
72
+ value.store(dataset)
73
+ return value
74
+ elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)
75
+ return dataset[0] == 1
76
+ else
77
+ return dataset[0]
78
+ end
79
+ end
80
+ end
81
+ end
82
+ end
83
+
@@ -61,6 +61,21 @@ module Chainer
61
61
  init_summary
62
62
  end
63
63
 
64
+ def serialize(serializer)
65
+ if @trigger.respond_to?(:serialize)
66
+ @trigger.serialize(serializer['_trigger'])
67
+ end
68
+ # Note that this serialization may lose some information of small
69
+ # numerical differences.
70
+ if serializer.is_a?(Chainer::Serializer)
71
+ log = JSON.generate(@log)
72
+ serializer.('_log', log)
73
+ else
74
+ log = serializer.('_log', '')
75
+ @log = JSON.parse(log)
76
+ end
77
+ end
78
+
64
79
  private
65
80
 
66
81
  def init_summary
@@ -44,6 +44,12 @@ module Chainer
44
44
  end
45
45
  end
46
46
 
47
+ def serialize(serializer)
48
+ if @log_report.is_a?(Chainer::Training::Extensions::LogReport)
49
+ @log_report.serialize(serializer['_log_report'])
50
+ end
51
+ end
52
+
47
53
  private
48
54
 
49
55
  def print(observation)
@@ -0,0 +1,33 @@
1
+ module Chainer
2
+ module Training
3
+ module Extensions
4
+ class Snapshot < Extension
5
+ attr_accessor :save_class, :filename_proc, :target
6
+
7
+ def self.snapshot_object(target:, save_class:, &block)
8
+ self.new(save_class: save_class, filename_proc: block, target: target)
9
+ end
10
+
11
+ def self.snapshot(save_class: nil, &block)
12
+ self.new(save_class: save_class, filename_proc: block)
13
+ end
14
+
15
+ def initialize(save_class: nil, filename_proc: nil, target: nil)
16
+ @save_class = save_class || Chainer::Serializers::MarshalSerializer
17
+ @filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
18
+ @target = target
19
+ end
20
+
21
+ def call(trainer)
22
+ target = @target || trainer
23
+ filename = filename_proc.call(trainer)
24
+ prefix = "tmp#{filename}"
25
+ temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
26
+ save_class.save_file(temp_file, trainer)
27
+ FileUtils.move(temp_file.path, File.join(trainer.out, filename))
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+
@@ -58,6 +58,18 @@ module Chainer
58
58
  iterator.finalize
59
59
  end
60
60
  end
61
+
62
+ def serialize(serializer)
63
+ @iterators.each do |name, iterator|
64
+ iterator.serialize(serializer["iterator:#{name}"])
65
+ end
66
+ @optimizers.each do |name, optimizer|
67
+ optimizer.serialize(serializer["optimizer:#{name}"])
68
+ optimizer.target.serialize(serializer["model:#{name}"])
69
+ end
70
+
71
+ @iteration = serializer.('iteration', @iteration)
72
+ end
61
73
  end
62
74
  end
63
75
  end
@@ -44,7 +44,7 @@ module Chainer
44
44
  return @final_elapsed_time if @done
45
45
  raise "training has not been started yet" if @start_at.nil?
46
46
 
47
- Time.now.to_f - @start_at - @snapshot_elapsed_time.to_f
47
+ Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f
48
48
  end
49
49
 
50
50
  def extend(extension, name: nil, trigger: nil, priority: nil, invoke_before_training: nil)
@@ -97,7 +97,7 @@ module Chainer
97
97
  raise 'cannot run training loop multiple times' if @done
98
98
  FileUtils.mkdir_p(@out)
99
99
 
100
- extensions = @extensions.sort_by { |(_, e)| e.priority }.map { |(name, extension)| [name, extension] }
100
+ extensions = @extensions.sort_by { |(_, e)| -e.priority }.map { |(name, extension)| [name, extension] }
101
101
 
102
102
  @start_at = Time.now.to_f
103
103
 
@@ -115,7 +115,7 @@ module Chainer
115
115
  @observation = {}
116
116
  reporter.scope(@observation) do
117
117
  update.call
118
- extensions.each do |(_, entry)|
118
+ extensions.each do |(name, entry)|
119
119
  entry.extension.(self) if entry.trigger.(self)
120
120
  end
121
121
  end
@@ -131,6 +131,29 @@ module Chainer
131
131
  @final_elapsed_time = @elapsed_time
132
132
  @done = true
133
133
  end
134
+
135
+ def serialize(serializer)
136
+ updater.serialize(serializer['updater'])
137
+ if @stop_trigger.respond_to?(:serialize)
138
+ @stop_trigger.serialize(serializer['stop_trigger'])
139
+ end
140
+
141
+ s = serializer['extensions']
142
+ t = serializer['extension_triggers']
143
+ @extensions.each do |name, entry|
144
+ if entry.extension.respond_to?(:serialize)
145
+ entry.extension.serialize(s[name])
146
+ end
147
+ if entry.trigger.respond_to?(:serialize)
148
+ entry.trigger.serialize(t[name])
149
+ end
150
+ end
151
+ if serializer.is_a?(Chainer::Serializer)
152
+ serializer.('_snapshot_elapsed_time', elapsed_time)
153
+ else
154
+ @snapshot_elapsed_time = serializer.('_snapshot_elapsed_time', 0.0)
155
+ end
156
+ end
134
157
  end
135
158
  end
136
159
  end
@@ -8,18 +8,43 @@ module Chainer
8
8
  @period = period
9
9
  @unit = unit
10
10
  @count = 0
11
+
12
+ @previous_iteration = 0
13
+ @previous_epoch_detail = 0.0
11
14
  end
12
15
 
13
16
  def call(trainer)
14
17
  updater = trainer.updater
15
18
  if @unit == 'epoch'
16
- prev = @count
17
- @count = updater.epoch_detail.div(@period)
18
- prev != @count
19
+ epoch_detail = updater.epoch_detail
20
+ previous_epoch_detail = @previous_epoch_detail
21
+
22
+ if previous_epoch_detail < 0
23
+ previous_epoch_detail = updater.previous_epoch_detail
24
+ end
25
+
26
+ @count = epoch_detail.div(@period).floor
27
+
28
+ fire = previous_epoch_detail.div(@period).floor != epoch_detail.div(@period).floor
19
29
  else
20
30
  iteration = updater.iteration
21
- iteration > 0 && iteration % @period == 0
31
+ previous_iteration = @previous_iteration
32
+ if previous_iteration < 0
33
+ previous_iteration = iteration - 1
34
+ end
35
+ fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
22
36
  end
37
+
38
+ # save current values
39
+ @previous_iteration = updater.iteration
40
+ @previous_epoch_detail = updater.epoch_detail
41
+
42
+ fire
43
+ end
44
+
45
+ def serialize(serializer)
46
+ @previous_iteration = serializer.('previous_iteration', @previous_iteration)
47
+ @previous_epoch_detail = serializer.('previous_epoch_detail', @previous_epoch_detail)
23
48
  end
24
49
  end
25
50
  end
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.1.1"
2
+ VERSION = "0.2.0"
3
3
  end
4
4
 
data/lib/chainer.rb CHANGED
@@ -22,7 +22,10 @@ require 'chainer/variable_node'
22
22
  require 'chainer/utils/initializer'
23
23
  require 'chainer/utils/variable'
24
24
  require 'chainer/utils/array'
25
+ require 'chainer/functions/activation/leaky_relu'
25
26
  require 'chainer/functions/activation/relu'
27
+ require 'chainer/functions/activation/sigmoid'
28
+ require 'chainer/functions/activation/tanh'
26
29
  require 'chainer/functions/activation/log_softmax'
27
30
  require 'chainer/functions/evaluation/accuracy'
28
31
  require 'chainer/functions/math/basic_math'
@@ -33,6 +36,7 @@ require 'chainer/training/extensions/evaluator'
33
36
  require 'chainer/training/extensions/log_report'
34
37
  require 'chainer/training/extensions/print_report'
35
38
  require 'chainer/training/extensions/progress_bar'
39
+ require 'chainer/training/extensions/snapshot'
36
40
  require 'chainer/training/trainer'
37
41
  require 'chainer/training/updater'
38
42
  require 'chainer/training/util'
@@ -44,6 +48,8 @@ require 'chainer/dataset/download'
44
48
  require 'chainer/datasets/mnist'
45
49
  require 'chainer/datasets/tuple_dataset'
46
50
  require 'chainer/reporter'
51
+ require 'chainer/serializer'
52
+ require 'chainer/serializers/marshal'
47
53
 
48
54
  require 'numo/narray'
49
55
 
data/red-chainer.gemspec CHANGED
@@ -19,7 +19,7 @@ Gem::Specification.new do |spec|
19
19
  spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
20
20
  spec.require_paths = ["lib"]
21
21
 
22
- spec.add_runtime_dependency "numo-narray", ">= 0.9.0.8"
22
+ spec.add_runtime_dependency "numo-narray", ">= 0.9.1.1"
23
23
 
24
24
  spec.add_development_dependency "bundler", "~> 1.15"
25
25
  spec.add_development_dependency "rake", "~> 10.0"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2017-11-18 00:00:00.000000000 Z
11
+ date: 2018-02-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.9.0.8
19
+ version: 0.9.1.1
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.9.0.8
26
+ version: 0.9.1.1
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: bundler
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -92,8 +92,11 @@ files:
92
92
  - lib/chainer/datasets/mnist.rb
93
93
  - lib/chainer/datasets/tuple_dataset.rb
94
94
  - lib/chainer/function.rb
95
+ - lib/chainer/functions/activation/leaky_relu.rb
95
96
  - lib/chainer/functions/activation/log_softmax.rb
96
97
  - lib/chainer/functions/activation/relu.rb
98
+ - lib/chainer/functions/activation/sigmoid.rb
99
+ - lib/chainer/functions/activation/tanh.rb
97
100
  - lib/chainer/functions/connection/linear.rb
98
101
  - lib/chainer/functions/evaluation/accuracy.rb
99
102
  - lib/chainer/functions/loss/softmax_cross_entropy.rb
@@ -112,11 +115,14 @@ files:
112
115
  - lib/chainer/optimizers/adam.rb
113
116
  - lib/chainer/parameter.rb
114
117
  - lib/chainer/reporter.rb
118
+ - lib/chainer/serializer.rb
119
+ - lib/chainer/serializers/marshal.rb
115
120
  - lib/chainer/training/extension.rb
116
121
  - lib/chainer/training/extensions/evaluator.rb
117
122
  - lib/chainer/training/extensions/log_report.rb
118
123
  - lib/chainer/training/extensions/print_report.rb
119
124
  - lib/chainer/training/extensions/progress_bar.rb
125
+ - lib/chainer/training/extensions/snapshot.rb
120
126
  - lib/chainer/training/standard_updater.rb
121
127
  - lib/chainer/training/trainer.rb
122
128
  - lib/chainer/training/triggers/interval.rb
@@ -149,7 +155,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
149
155
  version: '0'
150
156
  requirements: []
151
157
  rubyforge_project:
152
- rubygems_version: 2.6.13
158
+ rubygems_version: 2.7.3
153
159
  signing_key:
154
160
  specification_version: 4
155
161
  summary: A flexible framework for neural network for Ruby