red-chainer 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -15,13 +15,14 @@ module Chainer
15
15
  def reallocate_cleared_grads
16
16
  @target.namedparams(include_uninit: false) do |(name, param)|
17
17
  if param.grad.nil?
18
- param.grad = Numo::NArray.[](*param.data).new_zeros
18
+ xm = Chainer.get_array_module(param.data)
19
+ param.grad = xm::NArray.[](*param.data).new_zeros
19
20
  end
20
21
  end
21
22
  end
22
23
 
23
24
  def call_hooks
24
- @hooks.each do |hook|
25
+ @hooks.values.each do |hook|
25
26
  _call_hook(hook)
26
27
  reallocate_cleared_grads
27
28
  end
@@ -1,7 +1,7 @@
1
1
  module Chainer
2
2
  module Initializers
3
- def self.generate_array(initializer, shape)
4
- klass = Numo::SFloat
3
+ def self.generate_array(initializer, shape, device: Chainer::Device.default)
4
+ klass = device.xm::SFloat
5
5
  if initializer.respond_to?(:dtype) && initializer.dtype
6
6
  klass = initializer.dtype
7
7
  end
@@ -9,10 +9,10 @@ module Chainer
9
9
  initializer.(array)
10
10
  end
11
11
 
12
- def self.get_initializer(initializer)
13
- return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
12
+ def self.get_initializer(initializer, device: Chainer::Device.default)
13
+ return HeNormal.new(scale: 1 / device.xm::NMath.sqrt(2)) if initializer.nil?
14
14
  return Constant.new(initializer) if initializer.kind_of?(Numeric)
15
- return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
15
+ return Constant.new(initializer) if Chainer.array?(initializer)
16
16
 
17
17
  unless initializer.respond_to?(:call)
18
18
  raise TypeError, "invalid type of initializer: #{initializer.class}"
@@ -19,8 +19,10 @@ module Chainer
19
19
  end
20
20
 
21
21
  def call(array)
22
- fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape)
23
- s = @scale * Numo::NMath.sqrt(2.0 / fan_in)
22
+ # TODO(sonots): pass device from outside
23
+ device = Chainer::Device.default
24
+ fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape, device: device)
25
+ s = @scale * device.xm::NMath.sqrt(2.0 / fan_in)
24
26
  Normal.new(scale: s).(array)
25
27
  end
26
28
  end
@@ -0,0 +1,15 @@
1
+ module Chainer
2
+ module Initializers
3
+ class Uniform < ::Chainer::Initializer
4
+ def initialize(scale: 0.05, dtype: nil)
5
+ @scale = scale
6
+ super(dtype: dtype)
7
+ end
8
+
9
+ def call(array)
10
+ raise ArgumentError.new("dtypes are missmatched. #{dtype} != #{array.class}") if dtype && dtype != array.class
11
+ array.class.new(array.shape).rand(-@scale, @scale)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -3,11 +3,13 @@ module Chainer
3
3
  class SerialIterator < Chainer::Dataset::Iterator
4
4
  attr_reader :epoch, :is_new_epoch
5
5
 
6
- def initialize(dataset, batch_size, repeat: true, shuffle: true)
6
+ def initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default)
7
7
  @dataset = dataset
8
8
  @batch_size = batch_size
9
9
  @repeat = repeat
10
10
  @shuffle = shuffle
11
+ @device = device
12
+ @xm = device.xm
11
13
 
12
14
  reset
13
15
  end
@@ -83,10 +85,10 @@ module Chainer
83
85
  def reset
84
86
  if @shuffle
85
87
  order = @dataset.size.times.map(&:to_i).shuffle
86
- @order = Numo::Int64[*order]
88
+ @order = @xm::Int64[*order]
87
89
  else
88
90
  order = @dataset.size.times.map(&:to_i)
89
- @order = Numo::Int64[*order]
91
+ @order = @xm::Int64[*order]
90
92
  end
91
93
 
92
94
  @current_position = 0
data/lib/chainer/link.rb CHANGED
@@ -80,6 +80,8 @@ module Chainer
80
80
  end
81
81
 
82
82
  def serialize(serializer)
83
+ # TODO(sonots): pass device from outside
84
+ xm = Chainer::Device.default.xm
83
85
  d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
84
86
  @params.each do |name|
85
87
  param = d[name]
@@ -87,10 +89,10 @@ module Chainer
87
89
  if param.data.nil? && !data.nil?
88
90
  # Initialize the parameter here
89
91
  param.init(data.shape)
90
- if param.data.is_a?(Numo::NArray)
92
+ if Chainer.array?(param.data)
91
93
  param.data.store(data)
92
94
  else
93
- param.data.set(Numo::NArray.cast(data))
95
+ param.data.set(xm::NArray.cast(data))
94
96
  end
95
97
  end
96
98
  end
@@ -14,8 +14,8 @@ module Chainer
14
14
  # @param [integer or 2-d int array] stride Stride of filter applications.
15
15
  # @param [integer or 2-d int array] pad Spatial padding width for input arrays.
16
16
  # @param [boolean] nobias If `true`, then this link does not use the bias term.
17
- # @param [Numo::NArray] initialW Initial weight value. If `nil`, the default initializer is used.
18
- # @param [Numo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
17
+ # @param [Numo::NArray or Cumo::NArray] initial_w Initial weight value. If `nil`, the default initializer is used.
18
+ # @param [Numo::NArray or Cumo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
19
19
  #
20
20
  # Example
21
21
  # There are several ways to make a Convolution2D link.
@@ -4,26 +4,45 @@ module Chainer
4
4
  class Classifier < Chain
5
5
  attr_accessor :compute_accuracy
6
6
 
7
- def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy))
7
+ # @param [Chainer::Link] predictor Predictor network.
8
+ # @param [Function] lossfun Loss function.
9
+ # @param [Function] accfun Function that computes accuracy.
10
+ # @param [Integer, String] label_key Key to specify label variable from arguments.
11
+ # When it is Integer, a variable in positional arguments is used.
12
+ # And when it is String, a variable in keyword arguments is used.
13
+ def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1)
8
14
  super()
15
+
16
+ unless label_key.is_a?(Integer) || label_key.is_a?(String)
17
+ raise TypeError, "label_key must be Integer or String, but is #{label_key.class}"
18
+ end
19
+
9
20
  @lossfun = lossfun
10
21
  @accfun = accfun
11
22
  @y = nil
12
23
  @loss = nil
13
24
  @accuracy = nil
14
25
  @compute_accuracy = true
26
+ @label_key = label_key
15
27
 
16
28
  init_scope do
17
29
  @predictor = predictor
18
30
  end
19
31
  end
20
32
 
21
- def call(*args)
22
- t = args.pop
23
- x = args
33
+ def call(*args, **kwargs)
34
+ if @label_key.is_a?(Integer)
35
+ raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size
36
+ t = args.slice!(@label_key)
37
+ elsif @label_key.is_a?(String)
38
+ raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key)
39
+ t = kwargs[@label_key]
40
+ kwargs.delete(@label_key)
41
+ end
42
+
24
43
  @y = nil
25
44
  @accuracy = nil
26
- @y = @predictor.(*x)
45
+ @y = @predictor.(*args, **kwargs)
27
46
 
28
47
  @loss = @lossfun.call(@y, t)
29
48
  Chainer::Reporter.save_report({loss: @loss}, self)
@@ -20,11 +20,12 @@ module Chainer
20
20
  # @param [integer or int array] size Size (or shape) of channel dimensions.
21
21
  # @param [float] decay Decay rate of moving average. It is used on training.
22
22
  # @param [float] eps Epsilon value for numerical stability.
23
- # @param [Numo::NArray.dtype] dtype Type to use in computing.
23
+ # @param [Numo::NArray.dtype or Cumo::NArray.dtype] dtype Type to use in computing.
24
24
  # @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
25
25
  # @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
26
- def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::SFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
26
+ def initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
27
27
  super()
28
+ dtype ||= Chainer::Device.default.xm::SFloat
28
29
  @avg_mean = dtype.zeros(size)
29
30
  register_persistent('avg_mean')
30
31
  @avg_var = dtype.zeros(size)
@@ -79,15 +80,11 @@ module Chainer
79
80
  decay = @decay
80
81
  end
81
82
 
82
- func = Chainer::Functions::Normalization::BatchNormalizationFunction.new(eps: @eps, mean: @avg_mean, var: @avg_var, decay: decay)
83
- ret = func.(x, gamma, beta)
84
-
85
- @avg_mean[false] = func.running_mean
86
- @avg_var[false] = func.running_var
83
+ ret = Chainer::Functions::Normalization::BatchNormalization.batch_normalization(x, gamma, beta, eps: @eps, running_mean: @avg_mean, running_var: @avg_var, decay: decay)
87
84
  else
88
- mean = Chainer::Variable(@avg_mean)
89
- var = Chainer::Variable(@avg_var)
90
- ret = Chainer::Functions::Normalization::BatchNormalizationFunction.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps)
85
+ mean = Chainer::Variable.new(@avg_mean)
86
+ var = Chainer::Variable.new(@avg_var)
87
+ ret = Chainer::Functions::Normalization::FixedBatchNormalization.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps)
91
88
  end
92
89
 
93
90
  ret
@@ -10,13 +10,32 @@ module Chainer
10
10
  @hooks = {}
11
11
  end
12
12
 
13
+ def add_hook(hook, name: nil)
14
+ if !hook.class.method_defined?(:call)
15
+ raise TypeError, 'hook function is not callable'
16
+ end
17
+
18
+ name = hook.name if name.nil?
19
+ if @hooks[name]
20
+ raise TypeError, "hook #{name} already exists"
21
+ end
22
+ @hooks[name] = hook
23
+ end
24
+
25
+ def call_hooks
26
+ @hooks.values.each do |hook|
27
+ _call_hook(hook)
28
+ reallocate_cleared_grads
29
+ end
30
+ end
31
+
13
32
  def _call_hook(hook)
14
33
  if hook.methods.include?(:call_for_each_param)
15
- @target.params.each do |param|
34
+ @target.params do |param|
16
35
  hook.(param.update_rule, param)
17
36
  end
18
37
  else
19
- hook(self)
38
+ hook.(self)
20
39
  end
21
40
  end
22
41
 
@@ -47,7 +66,9 @@ module Chainer
47
66
  return unless @enabled
48
67
 
49
68
  @t += 1
50
- prepare(param)
69
+ unless param.data.nil?
70
+ prepare(param)
71
+ end
51
72
  @hooks.values.each do |hook|
52
73
  hook.call(param)
53
74
  end
@@ -55,14 +76,22 @@ module Chainer
55
76
  end
56
77
 
57
78
  def update_core(param)
58
- # TODO: support GPU
59
- update_core_cpu(param)
79
+ xm = Chainer.get_array_module(param)
80
+ if xm == Cumo
81
+ update_core_gpu(param)
82
+ else
83
+ update_core_cpu(param)
84
+ end
60
85
  end
61
86
 
62
87
  def update_core_cpu
63
88
  raise NotImplementedError
64
89
  end
65
90
 
91
+ def update_core_gpu
92
+ raise NotImplementedError
93
+ end
94
+
66
95
  def init_state(param)
67
96
  raise NotImplementedError
68
97
  end
@@ -74,14 +103,16 @@ module Chainer
74
103
  # method, and so you need to serialize the target link separately if you
75
104
  # want to fully recover the training state including parameters.
76
105
  #
77
- # @param [Chainer::AbstractSerializer] serializer: Serializer object.
106
+ # @param [Chainer::AbstractSerializer] serializer Serializer object.
78
107
  def serialize(serializer)
79
108
  if @state.nil?
80
109
  if serializer.is_a?(Chainer::Deserializer)
81
110
  # try to initialize the state to retrieve state entries
82
111
  @state = {}
83
112
  self_copy = self.dup
84
- arr = Numo::SFloat.new(1)
113
+ # TODO(sonots): pass device from outside
114
+ xm = Chainer::Device.default.xm
115
+ arr = xm::SFloat.new(1)
85
116
  self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
86
117
  @state.keys.each do |key|
87
118
  @state[key] = serializer.(key.to_s, nil)
@@ -101,7 +132,7 @@ module Chainer
101
132
  @state = {}
102
133
  init_state(param)
103
134
  end
104
- @state.select! { |_, v| v.kind_of?(Numo::NArray) }
135
+ @state.select! { |_, v| Chainer.array?(v) }
105
136
  end
106
137
  end
107
138
 
@@ -126,11 +157,11 @@ module Chainer
126
157
  #
127
158
  # @param [Float] rate Coefficient for the weight decay
128
159
  class WeightDecay
129
- def self.name
160
+ def name
130
161
  "WeightDecay"
131
162
  end
132
163
 
133
- def self.call_for_each_param
164
+ def call_for_each_param
134
165
  true
135
166
  end
136
167
 
@@ -140,7 +171,7 @@ module Chainer
140
171
 
141
172
  def call(rule, param)
142
173
  return if param.data.nil? || param.grad.nil?
143
- param.grad += @rate * param.data
174
+ param.grad += (@rate * param.data)
144
175
  end
145
176
  end
146
177
  end
@@ -21,7 +21,7 @@ module Chainer
21
21
  @state[:v] = param.data.new_zeros
22
22
  end
23
23
 
24
- def update_core_cpu(param)
24
+ def update_core(param)
25
25
  grad = param.grad
26
26
  return if grad.nil?
27
27
 
@@ -29,7 +29,8 @@ module Chainer
29
29
 
30
30
  @state[:m] += (1 - hp.beta1) * (grad - @state[:m])
31
31
  @state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
32
- param.data -= lr * @state[:m] / (Numo::NMath.sqrt(@state[:v]) + hp.eps)
32
+ xm = Chainer.get_array_module(grad)
33
+ param.data -= lr * @state[:m] / (xm::NMath.sqrt(@state[:v]) + hp.eps)
33
34
  end
34
35
 
35
36
  def lr
@@ -17,7 +17,7 @@ module Chainer
17
17
  @state[:v] = param.data.new_zeros
18
18
  end
19
19
 
20
- def update_core_cpu(param)
20
+ def update_core(param)
21
21
  grad = param.grad
22
22
  return if grad.nil?
23
23
 
@@ -10,20 +10,21 @@ module Chainer
10
10
  end
11
11
 
12
12
  if shape.nil?
13
- if @initializer.kind_of?(Numo::NArray)
13
+ if Chainer.array?(initializer)
14
14
  super(initializer, name: name)
15
15
  else
16
16
  super(name: name)
17
17
  @initializer = initializer
18
18
  dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'SFloat'
19
- @grad_initializer = Chainer::Initializers.nan()
19
+ @grad_initializer = Chainer::Initializers.nan(dtype: dtype)
20
20
  end
21
21
  else
22
- if initializer.kind_of?(Numo::NArray)
22
+ if Chainer.array?(initializer)
23
23
  initializer = Initializers::Constant.new(initializer)
24
24
  end
25
25
  data = Chainer::Initializers.generate_array(initializer, shape)
26
- grad = Numo::NArray[*[1, 2]].new_fill(-922337203)
26
+ xm = Chainer.get_array_module(data)
27
+ grad = xm::NArray[*[1, 2]].new_fill(-922337203)
27
28
  super(data, name: name, grad: grad)
28
29
  end
29
30
 
@@ -40,8 +41,8 @@ module Chainer
40
41
  ginit = @grad_initializer
41
42
  grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
42
43
 
43
- @data[0] = data
44
- @node.grad = grad
44
+ self.data = data
45
+ self.grad = grad
45
46
  end
46
47
 
47
48
  def update
@@ -4,7 +4,7 @@ module Chainer
4
4
  # Gets a child serializer.
5
5
  # This operator creates a child serializer represented by the given key.
6
6
  #
7
- # @param [string] key: Name of the child serializer.
7
+ # @param [string] key Name of the child serializer.
8
8
  def [](key)
9
9
  raise NotImplementedError
10
10
  end
@@ -21,8 +21,8 @@ module Chainer
21
21
  # For arrays, the restored elements are directly copied into the
22
22
  # ``value`` argument. String values are treated like scalars.
23
23
  #
24
- # @param [string] key: Name of the serialization entry.
25
- # @param [any] value: Object to be (de)serialized.
24
+ # @param [string] key Name of the serialization entry.
25
+ # @param [any] value Object to be (de)serialized.
26
26
  # ``None`` is only supported by deserializers.
27
27
  # @return Serialized or deserialized value.
28
28
  def call(key, value)
@@ -35,7 +35,7 @@ module Chainer
35
35
  # Saves an object by this serializer.
36
36
  # This is equivalent to ``obj.serialize(self)``.
37
37
  #
38
- # @param [any] obj: Target object to be serialized.
38
+ # @param [any] obj Target object to be serialized.
39
39
  def save(obj)
40
40
  obj.serialize(self)
41
41
  end
@@ -1,6 +1,6 @@
1
1
  module Chainer
2
2
  module Serializers
3
- class MarshalSerializer < Chainer::Serializer
3
+ class MarshalSerializer < Chainer::Serializer
4
4
  attr_accessor :target, :path
5
5
 
6
6
  # @param [string] file_path Target file path
@@ -13,9 +13,11 @@ module Chainer
13
13
  end
14
14
  end
15
15
 
16
- def initialize(target: nil, path: "")
16
+ def initialize(target: nil, path: "", device: Chainer::Device.default)
17
17
  @target = target.nil? ? {} : target
18
18
  @path = path
19
+ @device = device
20
+ @xm = device.xm
19
21
  end
20
22
 
21
23
  def [](key)
@@ -25,13 +27,13 @@ module Chainer
25
27
  def call(key, value)
26
28
  ret = value
27
29
  if value.is_a?(TrueClass)
28
- arr = Numo::Bit[1]
30
+ arr = @xm::Bit[1]
29
31
  elsif value.is_a?(FalseClass)
30
- arr = Numo::Bit[0]
32
+ arr = @xm::Bit[0]
31
33
  elsif value.instance_of?(String) || value.nil?
32
34
  arr = value
33
35
  else
34
- arr = Numo::NArray.cast(value)
36
+ arr = @xm::NArray.cast(value)
35
37
  end
36
38
  @target[File.join(@path, key)] = arr
37
39
  ret
@@ -42,8 +44,8 @@ module Chainer
42
44
  # Loads an object from the file in Marshal format.
43
45
  # This is a short-cut function to load from an Marshal file that contains only one object.
44
46
  #
45
- # @param [string ]filename: Name of the file to be loaded.
46
- # @param [object] obj: Object to be deserialized. It must support serialization protocol.
47
+ # @param [string] filename Name of the file to be loaded.
48
+ # @param [object] obj Object to be deserialized. It must support serialization protocol.
47
49
  def self.load_file(filename, obj)
48
50
  File.open(filename) do |f|
49
51
  d = self.new(Marshal.load(f))
@@ -72,7 +74,7 @@ module Chainer
72
74
  return dataset
73
75
  elsif value.instance_of?(String)
74
76
  return dataset
75
- elsif value.is_a?(Numo::NArray)
77
+ elsif Chainer.array?(value)
76
78
  value.store(dataset)
77
79
  return value
78
80
  elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)