synaptical 0.0.1.pre.beta1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,81 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ module Serializer
5
+ module JSON
6
+ class << self
7
+ # Generates a serialized Hash from a network
8
+ # @param network [Synaptical::Network] Network to serialize
9
+ #
10
+ # @return [Hash] Serialized network as hash
11
+ def as_json(network)
12
+ unless network.is_a?(Synaptical::Network)
13
+ raise ArgumentError, 'Only Networks can be serialized'
14
+ end
15
+
16
+ list = network.neurons
17
+ neurons = []
18
+ connections = []
19
+
20
+ ids = {}
21
+ list.each_with_index do |nr, i|
22
+ neuron = nr[:neuron]
23
+ ids[neuron.id] = i
24
+ copy = {
25
+ trace: { elegibility: {}, extended: {} },
26
+ state: neuron.state,
27
+ old: neuron.old,
28
+ activation: neuron.activation,
29
+ bias: neuron.bias,
30
+ layer: nr[:layer],
31
+ squash: 'LOGISTIC'
32
+ }
33
+
34
+ neurons << copy
35
+ end
36
+
37
+ list.each do |nr|
38
+ neuron = nr[:neuron]
39
+
40
+ neuron.connections.projected.each do |_id, conn|
41
+ connections << {
42
+ from: ids[conn.from.id],
43
+ to: ids[conn.to.id],
44
+ weight: conn.weight,
45
+ gater: conn.gater ? ids[conn.gater.id] : nil
46
+ }
47
+ end
48
+
49
+ next unless neuron.selfconnected?
50
+
51
+ connections << {
52
+ from: ids[neuron.id],
53
+ to: ids[neuron.id],
54
+ weight: neuron.selfconnection.weight,
55
+ gater: neuron.selfconnection.gater ? ids[neuron.selfconnection.gater.id] : nil
56
+ }
57
+ end
58
+
59
+ { neurons: neurons, connections: connections }
60
+ end
61
+
62
+ # Produces a serialized JSON string
63
+ # @param network [Synaptical::Network] network to serialize
64
+ #
65
+ # @return [String] Serialized network as JSON
66
+ def dump(network)
67
+ require 'json'
68
+ JSON.dump(as_json(network))
69
+ end
70
+
71
+ # Loads a network from JSON
72
+ # @param json [String] JSON string
73
+ #
74
+ # @return [Synaptical::Network] De-serialized network
75
+ def load(_json)
76
+ raise 'TODO'
77
+ end
78
+ end
79
+ end
80
+ end
81
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ # Squashing functions
5
+ module Squash
6
+ # Logistic function
7
+ module Logistic
8
+ class << self
9
+ # Apply logistic function for x_val
10
+ # @param x_val [Numeric] X value
11
+ #
12
+ # @return [Float] Y value
13
+ def call(x_val)
14
+ 1.0 / (1.0 + ::Math.exp(-x_val))
15
+ end
16
+
17
+ # Calculate derivate of value
18
+ # @param x_val [Numeric] value
19
+ #
20
+ # @return [Float] Derivate value
21
+ def derivate(x_val)
22
+ (x_val * (1 - x_val))
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ # Squashing functions
5
+ module Squash
6
+ # Hyperbolic tangens function
7
+ module Tanh
8
+ class << self
9
+ # Apply hyperbolic tangens function for x_val
10
+ # @param x_val [Numeric] X value
11
+ #
12
+ # @return [Float] Y value
13
+ def call(x_val)
14
+ ::Math.tanh(x_val)
15
+ end
16
+
17
+ # Calculate derivate of value
18
+ # @param x_val [Numeric] value
19
+ #
20
+ # @return [Float] Derivate value
21
+ def derivate(x_val)
22
+ 1.0 - x_val**2
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ class Trainer
5
+ attr_accessor :network, :rate, :iterations, :error, :cost, :cross_validate
6
+
7
+ Result = Struct.new(:error, :iterations, :time)
8
+
9
+ # Creates a new network trainer
10
+ # @param network [Synaptical::Network] Network to train
11
+ # @param rate: 0.2 [Float, Proc] Learning rate
12
+ # @param iterations: 100_000 [Integer] Training iterations
13
+ # @param error: 0.005 [Float] Max error
14
+ # @param cost: nil [type] [description]
15
+ # @param cross_validate: nil [type] [description]
16
+ def initialize(network, rate: 0.2, iterations: 100_000, error: 0.005, cost: Synaptical::Cost::Mse, cross_validate: false)
17
+ @network = network
18
+ @rate = rate
19
+ @iterations = iterations
20
+ @error = error
21
+ @cost = cost
22
+ @cross_validate = cross_validate
23
+ end
24
+
25
+ def train(set, options = nil)
26
+ error = 1.0
27
+ stop = false
28
+ iterations = bucket_size = 0
29
+ current_rate = rate
30
+ cross_validate = false
31
+ cost = self.cost
32
+ start = Time.now
33
+
34
+ bucket_size = iterations.fdiv(rate.size).floor if rate.is_a?(Array)
35
+
36
+ if cross_validate
37
+ num_train = ((1 - cross_validate.test_size) * set.size).ceil
38
+ train_set = set[0..num_train]
39
+ test_set = set[num_train..-1]
40
+ end
41
+
42
+ last_error = 0.0
43
+
44
+ while !stop && iterations < self.iterations && error > self.error
45
+ break if cross_validate && error <= self.cross_validate
46
+
47
+ current_set_size = set.size
48
+ error = 0.0
49
+ iterations += 1
50
+
51
+ if bucket_size.positive?
52
+ current_bucket = iterations.fdiv(bucket_size).floor
53
+ current_rate = rate[current_bucket] || current_rate
54
+ end
55
+
56
+ current_rate = rate.call(iterations, last_error) if rate.is_a?(Proc)
57
+
58
+ if cross_validate
59
+ train_set(train_set, current_rate, cost)
60
+ error += test(test_set).error
61
+ current_set_size = 1
62
+ else
63
+ error += train_set(set, current_rate, cost)
64
+ end
65
+
66
+ error /= current_set_size.to_f
67
+ last_error = error
68
+
69
+ raise 'TODO' if options
70
+ end
71
+
72
+ Result.new(error, iterations, Time.now - start)
73
+ end
74
+
75
+ def train_set(set, current_rate, cost_function)
76
+ set.reduce(0.0) do |sum, item|
77
+ input = item[:input]
78
+ target = item[:output]
79
+ output = network.activate(input)
80
+ network.propagate(current_rate, target)
81
+ sum + cost_function.call(target, output)
82
+ end
83
+ end
84
+
85
+ def test(_set)
86
+ raise 'TODO'
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ VERSION = '0.0.1-beta1'
5
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ lib = File.expand_path('../lib', __FILE__)
4
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
5
+ require 'synaptical/version'
6
+
7
+ Gem::Specification.new do |spec|
8
+ spec.name = 'synaptical'
9
+ spec.version = Synaptical::VERSION
10
+ spec.authors = ['Sebastian Wallin']
11
+ spec.email = ['sebastian.wallin@gmail.com']
12
+
13
+ spec.summary = 'Ruby port of Synaptic.js'
14
+ spec.description = 'Ruby port of Synaptic.js'
15
+ spec.homepage = 'https://github.com/castle/synaptical'
16
+
17
+ # Specify which files should be added to the gem when it is released.
18
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
19
+ spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
20
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
21
+ end
22
+ spec.bindir = 'exe'
23
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
24
+ spec.require_paths = ['lib']
25
+
26
+ spec.add_development_dependency 'bundler', '~> 1.16'
27
+ spec.add_development_dependency 'rake', '~> 10.0'
28
+ spec.add_development_dependency 'rspec', '~> 3.0'
29
+ end
metadata ADDED
@@ -0,0 +1,110 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: synaptical
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.0.1.pre.beta1
5
+ platform: ruby
6
+ authors:
7
+ - Sebastian Wallin
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2018-08-04 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: bundler
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '1.16'
20
+ type: :development
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '1.16'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rake
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '10.0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '10.0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rspec
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '3.0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '3.0'
55
+ description: Ruby port of Synaptic.js
56
+ email:
57
+ - sebastian.wallin@gmail.com
58
+ executables: []
59
+ extensions: []
60
+ extra_rdoc_files: []
61
+ files:
62
+ - ".gitignore"
63
+ - ".rspec"
64
+ - ".ruby-gemset"
65
+ - ".ruby-version"
66
+ - ".travis.yml"
67
+ - Gemfile
68
+ - Gemfile.lock
69
+ - README.md
70
+ - Rakefile
71
+ - bin/console
72
+ - bin/setup
73
+ - lib/synaptical.rb
74
+ - lib/synaptical/architect/perceptron.rb
75
+ - lib/synaptical/connection.rb
76
+ - lib/synaptical/cost/mse.rb
77
+ - lib/synaptical/layer.rb
78
+ - lib/synaptical/layer_connection.rb
79
+ - lib/synaptical/network.rb
80
+ - lib/synaptical/neuron.rb
81
+ - lib/synaptical/serializer/json.rb
82
+ - lib/synaptical/squash/logistic.rb
83
+ - lib/synaptical/squash/tanh.rb
84
+ - lib/synaptical/trainer.rb
85
+ - lib/synaptical/version.rb
86
+ - synaptical.gemspec
87
+ homepage: https://github.com/castle/synaptical
88
+ licenses: []
89
+ metadata: {}
90
+ post_install_message:
91
+ rdoc_options: []
92
+ require_paths:
93
+ - lib
94
+ required_ruby_version: !ruby/object:Gem::Requirement
95
+ requirements:
96
+ - - ">="
97
+ - !ruby/object:Gem::Version
98
+ version: '0'
99
+ required_rubygems_version: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">"
102
+ - !ruby/object:Gem::Version
103
+ version: 1.3.1
104
+ requirements: []
105
+ rubyforge_project:
106
+ rubygems_version: 2.7.4
107
+ signing_key:
108
+ specification_version: 4
109
+ summary: Ruby port of Synaptic.js
110
+ test_files: []