rann 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 671da77af2c018b9391e64bc6d6a3bdcd3267648
4
- data.tar.gz: edeb0b1f7c6c59b5afb692872efc1c1e80492971
3
+ metadata.gz: 81b3a045a8873df309575045595a2379a7e4b28c
4
+ data.tar.gz: 5cab5e7a20030b091341ff9dc9a1d6acb3279ea7
5
5
  SHA512:
6
- metadata.gz: 6527cce72e2d732c2efee6e1acf1ca9c01187e5dbc69f27600842de1d2685766e6631183abab171a1dabb029ce45d64486a62af6b69c40f656d9dc4afe58fca8
7
- data.tar.gz: 617ea35293d0a34f40f4e613124e73d6771bad29f47cb5860786a0e03564cd798619b51b84f0397624f5f3100a7b15a3e0ba2db0d1bef143f9e7c5e07e37bb0e
6
+ metadata.gz: f223e733f0eed2cc347050d6dc6b48eb64a1705ee0208da4eccb25d644c09cfa7bbd08313d427b0990dfb10ff136a27dcf3ecb39a4a9d81183423c181a724879
7
+ data.tar.gz: aca9f044286fd055e6868098375c084574162102e2e0d40a1e17c92c8f0795e1d9a7ca13216ec54010192b3f75d354166fbe32dde1804dc712a27a2accf16b35
data/CHANGES.md CHANGED
@@ -1,3 +1,7 @@
1
+ - Add save and restore methods to backprop.
2
+
3
+ *Michael Campbell*
4
+
1
5
  - Fix alias error.
2
6
 
3
7
  *Michael Campbell*
data/lib/rann/backprop.rb CHANGED
@@ -20,10 +20,10 @@ module RANN
20
20
 
21
21
  attr_accessor :network
22
22
 
23
- def initialize network, opts = {}, restore = {}
23
+ def initialize network, opts = {}
24
24
  @network = network
25
25
  @connections_hash = network.connections.each.with_object({}){ |c, h| h[c.id] = c }
26
- @optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts, restore
26
+ @optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts
27
27
  @batch_count = 0.to_d
28
28
  end
29
29
 
@@ -182,8 +182,31 @@ module RANN
182
182
  [gradients, error]
183
183
  end
184
184
 
185
- def state
186
- { historical_gradient: @historical_gradient }
185
+ def save filepath = nil
186
+ filepath ||= "rann_savepoint_#{DateTime.now.strftime('%Y-%m-%d-%H-%M-%S')}.yml"
187
+
188
+ weights = @network.params
189
+ opt_vars = @optimiser.state
190
+
191
+ File.open filepath, "w" do |f|
192
+ f.write YAML.dump [weights, opt_vars]
193
+ end
194
+ end
195
+
196
+ def restore filepath
197
+ unless filepath
198
+ filepath = Dir['*'].select{ |f| f =~ /rann_savepoint_.*/ }.sort.last
199
+
200
+ unless filepath
201
+ @network.init_normalised!
202
+ puts "No savepoints found—initialised normalised weights"
203
+ return
204
+ end
205
+ end
206
+
207
+ weights, opt_vars = YAML.load_file(filepath)
208
+ @network.impose(weights)
209
+ @network.optimiser.load_state(opt_vars)
187
210
  end
188
211
 
189
212
  def self.reset! network
@@ -5,10 +5,10 @@ require "bigdecimal/util"
5
5
  module RANN
6
6
  module Optimisers
7
7
  class AdaGrad
8
- def initialize opts = {}, restore = {}
8
+ def initialize opts = {}
9
9
  @fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
10
10
  @learning_rate = opts[:learning_rate] || 0.1.to_d
11
- @historical_gradient = (restore[:historical_gradient] || {}).tap{ |h| h.default = 0.to_d }
11
+ @historical_gradient = {}.tap{ |h| h.default = 0.to_d }
12
12
  end
13
13
 
14
14
  def update grad, cid
@@ -16,6 +16,19 @@ module RANN
16
16
 
17
17
  grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
18
18
  end
19
+
20
+ # anything that gets modified over the course of training
21
+ def state
22
+ {
23
+ historical_gradient: @historical_gradient,
24
+ }
25
+ end
26
+
27
+ def load_state state
28
+ state.each do |name, value|
29
+ instance_variable_set("@#{name}", value)
30
+ end
31
+ end
19
32
  end
20
33
  end
21
34
  end
@@ -5,11 +5,11 @@ require "bigdecimal/util"
5
5
  module RANN
6
6
  module Optimisers
7
7
  class RMSProp
8
- def initialize opts = {}, restore = {}
8
+ def initialize opts = {}
9
9
  @decay = opts[:decay] || 0.9.to_d
10
10
  @fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
11
11
  @learning_rate = opts[:learning_rate] || 0.01.to_d
12
- @historical_gradient = (restore[:historical_gradient] || {}).tap{ |h| h.default = 0.to_d }
12
+ @historical_gradient = {}.tap{ |h| h.default = 0.to_d }
13
13
  end
14
14
 
15
15
  def update grad, cid
@@ -17,6 +17,19 @@ module RANN
17
17
 
18
18
  grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
19
19
  end
20
+
21
+ # anything that gets modified over the course of training
22
+ def state
23
+ {
24
+ historical_gradient: @historical_gradient,
25
+ }
26
+ end
27
+
28
+ def load_state state
29
+ state.each do |name, value|
30
+ instance_variable_set("@#{name}", value)
31
+ end
32
+ end
20
33
  end
21
34
  end
22
35
  end
data/lib/rann/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module RANN
2
- VERSION = "0.2.4"
2
+ VERSION = "0.2.5"
3
3
  end
data/rann.gemspec CHANGED
@@ -12,6 +12,7 @@ Gem::Specification.new do |spec|
12
12
  spec.summary = %q{Ruby Artificial Neural Networks}
13
13
  spec.description = %q{Libary for working with neural networks in Ruby.}
14
14
  spec.homepage = "https://github.com/mikecmpbll/rann"
15
+ spec.licenses = 'Apache-2.0'
15
16
 
16
17
  spec.files = `git ls-files -z`.split("\x0").reject do |f|
17
18
  f.match(%r{^(test|spec|features)/})
@@ -20,8 +21,8 @@ Gem::Specification.new do |spec|
20
21
  spec.executables = spec.files.grep(%r{^exe/}){ |f| File.basename(f) }
21
22
  spec.require_paths = ["lib"]
22
23
 
23
- spec.add_runtime_dependency "parallel", "~> 1.12.0"
24
- spec.add_runtime_dependency "ruby-graphviz", "~> 1.2.3"
24
+ spec.add_runtime_dependency 'parallel', '~> 1.12', '>= 1.12.0'
25
+ spec.add_runtime_dependency 'ruby-graphviz', '~> 1.2', '>= 1.2.3'
25
26
 
26
27
  spec.add_development_dependency "bundler", "~> 1.16"
27
28
  spec.add_development_dependency "rake", "~> 10.0"
metadata CHANGED
@@ -1,20 +1,23 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rann
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.4
4
+ version: 0.2.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Michael Campbell
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2017-11-28 00:00:00.000000000 Z
11
+ date: 2017-11-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: parallel
15
15
  requirement: !ruby/object:Gem::Requirement
16
16
  requirements:
17
17
  - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '1.12'
20
+ - - ">="
18
21
  - !ruby/object:Gem::Version
19
22
  version: 1.12.0
20
23
  type: :runtime
@@ -22,6 +25,9 @@ dependencies:
22
25
  version_requirements: !ruby/object:Gem::Requirement
23
26
  requirements:
24
27
  - - "~>"
28
+ - !ruby/object:Gem::Version
29
+ version: '1.12'
30
+ - - ">="
25
31
  - !ruby/object:Gem::Version
26
32
  version: 1.12.0
27
33
  - !ruby/object:Gem::Dependency
@@ -29,6 +35,9 @@ dependencies:
29
35
  requirement: !ruby/object:Gem::Requirement
30
36
  requirements:
31
37
  - - "~>"
38
+ - !ruby/object:Gem::Version
39
+ version: '1.2'
40
+ - - ">="
32
41
  - !ruby/object:Gem::Version
33
42
  version: 1.2.3
34
43
  type: :runtime
@@ -36,6 +45,9 @@ dependencies:
36
45
  version_requirements: !ruby/object:Gem::Requirement
37
46
  requirements:
38
47
  - - "~>"
48
+ - !ruby/object:Gem::Version
49
+ version: '1.2'
50
+ - - ">="
39
51
  - !ruby/object:Gem::Version
40
52
  version: 1.2.3
41
53
  - !ruby/object:Gem::Dependency
@@ -114,7 +126,8 @@ files:
114
126
  - lib/rann/version.rb
115
127
  - rann.gemspec
116
128
  homepage: https://github.com/mikecmpbll/rann
117
- licenses: []
129
+ licenses:
130
+ - Apache-2.0
118
131
  metadata: {}
119
132
  post_install_message:
120
133
  rdoc_options: []