ai4r 1.13 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.rb +14 -11
  16. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  17. data/examples/classifiers/zero_one_r_data.csv +8 -0
  18. data/examples/clusterers/clusterer_example.rb +40 -34
  19. data/examples/clusterers/dbscan_example.rb +17 -0
  20. data/examples/clusterers/dendrogram_example.rb +17 -0
  21. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  22. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  23. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  24. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  25. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  26. data/examples/neural_network/backpropagation_example.rb +48 -48
  27. data/examples/neural_network/hopfield_example.rb +45 -0
  28. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  29. data/examples/neural_network/patterns_with_noise.rb +41 -39
  30. data/examples/neural_network/train_epochs_callback.rb +25 -0
  31. data/examples/neural_network/training_patterns.rb +39 -39
  32. data/examples/neural_network/transformer_text_classification.rb +78 -0
  33. data/examples/neural_network/xor_example.rb +23 -22
  34. data/examples/reinforcement/q_learning_example.rb +10 -0
  35. data/examples/som/som_data.rb +155 -152
  36. data/examples/som/som_multi_node_example.rb +12 -13
  37. data/examples/som/som_single_example.rb +12 -15
  38. data/examples/transformer/decode_classifier_example.rb +68 -0
  39. data/examples/transformer/deterministic_example.rb +10 -0
  40. data/examples/transformer/seq2seq_example.rb +16 -0
  41. data/lib/ai4r/classifiers/classifier.rb +24 -16
  42. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  43. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  44. data/lib/ai4r/classifiers/ib1.rb +122 -32
  45. data/lib/ai4r/classifiers/id3.rb +524 -145
  46. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  47. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  48. data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
  49. data/lib/ai4r/classifiers/one_r.rb +112 -44
  50. data/lib/ai4r/classifiers/prism.rb +167 -76
  51. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  52. data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
  53. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  54. data/lib/ai4r/classifiers/votes.rb +57 -0
  55. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  56. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  57. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  58. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  59. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  60. data/lib/ai4r/clusterers/clusterer.rb +29 -14
  61. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  62. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  63. data/lib/ai4r/clusterers/diana.rb +75 -49
  64. data/lib/ai4r/clusterers/k_means.rb +270 -135
  65. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  66. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  67. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  68. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
  69. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  70. data/lib/ai4r/data/data_set.rb +223 -103
  71. data/lib/ai4r/data/parameterizable.rb +31 -25
  72. data/lib/ai4r/data/proximity.rb +62 -62
  73. data/lib/ai4r/data/statistics.rb +46 -35
  74. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  75. data/lib/ai4r/experiment/split.rb +39 -0
  76. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  77. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  78. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  79. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  80. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  81. data/lib/ai4r/neural_network/backpropagation.rb +399 -134
  82. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  83. data/lib/ai4r/neural_network/transformer.rb +194 -0
  84. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  85. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  86. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  87. data/lib/ai4r/search/a_star.rb +76 -0
  88. data/lib/ai4r/search/bfs.rb +50 -0
  89. data/lib/ai4r/search/dfs.rb +50 -0
  90. data/lib/ai4r/search/mcts.rb +118 -0
  91. data/lib/ai4r/search.rb +12 -0
  92. data/lib/ai4r/som/distance_metrics.rb +29 -0
  93. data/lib/ai4r/som/layer.rb +28 -17
  94. data/lib/ai4r/som/node.rb +61 -32
  95. data/lib/ai4r/som/som.rb +158 -41
  96. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  97. data/lib/ai4r/version.rb +3 -0
  98. data/lib/ai4r.rb +57 -28
  99. metadata +79 -109
  100. data/README.rdoc +0 -39
  101. data/test/classifiers/hyperpipes_test.rb +0 -84
  102. data/test/classifiers/ib1_test.rb +0 -78
  103. data/test/classifiers/id3_test.rb +0 -220
  104. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  105. data/test/classifiers/naive_bayes_test.rb +0 -43
  106. data/test/classifiers/one_r_test.rb +0 -62
  107. data/test/classifiers/prism_test.rb +0 -85
  108. data/test/classifiers/simple_linear_regression_test.rb +0 -37
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -167
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
  119. data/test/clusterers/ward_linkage_test.rb +0 -53
  120. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  121. data/test/data/data_set_test.rb +0 -104
  122. data/test/data/proximity_test.rb +0 -87
  123. data/test/data/statistics_test.rb +0 -65
  124. data/test/experiment/classifier_evaluator_test.rb +0 -76
  125. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  126. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  127. data/test/neural_network/backpropagation_test.rb +0 -82
  128. data/test/neural_network/hopfield_test.rb +0 -72
  129. data/test/som/som_test.rb +0 -97
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI Assistant
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # Generic A* search algorithm.
8
+
9
+ require 'set'
10
+
11
+ module Ai4r
12
+ module Search
13
+ # Generic A* search implementation operating on arbitrary graph
14
+ # representations.
15
+ #
16
+ # Initialize with start node, a goal predicate, a neighbor function and a
17
+ # heuristic.
18
+ #
19
+ # The neighbor function must return a hash mapping neighbor nodes to edge
20
+ # costs. The heuristic receives a node and returns the estimated remaining
21
+ # cost to reach the goal.
22
+ class AStar
23
+ # @param start [Object] initial node
24
+ # @param goal_test [Proc] predicate returning true when node is goal
25
+ # @param neighbor_fn [Proc] -> node { neighbor => cost, ... }
26
+ # @param heuristic_fn [Proc] -> node => estimated remaining cost
27
+ def initialize(start, goal_test, neighbor_fn, heuristic_fn)
28
+ @start = start
29
+ @goal_test = goal_test
30
+ @neighbor_fn = neighbor_fn
31
+ @heuristic_fn = heuristic_fn
32
+ end
33
+
34
+ # Execute the search and return the path as an array of nodes or nil
35
+ # if the goal cannot be reached.
36
+ def search
37
+ open_set = [@start]
38
+ came_from = {}
39
+ g = { @start => 0 }
40
+ f = { @start => @heuristic_fn.call(@start) }
41
+ closed = Set.new
42
+
43
+ until open_set.empty?
44
+ current = open_set.min_by { |n| f[n] }
45
+ return reconstruct_path(came_from, current) if @goal_test.call(current)
46
+
47
+ open_set.delete(current)
48
+ closed << current
49
+ @neighbor_fn.call(current).each do |neighbor, cost|
50
+ next if closed.include?(neighbor)
51
+
52
+ tentative_g = g[current] + cost
53
+ next unless g[neighbor].nil? || tentative_g < g[neighbor]
54
+
55
+ came_from[neighbor] = current
56
+ g[neighbor] = tentative_g
57
+ f[neighbor] = tentative_g + @heuristic_fn.call(neighbor)
58
+ open_set << neighbor unless open_set.include?(neighbor)
59
+ end
60
+ end
61
+ nil
62
+ end
63
+
64
+ private
65
+
66
+ def reconstruct_path(came_from, node)
67
+ path = [node]
68
+ while came_from.key?(node)
69
+ node = came_from[node]
70
+ path.unshift(node)
71
+ end
72
+ path
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI Assistant
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # Basic breadth-first search implementation.
8
+
9
+ module Ai4r
10
+ module Search
11
+ # Explore nodes in breadth-first order until a goal is found.
12
+ class BFS
13
+ # Create a new BFS searcher.
14
+ #
15
+ # goal_test:: lambda returning true for a goal node
16
+ # neighbor_fn:: lambda returning adjacent nodes for a given node
17
+ # start:: optional starting node
18
+ def initialize(goal_test, neighbor_fn, start = nil)
19
+ @goal_test = goal_test
20
+ @neighbor_fn = neighbor_fn
21
+ @start = start
22
+ end
23
+
24
+ # Find a path from the start node to a goal.
25
+ #
26
+ # start:: initial node if not provided on initialization
27
+ #
28
+ # Returns an array of nodes representing the path, or nil if no goal was found.
29
+ def search(start = nil)
30
+ start ||= @start
31
+ raise ArgumentError, 'start node required' unless start
32
+
33
+ queue = [[start, [start]]]
34
+ visited = { start => true }
35
+ until queue.empty?
36
+ node, path = queue.shift
37
+ return path if @goal_test.call(node)
38
+
39
+ @neighbor_fn.call(node).each do |n|
40
+ next if visited[n]
41
+
42
+ visited[n] = true
43
+ queue << [n, path + [n]]
44
+ end
45
+ end
46
+ nil
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI Assistant
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # Basic depth-first search implementation.
8
+
9
+ module Ai4r
10
+ module Search
11
+ # Explore nodes in depth-first order until a goal is found.
12
+ class DFS
13
+ # Create a new DFS searcher.
14
+ #
15
+ # goal_test:: lambda returning true for a goal node
16
+ # neighbor_fn:: lambda returning adjacent nodes for a given node
17
+ # start:: optional starting node
18
+ def initialize(goal_test, neighbor_fn, start = nil)
19
+ @goal_test = goal_test
20
+ @neighbor_fn = neighbor_fn
21
+ @start = start
22
+ end
23
+
24
+ # Find a path from the start node to a goal.
25
+ #
26
+ # start:: initial node if not provided on initialization
27
+ #
28
+ # Returns an array of nodes representing the path, or nil if no goal was found.
29
+ def search(start = nil)
30
+ start ||= @start
31
+ raise ArgumentError, 'start node required' unless start
32
+
33
+ stack = [[start, [start]]]
34
+ visited = { start => true }
35
+ until stack.empty?
36
+ node, path = stack.pop
37
+ return path if @goal_test.call(node)
38
+
39
+ @neighbor_fn.call(node).each do |n|
40
+ next if visited[n]
41
+
42
+ visited[n] = true
43
+ stack << [n, path + [n]]
44
+ end
45
+ end
46
+ nil
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,118 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI Assistant
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # Minimal Monte Carlo Tree Search implementation.
8
+
9
+ require_relative '../data/parameterizable'
10
+
11
+ module Ai4r
12
+ module Search
13
+ # Basic UCT-style Monte Carlo Tree Search.
14
+ #
15
+ # This generic implementation expects four callbacks:
16
+ # - +actions_fn.call(state)+ returns available actions for a state.
17
+ # - +transition_fn.call(state, action)+ computes the next state.
18
+ # - +terminal_fn.call(state)+ returns true if the state has no children.
19
+ # - +reward_fn.call(state)+ yields a numeric payoff for terminal states.
20
+ #
21
+ # Example:
22
+ # env = {
23
+ # actions_fn: ->(s) { s == :root ? %i[a b] : [] },
24
+ # transition_fn: ->(s, a) { a == :a ? :win : :lose },
25
+ # terminal_fn: ->(s) { %i[win lose].include?(s) },
26
+ # reward_fn: ->(s) { s == :win ? 1.0 : 0.0 }
27
+ # }
28
+ # mcts = Ai4r::Search::MCTS.new(**env)
29
+ # best = mcts.search(:root, 50)
30
+ # # => :a
31
+ class MCTS
32
+ include Ai4r::Data::Parameterizable
33
+
34
+ Node = Struct.new(:state, :parent, :action, :children, :visits, :value) do
35
+ def initialize(state, parent = nil, action = nil)
36
+ super(state, parent, action, [], 0, 0.0)
37
+ end
38
+ end
39
+
40
+ parameters_info exploration: 'UCT exploration constant'
41
+
42
+ # Create a new search object.
43
+ #
44
+ # actions_fn:: returns available actions for a state
45
+ # transition_fn:: computes the next state given a state and action
46
+ # terminal_fn:: predicate to detect terminal states
47
+ # reward_fn:: numeric payoff for terminal states
48
+ def initialize(actions_fn:, transition_fn:, terminal_fn:, reward_fn:,
49
+ exploration: Math.sqrt(2))
50
+ @actions_fn = actions_fn
51
+ @transition_fn = transition_fn
52
+ @terminal_fn = terminal_fn
53
+ @reward_fn = reward_fn
54
+ @exploration = exploration
55
+ end
56
+
57
+ # Run MCTS starting from +root_state+ for a number of +iterations+.
58
+ # Returns the action considered best from the root.
59
+ def search(root_state, iterations)
60
+ root = Node.new(root_state)
61
+ iterations.times do
62
+ node = tree_policy(root)
63
+ reward = default_policy(node.state)
64
+ backup(node, reward)
65
+ end
66
+ best_child(root, 0)&.action
67
+ end
68
+
69
+ private
70
+
71
+ def tree_policy(node)
72
+ until @terminal_fn.call(node.state)
73
+ actions = @actions_fn.call(node.state)
74
+ return expand(node, actions) if node.children.length < actions.length
75
+
76
+ node = best_child(node, @exploration)
77
+
78
+ end
79
+ node
80
+ end
81
+
82
+ def expand(node, actions)
83
+ tried = node.children.map(&:action)
84
+ untried = actions - tried
85
+ action = untried.sample
86
+ state = @transition_fn.call(node.state, action)
87
+ child = Node.new(state, node, action)
88
+ node.children << child
89
+ child
90
+ end
91
+
92
+ def best_child(node, c)
93
+ node.children.max_by do |child|
94
+ exploitation = child.value / (child.visits.nonzero? || 1)
95
+ exploration = c * Math.sqrt(Math.log(node.visits + 1) / (child.visits.nonzero? || 1))
96
+ exploitation + exploration
97
+ end
98
+ end
99
+
100
+ def default_policy(state)
101
+ current = state
102
+ until @terminal_fn.call(current)
103
+ action = @actions_fn.call(current).sample
104
+ current = @transition_fn.call(current, action)
105
+ end
106
+ @reward_fn.call(current)
107
+ end
108
+
109
+ def backup(node, reward)
110
+ while node
111
+ node.visits += 1
112
+ node.value += reward
113
+ node = node.parent
114
+ end
115
+ end
116
+ end
117
+ end
118
+ end
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'search/a_star'
4
+ require_relative 'search/bfs'
5
+ require_relative 'search/dfs'
6
+ require_relative 'search/mcts'
7
+
8
+ module Ai4r
9
+ # Namespace for search algorithms like A*.
10
+ module Search
11
+ end
12
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ai4r
4
+ module Som
5
+ # Helper module with distance metrics for node coordinates
6
+ module DistanceMetrics
7
+ # @param dx [Object]
8
+ # @param dy [Object]
9
+ # @return [Object]
10
+ def self.chebyshev(dx, dy)
11
+ [dx.abs, dy.abs].max
12
+ end
13
+
14
+ # @param dx [Object]
15
+ # @param dy [Object]
16
+ # @return [Object]
17
+ def self.euclidean(dx, dy)
18
+ Math.sqrt((dx**2) + (dy**2))
19
+ end
20
+
21
+ # @param dx [Object]
22
+ # @param dy [Object]
23
+ # @return [Object]
24
+ def self.manhattan(dx, dy)
25
+ dx.abs + dy.abs
26
+ end
27
+ end
28
+ end
29
+ end
@@ -1,18 +1,18 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Thomas Kern
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
8
  # You can redistribute it and/or modify it under the terms of
7
9
  # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
- require File.dirname(__FILE__) + '/../data/parameterizable'
12
+ require_relative '../data/parameterizable'
11
13
 
12
14
  module Ai4r
13
-
14
15
  module Som
15
-
16
16
  # responsible for the implementation of the algorithm's decays
17
17
  # currently has methods for the decay of the radius, influence and learning rate.
18
18
  # Has only one phase, which ends after the number of epochs is passed by the Som-class.
@@ -24,16 +24,22 @@ module Ai4r
24
24
  # * epochs => number of epochs the algorithm runs, has to be an integer. By default it is set to 100
25
25
  # * learning_rate => sets the initial learning rate
26
26
  class Layer
27
-
28
27
  include Ai4r::Data::Parameterizable
29
28
 
30
- parameters_info :nodes => "number of nodes, has to be equal to the som",
31
- :epochs => "number of epochs the algorithm has to run",
32
- :radius => "sets the initial neighborhoud radius"
29
+ parameters_info nodes: 'number of nodes, has to be equal to the som',
30
+ epochs: 'number of epochs the algorithm has to run',
31
+ radius: 'sets the initial neighborhoud radius',
32
+ distance_metric: 'metric used to compute node distance'
33
+
34
+ # @param nodes [Object]
35
+ # @param radius [Object]
36
+ # @param epochs [Object]
37
+ # @param learning_rate [Object]
38
+ # @param options [Object]
39
+ # @return [Object]
40
+ def initialize(nodes, radius, epochs = 100, learning_rate = 0.7, options = {})
41
+ raise('Too few nodes') if nodes < 3
33
42
 
34
- def initialize(nodes, radius, epochs = 100, learning_rate = 0.7)
35
- raise("Too few nodes") if nodes < 3
36
-
37
43
  @nodes = nodes
38
44
  @epochs = epochs
39
45
  @radius = radius
@@ -41,28 +47,33 @@ module Ai4r
41
47
  @time_for_epoch = @epochs + 1.0 if @time_for_epoch < @epochs
42
48
 
43
49
  @initial_learning_rate = learning_rate
50
+ @distance_metric = options[:distance_metric] || :chebyshev
44
51
  end
45
52
 
46
53
  # calculates the influnce decay for a certain distance and the current radius
47
54
  # of the epoch
55
+ # @param distance [Object]
56
+ # @param radius [Object]
57
+ # @return [Object]
48
58
  def influence_decay(distance, radius)
49
- Math.exp(- (distance.to_f**2 / 2.0 / radius.to_f**2))
59
+ Math.exp(- ((distance.to_f**2) / 2.0 / (radius.to_f**2)))
50
60
  end
51
61
 
52
62
  # calculates the radius decay for the current epoch. Uses @time_for_epoch
53
63
  # which has to be higher than the number of epochs, otherwise the decay will be - Infinity
64
+ # @param epoch [Object]
65
+ # @return [Object]
54
66
  def radius_decay(epoch)
55
- (@radius * ( 1 - epoch/ @time_for_epoch)).round
67
+ (@radius * (1 - (epoch / @time_for_epoch))).round
56
68
  end
57
69
 
58
70
  # calculates the learning rate decay. uses @time_for_epoch again and same rule applies:
59
71
  # @time_for_epoch has to be higher than the number of epochs, otherwise the decay will be - Infinity
72
+ # @param epoch [Object]
73
+ # @return [Object]
60
74
  def learning_rate_decay(epoch)
61
- @initial_learning_rate * ( 1 - epoch / @time_for_epoch)
75
+ @initial_learning_rate * (1 - (epoch / @time_for_epoch))
62
76
  end
63
-
64
77
  end
65
-
66
78
  end
67
-
68
79
  end
data/lib/ai4r/som/node.rb CHANGED
@@ -1,19 +1,20 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Thomas Kern
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
8
  # You can redistribute it and/or modify it under the terms of
7
9
  # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
- require File.dirname(__FILE__) + '/../data/parameterizable'
11
- require File.dirname(__FILE__) + '/layer'
12
+ require_relative '../data/parameterizable'
13
+ require_relative 'layer'
14
+ require_relative 'distance_metrics'
12
15
 
13
16
  module Ai4r
14
-
15
17
  module Som
16
-
17
18
  # this class is used for the individual node and will be (nodes * nodes)-time instantiated
18
19
  #
19
20
  # = attributes
@@ -24,37 +25,67 @@ module Ai4r
24
25
  # * weights => values of the current weights are stored in an array of dimension 'dimensions'.
25
26
  # Weights are of type float
26
27
  # * instantiated_weight => the values of the first instantiation of weights. these values are
27
- # never changed
28
+ # never changed
28
29
 
30
+ # Represents a node in the self-organizing map grid.
29
31
  class Node
30
-
31
32
  include Ai4r::Data::Parameterizable
32
33
 
33
- parameters_info :weights => "holds the current weight",
34
- :instantiated_weight => "holds the very first weight",
35
- :x => "holds the row ID of the unit in the map",
36
- :y => "holds the column ID of the unit in the map",
37
- :id => "id of the node"
34
+ parameters_info weights: 'holds the current weight',
35
+ instantiated_weight: 'holds the very first weight',
36
+ x: 'holds the row ID of the unit in the map',
37
+ y: 'holds the column ID of the unit in the map',
38
+ id: 'id of the node',
39
+ distance_metric: 'metric used to compute node distance'
38
40
 
39
41
  # creates an instance of Node and instantiates the weights
40
- # the parameters is a uniq and sequential ID as well as the number of total nodes
41
- # dimensions signals the dimension of the input vector
42
- def self.create(id, total, dimensions)
42
+ #
43
+ # +id+:: unique identifier for this node
44
+ # +rows+:: number of rows of the SOM grid
45
+ # +columns+:: number of columns of the SOM grid
46
+ # +dimensions+:: dimension of the input vector
47
+ # @param id [Object]
48
+ # @param rows [Object]
49
+ # @param columns [Object]
50
+ # @param dimensions [Object]
51
+ # @param options [Object]
52
+ # @option options [Range] :range (0..1) range used to initialize weights
53
+ # @option options [Integer] :random_seed Seed for Ruby's RNG. The
54
+ # deprecated :seed key is supported for backward compatibility.
55
+ # @return [Object]
56
+ def self.create(id, _rows, columns, dimensions, options = {})
43
57
  n = Node.new
44
58
  n.id = id
45
- n.instantiate_weight dimensions
46
- n.x = id % total
47
- n.y = (id / total.to_f).to_i
59
+ n.distance_metric = options[:distance_metric] || :chebyshev
60
+ n.instantiate_weight dimensions, options
61
+ n.x = id % columns
62
+ n.y = (id / columns.to_f).to_i
48
63
  n
49
64
  end
50
65
 
51
66
  # instantiates the weights to the dimension (of the input vector)
52
67
  # for backup reasons, the instantiated weight is stored into @instantiated_weight as well
53
- def instantiate_weight(dimensions)
68
+ # @param dimensions [Object]
69
+ # @param options [Object]
70
+ # @option options [Range] :range (0..1) range used to initialize weights
71
+ # @option options [Integer] :random_seed Seed for Ruby's RNG. The
72
+ # deprecated :seed key is supported for backward compatibility.
73
+ # @return [Object]
74
+ def instantiate_weight(dimensions, options = {})
75
+ opts = { range: 0..1, random_seed: nil, seed: nil, rng: nil }.merge(options)
76
+ rng = opts[:rng]
77
+ unless rng
78
+ seed = opts[:random_seed] || opts[:seed]
79
+ rng = seed.nil? ? Random.new : Random.new(seed)
80
+ end
81
+ range = opts[:range] || (0..1)
82
+ min = range.first.to_f
83
+ max = range.last.to_f
84
+ span = max - min
54
85
  @weights = Array.new dimensions
55
86
  @instantiated_weight = Array.new dimensions
56
- @weights.each_with_index do |weight, index|
57
- @weights[index] = rand
87
+ @weights.each_index do |index|
88
+ @weights[index] = min + (rng.rand * span)
58
89
  @instantiated_weight[index] = @weights[index]
59
90
  end
60
91
  end
@@ -62,10 +93,12 @@ module Ai4r
62
93
  # returns the square distance between the current weights and the input
63
94
  # the input is a vector/array of the same size as weights
64
95
  # at the end, the square root is extracted from the sum of differences
96
+ # @param input [Object]
97
+ # @return [Object]
65
98
  def distance_to_input(input)
66
99
  dist = 0
67
100
  input.each_with_index do |i, index|
68
- dist += (i - @weights[index]) ** 2
101
+ dist += (i - @weights[index])**2
69
102
  end
70
103
 
71
104
  Math.sqrt(dist)
@@ -79,18 +112,14 @@ module Ai4r
79
112
  # 2 1 1 1 2
80
113
  # 2 2 2 2 2
81
114
  # 0 being the current node
115
+ # @param node [Object]
116
+ # @return [Object]
82
117
  def distance_to_node(node)
83
- max((self.x - node.x).abs, (self.y - node.y).abs)
118
+ dx = x - node.x
119
+ dy = y - node.y
120
+ metric = @distance_metric || :chebyshev
121
+ DistanceMetrics.send(metric, dx, dy)
84
122
  end
85
-
86
- private
87
-
88
- def max(a, b)
89
- a > b ? a : b
90
- end
91
-
92
123
  end
93
-
94
124
  end
95
-
96
125
  end