rann 0.2.4 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGES.md +4 -0
- data/lib/rann/backprop.rb +27 -4
- data/lib/rann/optimisers/adagrad.rb +15 -2
- data/lib/rann/optimisers/rmsprop.rb +15 -2
- data/lib/rann/version.rb +1 -1
- data/rann.gemspec +3 -2
- metadata +16 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 81b3a045a8873df309575045595a2379a7e4b28c
|
4
|
+
data.tar.gz: 5cab5e7a20030b091341ff9dc9a1d6acb3279ea7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f223e733f0eed2cc347050d6dc6b48eb64a1705ee0208da4eccb25d644c09cfa7bbd08313d427b0990dfb10ff136a27dcf3ecb39a4a9d81183423c181a724879
|
7
|
+
data.tar.gz: aca9f044286fd055e6868098375c084574162102e2e0d40a1e17c92c8f0795e1d9a7ca13216ec54010192b3f75d354166fbe32dde1804dc712a27a2accf16b35
|
data/CHANGES.md
CHANGED
data/lib/rann/backprop.rb
CHANGED
@@ -20,10 +20,10 @@ module RANN
|
|
20
20
|
|
21
21
|
attr_accessor :network
|
22
22
|
|
23
|
-
def initialize network, opts = {}
|
23
|
+
def initialize network, opts = {}
|
24
24
|
@network = network
|
25
25
|
@connections_hash = network.connections.each.with_object({}){ |c, h| h[c.id] = c }
|
26
|
-
@optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts
|
26
|
+
@optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts
|
27
27
|
@batch_count = 0.to_d
|
28
28
|
end
|
29
29
|
|
@@ -182,8 +182,31 @@ module RANN
|
|
182
182
|
[gradients, error]
|
183
183
|
end
|
184
184
|
|
185
|
-
def
|
186
|
-
|
185
|
+
def save filepath = nil
|
186
|
+
filepath ||= "rann_savepoint_#{DateTime.now.strftime('%Y-%m-%d-%H-%M-%S')}.yml"
|
187
|
+
|
188
|
+
weights = @network.params
|
189
|
+
opt_vars = @optimiser.state
|
190
|
+
|
191
|
+
File.open filepath, "w" do |f|
|
192
|
+
f.write YAML.dump [weights, opt_vars]
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
def restore filepath
|
197
|
+
unless filepath
|
198
|
+
filepath = Dir['*'].select{ |f| f =~ /rann_savepoint_.*/ }.sort.last
|
199
|
+
|
200
|
+
unless filepath
|
201
|
+
@network.init_normalised!
|
202
|
+
puts "No savepoints found—initialised normalised weights"
|
203
|
+
return
|
204
|
+
end
|
205
|
+
end
|
206
|
+
|
207
|
+
weights, opt_vars = YAML.load_file(filepath)
|
208
|
+
@network.impose(weights)
|
209
|
+
@network.optimiser.load_state(opt_vars)
|
187
210
|
end
|
188
211
|
|
189
212
|
def self.reset! network
|
@@ -5,10 +5,10 @@ require "bigdecimal/util"
|
|
5
5
|
module RANN
|
6
6
|
module Optimisers
|
7
7
|
class AdaGrad
|
8
|
-
def initialize opts = {}
|
8
|
+
def initialize opts = {}
|
9
9
|
@fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
|
10
10
|
@learning_rate = opts[:learning_rate] || 0.1.to_d
|
11
|
-
@historical_gradient =
|
11
|
+
@historical_gradient = {}.tap{ |h| h.default = 0.to_d }
|
12
12
|
end
|
13
13
|
|
14
14
|
def update grad, cid
|
@@ -16,6 +16,19 @@ module RANN
|
|
16
16
|
|
17
17
|
grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
|
18
18
|
end
|
19
|
+
|
20
|
+
# anything that gets modified over the course of training
|
21
|
+
def state
|
22
|
+
{
|
23
|
+
historical_gradient: @historical_gradient,
|
24
|
+
}
|
25
|
+
end
|
26
|
+
|
27
|
+
def load_state state
|
28
|
+
state.each do |name, value|
|
29
|
+
instance_variable_set("@#{name}", value)
|
30
|
+
end
|
31
|
+
end
|
19
32
|
end
|
20
33
|
end
|
21
34
|
end
|
@@ -5,11 +5,11 @@ require "bigdecimal/util"
|
|
5
5
|
module RANN
|
6
6
|
module Optimisers
|
7
7
|
class RMSProp
|
8
|
-
def initialize opts = {}
|
8
|
+
def initialize opts = {}
|
9
9
|
@decay = opts[:decay] || 0.9.to_d
|
10
10
|
@fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
|
11
11
|
@learning_rate = opts[:learning_rate] || 0.01.to_d
|
12
|
-
@historical_gradient =
|
12
|
+
@historical_gradient = {}.tap{ |h| h.default = 0.to_d }
|
13
13
|
end
|
14
14
|
|
15
15
|
def update grad, cid
|
@@ -17,6 +17,19 @@ module RANN
|
|
17
17
|
|
18
18
|
grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
|
19
19
|
end
|
20
|
+
|
21
|
+
# anything that gets modified over the course of training
|
22
|
+
def state
|
23
|
+
{
|
24
|
+
historical_gradient: @historical_gradient,
|
25
|
+
}
|
26
|
+
end
|
27
|
+
|
28
|
+
def load_state state
|
29
|
+
state.each do |name, value|
|
30
|
+
instance_variable_set("@#{name}", value)
|
31
|
+
end
|
32
|
+
end
|
20
33
|
end
|
21
34
|
end
|
22
35
|
end
|
data/lib/rann/version.rb
CHANGED
data/rann.gemspec
CHANGED
@@ -12,6 +12,7 @@ Gem::Specification.new do |spec|
|
|
12
12
|
spec.summary = %q{Ruby Artificial Neural Networks}
|
13
13
|
spec.description = %q{Libary for working with neural networks in Ruby.}
|
14
14
|
spec.homepage = "https://github.com/mikecmpbll/rann"
|
15
|
+
spec.licenses = 'Apache-2.0'
|
15
16
|
|
16
17
|
spec.files = `git ls-files -z`.split("\x0").reject do |f|
|
17
18
|
f.match(%r{^(test|spec|features)/})
|
@@ -20,8 +21,8 @@ Gem::Specification.new do |spec|
|
|
20
21
|
spec.executables = spec.files.grep(%r{^exe/}){ |f| File.basename(f) }
|
21
22
|
spec.require_paths = ["lib"]
|
22
23
|
|
23
|
-
spec.add_runtime_dependency
|
24
|
-
spec.add_runtime_dependency
|
24
|
+
spec.add_runtime_dependency 'parallel', '~> 1.12', '>= 1.12.0'
|
25
|
+
spec.add_runtime_dependency 'ruby-graphviz', '~> 1.2', '>= 1.2.3'
|
25
26
|
|
26
27
|
spec.add_development_dependency "bundler", "~> 1.16"
|
27
28
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,20 +1,23 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rann
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Michael Campbell
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2017-11-
|
11
|
+
date: 2017-11-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: parallel
|
15
15
|
requirement: !ruby/object:Gem::Requirement
|
16
16
|
requirements:
|
17
17
|
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '1.12'
|
20
|
+
- - ">="
|
18
21
|
- !ruby/object:Gem::Version
|
19
22
|
version: 1.12.0
|
20
23
|
type: :runtime
|
@@ -22,6 +25,9 @@ dependencies:
|
|
22
25
|
version_requirements: !ruby/object:Gem::Requirement
|
23
26
|
requirements:
|
24
27
|
- - "~>"
|
28
|
+
- !ruby/object:Gem::Version
|
29
|
+
version: '1.12'
|
30
|
+
- - ">="
|
25
31
|
- !ruby/object:Gem::Version
|
26
32
|
version: 1.12.0
|
27
33
|
- !ruby/object:Gem::Dependency
|
@@ -29,6 +35,9 @@ dependencies:
|
|
29
35
|
requirement: !ruby/object:Gem::Requirement
|
30
36
|
requirements:
|
31
37
|
- - "~>"
|
38
|
+
- !ruby/object:Gem::Version
|
39
|
+
version: '1.2'
|
40
|
+
- - ">="
|
32
41
|
- !ruby/object:Gem::Version
|
33
42
|
version: 1.2.3
|
34
43
|
type: :runtime
|
@@ -36,6 +45,9 @@ dependencies:
|
|
36
45
|
version_requirements: !ruby/object:Gem::Requirement
|
37
46
|
requirements:
|
38
47
|
- - "~>"
|
48
|
+
- !ruby/object:Gem::Version
|
49
|
+
version: '1.2'
|
50
|
+
- - ">="
|
39
51
|
- !ruby/object:Gem::Version
|
40
52
|
version: 1.2.3
|
41
53
|
- !ruby/object:Gem::Dependency
|
@@ -114,7 +126,8 @@ files:
|
|
114
126
|
- lib/rann/version.rb
|
115
127
|
- rann.gemspec
|
116
128
|
homepage: https://github.com/mikecmpbll/rann
|
117
|
-
licenses:
|
129
|
+
licenses:
|
130
|
+
- Apache-2.0
|
118
131
|
metadata: {}
|
119
132
|
post_install_message:
|
120
133
|
rdoc_options: []
|