synaptical 0.0.1.pre.beta1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +11 -0
- data/.rspec +3 -0
- data/.ruby-gemset +1 -0
- data/.ruby-version +1 -0
- data/.travis.yml +7 -0
- data/Gemfile +14 -0
- data/Gemfile.lock +47 -0
- data/README.md +101 -0
- data/Rakefile +96 -0
- data/bin/console +15 -0
- data/bin/setup +8 -0
- data/lib/synaptical.rb +18 -0
- data/lib/synaptical/architect/perceptron.rb +32 -0
- data/lib/synaptical/connection.rb +31 -0
- data/lib/synaptical/cost/mse.rb +21 -0
- data/lib/synaptical/layer.rb +143 -0
- data/lib/synaptical/layer_connection.rb +74 -0
- data/lib/synaptical/network.rb +125 -0
- data/lib/synaptical/neuron.rb +312 -0
- data/lib/synaptical/serializer/json.rb +81 -0
- data/lib/synaptical/squash/logistic.rb +27 -0
- data/lib/synaptical/squash/tanh.rb +27 -0
- data/lib/synaptical/trainer.rb +89 -0
- data/lib/synaptical/version.rb +5 -0
- data/synaptical.gemspec +29 -0
- metadata +110 -0
@@ -0,0 +1,81 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
module Serializer
|
5
|
+
module JSON
|
6
|
+
class << self
|
7
|
+
# Generates a serialized Hash from a network
|
8
|
+
# @param network [Synaptical::Network] Network to serialize
|
9
|
+
#
|
10
|
+
# @return [Hash] Serialized network as hash
|
11
|
+
def as_json(network)
|
12
|
+
unless network.is_a?(Synaptical::Network)
|
13
|
+
raise ArgumentError, 'Only Networks can be serialized'
|
14
|
+
end
|
15
|
+
|
16
|
+
list = network.neurons
|
17
|
+
neurons = []
|
18
|
+
connections = []
|
19
|
+
|
20
|
+
ids = {}
|
21
|
+
list.each_with_index do |nr, i|
|
22
|
+
neuron = nr[:neuron]
|
23
|
+
ids[neuron.id] = i
|
24
|
+
copy = {
|
25
|
+
trace: { elegibility: {}, extended: {} },
|
26
|
+
state: neuron.state,
|
27
|
+
old: neuron.old,
|
28
|
+
activation: neuron.activation,
|
29
|
+
bias: neuron.bias,
|
30
|
+
layer: nr[:layer],
|
31
|
+
squash: 'LOGISTIC'
|
32
|
+
}
|
33
|
+
|
34
|
+
neurons << copy
|
35
|
+
end
|
36
|
+
|
37
|
+
list.each do |nr|
|
38
|
+
neuron = nr[:neuron]
|
39
|
+
|
40
|
+
neuron.connections.projected.each do |_id, conn|
|
41
|
+
connections << {
|
42
|
+
from: ids[conn.from.id],
|
43
|
+
to: ids[conn.to.id],
|
44
|
+
weight: conn.weight,
|
45
|
+
gater: conn.gater ? ids[conn.gater.id] : nil
|
46
|
+
}
|
47
|
+
end
|
48
|
+
|
49
|
+
next unless neuron.selfconnected?
|
50
|
+
|
51
|
+
connections << {
|
52
|
+
from: ids[neuron.id],
|
53
|
+
to: ids[neuron.id],
|
54
|
+
weight: neuron.selfconnection.weight,
|
55
|
+
gater: neuron.selfconnection.gater ? ids[neuron.selfconnection.gater.id] : nil
|
56
|
+
}
|
57
|
+
end
|
58
|
+
|
59
|
+
{ neurons: neurons, connections: connections }
|
60
|
+
end
|
61
|
+
|
62
|
+
# Produces a serialized JSON string
|
63
|
+
# @param network [Synaptical::Network] network to serialize
|
64
|
+
#
|
65
|
+
# @return [String] Serialized network as JSON
|
66
|
+
def dump(network)
|
67
|
+
require 'json'
|
68
|
+
JSON.dump(as_json(network))
|
69
|
+
end
|
70
|
+
|
71
|
+
# Loads a network from JSON
|
72
|
+
# @param json [String] JSON string
|
73
|
+
#
|
74
|
+
# @return [Synaptical::Network] De-serialized network
|
75
|
+
def load(_json)
|
76
|
+
raise 'TODO'
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
# Squashing functions
|
5
|
+
module Squash
|
6
|
+
# Logistic function
|
7
|
+
module Logistic
|
8
|
+
class << self
|
9
|
+
# Apply logistic function for x_val
|
10
|
+
# @param x_val [Numeric] X value
|
11
|
+
#
|
12
|
+
# @return [Float] Y value
|
13
|
+
def call(x_val)
|
14
|
+
1.0 / (1.0 + ::Math.exp(-x_val))
|
15
|
+
end
|
16
|
+
|
17
|
+
# Calculate derivate of value
|
18
|
+
# @param x_val [Numeric] value
|
19
|
+
#
|
20
|
+
# @return [Float] Derivate value
|
21
|
+
def derivate(x_val)
|
22
|
+
(x_val * (1 - x_val))
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
# Squashing functions
|
5
|
+
module Squash
|
6
|
+
# Hyperbolic tangens function
|
7
|
+
module Tanh
|
8
|
+
class << self
|
9
|
+
# Apply hyperbolic tangens function for x_val
|
10
|
+
# @param x_val [Numeric] X value
|
11
|
+
#
|
12
|
+
# @return [Float] Y value
|
13
|
+
def call(x_val)
|
14
|
+
::Math.tanh(x_val)
|
15
|
+
end
|
16
|
+
|
17
|
+
# Calculate derivate of value
|
18
|
+
# @param x_val [Numeric] value
|
19
|
+
#
|
20
|
+
# @return [Float] Derivate value
|
21
|
+
def derivate(x_val)
|
22
|
+
1.0 - x_val**2
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Synaptical
|
4
|
+
class Trainer
|
5
|
+
attr_accessor :network, :rate, :iterations, :error, :cost, :cross_validate
|
6
|
+
|
7
|
+
Result = Struct.new(:error, :iterations, :time)
|
8
|
+
|
9
|
+
# Creates a new network trainer
|
10
|
+
# @param network [Synaptical::Network] Network to train
|
11
|
+
# @param rate: 0.2 [Float, Proc] Learning rate
|
12
|
+
# @param iterations: 100_000 [Integer] Training iterations
|
13
|
+
# @param error: 0.005 [Float] Max error
|
14
|
+
# @param cost: nil [type] [description]
|
15
|
+
# @param cross_validate: nil [type] [description]
|
16
|
+
def initialize(network, rate: 0.2, iterations: 100_000, error: 0.005, cost: Synaptical::Cost::Mse, cross_validate: false)
|
17
|
+
@network = network
|
18
|
+
@rate = rate
|
19
|
+
@iterations = iterations
|
20
|
+
@error = error
|
21
|
+
@cost = cost
|
22
|
+
@cross_validate = cross_validate
|
23
|
+
end
|
24
|
+
|
25
|
+
def train(set, options = nil)
|
26
|
+
error = 1.0
|
27
|
+
stop = false
|
28
|
+
iterations = bucket_size = 0
|
29
|
+
current_rate = rate
|
30
|
+
cross_validate = false
|
31
|
+
cost = self.cost
|
32
|
+
start = Time.now
|
33
|
+
|
34
|
+
bucket_size = iterations.fdiv(rate.size).floor if rate.is_a?(Array)
|
35
|
+
|
36
|
+
if cross_validate
|
37
|
+
num_train = ((1 - cross_validate.test_size) * set.size).ceil
|
38
|
+
train_set = set[0..num_train]
|
39
|
+
test_set = set[num_train..-1]
|
40
|
+
end
|
41
|
+
|
42
|
+
last_error = 0.0
|
43
|
+
|
44
|
+
while !stop && iterations < self.iterations && error > self.error
|
45
|
+
break if cross_validate && error <= self.cross_validate
|
46
|
+
|
47
|
+
current_set_size = set.size
|
48
|
+
error = 0.0
|
49
|
+
iterations += 1
|
50
|
+
|
51
|
+
if bucket_size.positive?
|
52
|
+
current_bucket = iterations.fdiv(bucket_size).floor
|
53
|
+
current_rate = rate[current_bucket] || current_rate
|
54
|
+
end
|
55
|
+
|
56
|
+
current_rate = rate.call(iterations, last_error) if rate.is_a?(Proc)
|
57
|
+
|
58
|
+
if cross_validate
|
59
|
+
train_set(train_set, current_rate, cost)
|
60
|
+
error += test(test_set).error
|
61
|
+
current_set_size = 1
|
62
|
+
else
|
63
|
+
error += train_set(set, current_rate, cost)
|
64
|
+
end
|
65
|
+
|
66
|
+
error /= current_set_size.to_f
|
67
|
+
last_error = error
|
68
|
+
|
69
|
+
raise 'TODO' if options
|
70
|
+
end
|
71
|
+
|
72
|
+
Result.new(error, iterations, Time.now - start)
|
73
|
+
end
|
74
|
+
|
75
|
+
def train_set(set, current_rate, cost_function)
|
76
|
+
set.reduce(0.0) do |sum, item|
|
77
|
+
input = item[:input]
|
78
|
+
target = item[:output]
|
79
|
+
output = network.activate(input)
|
80
|
+
network.propagate(current_rate, target)
|
81
|
+
sum + cost_function.call(target, output)
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
def test(_set)
|
86
|
+
raise 'TODO'
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
data/synaptical.gemspec
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
lib = File.expand_path('../lib', __FILE__)
|
4
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
5
|
+
require 'synaptical/version'
|
6
|
+
|
7
|
+
Gem::Specification.new do |spec|
|
8
|
+
spec.name = 'synaptical'
|
9
|
+
spec.version = Synaptical::VERSION
|
10
|
+
spec.authors = ['Sebastian Wallin']
|
11
|
+
spec.email = ['sebastian.wallin@gmail.com']
|
12
|
+
|
13
|
+
spec.summary = 'Ruby port of Synaptic.js'
|
14
|
+
spec.description = 'Ruby port of Synaptic.js'
|
15
|
+
spec.homepage = 'https://github.com/castle/synaptical'
|
16
|
+
|
17
|
+
# Specify which files should be added to the gem when it is released.
|
18
|
+
# The `git ls-files -z` loads the files in the RubyGem that have been added into git.
|
19
|
+
spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
|
20
|
+
`git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
21
|
+
end
|
22
|
+
spec.bindir = 'exe'
|
23
|
+
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
24
|
+
spec.require_paths = ['lib']
|
25
|
+
|
26
|
+
spec.add_development_dependency 'bundler', '~> 1.16'
|
27
|
+
spec.add_development_dependency 'rake', '~> 10.0'
|
28
|
+
spec.add_development_dependency 'rspec', '~> 3.0'
|
29
|
+
end
|
metadata
ADDED
@@ -0,0 +1,110 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: synaptical
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.0.1.pre.beta1
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Sebastian Wallin
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2018-08-04 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: bundler
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '1.16'
|
20
|
+
type: :development
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '1.16'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rake
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - "~>"
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '10.0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - "~>"
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '10.0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rspec
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '3.0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '3.0'
|
55
|
+
description: Ruby port of Synaptic.js
|
56
|
+
email:
|
57
|
+
- sebastian.wallin@gmail.com
|
58
|
+
executables: []
|
59
|
+
extensions: []
|
60
|
+
extra_rdoc_files: []
|
61
|
+
files:
|
62
|
+
- ".gitignore"
|
63
|
+
- ".rspec"
|
64
|
+
- ".ruby-gemset"
|
65
|
+
- ".ruby-version"
|
66
|
+
- ".travis.yml"
|
67
|
+
- Gemfile
|
68
|
+
- Gemfile.lock
|
69
|
+
- README.md
|
70
|
+
- Rakefile
|
71
|
+
- bin/console
|
72
|
+
- bin/setup
|
73
|
+
- lib/synaptical.rb
|
74
|
+
- lib/synaptical/architect/perceptron.rb
|
75
|
+
- lib/synaptical/connection.rb
|
76
|
+
- lib/synaptical/cost/mse.rb
|
77
|
+
- lib/synaptical/layer.rb
|
78
|
+
- lib/synaptical/layer_connection.rb
|
79
|
+
- lib/synaptical/network.rb
|
80
|
+
- lib/synaptical/neuron.rb
|
81
|
+
- lib/synaptical/serializer/json.rb
|
82
|
+
- lib/synaptical/squash/logistic.rb
|
83
|
+
- lib/synaptical/squash/tanh.rb
|
84
|
+
- lib/synaptical/trainer.rb
|
85
|
+
- lib/synaptical/version.rb
|
86
|
+
- synaptical.gemspec
|
87
|
+
homepage: https://github.com/castle/synaptical
|
88
|
+
licenses: []
|
89
|
+
metadata: {}
|
90
|
+
post_install_message:
|
91
|
+
rdoc_options: []
|
92
|
+
require_paths:
|
93
|
+
- lib
|
94
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
95
|
+
requirements:
|
96
|
+
- - ">="
|
97
|
+
- !ruby/object:Gem::Version
|
98
|
+
version: '0'
|
99
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">"
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: 1.3.1
|
104
|
+
requirements: []
|
105
|
+
rubyforge_project:
|
106
|
+
rubygems_version: 2.7.4
|
107
|
+
signing_key:
|
108
|
+
specification_version: 4
|
109
|
+
summary: Ruby port of Synaptic.js
|
110
|
+
test_files: []
|