synaptical 0.0.1.pre.beta1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,81 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ module Serializer
5
+ module JSON
6
+ class << self
7
+ # Generates a serialized Hash from a network
8
+ # @param network [Synaptical::Network] Network to serialize
9
+ #
10
+ # @return [Hash] Serialized network as hash
11
+ def as_json(network)
12
+ unless network.is_a?(Synaptical::Network)
13
+ raise ArgumentError, 'Only Networks can be serialized'
14
+ end
15
+
16
+ list = network.neurons
17
+ neurons = []
18
+ connections = []
19
+
20
+ ids = {}
21
+ list.each_with_index do |nr, i|
22
+ neuron = nr[:neuron]
23
+ ids[neuron.id] = i
24
+ copy = {
25
+ trace: { elegibility: {}, extended: {} },
26
+ state: neuron.state,
27
+ old: neuron.old,
28
+ activation: neuron.activation,
29
+ bias: neuron.bias,
30
+ layer: nr[:layer],
31
+ squash: 'LOGISTIC'
32
+ }
33
+
34
+ neurons << copy
35
+ end
36
+
37
+ list.each do |nr|
38
+ neuron = nr[:neuron]
39
+
40
+ neuron.connections.projected.each do |_id, conn|
41
+ connections << {
42
+ from: ids[conn.from.id],
43
+ to: ids[conn.to.id],
44
+ weight: conn.weight,
45
+ gater: conn.gater ? ids[conn.gater.id] : nil
46
+ }
47
+ end
48
+
49
+ next unless neuron.selfconnected?
50
+
51
+ connections << {
52
+ from: ids[neuron.id],
53
+ to: ids[neuron.id],
54
+ weight: neuron.selfconnection.weight,
55
+ gater: neuron.selfconnection.gater ? ids[neuron.selfconnection.gater.id] : nil
56
+ }
57
+ end
58
+
59
+ { neurons: neurons, connections: connections }
60
+ end
61
+
62
+ # Produces a serialized JSON string
63
+ # @param network [Synaptical::Network] network to serialize
64
+ #
65
+ # @return [String] Serialized network as JSON
66
+ def dump(network)
67
+ require 'json'
68
+ JSON.dump(as_json(network))
69
+ end
70
+
71
+ # Loads a network from JSON
72
+ # @param json [String] JSON string
73
+ #
74
+ # @return [Synaptical::Network] De-serialized network
75
+ def load(_json)
76
+ raise 'TODO'
77
+ end
78
+ end
79
+ end
80
+ end
81
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ # Squashing functions
5
+ module Squash
6
+ # Logistic function
7
+ module Logistic
8
+ class << self
9
+ # Apply logistic function for x_val
10
+ # @param x_val [Numeric] X value
11
+ #
12
+ # @return [Float] Y value
13
+ def call(x_val)
14
+ 1.0 / (1.0 + ::Math.exp(-x_val))
15
+ end
16
+
17
+ # Calculate derivate of value
18
+ # @param x_val [Numeric] value
19
+ #
20
+ # @return [Float] Derivate value
21
+ def derivate(x_val)
22
+ (x_val * (1 - x_val))
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ # Squashing functions
5
+ module Squash
6
+ # Hyperbolic tangens function
7
+ module Tanh
8
+ class << self
9
+ # Apply hyperbolic tangens function for x_val
10
+ # @param x_val [Numeric] X value
11
+ #
12
+ # @return [Float] Y value
13
+ def call(x_val)
14
+ ::Math.tanh(x_val)
15
+ end
16
+
17
+ # Calculate derivate of value
18
+ # @param x_val [Numeric] value
19
+ #
20
+ # @return [Float] Derivate value
21
+ def derivate(x_val)
22
+ 1.0 - x_val**2
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ class Trainer
5
+ attr_accessor :network, :rate, :iterations, :error, :cost, :cross_validate
6
+
7
+ Result = Struct.new(:error, :iterations, :time)
8
+
9
+ # Creates a new network trainer
10
+ # @param network [Synaptical::Network] Network to train
11
+ # @param rate: 0.2 [Float, Proc] Learning rate
12
+ # @param iterations: 100_000 [Integer] Training iterations
13
+ # @param error: 0.005 [Float] Max error
14
+ # @param cost: nil [type] [description]
15
+ # @param cross_validate: nil [type] [description]
16
+ def initialize(network, rate: 0.2, iterations: 100_000, error: 0.005, cost: Synaptical::Cost::Mse, cross_validate: false)
17
+ @network = network
18
+ @rate = rate
19
+ @iterations = iterations
20
+ @error = error
21
+ @cost = cost
22
+ @cross_validate = cross_validate
23
+ end
24
+
25
+ def train(set, options = nil)
26
+ error = 1.0
27
+ stop = false
28
+ iterations = bucket_size = 0
29
+ current_rate = rate
30
+ cross_validate = false
31
+ cost = self.cost
32
+ start = Time.now
33
+
34
+ bucket_size = iterations.fdiv(rate.size).floor if rate.is_a?(Array)
35
+
36
+ if cross_validate
37
+ num_train = ((1 - cross_validate.test_size) * set.size).ceil
38
+ train_set = set[0..num_train]
39
+ test_set = set[num_train..-1]
40
+ end
41
+
42
+ last_error = 0.0
43
+
44
+ while !stop && iterations < self.iterations && error > self.error
45
+ break if cross_validate && error <= self.cross_validate
46
+
47
+ current_set_size = set.size
48
+ error = 0.0
49
+ iterations += 1
50
+
51
+ if bucket_size.positive?
52
+ current_bucket = iterations.fdiv(bucket_size).floor
53
+ current_rate = rate[current_bucket] || current_rate
54
+ end
55
+
56
+ current_rate = rate.call(iterations, last_error) if rate.is_a?(Proc)
57
+
58
+ if cross_validate
59
+ train_set(train_set, current_rate, cost)
60
+ error += test(test_set).error
61
+ current_set_size = 1
62
+ else
63
+ error += train_set(set, current_rate, cost)
64
+ end
65
+
66
+ error /= current_set_size.to_f
67
+ last_error = error
68
+
69
+ raise 'TODO' if options
70
+ end
71
+
72
+ Result.new(error, iterations, Time.now - start)
73
+ end
74
+
75
+ def train_set(set, current_rate, cost_function)
76
+ set.reduce(0.0) do |sum, item|
77
+ input = item[:input]
78
+ target = item[:output]
79
+ output = network.activate(input)
80
+ network.propagate(current_rate, target)
81
+ sum + cost_function.call(target, output)
82
+ end
83
+ end
84
+
85
+ def test(_set)
86
+ raise 'TODO'
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Synaptical
4
+ VERSION = '0.0.1-beta1'
5
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ lib = File.expand_path('../lib', __FILE__)
4
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
5
+ require 'synaptical/version'
6
+
7
+ Gem::Specification.new do |spec|
8
+ spec.name = 'synaptical'
9
+ spec.version = Synaptical::VERSION
10
+ spec.authors = ['Sebastian Wallin']
11
+ spec.email = ['sebastian.wallin@gmail.com']
12
+
13
+ spec.summary = 'Ruby port of Synaptic.js'
14
+ spec.description = 'Ruby port of Synaptic.js'
15
+ spec.homepage = 'https://github.com/castle/synaptical'
16
+
17
+ # Specify which files should be added to the gem when it is released.
18
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
19
+ spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
20
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
21
+ end
22
+ spec.bindir = 'exe'
23
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
24
+ spec.require_paths = ['lib']
25
+
26
+ spec.add_development_dependency 'bundler', '~> 1.16'
27
+ spec.add_development_dependency 'rake', '~> 10.0'
28
+ spec.add_development_dependency 'rspec', '~> 3.0'
29
+ end
metadata ADDED
@@ -0,0 +1,110 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: synaptical
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.0.1.pre.beta1
5
+ platform: ruby
6
+ authors:
7
+ - Sebastian Wallin
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2018-08-04 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: bundler
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '1.16'
20
+ type: :development
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '1.16'
27
+ - !ruby/object:Gem::Dependency
28
+ name: rake
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '10.0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '10.0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rspec
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '3.0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '3.0'
55
+ description: Ruby port of Synaptic.js
56
+ email:
57
+ - sebastian.wallin@gmail.com
58
+ executables: []
59
+ extensions: []
60
+ extra_rdoc_files: []
61
+ files:
62
+ - ".gitignore"
63
+ - ".rspec"
64
+ - ".ruby-gemset"
65
+ - ".ruby-version"
66
+ - ".travis.yml"
67
+ - Gemfile
68
+ - Gemfile.lock
69
+ - README.md
70
+ - Rakefile
71
+ - bin/console
72
+ - bin/setup
73
+ - lib/synaptical.rb
74
+ - lib/synaptical/architect/perceptron.rb
75
+ - lib/synaptical/connection.rb
76
+ - lib/synaptical/cost/mse.rb
77
+ - lib/synaptical/layer.rb
78
+ - lib/synaptical/layer_connection.rb
79
+ - lib/synaptical/network.rb
80
+ - lib/synaptical/neuron.rb
81
+ - lib/synaptical/serializer/json.rb
82
+ - lib/synaptical/squash/logistic.rb
83
+ - lib/synaptical/squash/tanh.rb
84
+ - lib/synaptical/trainer.rb
85
+ - lib/synaptical/version.rb
86
+ - synaptical.gemspec
87
+ homepage: https://github.com/castle/synaptical
88
+ licenses: []
89
+ metadata: {}
90
+ post_install_message:
91
+ rdoc_options: []
92
+ require_paths:
93
+ - lib
94
+ required_ruby_version: !ruby/object:Gem::Requirement
95
+ requirements:
96
+ - - ">="
97
+ - !ruby/object:Gem::Version
98
+ version: '0'
99
+ required_rubygems_version: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">"
102
+ - !ruby/object:Gem::Version
103
+ version: 1.3.1
104
+ requirements: []
105
+ rubyforge_project:
106
+ rubygems_version: 2.7.4
107
+ signing_key:
108
+ specification_version: 4
109
+ summary: Ruby port of Synaptic.js
110
+ test_files: []