red-chainer 0.3.2 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +2 -2
  3. data/.travis.yml +8 -3
  4. data/.yardopts +1 -0
  5. data/Gemfile +6 -1
  6. data/README.md +34 -3
  7. data/examples/cifar/train_cifar.rb +13 -2
  8. data/examples/iris/iris.rb +9 -5
  9. data/examples/mnist/mnist.rb +16 -4
  10. data/lib/chainer.rb +17 -1
  11. data/lib/chainer/backend.rb +27 -0
  12. data/lib/chainer/cuda.rb +37 -15
  13. data/lib/chainer/dataset/convert.rb +20 -16
  14. data/lib/chainer/datasets/cifar.rb +8 -6
  15. data/lib/chainer/datasets/mnist.rb +14 -55
  16. data/lib/chainer/device.rb +88 -0
  17. data/lib/chainer/function.rb +103 -41
  18. data/lib/chainer/function_node.rb +454 -0
  19. data/lib/chainer/functions/activation/leaky_relu.rb +38 -13
  20. data/lib/chainer/functions/activation/log_softmax.rb +46 -9
  21. data/lib/chainer/functions/activation/relu.rb +8 -8
  22. data/lib/chainer/functions/activation/relu_grad2.rb +34 -0
  23. data/lib/chainer/functions/activation/sigmoid.rb +13 -11
  24. data/lib/chainer/functions/activation/sigmoid_grad.rb +25 -0
  25. data/lib/chainer/functions/activation/tanh.rb +48 -11
  26. data/lib/chainer/functions/array/broadcast_to.rb +56 -0
  27. data/lib/chainer/functions/array/cast.rb +41 -0
  28. data/lib/chainer/functions/array/reshape.rb +28 -0
  29. data/lib/chainer/functions/array/rollaxis.rb +57 -0
  30. data/lib/chainer/functions/array/select_item.rb +72 -0
  31. data/lib/chainer/functions/array/squeeze.rb +78 -0
  32. data/lib/chainer/functions/array/transpose.rb +44 -0
  33. data/lib/chainer/functions/connection/convolution_2d.rb +43 -26
  34. data/lib/chainer/functions/connection/convolution_2d_grad_w.rb +48 -0
  35. data/lib/chainer/functions/connection/deconvolution_2d.rb +159 -0
  36. data/lib/chainer/functions/connection/linear.rb +29 -22
  37. data/lib/chainer/functions/evaluation/accuracy.rb +5 -5
  38. data/lib/chainer/functions/loss/mean_squared_error.rb +21 -12
  39. data/lib/chainer/functions/loss/softmax_cross_entropy.rb +98 -71
  40. data/lib/chainer/functions/math/basic_math.rb +36 -30
  41. data/lib/chainer/functions/math/exp.rb +28 -0
  42. data/lib/chainer/functions/math/identity.rb +4 -3
  43. data/lib/chainer/functions/math/sum.rb +52 -0
  44. data/lib/chainer/functions/noise/dropout.rb +20 -4
  45. data/lib/chainer/functions/normalization/batch_normalization.rb +257 -104
  46. data/lib/chainer/functions/pooling/average_pooling_2d.rb +29 -6
  47. data/lib/chainer/functions/pooling/max_pooling_2d.rb +67 -12
  48. data/lib/chainer/functions/pooling/pooling_2d.rb +6 -4
  49. data/lib/chainer/gradient_check.rb +157 -73
  50. data/lib/chainer/gradient_method.rb +3 -2
  51. data/lib/chainer/initializers/init.rb +5 -5
  52. data/lib/chainer/initializers/normal.rb +4 -2
  53. data/lib/chainer/initializers/uniform.rb +15 -0
  54. data/lib/chainer/iterators/serial_iterator.rb +5 -3
  55. data/lib/chainer/link.rb +4 -2
  56. data/lib/chainer/links/connection/convolution_2d.rb +2 -2
  57. data/lib/chainer/links/model/classifier.rb +24 -5
  58. data/lib/chainer/links/normalization/batch_normalization.rb +7 -10
  59. data/lib/chainer/optimizer.rb +42 -11
  60. data/lib/chainer/optimizers/adam.rb +3 -2
  61. data/lib/chainer/optimizers/momentum_sgd.rb +1 -1
  62. data/lib/chainer/parameter.rb +7 -6
  63. data/lib/chainer/serializer.rb +4 -4
  64. data/lib/chainer/serializers/marshal.rb +10 -8
  65. data/lib/chainer/testing/array.rb +1 -1
  66. data/lib/chainer/training/extensions/evaluator.rb +2 -3
  67. data/lib/chainer/training/extensions/exponential_shift.rb +1 -1
  68. data/lib/chainer/training/extensions/progress_bar.rb +1 -0
  69. data/lib/chainer/training/trainer.rb +4 -9
  70. data/lib/chainer/training/triggers/interval.rb +7 -2
  71. data/lib/chainer/utils/array.rb +80 -1
  72. data/lib/chainer/utils/conv.rb +10 -2
  73. data/lib/chainer/utils/initializer.rb +2 -2
  74. data/lib/chainer/variable.rb +159 -69
  75. data/lib/chainer/variable_node.rb +64 -10
  76. data/lib/chainer/version.rb +1 -1
  77. data/red-chainer.gemspec +4 -3
  78. data/templates/default/layout/html/layout.erb +40 -0
  79. data/templates/default/onefile/html/layout.erb +33 -0
  80. metadata +44 -11
  81. data/lib/chainer/dataset/download.rb +0 -56
@@ -15,13 +15,14 @@ module Chainer
15
15
  def reallocate_cleared_grads
16
16
  @target.namedparams(include_uninit: false) do |(name, param)|
17
17
  if param.grad.nil?
18
- param.grad = Numo::NArray.[](*param.data).new_zeros
18
+ xm = Chainer.get_array_module(param.data)
19
+ param.grad = xm::NArray.[](*param.data).new_zeros
19
20
  end
20
21
  end
21
22
  end
22
23
 
23
24
  def call_hooks
24
- @hooks.each do |hook|
25
+ @hooks.values.each do |hook|
25
26
  _call_hook(hook)
26
27
  reallocate_cleared_grads
27
28
  end
@@ -1,7 +1,7 @@
1
1
  module Chainer
2
2
  module Initializers
3
- def self.generate_array(initializer, shape)
4
- klass = Numo::SFloat
3
+ def self.generate_array(initializer, shape, device: Chainer::Device.default)
4
+ klass = device.xm::SFloat
5
5
  if initializer.respond_to?(:dtype) && initializer.dtype
6
6
  klass = initializer.dtype
7
7
  end
@@ -9,10 +9,10 @@ module Chainer
9
9
  initializer.(array)
10
10
  end
11
11
 
12
- def self.get_initializer(initializer)
13
- return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
12
+ def self.get_initializer(initializer, device: Chainer::Device.default)
13
+ return HeNormal.new(scale: 1 / device.xm::NMath.sqrt(2)) if initializer.nil?
14
14
  return Constant.new(initializer) if initializer.kind_of?(Numeric)
15
- return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
15
+ return Constant.new(initializer) if Chainer.array?(initializer)
16
16
 
17
17
  unless initializer.respond_to?(:call)
18
18
  raise TypeError, "invalid type of initializer: #{initializer.class}"
@@ -19,8 +19,10 @@ module Chainer
19
19
  end
20
20
 
21
21
  def call(array)
22
- fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape)
23
- s = @scale * Numo::NMath.sqrt(2.0 / fan_in)
22
+ # TODO(sonots): pass device from outside
23
+ device = Chainer::Device.default
24
+ fan_in, fan_out = Chainer::Utils::Initializer.get_fans(array.shape, device: device)
25
+ s = @scale * device.xm::NMath.sqrt(2.0 / fan_in)
24
26
  Normal.new(scale: s).(array)
25
27
  end
26
28
  end
@@ -0,0 +1,15 @@
1
+ module Chainer
2
+ module Initializers
3
+ class Uniform < ::Chainer::Initializer
4
+ def initialize(scale: 0.05, dtype: nil)
5
+ @scale = scale
6
+ super(dtype: dtype)
7
+ end
8
+
9
+ def call(array)
10
+ raise ArgumentError.new("dtypes are missmatched. #{dtype} != #{array.class}") if dtype && dtype != array.class
11
+ array.class.new(array.shape).rand(-@scale, @scale)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -3,11 +3,13 @@ module Chainer
3
3
  class SerialIterator < Chainer::Dataset::Iterator
4
4
  attr_reader :epoch, :is_new_epoch
5
5
 
6
- def initialize(dataset, batch_size, repeat: true, shuffle: true)
6
+ def initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default)
7
7
  @dataset = dataset
8
8
  @batch_size = batch_size
9
9
  @repeat = repeat
10
10
  @shuffle = shuffle
11
+ @device = device
12
+ @xm = device.xm
11
13
 
12
14
  reset
13
15
  end
@@ -83,10 +85,10 @@ module Chainer
83
85
  def reset
84
86
  if @shuffle
85
87
  order = @dataset.size.times.map(&:to_i).shuffle
86
- @order = Numo::Int64[*order]
88
+ @order = @xm::Int64[*order]
87
89
  else
88
90
  order = @dataset.size.times.map(&:to_i)
89
- @order = Numo::Int64[*order]
91
+ @order = @xm::Int64[*order]
90
92
  end
91
93
 
92
94
  @current_position = 0
data/lib/chainer/link.rb CHANGED
@@ -80,6 +80,8 @@ module Chainer
80
80
  end
81
81
 
82
82
  def serialize(serializer)
83
+ # TODO(sonots): pass device from outside
84
+ xm = Chainer::Device.default.xm
83
85
  d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
84
86
  @params.each do |name|
85
87
  param = d[name]
@@ -87,10 +89,10 @@ module Chainer
87
89
  if param.data.nil? && !data.nil?
88
90
  # Initialize the parameter here
89
91
  param.init(data.shape)
90
- if param.data.is_a?(Numo::NArray)
92
+ if Chainer.array?(param.data)
91
93
  param.data.store(data)
92
94
  else
93
- param.data.set(Numo::NArray.cast(data))
95
+ param.data.set(xm::NArray.cast(data))
94
96
  end
95
97
  end
96
98
  end
@@ -14,8 +14,8 @@ module Chainer
14
14
  # @param [integer or 2-d int array] stride Stride of filter applications.
15
15
  # @param [integer or 2-d int array] pad Spatial padding width for input arrays.
16
16
  # @param [boolean] nobias If `true`, then this link does not use the bias term.
17
- # @param [Numo::NArray] initialW Initial weight value. If `nil`, the default initializer is used.
18
- # @param [Numo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
17
+ # @param [Numo::NArray or Cumo::NArray] initial_w Initial weight value. If `nil`, the default initializer is used.
18
+ # @param [Numo::NArray or Cumo::NArray] initial_bias Initial bias value. If `nil`, the bias is set to 0.
19
19
  #
20
20
  # Example
21
21
  # There are several ways to make a Convolution2D link.
@@ -4,26 +4,45 @@ module Chainer
4
4
  class Classifier < Chain
5
5
  attr_accessor :compute_accuracy
6
6
 
7
- def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy))
7
+ # @param [Chainer::Link] predictor Predictor network.
8
+ # @param [Function] lossfun Loss function.
9
+ # @param [Function] accfun Function that computes accuracy.
10
+ # @param [Integer, String] label_key Key to specify label variable from arguments.
11
+ # When it is Integer, a variable in positional arguments is used.
12
+ # And when it is String, a variable in keyword arguments is used.
13
+ def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1)
8
14
  super()
15
+
16
+ unless label_key.is_a?(Integer) || label_key.is_a?(String)
17
+ raise TypeError, "label_key must be Integer or String, but is #{label_key.class}"
18
+ end
19
+
9
20
  @lossfun = lossfun
10
21
  @accfun = accfun
11
22
  @y = nil
12
23
  @loss = nil
13
24
  @accuracy = nil
14
25
  @compute_accuracy = true
26
+ @label_key = label_key
15
27
 
16
28
  init_scope do
17
29
  @predictor = predictor
18
30
  end
19
31
  end
20
32
 
21
- def call(*args)
22
- t = args.pop
23
- x = args
33
+ def call(*args, **kwargs)
34
+ if @label_key.is_a?(Integer)
35
+ raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size
36
+ t = args.slice!(@label_key)
37
+ elsif @label_key.is_a?(String)
38
+ raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key)
39
+ t = kwargs[@label_key]
40
+ kwargs.delete(@label_key)
41
+ end
42
+
24
43
  @y = nil
25
44
  @accuracy = nil
26
- @y = @predictor.(*x)
45
+ @y = @predictor.(*args, **kwargs)
27
46
 
28
47
  @loss = @lossfun.call(@y, t)
29
48
  Chainer::Reporter.save_report({loss: @loss}, self)
@@ -20,11 +20,12 @@ module Chainer
20
20
  # @param [integer or int array] size Size (or shape) of channel dimensions.
21
21
  # @param [float] decay Decay rate of moving average. It is used on training.
22
22
  # @param [float] eps Epsilon value for numerical stability.
23
- # @param [Numo::NArray.dtype] dtype Type to use in computing.
23
+ # @param [Numo::NArray.dtype or Cumo::NArray.dtype] dtype Type to use in computing.
24
24
  # @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
25
25
  # @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
26
- def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::SFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
26
+ def initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
27
27
  super()
28
+ dtype ||= Chainer::Device.default.xm::SFloat
28
29
  @avg_mean = dtype.zeros(size)
29
30
  register_persistent('avg_mean')
30
31
  @avg_var = dtype.zeros(size)
@@ -79,15 +80,11 @@ module Chainer
79
80
  decay = @decay
80
81
  end
81
82
 
82
- func = Chainer::Functions::Normalization::BatchNormalizationFunction.new(eps: @eps, mean: @avg_mean, var: @avg_var, decay: decay)
83
- ret = func.(x, gamma, beta)
84
-
85
- @avg_mean[false] = func.running_mean
86
- @avg_var[false] = func.running_var
83
+ ret = Chainer::Functions::Normalization::BatchNormalization.batch_normalization(x, gamma, beta, eps: @eps, running_mean: @avg_mean, running_var: @avg_var, decay: decay)
87
84
  else
88
- mean = Chainer::Variable(@avg_mean)
89
- var = Chainer::Variable(@avg_var)
90
- ret = Chainer::Functions::Normalization::BatchNormalizationFunction.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps)
85
+ mean = Chainer::Variable.new(@avg_mean)
86
+ var = Chainer::Variable.new(@avg_var)
87
+ ret = Chainer::Functions::Normalization::FixedBatchNormalization.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps)
91
88
  end
92
89
 
93
90
  ret
@@ -10,13 +10,32 @@ module Chainer
10
10
  @hooks = {}
11
11
  end
12
12
 
13
+ def add_hook(hook, name: nil)
14
+ if !hook.class.method_defined?(:call)
15
+ raise TypeError, 'hook function is not callable'
16
+ end
17
+
18
+ name = hook.name if name.nil?
19
+ if @hooks[name]
20
+ raise TypeError, "hook #{name} already exists"
21
+ end
22
+ @hooks[name] = hook
23
+ end
24
+
25
+ def call_hooks
26
+ @hooks.values.each do |hook|
27
+ _call_hook(hook)
28
+ reallocate_cleared_grads
29
+ end
30
+ end
31
+
13
32
  def _call_hook(hook)
14
33
  if hook.methods.include?(:call_for_each_param)
15
- @target.params.each do |param|
34
+ @target.params do |param|
16
35
  hook.(param.update_rule, param)
17
36
  end
18
37
  else
19
- hook(self)
38
+ hook.(self)
20
39
  end
21
40
  end
22
41
 
@@ -47,7 +66,9 @@ module Chainer
47
66
  return unless @enabled
48
67
 
49
68
  @t += 1
50
- prepare(param)
69
+ unless param.data.nil?
70
+ prepare(param)
71
+ end
51
72
  @hooks.values.each do |hook|
52
73
  hook.call(param)
53
74
  end
@@ -55,14 +76,22 @@ module Chainer
55
76
  end
56
77
 
57
78
  def update_core(param)
58
- # TODO: support GPU
59
- update_core_cpu(param)
79
+ xm = Chainer.get_array_module(param)
80
+ if xm == Cumo
81
+ update_core_gpu(param)
82
+ else
83
+ update_core_cpu(param)
84
+ end
60
85
  end
61
86
 
62
87
  def update_core_cpu
63
88
  raise NotImplementedError
64
89
  end
65
90
 
91
+ def update_core_gpu
92
+ raise NotImplementedError
93
+ end
94
+
66
95
  def init_state(param)
67
96
  raise NotImplementedError
68
97
  end
@@ -74,14 +103,16 @@ module Chainer
74
103
  # method, and so you need to serialize the target link separately if you
75
104
  # want to fully recover the training state including parameters.
76
105
  #
77
- # @param [Chainer::AbstractSerializer] serializer: Serializer object.
106
+ # @param [Chainer::AbstractSerializer] serializer Serializer object.
78
107
  def serialize(serializer)
79
108
  if @state.nil?
80
109
  if serializer.is_a?(Chainer::Deserializer)
81
110
  # try to initialize the state to retrieve state entries
82
111
  @state = {}
83
112
  self_copy = self.dup
84
- arr = Numo::SFloat.new(1)
113
+ # TODO(sonots): pass device from outside
114
+ xm = Chainer::Device.default.xm
115
+ arr = xm::SFloat.new(1)
85
116
  self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
86
117
  @state.keys.each do |key|
87
118
  @state[key] = serializer.(key.to_s, nil)
@@ -101,7 +132,7 @@ module Chainer
101
132
  @state = {}
102
133
  init_state(param)
103
134
  end
104
- @state.select! { |_, v| v.kind_of?(Numo::NArray) }
135
+ @state.select! { |_, v| Chainer.array?(v) }
105
136
  end
106
137
  end
107
138
 
@@ -126,11 +157,11 @@ module Chainer
126
157
  #
127
158
  # @param [Float] rate Coefficient for the weight decay
128
159
  class WeightDecay
129
- def self.name
160
+ def name
130
161
  "WeightDecay"
131
162
  end
132
163
 
133
- def self.call_for_each_param
164
+ def call_for_each_param
134
165
  true
135
166
  end
136
167
 
@@ -140,7 +171,7 @@ module Chainer
140
171
 
141
172
  def call(rule, param)
142
173
  return if param.data.nil? || param.grad.nil?
143
- param.grad += @rate * param.data
174
+ param.grad += (@rate * param.data)
144
175
  end
145
176
  end
146
177
  end
@@ -21,7 +21,7 @@ module Chainer
21
21
  @state[:v] = param.data.new_zeros
22
22
  end
23
23
 
24
- def update_core_cpu(param)
24
+ def update_core(param)
25
25
  grad = param.grad
26
26
  return if grad.nil?
27
27
 
@@ -29,7 +29,8 @@ module Chainer
29
29
 
30
30
  @state[:m] += (1 - hp.beta1) * (grad - @state[:m])
31
31
  @state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
32
- param.data -= lr * @state[:m] / (Numo::NMath.sqrt(@state[:v]) + hp.eps)
32
+ xm = Chainer.get_array_module(grad)
33
+ param.data -= lr * @state[:m] / (xm::NMath.sqrt(@state[:v]) + hp.eps)
33
34
  end
34
35
 
35
36
  def lr
@@ -17,7 +17,7 @@ module Chainer
17
17
  @state[:v] = param.data.new_zeros
18
18
  end
19
19
 
20
- def update_core_cpu(param)
20
+ def update_core(param)
21
21
  grad = param.grad
22
22
  return if grad.nil?
23
23
 
@@ -10,20 +10,21 @@ module Chainer
10
10
  end
11
11
 
12
12
  if shape.nil?
13
- if @initializer.kind_of?(Numo::NArray)
13
+ if Chainer.array?(initializer)
14
14
  super(initializer, name: name)
15
15
  else
16
16
  super(name: name)
17
17
  @initializer = initializer
18
18
  dtype = initializer.respond_to?(:dtype) ? initializer.dtype : 'SFloat'
19
- @grad_initializer = Chainer::Initializers.nan()
19
+ @grad_initializer = Chainer::Initializers.nan(dtype: dtype)
20
20
  end
21
21
  else
22
- if initializer.kind_of?(Numo::NArray)
22
+ if Chainer.array?(initializer)
23
23
  initializer = Initializers::Constant.new(initializer)
24
24
  end
25
25
  data = Chainer::Initializers.generate_array(initializer, shape)
26
- grad = Numo::NArray[*[1, 2]].new_fill(-922337203)
26
+ xm = Chainer.get_array_module(data)
27
+ grad = xm::NArray[*[1, 2]].new_fill(-922337203)
27
28
  super(data, name: name, grad: grad)
28
29
  end
29
30
 
@@ -40,8 +41,8 @@ module Chainer
40
41
  ginit = @grad_initializer
41
42
  grad = ginit.nil? ? nil : Chainer::Initializers.generate_array(ginit, shape)
42
43
 
43
- @data[0] = data
44
- @node.grad = grad
44
+ self.data = data
45
+ self.grad = grad
45
46
  end
46
47
 
47
48
  def update
@@ -4,7 +4,7 @@ module Chainer
4
4
  # Gets a child serializer.
5
5
  # This operator creates a child serializer represented by the given key.
6
6
  #
7
- # @param [string] key: Name of the child serializer.
7
+ # @param [string] key Name of the child serializer.
8
8
  def [](key)
9
9
  raise NotImplementedError
10
10
  end
@@ -21,8 +21,8 @@ module Chainer
21
21
  # For arrays, the restored elements are directly copied into the
22
22
  # ``value`` argument. String values are treated like scalars.
23
23
  #
24
- # @param [string] key: Name of the serialization entry.
25
- # @param [any] value: Object to be (de)serialized.
24
+ # @param [string] key Name of the serialization entry.
25
+ # @param [any] value Object to be (de)serialized.
26
26
  # ``None`` is only supported by deserializers.
27
27
  # @return Serialized or deserialized value.
28
28
  def call(key, value)
@@ -35,7 +35,7 @@ module Chainer
35
35
  # Saves an object by this serializer.
36
36
  # This is equivalent to ``obj.serialize(self)``.
37
37
  #
38
- # @param [any] obj: Target object to be serialized.
38
+ # @param [any] obj Target object to be serialized.
39
39
  def save(obj)
40
40
  obj.serialize(self)
41
41
  end
@@ -1,6 +1,6 @@
1
1
  module Chainer
2
2
  module Serializers
3
- class MarshalSerializer < Chainer::Serializer
3
+ class MarshalSerializer < Chainer::Serializer
4
4
  attr_accessor :target, :path
5
5
 
6
6
  # @param [string] file_path Target file path
@@ -13,9 +13,11 @@ module Chainer
13
13
  end
14
14
  end
15
15
 
16
- def initialize(target: nil, path: "")
16
+ def initialize(target: nil, path: "", device: Chainer::Device.default)
17
17
  @target = target.nil? ? {} : target
18
18
  @path = path
19
+ @device = device
20
+ @xm = device.xm
19
21
  end
20
22
 
21
23
  def [](key)
@@ -25,13 +27,13 @@ module Chainer
25
27
  def call(key, value)
26
28
  ret = value
27
29
  if value.is_a?(TrueClass)
28
- arr = Numo::Bit[1]
30
+ arr = @xm::Bit[1]
29
31
  elsif value.is_a?(FalseClass)
30
- arr = Numo::Bit[0]
32
+ arr = @xm::Bit[0]
31
33
  elsif value.instance_of?(String) || value.nil?
32
34
  arr = value
33
35
  else
34
- arr = Numo::NArray.cast(value)
36
+ arr = @xm::NArray.cast(value)
35
37
  end
36
38
  @target[File.join(@path, key)] = arr
37
39
  ret
@@ -42,8 +44,8 @@ module Chainer
42
44
  # Loads an object from the file in Marshal format.
43
45
  # This is a short-cut function to load from an Marshal file that contains only one object.
44
46
  #
45
- # @param [string ]filename: Name of the file to be loaded.
46
- # @param [object] obj: Object to be deserialized. It must support serialization protocol.
47
+ # @param [string] filename Name of the file to be loaded.
48
+ # @param [object] obj Object to be deserialized. It must support serialization protocol.
47
49
  def self.load_file(filename, obj)
48
50
  File.open(filename) do |f|
49
51
  d = self.new(Marshal.load(f))
@@ -72,7 +74,7 @@ module Chainer
72
74
  return dataset
73
75
  elsif value.instance_of?(String)
74
76
  return dataset
75
- elsif value.is_a?(Numo::NArray)
77
+ elsif Chainer.array?(value)
76
78
  value.store(dataset)
77
79
  return value
78
80
  elsif value.is_a?(TrueClass) || value.is_a?(FalseClass)