rann 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 671da77af2c018b9391e64bc6d6a3bdcd3267648
4
- data.tar.gz: edeb0b1f7c6c59b5afb692872efc1c1e80492971
3
+ metadata.gz: 81b3a045a8873df309575045595a2379a7e4b28c
4
+ data.tar.gz: 5cab5e7a20030b091341ff9dc9a1d6acb3279ea7
5
5
  SHA512:
6
- metadata.gz: 6527cce72e2d732c2efee6e1acf1ca9c01187e5dbc69f27600842de1d2685766e6631183abab171a1dabb029ce45d64486a62af6b69c40f656d9dc4afe58fca8
7
- data.tar.gz: 617ea35293d0a34f40f4e613124e73d6771bad29f47cb5860786a0e03564cd798619b51b84f0397624f5f3100a7b15a3e0ba2db0d1bef143f9e7c5e07e37bb0e
6
+ metadata.gz: f223e733f0eed2cc347050d6dc6b48eb64a1705ee0208da4eccb25d644c09cfa7bbd08313d427b0990dfb10ff136a27dcf3ecb39a4a9d81183423c181a724879
7
+ data.tar.gz: aca9f044286fd055e6868098375c084574162102e2e0d40a1e17c92c8f0795e1d9a7ca13216ec54010192b3f75d354166fbe32dde1804dc712a27a2accf16b35
data/CHANGES.md CHANGED
@@ -1,3 +1,7 @@
1
+ - Add save and restore methods to backprop.
2
+
3
+ *Michael Campbell*
4
+
1
5
  - Fix alias error.
2
6
 
3
7
  *Michael Campbell*
data/lib/rann/backprop.rb CHANGED
@@ -20,10 +20,10 @@ module RANN
20
20
 
21
21
  attr_accessor :network
22
22
 
23
- def initialize network, opts = {}, restore = {}
23
+ def initialize network, opts = {}
24
24
  @network = network
25
25
  @connections_hash = network.connections.each.with_object({}){ |c, h| h[c.id] = c }
26
- @optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts, restore
26
+ @optimiser = RANN::Optimisers.const_get(opts[:optimiser] || 'RMSProp').new opts
27
27
  @batch_count = 0.to_d
28
28
  end
29
29
 
@@ -182,8 +182,31 @@ module RANN
182
182
  [gradients, error]
183
183
  end
184
184
 
185
- def state
186
- { historical_gradient: @historical_gradient }
185
+ def save filepath = nil
186
+ filepath ||= "rann_savepoint_#{DateTime.now.strftime('%Y-%m-%d-%H-%M-%S')}.yml"
187
+
188
+ weights = @network.params
189
+ opt_vars = @optimiser.state
190
+
191
+ File.open filepath, "w" do |f|
192
+ f.write YAML.dump [weights, opt_vars]
193
+ end
194
+ end
195
+
196
+ def restore filepath
197
+ unless filepath
198
+ filepath = Dir['*'].select{ |f| f =~ /rann_savepoint_.*/ }.sort.last
199
+
200
+ unless filepath
201
+ @network.init_normalised!
202
+ puts "No savepoints found—initialised normalised weights"
203
+ return
204
+ end
205
+ end
206
+
207
+ weights, opt_vars = YAML.load_file(filepath)
208
+ @network.impose(weights)
209
+ @network.optimiser.load_state(opt_vars)
187
210
  end
188
211
 
189
212
  def self.reset! network
@@ -5,10 +5,10 @@ require "bigdecimal/util"
5
5
  module RANN
6
6
  module Optimisers
7
7
  class AdaGrad
8
- def initialize opts = {}, restore = {}
8
+ def initialize opts = {}
9
9
  @fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
10
10
  @learning_rate = opts[:learning_rate] || 0.1.to_d
11
- @historical_gradient = (restore[:historical_gradient] || {}).tap{ |h| h.default = 0.to_d }
11
+ @historical_gradient = {}.tap{ |h| h.default = 0.to_d }
12
12
  end
13
13
 
14
14
  def update grad, cid
@@ -16,6 +16,19 @@ module RANN
16
16
 
17
17
  grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
18
18
  end
19
+
20
+ # anything that gets modified over the course of training
21
+ def state
22
+ {
23
+ historical_gradient: @historical_gradient,
24
+ }
25
+ end
26
+
27
+ def load_state state
28
+ state.each do |name, value|
29
+ instance_variable_set("@#{name}", value)
30
+ end
31
+ end
19
32
  end
20
33
  end
21
34
  end
@@ -5,11 +5,11 @@ require "bigdecimal/util"
5
5
  module RANN
6
6
  module Optimisers
7
7
  class RMSProp
8
- def initialize opts = {}, restore = {}
8
+ def initialize opts = {}
9
9
  @decay = opts[:decay] || 0.9.to_d
10
10
  @fudge_factor = opts[:fudge_factor] || 0.00000001.to_d
11
11
  @learning_rate = opts[:learning_rate] || 0.01.to_d
12
- @historical_gradient = (restore[:historical_gradient] || {}).tap{ |h| h.default = 0.to_d }
12
+ @historical_gradient = {}.tap{ |h| h.default = 0.to_d }
13
13
  end
14
14
 
15
15
  def update grad, cid
@@ -17,6 +17,19 @@ module RANN
17
17
 
18
18
  grad.mult(- @learning_rate.div(@fudge_factor + @historical_gradient[cid].sqrt(10), 10), 10)
19
19
  end
20
+
21
+ # anything that gets modified over the course of training
22
+ def state
23
+ {
24
+ historical_gradient: @historical_gradient,
25
+ }
26
+ end
27
+
28
+ def load_state state
29
+ state.each do |name, value|
30
+ instance_variable_set("@#{name}", value)
31
+ end
32
+ end
20
33
  end
21
34
  end
22
35
  end
data/lib/rann/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module RANN
2
- VERSION = "0.2.4"
2
+ VERSION = "0.2.5"
3
3
  end
data/rann.gemspec CHANGED
@@ -12,6 +12,7 @@ Gem::Specification.new do |spec|
12
12
  spec.summary = %q{Ruby Artificial Neural Networks}
13
13
  spec.description = %q{Libary for working with neural networks in Ruby.}
14
14
  spec.homepage = "https://github.com/mikecmpbll/rann"
15
+ spec.licenses = 'Apache-2.0'
15
16
 
16
17
  spec.files = `git ls-files -z`.split("\x0").reject do |f|
17
18
  f.match(%r{^(test|spec|features)/})
@@ -20,8 +21,8 @@ Gem::Specification.new do |spec|
20
21
  spec.executables = spec.files.grep(%r{^exe/}){ |f| File.basename(f) }
21
22
  spec.require_paths = ["lib"]
22
23
 
23
- spec.add_runtime_dependency "parallel", "~> 1.12.0"
24
- spec.add_runtime_dependency "ruby-graphviz", "~> 1.2.3"
24
+ spec.add_runtime_dependency 'parallel', '~> 1.12', '>= 1.12.0'
25
+ spec.add_runtime_dependency 'ruby-graphviz', '~> 1.2', '>= 1.2.3'
25
26
 
26
27
  spec.add_development_dependency "bundler", "~> 1.16"
27
28
  spec.add_development_dependency "rake", "~> 10.0"
metadata CHANGED
@@ -1,20 +1,23 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rann
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.4
4
+ version: 0.2.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Michael Campbell
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2017-11-28 00:00:00.000000000 Z
11
+ date: 2017-11-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: parallel
15
15
  requirement: !ruby/object:Gem::Requirement
16
16
  requirements:
17
17
  - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '1.12'
20
+ - - ">="
18
21
  - !ruby/object:Gem::Version
19
22
  version: 1.12.0
20
23
  type: :runtime
@@ -22,6 +25,9 @@ dependencies:
22
25
  version_requirements: !ruby/object:Gem::Requirement
23
26
  requirements:
24
27
  - - "~>"
28
+ - !ruby/object:Gem::Version
29
+ version: '1.12'
30
+ - - ">="
25
31
  - !ruby/object:Gem::Version
26
32
  version: 1.12.0
27
33
  - !ruby/object:Gem::Dependency
@@ -29,6 +35,9 @@ dependencies:
29
35
  requirement: !ruby/object:Gem::Requirement
30
36
  requirements:
31
37
  - - "~>"
38
+ - !ruby/object:Gem::Version
39
+ version: '1.2'
40
+ - - ">="
32
41
  - !ruby/object:Gem::Version
33
42
  version: 1.2.3
34
43
  type: :runtime
@@ -36,6 +45,9 @@ dependencies:
36
45
  version_requirements: !ruby/object:Gem::Requirement
37
46
  requirements:
38
47
  - - "~>"
48
+ - !ruby/object:Gem::Version
49
+ version: '1.2'
50
+ - - ">="
39
51
  - !ruby/object:Gem::Version
40
52
  version: 1.2.3
41
53
  - !ruby/object:Gem::Dependency
@@ -114,7 +126,8 @@ files:
114
126
  - lib/rann/version.rb
115
127
  - rann.gemspec
116
128
  homepage: https://github.com/mikecmpbll/rann
117
- licenses: []
129
+ licenses:
130
+ - Apache-2.0
118
131
  metadata: {}
119
132
  post_install_message:
120
133
  rdoc_options: []