red-chainer 0.3.2 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +2 -2
- data/.travis.yml +8 -3
- data/.yardopts +1 -0
- data/Gemfile +6 -1
- data/README.md +34 -3
- data/examples/cifar/train_cifar.rb +13 -2
- data/examples/iris/iris.rb +9 -5
- data/examples/mnist/mnist.rb +16 -4
- data/lib/chainer.rb +17 -1
- data/lib/chainer/backend.rb +27 -0
- data/lib/chainer/cuda.rb +37 -15
- data/lib/chainer/dataset/convert.rb +20 -16
- data/lib/chainer/datasets/cifar.rb +8 -6
- data/lib/chainer/datasets/mnist.rb +14 -55
- data/lib/chainer/device.rb +88 -0
- data/lib/chainer/function.rb +103 -41
- data/lib/chainer/function_node.rb +454 -0
- data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
- data/lib/chainer/functions/activation/log_softmax.rb +46 -9
- data/lib/chainer/functions/activation/relu.rb +8 -8
- data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
- data/lib/chainer/functions/activation/sigmoid.rb +13 -11
- data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
- data/lib/chainer/functions/activation/tanh.rb +48 -11
- data/lib/chainer/functions/array/broadcast_to.rb +56 -0
- data/lib/chainer/functions/array/cast.rb +41 -0
- data/lib/chainer/functions/array/reshape.rb +28 -0
- data/lib/chainer/functions/array/rollaxis.rb +57 -0
- data/lib/chainer/functions/array/select_item.rb +72 -0
- data/lib/chainer/functions/array/squeeze.rb +78 -0
- data/lib/chainer/functions/array/transpose.rb +44 -0
- data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
- data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
- data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
- data/lib/chainer/functions/connection/linear.rb +29 -22
- data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
- data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
- data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
- data/lib/chainer/functions/math/basic_math.rb +36 -30
- data/lib/chainer/functions/math/exp.rb +28 -0
- data/lib/chainer/functions/math/identity.rb +4 -3
- data/lib/chainer/functions/math/sum.rb +52 -0
- data/lib/chainer/functions/noise/dropout.rb +20 -4
- data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
- data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
- data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
- data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
- data/lib/chainer/gradient_check.rb +157 -73
- data/lib/chainer/gradient_method.rb +3 -2
- data/lib/chainer/initializers/init.rb +5 -5
- data/lib/chainer/initializers/normal.rb +4 -2
- data/lib/chainer/initializers/uniform.rb +15 -0
- data/lib/chainer/iterators/serial_iterator.rb +5 -3
- data/lib/chainer/link.rb +4 -2
- data/lib/chainer/links/connection/convolution_2d.rb +2 -2
- data/lib/chainer/links/model/classifier.rb +24 -5
- data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
- data/lib/chainer/optimizer.rb +42 -11
- data/lib/chainer/optimizers/adam.rb +3 -2
- data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
- data/lib/chainer/parameter.rb +7 -6
- data/lib/chainer/serializer.rb +4 -4
- data/lib/chainer/serializers/marshal.rb +10 -8
- data/lib/chainer/testing/array.rb +1 -1
- data/lib/chainer/training/extensions/evaluator.rb +2 -3
- data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
- data/lib/chainer/training/extensions/progress_bar.rb +1 -0
- data/lib/chainer/training/trainer.rb +4 -9
- data/lib/chainer/training/triggers/interval.rb +7 -2
- data/lib/chainer/utils/array.rb +80 -1
- data/lib/chainer/utils/conv.rb +10 -2
- data/lib/chainer/utils/initializer.rb +2 -2
- data/lib/chainer/variable.rb +159 -69
- data/lib/chainer/variable_node.rb +64 -10
- data/lib/chainer/version.rb +1 -1
- data/red-chainer.gemspec +4 -3
- data/templates/default/layout/html/layout.erb +40 -0
- data/templates/default/onefile/html/layout.erb +33 -0
- metadata +44 -11
- data/lib/chainer/dataset/download.rb +0 -56
@@ -15,13 +15,14 @@ module Chainer
|
|
15
15
|
def reallocate_cleared_grads
|
16
16
|
@target.namedparams(include_uninit: false) do |(name, param)|
|
17
17
|
if param.grad.nil?
|
18
|
-
|
18
|
+
xm = Chainer.get_array_module(param.data)
|
19
|
+
param.grad = xm::NArray.[](*param.data).new_zeros
|
19
20
|
end
|
20
21
|
end
|
21
22
|
end
|
22
23
|
|
23
24
|
def call_hooks
|
24
|
-
@hooks.each do |hook|
|
25
|
+
@hooks.values.each do |hook|
|
25
26
|
_call_hook(hook)
|
26
27
|
reallocate_cleared_grads
|
27
28
|
end
|
@@ -1,7 +1,7 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Initializers
|
3
|
-
def self.generate_array(initializer, shape)
|
4
|
-
klass =
|
3
|
+
def self.generate_array(initializer, shape, device: Chainer::Device.default)
|
4
|
+
klass = device.xm::SFloat
|
5
5
|
if initializer.respond_to?(:dtype) && initializer.dtype
|
6
6
|
klass = initializer.dtype
|
7
7
|
end
|
@@ -9,10 +9,10 @@ module Chainer
|
|
9
9
|
initializer.(array)
|
10
10
|
end
|
11
11
|
|
12
|
-
def self.get_initializer(initializer)
|
13
|
-
return HeNormal.new(scale: 1 /
|
12
|
+
def self.get_initializer(initializer, device: Chainer::Device.default)
|
13
|
+
return HeNormal.new(scale: 1 / device.xm::NMath.sqrt(2)) if initializer.nil?
|
14
14
|
return Constant.new(initializer) if initializer.kind_of?(Numeric)
|
15
|
-
return Constant.new(initializer) if
|
15
|
+
return Constant.new(initializer) if Chainer.array?(initializer)
|
16
16
|
|
17
17
|
unless initializer.respond_to?(:call)
|
18
18
|
raise TypeError, "invalid type of initializer: #{initializer.class}"
|
@@ -19,8 +19,10 @@ module Chainer
|
|
19
19
|
end
|
20
20
|
|
21
21
|
def call(array)
|
22
|
-
|
23
|
-
|
22
|
+
# TODO(sonots): pass device from outside
|
23
|
+
device = Chainer::Device.default
|
24
|
+
fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape, device: device)
|
25
|
+
s = @scale * device.xm::NMath.sqrt(2.0 / fan_in)
|
24
26
|
Normal.new(scale: s).(array)
|
25
27
|
end
|
26
28
|
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Initializers
|
3
|
+
class Uniform < ::Chainer::Initializer
|
4
|
+
def initialize(scale: 0.05, dtype: nil)
|
5
|
+
@scale = scale
|
6
|
+
super(dtype: dtype)
|
7
|
+
end
|
8
|
+
|
9
|
+
def call(array)
|
10
|
+
raise ArgumentError.new("dtypes are missmatched. #{dtype} != #{array.class}") if dtype && dtype != array.class
|
11
|
+
array.class.new(array.shape).rand(-@scale, @scale)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -3,11 +3,13 @@ module Chainer
|
|
3
3
|
class SerialIterator < Chainer::Dataset::Iterator
|
4
4
|
attr_reader :epoch, :is_new_epoch
|
5
5
|
|
6
|
-
def initialize(dataset, batch_size, repeat: true, shuffle: true)
|
6
|
+
def initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default)
|
7
7
|
@dataset = dataset
|
8
8
|
@batch_size = batch_size
|
9
9
|
@repeat = repeat
|
10
10
|
@shuffle = shuffle
|
11
|
+
@device = device
|
12
|
+
@xm = device.xm
|
11
13
|
|
12
14
|
reset
|
13
15
|
end
|
@@ -83,10 +85,10 @@ module Chainer
|
|
83
85
|
def reset
|
84
86
|
if @shuffle
|
85
87
|
order = @dataset.size.times.map(&:to_i).shuffle
|
86
|
-
@order =
|
88
|
+
@order = @xm::Int64[*order]
|
87
89
|
else
|
88
90
|
order = @dataset.size.times.map(&:to_i)
|
89
|
-
@order =
|
91
|
+
@order = @xm::Int64[*order]
|
90
92
|
end
|
91
93
|
|
92
94
|
@current_position = 0
|
data/lib/chainer/link.rb
CHANGED
@@ -80,6 +80,8 @@ module Chainer
|
|
80
80
|
end
|
81
81
|
|
82
82
|
def serialize(serializer)
|
83
|
+
# TODO(sonots): pass device from outside
|
84
|
+
xm = Chainer::Device.default.xm
|
83
85
|
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
|
84
86
|
@params.each do |name|
|
85
87
|
param = d[name]
|
@@ -87,10 +89,10 @@ module Chainer
|
|
87
89
|
if param.data.nil? && !data.nil?
|
88
90
|
# Initialize the parameter here
|
89
91
|
param.init(data.shape)
|
90
|
-
if param.data
|
92
|
+
if Chainer.array?(param.data)
|
91
93
|
param.data.store(data)
|
92
94
|
else
|
93
|
-
param.data.set(
|
95
|
+
param.data.set(xm::NArray.cast(data))
|
94
96
|
end
|
95
97
|
end
|
96
98
|
end
|
@@ -14,8 +14,8 @@ module Chainer
|
|
14
14
|
# @param [integer or 2-d int array] stride Stride of filter applications.
|
15
15
|
# @param [integer or 2-d int array] pad Spatial padding width for input arrays.
|
16
16
|
# @param [boolean] nobias If `true`, then this link does not use the bias term.
|
17
|
-
# @param [Numo::NArray]
|
18
|
-
# @param [Numo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
|
17
|
+
# @param [Numo::NArray or Cumo::NArray] initial_w Initial weight value. If `nil`, the default initializer is used.
|
18
|
+
# @param [Numo::NArray or Cumo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
|
19
19
|
#
|
20
20
|
# Example
|
21
21
|
# There are several ways to make a Convolution2D link.
|
@@ -4,26 +4,45 @@ module Chainer
|
|
4
4
|
class Classifier < Chain
|
5
5
|
attr_accessor :compute_accuracy
|
6
6
|
|
7
|
-
|
7
|
+
# @param [Chainer::Link] predictor Predictor network.
|
8
|
+
# @param [Function] lossfun Loss function.
|
9
|
+
# @param [Function] accfun Function that computes accuracy.
|
10
|
+
# @param [Integer, String] label_key Key to specify label variable from arguments.
|
11
|
+
# When it is Integer, a variable in positional arguments is used.
|
12
|
+
# And when it is String, a variable in keyword arguments is used.
|
13
|
+
def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1)
|
8
14
|
super()
|
15
|
+
|
16
|
+
unless label_key.is_a?(Integer) || label_key.is_a?(String)
|
17
|
+
raise TypeError, "label_key must be Integer or String, but is #{label_key.class}"
|
18
|
+
end
|
19
|
+
|
9
20
|
@lossfun = lossfun
|
10
21
|
@accfun = accfun
|
11
22
|
@y = nil
|
12
23
|
@loss = nil
|
13
24
|
@accuracy = nil
|
14
25
|
@compute_accuracy = true
|
26
|
+
@label_key = label_key
|
15
27
|
|
16
28
|
init_scope do
|
17
29
|
@predictor = predictor
|
18
30
|
end
|
19
31
|
end
|
20
32
|
|
21
|
-
def call(*args)
|
22
|
-
|
23
|
-
|
33
|
+
def call(*args, **kwargs)
|
34
|
+
if @label_key.is_a?(Integer)
|
35
|
+
raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size
|
36
|
+
t = args.slice!(@label_key)
|
37
|
+
elsif @label_key.is_a?(String)
|
38
|
+
raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key)
|
39
|
+
t = kwargs[@label_key]
|
40
|
+
kwargs.delete(@label_key)
|
41
|
+
end
|
42
|
+
|
24
43
|
@y = nil
|
25
44
|
@accuracy = nil
|
26
|
-
@y = @predictor.(*
|
45
|
+
@y = @predictor.(*args, **kwargs)
|
27
46
|
|
28
47
|
@loss = @lossfun.call(@y, t)
|
29
48
|
Chainer::Reporter.save_report({loss: @loss}, self)
|
@@ -20,11 +20,12 @@ module Chainer
|
|
20
20
|
# @param [integer or int array] size Size (or shape) of channel dimensions.
|
21
21
|
# @param [float] decay Decay rate of moving average. It is used on training.
|
22
22
|
# @param [float] eps Epsilon value for numerical stability.
|
23
|
-
# @param [Numo::NArray.dtype] dtype Type to use in computing.
|
23
|
+
# @param [Numo::NArray.dtype or Cumo::NArray.dtype] dtype Type to use in computing.
|
24
24
|
# @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
|
25
25
|
# @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
|
26
|
-
def initialize(size, decay: 0.9, eps: 2e-5, dtype:
|
26
|
+
def initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
|
27
27
|
super()
|
28
|
+
dtype ||= Chainer::Device.default.xm::SFloat
|
28
29
|
@avg_mean = dtype.zeros(size)
|
29
30
|
register_persistent('avg_mean')
|
30
31
|
@avg_var = dtype.zeros(size)
|
@@ -79,15 +80,11 @@ module Chainer
|
|
79
80
|
decay = @decay
|
80
81
|
end
|
81
82
|
|
82
|
-
|
83
|
-
ret = func.(x, gamma, beta)
|
84
|
-
|
85
|
-
@avg_mean[false] = func.running_mean
|
86
|
-
@avg_var[false] = func.running_var
|
83
|
+
ret = Chainer::Functions::Normalization::BatchNormalization.batch_normalization(x, gamma, beta, eps: @eps, running_mean: @avg_mean, running_var: @avg_var, decay: decay)
|
87
84
|
else
|
88
|
-
mean = Chainer::Variable(@avg_mean)
|
89
|
-
var = Chainer::Variable(@avg_var)
|
90
|
-
ret = Chainer::Functions::Normalization::
|
85
|
+
mean = Chainer::Variable.new(@avg_mean)
|
86
|
+
var = Chainer::Variable.new(@avg_var)
|
87
|
+
ret = Chainer::Functions::Normalization::FixedBatchNormalization.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps)
|
91
88
|
end
|
92
89
|
|
93
90
|
ret
|
data/lib/chainer/optimizer.rb
CHANGED
@@ -10,13 +10,32 @@ module Chainer
|
|
10
10
|
@hooks = {}
|
11
11
|
end
|
12
12
|
|
13
|
+
def add_hook(hook, name: nil)
|
14
|
+
if !hook.class.method_defined?(:call)
|
15
|
+
raise TypeError, 'hook function is not callable'
|
16
|
+
end
|
17
|
+
|
18
|
+
name = hook.name if name.nil?
|
19
|
+
if @hooks[name]
|
20
|
+
raise TypeError, "hook #{name} already exists"
|
21
|
+
end
|
22
|
+
@hooks[name] = hook
|
23
|
+
end
|
24
|
+
|
25
|
+
def call_hooks
|
26
|
+
@hooks.values.each do |hook|
|
27
|
+
_call_hook(hook)
|
28
|
+
reallocate_cleared_grads
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
13
32
|
def _call_hook(hook)
|
14
33
|
if hook.methods.include?(:call_for_each_param)
|
15
|
-
@target.params
|
34
|
+
@target.params do |param|
|
16
35
|
hook.(param.update_rule, param)
|
17
36
|
end
|
18
37
|
else
|
19
|
-
hook(self)
|
38
|
+
hook.(self)
|
20
39
|
end
|
21
40
|
end
|
22
41
|
|
@@ -47,7 +66,9 @@ module Chainer
|
|
47
66
|
return unless @enabled
|
48
67
|
|
49
68
|
@t += 1
|
50
|
-
|
69
|
+
unless param.data.nil?
|
70
|
+
prepare(param)
|
71
|
+
end
|
51
72
|
@hooks.values.each do |hook|
|
52
73
|
hook.call(param)
|
53
74
|
end
|
@@ -55,14 +76,22 @@ module Chainer
|
|
55
76
|
end
|
56
77
|
|
57
78
|
def update_core(param)
|
58
|
-
|
59
|
-
|
79
|
+
xm = Chainer.get_array_module(param)
|
80
|
+
if xm == Cumo
|
81
|
+
update_core_gpu(param)
|
82
|
+
else
|
83
|
+
update_core_cpu(param)
|
84
|
+
end
|
60
85
|
end
|
61
86
|
|
62
87
|
def update_core_cpu
|
63
88
|
raise NotImplementedError
|
64
89
|
end
|
65
90
|
|
91
|
+
def update_core_gpu
|
92
|
+
raise NotImplementedError
|
93
|
+
end
|
94
|
+
|
66
95
|
def init_state(param)
|
67
96
|
raise NotImplementedError
|
68
97
|
end
|
@@ -74,14 +103,16 @@ module Chainer
|
|
74
103
|
# method, and so you need to serialize the target link separately if you
|
75
104
|
# want to fully recover the training state including parameters.
|
76
105
|
#
|
77
|
-
# @param [Chainer::AbstractSerializer] serializer
|
106
|
+
# @param [Chainer::AbstractSerializer] serializer Serializer object.
|
78
107
|
def serialize(serializer)
|
79
108
|
if @state.nil?
|
80
109
|
if serializer.is_a?(Chainer::Deserializer)
|
81
110
|
# try to initialize the state to retrieve state entries
|
82
111
|
@state = {}
|
83
112
|
self_copy = self.dup
|
84
|
-
|
113
|
+
# TODO(sonots): pass device from outside
|
114
|
+
xm = Chainer::Device.default.xm
|
115
|
+
arr = xm::SFloat.new(1)
|
85
116
|
self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
|
86
117
|
@state.keys.each do |key|
|
87
118
|
@state[key] = serializer.(key.to_s, nil)
|
@@ -101,7 +132,7 @@ module Chainer
|
|
101
132
|
@state = {}
|
102
133
|
init_state(param)
|
103
134
|
end
|
104
|
-
@state.select! { |_, v|
|
135
|
+
@state.select! { |_, v| Chainer.array?(v) }
|
105
136
|
end
|
106
137
|
end
|
107
138
|
|
@@ -126,11 +157,11 @@ module Chainer
|
|
126
157
|
#
|
127
158
|
# @param [Float] rate Coefficient for the weight decay
|
128
159
|
class WeightDecay
|
129
|
-
def
|
160
|
+
def name
|
130
161
|
"WeightDecay"
|
131
162
|
end
|
132
163
|
|
133
|
-
def
|
164
|
+
def call_for_each_param
|
134
165
|
true
|
135
166
|
end
|
136
167
|
|
@@ -140,7 +171,7 @@ module Chainer
|
|
140
171
|
|
141
172
|
def call(rule, param)
|
142
173
|
return if param.data.nil? || param.grad.nil?
|
143
|
-
param.grad += @rate * param.data
|
174
|
+
param.grad += (@rate * param.data)
|
144
175
|
end
|
145
176
|
end
|
146
177
|
end
|
@@ -21,7 +21,7 @@ module Chainer
|
|
21
21
|
@state[:v] = param.data.new_zeros
|
22
22
|
end
|
23
23
|
|
24
|
-
def
|
24
|
+
def update_core(param)
|
25
25
|
grad = param.grad
|
26
26
|
return if grad.nil?
|
27
27
|
|
@@ -29,7 +29,8 @@ module Chainer
|
|
29
29
|
|
30
30
|
@state[:m] += (1 - hp.beta1) * (grad - @state[:m])
|
31
31
|
@state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
|
32
|
-
|
32
|
+
xm = Chainer.get_array_module(grad)
|
33
|
+
param.data -= lr * @state[:m] / (xm::NMath.sqrt(@state[:v]) + hp.eps)
|
33
34
|
end
|
34
35
|
|
35
36
|
def lr
|
data/lib/chainer/parameter.rb
CHANGED
@@ -10,20 +10,21 @@ module Chainer
|
|
10
10
|
end
|
11
11
|
|
12
12
|
if shape.nil?
|
13
|
-
if
|
13
|
+
if Chainer.array?(initializer)
|
14
14
|
super(initializer, name: name)
|
15
15
|
else
|
16
16
|
super(name: name)
|
17
17
|
@initializer = initializer
|
18
18
|
dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'SFloat'
|
19
|
-
@grad_initializer = Chainer::Initializers.nan()
|
19
|
+
@grad_initializer = Chainer::Initializers.nan(dtype: dtype)
|
20
20
|
end
|
21
21
|
else
|
22
|
-
if
|
22
|
+
if Chainer.array?(initializer)
|
23
23
|
initializer = Initializers::Constant.new(initializer)
|
24
24
|
end
|
25
25
|
data = Chainer::Initializers.generate_array(initializer, shape)
|
26
|
-
|
26
|
+
xm = Chainer.get_array_module(data)
|
27
|
+
grad = xm::NArray[*[1, 2]].new_fill(-922337203)
|
27
28
|
super(data, name: name, grad: grad)
|
28
29
|
end
|
29
30
|
|
@@ -40,8 +41,8 @@ module Chainer
|
|
40
41
|
ginit = @grad_initializer
|
41
42
|
grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
|
42
43
|
|
43
|
-
|
44
|
-
|
44
|
+
self.data = data
|
45
|
+
self.grad = grad
|
45
46
|
end
|
46
47
|
|
47
48
|
def update
|
data/lib/chainer/serializer.rb
CHANGED
@@ -4,7 +4,7 @@ module Chainer
|
|
4
4
|
# Gets a child serializer.
|
5
5
|
# This operator creates a child serializer represented by the given key.
|
6
6
|
#
|
7
|
-
# @param [string] key
|
7
|
+
# @param [string] key Name of the child serializer.
|
8
8
|
def [](key)
|
9
9
|
raise NotImplementedError
|
10
10
|
end
|
@@ -21,8 +21,8 @@ module Chainer
|
|
21
21
|
# For arrays, the restored elements are directly copied into the
|
22
22
|
# ``value`` argument. String values are treated like scalars.
|
23
23
|
#
|
24
|
-
# @param [string] key
|
25
|
-
# @param [any] value
|
24
|
+
# @param [string] key Name of the serialization entry.
|
25
|
+
# @param [any] value Object to be (de)serialized.
|
26
26
|
# ``None`` is only supported by deserializers.
|
27
27
|
# @return Serialized or deserialized value.
|
28
28
|
def call(key, value)
|
@@ -35,7 +35,7 @@ module Chainer
|
|
35
35
|
# Saves an object by this serializer.
|
36
36
|
# This is equivalent to ``obj.serialize(self)``.
|
37
37
|
#
|
38
|
-
# @param [any] obj
|
38
|
+
# @param [any] obj Target object to be serialized.
|
39
39
|
def save(obj)
|
40
40
|
obj.serialize(self)
|
41
41
|
end
|
@@ -1,6 +1,6 @@
|
|
1
1
|
module Chainer
|
2
2
|
module Serializers
|
3
|
-
class MarshalSerializer < Chainer::Serializer
|
3
|
+
class MarshalSerializer < Chainer::Serializer
|
4
4
|
attr_accessor :target, :path
|
5
5
|
|
6
6
|
# @param [string] file_path Target file path
|
@@ -13,9 +13,11 @@ module Chainer
|
|
13
13
|
end
|
14
14
|
end
|
15
15
|
|
16
|
-
def initialize(target: nil, path: "")
|
16
|
+
def initialize(target: nil, path: "", device: Chainer::Device.default)
|
17
17
|
@target = target.nil? ? {} : target
|
18
18
|
@path = path
|
19
|
+
@device = device
|
20
|
+
@xm = device.xm
|
19
21
|
end
|
20
22
|
|
21
23
|
def [](key)
|
@@ -25,13 +27,13 @@ module Chainer
|
|
25
27
|
def call(key, value)
|
26
28
|
ret = value
|
27
29
|
if value.is_a?(TrueClass)
|
28
|
-
arr =
|
30
|
+
arr = @xm::Bit[1]
|
29
31
|
elsif value.is_a?(FalseClass)
|
30
|
-
arr =
|
32
|
+
arr = @xm::Bit[0]
|
31
33
|
elsif value.instance_of?(String) || value.nil?
|
32
34
|
arr = value
|
33
35
|
else
|
34
|
-
arr =
|
36
|
+
arr = @xm::NArray.cast(value)
|
35
37
|
end
|
36
38
|
@target[File.join(@path, key)] = arr
|
37
39
|
ret
|
@@ -42,8 +44,8 @@ module Chainer
|
|
42
44
|
# Loads an object from the file in Marshal format.
|
43
45
|
# This is a short-cut function to load from an Marshal file that contains only one object.
|
44
46
|
#
|
45
|
-
# @param [string
|
46
|
-
# @param [object] obj
|
47
|
+
# @param [string] filename Name of the file to be loaded.
|
48
|
+
# @param [object] obj Object to be deserialized. It must support serialization protocol.
|
47
49
|
def self.load_file(filename, obj)
|
48
50
|
File.open(filename) do |f|
|
49
51
|
d = self.new(Marshal.load(f))
|
@@ -72,7 +74,7 @@ module Chainer
|
|
72
74
|
return dataset
|
73
75
|
elsif value.instance_of?(String)
|
74
76
|
return dataset
|
75
|
-
elsif
|
77
|
+
elsif Chainer.array?(value)
|
76
78
|
value.store(dataset)
|
77
79
|
return value
|
78
80
|
elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)
|