red-chainer 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -15,13 +15,14 @@ module Chainer
|
|
15
15
|
def reallocate_cleared_grads
|
16
16
|
@target.namedparams(include_uninit: false) do |(name, param)|
|
17
17
|
if param.grad.nil?
|
18
|
-
|
18
|
+
xm = Chainer.get_array_module(param.data)
|
19
|
+
param.grad = xm::NArray.[](*param.data).new_zeros
|
19
20
|
end
|
20
21
|
end
|
21
22
|
end
|
22
23
|
|
23
24
|
def call_hooks
|
24
|
-
@hooks.each do |hook|
|
25
|
+
@hooks.values.each do |hook|
|
25
26
|
_call_hook(hook)
|
26
27
|
reallocate_cleared_grads
|
27
28
|
end
|
@@ -1,7 +1,7 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Initializers
|
3
|
-
def self.generate_array(initializer, shape)
|
4
|
-
klass =
|
3
|
+
def self.generate_array(initializer, shape, device: Chainer::Device.default)
|
4
|
+
klass = device.xm::SFloat
|
5
5
|
if initializer.respond_to?(:dtype) && initializer.dtype
|
6
6
|
klass = initializer.dtype
|
7
7
|
end
|
@@ -9,10 +9,10 @@ module Chainer
|
|
9
9
|
initializer.(array)
|
10
10
|
end
|
11
11
|
|
12
|
-
def self.get_initializer(initializer)
|
13
|
-
return HeNormal.new(scale: 1 /
|
12
|
+
def self.get_initializer(initializer, device: Chainer::Device.default)
|
13
|
+
return HeNormal.new(scale: 1 / device.xm::NMath.sqrt(2)) if initializer.nil?
|
14
14
|
return Constant.new(initializer) if initializer.kind_of?(Numeric)
|
15
|
-
return Constant.new(initializer) if
|
15
|
+
return Constant.new(initializer) if Chainer.array?(initializer)
|
16
16
|
|
17
17
|
unless initializer.respond_to?(:call)
|
18
18
|
raise TypeError, "invalid type of initializer: #{initializer.class}"
|
@@ -19,8 +19,10 @@ module Chainer
|
|
19
19
|
end
|
20
20
|
|
21
21
|
def call(array)
|
22
|
-
|
23
|
-
|
22
|
+
# TODO(sonots): pass device from outside
|
23
|
+
device = Chainer::Device.default
|
24
|
+
fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape, device: device)
|
25
|
+
s = @scale * device.xm::NMath.sqrt(2.0 / fan_in)
|
24
26
|
Normal.new(scale: s).(array)
|
25
27
|
end
|
26
28
|
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
class Uniform < ::Chainer::Initializer
|
4
|
+
def initialize(scale: 0.05, dtype: nil)
|
5
|
+
@scale = scale
|
6
|
+
super(dtype: dtype)
|
7
|
+
end
|
8
|
+
|
9
|
+
def call(array)
|
10
|
+
raise ArgumentError.new("dtypes are missmatched. #{dtype} != #{array.class}") if dtype && dtype != array.class
|
11
|
+
array.class.new(array.shape).rand(-@scale, @scale)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -3,11 +3,13 @@ module Chainer
|
|
3
3
|
class SerialIterator < Chainer::Dataset::Iterator
|
4
4
|
attr_reader :epoch, :is_new_epoch
|
5
5
|
|
6
|
-
def initialize(dataset, batch_size, repeat: true, shuffle: true)
|
6
|
+
def initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default)
|
7
7
|
@dataset = dataset
|
8
8
|
@batch_size = batch_size
|
9
9
|
@repeat = repeat
|
10
10
|
@shuffle = shuffle
|
11
|
+
@device = device
|
12
|
+
@xm = device.xm
|
11
13
|
|
12
14
|
reset
|
13
15
|
end
|
@@ -83,10 +85,10 @@ module Chainer
|
|
83
85
|
def reset
|
84
86
|
if @shuffle
|
85
87
|
order = @dataset.size.times.map(&:to_i).shuffle
|
86
|
-
@order =
|
88
|
+
@order = @xm::Int64[*order]
|
87
89
|
else
|
88
90
|
order = @dataset.size.times.map(&:to_i)
|
89
|
-
@order =
|
91
|
+
@order = @xm::Int64[*order]
|
90
92
|
end
|
91
93
|
|
92
94
|
@current_position = 0
|
data/lib/chainer/link.rb
CHANGED
@@ -80,6 +80,8 @@ module Chainer
|
|
80
80
|
end
|
81
81
|
|
82
82
|
def serialize(serializer)
|
83
|
+
# TODO(sonots): pass device from outside
|
84
|
+
xm = Chainer::Device.default.xm
|
83
85
|
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
84
86
|
@params.each do |name|
|
85
87
|
param = d[name]
|
@@ -87,10 +89,10 @@ module Chainer
|
|
87
89
|
if param.data.nil? && !data.nil?
|
88
90
|
# Initialize the parameter here
|
89
91
|
param.init(data.shape)
|
90
|
-
if param.data
|
92
|
+
if Chainer.array?(param.data)
|
91
93
|
param.data.store(data)
|
92
94
|
else
|
93
|
-
param.data.set(
|
95
|
+
param.data.set(xm::NArray.cast(data))
|
94
96
|
end
|
95
97
|
end
|
96
98
|
end
|
@@ -14,8 +14,8 @@ module Chainer
|
|
14
14
|
# @param [integer or 2-d int array] stride Stride of filter applications.
|
15
15
|
# @param [integer or 2-d int array] pad Spatial padding width for input arrays.
|
16
16
|
# @param [boolean] nobias If `true`, then this link does not use the bias term.
|
17
|
-
# @param [Numo::NArray]
|
18
|
-
# @param [Numo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
|
17
|
+
# @param [Numo::NArray or Cumo::NArray] initial_w Initial weight value. If `nil`, the default initializer is used.
|
18
|
+
# @param [Numo::NArray or Cumo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
|
19
19
|
#
|
20
20
|
# Example
|
21
21
|
# There are several ways to make a Convolution2D link.
|
@@ -4,26 +4,45 @@ module Chainer
|
|
4
4
|
class Classifier < Chain
|
5
5
|
attr_accessor :compute_accuracy
|
6
6
|
|
7
|
-
|
7
|
+
# @param [Chainer::Link] predictor Predictor network.
|
8
|
+
# @param [Function] lossfun Loss function.
|
9
|
+
# @param [Function] accfun Function that computes accuracy.
|
10
|
+
# @param [Integer, String] label_key Key to specify label variable from arguments.
|
11
|
+
# When it is Integer, a variable in positional arguments is used.
|
12
|
+
# And when it is String, a variable in keyword arguments is used.
|
13
|
+
def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1)
|
8
14
|
super()
|
15
|
+
|
16
|
+
unless label_key.is_a?(Integer) || label_key.is_a?(String)
|
17
|
+
raise TypeError, "label_key must be Integer or String, but is #{label_key.class}"
|
18
|
+
end
|
19
|
+
|
9
20
|
@lossfun = lossfun
|
10
21
|
@accfun = accfun
|
11
22
|
@y = nil
|
12
23
|
@loss = nil
|
13
24
|
@accuracy = nil
|
14
25
|
@compute_accuracy = true
|
26
|
+
@label_key = label_key
|
15
27
|
|
16
28
|
init_scope do
|
17
29
|
@predictor = predictor
|
18
30
|
end
|
19
31
|
end
|
20
32
|
|
21
|
-
def call(*args)
|
22
|
-
|
23
|
-
|
33
|
+
def call(*args, **kwargs)
|
34
|
+
if @label_key.is_a?(Integer)
|
35
|
+
raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size
|
36
|
+
t = args.slice!(@label_key)
|
37
|
+
elsif @label_key.is_a?(String)
|
38
|
+
raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key)
|
39
|
+
t = kwargs[@label_key]
|
40
|
+
kwargs.delete(@label_key)
|
41
|
+
end
|
42
|
+
|
24
43
|
@y = nil
|
25
44
|
@accuracy = nil
|
26
|
-
@y = @predictor.(*
|
45
|
+
@y = @predictor.(*args, **kwargs)
|
27
46
|
|
28
47
|
@loss = @lossfun.call(@y, t)
|
29
48
|
Chainer::Reporter.save_report({loss: @loss}, self)
|
@@ -20,11 +20,12 @@ module Chainer
|
|
20
20
|
# @param [integer or int array] size Size (or shape) of channel dimensions.
|
21
21
|
# @param [float] decay Decay rate of moving average. It is used on training.
|
22
22
|
# @param [float] eps Epsilon value for numerical stability.
|
23
|
-
# @param [Numo::NArray.dtype] dtype Type to use in computing.
|
23
|
+
# @param [Numo::NArray.dtype or Cumo::NArray.dtype] dtype Type to use in computing.
|
24
24
|
# @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
|
25
25
|
# @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
|
26
|
-
def initialize(size, decay: 0.9, eps: 2e-5, dtype:
|
26
|
+
def initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
|
27
27
|
super()
|
28
|
+
dtype ||= Chainer::Device.default.xm::SFloat
|
28
29
|
@avg_mean = dtype.zeros(size)
|
29
30
|
register_persistent('avg_mean')
|
30
31
|
@avg_var = dtype.zeros(size)
|
@@ -79,15 +80,11 @@ module Chainer
|
|
79
80
|
decay = @decay
|
80
81
|
end
|
81
82
|
|
82
|
-
|
83
|
-
ret = func.(x, gamma, beta)
|
84
|
-
|
85
|
-
@avg_mean[false] = func.running_mean
|
86
|
-
@avg_var[false] = func.running_var
|
83
|
+
ret = Chainer::Functions::Normalization::BatchNormalization.batch_normalization(x, gamma, beta, eps: @eps, running_mean: @avg_mean, running_var: @avg_var, decay: decay)
|
87
84
|
else
|
88
|
-
mean = Chainer::Variable(@avg_mean)
|
89
|
-
var = Chainer::Variable(@avg_var)
|
90
|
-
ret = Chainer::Functions::Normalization::
|
85
|
+
mean = Chainer::Variable.new(@avg_mean)
|
86
|
+
var = Chainer::Variable.new(@avg_var)
|
87
|
+
ret = Chainer::Functions::Normalization::FixedBatchNormalization.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps)
|
91
88
|
end
|
92
89
|
|
93
90
|
ret
|
data/lib/chainer/optimizer.rb
CHANGED
@@ -10,13 +10,32 @@ module Chainer
|
|
10
10
|
@hooks = {}
|
11
11
|
end
|
12
12
|
|
13
|
+
def add_hook(hook, name: nil)
|
14
|
+
if !hook.class.method_defined?(:call)
|
15
|
+
raise TypeError, 'hook function is not callable'
|
16
|
+
end
|
17
|
+
|
18
|
+
name = hook.name if name.nil?
|
19
|
+
if @hooks[name]
|
20
|
+
raise TypeError, "hook #{name} already exists"
|
21
|
+
end
|
22
|
+
@hooks[name] = hook
|
23
|
+
end
|
24
|
+
|
25
|
+
def call_hooks
|
26
|
+
@hooks.values.each do |hook|
|
27
|
+
_call_hook(hook)
|
28
|
+
reallocate_cleared_grads
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
13
32
|
def _call_hook(hook)
|
14
33
|
if hook.methods.include?(:call_for_each_param)
|
15
|
-
@target.params
|
34
|
+
@target.params do |param|
|
16
35
|
hook.(param.update_rule, param)
|
17
36
|
end
|
18
37
|
else
|
19
|
-
hook(self)
|
38
|
+
hook.(self)
|
20
39
|
end
|
21
40
|
end
|
22
41
|
|
@@ -47,7 +66,9 @@ module Chainer
|
|
47
66
|
return unless @enabled
|
48
67
|
|
49
68
|
@t += 1
|
50
|
-
|
69
|
+
unless param.data.nil?
|
70
|
+
prepare(param)
|
71
|
+
end
|
51
72
|
@hooks.values.each do |hook|
|
52
73
|
hook.call(param)
|
53
74
|
end
|
@@ -55,14 +76,22 @@ module Chainer
|
|
55
76
|
end
|
56
77
|
|
57
78
|
def update_core(param)
|
58
|
-
|
59
|
-
|
79
|
+
xm = Chainer.get_array_module(param)
|
80
|
+
if xm == Cumo
|
81
|
+
update_core_gpu(param)
|
82
|
+
else
|
83
|
+
update_core_cpu(param)
|
84
|
+
end
|
60
85
|
end
|
61
86
|
|
62
87
|
def update_core_cpu
|
63
88
|
raise NotImplementedError
|
64
89
|
end
|
65
90
|
|
91
|
+
def update_core_gpu
|
92
|
+
raise NotImplementedError
|
93
|
+
end
|
94
|
+
|
66
95
|
def init_state(param)
|
67
96
|
raise NotImplementedError
|
68
97
|
end
|
@@ -74,14 +103,16 @@ module Chainer
|
|
74
103
|
# method, and so you need to serialize the target link separately if you
|
75
104
|
# want to fully recover the training state including parameters.
|
76
105
|
#
|
77
|
-
# @param [Chainer::AbstractSerializer] serializer
|
106
|
+
# @param [Chainer::AbstractSerializer] serializer Serializer object.
|
78
107
|
def serialize(serializer)
|
79
108
|
if @state.nil?
|
80
109
|
if serializer.is_a?(Chainer::Deserializer)
|
81
110
|
# try to initialize the state to retrieve state entries
|
82
111
|
@state = {}
|
83
112
|
self_copy = self.dup
|
84
|
-
|
113
|
+
# TODO(sonots): pass device from outside
|
114
|
+
xm = Chainer::Device.default.xm
|
115
|
+
arr = xm::SFloat.new(1)
|
85
116
|
self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
|
86
117
|
@state.keys.each do |key|
|
87
118
|
@state[key] = serializer.(key.to_s, nil)
|
@@ -101,7 +132,7 @@ module Chainer
|
|
101
132
|
@state = {}
|
102
133
|
init_state(param)
|
103
134
|
end
|
104
|
-
@state.select! { |_, v|
|
135
|
+
@state.select! { |_, v| Chainer.array?(v) }
|
105
136
|
end
|
106
137
|
end
|
107
138
|
|
@@ -126,11 +157,11 @@ module Chainer
|
|
126
157
|
#
|
127
158
|
# @param [Float] rate Coefficient for the weight decay
|
128
159
|
class WeightDecay
|
129
|
-
def
|
160
|
+
def name
|
130
161
|
"WeightDecay"
|
131
162
|
end
|
132
163
|
|
133
|
-
def
|
164
|
+
def call_for_each_param
|
134
165
|
true
|
135
166
|
end
|
136
167
|
|
@@ -140,7 +171,7 @@ module Chainer
|
|
140
171
|
|
141
172
|
def call(rule, param)
|
142
173
|
return if param.data.nil? || param.grad.nil?
|
143
|
-
param.grad += @rate * param.data
|
174
|
+
param.grad += (@rate * param.data)
|
144
175
|
end
|
145
176
|
end
|
146
177
|
end
|
@@ -21,7 +21,7 @@ module Chainer
|
|
21
21
|
@state[:v] = param.data.new_zeros
|
22
22
|
end
|
23
23
|
|
24
|
-
def
|
24
|
+
def update_core(param)
|
25
25
|
grad = param.grad
|
26
26
|
return if grad.nil?
|
27
27
|
|
@@ -29,7 +29,8 @@ module Chainer
|
|
29
29
|
|
30
30
|
@state[:m] += (1 - hp.beta1) * (grad - @state[:m])
|
31
31
|
@state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
|
32
|
-
|
32
|
+
xm = Chainer.get_array_module(grad)
|
33
|
+
param.data -= lr * @state[:m] / (xm::NMath.sqrt(@state[:v]) + hp.eps)
|
33
34
|
end
|
34
35
|
|
35
36
|
def lr
|
data/lib/chainer/parameter.rb
CHANGED
@@ -10,20 +10,21 @@ module Chainer
|
|
10
10
|
end
|
11
11
|
|
12
12
|
if shape.nil?
|
13
|
-
if
|
13
|
+
if Chainer.array?(initializer)
|
14
14
|
super(initializer, name: name)
|
15
15
|
else
|
16
16
|
super(name: name)
|
17
17
|
@initializer = initializer
|
18
18
|
dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'SFloat'
|
19
|
-
@grad_initializer = Chainer::Initializers.nan()
|
19
|
+
@grad_initializer = Chainer::Initializers.nan(dtype: dtype)
|
20
20
|
end
|
21
21
|
else
|
22
|
-
if
|
22
|
+
if Chainer.array?(initializer)
|
23
23
|
initializer = Initializers::Constant.new(initializer)
|
24
24
|
end
|
25
25
|
data = Chainer::Initializers.generate_array(initializer, shape)
|
26
|
-
|
26
|
+
xm = Chainer.get_array_module(data)
|
27
|
+
grad = xm::NArray[*[1, 2]].new_fill(-922337203)
|
27
28
|
super(data, name: name, grad: grad)
|
28
29
|
end
|
29
30
|
|
@@ -40,8 +41,8 @@ module Chainer
|
|
40
41
|
ginit = @grad_initializer
|
41
42
|
grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
|
42
43
|
|
43
|
-
|
44
|
-
|
44
|
+
self.data = data
|
45
|
+
self.grad = grad
|
45
46
|
end
|
46
47
|
|
47
48
|
def update
|
data/lib/chainer/serializer.rb
CHANGED
@@ -4,7 +4,7 @@ module Chainer
|
|
4
4
|
# Gets a child serializer.
|
5
5
|
# This operator creates a child serializer represented by the given key.
|
6
6
|
#
|
7
|
-
# @param [string] key
|
7
|
+
# @param [string] key Name of the child serializer.
|
8
8
|
def [](key)
|
9
9
|
raise NotImplementedError
|
10
10
|
end
|
@@ -21,8 +21,8 @@ module Chainer
|
|
21
21
|
# For arrays, the restored elements are directly copied into the
|
22
22
|
# ``value`` argument. String values are treated like scalars.
|
23
23
|
#
|
24
|
-
# @param [string] key
|
25
|
-
# @param [any] value
|
24
|
+
# @param [string] key Name of the serialization entry.
|
25
|
+
# @param [any] value Object to be (de)serialized.
|
26
26
|
# ``None`` is only supported by deserializers.
|
27
27
|
# @return Serialized or deserialized value.
|
28
28
|
def call(key, value)
|
@@ -35,7 +35,7 @@ module Chainer
|
|
35
35
|
# Saves an object by this serializer.
|
36
36
|
# This is equivalent to ``obj.serialize(self)``.
|
37
37
|
#
|
38
|
-
# @param [any] obj
|
38
|
+
# @param [any] obj Target object to be serialized.
|
39
39
|
def save(obj)
|
40
40
|
obj.serialize(self)
|
41
41
|
end
|
@@ -1,6 +1,6 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Serializers
|
3
|
-
class MarshalSerializer < Chainer::Serializer
|
3
|
+
class MarshalSerializer < Chainer::Serializer
|
4
4
|
attr_accessor :target, :path
|
5
5
|
|
6
6
|
# @param [string] file_path Target file path
|
@@ -13,9 +13,11 @@ module Chainer
|
|
13
13
|
end
|
14
14
|
end
|
15
15
|
|
16
|
-
def initialize(target: nil, path: "")
|
16
|
+
def initialize(target: nil, path: "", device: Chainer::Device.default)
|
17
17
|
@target = target.nil? ? {} : target
|
18
18
|
@path = path
|
19
|
+
@device = device
|
20
|
+
@xm = device.xm
|
19
21
|
end
|
20
22
|
|
21
23
|
def [](key)
|
@@ -25,13 +27,13 @@ module Chainer
|
|
25
27
|
def call(key, value)
|
26
28
|
ret = value
|
27
29
|
if value.is_a?(TrueClass)
|
28
|
-
arr =
|
30
|
+
arr = @xm::Bit[1]
|
29
31
|
elsif value.is_a?(FalseClass)
|
30
|
-
arr =
|
32
|
+
arr = @xm::Bit[0]
|
31
33
|
elsif value.instance_of?(String) || value.nil?
|
32
34
|
arr = value
|
33
35
|
else
|
34
|
-
arr =
|
36
|
+
arr = @xm::NArray.cast(value)
|
35
37
|
end
|
36
38
|
@target[File.join(@path, key)] = arr
|
37
39
|
ret
|
@@ -42,8 +44,8 @@ module Chainer
|
|
42
44
|
# Loads an object from the file in Marshal format.
|
43
45
|
# This is a short-cut function to load from an Marshal file that contains only one object.
|
44
46
|
#
|
45
|
-
# @param [string
|
46
|
-
# @param [object] obj
|
47
|
+
# @param [string] filename Name of the file to be loaded.
|
48
|
+
# @param [object] obj Object to be deserialized. It must support serialization protocol.
|
47
49
|
def self.load_file(filename, obj)
|
48
50
|
File.open(filename) do |f|
|
49
51
|
d = self.new(Marshal.load(f))
|
@@ -72,7 +74,7 @@ module Chainer
|
|
72
74
|
return dataset
|
73
75
|
elsif value.instance_of?(String)
|
74
76
|
return dataset
|
75
|
-
elsif
|
77
|
+
elsif Chainer.array?(value)
|
76
78
|
value.store(dataset)
|
77
79
|
return value
|
78
80
|
elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)
|