ai4r 1.13 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.rb +14 -11
  16. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  17. data/examples/classifiers/zero_one_r_data.csv +8 -0
  18. data/examples/clusterers/clusterer_example.rb +40 -34
  19. data/examples/clusterers/dbscan_example.rb +17 -0
  20. data/examples/clusterers/dendrogram_example.rb +17 -0
  21. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  22. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  23. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  24. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  25. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  26. data/examples/neural_network/backpropagation_example.rb +48 -48
  27. data/examples/neural_network/hopfield_example.rb +45 -0
  28. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  29. data/examples/neural_network/patterns_with_noise.rb +41 -39
  30. data/examples/neural_network/train_epochs_callback.rb +25 -0
  31. data/examples/neural_network/training_patterns.rb +39 -39
  32. data/examples/neural_network/transformer_text_classification.rb +78 -0
  33. data/examples/neural_network/xor_example.rb +23 -22
  34. data/examples/reinforcement/q_learning_example.rb +10 -0
  35. data/examples/som/som_data.rb +155 -152
  36. data/examples/som/som_multi_node_example.rb +12 -13
  37. data/examples/som/som_single_example.rb +12 -15
  38. data/examples/transformer/decode_classifier_example.rb +68 -0
  39. data/examples/transformer/deterministic_example.rb +10 -0
  40. data/examples/transformer/seq2seq_example.rb +16 -0
  41. data/lib/ai4r/classifiers/classifier.rb +24 -16
  42. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  43. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  44. data/lib/ai4r/classifiers/ib1.rb +122 -32
  45. data/lib/ai4r/classifiers/id3.rb +524 -145
  46. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  47. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  48. data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
  49. data/lib/ai4r/classifiers/one_r.rb +112 -44
  50. data/lib/ai4r/classifiers/prism.rb +167 -76
  51. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  52. data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
  53. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  54. data/lib/ai4r/classifiers/votes.rb +57 -0
  55. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  56. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  57. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  58. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  59. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  60. data/lib/ai4r/clusterers/clusterer.rb +29 -14
  61. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  62. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  63. data/lib/ai4r/clusterers/diana.rb +75 -49
  64. data/lib/ai4r/clusterers/k_means.rb +270 -135
  65. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  66. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  67. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  68. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
  69. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  70. data/lib/ai4r/data/data_set.rb +223 -103
  71. data/lib/ai4r/data/parameterizable.rb +31 -25
  72. data/lib/ai4r/data/proximity.rb +62 -62
  73. data/lib/ai4r/data/statistics.rb +46 -35
  74. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  75. data/lib/ai4r/experiment/split.rb +39 -0
  76. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  77. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  78. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  79. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  80. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  81. data/lib/ai4r/neural_network/backpropagation.rb +399 -134
  82. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  83. data/lib/ai4r/neural_network/transformer.rb +194 -0
  84. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  85. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  86. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  87. data/lib/ai4r/search/a_star.rb +76 -0
  88. data/lib/ai4r/search/bfs.rb +50 -0
  89. data/lib/ai4r/search/dfs.rb +50 -0
  90. data/lib/ai4r/search/mcts.rb +118 -0
  91. data/lib/ai4r/search.rb +12 -0
  92. data/lib/ai4r/som/distance_metrics.rb +29 -0
  93. data/lib/ai4r/som/layer.rb +28 -17
  94. data/lib/ai4r/som/node.rb +61 -32
  95. data/lib/ai4r/som/som.rb +158 -41
  96. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  97. data/lib/ai4r/version.rb +3 -0
  98. data/lib/ai4r.rb +57 -28
  99. metadata +79 -109
  100. data/README.rdoc +0 -39
  101. data/test/classifiers/hyperpipes_test.rb +0 -84
  102. data/test/classifiers/ib1_test.rb +0 -78
  103. data/test/classifiers/id3_test.rb +0 -220
  104. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  105. data/test/classifiers/naive_bayes_test.rb +0 -43
  106. data/test/classifiers/one_r_test.rb +0 -62
  107. data/test/classifiers/prism_test.rb +0 -85
  108. data/test/classifiers/simple_linear_regression_test.rb +0 -37
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -167
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
  119. data/test/clusterers/ward_linkage_test.rb +0 -53
  120. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  121. data/test/data/data_set_test.rb +0 -104
  122. data/test/data/proximity_test.rb +0 -87
  123. data/test/data/statistics_test.rb +0 -65
  124. data/test/experiment/classifier_evaluator_test.rb +0 -76
  125. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  126. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  127. data/test/neural_network/backpropagation_test.rb +0 -82
  128. data/test/neural_network/hopfield_test.rb +0 -72
  129. data/test/som/som_test.rb +0 -97
@@ -1,31 +1,32 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
6
  # Url:: http://www.ai4r.org/
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
-
10
11
 
11
12
  TRIANGLE = [
12
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
13
- [ 0, 0, 0, 0, 0, 0, 1, 9, 9, 1, 0, 0, 0, 0, 0, 0],
14
- [ 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0],
15
- [ 0, 0, 0, 0, 0, 1, 9, 1, 1, 9, 1, 0, 0, 0, 0, 0],
16
- [ 0, 0, 0, 0, 0, 5, 5, 0, 0, 5, 5, 0, 0, 0, 0, 0],
17
- [ 0, 0, 0, 0, 1, 9, 1, 0, 0, 1, 9, 1, 0, 0, 0, 0],
18
- [ 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0],
19
- [ 0, 0, 0, 1, 9, 1, 0, 0, 0, 0, 1, 9, 1, 0, 0, 0],
20
- [ 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0],
21
- [ 0, 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 0],
22
- [ 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0],
23
- [ 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0],
24
- [ 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0],
25
- [ 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1],
26
- [ 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5],
13
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
14
+ [0, 0, 0, 0, 0, 0, 1, 9, 9, 1, 0, 0, 0, 0, 0, 0],
15
+ [0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0],
16
+ [0, 0, 0, 0, 0, 1, 9, 1, 1, 9, 1, 0, 0, 0, 0, 0],
17
+ [0, 0, 0, 0, 0, 5, 5, 0, 0, 5, 5, 0, 0, 0, 0, 0],
18
+ [0, 0, 0, 0, 1, 9, 1, 0, 0, 1, 9, 1, 0, 0, 0, 0],
19
+ [0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0],
20
+ [0, 0, 0, 1, 9, 1, 0, 0, 0, 0, 1, 9, 1, 0, 0, 0],
21
+ [0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0],
22
+ [0, 0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0, 0],
23
+ [0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0],
24
+ [0, 1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1, 0],
25
+ [0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0],
26
+ [1, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 9, 1],
27
+ [5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5],
27
28
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
28
- ]
29
+ ].freeze
29
30
 
30
31
  SQUARE = [
31
32
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
@@ -44,25 +45,24 @@ SQUARE = [
44
45
  [10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
45
46
  [10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
46
47
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
47
-
48
- ]
48
+
49
+ ].freeze
49
50
 
50
51
  CROSS = [
51
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
52
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
53
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
54
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
55
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
56
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
57
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
58
- [ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
59
- [ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
60
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
61
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
62
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
63
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
64
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
65
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
66
- [ 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0]
67
- ]
68
-
52
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
53
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
54
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
55
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
56
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
57
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
58
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
59
+ [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
60
+ [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
61
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
62
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
63
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
64
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
65
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
66
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0],
67
+ [0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0]
68
+ ].freeze
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI Assistant
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ # Url:: http://www.ai4r.org/
7
+ #
8
+ # Toy example of using the minimal Transformer encoder for
9
+ # text classification. We build random sentence embeddings
10
+ # with the Transformer and train a logistic regression
11
+ # classifier on a tiny sentiment dataset.
12
+
13
+ require_relative '../../lib/ai4r/neural_network/transformer'
14
+ require_relative '../../lib/ai4r/classifiers/logistic_regression'
15
+ require_relative '../../lib/ai4r/data/data_set'
16
+
17
+ # Vocabulary for our miniature dataset
18
+ VOCAB = {
19
+ 'good' => 0,
20
+ 'great' => 1,
21
+ 'bad' => 2,
22
+ 'awful' => 3,
23
+ 'movie' => 4,
24
+ 'film' => 5,
25
+ '<pad>' => 6
26
+ }.freeze
27
+
28
+ MAX_LEN = 2
29
+
30
+ # Helper that converts space separated text into an array of token ids
31
+ # and pads it to MAX_LEN tokens.
32
+ def encode(text)
33
+ tokens = text.split.map { |w| VOCAB[w] }
34
+ tokens.fill(VOCAB['<pad>'], tokens.length...MAX_LEN)
35
+ end
36
+
37
+ # Build encoder only Transformer
38
+ model = Ai4r::NeuralNetwork::Transformer.new(
39
+ vocab_size: VOCAB.size,
40
+ max_len: MAX_LEN
41
+ )
42
+
43
+ train_texts = ['good movie', 'great film', 'bad movie', 'awful film']
44
+ labels = [1, 1, 0, 0]
45
+
46
+ # Obtain sentence embeddings by averaging token representations
47
+ train_features = train_texts.map do |text|
48
+ tokens = encode(text)
49
+ enc = model.eval(tokens)
50
+ mean = Array.new(model.embed_dim, 0.0)
51
+ enc.each do |vec|
52
+ vec.each_with_index { |v, i| mean[i] += v }
53
+ end
54
+ mean.map { |v| v / enc.length }
55
+ end
56
+
57
+ data_items = train_features.each_with_index.map { |feat, i| feat + [labels[i]] }
58
+ labels_names = (0...model.embed_dim).map { |i| "f#{i}" } + ['class']
59
+
60
+ dataset = Ai4r::Data::DataSet.new(
61
+ data_items: data_items,
62
+ data_labels: labels_names
63
+ )
64
+
65
+ classifier = Ai4r::Classifiers::LogisticRegression.new
66
+ classifier.set_parameters(lr: 0.5, iterations: 2000).build(dataset)
67
+
68
+ puts 'Predictions:'
69
+ ['good film', 'awful movie'].each do |text|
70
+ tokens = encode(text)
71
+ enc = model.eval(tokens)
72
+ mean = Array.new(model.embed_dim, 0.0)
73
+ enc.each do |vec|
74
+ vec.each_with_index { |v, i| mean[i] += v }
75
+ end
76
+ mean.map! { |v| v / enc.length }
77
+ puts "#{text} => #{classifier.eval(mean)}"
78
+ end
@@ -1,35 +1,36 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
6
  # Url:: http://www.ai4r.org/
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
- require File.dirname(__FILE__) + '/../../lib/ai4r/neural_network/backpropagation'
12
+ require_relative '../../lib/ai4r/neural_network/backpropagation'
11
13
  require 'benchmark'
12
14
 
13
15
  times = Benchmark.measure do
16
+ srand 1
17
+
18
+ net = Ai4r::NeuralNetwork::Backpropagation.new([2, 2, 1])
19
+
20
+ puts 'Training the network, please wait.'
21
+ 2001.times do |i|
22
+ net.train([0, 0], [0])
23
+ net.train([0, 1], [1])
24
+ net.train([1, 0], [1])
25
+ error = net.train([1, 1], [0])
26
+ puts "Error after iteration #{i}:\t#{error}" if (i % 200).zero?
27
+ end
14
28
 
15
- srand 1
16
-
17
- net = Ai4r::NeuralNetwork::Backpropagation.new([2, 2, 1])
18
-
19
- puts "Training the network, please wait."
20
- 2001.times do |i|
21
- net.train([0,0], [0])
22
- net.train([0,1], [1])
23
- net.train([1,0], [1])
24
- error = net.train([1,1], [0])
25
- puts "Error after iteration #{i}:\t#{error}" if i%200 == 0
26
- end
27
-
28
- puts "Test data"
29
- puts "[0,0] = > #{net.eval([0,0]).inspect}"
30
- puts "[0,1] = > #{net.eval([0,1]).inspect}"
31
- puts "[1,0] = > #{net.eval([1,0]).inspect}"
32
- puts "[1,1] = > #{net.eval([1,1]).inspect}"
29
+ puts 'Test data'
30
+ puts "[0,0] = > #{net.eval([0, 0]).inspect}"
31
+ puts "[0,1] = > #{net.eval([0, 1]).inspect}"
32
+ puts "[1,0] = > #{net.eval([1, 0]).inspect}"
33
+ puts "[1,1] = > #{net.eval([1, 1]).inspect}"
33
34
  end
34
35
 
35
- puts "Elapsed time: #{times}"
36
+ puts "Elapsed time: #{times}"
@@ -0,0 +1,10 @@
1
+ require 'ai4r/reinforcement/q_learning'
2
+
3
+ agent = Ai4r::Reinforcement::QLearning.new
4
+ agent.set_parameters(learning_rate: 0.5, discount: 1.0, exploration: 0.0)
5
+
6
+ # Simple two-state MDP
7
+ agent.update(:s1, :a, 0, :s2)
8
+ agent.update(:s1, :b, 1, :s1)
9
+
10
+ puts "Best action from s1: #{agent.choose_action(:s1)}"
@@ -1,156 +1,159 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # data is from the iris dataset (http://archive.ics.uci.edu/ml/datasets/Iris)
2
4
  # it is the full dataset, removing the last column
3
- # website provides additional information on the dataset itself (attributes, class distribution, etc)
5
+ # website provides additional information on the dataset itself
6
+ # (attributes, class distribution, etc)
4
7
 
5
8
  SOM_DATA = [
6
- [5.1, 3.5, 1.4, 0.2],
7
- [4.9, 3.0, 1.4, 0.2],
8
- [4.7, 3.2, 1.3, 0.2],
9
- [4.6, 3.1, 1.5, 0.2],
10
- [5.0, 3.6, 1.4, 0.2],
11
- [5.4, 3.9, 1.7, 0.4],
12
- [4.6, 3.4, 1.4, 0.3],
13
- [5.0, 3.4, 1.5, 0.2],
14
- [4.4, 2.9, 1.4, 0.2],
15
- [4.9, 3.1, 1.5, 0.1],
16
- [5.4, 3.7, 1.5, 0.2],
17
- [4.8, 3.4, 1.6, 0.2],
18
- [4.8, 3.0, 1.4, 0.1],
19
- [4.3, 3.0, 1.1, 0.1],
20
- [5.8, 4.0, 1.2, 0.2],
21
- [5.7, 4.4, 1.5, 0.4],
22
- [5.4, 3.9, 1.3, 0.4],
23
- [5.1, 3.5, 1.4, 0.3],
24
- [5.7, 3.8, 1.7, 0.3],
25
- [5.1, 3.8, 1.5, 0.3],
26
- [5.4, 3.4, 1.7, 0.2],
27
- [5.1, 3.7, 1.5, 0.4],
28
- [4.6, 3.6, 1.0, 0.2],
29
- [5.1, 3.3, 1.7, 0.5],
30
- [4.8, 3.4, 1.9, 0.2],
31
- [5.0, 3.0, 1.6, 0.2],
32
- [5.0, 3.4, 1.6, 0.4],
33
- [5.2, 3.5, 1.5, 0.2],
34
- [5.2, 3.4, 1.4, 0.2],
35
- [4.7, 3.2, 1.6, 0.2],
36
- [4.8, 3.1, 1.6, 0.2],
37
- [5.4, 3.4, 1.5, 0.4],
38
- [5.2, 4.1, 1.5, 0.1],
39
- [5.5, 4.2, 1.4, 0.2],
40
- [4.9, 3.1, 1.5, 0.1],
41
- [5.0, 3.2, 1.2, 0.2],
42
- [5.5, 3.5, 1.3, 0.2],
43
- [4.9, 3.1, 1.5, 0.1],
44
- [4.4, 3.0, 1.3, 0.2],
45
- [5.1, 3.4, 1.5, 0.2],
46
- [5.0, 3.5, 1.3, 0.3],
47
- [4.5, 2.3, 1.3, 0.3],
48
- [4.4, 3.2, 1.3, 0.2],
49
- [5.0, 3.5, 1.6, 0.6],
50
- [5.1, 3.8, 1.9, 0.4],
51
- [4.8, 3.0, 1.4, 0.3],
52
- [5.1, 3.8, 1.6, 0.2],
53
- [4.6, 3.2, 1.4, 0.2],
54
- [5.3, 3.7, 1.5, 0.2],
55
- [5.0, 3.3, 1.4, 0.2],
56
- [7.0, 3.2, 4.7, 1.4],
57
- [6.4, 3.2, 4.5, 1.5],
58
- [6.9, 3.1, 4.9, 1.5],
59
- [5.5, 2.3, 4.0, 1.3],
60
- [6.5, 2.8, 4.6, 1.5],
61
- [5.7, 2.8, 4.5, 1.3],
62
- [6.3, 3.3, 4.7, 1.6],
63
- [4.9, 2.4, 3.3, 1.0],
64
- [6.6, 2.9, 4.6, 1.3],
65
- [5.2, 2.7, 3.9, 1.4],
66
- [5.0, 2.0, 3.5, 1.0],
67
- [5.9, 3.0, 4.2, 1.5],
68
- [6.0, 2.2, 4.0, 1.0],
69
- [6.1, 2.9, 4.7, 1.4],
70
- [5.6, 2.9, 3.6, 1.3],
71
- [6.7, 3.1, 4.4, 1.4],
72
- [5.6, 3.0, 4.5, 1.5],
73
- [5.8, 2.7, 4.1, 1.0],
74
- [6.2, 2.2, 4.5, 1.5],
75
- [5.6, 2.5, 3.9, 1.1],
76
- [5.9, 3.2, 4.8, 1.8],
77
- [6.1, 2.8, 4.0, 1.3],
78
- [6.3, 2.5, 4.9, 1.5],
79
- [6.1, 2.8, 4.7, 1.2],
80
- [6.4, 2.9, 4.3, 1.3],
81
- [6.6, 3.0, 4.4, 1.4],
82
- [6.8, 2.8, 4.8, 1.4],
83
- [6.7, 3.0, 5.0, 1.7],
84
- [6.0, 2.9, 4.5, 1.5],
85
- [5.7, 2.6, 3.5, 1.0],
86
- [5.5, 2.4, 3.8, 1.1],
87
- [5.5, 2.4, 3.7, 1.0],
88
- [5.8, 2.7, 3.9, 1.2],
89
- [6.0, 2.7, 5.1, 1.6],
90
- [5.4, 3.0, 4.5, 1.5],
91
- [6.0, 3.4, 4.5, 1.6],
92
- [6.7, 3.1, 4.7, 1.5],
93
- [6.3, 2.3, 4.4, 1.3],
94
- [5.6, 3.0, 4.1, 1.3],
95
- [5.5, 2.5, 4.0, 1.3],
96
- [5.5, 2.6, 4.4, 1.2],
97
- [6.1, 3.0, 4.6, 1.4],
98
- [5.8, 2.6, 4.0, 1.2],
99
- [5.0, 2.3, 3.3, 1.0],
100
- [5.6, 2.7, 4.2, 1.3],
101
- [5.7, 3.0, 4.2, 1.2],
102
- [5.7, 2.9, 4.2, 1.3],
103
- [6.2, 2.9, 4.3, 1.3],
104
- [5.1, 2.5, 3.0, 1.1],
105
- [5.7, 2.8, 4.1, 1.3],
106
- [6.3, 3.3, 6.0, 2.5],
107
- [5.8, 2.7, 5.1, 1.9],
108
- [7.1, 3.0, 5.9, 2.1],
109
- [6.3, 2.9, 5.6, 1.8],
110
- [6.5, 3.0, 5.8, 2.2],
111
- [7.6, 3.0, 6.6, 2.1],
112
- [4.9, 2.5, 4.5, 1.7],
113
- [7.3, 2.9, 6.3, 1.8],
114
- [6.7, 2.5, 5.8, 1.8],
115
- [7.2, 3.6, 6.1, 2.5],
116
- [6.5, 3.2, 5.1, 2.0],
117
- [6.4, 2.7, 5.3, 1.9],
118
- [6.8, 3.0, 5.5, 2.1],
119
- [5.7, 2.5, 5.0, 2.0],
120
- [5.8, 2.8, 5.1, 2.4],
121
- [6.4, 3.2, 5.3, 2.3],
122
- [6.5, 3.0, 5.5, 1.8],
123
- [7.7, 3.8, 6.7, 2.2],
124
- [7.7, 2.6, 6.9, 2.3],
125
- [6.0, 2.2, 5.0, 1.5],
126
- [6.9, 3.2, 5.7, 2.3],
127
- [5.6, 2.8, 4.9, 2.0],
128
- [7.7, 2.8, 6.7, 2.0],
129
- [6.3, 2.7, 4.9, 1.8],
130
- [6.7, 3.3, 5.7, 2.1],
131
- [7.2, 3.2, 6.0, 1.8],
132
- [6.2, 2.8, 4.8, 1.8],
133
- [6.1, 3.0, 4.9, 1.8],
134
- [6.4, 2.8, 5.6, 2.1],
135
- [7.2, 3.0, 5.8, 1.6],
136
- [7.4, 2.8, 6.1, 1.9],
137
- [7.9, 3.8, 6.4, 2.0],
138
- [6.4, 2.8, 5.6, 2.2],
139
- [6.3, 2.8, 5.1, 1.5],
140
- [6.1, 2.6, 5.6, 1.4],
141
- [7.7, 3.0, 6.1, 2.3],
142
- [6.3, 3.4, 5.6, 2.4],
143
- [6.4, 3.1, 5.5, 1.8],
144
- [6.0, 3.0, 4.8, 1.8],
145
- [6.9, 3.1, 5.4, 2.1],
146
- [6.7, 3.1, 5.6, 2.4],
147
- [6.9, 3.1, 5.1, 2.3],
148
- [5.8, 2.7, 5.1, 1.9],
149
- [6.8, 3.2, 5.9, 2.3],
150
- [6.7, 3.3, 5.7, 2.5],
151
- [6.7, 3.0, 5.2, 2.3],
152
- [6.3, 2.5, 5.0, 1.9],
153
- [6.5, 3.0, 5.2, 2.0],
154
- [6.2, 3.4, 5.4, 2.3],
155
- [5.9, 3.0, 5.1, 1.8],
156
- ]
9
+ [5.1, 3.5, 1.4, 0.2],
10
+ [4.9, 3.0, 1.4, 0.2],
11
+ [4.7, 3.2, 1.3, 0.2],
12
+ [4.6, 3.1, 1.5, 0.2],
13
+ [5.0, 3.6, 1.4, 0.2],
14
+ [5.4, 3.9, 1.7, 0.4],
15
+ [4.6, 3.4, 1.4, 0.3],
16
+ [5.0, 3.4, 1.5, 0.2],
17
+ [4.4, 2.9, 1.4, 0.2],
18
+ [4.9, 3.1, 1.5, 0.1],
19
+ [5.4, 3.7, 1.5, 0.2],
20
+ [4.8, 3.4, 1.6, 0.2],
21
+ [4.8, 3.0, 1.4, 0.1],
22
+ [4.3, 3.0, 1.1, 0.1],
23
+ [5.8, 4.0, 1.2, 0.2],
24
+ [5.7, 4.4, 1.5, 0.4],
25
+ [5.4, 3.9, 1.3, 0.4],
26
+ [5.1, 3.5, 1.4, 0.3],
27
+ [5.7, 3.8, 1.7, 0.3],
28
+ [5.1, 3.8, 1.5, 0.3],
29
+ [5.4, 3.4, 1.7, 0.2],
30
+ [5.1, 3.7, 1.5, 0.4],
31
+ [4.6, 3.6, 1.0, 0.2],
32
+ [5.1, 3.3, 1.7, 0.5],
33
+ [4.8, 3.4, 1.9, 0.2],
34
+ [5.0, 3.0, 1.6, 0.2],
35
+ [5.0, 3.4, 1.6, 0.4],
36
+ [5.2, 3.5, 1.5, 0.2],
37
+ [5.2, 3.4, 1.4, 0.2],
38
+ [4.7, 3.2, 1.6, 0.2],
39
+ [4.8, 3.1, 1.6, 0.2],
40
+ [5.4, 3.4, 1.5, 0.4],
41
+ [5.2, 4.1, 1.5, 0.1],
42
+ [5.5, 4.2, 1.4, 0.2],
43
+ [4.9, 3.1, 1.5, 0.1],
44
+ [5.0, 3.2, 1.2, 0.2],
45
+ [5.5, 3.5, 1.3, 0.2],
46
+ [4.9, 3.1, 1.5, 0.1],
47
+ [4.4, 3.0, 1.3, 0.2],
48
+ [5.1, 3.4, 1.5, 0.2],
49
+ [5.0, 3.5, 1.3, 0.3],
50
+ [4.5, 2.3, 1.3, 0.3],
51
+ [4.4, 3.2, 1.3, 0.2],
52
+ [5.0, 3.5, 1.6, 0.6],
53
+ [5.1, 3.8, 1.9, 0.4],
54
+ [4.8, 3.0, 1.4, 0.3],
55
+ [5.1, 3.8, 1.6, 0.2],
56
+ [4.6, 3.2, 1.4, 0.2],
57
+ [5.3, 3.7, 1.5, 0.2],
58
+ [5.0, 3.3, 1.4, 0.2],
59
+ [7.0, 3.2, 4.7, 1.4],
60
+ [6.4, 3.2, 4.5, 1.5],
61
+ [6.9, 3.1, 4.9, 1.5],
62
+ [5.5, 2.3, 4.0, 1.3],
63
+ [6.5, 2.8, 4.6, 1.5],
64
+ [5.7, 2.8, 4.5, 1.3],
65
+ [6.3, 3.3, 4.7, 1.6],
66
+ [4.9, 2.4, 3.3, 1.0],
67
+ [6.6, 2.9, 4.6, 1.3],
68
+ [5.2, 2.7, 3.9, 1.4],
69
+ [5.0, 2.0, 3.5, 1.0],
70
+ [5.9, 3.0, 4.2, 1.5],
71
+ [6.0, 2.2, 4.0, 1.0],
72
+ [6.1, 2.9, 4.7, 1.4],
73
+ [5.6, 2.9, 3.6, 1.3],
74
+ [6.7, 3.1, 4.4, 1.4],
75
+ [5.6, 3.0, 4.5, 1.5],
76
+ [5.8, 2.7, 4.1, 1.0],
77
+ [6.2, 2.2, 4.5, 1.5],
78
+ [5.6, 2.5, 3.9, 1.1],
79
+ [5.9, 3.2, 4.8, 1.8],
80
+ [6.1, 2.8, 4.0, 1.3],
81
+ [6.3, 2.5, 4.9, 1.5],
82
+ [6.1, 2.8, 4.7, 1.2],
83
+ [6.4, 2.9, 4.3, 1.3],
84
+ [6.6, 3.0, 4.4, 1.4],
85
+ [6.8, 2.8, 4.8, 1.4],
86
+ [6.7, 3.0, 5.0, 1.7],
87
+ [6.0, 2.9, 4.5, 1.5],
88
+ [5.7, 2.6, 3.5, 1.0],
89
+ [5.5, 2.4, 3.8, 1.1],
90
+ [5.5, 2.4, 3.7, 1.0],
91
+ [5.8, 2.7, 3.9, 1.2],
92
+ [6.0, 2.7, 5.1, 1.6],
93
+ [5.4, 3.0, 4.5, 1.5],
94
+ [6.0, 3.4, 4.5, 1.6],
95
+ [6.7, 3.1, 4.7, 1.5],
96
+ [6.3, 2.3, 4.4, 1.3],
97
+ [5.6, 3.0, 4.1, 1.3],
98
+ [5.5, 2.5, 4.0, 1.3],
99
+ [5.5, 2.6, 4.4, 1.2],
100
+ [6.1, 3.0, 4.6, 1.4],
101
+ [5.8, 2.6, 4.0, 1.2],
102
+ [5.0, 2.3, 3.3, 1.0],
103
+ [5.6, 2.7, 4.2, 1.3],
104
+ [5.7, 3.0, 4.2, 1.2],
105
+ [5.7, 2.9, 4.2, 1.3],
106
+ [6.2, 2.9, 4.3, 1.3],
107
+ [5.1, 2.5, 3.0, 1.1],
108
+ [5.7, 2.8, 4.1, 1.3],
109
+ [6.3, 3.3, 6.0, 2.5],
110
+ [5.8, 2.7, 5.1, 1.9],
111
+ [7.1, 3.0, 5.9, 2.1],
112
+ [6.3, 2.9, 5.6, 1.8],
113
+ [6.5, 3.0, 5.8, 2.2],
114
+ [7.6, 3.0, 6.6, 2.1],
115
+ [4.9, 2.5, 4.5, 1.7],
116
+ [7.3, 2.9, 6.3, 1.8],
117
+ [6.7, 2.5, 5.8, 1.8],
118
+ [7.2, 3.6, 6.1, 2.5],
119
+ [6.5, 3.2, 5.1, 2.0],
120
+ [6.4, 2.7, 5.3, 1.9],
121
+ [6.8, 3.0, 5.5, 2.1],
122
+ [5.7, 2.5, 5.0, 2.0],
123
+ [5.8, 2.8, 5.1, 2.4],
124
+ [6.4, 3.2, 5.3, 2.3],
125
+ [6.5, 3.0, 5.5, 1.8],
126
+ [7.7, 3.8, 6.7, 2.2],
127
+ [7.7, 2.6, 6.9, 2.3],
128
+ [6.0, 2.2, 5.0, 1.5],
129
+ [6.9, 3.2, 5.7, 2.3],
130
+ [5.6, 2.8, 4.9, 2.0],
131
+ [7.7, 2.8, 6.7, 2.0],
132
+ [6.3, 2.7, 4.9, 1.8],
133
+ [6.7, 3.3, 5.7, 2.1],
134
+ [7.2, 3.2, 6.0, 1.8],
135
+ [6.2, 2.8, 4.8, 1.8],
136
+ [6.1, 3.0, 4.9, 1.8],
137
+ [6.4, 2.8, 5.6, 2.1],
138
+ [7.2, 3.0, 5.8, 1.6],
139
+ [7.4, 2.8, 6.1, 1.9],
140
+ [7.9, 3.8, 6.4, 2.0],
141
+ [6.4, 2.8, 5.6, 2.2],
142
+ [6.3, 2.8, 5.1, 1.5],
143
+ [6.1, 2.6, 5.6, 1.4],
144
+ [7.7, 3.0, 6.1, 2.3],
145
+ [6.3, 3.4, 5.6, 2.4],
146
+ [6.4, 3.1, 5.5, 1.8],
147
+ [6.0, 3.0, 4.8, 1.8],
148
+ [6.9, 3.1, 5.4, 2.1],
149
+ [6.7, 3.1, 5.6, 2.4],
150
+ [6.9, 3.1, 5.1, 2.3],
151
+ [5.8, 2.7, 5.1, 1.9],
152
+ [6.8, 3.2, 5.9, 2.3],
153
+ [6.7, 3.3, 5.7, 2.5],
154
+ [6.7, 3.0, 5.2, 2.3],
155
+ [6.3, 2.5, 5.0, 1.9],
156
+ [6.5, 3.0, 5.2, 2.0],
157
+ [6.2, 3.4, 5.4, 2.3],
158
+ [5.9, 3.0, 5.1, 1.8]
159
+ ].freeze
@@ -1,22 +1,21 @@
1
- # this example shows the impact of the size of a som on the global error distance
2
- require File.dirname(__FILE__) + '/../../lib/ai4r/som/som'
3
- require File.dirname(__FILE__) + '/som_data'
1
+ # frozen_string_literal: true
2
+
3
+ # Demonstrates how map size impacts error and uses early stopping.
4
+ require_relative '../../lib/ai4r/som/som'
5
+ require_relative 'som_data'
4
6
  require 'benchmark'
5
7
 
6
8
  10.times do |t|
7
- t += 3 # minimum number of nodes
9
+ nodes = t + 3 # minimum number of nodes
8
10
 
9
- puts "Nodes: #{t}"
10
- som = Ai4r::Som::Som.new 4, 8, Ai4r::Som::TwoPhaseLayer.new(t)
11
+ puts "Nodes: #{nodes}"
12
+ som = Ai4r::Som::Som.new 4, 8, 8, Ai4r::Som::TwoPhaseLayer.new(nodes)
11
13
  som.initiate_map
12
14
 
13
- puts "global error distance: #{som.global_error(SOM_DATA)}"
14
- puts "\ntraining the som\n"
15
-
15
+ puts "Initial error: #{som.global_error(SOM_DATA)}"
16
16
  times = Benchmark.measure do
17
- som.train SOM_DATA
17
+ som.train(SOM_DATA, error_threshold: 1000)
18
18
  end
19
-
20
19
  puts "Elapsed time for training: #{times}"
21
- puts "global error distance: #{som.global_error(SOM_DATA)}\n\n"
22
- end
20
+ puts "Final error: #{som.global_error(SOM_DATA)}\n\n"
21
+ end
@@ -1,24 +1,21 @@
1
- require File.dirname(__FILE__) + '/../../lib/ai4r/som/som'
2
- require File.dirname(__FILE__) + '/som_data'
1
+ # frozen_string_literal: true
2
+
3
+ require_relative '../../lib/ai4r/som/som'
4
+ require_relative 'som_data'
3
5
  require 'benchmark'
4
6
 
5
- som = Ai4r::Som::Som.new 4, 8, Ai4r::Som::TwoPhaseLayer.new(10)
7
+ # Train a small SOM and stop early when the global error drops below 1000.
8
+ som = Ai4r::Som::Som.new 4, 8, 8, Ai4r::Som::TwoPhaseLayer.new(10)
6
9
  som.initiate_map
7
10
 
8
- som.nodes.each do |node|
9
- p node.weights
10
- end
11
-
12
- puts "global error distance: #{som.global_error(SOM_DATA)}"
13
- puts "\ntraining the som\n"
11
+ puts "Initial global error: #{som.global_error(SOM_DATA)}"
14
12
 
13
+ puts "\nTraining the SOM (early stopping threshold = 1000)\n"
15
14
  times = Benchmark.measure do
16
- som.train SOM_DATA
17
- end
18
-
19
- som.nodes.each do |node|
20
- p node.weights
15
+ som.train(SOM_DATA, error_threshold: 1000) do |error|
16
+ puts "Epoch #{som.epoch}: error = #{error}"
17
+ end
21
18
  end
22
19
 
23
20
  puts "Elapsed time for training: #{times}"
24
- puts "global error distance: #{som.global_error(SOM_DATA)}\n\n"
21
+ puts "Final global error: #{som.global_error(SOM_DATA)}\n"