ai4r 1.13 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.rb +14 -11
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +40 -34
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +48 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +524 -145
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +29 -14
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +270 -135
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +223 -103
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +62 -62
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +399 -134
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +57 -28
- metadata +79 -109
- data/README.rdoc +0 -39
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -220
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/simple_linear_regression_test.rb +0 -37
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -167
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -104
- data/test/data/proximity_test.rb +0 -87
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
@@ -0,0 +1,76 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: OpenAI Assistant
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
#
|
7
|
+
# Generic A* search algorithm.
|
8
|
+
|
9
|
+
require 'set'
|
10
|
+
|
11
|
+
module Ai4r
|
12
|
+
module Search
|
13
|
+
# Generic A* search implementation operating on arbitrary graph
|
14
|
+
# representations.
|
15
|
+
#
|
16
|
+
# Initialize with start node, a goal predicate, a neighbor function and a
|
17
|
+
# heuristic.
|
18
|
+
#
|
19
|
+
# The neighbor function must return a hash mapping neighbor nodes to edge
|
20
|
+
# costs. The heuristic receives a node and returns the estimated remaining
|
21
|
+
# cost to reach the goal.
|
22
|
+
class AStar
|
23
|
+
# @param start [Object] initial node
|
24
|
+
# @param goal_test [Proc] predicate returning true when node is goal
|
25
|
+
# @param neighbor_fn [Proc] -> node { neighbor => cost, ... }
|
26
|
+
# @param heuristic_fn [Proc] -> node => estimated remaining cost
|
27
|
+
def initialize(start, goal_test, neighbor_fn, heuristic_fn)
|
28
|
+
@start = start
|
29
|
+
@goal_test = goal_test
|
30
|
+
@neighbor_fn = neighbor_fn
|
31
|
+
@heuristic_fn = heuristic_fn
|
32
|
+
end
|
33
|
+
|
34
|
+
# Execute the search and return the path as an array of nodes or nil
|
35
|
+
# if the goal cannot be reached.
|
36
|
+
def search
|
37
|
+
open_set = [@start]
|
38
|
+
came_from = {}
|
39
|
+
g = { @start => 0 }
|
40
|
+
f = { @start => @heuristic_fn.call(@start) }
|
41
|
+
closed = Set.new
|
42
|
+
|
43
|
+
until open_set.empty?
|
44
|
+
current = open_set.min_by { |n| f[n] }
|
45
|
+
return reconstruct_path(came_from, current) if @goal_test.call(current)
|
46
|
+
|
47
|
+
open_set.delete(current)
|
48
|
+
closed << current
|
49
|
+
@neighbor_fn.call(current).each do |neighbor, cost|
|
50
|
+
next if closed.include?(neighbor)
|
51
|
+
|
52
|
+
tentative_g = g[current] + cost
|
53
|
+
next unless g[neighbor].nil? || tentative_g < g[neighbor]
|
54
|
+
|
55
|
+
came_from[neighbor] = current
|
56
|
+
g[neighbor] = tentative_g
|
57
|
+
f[neighbor] = tentative_g + @heuristic_fn.call(neighbor)
|
58
|
+
open_set << neighbor unless open_set.include?(neighbor)
|
59
|
+
end
|
60
|
+
end
|
61
|
+
nil
|
62
|
+
end
|
63
|
+
|
64
|
+
private
|
65
|
+
|
66
|
+
def reconstruct_path(came_from, node)
|
67
|
+
path = [node]
|
68
|
+
while came_from.key?(node)
|
69
|
+
node = came_from[node]
|
70
|
+
path.unshift(node)
|
71
|
+
end
|
72
|
+
path
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: OpenAI Assistant
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
#
|
7
|
+
# Basic breadth-first search implementation.
|
8
|
+
|
9
|
+
module Ai4r
|
10
|
+
module Search
|
11
|
+
# Explore nodes in breadth-first order until a goal is found.
|
12
|
+
class BFS
|
13
|
+
# Create a new BFS searcher.
|
14
|
+
#
|
15
|
+
# goal_test:: lambda returning true for a goal node
|
16
|
+
# neighbor_fn:: lambda returning adjacent nodes for a given node
|
17
|
+
# start:: optional starting node
|
18
|
+
def initialize(goal_test, neighbor_fn, start = nil)
|
19
|
+
@goal_test = goal_test
|
20
|
+
@neighbor_fn = neighbor_fn
|
21
|
+
@start = start
|
22
|
+
end
|
23
|
+
|
24
|
+
# Find a path from the start node to a goal.
|
25
|
+
#
|
26
|
+
# start:: initial node if not provided on initialization
|
27
|
+
#
|
28
|
+
# Returns an array of nodes representing the path, or nil if no goal was found.
|
29
|
+
def search(start = nil)
|
30
|
+
start ||= @start
|
31
|
+
raise ArgumentError, 'start node required' unless start
|
32
|
+
|
33
|
+
queue = [[start, [start]]]
|
34
|
+
visited = { start => true }
|
35
|
+
until queue.empty?
|
36
|
+
node, path = queue.shift
|
37
|
+
return path if @goal_test.call(node)
|
38
|
+
|
39
|
+
@neighbor_fn.call(node).each do |n|
|
40
|
+
next if visited[n]
|
41
|
+
|
42
|
+
visited[n] = true
|
43
|
+
queue << [n, path + [n]]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
nil
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: OpenAI Assistant
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
#
|
7
|
+
# Basic depth-first search implementation.
|
8
|
+
|
9
|
+
module Ai4r
|
10
|
+
module Search
|
11
|
+
# Explore nodes in depth-first order until a goal is found.
|
12
|
+
class DFS
|
13
|
+
# Create a new DFS searcher.
|
14
|
+
#
|
15
|
+
# goal_test:: lambda returning true for a goal node
|
16
|
+
# neighbor_fn:: lambda returning adjacent nodes for a given node
|
17
|
+
# start:: optional starting node
|
18
|
+
def initialize(goal_test, neighbor_fn, start = nil)
|
19
|
+
@goal_test = goal_test
|
20
|
+
@neighbor_fn = neighbor_fn
|
21
|
+
@start = start
|
22
|
+
end
|
23
|
+
|
24
|
+
# Find a path from the start node to a goal.
|
25
|
+
#
|
26
|
+
# start:: initial node if not provided on initialization
|
27
|
+
#
|
28
|
+
# Returns an array of nodes representing the path, or nil if no goal was found.
|
29
|
+
def search(start = nil)
|
30
|
+
start ||= @start
|
31
|
+
raise ArgumentError, 'start node required' unless start
|
32
|
+
|
33
|
+
stack = [[start, [start]]]
|
34
|
+
visited = { start => true }
|
35
|
+
until stack.empty?
|
36
|
+
node, path = stack.pop
|
37
|
+
return path if @goal_test.call(node)
|
38
|
+
|
39
|
+
@neighbor_fn.call(node).each do |n|
|
40
|
+
next if visited[n]
|
41
|
+
|
42
|
+
visited[n] = true
|
43
|
+
stack << [n, path + [n]]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
nil
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -0,0 +1,118 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: OpenAI Assistant
|
4
|
+
# License:: MPL 1.1
|
5
|
+
# Project:: ai4r
|
6
|
+
#
|
7
|
+
# Minimal Monte Carlo Tree Search implementation.
|
8
|
+
|
9
|
+
require_relative '../data/parameterizable'
|
10
|
+
|
11
|
+
module Ai4r
|
12
|
+
module Search
|
13
|
+
# Basic UCT-style Monte Carlo Tree Search.
|
14
|
+
#
|
15
|
+
# This generic implementation expects four callbacks:
|
16
|
+
# - +actions_fn.call(state)+ returns available actions for a state.
|
17
|
+
# - +transition_fn.call(state, action)+ computes the next state.
|
18
|
+
# - +terminal_fn.call(state)+ returns true if the state has no children.
|
19
|
+
# - +reward_fn.call(state)+ yields a numeric payoff for terminal states.
|
20
|
+
#
|
21
|
+
# Example:
|
22
|
+
# env = {
|
23
|
+
# actions_fn: ->(s) { s == :root ? %i[a b] : [] },
|
24
|
+
# transition_fn: ->(s, a) { a == :a ? :win : :lose },
|
25
|
+
# terminal_fn: ->(s) { %i[win lose].include?(s) },
|
26
|
+
# reward_fn: ->(s) { s == :win ? 1.0 : 0.0 }
|
27
|
+
# }
|
28
|
+
# mcts = Ai4r::Search::MCTS.new(**env)
|
29
|
+
# best = mcts.search(:root, 50)
|
30
|
+
# # => :a
|
31
|
+
class MCTS
|
32
|
+
include Ai4r::Data::Parameterizable
|
33
|
+
|
34
|
+
Node = Struct.new(:state, :parent, :action, :children, :visits, :value) do
|
35
|
+
def initialize(state, parent = nil, action = nil)
|
36
|
+
super(state, parent, action, [], 0, 0.0)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
parameters_info exploration: 'UCT exploration constant'
|
41
|
+
|
42
|
+
# Create a new search object.
|
43
|
+
#
|
44
|
+
# actions_fn:: returns available actions for a state
|
45
|
+
# transition_fn:: computes the next state given a state and action
|
46
|
+
# terminal_fn:: predicate to detect terminal states
|
47
|
+
# reward_fn:: numeric payoff for terminal states
|
48
|
+
def initialize(actions_fn:, transition_fn:, terminal_fn:, reward_fn:,
|
49
|
+
exploration: Math.sqrt(2))
|
50
|
+
@actions_fn = actions_fn
|
51
|
+
@transition_fn = transition_fn
|
52
|
+
@terminal_fn = terminal_fn
|
53
|
+
@reward_fn = reward_fn
|
54
|
+
@exploration = exploration
|
55
|
+
end
|
56
|
+
|
57
|
+
# Run MCTS starting from +root_state+ for a number of +iterations+.
|
58
|
+
# Returns the action considered best from the root.
|
59
|
+
def search(root_state, iterations)
|
60
|
+
root = Node.new(root_state)
|
61
|
+
iterations.times do
|
62
|
+
node = tree_policy(root)
|
63
|
+
reward = default_policy(node.state)
|
64
|
+
backup(node, reward)
|
65
|
+
end
|
66
|
+
best_child(root, 0)&.action
|
67
|
+
end
|
68
|
+
|
69
|
+
private
|
70
|
+
|
71
|
+
def tree_policy(node)
|
72
|
+
until @terminal_fn.call(node.state)
|
73
|
+
actions = @actions_fn.call(node.state)
|
74
|
+
return expand(node, actions) if node.children.length < actions.length
|
75
|
+
|
76
|
+
node = best_child(node, @exploration)
|
77
|
+
|
78
|
+
end
|
79
|
+
node
|
80
|
+
end
|
81
|
+
|
82
|
+
def expand(node, actions)
|
83
|
+
tried = node.children.map(&:action)
|
84
|
+
untried = actions - tried
|
85
|
+
action = untried.sample
|
86
|
+
state = @transition_fn.call(node.state, action)
|
87
|
+
child = Node.new(state, node, action)
|
88
|
+
node.children << child
|
89
|
+
child
|
90
|
+
end
|
91
|
+
|
92
|
+
def best_child(node, c)
|
93
|
+
node.children.max_by do |child|
|
94
|
+
exploitation = child.value / (child.visits.nonzero? || 1)
|
95
|
+
exploration = c * Math.sqrt(Math.log(node.visits + 1) / (child.visits.nonzero? || 1))
|
96
|
+
exploitation + exploration
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
def default_policy(state)
|
101
|
+
current = state
|
102
|
+
until @terminal_fn.call(current)
|
103
|
+
action = @actions_fn.call(current).sample
|
104
|
+
current = @transition_fn.call(current, action)
|
105
|
+
end
|
106
|
+
@reward_fn.call(current)
|
107
|
+
end
|
108
|
+
|
109
|
+
def backup(node, reward)
|
110
|
+
while node
|
111
|
+
node.visits += 1
|
112
|
+
node.value += reward
|
113
|
+
node = node.parent
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
data/lib/ai4r/search.rb
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Ai4r
|
4
|
+
module Som
|
5
|
+
# Helper module with distance metrics for node coordinates
|
6
|
+
module DistanceMetrics
|
7
|
+
# @param dx [Object]
|
8
|
+
# @param dy [Object]
|
9
|
+
# @return [Object]
|
10
|
+
def self.chebyshev(dx, dy)
|
11
|
+
[dx.abs, dy.abs].max
|
12
|
+
end
|
13
|
+
|
14
|
+
# @param dx [Object]
|
15
|
+
# @param dy [Object]
|
16
|
+
# @return [Object]
|
17
|
+
def self.euclidean(dx, dy)
|
18
|
+
Math.sqrt((dx**2) + (dy**2))
|
19
|
+
end
|
20
|
+
|
21
|
+
# @param dx [Object]
|
22
|
+
# @param dy [Object]
|
23
|
+
# @return [Object]
|
24
|
+
def self.manhattan(dx, dy)
|
25
|
+
dx.abs + dy.abs
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
data/lib/ai4r/som/layer.rb
CHANGED
@@ -1,18 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Thomas Kern
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
8
|
# You can redistribute it and/or modify it under the terms of
|
7
9
|
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
|
-
|
12
|
+
require_relative '../data/parameterizable'
|
11
13
|
|
12
14
|
module Ai4r
|
13
|
-
|
14
15
|
module Som
|
15
|
-
|
16
16
|
# responsible for the implementation of the algorithm's decays
|
17
17
|
# currently has methods for the decay of the radius, influence and learning rate.
|
18
18
|
# Has only one phase, which ends after the number of epochs is passed by the Som-class.
|
@@ -24,16 +24,22 @@ module Ai4r
|
|
24
24
|
# * epochs => number of epochs the algorithm runs, has to be an integer. By default it is set to 100
|
25
25
|
# * learning_rate => sets the initial learning rate
|
26
26
|
class Layer
|
27
|
-
|
28
27
|
include Ai4r::Data::Parameterizable
|
29
28
|
|
30
|
-
parameters_info :
|
31
|
-
:
|
32
|
-
:
|
29
|
+
parameters_info nodes: 'number of nodes, has to be equal to the som',
|
30
|
+
epochs: 'number of epochs the algorithm has to run',
|
31
|
+
radius: 'sets the initial neighborhoud radius',
|
32
|
+
distance_metric: 'metric used to compute node distance'
|
33
|
+
|
34
|
+
# @param nodes [Object]
|
35
|
+
# @param radius [Object]
|
36
|
+
# @param epochs [Object]
|
37
|
+
# @param learning_rate [Object]
|
38
|
+
# @param options [Object]
|
39
|
+
# @return [Object]
|
40
|
+
def initialize(nodes, radius, epochs = 100, learning_rate = 0.7, options = {})
|
41
|
+
raise('Too few nodes') if nodes < 3
|
33
42
|
|
34
|
-
def initialize(nodes, radius, epochs = 100, learning_rate = 0.7)
|
35
|
-
raise("Too few nodes") if nodes < 3
|
36
|
-
|
37
43
|
@nodes = nodes
|
38
44
|
@epochs = epochs
|
39
45
|
@radius = radius
|
@@ -41,28 +47,33 @@ module Ai4r
|
|
41
47
|
@time_for_epoch = @epochs + 1.0 if @time_for_epoch < @epochs
|
42
48
|
|
43
49
|
@initial_learning_rate = learning_rate
|
50
|
+
@distance_metric = options[:distance_metric] || :chebyshev
|
44
51
|
end
|
45
52
|
|
46
53
|
# calculates the influnce decay for a certain distance and the current radius
|
47
54
|
# of the epoch
|
55
|
+
# @param distance [Object]
|
56
|
+
# @param radius [Object]
|
57
|
+
# @return [Object]
|
48
58
|
def influence_decay(distance, radius)
|
49
|
-
Math.exp(- (distance.to_f**2 / 2.0 / radius.to_f**2))
|
59
|
+
Math.exp(- ((distance.to_f**2) / 2.0 / (radius.to_f**2)))
|
50
60
|
end
|
51
61
|
|
52
62
|
# calculates the radius decay for the current epoch. Uses @time_for_epoch
|
53
63
|
# which has to be higher than the number of epochs, otherwise the decay will be - Infinity
|
64
|
+
# @param epoch [Object]
|
65
|
+
# @return [Object]
|
54
66
|
def radius_decay(epoch)
|
55
|
-
(@radius * (
|
67
|
+
(@radius * (1 - (epoch / @time_for_epoch))).round
|
56
68
|
end
|
57
69
|
|
58
70
|
# calculates the learning rate decay. uses @time_for_epoch again and same rule applies:
|
59
71
|
# @time_for_epoch has to be higher than the number of epochs, otherwise the decay will be - Infinity
|
72
|
+
# @param epoch [Object]
|
73
|
+
# @return [Object]
|
60
74
|
def learning_rate_decay(epoch)
|
61
|
-
@initial_learning_rate * (
|
75
|
+
@initial_learning_rate * (1 - (epoch / @time_for_epoch))
|
62
76
|
end
|
63
|
-
|
64
77
|
end
|
65
|
-
|
66
78
|
end
|
67
|
-
|
68
79
|
end
|
data/lib/ai4r/som/node.rb
CHANGED
@@ -1,19 +1,20 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Thomas Kern
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
8
|
# You can redistribute it and/or modify it under the terms of
|
7
9
|
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
|
-
|
11
|
-
|
12
|
+
require_relative '../data/parameterizable'
|
13
|
+
require_relative 'layer'
|
14
|
+
require_relative 'distance_metrics'
|
12
15
|
|
13
16
|
module Ai4r
|
14
|
-
|
15
17
|
module Som
|
16
|
-
|
17
18
|
# this class is used for the individual node and will be (nodes * nodes)-time instantiated
|
18
19
|
#
|
19
20
|
# = attributes
|
@@ -24,37 +25,67 @@ module Ai4r
|
|
24
25
|
# * weights => values of the current weights are stored in an array of dimension 'dimensions'.
|
25
26
|
# Weights are of type float
|
26
27
|
# * instantiated_weight => the values of the first instantiation of weights. these values are
|
27
|
-
# never changed
|
28
|
+
# never changed
|
28
29
|
|
30
|
+
# Represents a node in the self-organizing map grid.
|
29
31
|
class Node
|
30
|
-
|
31
32
|
include Ai4r::Data::Parameterizable
|
32
33
|
|
33
|
-
parameters_info :
|
34
|
-
:
|
35
|
-
:
|
36
|
-
:
|
37
|
-
:
|
34
|
+
parameters_info weights: 'holds the current weight',
|
35
|
+
instantiated_weight: 'holds the very first weight',
|
36
|
+
x: 'holds the row ID of the unit in the map',
|
37
|
+
y: 'holds the column ID of the unit in the map',
|
38
|
+
id: 'id of the node',
|
39
|
+
distance_metric: 'metric used to compute node distance'
|
38
40
|
|
39
41
|
# creates an instance of Node and instantiates the weights
|
40
|
-
#
|
41
|
-
#
|
42
|
-
|
42
|
+
#
|
43
|
+
# +id+:: unique identifier for this node
|
44
|
+
# +rows+:: number of rows of the SOM grid
|
45
|
+
# +columns+:: number of columns of the SOM grid
|
46
|
+
# +dimensions+:: dimension of the input vector
|
47
|
+
# @param id [Object]
|
48
|
+
# @param rows [Object]
|
49
|
+
# @param columns [Object]
|
50
|
+
# @param dimensions [Object]
|
51
|
+
# @param options [Object]
|
52
|
+
# @option options [Range] :range (0..1) range used to initialize weights
|
53
|
+
# @option options [Integer] :random_seed Seed for Ruby's RNG. The
|
54
|
+
# deprecated :seed key is supported for backward compatibility.
|
55
|
+
# @return [Object]
|
56
|
+
def self.create(id, _rows, columns, dimensions, options = {})
|
43
57
|
n = Node.new
|
44
58
|
n.id = id
|
45
|
-
n.
|
46
|
-
n.
|
47
|
-
n.
|
59
|
+
n.distance_metric = options[:distance_metric] || :chebyshev
|
60
|
+
n.instantiate_weight dimensions, options
|
61
|
+
n.x = id % columns
|
62
|
+
n.y = (id / columns.to_f).to_i
|
48
63
|
n
|
49
64
|
end
|
50
65
|
|
51
66
|
# instantiates the weights to the dimension (of the input vector)
|
52
67
|
# for backup reasons, the instantiated weight is stored into @instantiated_weight as well
|
53
|
-
|
68
|
+
# @param dimensions [Object]
|
69
|
+
# @param options [Object]
|
70
|
+
# @option options [Range] :range (0..1) range used to initialize weights
|
71
|
+
# @option options [Integer] :random_seed Seed for Ruby's RNG. The
|
72
|
+
# deprecated :seed key is supported for backward compatibility.
|
73
|
+
# @return [Object]
|
74
|
+
def instantiate_weight(dimensions, options = {})
|
75
|
+
opts = { range: 0..1, random_seed: nil, seed: nil, rng: nil }.merge(options)
|
76
|
+
rng = opts[:rng]
|
77
|
+
unless rng
|
78
|
+
seed = opts[:random_seed] || opts[:seed]
|
79
|
+
rng = seed.nil? ? Random.new : Random.new(seed)
|
80
|
+
end
|
81
|
+
range = opts[:range] || (0..1)
|
82
|
+
min = range.first.to_f
|
83
|
+
max = range.last.to_f
|
84
|
+
span = max - min
|
54
85
|
@weights = Array.new dimensions
|
55
86
|
@instantiated_weight = Array.new dimensions
|
56
|
-
@weights.
|
57
|
-
@weights[index] = rand
|
87
|
+
@weights.each_index do |index|
|
88
|
+
@weights[index] = min + (rng.rand * span)
|
58
89
|
@instantiated_weight[index] = @weights[index]
|
59
90
|
end
|
60
91
|
end
|
@@ -62,10 +93,12 @@ module Ai4r
|
|
62
93
|
# returns the square distance between the current weights and the input
|
63
94
|
# the input is a vector/array of the same size as weights
|
64
95
|
# at the end, the square root is extracted from the sum of differences
|
96
|
+
# @param input [Object]
|
97
|
+
# @return [Object]
|
65
98
|
def distance_to_input(input)
|
66
99
|
dist = 0
|
67
100
|
input.each_with_index do |i, index|
|
68
|
-
dist += (i - @weights[index])
|
101
|
+
dist += (i - @weights[index])**2
|
69
102
|
end
|
70
103
|
|
71
104
|
Math.sqrt(dist)
|
@@ -79,18 +112,14 @@ module Ai4r
|
|
79
112
|
# 2 1 1 1 2
|
80
113
|
# 2 2 2 2 2
|
81
114
|
# 0 being the current node
|
115
|
+
# @param node [Object]
|
116
|
+
# @return [Object]
|
82
117
|
def distance_to_node(node)
|
83
|
-
|
118
|
+
dx = x - node.x
|
119
|
+
dy = y - node.y
|
120
|
+
metric = @distance_metric || :chebyshev
|
121
|
+
DistanceMetrics.send(metric, dx, dy)
|
84
122
|
end
|
85
|
-
|
86
|
-
private
|
87
|
-
|
88
|
-
def max(a, b)
|
89
|
-
a > b ? a : b
|
90
|
-
end
|
91
|
-
|
92
123
|
end
|
93
|
-
|
94
124
|
end
|
95
|
-
|
96
125
|
end
|