synaptical 0.0.1.pre.beta1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +11 -0
- data/.rspec +3 -0
- data/.ruby-gemset +1 -0
- data/.ruby-version +1 -0
- data/.travis.yml +7 -0
- data/Gemfile +14 -0
- data/Gemfile.lock +47 -0
- data/README.md +101 -0
- data/Rakefile +96 -0
- data/bin/console +15 -0
- data/bin/setup +8 -0
- data/lib/synaptical.rb +18 -0
- data/lib/synaptical/architect/perceptron.rb +32 -0
- data/lib/synaptical/connection.rb +31 -0
- data/lib/synaptical/cost/mse.rb +21 -0
- data/lib/synaptical/layer.rb +143 -0
- data/lib/synaptical/layer_connection.rb +74 -0
- data/lib/synaptical/network.rb +125 -0
- data/lib/synaptical/neuron.rb +312 -0
- data/lib/synaptical/serializer/json.rb +81 -0
- data/lib/synaptical/squash/logistic.rb +27 -0
- data/lib/synaptical/squash/tanh.rb +27 -0
- data/lib/synaptical/trainer.rb +89 -0
- data/lib/synaptical/version.rb +5 -0
- data/synaptical.gemspec +29 -0
- metadata +110 -0
@@ -0,0 +1,81 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
module Serializer
|
5
|
+
module JSON
|
6
|
+
class << self
|
7
|
+
# Generates a serialized Hash from a network
|
8
|
+
# @param network [Synaptical::Network] Network to serialize
|
9
|
+
#
|
10
|
+
# @return [Hash] Serialized network as hash
|
11
|
+
def as_json(network)
|
12
|
+
unless network.is_a?(Synaptical::Network)
|
13
|
+
raise ArgumentError, 'Only Networks can be serialized'
|
14
|
+
end
|
15
|
+
|
16
|
+
list = network.neurons
|
17
|
+
neurons = []
|
18
|
+
connections = []
|
19
|
+
|
20
|
+
ids = {}
|
21
|
+
list.each_with_index do |nr, i|
|
22
|
+
neuron = nr[:neuron]
|
23
|
+
ids[neuron.id] = i
|
24
|
+
copy = {
|
25
|
+
trace: { elegibility: {}, extended: {} },
|
26
|
+
state: neuron.state,
|
27
|
+
old: neuron.old,
|
28
|
+
activation: neuron.activation,
|
29
|
+
bias: neuron.bias,
|
30
|
+
layer: nr[:layer],
|
31
|
+
squash: 'LOGISTIC'
|
32
|
+
}
|
33
|
+
|
34
|
+
neurons << copy
|
35
|
+
end
|
36
|
+
|
37
|
+
list.each do |nr|
|
38
|
+
neuron = nr[:neuron]
|
39
|
+
|
40
|
+
neuron.connections.projected.each do |_id, conn|
|
41
|
+
connections << {
|
42
|
+
from: ids[conn.from.id],
|
43
|
+
to: ids[conn.to.id],
|
44
|
+
weight: conn.weight,
|
45
|
+
gater: conn.gater ? ids[conn.gater.id] : nil
|
46
|
+
}
|
47
|
+
end
|
48
|
+
|
49
|
+
next unless neuron.selfconnected?
|
50
|
+
|
51
|
+
connections << {
|
52
|
+
from: ids[neuron.id],
|
53
|
+
to: ids[neuron.id],
|
54
|
+
weight: neuron.selfconnection.weight,
|
55
|
+
gater: neuron.selfconnection.gater ? ids[neuron.selfconnection.gater.id] : nil
|
56
|
+
}
|
57
|
+
end
|
58
|
+
|
59
|
+
{ neurons: neurons, connections: connections }
|
60
|
+
end
|
61
|
+
|
62
|
+
# Produces a serialized JSON string
|
63
|
+
# @param network [Synaptical::Network] network to serialize
|
64
|
+
#
|
65
|
+
# @return [String] Serialized network as JSON
|
66
|
+
def dump(network)
|
67
|
+
require 'json'
|
68
|
+
JSON.dump(as_json(network))
|
69
|
+
end
|
70
|
+
|
71
|
+
# Loads a network from JSON
|
72
|
+
# @param json [String] JSON string
|
73
|
+
#
|
74
|
+
# @return [Synaptical::Network] De-serialized network
|
75
|
+
def load(_json)
|
76
|
+
raise 'TODO'
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
# Squashing functions
|
5
|
+
module Squash
|
6
|
+
# Logistic function
|
7
|
+
module Logistic
|
8
|
+
class << self
|
9
|
+
# Apply logistic function for x_val
|
10
|
+
# @param x_val [Numeric] X value
|
11
|
+
#
|
12
|
+
# @return [Float] Y value
|
13
|
+
def call(x_val)
|
14
|
+
1.0 / (1.0 + ::Math.exp(-x_val))
|
15
|
+
end
|
16
|
+
|
17
|
+
# Calculate derivate of value
|
18
|
+
# @param x_val [Numeric] value
|
19
|
+
#
|
20
|
+
# @return [Float] Derivate value
|
21
|
+
def derivate(x_val)
|
22
|
+
(x_val * (1 - x_val))
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
# Squashing functions
|
5
|
+
module Squash
|
6
|
+
# Hyperbolic tangens function
|
7
|
+
module Tanh
|
8
|
+
class << self
|
9
|
+
# Apply hyperbolic tangens function for x_val
|
10
|
+
# @param x_val [Numeric] X value
|
11
|
+
#
|
12
|
+
# @return [Float] Y value
|
13
|
+
def call(x_val)
|
14
|
+
::Math.tanh(x_val)
|
15
|
+
end
|
16
|
+
|
17
|
+
# Calculate derivate of value
|
18
|
+
# @param x_val [Numeric] value
|
19
|
+
#
|
20
|
+
# @return [Float] Derivate value
|
21
|
+
def derivate(x_val)
|
22
|
+
1.0 - x_val**2
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
class Trainer
|
5
|
+
attr_accessor :network, :rate, :iterations, :error, :cost, :cross_validate
|
6
|
+
|
7
|
+
Result = Struct.new(:error, :iterations, :time)
|
8
|
+
|
9
|
+
# Creates a new network trainer
|
10
|
+
# @param network [Synaptical::Network] Network to train
|
11
|
+
# @param rate: 0.2 [Float, Proc] Learning rate
|
12
|
+
# @param iterations: 100_000 [Integer] Training iterations
|
13
|
+
# @param error: 0.005 [Float] Max error
|
14
|
+
# @param cost: nil [type] [description]
|
15
|
+
# @param cross_validate: nil [type] [description]
|
16
|
+
def initialize(network, rate: 0.2, iterations: 100_000, error: 0.005, cost: Synaptical::Cost::Mse, cross_validate: false)
|
17
|
+
@network = network
|
18
|
+
@rate = rate
|
19
|
+
@iterations = iterations
|
20
|
+
@error = error
|
21
|
+
@cost = cost
|
22
|
+
@cross_validate = cross_validate
|
23
|
+
end
|
24
|
+
|
25
|
+
def train(set, options = nil)
|
26
|
+
error = 1.0
|
27
|
+
stop = false
|
28
|
+
iterations = bucket_size = 0
|
29
|
+
current_rate = rate
|
30
|
+
cross_validate = false
|
31
|
+
cost = self.cost
|
32
|
+
start = Time.now
|
33
|
+
|
34
|
+
bucket_size = iterations.fdiv(rate.size).floor if rate.is_a?(Array)
|
35
|
+
|
36
|
+
if cross_validate
|
37
|
+
num_train = ((1 - cross_validate.test_size) * set.size).ceil
|
38
|
+
train_set = set[0..num_train]
|
39
|
+
test_set = set[num_train..-1]
|
40
|
+
end
|
41
|
+
|
42
|
+
last_error = 0.0
|
43
|
+
|
44
|
+
while !stop && iterations < self.iterations && error > self.error
|
45
|
+
break if cross_validate && error <= self.cross_validate
|
46
|
+
|
47
|
+
current_set_size = set.size
|
48
|
+
error = 0.0
|
49
|
+
iterations += 1
|
50
|
+
|
51
|
+
if bucket_size.positive?
|
52
|
+
current_bucket = iterations.fdiv(bucket_size).floor
|
53
|
+
current_rate = rate[current_bucket] || current_rate
|
54
|
+
end
|
55
|
+
|
56
|
+
current_rate = rate.call(iterations, last_error) if rate.is_a?(Proc)
|
57
|
+
|
58
|
+
if cross_validate
|
59
|
+
train_set(train_set, current_rate, cost)
|
60
|
+
error += test(test_set).error
|
61
|
+
current_set_size = 1
|
62
|
+
else
|
63
|
+
error += train_set(set, current_rate, cost)
|
64
|
+
end
|
65
|
+
|
66
|
+
error /= current_set_size.to_f
|
67
|
+
last_error = error
|
68
|
+
|
69
|
+
raise 'TODO' if options
|
70
|
+
end
|
71
|
+
|
72
|
+
Result.new(error, iterations, Time.now - start)
|
73
|
+
end
|
74
|
+
|
75
|
+
def train_set(set, current_rate, cost_function)
|
76
|
+
set.reduce(0.0) do |sum, item|
|
77
|
+
input = item[:input]
|
78
|
+
target = item[:output]
|
79
|
+
output = network.activate(input)
|
80
|
+
network.propagate(current_rate, target)
|
81
|
+
sum + cost_function.call(target, output)
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
def test(_set)
|
86
|
+
raise 'TODO'
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
data/synaptical.gemspec
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
lib = File.expand_path('../lib', __FILE__)
|
4
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
5
|
+
require 'synaptical/version'
|
6
|
+
|
7
|
+
Gem::Specification.new do |spec|
|
8
|
+
spec.name = 'synaptical'
|
9
|
+
spec.version = Synaptical::VERSION
|
10
|
+
spec.authors = ['Sebastian Wallin']
|
11
|
+
spec.email = ['sebastian.wallin@gmail.com']
|
12
|
+
|
13
|
+
spec.summary = 'Ruby port of Synaptic.js'
|
14
|
+
spec.description = 'Ruby port of Synaptic.js'
|
15
|
+
spec.homepage = 'https://github.com/castle/synaptical'
|
16
|
+
|
17
|
+
# Specify which files should be added to the gem when it is released.
|
18
|
+
# The `git ls-files -z` loads the files in the RubyGem that have been added into git.
|
19
|
+
spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
|
20
|
+
`git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
21
|
+
end
|
22
|
+
spec.bindir = 'exe'
|
23
|
+
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
24
|
+
spec.require_paths = ['lib']
|
25
|
+
|
26
|
+
spec.add_development_dependency 'bundler', '~> 1.16'
|
27
|
+
spec.add_development_dependency 'rake', '~> 10.0'
|
28
|
+
spec.add_development_dependency 'rspec', '~> 3.0'
|
29
|
+
end
|
metadata
ADDED
@@ -0,0 +1,110 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: synaptical
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.0.1.pre.beta1
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Sebastian Wallin
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2018-08-04 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: bundler
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '1.16'
|
20
|
+
type: :development
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '1.16'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rake
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '10.0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '10.0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rspec
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '3.0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '3.0'
|
55
|
+
description: Ruby port of Synaptic.js
|
56
|
+
email:
|
57
|
+
- sebastian.wallin@gmail.com
|
58
|
+
executables: []
|
59
|
+
extensions: []
|
60
|
+
extra_rdoc_files: []
|
61
|
+
files:
|
62
|
+
- ".gitignore"
|
63
|
+
- ".rspec"
|
64
|
+
- ".ruby-gemset"
|
65
|
+
- ".ruby-version"
|
66
|
+
- ".travis.yml"
|
67
|
+
- Gemfile
|
68
|
+
- Gemfile.lock
|
69
|
+
- README.md
|
70
|
+
- Rakefile
|
71
|
+
- bin/console
|
72
|
+
- bin/setup
|
73
|
+
- lib/synaptical.rb
|
74
|
+
- lib/synaptical/architect/perceptron.rb
|
75
|
+
- lib/synaptical/connection.rb
|
76
|
+
- lib/synaptical/cost/mse.rb
|
77
|
+
- lib/synaptical/layer.rb
|
78
|
+
- lib/synaptical/layer_connection.rb
|
79
|
+
- lib/synaptical/network.rb
|
80
|
+
- lib/synaptical/neuron.rb
|
81
|
+
- lib/synaptical/serializer/json.rb
|
82
|
+
- lib/synaptical/squash/logistic.rb
|
83
|
+
- lib/synaptical/squash/tanh.rb
|
84
|
+
- lib/synaptical/trainer.rb
|
85
|
+
- lib/synaptical/version.rb
|
86
|
+
- synaptical.gemspec
|
87
|
+
homepage: https://github.com/castle/synaptical
|
88
|
+
licenses: []
|
89
|
+
metadata: {}
|
90
|
+
post_install_message:
|
91
|
+
rdoc_options: []
|
92
|
+
require_paths:
|
93
|
+
- lib
|
94
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
95
|
+
requirements:
|
96
|
+
- - ">="
|
97
|
+
- !ruby/object:Gem::Version
|
98
|
+
version: '0'
|
99
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">"
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: 1.3.1
|
104
|
+
requirements: []
|
105
|
+
rubyforge_project:
|
106
|
+
rubygems_version: 2.7.4
|
107
|
+
signing_key:
|
108
|
+
specification_version: 4
|
109
|
+
summary: Ruby port of Synaptic.js
|
110
|
+
test_files: []
|