ruby-dnn 0.15.2 → 0.15.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ebd331d481ec073102bb4b8c79a6ef280741deaf147a7908ef83f35603f784b9
4
- data.tar.gz: c65399f9f29b65be2ffdc179047f58ae17d6b598328163fc45bb2da5f5f64751
3
+ metadata.gz: b213164ca3e4a7d781a673c4c49ec8dfbd6713468687a98c3739c0cf5629de73
4
+ data.tar.gz: 92a8133dc0ab085d387199f92c2a602bb3782b810c20a931f7ee8de6e247f2a4
5
5
  SHA512:
6
- metadata.gz: fb50c04b888f520d441de0643a6c9870be6df0c7e6d5c54f7c2b8e441e339723945f427cef67c51dc4d250bf0f02c20150e0ca47e1c13974fbe19d272f2301fb
7
- data.tar.gz: 696f4fcbe01aa29f31ea08d32c7e399789ef841e22bfd77029701f132c10687ea81f25fe56e3044de95e4e8cd4c86b947fc32b257896f4dbcc3c125710d52ab8
6
+ metadata.gz: ea6652994f71d97142fe7357947701eff5f0c26708698a670a148e995fb10816b75b62b0bd60004d9ea99b6c837e5b599e790982ad12d010394e4ec7d6962135
7
+ data.tar.gz: 9d474a7bdd120f0e39fcb05e11f0f8370100dd21e79440fe0408bbccf676890726623f516ada78ae6b8829bcf0799ceeea78218f64d74a70ca7f5dd9286bdae3
@@ -27,7 +27,8 @@ module DNN
27
27
 
28
28
  # This callback wrap the lambda function.
29
29
  class LambdaCallback < Callback
30
- def initialize(event, lambda)
30
+ def initialize(event, lambda = nil, &block)
31
+ lambda = block unless lambda
31
32
  instance_eval do
32
33
  define_singleton_method(event) { lambda.call }
33
34
  end
@@ -37,15 +38,19 @@ module DNN
37
38
  # A callback that save the model at the after of the epoch.
38
39
  # @param [String] base_file_name Base file name for saving.
39
40
  # @param [Boolean] include_model When set a true, save data included model structure.
41
+ # @param [Integer] interval Save interval.
40
42
  class CheckPoint < Callback
41
- def initialize(base_file_name, include_model: true)
43
+ def initialize(base_file_name, include_model: true, interval: 1)
42
44
  @base_file_name = base_file_name
43
45
  @include_model = include_model
46
+ @interval = interval
44
47
  end
45
48
 
46
49
  def after_epoch
47
50
  saver = Savers::MarshalSaver.new(@model, include_model: @include_model)
48
- saver.save(@base_file_name + "_epoch#{model.last_log[:epoch]}.marshal")
51
+ if @model.last_log[:epoch] % @interval == 0
52
+ saver.save(@base_file_name + "_epoch#{model.last_log[:epoch]}.marshal")
53
+ end
49
54
  end
50
55
  end
51
56
 
@@ -17,7 +17,7 @@ module DNN
17
17
  end
18
18
  loss_value = forward(y, t)
19
19
  loss_value += regularizers_forward(layers) if layers
20
- loss_value.is_a?(Float) ? loss_value : loss_value.sum
20
+ loss_value
21
21
  end
22
22
 
23
23
  def forward(y, t)
@@ -65,8 +65,7 @@ module DNN
65
65
 
66
66
  class MeanSquaredError < Loss
67
67
  def forward(y, t)
68
- batch_size = t.shape[0]
69
- 0.5 * ((y - t)**2).sum / batch_size
68
+ 0.5 * ((y - t)**2).mean(0).sum
70
69
  end
71
70
 
72
71
  def backward(y, t)
@@ -76,8 +75,7 @@ module DNN
76
75
 
77
76
  class MeanAbsoluteError < Loss
78
77
  def forward(y, t)
79
- batch_size = t.shape[0]
80
- (y - t).abs.sum / batch_size
78
+ (y - t).abs.mean(0).sum
81
79
  end
82
80
 
83
81
  def backward(y, t)
@@ -91,7 +89,7 @@ module DNN
91
89
  class Hinge < Loss
92
90
  def forward(y, t)
93
91
  @a = 1 - y * t
94
- Xumo::SFloat.maximum(0, @a)
92
+ Xumo::SFloat.maximum(0, @a).mean(0).sum
95
93
  end
96
94
 
97
95
  def backward(y, t)
@@ -119,13 +117,11 @@ module DNN
119
117
  private
120
118
 
121
119
  def loss_l1(y, t)
122
- batch_size = t.shape[0]
123
- (y - t).abs.sum / batch_size
120
+ (y - t).abs.mean(0).sum
124
121
  end
125
122
 
126
123
  def loss_l2(y, t)
127
- batch_size = t.shape[0]
128
- 0.5 * ((y - t)**2).sum / batch_size
124
+ 0.5 * ((y - t)**2).mean(0).sum
129
125
  end
130
126
  end
131
127
 
@@ -147,8 +143,7 @@ module DNN
147
143
 
148
144
  def forward(y, t)
149
145
  @x = SoftmaxCrossEntropy.softmax(y)
150
- batch_size = t.shape[0]
151
- -(t * Xumo::NMath.log(@x + @eps)).sum / batch_size
146
+ -(t * Xumo::NMath.log(@x + @eps)).mean(0).sum
152
147
  end
153
148
 
154
149
  def backward(y, t)
@@ -182,7 +177,7 @@ module DNN
182
177
 
183
178
  def forward(y, t)
184
179
  @x = SigmoidCrossEntropy.sigmoid(y)
185
- -(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps))
180
+ -(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps)).mean(0).sum
186
181
  end
187
182
 
188
183
  def backward(y, t)
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.15.2"
2
+ VERSION = "0.15.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.15.2
4
+ version: 0.15.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-12-01 00:00:00.000000000 Z
11
+ date: 2019-12-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray