nn 1.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: ae8f441634517a886dbdcee8b7186c5d02710c84dc52b381bb29d0d219f958f3
4
+ data.tar.gz: 3af1f9f95a8727aec20100b1d82ef598dda3857915641a3b73ba16cf372841de
5
+ SHA512:
6
+ metadata.gz: 2774e79cddbc52530d9f00cd34bb981f7beaa7f6d1c402b6205a41dbe7949fb40dedd2e6c97e9f1a38abc64b5e3bf1bffc40843c943bd0f88dcba2e9bd52f202
7
+ data.tar.gz: c45b37a70bfddb2a31f1682d4268e943098e228b16102d61b74d2d26dd526b2fd491c4ec4d7a4758cfe838a7511061d4e87b263652acb46a5e29e14d298676bf
@@ -0,0 +1,8 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /_yardoc/
4
+ /coverage/
5
+ /doc/
6
+ /pkg/
7
+ /spec/reports/
8
+ /tmp/
data/Gemfile ADDED
@@ -0,0 +1,6 @@
1
+ source "https://rubygems.org"
2
+
3
+ git_source(:github) {|repo_name| "https://github.com/#{repo_name}" }
4
+
5
+ # Specify your gem's dependencies in nn.gemspec
6
+ gemspec
@@ -0,0 +1,21 @@
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2018 unagiootoro
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
@@ -0,0 +1,17 @@
1
+ # ruby-nn
2
+
3
+ ruby-nnは、rubyで書かれたニューラルネットワークライブラリです。
4
+ python向けの本格的なディープラーニングライブラリと比べると、性能や機能面で、大きく見劣りしますが、
5
+ MNISTで98%以上の精度を出せるぐらいの性能はあります。
6
+
7
+ ## インストール
8
+
9
+ $ gem install nn
10
+
11
+ ## 使用法
12
+
13
+ 付属のdocument.txtを参照してください。
14
+
15
+ ## ライセンス
16
+
17
+ この宝石は、[MITライセンス](https://opensource.org/licenses/MIT)の条件でオープンソースとして入手できます。
@@ -0,0 +1,2 @@
1
+ require "bundler/gem_tasks"
2
+ task :default => :spec
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "bundler/setup"
4
+ require "nn"
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require "pry"
11
+ # Pry.start
12
+
13
+ require "irb"
14
+ IRB.start(__FILE__)
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
@@ -0,0 +1,212 @@
1
+ ruby-nnは、rubyで書かれたニューラルネットワークライブラリです。
2
+ python向けの本格的なディープラーニングライブラリと比べると、性能や機能面で、大きく見劣りしますが、
3
+ MNISTで98%以上の精度を出せるぐらいの性能はあります。
4
+
5
+ なお、ruby-nnはNumo/NArrayを使用しています。
6
+ そのため、ruby-nnの使用には、Numo/NArrayのインストールが必要です。
7
+
8
+
9
+ [リファレンス]
10
+
11
+ class NN
12
+ ニューラルネットワークを扱うクラスです。
13
+
14
+ <クラスメソッド>
15
+ load(file_name) : NN
16
+ JSON形式で保存された学習結果を読み込みます。
17
+ String file_name 読み込むJSONファイル名
18
+ 戻り値 NNのインスタンス
19
+
20
+ <プロパティ>
21
+ Array<SFloat> weights ネットワークの重みをSFloat形式で取得します。
22
+ Array<SFloat> biases ネットワークのバイアスをSFloat形式で取得します。
23
+ Array<Float> gammas バッチノーマライゼーションを使用している場合、gammaを取得します。
24
+ Array<Float> betas バッチノーマライゼーションを使用している場合、betaを取得します。
25
+ Float learning_rate 学習率
26
+ Integer batch_size ミニバッチの数
27
+ Array<Symbol> activation 活性化関数。配列の要素1が中間層の活性化関数で要素2が隠れ層の活性化関数です。
28
+ 中間層には、:sigmoidまたは:relu、出力層には、:identityまたは:softmaxが使用できます。
29
+ Float momentum モーメンタム係数
30
+ Float weight_decay L2正則化項の強さ
31
+ Float dropout_ratio ドロップアウトさせるノードの比率
32
+
33
+ <インスタンスメソッド>
34
+ initialize(num_nodes,
35
+ learning_rate: 0.01,
36
+ batch_size: 1,
37
+ activation: [:relu, :identity],
38
+ momentum: 0,
39
+ weight_decay: 0,
40
+ use_dropout: false,
41
+ dropout_ratio: 0.5,
42
+ use_batch_norm: false)
43
+ オブジェクトを初期化します。
44
+ Array<Integer> num_nodes 各層のノード数
45
+ Float learning_rate 学習率
46
+ Integer batch_size ミニバッチの数
47
+ Array<Symbol> activation 活性化関数。配列の要素1が中間層の活性化関数で要素2が隠れ層の活性化関数です。
48
+ 中間層には、:sigmoidまたは:relu、出力層には、:identityまたは:softmaxが使用できます。
49
+ Float momentum モーメンタム係数
50
+ Float weight_decay L2正則化項の強さ
51
+ bool use_dropout ドロップアウトを使用するか否か
52
+ Float dropout_ratio ドロップアウトさせるノードの比率
53
+ bool use_batch_norm バッチノーマライゼーションを使用するか否か
54
+
55
+ train(x_train, y_train, x_test, y_test, epoch,
56
+ save_dir: nil,
57
+ save_interval: 1,
58
+ test: nil,
59
+ border: nil,
60
+ tolerance: 0.5,
61
+ &block) : void
62
+ 学習を行います。
63
+ Array<Array<Numeric>> | SFloat x_train トレーニング用入力データ。
64
+ Array<Array<Numeric>> | SFloat y_train トレーニング用正解データ。
65
+ Integer epoch 学習回数。入力データすべてを見たタイミングを1エポックとします。
66
+ String save_dir 学習中にセーブを行う場合、セーブするディレクトリを指定します。nilの場合、セーブを行いません。
67
+ Integer save_interval 学習中にセーブするタイミングをエポック単位で指定します。
68
+ Array<Array<Array<Numeric>> | SFloat> test テストで使用するデータ。[x_test, y_test]の形式で指定してください。
69
+ nilを指定すると、エポックごとにテストを行いません。
70
+ Float border 学習の早期終了判定に使用するテストデータの正答率。
71
+ nilの場合、学習の早期終了を行いません。
72
+ Proc &block(SFloat x, SFloat y) : Array<SFloat> 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の
73
+ 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。
74
+
75
+ test(x_test, y_test, tolerance = 0.5, &block) : Float
76
+ テストデータを用いて、テストを行います。
77
+ Array<Array<Numeric>> | SFloat x_train テスト用入力データ。
78
+ Array<Array<Numeric>> | SFloat y_train テスト用正解データ。
79
+ Float tolerance 許容する誤差。出力層の活性化関数が:identityの場合に使用します。
80
+ 例えば出力が0.7で正解が1.0の場合、toleranceが0.4なら合格となり、0.2なら不合格となります。
81
+ Proc &block(SFloat x, SFloat y) : Array<SFloat> 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の
82
+ 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。
83
+ 戻り値 テストデータの正答率。
84
+
85
+ accurate(x_test, y_test, tolera)
86
+ テストデータを用いて、テストデータの正答率を取得します。
87
+ Array<Array<Numeric>> | SFloat x_train テスト用入力データ。
88
+ Array<Array<Numeric>> | SFloat y_train テスト用正解データ。
89
+ Float tolerance 許容する誤差。出力層の活性化関数が:identityの場合に使用します。
90
+ 例えば出力が0.7で正解が1.0の場合、toleranceが0.4なら合格となり、0.2なら不合格となります。
91
+ Proc &block(SFloat x, SFloat y) : Array<SFloat> 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の
92
+ 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。
93
+ 戻り値 テストデータの正答率。
94
+
95
+ learn(x_train, y_train, &block) : Float
96
+ 入力データを元に、1回だけ学習を行います。途中で学習を切り上げるなど、柔軟な学習を行いたい場合に使用します。
97
+ Array<Array<Numeric>> | SFloat x_train 入力データ
98
+ Array<Array<Numeric>> | SFloat y_train 正解データ
99
+ Proc &block(SFloat x, SFloat y) : Array<SFloat> 入力層のミニバッチを取得します。ブロックの戻り値は、ミニバッチを[x, y]の
100
+ 形で指定してください。入力層をミニバッチ単位で正規化したい場合に使用します。
101
+ 戻り値 誤差関数の値。誤差関数は、出力層の活性化関数が:identityの場合、二乗和誤差が、
102
+ :softmaxの場合、クロスエントロピー誤差が使用されます。なお、L2正則化を使用する場合、
103
+ 誤差関数の値には正則化項の値が含まれます。
104
+
105
+ run(x) : Array<Array<Numeric>>
106
+ 入力データから出力値を二次元配列で得ます。
107
+ Array<Array<Float>> | SFloat x 入力データ
108
+ 戻り値 出力ノードの値
109
+
110
+ save(file_name) : void
111
+ 学習結果をJSON形式で保存します。
112
+ String file_name 書き込むJSONファイル名
113
+
114
+
115
+ [サンプル1 XOR]
116
+
117
+ #ライブラリの読み込み
118
+ require "nn"
119
+
120
+ x = [
121
+ [0, 0],
122
+ [1, 0],
123
+ [0, 1],
124
+ [1, 1],
125
+ ]
126
+
127
+ y = [[0], [1], [1], [0]]
128
+
129
+ #ニューラルネットワークの初期化
130
+ nn = NN.new([2, 4, 1], #ノード数
131
+ learning_rate: 0.1, #学習率
132
+ batch_size: 4, #ミニバッチの数
133
+ activation: [:sigmoid, :identity] #活性化関数
134
+ )
135
+
136
+ #学習を行う
137
+ nn.train(x, y, 20000)
138
+
139
+ #学習結果の確認
140
+ p nn.run(x)
141
+
142
+
143
+ [MNISTデータを読み込む]
144
+ MNISTをRubyでも簡単に試せるよう、MNISTを扱うためのモジュールを用意しました。
145
+ 次のリンク(http://yann.lecun.com/exdb/mnist/)から、
146
+ train-images-idx3-ubyte.gz
147
+ train-labels-idx1-ubyte.gz
148
+ t10k-images-idx3-ubyte.gz
149
+ t10k-labels-idx1-ubyte.gz
150
+ の4つのファイルをダウンロードし、実行するRubyファイルと同じ階層のmnistディレクトリに格納したうえで、使用してください。
151
+
152
+ MNIST.load_trainで学習用データを読み込み、MNIST.load_testでテスト用データを読み込みます。
153
+ また、MNIST.categorycalを使用すると、正解データを10クラスにカテゴライズされた上で、配列形式で返します。
154
+ (RubyでのMNISTの読み込みは、以下のリンクを参考にさせていただきました。)
155
+ http://d.hatena.ne.jp/n_shuyo/20090913/mnist
156
+
157
+
158
+ [サンプル2 MNIST]
159
+
160
+ #ライブラリの読み込み
161
+ require "nn"
162
+ require "nn/mnist"
163
+
164
+ #MNISTのトレーニング用データを読み込む
165
+ x_train, y_train = MNIST.load_train
166
+
167
+ #y_trainを10クラスに配列でカテゴライズする
168
+ y_train = MNIST.categorical(y_train)
169
+
170
+ #MNISTのテスト用データを読み込む
171
+ x_test, y_test = MNIST.load_test
172
+
173
+ #y_testを10クラスにカテゴライズする
174
+ y_test = MNIST.categorical(y_test)
175
+
176
+ puts "load mnist"
177
+
178
+ #ニューラルネットワークの初期化
179
+ nn = NN.new([784, 100, 100, 10], #ノード数
180
+ learning_rate: 0.1, #学習率
181
+ batch_size: 100, #ミニバッチの数
182
+ activation: [:relu, :softmax], #活性化関数
183
+ momentum: 0.9, #モーメンタム係数
184
+ use_batch_norm: true, #バッチノーマライゼーションを使用する
185
+ )
186
+
187
+ #学習を行う
188
+ nn.train(x_train, y_train, 10, test: [x_test, y_test]) do |x_batch, y_batch|
189
+ x_batch /= 255 #ミニバッチを0~1の範囲で正規化
190
+ [x_batch, y_batch]
191
+ end
192
+
193
+ #学習結果のテストを行う
194
+ nn.test(x_test, y_test) do |x_batch, y_batch|
195
+ x_batch /= 255 #ミニバッチを0~1の範囲で正規化
196
+ [x_batch, y_batch]
197
+ end
198
+
199
+
200
+ [お断り]
201
+ 作者は、ニューラルネットワークを勉強し始めたばかりの初心者です。
202
+ そのため、バグや実装のミスもあるかと思いますが、温かい目で見守っていただけると、幸いでございます。
203
+
204
+
205
+ [更新履歴]
206
+ 2018/3/8 バージョン1.0公開
207
+ 2018/3/11 バージョン1.1公開
208
+ 2018/3/13 バージョン1.2公開
209
+ 2018/3/14 バージョン1.3公開
210
+ 2018/3/18 バージョン1.4公開
211
+ 2018/3/22 バージョン1.5公開
212
+ 2018/3/27 RubyGemに公開
@@ -0,0 +1,430 @@
1
+ require "numo/narray"
2
+ require "json"
3
+
4
+ class NN
5
+ VERSION = "1.5"
6
+
7
+ include Numo
8
+
9
+ attr_accessor :weights
10
+ attr_accessor :biases
11
+ attr_accessor :gammas
12
+ attr_accessor :betas
13
+ attr_accessor :learning_rate
14
+ attr_accessor :batch_size
15
+ attr_accessor :activation
16
+ attr_accessor :momentum
17
+ attr_accessor :weight_decay
18
+ attr_accessor :dropout_ratio
19
+ attr_reader :training
20
+
21
+ def initialize(num_nodes,
22
+ learning_rate: 0.01,
23
+ batch_size: 1,
24
+ activation: %i(relu identity),
25
+ momentum: 0,
26
+ weight_decay: 0,
27
+ use_dropout: false,
28
+ dropout_ratio: 0.5,
29
+ use_batch_norm: false)
30
+ SFloat.srand(rand(2 ** 64))
31
+ @num_nodes = num_nodes
32
+ @learning_rate = learning_rate
33
+ @batch_size = batch_size
34
+ @activation = activation
35
+ @momentum = momentum
36
+ @weight_decay = weight_decay
37
+ @use_dropout = use_dropout
38
+ @dropout_ratio = dropout_ratio
39
+ @use_batch_norm = use_batch_norm
40
+ init_weight_and_bias
41
+ init_gamma_and_beta if @use_batch_norm
42
+ @training = true
43
+ init_layers
44
+ end
45
+
46
+ def self.load(file_name)
47
+ json = JSON.parse(File.read(file_name))
48
+ nn = self.new(json["num_nodes"],
49
+ learning_rate: json["learning_rate"],
50
+ batch_size: json["batch_size"],
51
+ activation: json["activation"].map(&:to_sym),
52
+ momentum: json["momentum"],
53
+ weight_decay: json["weight_decay"],
54
+ use_dropout: json["use_dropout"],
55
+ dropout_ratio: json["dropout_ratio"],
56
+ use_batch_norm: json["use_batch_norm"],
57
+ )
58
+ nn.weights = json["weights"].map{|weight| SFloat.cast(weight)}
59
+ nn.biases = json["biases"].map{|bias| SFloat.cast(bias)}
60
+ if json["use_batch_norm"]
61
+ nn.gammas = json["gammas"].map{|gamma| SFloat.cast(gamma)}
62
+ nn.betas = json["betas"].map{|beta| SFloat.cast(beta)}
63
+ end
64
+ nn
65
+ end
66
+
67
+ def train(x_train, y_train, epoch,
68
+ save_dir: nil, save_interval: 1, test: nil, border: nil, tolerance: 0.5, &block)
69
+ num_train_data = x_train.is_a?(SFloat) ? x_train.shape[0] : x_train.length
70
+ (epoch * num_train_data / @batch_size).times do |count|
71
+ loss = learn(x_train, y_train, &block)
72
+ if loss.nan?
73
+ puts "loss is nan"
74
+ break
75
+ end
76
+ if (count + 1) % (num_train_data / @batch_size) == 0
77
+ now_epoch = (count + 1) / (num_train_data / @batch_size)
78
+ if save_dir && now_epoch % save_interval == 0
79
+ save("#{save_dir}/epoch#{now_epoch}.json")
80
+ end
81
+ msg = "epoch #{now_epoch}/#{epoch} loss: #{loss}"
82
+ if test
83
+ acc = accurate(*test, tolerance, &block)
84
+ puts "#{msg} accurate: #{acc}"
85
+ break if border && acc >= border
86
+ else
87
+ puts msg
88
+ end
89
+ end
90
+ end
91
+ end
92
+
93
+ def test(x_test, y_test, tolerance = 0.5, &block)
94
+ acc = accurate(x_test, y_test, tolerance, &block)
95
+ puts "accurate: #{acc}"
96
+ acc
97
+ end
98
+
99
+ def accurate(x_test, y_test, tolerance = 0.5, &block)
100
+ correct = 0
101
+ num_test_data = x_test.is_a?(SFloat) ? x_test.shape[0] : x_test.length
102
+ (num_test_data / @batch_size).times do |i|
103
+ x = SFloat.zeros(@batch_size, @num_nodes.first)
104
+ y = SFloat.zeros(@batch_size, @num_nodes.last)
105
+ @batch_size.times do |j|
106
+ k = i * @batch_size + j
107
+ if x_test.is_a?(SFloat)
108
+ x[j, true] = x_test[k, true]
109
+ y[j, true] = y_test[k, true]
110
+ else
111
+ x[j, true] = SFloat.cast(x_test[k])
112
+ y[j, true] = SFloat.cast(y_test[k])
113
+ end
114
+ end
115
+ x, y = block.call(x, y) if block
116
+ out = forward(x, false)
117
+ @batch_size.times do |j|
118
+ vout = out[j, true]
119
+ vy = y[j, true]
120
+ case @activation[1]
121
+ when :identity
122
+ correct += 1 unless (NMath.sqrt((vout - vy) ** 2) < tolerance).to_a.include?(0)
123
+ when :softmax
124
+ correct += 1 if vout.max_index == vy.max_index
125
+ end
126
+ end
127
+ end
128
+ correct.to_f / num_test_data
129
+ end
130
+
131
+ def learn(x_train, y_train, &block)
132
+ x = SFloat.zeros(@batch_size, @num_nodes.first)
133
+ y = SFloat.zeros(@batch_size, @num_nodes.last)
134
+ @batch_size.times do |i|
135
+ if x_train.is_a?(SFloat)
136
+ r = rand(x_train.shape[0])
137
+ x[i, true] = x_train[r, true]
138
+ y[i, true] = y_train[r, true]
139
+ else
140
+ r = rand(x_train.length)
141
+ x[i, true] = SFloat.cast(x_train[r])
142
+ y[i, true] = SFloat.cast(y_train[r])
143
+ end
144
+ end
145
+ x, y = block.call(x, y) if block
146
+ forward(x)
147
+ backward(y)
148
+ update_weight_and_bias
149
+ update_gamma_and_beta if @use_batch_norm
150
+ @layers[-1].loss(y)
151
+ end
152
+
153
+ def run(x)
154
+ x = SFloat.cast(x) if x.is_a?(Array)
155
+ out = forward(x, false)
156
+ out.to_a
157
+ end
158
+
159
+ def save(file_name)
160
+ json = {
161
+ "version" => VERSION,
162
+ "num_nodes" => @num_nodes,
163
+ "learning_rate" => @learning_rate,
164
+ "batch_size" => @batch_size,
165
+ "activation" => @activation,
166
+ "momentum" => @momentum,
167
+ "weight_decay" => @weight_decay,
168
+ "use_dropout" => @use_dropout,
169
+ "dropout_ratio" => @dropout_ratio,
170
+ "use_batch_norm" => @use_batch_norm,
171
+ "weights" => @weights.map(&:to_a),
172
+ "biases" => @biases.map(&:to_a),
173
+ }
174
+ if @use_batch_norm
175
+ json_batch_norm = {
176
+ "gammas" => @gammas,
177
+ "betas" => @betas
178
+ }
179
+ json.merge!(json_batch_norm)
180
+ end
181
+ File.write(file_name, JSON.dump(json))
182
+ end
183
+
184
+ private
185
+
186
+ def init_weight_and_bias
187
+ @weights = Array.new(@num_nodes.length - 1)
188
+ @biases = Array.new(@num_nodes.length - 1)
189
+ @weight_amounts = Array.new(@num_nodes.length - 1, 0)
190
+ @bias_amounts = Array.new(@num_nodes.length - 1, 0)
191
+ @num_nodes[0...-1].each_index do |i|
192
+ weight = SFloat.new(@num_nodes[i], @num_nodes[i + 1]).rand_norm
193
+ bias = SFloat.new(@num_nodes[i + 1]).rand_norm
194
+ if @activation[0] == :relu
195
+ @weights[i] = weight / Math.sqrt(@num_nodes[i]) * Math.sqrt(2)
196
+ @biases[i] = bias / Math.sqrt(@num_nodes[i]) * Math.sqrt(2)
197
+ else
198
+ @weights[i] = weight / Math.sqrt(@num_nodes[i])
199
+ @biases[i] = bias / Math.sqrt(@num_nodes[i])
200
+ end
201
+ end
202
+ end
203
+
204
+ def init_gamma_and_beta
205
+ @gammas = Array.new(@num_nodes.length - 2, 1)
206
+ @betas = Array.new(@num_nodes.length - 2, 0)
207
+ @gamma_amounts = Array.new(@num_nodes.length - 2, 0)
208
+ @beta_amounts = Array.new(@num_nodes.length - 2, 0)
209
+ end
210
+
211
+
212
+ def init_layers
213
+ @layers = []
214
+ @num_nodes[0...-2].each_index do |i|
215
+ @layers << Affine.new(self, i)
216
+ @layers << BatchNorm.new(self, i) if @use_batch_norm
217
+ @layers << case @activation[0]
218
+ when :sigmoid
219
+ Sigmoid.new
220
+ when :relu
221
+ ReLU.new
222
+ end
223
+ @layers << Dropout.new(self) if @use_dropout
224
+ end
225
+ @layers << Affine.new(self, -1)
226
+ @layers << case @activation[1]
227
+ when :identity
228
+ Identity.new(self)
229
+ when :softmax
230
+ Softmax.new(self)
231
+ end
232
+ end
233
+
234
+ def forward(x, training = true)
235
+ @training = training
236
+ @layers.each do |layer|
237
+ x = layer.forward(x)
238
+ end
239
+ x
240
+ end
241
+
242
+ def backward(y)
243
+ dout = @layers[-1].backward(y)
244
+ @layers[0...-1].reverse.each do |layer|
245
+ dout = layer.backward(dout)
246
+ end
247
+ end
248
+
249
+ def update_weight_and_bias
250
+ @layers.select{|layer| layer.is_a?(Affine)}.each.with_index do |layer, i|
251
+ weight_amount = layer.d_weight.mean(0) * @learning_rate
252
+ @weight_amounts[i] = weight_amount + @momentum * @weight_amounts[i]
253
+ @weights[i] -= @weight_amounts[i]
254
+ bias_amount = layer.d_bias.mean * @learning_rate
255
+ @bias_amounts[i] = bias_amount + @momentum * @bias_amounts[i]
256
+ @biases[i] -= @bias_amounts[i]
257
+ end
258
+ end
259
+
260
+ def update_gamma_and_beta
261
+ @layers.select{|layer| layer.is_a?(BatchNorm)}.each.with_index do |layer, i|
262
+ gamma_amount = layer.d_gamma.mean * @learning_rate
263
+ @gamma_amounts[i] = gamma_amount + @momentum * @gamma_amounts[i]
264
+ @gammas[i] -= @gamma_amounts[i]
265
+ beta_amount = layer.d_beta.mean * @learning_rate
266
+ @beta_amounts[i] = beta_amount + @momentum * @beta_amounts[i]
267
+ @betas[i] -= @beta_amounts[i]
268
+ end
269
+ end
270
+ end
271
+
272
+
273
+ class NN::Affine
274
+ include Numo
275
+
276
+ attr_reader :d_weight
277
+ attr_reader :d_bias
278
+
279
+ def initialize(nn, index)
280
+ @nn = nn
281
+ @index = index
282
+ @d_weight = nil
283
+ @d_bias = nil
284
+ end
285
+
286
+ def forward(x)
287
+ @x = x
288
+ @x.dot(@nn.weights[@index]) + @nn.biases[@index]
289
+ end
290
+
291
+ def backward(dout)
292
+ x = @x.reshape(*@x.shape, 1)
293
+ d_ridge = @nn.weight_decay * @nn.weights[@index]
294
+ @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1])) + d_ridge
295
+ @d_bias = dout
296
+ dout.dot(@nn.weights[@index].transpose)
297
+ end
298
+ end
299
+
300
+
301
+ class NN::Sigmoid
302
+ def forward(x)
303
+ @out = 1.0 / (1 + Numo::NMath.exp(-x))
304
+ end
305
+
306
+ def backward(dout)
307
+ dout * (1.0 - @out) * @out
308
+ end
309
+ end
310
+
311
+
312
+ class NN::ReLU
313
+ def forward(x)
314
+ @x = x.clone
315
+ x[x < 0] = 0
316
+ x
317
+ end
318
+
319
+ def backward(dout)
320
+ @x[@x > 0] = 1.0
321
+ @x[@x <= 0] = 0.0
322
+ dout * @x
323
+ end
324
+ end
325
+
326
+
327
+ class NN::Identity
328
+ include Numo
329
+
330
+ def initialize(nn)
331
+ @nn = nn
332
+ end
333
+
334
+ def forward(x)
335
+ @out = x
336
+ end
337
+
338
+ def backward(y)
339
+ @out - y
340
+ end
341
+
342
+ def loss(y)
343
+ ridge = 0.5 * @nn.weight_decay * @nn.weights.reduce(0){|sum, weight| sum + (weight ** 2).sum}
344
+ 0.5 * ((@out - y) ** 2).sum / @nn.batch_size + ridge
345
+ end
346
+ end
347
+
348
+
349
+ class NN::Softmax
350
+ include Numo
351
+
352
+ def initialize(nn)
353
+ @nn = nn
354
+ end
355
+
356
+ def forward(x)
357
+ @out = NMath.exp(x) / NMath.exp(x).sum(1).reshape(x.shape[0], 1)
358
+ end
359
+
360
+ def backward(y)
361
+ @out - y
362
+ end
363
+
364
+ def loss(y)
365
+ ridge = 0.5 * @nn.weight_decay * @nn.weights.reduce(0){|sum, weight| sum + (weight ** 2).sum}
366
+ -(y * NMath.log(@out + 1e-7)).sum / @nn.batch_size + ridge
367
+ end
368
+ end
369
+
370
+
371
+ class NN::Dropout
372
+ include Numo
373
+
374
+ def initialize(nn)
375
+ @nn = nn
376
+ @mask = nil
377
+ end
378
+
379
+ def forward(x)
380
+ if @nn.training
381
+ @mask = SFloat.ones(*x.shape).rand < @nn.dropout_ratio
382
+ x[@mask] = 0
383
+ else
384
+ x *= (1 - @nn.dropout_ratio)
385
+ end
386
+ x
387
+ end
388
+
389
+ def backward(dout)
390
+ dout[@mask] = 0 if @nn.training
391
+ dout
392
+ end
393
+ end
394
+
395
+
396
+ class NN::BatchNorm
397
+ include Numo
398
+
399
+ attr_reader :d_gamma
400
+ attr_reader :d_beta
401
+
402
+ def initialize(nn, index)
403
+ @nn = nn
404
+ @index = index
405
+ end
406
+
407
+ def forward(x)
408
+ @x = x
409
+ @mean = x.mean(0)
410
+ @xc = x - @mean
411
+ @var = (@xc ** 2).mean(0)
412
+ @std = NMath.sqrt(@var + 1e-7)
413
+ @xn = @xc / @std
414
+ out = @nn.gammas[@index] * @xn + @nn.betas[@index]
415
+ out.reshape(*@x.shape)
416
+ end
417
+
418
+ def backward(dout)
419
+ @d_beta = dout.sum(0)
420
+ @d_gamma = (@xn * dout).sum(0)
421
+ dxn = @nn.gammas[@index] * dout
422
+ dxc = dxn / @std
423
+ dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
424
+ dvar = 0.5 * dstd / @std
425
+ dxc += (2.0 / @nn.batch_size) * @xc * dvar
426
+ dmean = dxc.sum(0)
427
+ dx = dxc - dmean / @nn.batch_size
428
+ dx.reshape(*@x.shape)
429
+ end
430
+ end
@@ -0,0 +1,54 @@
1
+ require "zlib"
2
+
3
+ module MNIST
4
+ def self.load_train
5
+ if File.exist?("mnist/train.marshal")
6
+ marshal = File.binread("mnist/train.marshal")
7
+ Marshal.load(marshal)
8
+ else
9
+ x_train, y_train = load("mnist/train-images-idx3-ubyte.gz", "mnist/train-labels-idx1-ubyte.gz")
10
+ marshal = Marshal.dump([x_train, y_train])
11
+ File.binwrite("mnist/train.marshal", marshal)
12
+ [x_train, y_train]
13
+ end
14
+ end
15
+
16
+ def self.load_test
17
+ if File.exist?("mnist/test.marshal")
18
+ marshal = File.binread("mnist/test.marshal")
19
+ Marshal.load(marshal)
20
+ else
21
+ x_test, y_test = load("mnist/t10k-images-idx3-ubyte.gz", "mnist/t10k-labels-idx1-ubyte.gz")
22
+ marshal = Marshal.dump([x_test, y_test])
23
+ File.binwrite("mnist/test.marshal", marshal)
24
+ [x_test, y_test]
25
+ end
26
+ end
27
+
28
+ def self.categorical(y_data)
29
+ y_data = y_data.map do |label|
30
+ classes = Array.new(10, 0)
31
+ classes[label] = 1
32
+ classes
33
+ end
34
+ end
35
+
36
+ private_class_method
37
+
38
+ def self.load(images_file_name, labels_file_name)
39
+ images = []
40
+ labels = nil
41
+ Zlib::GzipReader.open(images_file_name) do |f|
42
+ magic, n_images = f.read(8).unpack("N2")
43
+ n_rows, n_cols = f.read(8).unpack("N2")
44
+ n_images.times do
45
+ images << f.read(n_rows * n_cols).unpack("C*")
46
+ end
47
+ end
48
+ Zlib::GzipReader.open(labels_file_name) do |f|
49
+ magic, n_labels = f.read(8).unpack("N2")
50
+ labels = f.read(n_labels).unpack("C*")
51
+ end
52
+ [images, labels]
53
+ end
54
+ end
@@ -0,0 +1,2 @@
1
+ require "nn"
2
+
@@ -0,0 +1,39 @@
1
+
2
+ lib = File.expand_path("../lib", __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require "nn"
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "nn"
8
+ spec.version = NN::VERSION
9
+ spec.authors = ["unagiootoro"]
10
+ spec.email = ["ootoro838861@outlook.jp"]
11
+
12
+ spec.summary = %q{Ruby用ニューラルネットワークライブラリ}
13
+ spec.description = %q{Rubyでニューラルネットワークを作成できます。}
14
+ spec.homepage = "https://github.com/unagiootoro/nn.git"
15
+ spec.license = "MIT"
16
+
17
+ spec.add_dependency "numo-narray"
18
+
19
+ # Prevent pushing this gem to RubyGems.org. To allow pushes either set the 'allowed_push_host'
20
+ # to allow pushing to a single host or delete this section to allow pushing to any host.
21
+ =begin
22
+ if spec.respond_to?(:metadata)
23
+ spec.metadata["allowed_push_host"] = "TODO: Set to 'http://mygemserver.com'"
24
+ else
25
+ raise "RubyGems 2.0 or newer is required to protect against " \
26
+ "public gem pushes."
27
+ end
28
+ =end
29
+
30
+ spec.files = `git ls-files -z`.split("\x0").reject do |f|
31
+ f.match(%r{^(test|spec|features)/})
32
+ end
33
+ spec.bindir = "exe"
34
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
35
+ spec.require_paths = ["lib"]
36
+
37
+ spec.add_development_dependency "bundler", "~> 1.16"
38
+ spec.add_development_dependency "rake", "~> 10.0"
39
+ end
metadata ADDED
@@ -0,0 +1,98 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: nn
3
+ version: !ruby/object:Gem::Version
4
+ version: '1.5'
5
+ platform: ruby
6
+ authors:
7
+ - unagiootoro
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2018-03-27 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '1.16'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '1.16'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '10.0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '10.0'
55
+ description: Rubyでニューラルネットワークを作成できます。
56
+ email:
57
+ - ootoro838861@outlook.jp
58
+ executables: []
59
+ extensions: []
60
+ extra_rdoc_files: []
61
+ files:
62
+ - ".gitignore"
63
+ - Gemfile
64
+ - LICENSE.txt
65
+ - README.md
66
+ - Rakefile
67
+ - bin/console
68
+ - bin/setup
69
+ - document.txt
70
+ - lib/nn.rb
71
+ - lib/nn/mnist.rb
72
+ - lib/nn/version.rb
73
+ - nn.gemspec
74
+ homepage: https://github.com/unagiootoro/nn.git
75
+ licenses:
76
+ - MIT
77
+ metadata: {}
78
+ post_install_message:
79
+ rdoc_options: []
80
+ require_paths:
81
+ - lib
82
+ required_ruby_version: !ruby/object:Gem::Requirement
83
+ requirements:
84
+ - - ">="
85
+ - !ruby/object:Gem::Version
86
+ version: '0'
87
+ required_rubygems_version: !ruby/object:Gem::Requirement
88
+ requirements:
89
+ - - ">="
90
+ - !ruby/object:Gem::Version
91
+ version: '0'
92
+ requirements: []
93
+ rubyforge_project:
94
+ rubygems_version: 2.7.3
95
+ signing_key:
96
+ specification_version: 4
97
+ summary: Ruby用ニューラルネットワークライブラリ
98
+ test_files: []