ai4r 1.12 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.csv +159 -0
  16. data/examples/classifiers/simple_linear_regression_example.rb +18 -0
  17. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  18. data/examples/classifiers/zero_one_r_data.csv +8 -0
  19. data/examples/clusterers/clusterer_example.rb +62 -0
  20. data/examples/clusterers/dbscan_example.rb +17 -0
  21. data/examples/clusterers/dendrogram_example.rb +17 -0
  22. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  23. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  24. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  25. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  26. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  27. data/examples/neural_network/backpropagation_example.rb +49 -48
  28. data/examples/neural_network/hopfield_example.rb +45 -0
  29. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  30. data/examples/neural_network/patterns_with_noise.rb +41 -39
  31. data/examples/neural_network/train_epochs_callback.rb +25 -0
  32. data/examples/neural_network/training_patterns.rb +39 -39
  33. data/examples/neural_network/transformer_text_classification.rb +78 -0
  34. data/examples/neural_network/xor_example.rb +23 -22
  35. data/examples/reinforcement/q_learning_example.rb +10 -0
  36. data/examples/som/som_data.rb +155 -152
  37. data/examples/som/som_multi_node_example.rb +12 -13
  38. data/examples/som/som_single_example.rb +12 -15
  39. data/examples/transformer/decode_classifier_example.rb +68 -0
  40. data/examples/transformer/deterministic_example.rb +10 -0
  41. data/examples/transformer/seq2seq_example.rb +16 -0
  42. data/lib/ai4r/classifiers/classifier.rb +24 -16
  43. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  44. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  45. data/lib/ai4r/classifiers/ib1.rb +122 -32
  46. data/lib/ai4r/classifiers/id3.rb +527 -144
  47. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  48. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  49. data/lib/ai4r/classifiers/naive_bayes.rb +112 -48
  50. data/lib/ai4r/classifiers/one_r.rb +112 -44
  51. data/lib/ai4r/classifiers/prism.rb +167 -76
  52. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  53. data/lib/ai4r/classifiers/simple_linear_regression.rb +143 -0
  54. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  55. data/lib/ai4r/classifiers/votes.rb +57 -0
  56. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  57. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  58. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  59. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  60. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  61. data/lib/ai4r/clusterers/clusterer.rb +28 -24
  62. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  63. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  64. data/lib/ai4r/clusterers/diana.rb +75 -49
  65. data/lib/ai4r/clusterers/k_means.rb +309 -72
  66. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  67. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  68. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  69. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +63 -0
  70. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  71. data/lib/ai4r/data/data_set.rb +229 -100
  72. data/lib/ai4r/data/parameterizable.rb +31 -25
  73. data/lib/ai4r/data/proximity.rb +72 -50
  74. data/lib/ai4r/data/statistics.rb +46 -35
  75. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  76. data/lib/ai4r/experiment/split.rb +39 -0
  77. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  78. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  79. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  80. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  81. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  82. data/lib/ai4r/neural_network/backpropagation.rb +419 -143
  83. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  84. data/lib/ai4r/neural_network/transformer.rb +194 -0
  85. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  86. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  87. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  88. data/lib/ai4r/search/a_star.rb +76 -0
  89. data/lib/ai4r/search/bfs.rb +50 -0
  90. data/lib/ai4r/search/dfs.rb +50 -0
  91. data/lib/ai4r/search/mcts.rb +118 -0
  92. data/lib/ai4r/search.rb +12 -0
  93. data/lib/ai4r/som/distance_metrics.rb +29 -0
  94. data/lib/ai4r/som/layer.rb +28 -17
  95. data/lib/ai4r/som/node.rb +61 -32
  96. data/lib/ai4r/som/som.rb +158 -41
  97. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  98. data/lib/ai4r/version.rb +3 -0
  99. data/lib/ai4r.rb +58 -27
  100. metadata +117 -106
  101. data/README.rdoc +0 -44
  102. data/test/classifiers/hyperpipes_test.rb +0 -84
  103. data/test/classifiers/ib1_test.rb +0 -78
  104. data/test/classifiers/id3_test.rb +0 -208
  105. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  106. data/test/classifiers/naive_bayes_test.rb +0 -43
  107. data/test/classifiers/one_r_test.rb +0 -62
  108. data/test/classifiers/prism_test.rb +0 -85
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -100
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_test.rb +0 -53
  119. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  120. data/test/data/data_set_test.rb +0 -96
  121. data/test/data/proximity_test.rb +0 -81
  122. data/test/data/statistics_test.rb +0 -65
  123. data/test/experiment/classifier_evaluator_test.rb +0 -76
  124. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  125. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  126. data/test/neural_network/backpropagation_test.rb +0 -82
  127. data/test/neural_network/hopfield_test.rb +0 -72
  128. data/test/som/som_test.rb +0 -97
@@ -1,156 +1,159 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # data is from the iris dataset (http://archive.ics.uci.edu/ml/datasets/Iris)
2
4
  # it is the full dataset, removing the last column
3
- # website provides additional information on the dataset itself (attributes, class distribution, etc)
5
+ # website provides additional information on the dataset itself
6
+ # (attributes, class distribution, etc)
4
7
 
5
8
  SOM_DATA = [
6
- [5.1, 3.5, 1.4, 0.2],
7
- [4.9, 3.0, 1.4, 0.2],
8
- [4.7, 3.2, 1.3, 0.2],
9
- [4.6, 3.1, 1.5, 0.2],
10
- [5.0, 3.6, 1.4, 0.2],
11
- [5.4, 3.9, 1.7, 0.4],
12
- [4.6, 3.4, 1.4, 0.3],
13
- [5.0, 3.4, 1.5, 0.2],
14
- [4.4, 2.9, 1.4, 0.2],
15
- [4.9, 3.1, 1.5, 0.1],
16
- [5.4, 3.7, 1.5, 0.2],
17
- [4.8, 3.4, 1.6, 0.2],
18
- [4.8, 3.0, 1.4, 0.1],
19
- [4.3, 3.0, 1.1, 0.1],
20
- [5.8, 4.0, 1.2, 0.2],
21
- [5.7, 4.4, 1.5, 0.4],
22
- [5.4, 3.9, 1.3, 0.4],
23
- [5.1, 3.5, 1.4, 0.3],
24
- [5.7, 3.8, 1.7, 0.3],
25
- [5.1, 3.8, 1.5, 0.3],
26
- [5.4, 3.4, 1.7, 0.2],
27
- [5.1, 3.7, 1.5, 0.4],
28
- [4.6, 3.6, 1.0, 0.2],
29
- [5.1, 3.3, 1.7, 0.5],
30
- [4.8, 3.4, 1.9, 0.2],
31
- [5.0, 3.0, 1.6, 0.2],
32
- [5.0, 3.4, 1.6, 0.4],
33
- [5.2, 3.5, 1.5, 0.2],
34
- [5.2, 3.4, 1.4, 0.2],
35
- [4.7, 3.2, 1.6, 0.2],
36
- [4.8, 3.1, 1.6, 0.2],
37
- [5.4, 3.4, 1.5, 0.4],
38
- [5.2, 4.1, 1.5, 0.1],
39
- [5.5, 4.2, 1.4, 0.2],
40
- [4.9, 3.1, 1.5, 0.1],
41
- [5.0, 3.2, 1.2, 0.2],
42
- [5.5, 3.5, 1.3, 0.2],
43
- [4.9, 3.1, 1.5, 0.1],
44
- [4.4, 3.0, 1.3, 0.2],
45
- [5.1, 3.4, 1.5, 0.2],
46
- [5.0, 3.5, 1.3, 0.3],
47
- [4.5, 2.3, 1.3, 0.3],
48
- [4.4, 3.2, 1.3, 0.2],
49
- [5.0, 3.5, 1.6, 0.6],
50
- [5.1, 3.8, 1.9, 0.4],
51
- [4.8, 3.0, 1.4, 0.3],
52
- [5.1, 3.8, 1.6, 0.2],
53
- [4.6, 3.2, 1.4, 0.2],
54
- [5.3, 3.7, 1.5, 0.2],
55
- [5.0, 3.3, 1.4, 0.2],
56
- [7.0, 3.2, 4.7, 1.4],
57
- [6.4, 3.2, 4.5, 1.5],
58
- [6.9, 3.1, 4.9, 1.5],
59
- [5.5, 2.3, 4.0, 1.3],
60
- [6.5, 2.8, 4.6, 1.5],
61
- [5.7, 2.8, 4.5, 1.3],
62
- [6.3, 3.3, 4.7, 1.6],
63
- [4.9, 2.4, 3.3, 1.0],
64
- [6.6, 2.9, 4.6, 1.3],
65
- [5.2, 2.7, 3.9, 1.4],
66
- [5.0, 2.0, 3.5, 1.0],
67
- [5.9, 3.0, 4.2, 1.5],
68
- [6.0, 2.2, 4.0, 1.0],
69
- [6.1, 2.9, 4.7, 1.4],
70
- [5.6, 2.9, 3.6, 1.3],
71
- [6.7, 3.1, 4.4, 1.4],
72
- [5.6, 3.0, 4.5, 1.5],
73
- [5.8, 2.7, 4.1, 1.0],
74
- [6.2, 2.2, 4.5, 1.5],
75
- [5.6, 2.5, 3.9, 1.1],
76
- [5.9, 3.2, 4.8, 1.8],
77
- [6.1, 2.8, 4.0, 1.3],
78
- [6.3, 2.5, 4.9, 1.5],
79
- [6.1, 2.8, 4.7, 1.2],
80
- [6.4, 2.9, 4.3, 1.3],
81
- [6.6, 3.0, 4.4, 1.4],
82
- [6.8, 2.8, 4.8, 1.4],
83
- [6.7, 3.0, 5.0, 1.7],
84
- [6.0, 2.9, 4.5, 1.5],
85
- [5.7, 2.6, 3.5, 1.0],
86
- [5.5, 2.4, 3.8, 1.1],
87
- [5.5, 2.4, 3.7, 1.0],
88
- [5.8, 2.7, 3.9, 1.2],
89
- [6.0, 2.7, 5.1, 1.6],
90
- [5.4, 3.0, 4.5, 1.5],
91
- [6.0, 3.4, 4.5, 1.6],
92
- [6.7, 3.1, 4.7, 1.5],
93
- [6.3, 2.3, 4.4, 1.3],
94
- [5.6, 3.0, 4.1, 1.3],
95
- [5.5, 2.5, 4.0, 1.3],
96
- [5.5, 2.6, 4.4, 1.2],
97
- [6.1, 3.0, 4.6, 1.4],
98
- [5.8, 2.6, 4.0, 1.2],
99
- [5.0, 2.3, 3.3, 1.0],
100
- [5.6, 2.7, 4.2, 1.3],
101
- [5.7, 3.0, 4.2, 1.2],
102
- [5.7, 2.9, 4.2, 1.3],
103
- [6.2, 2.9, 4.3, 1.3],
104
- [5.1, 2.5, 3.0, 1.1],
105
- [5.7, 2.8, 4.1, 1.3],
106
- [6.3, 3.3, 6.0, 2.5],
107
- [5.8, 2.7, 5.1, 1.9],
108
- [7.1, 3.0, 5.9, 2.1],
109
- [6.3, 2.9, 5.6, 1.8],
110
- [6.5, 3.0, 5.8, 2.2],
111
- [7.6, 3.0, 6.6, 2.1],
112
- [4.9, 2.5, 4.5, 1.7],
113
- [7.3, 2.9, 6.3, 1.8],
114
- [6.7, 2.5, 5.8, 1.8],
115
- [7.2, 3.6, 6.1, 2.5],
116
- [6.5, 3.2, 5.1, 2.0],
117
- [6.4, 2.7, 5.3, 1.9],
118
- [6.8, 3.0, 5.5, 2.1],
119
- [5.7, 2.5, 5.0, 2.0],
120
- [5.8, 2.8, 5.1, 2.4],
121
- [6.4, 3.2, 5.3, 2.3],
122
- [6.5, 3.0, 5.5, 1.8],
123
- [7.7, 3.8, 6.7, 2.2],
124
- [7.7, 2.6, 6.9, 2.3],
125
- [6.0, 2.2, 5.0, 1.5],
126
- [6.9, 3.2, 5.7, 2.3],
127
- [5.6, 2.8, 4.9, 2.0],
128
- [7.7, 2.8, 6.7, 2.0],
129
- [6.3, 2.7, 4.9, 1.8],
130
- [6.7, 3.3, 5.7, 2.1],
131
- [7.2, 3.2, 6.0, 1.8],
132
- [6.2, 2.8, 4.8, 1.8],
133
- [6.1, 3.0, 4.9, 1.8],
134
- [6.4, 2.8, 5.6, 2.1],
135
- [7.2, 3.0, 5.8, 1.6],
136
- [7.4, 2.8, 6.1, 1.9],
137
- [7.9, 3.8, 6.4, 2.0],
138
- [6.4, 2.8, 5.6, 2.2],
139
- [6.3, 2.8, 5.1, 1.5],
140
- [6.1, 2.6, 5.6, 1.4],
141
- [7.7, 3.0, 6.1, 2.3],
142
- [6.3, 3.4, 5.6, 2.4],
143
- [6.4, 3.1, 5.5, 1.8],
144
- [6.0, 3.0, 4.8, 1.8],
145
- [6.9, 3.1, 5.4, 2.1],
146
- [6.7, 3.1, 5.6, 2.4],
147
- [6.9, 3.1, 5.1, 2.3],
148
- [5.8, 2.7, 5.1, 1.9],
149
- [6.8, 3.2, 5.9, 2.3],
150
- [6.7, 3.3, 5.7, 2.5],
151
- [6.7, 3.0, 5.2, 2.3],
152
- [6.3, 2.5, 5.0, 1.9],
153
- [6.5, 3.0, 5.2, 2.0],
154
- [6.2, 3.4, 5.4, 2.3],
155
- [5.9, 3.0, 5.1, 1.8],
156
- ]
9
+ [5.1, 3.5, 1.4, 0.2],
10
+ [4.9, 3.0, 1.4, 0.2],
11
+ [4.7, 3.2, 1.3, 0.2],
12
+ [4.6, 3.1, 1.5, 0.2],
13
+ [5.0, 3.6, 1.4, 0.2],
14
+ [5.4, 3.9, 1.7, 0.4],
15
+ [4.6, 3.4, 1.4, 0.3],
16
+ [5.0, 3.4, 1.5, 0.2],
17
+ [4.4, 2.9, 1.4, 0.2],
18
+ [4.9, 3.1, 1.5, 0.1],
19
+ [5.4, 3.7, 1.5, 0.2],
20
+ [4.8, 3.4, 1.6, 0.2],
21
+ [4.8, 3.0, 1.4, 0.1],
22
+ [4.3, 3.0, 1.1, 0.1],
23
+ [5.8, 4.0, 1.2, 0.2],
24
+ [5.7, 4.4, 1.5, 0.4],
25
+ [5.4, 3.9, 1.3, 0.4],
26
+ [5.1, 3.5, 1.4, 0.3],
27
+ [5.7, 3.8, 1.7, 0.3],
28
+ [5.1, 3.8, 1.5, 0.3],
29
+ [5.4, 3.4, 1.7, 0.2],
30
+ [5.1, 3.7, 1.5, 0.4],
31
+ [4.6, 3.6, 1.0, 0.2],
32
+ [5.1, 3.3, 1.7, 0.5],
33
+ [4.8, 3.4, 1.9, 0.2],
34
+ [5.0, 3.0, 1.6, 0.2],
35
+ [5.0, 3.4, 1.6, 0.4],
36
+ [5.2, 3.5, 1.5, 0.2],
37
+ [5.2, 3.4, 1.4, 0.2],
38
+ [4.7, 3.2, 1.6, 0.2],
39
+ [4.8, 3.1, 1.6, 0.2],
40
+ [5.4, 3.4, 1.5, 0.4],
41
+ [5.2, 4.1, 1.5, 0.1],
42
+ [5.5, 4.2, 1.4, 0.2],
43
+ [4.9, 3.1, 1.5, 0.1],
44
+ [5.0, 3.2, 1.2, 0.2],
45
+ [5.5, 3.5, 1.3, 0.2],
46
+ [4.9, 3.1, 1.5, 0.1],
47
+ [4.4, 3.0, 1.3, 0.2],
48
+ [5.1, 3.4, 1.5, 0.2],
49
+ [5.0, 3.5, 1.3, 0.3],
50
+ [4.5, 2.3, 1.3, 0.3],
51
+ [4.4, 3.2, 1.3, 0.2],
52
+ [5.0, 3.5, 1.6, 0.6],
53
+ [5.1, 3.8, 1.9, 0.4],
54
+ [4.8, 3.0, 1.4, 0.3],
55
+ [5.1, 3.8, 1.6, 0.2],
56
+ [4.6, 3.2, 1.4, 0.2],
57
+ [5.3, 3.7, 1.5, 0.2],
58
+ [5.0, 3.3, 1.4, 0.2],
59
+ [7.0, 3.2, 4.7, 1.4],
60
+ [6.4, 3.2, 4.5, 1.5],
61
+ [6.9, 3.1, 4.9, 1.5],
62
+ [5.5, 2.3, 4.0, 1.3],
63
+ [6.5, 2.8, 4.6, 1.5],
64
+ [5.7, 2.8, 4.5, 1.3],
65
+ [6.3, 3.3, 4.7, 1.6],
66
+ [4.9, 2.4, 3.3, 1.0],
67
+ [6.6, 2.9, 4.6, 1.3],
68
+ [5.2, 2.7, 3.9, 1.4],
69
+ [5.0, 2.0, 3.5, 1.0],
70
+ [5.9, 3.0, 4.2, 1.5],
71
+ [6.0, 2.2, 4.0, 1.0],
72
+ [6.1, 2.9, 4.7, 1.4],
73
+ [5.6, 2.9, 3.6, 1.3],
74
+ [6.7, 3.1, 4.4, 1.4],
75
+ [5.6, 3.0, 4.5, 1.5],
76
+ [5.8, 2.7, 4.1, 1.0],
77
+ [6.2, 2.2, 4.5, 1.5],
78
+ [5.6, 2.5, 3.9, 1.1],
79
+ [5.9, 3.2, 4.8, 1.8],
80
+ [6.1, 2.8, 4.0, 1.3],
81
+ [6.3, 2.5, 4.9, 1.5],
82
+ [6.1, 2.8, 4.7, 1.2],
83
+ [6.4, 2.9, 4.3, 1.3],
84
+ [6.6, 3.0, 4.4, 1.4],
85
+ [6.8, 2.8, 4.8, 1.4],
86
+ [6.7, 3.0, 5.0, 1.7],
87
+ [6.0, 2.9, 4.5, 1.5],
88
+ [5.7, 2.6, 3.5, 1.0],
89
+ [5.5, 2.4, 3.8, 1.1],
90
+ [5.5, 2.4, 3.7, 1.0],
91
+ [5.8, 2.7, 3.9, 1.2],
92
+ [6.0, 2.7, 5.1, 1.6],
93
+ [5.4, 3.0, 4.5, 1.5],
94
+ [6.0, 3.4, 4.5, 1.6],
95
+ [6.7, 3.1, 4.7, 1.5],
96
+ [6.3, 2.3, 4.4, 1.3],
97
+ [5.6, 3.0, 4.1, 1.3],
98
+ [5.5, 2.5, 4.0, 1.3],
99
+ [5.5, 2.6, 4.4, 1.2],
100
+ [6.1, 3.0, 4.6, 1.4],
101
+ [5.8, 2.6, 4.0, 1.2],
102
+ [5.0, 2.3, 3.3, 1.0],
103
+ [5.6, 2.7, 4.2, 1.3],
104
+ [5.7, 3.0, 4.2, 1.2],
105
+ [5.7, 2.9, 4.2, 1.3],
106
+ [6.2, 2.9, 4.3, 1.3],
107
+ [5.1, 2.5, 3.0, 1.1],
108
+ [5.7, 2.8, 4.1, 1.3],
109
+ [6.3, 3.3, 6.0, 2.5],
110
+ [5.8, 2.7, 5.1, 1.9],
111
+ [7.1, 3.0, 5.9, 2.1],
112
+ [6.3, 2.9, 5.6, 1.8],
113
+ [6.5, 3.0, 5.8, 2.2],
114
+ [7.6, 3.0, 6.6, 2.1],
115
+ [4.9, 2.5, 4.5, 1.7],
116
+ [7.3, 2.9, 6.3, 1.8],
117
+ [6.7, 2.5, 5.8, 1.8],
118
+ [7.2, 3.6, 6.1, 2.5],
119
+ [6.5, 3.2, 5.1, 2.0],
120
+ [6.4, 2.7, 5.3, 1.9],
121
+ [6.8, 3.0, 5.5, 2.1],
122
+ [5.7, 2.5, 5.0, 2.0],
123
+ [5.8, 2.8, 5.1, 2.4],
124
+ [6.4, 3.2, 5.3, 2.3],
125
+ [6.5, 3.0, 5.5, 1.8],
126
+ [7.7, 3.8, 6.7, 2.2],
127
+ [7.7, 2.6, 6.9, 2.3],
128
+ [6.0, 2.2, 5.0, 1.5],
129
+ [6.9, 3.2, 5.7, 2.3],
130
+ [5.6, 2.8, 4.9, 2.0],
131
+ [7.7, 2.8, 6.7, 2.0],
132
+ [6.3, 2.7, 4.9, 1.8],
133
+ [6.7, 3.3, 5.7, 2.1],
134
+ [7.2, 3.2, 6.0, 1.8],
135
+ [6.2, 2.8, 4.8, 1.8],
136
+ [6.1, 3.0, 4.9, 1.8],
137
+ [6.4, 2.8, 5.6, 2.1],
138
+ [7.2, 3.0, 5.8, 1.6],
139
+ [7.4, 2.8, 6.1, 1.9],
140
+ [7.9, 3.8, 6.4, 2.0],
141
+ [6.4, 2.8, 5.6, 2.2],
142
+ [6.3, 2.8, 5.1, 1.5],
143
+ [6.1, 2.6, 5.6, 1.4],
144
+ [7.7, 3.0, 6.1, 2.3],
145
+ [6.3, 3.4, 5.6, 2.4],
146
+ [6.4, 3.1, 5.5, 1.8],
147
+ [6.0, 3.0, 4.8, 1.8],
148
+ [6.9, 3.1, 5.4, 2.1],
149
+ [6.7, 3.1, 5.6, 2.4],
150
+ [6.9, 3.1, 5.1, 2.3],
151
+ [5.8, 2.7, 5.1, 1.9],
152
+ [6.8, 3.2, 5.9, 2.3],
153
+ [6.7, 3.3, 5.7, 2.5],
154
+ [6.7, 3.0, 5.2, 2.3],
155
+ [6.3, 2.5, 5.0, 1.9],
156
+ [6.5, 3.0, 5.2, 2.0],
157
+ [6.2, 3.4, 5.4, 2.3],
158
+ [5.9, 3.0, 5.1, 1.8]
159
+ ].freeze
@@ -1,22 +1,21 @@
1
- # this example shows the impact of the size of a som on the global error distance
2
- require File.dirname(__FILE__) + '/../../lib/ai4r/som/som'
3
- require File.dirname(__FILE__) + '/som_data'
1
+ # frozen_string_literal: true
2
+
3
+ # Demonstrates how map size impacts error and uses early stopping.
4
+ require_relative '../../lib/ai4r/som/som'
5
+ require_relative 'som_data'
4
6
  require 'benchmark'
5
7
 
6
8
  10.times do |t|
7
- t += 3 # minimum number of nodes
9
+ nodes = t + 3 # minimum number of nodes
8
10
 
9
- puts "Nodes: #{t}"
10
- som = Ai4r::Som::Som.new 4, 8, Ai4r::Som::TwoPhaseLayer.new(t)
11
+ puts "Nodes: #{nodes}"
12
+ som = Ai4r::Som::Som.new 4, 8, 8, Ai4r::Som::TwoPhaseLayer.new(nodes)
11
13
  som.initiate_map
12
14
 
13
- puts "global error distance: #{som.global_error(SOM_DATA)}"
14
- puts "\ntraining the som\n"
15
-
15
+ puts "Initial error: #{som.global_error(SOM_DATA)}"
16
16
  times = Benchmark.measure do
17
- som.train SOM_DATA
17
+ som.train(SOM_DATA, error_threshold: 1000)
18
18
  end
19
-
20
19
  puts "Elapsed time for training: #{times}"
21
- puts "global error distance: #{som.global_error(SOM_DATA)}\n\n"
22
- end
20
+ puts "Final error: #{som.global_error(SOM_DATA)}\n\n"
21
+ end
@@ -1,24 +1,21 @@
1
- require File.dirname(__FILE__) + '/../../lib/ai4r/som/som'
2
- require File.dirname(__FILE__) + '/som_data'
1
+ # frozen_string_literal: true
2
+
3
+ require_relative '../../lib/ai4r/som/som'
4
+ require_relative 'som_data'
3
5
  require 'benchmark'
4
6
 
5
- som = Ai4r::Som::Som.new 4, 8, Ai4r::Som::TwoPhaseLayer.new(10)
7
+ # Train a small SOM and stop early when the global error drops below 1000.
8
+ som = Ai4r::Som::Som.new 4, 8, 8, Ai4r::Som::TwoPhaseLayer.new(10)
6
9
  som.initiate_map
7
10
 
8
- som.nodes.each do |node|
9
- p node.weights
10
- end
11
-
12
- puts "global error distance: #{som.global_error(SOM_DATA)}"
13
- puts "\ntraining the som\n"
11
+ puts "Initial global error: #{som.global_error(SOM_DATA)}"
14
12
 
13
+ puts "\nTraining the SOM (early stopping threshold = 1000)\n"
15
14
  times = Benchmark.measure do
16
- som.train SOM_DATA
17
- end
18
-
19
- som.nodes.each do |node|
20
- p node.weights
15
+ som.train(SOM_DATA, error_threshold: 1000) do |error|
16
+ puts "Epoch #{som.epoch}: error = #{error}"
17
+ end
21
18
  end
22
19
 
23
20
  puts "Elapsed time for training: #{times}"
24
- puts "global error distance: #{som.global_error(SOM_DATA)}\n\n"
21
+ puts "Final global error: #{som.global_error(SOM_DATA)}\n"
@@ -0,0 +1,68 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative '../../lib/ai4r/neural_network/transformer'
4
+ require_relative '../../lib/ai4r/classifiers/logistic_regression'
5
+ require_relative '../../lib/ai4r/data/data_set'
6
+
7
+ # Tiny dataset of greetings (label 0) and farewells (label 1)
8
+ sentences = [
9
+ %w[hello there],
10
+ %w[how are you],
11
+ %w[good morning],
12
+ %w[nice to meet you],
13
+ %w[goodbye],
14
+ %w[see you later],
15
+ %w[have a nice day],
16
+ %w[take care]
17
+ ]
18
+ labels = [0, 0, 0, 0, 1, 1, 1, 1]
19
+
20
+ # Build vocabulary
21
+ vocab = {}
22
+ next_id = 0
23
+ sentences.each do |tokens|
24
+ tokens.each do |t|
25
+ unless vocab.key?(t)
26
+ vocab[t] = next_id
27
+ next_id += 1
28
+ end
29
+ end
30
+ end
31
+
32
+ vocab_size = vocab.length
33
+ max_len = sentences.map(&:length).max
34
+
35
+ transformer = Ai4r::NeuralNetwork::Transformer.new(
36
+ vocab_size: vocab_size,
37
+ max_len: max_len,
38
+ architecture: :decoder
39
+ )
40
+ embed_dim = transformer.embed_dim
41
+
42
+ # Encode each sentence and average embeddings
43
+ items = []
44
+ sentences.each_with_index do |tokens, idx|
45
+ ids = tokens.map { |t| vocab[t] }
46
+ vecs = transformer.eval(ids)
47
+ avg = Array.new(embed_dim, 0.0)
48
+ vecs.each do |v|
49
+ v.each_index { |i| avg[i] += v[i] }
50
+ end
51
+ avg.map! { |v| v / vecs.length }
52
+ items << (avg + [labels[idx]])
53
+ end
54
+
55
+ labels_names = (0...embed_dim).map { |i| "x#{i}" } + ['class']
56
+ set = Ai4r::Data::DataSet.new(data_items: items, data_labels: labels_names)
57
+
58
+ classifier = Ai4r::Classifiers::LogisticRegression.new
59
+ classifier.set_parameters(lr: 0.5, iterations: 500).build(set)
60
+
61
+ # Classify a short greeting
62
+ sample = %w[hello]
63
+ ids = sample.map { |t| vocab[t] }
64
+ vecs = transformer.eval(ids)
65
+ avg = Array.new(embed_dim, 0.0)
66
+ vecs.each { |v| v.each_index { |i| avg[i] += v[i] } }
67
+ avg.map! { |v| v / vecs.length }
68
+ puts "Prediction: #{classifier.eval(avg)} (0=greeting, 1=farewell)"
@@ -0,0 +1,10 @@
1
+ require_relative '../../lib/ai4r/neural_network/transformer'
2
+
3
+ # Demonstrates deterministic initialization using the :seed parameter.
4
+ model_a = Ai4r::NeuralNetwork::Transformer.new(vocab_size: 5, max_len: 3, seed: 42)
5
+ model_b = Ai4r::NeuralNetwork::Transformer.new(vocab_size: 5, max_len: 3, seed: 42)
6
+
7
+ output_a = model_a.eval([0, 1, 2])
8
+ output_b = model_b.eval([0, 1, 2])
9
+
10
+ puts "Outputs identical? #{output_a == output_b}"
@@ -0,0 +1,16 @@
1
+ require_relative '../../lib/ai4r/neural_network/transformer'
2
+
3
+ # Simple demo of the seq2seq architecture.
4
+ # The model returns random vectors but shows how
5
+ # to provide encoder and decoder inputs.
6
+ model = Ai4r::NeuralNetwork::Transformer.new(
7
+ vocab_size: 10,
8
+ max_len: 5,
9
+ architecture: :seq2seq
10
+ )
11
+
12
+ encoder_input = [1, 2, 3]
13
+ decoder_input = [4, 5]
14
+
15
+ output = model.eval(encoder_input, decoder_input)
16
+ puts "Output length: #{output.length}"
@@ -1,62 +1,70 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://ai4r.org
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
-
10
- require File.dirname(__FILE__) + '/../data/parameterizable'
11
-
11
+
12
+ require_relative '../data/parameterizable'
13
+
12
14
  module Ai4r
13
15
  module Classifiers
14
-
15
16
  # This class defines a common API for classifiers.
16
17
  # All methods in this class must be implemented in subclasses.
17
18
  class Classifier
19
+ include Ai4r::Data::Parameterizable
18
20
 
19
- include Ai4r::Data::Parameterizable
20
-
21
21
  # Build a new classifier, using data examples found in data_set.
22
22
  # The last attribute of each item is considered as the
23
23
  # item class.
24
+ # @param data_set [Object]
25
+ # @return [Object]
24
26
  def build(data_set)
25
27
  raise NotImplementedError
26
28
  end
27
-
29
+
28
30
  # You can evaluate new data, predicting its class.
29
31
  # e.g.
30
32
  # classifier.eval(['New York', '<30', 'F']) # => 'Y'
33
+ # @param data [Object]
34
+ # @return [Object]
31
35
  def eval(data)
32
36
  raise NotImplementedError
33
37
  end
34
-
38
+
35
39
  # This method returns the generated rules in ruby code.
36
40
  # e.g.
37
- #
41
+ #
38
42
  # classifier.get_rules
39
43
  # # => if age_range=='<30' then marketing_target='Y'
40
44
  # elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
41
45
  # elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
42
46
  # elsif age_range=='[50-80]' then marketing_target='N'
43
47
  # elsif age_range=='>80' then marketing_target='Y'
44
- # else raise 'There was not enough information during training to do a proper induction for this data element' end
48
+ # else
49
+ # raise 'There was not enough information during training to do a '
50
+ # 'proper induction for this data element'
51
+ # end
45
52
  #
46
- # It is a nice way to inspect induction results, and also to execute them:
53
+ # It is a nice way to inspect induction results, and also to execute them:
47
54
  # age_range = '<30'
48
55
  # city='New York'
49
56
  # marketing_target = nil
50
- # eval classifier.get_rules
57
+ # eval classifier.get_rules
51
58
  # puts marketing_target
52
59
  # # => 'Y'
53
60
  #
54
61
  # Note, however, that not all classifiers are able to produce rules.
55
62
  # This method is not implemented in such classifiers.
63
+ # @return [Object]
56
64
  def get_rules
57
65
  raise NotImplementedError
58
66
  end
59
-
67
+ # rubocop:enable Naming/AccessorMethodName
60
68
  end
61
69
  end
62
70
  end
@@ -0,0 +1,64 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Author:: OpenAI ChatGPT
4
+ # License:: MPL 1.1
5
+ # Project:: ai4r
6
+ #
7
+ # Very small gradient boosting implementation for regression using
8
+ # simple linear regression as base learner.
9
+
10
+ require_relative 'simple_linear_regression'
11
+ require_relative '../data/data_set'
12
+ require_relative '../classifiers/classifier'
13
+
14
+ module Ai4r
15
+ module Classifiers
16
+ # Gradient boosting regressor using simple linear regression base learners.
17
+ class GradientBoosting < Classifier
18
+ parameters_info n_estimators: 'Number of boosting iterations. Default 10.',
19
+ learning_rate: 'Shrinkage parameter for each learner. Default 0.1.'
20
+
21
+ attr_reader :initial_value, :learners
22
+
23
+ def initialize
24
+ super()
25
+ @n_estimators = 10
26
+ @learning_rate = 0.1
27
+ end
28
+
29
+ def build(data_set)
30
+ data_set.check_not_empty
31
+ @learners = []
32
+ targets = data_set.data_items.map(&:last)
33
+ @initial_value = targets.sum.to_f / targets.length
34
+ predictions = Array.new(targets.length, @initial_value)
35
+ @n_estimators.times do
36
+ residuals = targets.zip(predictions).map { |y, f| y - f }
37
+ items = data_set.data_items.each_with_index.map do |item, idx|
38
+ item[0...-1] + [residuals[idx]]
39
+ end
40
+ ds = Ai4r::Data::DataSet.new(data_items: items, data_labels: data_set.data_labels)
41
+ learner = SimpleLinearRegression.new.build(ds)
42
+ @learners << learner
43
+ pred = items.map { |it| learner.eval(it[0...-1]) }
44
+ predictions = predictions.zip(pred).map { |f, p| f + (@learning_rate * p) }
45
+ end
46
+ self
47
+ end
48
+ # rubocop:enable Metrics/AbcSize
49
+
50
+ def eval(data)
51
+ value = @initial_value
52
+ @learners.each do |learner|
53
+ value += @learning_rate * learner.eval(data)
54
+ end
55
+ value
56
+ end
57
+
58
+ def get_rules
59
+ 'GradientBoosting does not support rule extraction.'
60
+ end
61
+ # rubocop:enable Naming/AccessorMethodName
62
+ end
63
+ end
64
+ end