ruby-dnn 0.15.2 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ebd331d481ec073102bb4b8c79a6ef280741deaf147a7908ef83f35603f784b9
4
- data.tar.gz: c65399f9f29b65be2ffdc179047f58ae17d6b598328163fc45bb2da5f5f64751
3
+ metadata.gz: b213164ca3e4a7d781a673c4c49ec8dfbd6713468687a98c3739c0cf5629de73
4
+ data.tar.gz: 92a8133dc0ab085d387199f92c2a602bb3782b810c20a931f7ee8de6e247f2a4
5
5
  SHA512:
6
- metadata.gz: fb50c04b888f520d441de0643a6c9870be6df0c7e6d5c54f7c2b8e441e339723945f427cef67c51dc4d250bf0f02c20150e0ca47e1c13974fbe19d272f2301fb
7
- data.tar.gz: 696f4fcbe01aa29f31ea08d32c7e399789ef841e22bfd77029701f132c10687ea81f25fe56e3044de95e4e8cd4c86b947fc32b257896f4dbcc3c125710d52ab8
6
+ metadata.gz: ea6652994f71d97142fe7357947701eff5f0c26708698a670a148e995fb10816b75b62b0bd60004d9ea99b6c837e5b599e790982ad12d010394e4ec7d6962135
7
+ data.tar.gz: 9d474a7bdd120f0e39fcb05e11f0f8370100dd21e79440fe0408bbccf676890726623f516ada78ae6b8829bcf0799ceeea78218f64d74a70ca7f5dd9286bdae3
@@ -27,7 +27,8 @@ module DNN
27
27
 
28
28
  # This callback wrap the lambda function.
29
29
  class LambdaCallback < Callback
30
- def initialize(event, lambda)
30
+ def initialize(event, lambda = nil, &block)
31
+ lambda = block unless lambda
31
32
  instance_eval do
32
33
  define_singleton_method(event) { lambda.call }
33
34
  end
@@ -37,15 +38,19 @@ module DNN
37
38
  # A callback that save the model at the after of the epoch.
38
39
  # @param [String] base_file_name Base file name for saving.
39
40
  # @param [Boolean] include_model When set a true, save data included model structure.
41
+ # @param [Integer] interval Save interval.
40
42
  class CheckPoint < Callback
41
- def initialize(base_file_name, include_model: true)
43
+ def initialize(base_file_name, include_model: true, interval: 1)
42
44
  @base_file_name = base_file_name
43
45
  @include_model = include_model
46
+ @interval = interval
44
47
  end
45
48
 
46
49
  def after_epoch
47
50
  saver = Savers::MarshalSaver.new(@model, include_model: @include_model)
48
- saver.save(@base_file_name + "_epoch#{model.last_log[:epoch]}.marshal")
51
+ if @model.last_log[:epoch] % @interval == 0
52
+ saver.save(@base_file_name + "_epoch#{model.last_log[:epoch]}.marshal")
53
+ end
49
54
  end
50
55
  end
51
56
 
@@ -17,7 +17,7 @@ module DNN
17
17
  end
18
18
  loss_value = forward(y, t)
19
19
  loss_value += regularizers_forward(layers) if layers
20
- loss_value.is_a?(Float) ? loss_value : loss_value.sum
20
+ loss_value
21
21
  end
22
22
 
23
23
  def forward(y, t)
@@ -65,8 +65,7 @@ module DNN
65
65
 
66
66
  class MeanSquaredError < Loss
67
67
  def forward(y, t)
68
- batch_size = t.shape[0]
69
- 0.5 * ((y - t)**2).sum / batch_size
68
+ 0.5 * ((y - t)**2).mean(0).sum
70
69
  end
71
70
 
72
71
  def backward(y, t)
@@ -76,8 +75,7 @@ module DNN
76
75
 
77
76
  class MeanAbsoluteError < Loss
78
77
  def forward(y, t)
79
- batch_size = t.shape[0]
80
- (y - t).abs.sum / batch_size
78
+ (y - t).abs.mean(0).sum
81
79
  end
82
80
 
83
81
  def backward(y, t)
@@ -91,7 +89,7 @@ module DNN
91
89
  class Hinge < Loss
92
90
  def forward(y, t)
93
91
  @a = 1 - y * t
94
- Xumo::SFloat.maximum(0, @a)
92
+ Xumo::SFloat.maximum(0, @a).mean(0).sum
95
93
  end
96
94
 
97
95
  def backward(y, t)
@@ -119,13 +117,11 @@ module DNN
119
117
  private
120
118
 
121
119
  def loss_l1(y, t)
122
- batch_size = t.shape[0]
123
- (y - t).abs.sum / batch_size
120
+ (y - t).abs.mean(0).sum
124
121
  end
125
122
 
126
123
  def loss_l2(y, t)
127
- batch_size = t.shape[0]
128
- 0.5 * ((y - t)**2).sum / batch_size
124
+ 0.5 * ((y - t)**2).mean(0).sum
129
125
  end
130
126
  end
131
127
 
@@ -147,8 +143,7 @@ module DNN
147
143
 
148
144
  def forward(y, t)
149
145
  @x = SoftmaxCrossEntropy.softmax(y)
150
- batch_size = t.shape[0]
151
- -(t * Xumo::NMath.log(@x + @eps)).sum / batch_size
146
+ -(t * Xumo::NMath.log(@x + @eps)).mean(0).sum
152
147
  end
153
148
 
154
149
  def backward(y, t)
@@ -182,7 +177,7 @@ module DNN
182
177
 
183
178
  def forward(y, t)
184
179
  @x = SigmoidCrossEntropy.sigmoid(y)
185
- -(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps))
180
+ -(t * Xumo::NMath.log(@x + @eps) + (1 - t) * Xumo::NMath.log(1 - @x + @eps)).mean(0).sum
186
181
  end
187
182
 
188
183
  def backward(y, t)
@@ -1,3 +1,3 @@
1
1
  module DNN
2
- VERSION = "0.15.2"
2
+ VERSION = "0.15.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: ruby-dnn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.15.2
4
+ version: 0.15.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - unagiootoro
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-12-01 00:00:00.000000000 Z
11
+ date: 2019-12-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray