rann 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGES.md +4 -0
- data/lib/rann/backprop.rb +27 -4
- data/lib/rann/optimisers/adagrad.rb +15 -2
- data/lib/rann/optimisers/rmsprop.rb +15 -2
- data/lib/rann/version.rb +1 -1
- data/rann.gemspec +3 -2
- metadata +16 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 81b3a045a8873df309575045595a2379a7e4b28c
|
4
|
+
data.tar.gz: 5cab5e7a20030b091341ff9dc9a1d6acb3279ea7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f223e733f0eed2cc347050d6dc6b48eb64a1705ee0208da4eccb25d644c09cfa7bbd08313d427b0990dfb10ff136a27dcf3ecb39a4a9d81183423c181a724879
|
7
|
+
data.tar.gz: aca9f044286fd055e6868098375c084574162102e2e0d40a1e17c92c8f0795e1d9a7ca13216ec54010192b3f75d354166fbe32dde1804dc712a27a2accf16b35
|
data/CHANGES.md
CHANGED
data/lib/rann/backprop.rb
CHANGED
@@ -20,10 +20,10 @@ module RANN
|
|
20
20
|
|
21
21
|
attr_accessor :network
|
22
22
|
|
23
|
-
def initialize network, opts = {}
|
23
|
+
def initialize network, opts = {}
|
24
24
|
@network = network
|
25
25
|
@connections_hash = network.connections.each.with_object({}){ |c, h| h[c.id] = c }
|
26
|
-
@optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts
|
26
|
+
@optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts
|
27
27
|
@batch_count = 0.to_d
|
28
28
|
end
|
29
29
|
|
@@ -182,8 +182,31 @@ module RANN
|
|
182
182
|
[gradients, error]
|
183
183
|
end
|
184
184
|
|
185
|
-
def
|
186
|
-
|
185
|
+
def save filepath = nil
|
186
|
+
filepath ||= "rann_savepoint_#{DateTime.now.strftime('%Y-%m-%d-%H-%M-%S')}.yml"
|
187
|
+
|
188
|
+
weights = @network.params
|
189
|
+
opt_vars = @optimiser.state
|
190
|
+
|
191
|
+
File.open filepath, "w" do |f|
|
192
|
+
f.write YAML.dump [weights, opt_vars]
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
def restore filepath
|
197
|
+
unless filepath
|
198
|
+
filepath = Dir['*'].select{ |f| f =~ /rann_savepoint_.*/ }.sort.last
|
199
|
+
|
200
|
+
unless filepath
|
201
|
+
@network.init_normalised!
|
202
|
+
puts "No savepoints found—initialised normalised weights"
|
203
|
+
return
|
204
|
+
end
|
205
|
+
end
|
206
|
+
|
207
|
+
weights, opt_vars = YAML.load_file(filepath)
|
208
|
+
@network.impose(weights)
|
209
|
+
@network.optimiser.load_state(opt_vars)
|
187
210
|
end
|
188
211
|
|
189
212
|
def self.reset! network
|
@@ -5,10 +5,10 @@ require "bigdecimal/util"
|
|
5
5
|
module RANN
|
6
6
|
module Optimisers
|
7
7
|
class AdaGrad
|
8
|
-
def initialize opts = {}
|
8
|
+
def initialize opts = {}
|
9
9
|
@fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
|
10
10
|
@learning_rate = opts[:learning_rate] || 0.1.to_d
|
11
|
-
@historical_gradient =
|
11
|
+
@historical_gradient = {}.tap{ |h| h.default = 0.to_d }
|
12
12
|
end
|
13
13
|
|
14
14
|
def update grad, cid
|
@@ -16,6 +16,19 @@ module RANN
|
|
16
16
|
|
17
17
|
grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
|
18
18
|
end
|
19
|
+
|
20
|
+
# anything that gets modified over the course of training
|
21
|
+
def state
|
22
|
+
{
|
23
|
+
historical_gradient: @historical_gradient,
|
24
|
+
}
|
25
|
+
end
|
26
|
+
|
27
|
+
def load_state state
|
28
|
+
state.each do |name, value|
|
29
|
+
instance_variable_set("@#{name}", value)
|
30
|
+
end
|
31
|
+
end
|
19
32
|
end
|
20
33
|
end
|
21
34
|
end
|
@@ -5,11 +5,11 @@ require "bigdecimal/util"
|
|
5
5
|
module RANN
|
6
6
|
module Optimisers
|
7
7
|
class RMSProp
|
8
|
-
def initialize opts = {}
|
8
|
+
def initialize opts = {}
|
9
9
|
@decay = opts[:decay] || 0.9.to_d
|
10
10
|
@fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
|
11
11
|
@learning_rate = opts[:learning_rate] || 0.01.to_d
|
12
|
-
@historical_gradient =
|
12
|
+
@historical_gradient = {}.tap{ |h| h.default = 0.to_d }
|
13
13
|
end
|
14
14
|
|
15
15
|
def update grad, cid
|
@@ -17,6 +17,19 @@ module RANN
|
|
17
17
|
|
18
18
|
grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
|
19
19
|
end
|
20
|
+
|
21
|
+
# anything that gets modified over the course of training
|
22
|
+
def state
|
23
|
+
{
|
24
|
+
historical_gradient: @historical_gradient,
|
25
|
+
}
|
26
|
+
end
|
27
|
+
|
28
|
+
def load_state state
|
29
|
+
state.each do |name, value|
|
30
|
+
instance_variable_set("@#{name}", value)
|
31
|
+
end
|
32
|
+
end
|
20
33
|
end
|
21
34
|
end
|
22
35
|
end
|
data/lib/rann/version.rb
CHANGED
data/rann.gemspec
CHANGED
@@ -12,6 +12,7 @@ Gem::Specification.new do |spec|
|
|
12
12
|
spec.summary = %q{Ruby Artificial Neural Networks}
|
13
13
|
spec.description = %q{Libary for working with neural networks in Ruby.}
|
14
14
|
spec.homepage = "https://github.com/mikecmpbll/rann"
|
15
|
+
spec.licenses = 'Apache-2.0'
|
15
16
|
|
16
17
|
spec.files = `git ls-files -z`.split("\x0").reject do |f|
|
17
18
|
f.match(%r{^(test|spec|features)/})
|
@@ -20,8 +21,8 @@ Gem::Specification.new do |spec|
|
|
20
21
|
spec.executables = spec.files.grep(%r{^exe/}){ |f| File.basename(f) }
|
21
22
|
spec.require_paths = ["lib"]
|
22
23
|
|
23
|
-
spec.add_runtime_dependency
|
24
|
-
spec.add_runtime_dependency
|
24
|
+
spec.add_runtime_dependency 'parallel', '~> 1.12', '>= 1.12.0'
|
25
|
+
spec.add_runtime_dependency 'ruby-graphviz', '~> 1.2', '>= 1.2.3'
|
25
26
|
|
26
27
|
spec.add_development_dependency "bundler", "~> 1.16"
|
27
28
|
spec.add_development_dependency "rake", "~> 10.0"
|
metadata
CHANGED
@@ -1,20 +1,23 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rann
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Michael Campbell
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2017-11-
|
11
|
+
date: 2017-11-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: parallel
|
15
15
|
requirement: !ruby/object:Gem::Requirement
|
16
16
|
requirements:
|
17
17
|
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '1.12'
|
20
|
+
- - ">="
|
18
21
|
- !ruby/object:Gem::Version
|
19
22
|
version: 1.12.0
|
20
23
|
type: :runtime
|
@@ -22,6 +25,9 @@ dependencies:
|
|
22
25
|
version_requirements: !ruby/object:Gem::Requirement
|
23
26
|
requirements:
|
24
27
|
- - "~>"
|
28
|
+
- !ruby/object:Gem::Version
|
29
|
+
version: '1.12'
|
30
|
+
- - ">="
|
25
31
|
- !ruby/object:Gem::Version
|
26
32
|
version: 1.12.0
|
27
33
|
- !ruby/object:Gem::Dependency
|
@@ -29,6 +35,9 @@ dependencies:
|
|
29
35
|
requirement: !ruby/object:Gem::Requirement
|
30
36
|
requirements:
|
31
37
|
- - "~>"
|
38
|
+
- !ruby/object:Gem::Version
|
39
|
+
version: '1.2'
|
40
|
+
- - ">="
|
32
41
|
- !ruby/object:Gem::Version
|
33
42
|
version: 1.2.3
|
34
43
|
type: :runtime
|
@@ -36,6 +45,9 @@ dependencies:
|
|
36
45
|
version_requirements: !ruby/object:Gem::Requirement
|
37
46
|
requirements:
|
38
47
|
- - "~>"
|
48
|
+
- !ruby/object:Gem::Version
|
49
|
+
version: '1.2'
|
50
|
+
- - ">="
|
39
51
|
- !ruby/object:Gem::Version
|
40
52
|
version: 1.2.3
|
41
53
|
- !ruby/object:Gem::Dependency
|
@@ -114,7 +126,8 @@ files:
|
|
114
126
|
- lib/rann/version.rb
|
115
127
|
- rann.gemspec
|
116
128
|
homepage: https://github.com/mikecmpbll/rann
|
117
|
-
licenses:
|
129
|
+
licenses:
|
130
|
+
- Apache-2.0
|
118
131
|
metadata: {}
|
119
132
|
post_install_message:
|
120
133
|
rdoc_options: []
|