ai4r 1.13 → 2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +174 -0
  3. data/examples/classifiers/hyperpipes_data.csv +14 -0
  4. data/examples/classifiers/hyperpipes_example.rb +22 -0
  5. data/examples/classifiers/ib1_example.rb +12 -0
  6. data/examples/classifiers/id3_example.rb +15 -10
  7. data/examples/classifiers/id3_graphviz_example.rb +17 -0
  8. data/examples/classifiers/logistic_regression_example.rb +11 -0
  9. data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
  10. data/examples/classifiers/naive_bayes_example.rb +12 -13
  11. data/examples/classifiers/one_r_example.rb +27 -0
  12. data/examples/classifiers/parameter_tutorial.rb +29 -0
  13. data/examples/classifiers/prism_nominal_example.rb +15 -0
  14. data/examples/classifiers/prism_numeric_example.rb +21 -0
  15. data/examples/classifiers/simple_linear_regression_example.rb +14 -11
  16. data/examples/classifiers/zero_and_one_r_example.rb +34 -0
  17. data/examples/classifiers/zero_one_r_data.csv +8 -0
  18. data/examples/clusterers/clusterer_example.rb +40 -34
  19. data/examples/clusterers/dbscan_example.rb +17 -0
  20. data/examples/clusterers/dendrogram_example.rb +17 -0
  21. data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
  22. data/examples/clusterers/kmeans_custom_example.rb +26 -0
  23. data/examples/genetic_algorithm/bitstring_example.rb +41 -0
  24. data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
  25. data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
  26. data/examples/neural_network/backpropagation_example.rb +48 -48
  27. data/examples/neural_network/hopfield_example.rb +45 -0
  28. data/examples/neural_network/patterns_with_base_noise.rb +39 -39
  29. data/examples/neural_network/patterns_with_noise.rb +41 -39
  30. data/examples/neural_network/train_epochs_callback.rb +25 -0
  31. data/examples/neural_network/training_patterns.rb +39 -39
  32. data/examples/neural_network/transformer_text_classification.rb +78 -0
  33. data/examples/neural_network/xor_example.rb +23 -22
  34. data/examples/reinforcement/q_learning_example.rb +10 -0
  35. data/examples/som/som_data.rb +155 -152
  36. data/examples/som/som_multi_node_example.rb +12 -13
  37. data/examples/som/som_single_example.rb +12 -15
  38. data/examples/transformer/decode_classifier_example.rb +68 -0
  39. data/examples/transformer/deterministic_example.rb +10 -0
  40. data/examples/transformer/seq2seq_example.rb +16 -0
  41. data/lib/ai4r/classifiers/classifier.rb +24 -16
  42. data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
  43. data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
  44. data/lib/ai4r/classifiers/ib1.rb +122 -32
  45. data/lib/ai4r/classifiers/id3.rb +524 -145
  46. data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
  47. data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
  48. data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
  49. data/lib/ai4r/classifiers/one_r.rb +112 -44
  50. data/lib/ai4r/classifiers/prism.rb +167 -76
  51. data/lib/ai4r/classifiers/random_forest.rb +72 -0
  52. data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
  53. data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
  54. data/lib/ai4r/classifiers/votes.rb +57 -0
  55. data/lib/ai4r/classifiers/zero_r.rb +71 -30
  56. data/lib/ai4r/clusterers/average_linkage.rb +46 -27
  57. data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
  58. data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
  59. data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
  60. data/lib/ai4r/clusterers/clusterer.rb +29 -14
  61. data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
  62. data/lib/ai4r/clusterers/dbscan.rb +134 -0
  63. data/lib/ai4r/clusterers/diana.rb +75 -49
  64. data/lib/ai4r/clusterers/k_means.rb +270 -135
  65. data/lib/ai4r/clusterers/median_linkage.rb +49 -33
  66. data/lib/ai4r/clusterers/single_linkage.rb +196 -88
  67. data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
  68. data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
  69. data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
  70. data/lib/ai4r/data/data_set.rb +223 -103
  71. data/lib/ai4r/data/parameterizable.rb +31 -25
  72. data/lib/ai4r/data/proximity.rb +62 -62
  73. data/lib/ai4r/data/statistics.rb +46 -35
  74. data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
  75. data/lib/ai4r/experiment/split.rb +39 -0
  76. data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
  77. data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
  78. data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
  79. data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
  80. data/lib/ai4r/neural_network/activation_functions.rb +37 -0
  81. data/lib/ai4r/neural_network/backpropagation.rb +399 -134
  82. data/lib/ai4r/neural_network/hopfield.rb +175 -58
  83. data/lib/ai4r/neural_network/transformer.rb +194 -0
  84. data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
  85. data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
  86. data/lib/ai4r/reinforcement/q_learning.rb +51 -0
  87. data/lib/ai4r/search/a_star.rb +76 -0
  88. data/lib/ai4r/search/bfs.rb +50 -0
  89. data/lib/ai4r/search/dfs.rb +50 -0
  90. data/lib/ai4r/search/mcts.rb +118 -0
  91. data/lib/ai4r/search.rb +12 -0
  92. data/lib/ai4r/som/distance_metrics.rb +29 -0
  93. data/lib/ai4r/som/layer.rb +28 -17
  94. data/lib/ai4r/som/node.rb +61 -32
  95. data/lib/ai4r/som/som.rb +158 -41
  96. data/lib/ai4r/som/two_phase_layer.rb +21 -25
  97. data/lib/ai4r/version.rb +3 -0
  98. data/lib/ai4r.rb +57 -28
  99. metadata +79 -109
  100. data/README.rdoc +0 -39
  101. data/test/classifiers/hyperpipes_test.rb +0 -84
  102. data/test/classifiers/ib1_test.rb +0 -78
  103. data/test/classifiers/id3_test.rb +0 -220
  104. data/test/classifiers/multilayer_perceptron_test.rb +0 -79
  105. data/test/classifiers/naive_bayes_test.rb +0 -43
  106. data/test/classifiers/one_r_test.rb +0 -62
  107. data/test/classifiers/prism_test.rb +0 -85
  108. data/test/classifiers/simple_linear_regression_test.rb +0 -37
  109. data/test/classifiers/zero_r_test.rb +0 -50
  110. data/test/clusterers/average_linkage_test.rb +0 -51
  111. data/test/clusterers/bisecting_k_means_test.rb +0 -66
  112. data/test/clusterers/centroid_linkage_test.rb +0 -53
  113. data/test/clusterers/complete_linkage_test.rb +0 -57
  114. data/test/clusterers/diana_test.rb +0 -69
  115. data/test/clusterers/k_means_test.rb +0 -167
  116. data/test/clusterers/median_linkage_test.rb +0 -53
  117. data/test/clusterers/single_linkage_test.rb +0 -122
  118. data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
  119. data/test/clusterers/ward_linkage_test.rb +0 -53
  120. data/test/clusterers/weighted_average_linkage_test.rb +0 -53
  121. data/test/data/data_set_test.rb +0 -104
  122. data/test/data/proximity_test.rb +0 -87
  123. data/test/data/statistics_test.rb +0 -65
  124. data/test/experiment/classifier_evaluator_test.rb +0 -76
  125. data/test/genetic_algorithm/chromosome_test.rb +0 -57
  126. data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
  127. data/test/neural_network/backpropagation_test.rb +0 -82
  128. data/test/neural_network/hopfield_test.rb +0 -72
  129. data/test/som/som_test.rb +0 -97
@@ -1,61 +1,77 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens (implementation)
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://www.ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
- require File.dirname(__FILE__) + '/../data/data_set'
11
- require File.dirname(__FILE__) + '/../clusterers/single_linkage'
12
+ require_relative '../data/data_set'
13
+ require_relative '../clusterers/single_linkage'
14
+ require_relative '../clusterers/cluster_tree'
12
15
 
13
16
  module Ai4r
14
17
  module Clusterers
15
-
16
- # Implementation of an Agglomerative Hierarchical clusterer with
17
- # weighted average linkage algorithm, aka weighted pair group method
18
+ # Implementation of an Agglomerative Hierarchical clusterer with
19
+ # weighted average linkage algorithm, aka weighted pair group method
18
20
  # average or WPGMA (Jain and Dubes, 1988 ; McQuitty, 1966 )
19
- # Hierarchical clusterer create one cluster per element, and then
21
+ # Hierarchical clusterer create one cluster per element, and then
20
22
  # progressively merge clusters, until the required number of clusters
21
23
  # is reached.
22
- # Similar to AverageLinkage, but the distances between clusters are
24
+ # Similar to AverageLinkage, but the distances between clusters are
23
25
  # weighted based on the number of data items in each of them.
24
- #
26
+ #
25
27
  # D(cx, (ci U cj)) = ( ni * D(cx, ci) + nj * D(cx, cj)) / (ni + nj)
26
28
  class WeightedAverageLinkage < SingleLinkage
27
-
28
- parameters_info :distance_function =>
29
- "Custom implementation of distance function. " +
30
- "It must be a closure receiving two data items and return the " +
31
- "distance between them. By default, this algorithm uses " +
32
- "euclidean distance of numeric attributes to the power of 2."
33
-
29
+ include ClusterTree
30
+
31
+ parameters_info distance_function:
32
+ 'Custom implementation of distance function. ' \
33
+ 'It must be a closure receiving two data items and return the ' \
34
+ 'distance between them. By default, this algorithm uses ' \
35
+ 'euclidean distance of numeric attributes to the power of 2.'
36
+
34
37
  # Build a new clusterer, using data examples found in data_set.
35
38
  # Items will be clustered in "number_of_clusters" different
36
39
  # clusters.
37
- def build(data_set, number_of_clusters)
40
+ # @param data_set [Object]
41
+ # @param number_of_clusters [Object]
42
+ # @param *options [Object]
43
+ # @return [Object]
44
+ def build(data_set, number_of_clusters = 1, **options)
38
45
  super
39
46
  end
40
-
41
- # This algorithms does not allow classification of new data items
47
+
48
+ # This algorithms does not allow classification of new data items
42
49
  # once it has been built. Rebuild the cluster including you data element.
43
- def eval(data_item)
44
- Raise "Eval of new data is not supported by this algorithm."
50
+ # @param _data_item [Object]
51
+ # @return [Object]
52
+ def eval(_data_item)
53
+ raise NotImplementedError, 'Eval of new data is not supported by this algorithm.'
45
54
  end
46
-
55
+
56
+ # @return [Object]
57
+ def supports_eval?
58
+ false
59
+ end
60
+
47
61
  protected
48
-
62
+
49
63
  # return distance between cluster cx and cluster (ci U cj),
50
64
  # using weighted average linkage
51
- def linkage_distance(cx, ci, cj)
52
- ni = @index_clusters[ci].length
53
- nj = @index_clusters[cj].length
54
- (1.0 * ni * read_distance_matrix(cx, ci)+
55
- nj * read_distance_matrix(cx, cj))/(ni+nj)
65
+ # @param cx [Object]
66
+ # @param ci [Object]
67
+ # @param cj [Object]
68
+ # @return [Object]
69
+ def linkage_distance(cluster_x, cluster_i, cluster_j)
70
+ ni = @index_clusters[cluster_i].length
71
+ nj = @index_clusters[cluster_j].length
72
+ ((1.0 * ni * read_distance_matrix(cluster_x, cluster_i)) +
73
+ (nj * read_distance_matrix(cluster_x, cluster_j))) / (ni + nj)
56
74
  end
57
-
58
75
  end
59
76
  end
60
77
  end
61
-
@@ -1,34 +1,51 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Author:: Sergio Fierens
2
4
  # License:: MPL 1.1
3
5
  # Project:: ai4r
4
- # Url:: http://ai4r.org/
6
+ # Url:: https://github.com/SergioFierens/ai4r
5
7
  #
6
- # You can redistribute it and/or modify it under the terms of
7
- # the Mozilla Public License version 1.1 as published by the
8
+ # You can redistribute it and/or modify it under the terms of
9
+ # the Mozilla Public License version 1.1 as published by the
8
10
  # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
9
11
 
10
12
  require 'csv'
11
13
  require 'set'
12
- require File.dirname(__FILE__) + '/statistics'
14
+ require_relative 'statistics'
13
15
 
14
16
  module Ai4r
15
17
  module Data
16
-
17
- # A data set is a collection of N data items. Each data item is
18
+ # A data set is a collection of N data items. Each data item is
18
19
  # described by a set of attributes, represented as an array.
19
- # Optionally, you can assign a label to the attributes, using
20
+ # Optionally, you can assign a label to the attributes, using
20
21
  # the data_labels property.
21
22
  class DataSet
22
-
23
23
  attr_reader :data_labels, :data_items
24
24
 
25
+ # Return a new DataSet with numeric attributes normalized.
26
+ # Available methods are:
27
+ # * +:zscore+ - subtract the mean and divide by the standard deviation
28
+ # * +:minmax+ - scale values to the [0,1] range
29
+ # @param data_set [Object]
30
+ # @param method [Object]
31
+ # @return [Object]
32
+ def self.normalized(data_set, method: :zscore)
33
+ new_set = DataSet.new(
34
+ data_items: data_set.data_items.map(&:dup),
35
+ data_labels: data_set.data_labels.dup
36
+ )
37
+ new_set.normalize!(method)
38
+ end
39
+
25
40
  # Create a new DataSet. By default, empty.
26
41
  # Optionaly, you can provide the initial data items and data labels.
27
- #
42
+ #
28
43
  # e.g. DataSet.new(:data_items => data_items, :data_labels => labels)
29
- #
44
+ #
30
45
  # If you provide data items, but no data labels, the data set will
31
46
  # use the default data label values (see set_data_labels)
47
+ # @param options [Object]
48
+ # @return [Object]
32
49
  def initialize(options = {})
33
50
  @data_labels = []
34
51
  @data_items = options[:data_items] || []
@@ -36,60 +53,70 @@ module Ai4r
36
53
  set_data_items(options[:data_items]) if options[:data_items]
37
54
  end
38
55
 
39
- # Retrieve a new DataSet, with the item(s) selected by the provided
56
+ # Retrieve a new DataSet, with the item(s) selected by the provided
40
57
  # index. You can specify an index range, too.
58
+ # @param index [Object]
59
+ # @return [Object]
41
60
  def [](index)
42
- selected_items = (index.is_a?(Fixnum)) ?
43
- [@data_items[index]] : @data_items[index]
44
- return DataSet.new(:data_items => selected_items,
45
- :data_labels =>@data_labels)
61
+ selected_items = if index.is_a?(Integer)
62
+ [@data_items[index]]
63
+ else
64
+ @data_items[index]
65
+ end
66
+ DataSet.new(data_items: selected_items,
67
+ data_labels: @data_labels)
46
68
  end
47
69
 
48
70
  # Load data items from csv file
49
- def load_csv(filepath)
50
- items = []
51
- open_csv_file(filepath) do |entry|
52
- items << entry
53
- end
54
- set_data_items(items)
55
- end
56
-
57
- # opens a csv-file and reads it line by line
58
- # for each line, a block is called and the row is passed to the block
59
- # ruby1.8 and 1.9 safe
60
- def open_csv_file(filepath, &block)
61
- if CSV.const_defined? :Reader
62
- CSV::Reader.parse(File.open(filepath, 'r')) do |row|
63
- block.call row
64
- end
71
+ # @param filepath [Object]
72
+ # @return [Object]
73
+ def load_csv(filepath, parse_numeric: false)
74
+ if parse_numeric
75
+ parse_csv(filepath)
65
76
  else
66
- CSV.parse(File.open(filepath, 'r')) do |row|
67
- block.call row
77
+ items = []
78
+ open_csv_file(filepath) do |entry|
79
+ items << entry
68
80
  end
81
+ set_data_items(items)
69
82
  end
70
83
  end
71
84
 
85
+ # Open a CSV file and yield each row to the provided block.
86
+ # @param filepath [Object]
87
+ # @param block [Object]
88
+ # @return [Object]
89
+ def open_csv_file(filepath, &)
90
+ CSV.foreach(filepath, &)
91
+ end
92
+
72
93
  # Load data items from csv file. The first row is used as data labels.
73
- def load_csv_with_labels(filepath)
74
- load_csv(filepath)
94
+ # @param filepath [Object]
95
+ # @return [Object]
96
+ def load_csv_with_labels(filepath, parse_numeric: false)
97
+ load_csv(filepath, parse_numeric: parse_numeric)
75
98
  @data_labels = @data_items.shift
76
- return self
99
+ self
77
100
  end
78
101
 
79
102
  # Same as load_csv, but it will try to convert cell contents as numbers.
103
+ # @param filepath [Object]
104
+ # @return [Object]
80
105
  def parse_csv(filepath)
81
106
  items = []
82
107
  open_csv_file(filepath) do |row|
83
- items << row.collect{|x| is_number?(x) ? Float(x) : x }
108
+ items << row.collect do |x|
109
+ number?(x) ? Float(x, exception: false) : x
110
+ end
84
111
  end
85
112
  set_data_items(items)
86
113
  end
87
114
 
88
115
  # Same as load_csv_with_labels, but it will try to convert cell contents as numbers.
116
+ # @param filepath [Object]
117
+ # @return [Object]
89
118
  def parse_csv_with_labels(filepath)
90
- parse_csv(filepath)
91
- @data_labels = @data_items.shift
92
- return self
119
+ load_csv_with_labels(filepath, parse_numeric: true)
93
120
  end
94
121
 
95
122
  # Set data labels.
@@ -98,23 +125,25 @@ module Ai4r
98
125
  #
99
126
  # If you do not provide labels for you data, the following labels will
100
127
  # be created by default:
101
- # [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
128
+ # [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
129
+ # @param labels [Object]
130
+ # @return [Object]
102
131
  def set_data_labels(labels)
103
132
  check_data_labels(labels)
104
133
  @data_labels = labels
105
- return self
134
+ self
106
135
  end
107
136
 
108
137
  # Set the data items.
109
- # M data items with N attributes must have the following
138
+ # M data items with N attributes must have the following
110
139
  # format:
111
- #
112
- # [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CLASS_VAL1],
113
- # [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CLASS_VAL2],
140
+ #
141
+ # [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CLASS_VAL1],
142
+ # [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CLASS_VAL2],
114
143
  # ...
115
- # [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CLASS_VALM],
144
+ # [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CLASS_VALM],
116
145
  # ]
117
- #
146
+ #
118
147
  # e.g.
119
148
  # [ ['New York', '<30', 'M', 'Y'],
120
149
  # ['Chicago', '<30', 'M', 'Y'],
@@ -132,144 +161,235 @@ module Ai4r
132
161
  # ['New York', '[50-80]', 'F', 'N'],
133
162
  # ['Chicago', '>80', 'F', 'Y']
134
163
  # ]
135
- #
164
+ #
136
165
  # This method returns the classifier (self), allowing method chaining.
166
+ # @param items [Object]
167
+ # @return [Object]
137
168
  def set_data_items(items)
138
169
  check_data_items(items)
139
170
  @data_labels = default_data_labels(items) if @data_labels.empty?
140
171
  @data_items = items
141
- return self
172
+ self
142
173
  end
143
174
 
144
175
  # Returns an array with the domain of each attribute:
145
176
  # * Set instance containing all possible values for nominal attributes
146
177
  # * Array with min and max values for numeric attributes (i.e. [min, max])
147
- #
178
+ #
148
179
  # Return example:
149
- # => [#<Set: {"New York", "Chicago"}>,
150
- # #<Set: {"<30", "[30-50)", "[50-80]", ">80"}>,
180
+ # => [#<Set: {"New York", "Chicago"}>,
181
+ # #<Set: {"<30", "[30-50)", "[50-80]", ">80"}>,
151
182
  # #<Set: {"M", "F"}>,
152
- # [5, 85],
183
+ # [5, 85],
153
184
  # #<Set: {"Y", "N"}>]
185
+ # @return [Object]
154
186
  def build_domains
155
- @data_labels.collect {|attr_label| build_domain(attr_label) }
187
+ @data_labels.collect { |attr_label| build_domain(attr_label) }
156
188
  end
157
189
 
158
190
  # Returns a Set instance containing all possible values for an attribute
159
191
  # The parameter can be an attribute label or index (0 based).
160
192
  # * Set instance containing all possible values for nominal attributes
161
193
  # * Array with min and max values for numeric attributes (i.e. [min, max])
162
- #
194
+ #
163
195
  # build_domain("city")
164
196
  # => #<Set: {"New York", "Chicago"}>
165
- #
197
+ #
166
198
  # build_domain("age")
167
199
  # => [5, 85]
168
- #
200
+ #
169
201
  # build_domain(2) # In this example, the third attribute is gender
170
202
  # => #<Set: {"M", "F"}>
203
+ # @param attr [Object]
204
+ # @return [Object]
171
205
  def build_domain(attr)
172
206
  index = get_index(attr)
173
- if @data_items.first[index].is_a?(Numeric)
174
- return [Statistics.min(self, index), Statistics.max(self, index)]
175
- else
176
- return @data_items.inject(Set.new){|domain, x| domain << x[index]}
177
- end
207
+ return [Statistics.min(self, index), Statistics.max(self, index)] if @data_items.first[index].is_a?(Numeric)
208
+
209
+ @data_items.inject(Set.new) { |domain, x| domain << x[index] }
178
210
  end
179
211
 
180
212
  # Returns attributes number, including class attribute
213
+ # @return [Object]
181
214
  def num_attributes
182
- return (@data_items.empty?) ? 0 : @data_items.first.size
215
+ @data_items.empty? ? 0 : @data_items.first.size
183
216
  end
184
217
 
185
218
  # Returns the index of a given attribute (0-based).
186
219
  # For example, if "gender" is the third attribute, then:
187
- # get_index("gender")
220
+ # get_index("gender")
188
221
  # => 2
222
+ # @param attr [Object]
223
+ # @return [Object]
189
224
  def get_index(attr)
190
- return (attr.is_a?(Fixnum) || attr.is_a?(Range)) ? attr : @data_labels.index(attr)
225
+ attr.is_a?(Integer) || attr.is_a?(Range) ? attr : @data_labels.index(attr)
191
226
  end
192
227
 
193
228
  # Raise an exception if there is no data item.
229
+ # @return [Object]
194
230
  def check_not_empty
195
- if @data_items.empty?
196
- raise ArgumentError, "Examples data set must not be empty."
197
- end
231
+ return unless @data_items.empty?
232
+
233
+ raise ArgumentError, 'Examples data set must not be empty.'
198
234
  end
199
235
 
200
236
  # Add a data item to the data set
201
- def << data_item
237
+ # @return [Object]
238
+ def <<(data_item)
202
239
  if data_item.nil? || !data_item.is_a?(Enumerable) || data_item.empty?
203
- raise ArgumentError, "Data must not be an non empty array."
240
+ raise ArgumentError, 'Data must not be an non empty array.'
204
241
  elsif @data_items.empty?
205
242
  set_data_items([data_item])
206
243
  elsif data_item.length != num_attributes
207
- raise ArgumentError, "Number of attributes do not match. " +
208
- "#{data_item.length} attributes provided, " +
209
- "#{num_attributes} attributes expected."
244
+ raise ArgumentError, 'Number of attributes do not match. ' \
245
+ "#{data_item.length} attributes provided, " \
246
+ "#{num_attributes} attributes expected."
210
247
  else
211
248
  @data_items << data_item
212
249
  end
213
250
  end
214
251
 
215
- # Returns an array with the mean value of numeric attributes, and
252
+ # Returns an array with the mean value of numeric attributes, and
216
253
  # the most frequent value of non numeric attributes
254
+ # @return [Object]
217
255
  def get_mean_or_mode
218
256
  mean = []
219
257
  num_attributes.times do |i|
220
258
  mean[i] =
221
- if @data_items.first[i].is_a?(Numeric)
222
- Statistics.mean(self, i)
223
- else
224
- Statistics.mode(self, i)
225
- end
259
+ if @data_items.first[i].is_a?(Numeric)
260
+ Statistics.mean(self, i)
261
+ else
262
+ Statistics.mode(self, i)
263
+ end
226
264
  end
227
- return mean
265
+ mean
266
+ end
267
+
268
+ # Normalize numeric attributes in place. Supported methods are
269
+ # +:zscore+ (default) and +:minmax+.
270
+ # @param method [Object]
271
+ # @return [Object]
272
+ def normalize!(method = :zscore)
273
+ numeric_indices = (0...num_attributes).select do |i|
274
+ @data_items.first[i].is_a?(Numeric)
275
+ end
276
+
277
+ case method
278
+ when :zscore
279
+ means = numeric_indices.map { |i| Statistics.mean(self, i) }
280
+ sds = numeric_indices.map { |i| Statistics.standard_deviation(self, i) }
281
+ @data_items.each do |row|
282
+ numeric_indices.each_with_index do |idx, j|
283
+ sd = sds[j]
284
+ row[idx] = sd.zero? ? 0 : (row[idx] - means[j]) / sd
285
+ end
286
+ end
287
+ when :minmax
288
+ mins = numeric_indices.map { |i| Statistics.min(self, i) }
289
+ maxs = numeric_indices.map { |i| Statistics.max(self, i) }
290
+ @data_items.each do |row|
291
+ numeric_indices.each_with_index do |idx, j|
292
+ range = maxs[j] - mins[j]
293
+ row[idx] = range.zero? ? 0 : (row[idx] - mins[j]) / range.to_f
294
+ end
295
+ end
296
+ else
297
+ raise ArgumentError, "Unknown normalization method #{method}"
298
+ end
299
+
300
+ self
301
+ end
302
+
303
+ # Randomizes the order of data items in place.
304
+ # If a +seed+ is provided, it is used to initialize the random number
305
+ # generator for deterministic shuffling.
306
+ #
307
+ # data_set.shuffle!(seed: 123)
308
+ #
309
+ # @param seed [Integer, nil] Seed for the RNG
310
+ # @return [DataSet] self
311
+ def shuffle!(seed: nil)
312
+ rng = seed ? Random.new(seed) : Random.new
313
+ @data_items.shuffle!(random: rng)
314
+ self
315
+ end
316
+
317
+ # Split the dataset into two new DataSet instances using the given ratio
318
+ # for the first set.
319
+ #
320
+ # train, test = data_set.split(ratio: 0.8)
321
+ #
322
+ # @param ratio [Float] fraction of items to place in the first set
323
+ # @return [Array<DataSet, DataSet>] the two resulting datasets
324
+ def split(ratio:)
325
+ raise ArgumentError, 'ratio must be between 0 and 1' unless ratio.positive? && ratio < 1
326
+
327
+ pivot = (ratio * @data_items.length).round
328
+ first_items = @data_items[0...pivot].map(&:dup)
329
+ second_items = @data_items[pivot..].map(&:dup)
330
+
331
+ [
332
+ DataSet.new(data_items: first_items, data_labels: @data_labels.dup),
333
+ DataSet.new(data_items: second_items, data_labels: @data_labels.dup)
334
+ ]
335
+ end
336
+
337
+ # Returns label of category
338
+ # @return [Object]
339
+ def category_label
340
+ data_labels.last
228
341
  end
229
342
 
230
343
  protected
231
344
 
232
- def is_number?(x)
233
- true if Float(x) rescue false
345
+ # @param x [Object]
346
+ # @return [Object]
347
+ def number?(x)
348
+ !Float(x, exception: false).nil?
234
349
  end
235
350
 
351
+ # @param data_items [Object]
352
+ # @return [Object]
236
353
  def check_data_items(data_items)
237
354
  if !data_items || data_items.empty?
238
- raise ArgumentError, "Examples data set must not be empty."
355
+ raise ArgumentError, 'Examples data set must not be empty.'
239
356
  elsif !data_items.first.is_a?(Enumerable)
240
- raise ArgumentError, "Unkown format for example data."
357
+ raise ArgumentError, 'Unkown format for example data.'
241
358
  end
359
+
242
360
  attributes_num = data_items.first.length
243
361
  data_items.each_index do |index|
244
- if data_items[index].length != attributes_num
245
- raise ArgumentError,
246
- "Quantity of attributes is inconsistent. " +
247
- "The first item has #{attributes_num} attributes "+
248
- "and row #{index} has #{data_items[index].length} attributes"
249
- end
362
+ next unless data_items[index].length != attributes_num
363
+
364
+ raise ArgumentError,
365
+ 'Quantity of attributes is inconsistent. ' \
366
+ "The first item has #{attributes_num} attributes " \
367
+ "and row #{index} has #{data_items[index].length} attributes"
250
368
  end
251
369
  end
252
370
 
371
+ # @param labels [Object]
372
+ # @return [Object]
253
373
  def check_data_labels(labels)
254
- if !@data_items.empty?
255
- if labels.length != @data_items.first.length
256
- raise ArgumentError,
257
- "Number of labels and attributes do not match. " +
258
- "#{labels.length} labels and " +
259
- "#{@data_items.first.length} attributes found."
260
- end
261
- end
374
+ return if @data_items.empty?
375
+ return unless labels.length != @data_items.first.length
376
+
377
+ raise ArgumentError,
378
+ 'Number of labels and attributes do not match. ' \
379
+ "#{labels.length} labels and " \
380
+ "#{@data_items.first.length} attributes found."
262
381
  end
263
382
 
383
+ # @param data_items [Object]
384
+ # @return [Object]
264
385
  def default_data_labels(data_items)
265
386
  data_labels = []
266
387
  data_items[0][0..-2].each_index do |i|
267
- data_labels[i] = "attribute_#{i+1}"
388
+ data_labels[i] = "attribute_#{i + 1}"
268
389
  end
269
- data_labels[data_labels.length]="class_value"
270
- return data_labels
390
+ data_labels[data_labels.length] = 'class_value'
391
+ data_labels
271
392
  end
272
-
273
393
  end
274
394
  end
275
395
  end