ruby-dnn 0.6.7 → 0.6.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/API-Reference.ja.md +24 -6
- data/examples/mnist_conv2d_example.rb +0 -1
- data/examples/mnist_lstm_example.rb +0 -1
- data/examples/xor_example.rb +0 -1
- data/lib/dnn/core/optimizers.rb +44 -11
- data/lib/dnn/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 95749384aa75b235acbae6d8be2ff595120eef4e4737470ade303c610ea448f4
|
4
|
+
data.tar.gz: fdaff7544dbfbb8ef7d8f7cdb6676d33fc99d88f4d7f79f8c6f4c7564ecb58c5
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e537e7a795c39f61933af29730256e9f523879620075fccf125260b61f8ccd623b3ddc6a0bca5e0d5d07ad1f5aeee1129ad60ece5359211d25e98cbc92b13ade
|
7
|
+
data.tar.gz: 89ffc6e14c02476f683eb4057c1192c9e80a1a8f79b08a577378cef8f7b58c62992351ab01680d64a6c0339e853a5f0131037621bba975d590493dce0e98507a
|
data/API-Reference.ja.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
ruby-dnnのAPIリファレンスです。このリファレンスでは、APIを利用するうえで必要となるクラスとメソッドしか記載していません。
|
3
3
|
そのため、プログラムの詳細が必要な場合は、ソースコードを参照してください。
|
4
4
|
|
5
|
-
最終更新バージョン:0.6.
|
5
|
+
最終更新バージョン:0.6.8
|
6
6
|
|
7
7
|
# module DNN
|
8
8
|
ruby-dnnの名前空間をなすモジュールです。
|
@@ -757,18 +757,36 @@ RMSPropによるオプティマイザです。
|
|
757
757
|
|
758
758
|
## 【Properties】
|
759
759
|
|
760
|
-
## attr_accessor :
|
761
|
-
Float
|
760
|
+
## attr_accessor :alpha
|
761
|
+
Float alpha
|
762
762
|
指数平均移動のための係数。
|
763
763
|
|
764
764
|
## 【Instance methods】
|
765
765
|
|
766
|
-
## def initialize(learning_rate = 0.001,
|
766
|
+
## def initialize(learning_rate = 0.001, alpha: 0.9)
|
767
767
|
コンストラクタ。
|
768
768
|
### arguments
|
769
769
|
* Float learning_rate
|
770
770
|
学習率。
|
771
|
-
* Float
|
771
|
+
* Float alpha
|
772
|
+
指数平均移動のための係数。
|
773
|
+
|
774
|
+
|
775
|
+
# class AdaDelta < Optimizer
|
776
|
+
AdaDeltaによるオプティマイザです。
|
777
|
+
|
778
|
+
## 【Properties】
|
779
|
+
|
780
|
+
## attr_accessor :rho
|
781
|
+
Float rho
|
782
|
+
指数平均移動のための係数。
|
783
|
+
|
784
|
+
## 【Instance methods】
|
785
|
+
|
786
|
+
## def initialize(rho: 0.95)
|
787
|
+
コンストラクタ。
|
788
|
+
### arguments
|
789
|
+
* Float rho
|
772
790
|
指数平均移動のための係数。
|
773
791
|
|
774
792
|
|
@@ -787,7 +805,7 @@ Float beta2
|
|
787
805
|
|
788
806
|
## 【Instance methods】
|
789
807
|
|
790
|
-
## def initialize(learning_rate = 0.001, beta1
|
808
|
+
## def initialize(learning_rate = 0.001, beta1: 0.9, beta2: 0.999)
|
791
809
|
コンストラクタ。
|
792
810
|
### arguments
|
793
811
|
* Float beta1
|
data/examples/xor_example.rb
CHANGED
data/lib/dnn/core/optimizers.rb
CHANGED
@@ -95,15 +95,15 @@ module DNN
|
|
95
95
|
|
96
96
|
|
97
97
|
class RMSProp < Optimizer
|
98
|
-
attr_accessor :
|
98
|
+
attr_accessor :alpha
|
99
99
|
|
100
100
|
def self.load_hash(hash)
|
101
|
-
self.new(hash[:learning_rate], hash[:
|
101
|
+
self.new(hash[:learning_rate], alpha: hash[:alpha])
|
102
102
|
end
|
103
103
|
|
104
|
-
def initialize(learning_rate = 0.001,
|
104
|
+
def initialize(learning_rate = 0.001, alpha: 0.9)
|
105
105
|
super(learning_rate)
|
106
|
-
@
|
106
|
+
@alpha = alpha
|
107
107
|
@g = {}
|
108
108
|
end
|
109
109
|
|
@@ -111,13 +111,46 @@ module DNN
|
|
111
111
|
@g[layer] ||= {}
|
112
112
|
layer.params.each_key do |key|
|
113
113
|
@g[layer][key] ||= 0
|
114
|
-
@g[layer][key] = @
|
114
|
+
@g[layer][key] = @alpha * @g[layer][key] + (1 - @alpha) * layer.grads[key]**2
|
115
115
|
layer.params[key] -= (@learning_rate / Xumo::NMath.sqrt(@g[layer][key] + 1e-7)) * layer.grads[key]
|
116
116
|
end
|
117
117
|
end
|
118
118
|
|
119
119
|
def to_hash
|
120
|
-
super({
|
120
|
+
super({alpha: @alpha})
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
|
125
|
+
class AdaDelta < Optimizer
|
126
|
+
attr_accessor :rho
|
127
|
+
|
128
|
+
def self.load_hash(hash)
|
129
|
+
self.new(rho: hash[:rho])
|
130
|
+
end
|
131
|
+
|
132
|
+
def initialize(rho: 0.95)
|
133
|
+
super(nil)
|
134
|
+
@rho = rho
|
135
|
+
@h = {}
|
136
|
+
@s = {}
|
137
|
+
end
|
138
|
+
|
139
|
+
def update(layer)
|
140
|
+
@h[layer] ||= {}
|
141
|
+
@s[layer] ||= {}
|
142
|
+
layer.params.each_key do |key|
|
143
|
+
@h[layer][key] ||= Xumo::SFloat.zeros(*layer.params[key].shape)
|
144
|
+
@s[layer][key] ||= Xumo::SFloat.zeros(*layer.params[key].shape)
|
145
|
+
@h[layer][key] = @rho * @h[layer][key] + (1 - @rho) * layer.grads[key]**2
|
146
|
+
v = (Xumo::NMath.sqrt(@s[layer][key] + 1e-6) / Xumo::NMath.sqrt(@h[layer][key] + 1e-6)) * layer.grads[key]
|
147
|
+
@s[layer][key] = @rho * @s[layer][key] + (1 - @rho) * v**2
|
148
|
+
layer.params[key] -= v
|
149
|
+
end
|
150
|
+
end
|
151
|
+
|
152
|
+
def to_hash
|
153
|
+
super({rho: @rho})
|
121
154
|
end
|
122
155
|
end
|
123
156
|
|
@@ -125,8 +158,12 @@ module DNN
|
|
125
158
|
class Adam < Optimizer
|
126
159
|
attr_accessor :beta1
|
127
160
|
attr_accessor :beta2
|
161
|
+
|
162
|
+
def self.load_hash(hash)
|
163
|
+
self.new(hash[:learning_rate], beta1: hash[:beta1], beta2: hash[:beta2])
|
164
|
+
end
|
128
165
|
|
129
|
-
def initialize(learning_rate = 0.001, beta1
|
166
|
+
def initialize(learning_rate = 0.001, beta1: 0.9, beta2: 0.999)
|
130
167
|
super(learning_rate)
|
131
168
|
@beta1 = beta1
|
132
169
|
@beta2 = beta2
|
@@ -135,10 +172,6 @@ module DNN
|
|
135
172
|
@v = {}
|
136
173
|
end
|
137
174
|
|
138
|
-
def self.load_hash(hash)
|
139
|
-
self.new(hash[:learning_rate], hash[:beta1], hash[:beta2])
|
140
|
-
end
|
141
|
-
|
142
175
|
def update(layer)
|
143
176
|
@iter += 1
|
144
177
|
@m[layer] ||= {}
|
data/lib/dnn/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: ruby-dnn
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.6.
|
4
|
+
version: 0.6.8
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- unagiootoro
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-
|
11
|
+
date: 2018-09-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|