ai4r 1.13 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.rb +14 -11
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +40 -34
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +48 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +524 -145
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +95 -34
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +83 -58
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +29 -14
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +270 -135
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +25 -10
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +223 -103
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +62 -62
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +399 -134
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +57 -28
- metadata +79 -109
- data/README.rdoc +0 -39
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -220
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/simple_linear_regression_test.rb +0 -37
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -167
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_hierarchical_test.rb +0 -81
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -104
- data/test/data/proximity_test.rb +0 -87
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
@@ -1,61 +1,77 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens (implementation)
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
|
-
|
11
|
-
|
12
|
+
require_relative '../data/data_set'
|
13
|
+
require_relative '../clusterers/single_linkage'
|
14
|
+
require_relative '../clusterers/cluster_tree'
|
12
15
|
|
13
16
|
module Ai4r
|
14
17
|
module Clusterers
|
15
|
-
|
16
|
-
#
|
17
|
-
# weighted average linkage algorithm, aka weighted pair group method
|
18
|
+
# Implementation of an Agglomerative Hierarchical clusterer with
|
19
|
+
# weighted average linkage algorithm, aka weighted pair group method
|
18
20
|
# average or WPGMA (Jain and Dubes, 1988 ; McQuitty, 1966 )
|
19
|
-
# Hierarchical clusterer create one cluster per element, and then
|
21
|
+
# Hierarchical clusterer create one cluster per element, and then
|
20
22
|
# progressively merge clusters, until the required number of clusters
|
21
23
|
# is reached.
|
22
|
-
# Similar to AverageLinkage, but the distances between clusters are
|
24
|
+
# Similar to AverageLinkage, but the distances between clusters are
|
23
25
|
# weighted based on the number of data items in each of them.
|
24
|
-
#
|
26
|
+
#
|
25
27
|
# D(cx, (ci U cj)) = ( ni * D(cx, ci) + nj * D(cx, cj)) / (ni + nj)
|
26
28
|
class WeightedAverageLinkage < SingleLinkage
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
29
|
+
include ClusterTree
|
30
|
+
|
31
|
+
parameters_info distance_function:
|
32
|
+
'Custom implementation of distance function. ' \
|
33
|
+
'It must be a closure receiving two data items and return the ' \
|
34
|
+
'distance between them. By default, this algorithm uses ' \
|
35
|
+
'euclidean distance of numeric attributes to the power of 2.'
|
36
|
+
|
34
37
|
# Build a new clusterer, using data examples found in data_set.
|
35
38
|
# Items will be clustered in "number_of_clusters" different
|
36
39
|
# clusters.
|
37
|
-
|
40
|
+
# @param data_set [Object]
|
41
|
+
# @param number_of_clusters [Object]
|
42
|
+
# @param *options [Object]
|
43
|
+
# @return [Object]
|
44
|
+
def build(data_set, number_of_clusters = 1, **options)
|
38
45
|
super
|
39
46
|
end
|
40
|
-
|
41
|
-
# This algorithms does not allow classification of new data items
|
47
|
+
|
48
|
+
# This algorithms does not allow classification of new data items
|
42
49
|
# once it has been built. Rebuild the cluster including you data element.
|
43
|
-
|
44
|
-
|
50
|
+
# @param _data_item [Object]
|
51
|
+
# @return [Object]
|
52
|
+
def eval(_data_item)
|
53
|
+
raise NotImplementedError, 'Eval of new data is not supported by this algorithm.'
|
45
54
|
end
|
46
|
-
|
55
|
+
|
56
|
+
# @return [Object]
|
57
|
+
def supports_eval?
|
58
|
+
false
|
59
|
+
end
|
60
|
+
|
47
61
|
protected
|
48
|
-
|
62
|
+
|
49
63
|
# return distance between cluster cx and cluster (ci U cj),
|
50
64
|
# using weighted average linkage
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
65
|
+
# @param cx [Object]
|
66
|
+
# @param ci [Object]
|
67
|
+
# @param cj [Object]
|
68
|
+
# @return [Object]
|
69
|
+
def linkage_distance(cluster_x, cluster_i, cluster_j)
|
70
|
+
ni = @index_clusters[cluster_i].length
|
71
|
+
nj = @index_clusters[cluster_j].length
|
72
|
+
((1.0 * ni * read_distance_matrix(cluster_x, cluster_i)) +
|
73
|
+
(nj * read_distance_matrix(cluster_x, cluster_j))) / (ni + nj)
|
56
74
|
end
|
57
|
-
|
58
75
|
end
|
59
76
|
end
|
60
77
|
end
|
61
|
-
|
data/lib/ai4r/data/data_set.rb
CHANGED
@@ -1,34 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
12
|
require 'csv'
|
11
13
|
require 'set'
|
12
|
-
|
14
|
+
require_relative 'statistics'
|
13
15
|
|
14
16
|
module Ai4r
|
15
17
|
module Data
|
16
|
-
|
17
|
-
# A data set is a collection of N data items. Each data item is
|
18
|
+
# A data set is a collection of N data items. Each data item is
|
18
19
|
# described by a set of attributes, represented as an array.
|
19
|
-
# Optionally, you can assign a label to the attributes, using
|
20
|
+
# Optionally, you can assign a label to the attributes, using
|
20
21
|
# the data_labels property.
|
21
22
|
class DataSet
|
22
|
-
|
23
23
|
attr_reader :data_labels, :data_items
|
24
24
|
|
25
|
+
# Return a new DataSet with numeric attributes normalized.
|
26
|
+
# Available methods are:
|
27
|
+
# * +:zscore+ - subtract the mean and divide by the standard deviation
|
28
|
+
# * +:minmax+ - scale values to the [0,1] range
|
29
|
+
# @param data_set [Object]
|
30
|
+
# @param method [Object]
|
31
|
+
# @return [Object]
|
32
|
+
def self.normalized(data_set, method: :zscore)
|
33
|
+
new_set = DataSet.new(
|
34
|
+
data_items: data_set.data_items.map(&:dup),
|
35
|
+
data_labels: data_set.data_labels.dup
|
36
|
+
)
|
37
|
+
new_set.normalize!(method)
|
38
|
+
end
|
39
|
+
|
25
40
|
# Create a new DataSet. By default, empty.
|
26
41
|
# Optionaly, you can provide the initial data items and data labels.
|
27
|
-
#
|
42
|
+
#
|
28
43
|
# e.g. DataSet.new(:data_items => data_items, :data_labels => labels)
|
29
|
-
#
|
44
|
+
#
|
30
45
|
# If you provide data items, but no data labels, the data set will
|
31
46
|
# use the default data label values (see set_data_labels)
|
47
|
+
# @param options [Object]
|
48
|
+
# @return [Object]
|
32
49
|
def initialize(options = {})
|
33
50
|
@data_labels = []
|
34
51
|
@data_items = options[:data_items] || []
|
@@ -36,60 +53,70 @@ module Ai4r
|
|
36
53
|
set_data_items(options[:data_items]) if options[:data_items]
|
37
54
|
end
|
38
55
|
|
39
|
-
# Retrieve a new DataSet, with the item(s) selected by the provided
|
56
|
+
# Retrieve a new DataSet, with the item(s) selected by the provided
|
40
57
|
# index. You can specify an index range, too.
|
58
|
+
# @param index [Object]
|
59
|
+
# @return [Object]
|
41
60
|
def [](index)
|
42
|
-
selected_items =
|
43
|
-
|
44
|
-
|
45
|
-
|
61
|
+
selected_items = if index.is_a?(Integer)
|
62
|
+
[@data_items[index]]
|
63
|
+
else
|
64
|
+
@data_items[index]
|
65
|
+
end
|
66
|
+
DataSet.new(data_items: selected_items,
|
67
|
+
data_labels: @data_labels)
|
46
68
|
end
|
47
69
|
|
48
70
|
# Load data items from csv file
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
set_data_items(items)
|
55
|
-
end
|
56
|
-
|
57
|
-
# opens a csv-file and reads it line by line
|
58
|
-
# for each line, a block is called and the row is passed to the block
|
59
|
-
# ruby1.8 and 1.9 safe
|
60
|
-
def open_csv_file(filepath, &block)
|
61
|
-
if CSV.const_defined? :Reader
|
62
|
-
CSV::Reader.parse(File.open(filepath, 'r')) do |row|
|
63
|
-
block.call row
|
64
|
-
end
|
71
|
+
# @param filepath [Object]
|
72
|
+
# @return [Object]
|
73
|
+
def load_csv(filepath, parse_numeric: false)
|
74
|
+
if parse_numeric
|
75
|
+
parse_csv(filepath)
|
65
76
|
else
|
66
|
-
|
67
|
-
|
77
|
+
items = []
|
78
|
+
open_csv_file(filepath) do |entry|
|
79
|
+
items << entry
|
68
80
|
end
|
81
|
+
set_data_items(items)
|
69
82
|
end
|
70
83
|
end
|
71
84
|
|
85
|
+
# Open a CSV file and yield each row to the provided block.
|
86
|
+
# @param filepath [Object]
|
87
|
+
# @param block [Object]
|
88
|
+
# @return [Object]
|
89
|
+
def open_csv_file(filepath, &)
|
90
|
+
CSV.foreach(filepath, &)
|
91
|
+
end
|
92
|
+
|
72
93
|
# Load data items from csv file. The first row is used as data labels.
|
73
|
-
|
74
|
-
|
94
|
+
# @param filepath [Object]
|
95
|
+
# @return [Object]
|
96
|
+
def load_csv_with_labels(filepath, parse_numeric: false)
|
97
|
+
load_csv(filepath, parse_numeric: parse_numeric)
|
75
98
|
@data_labels = @data_items.shift
|
76
|
-
|
99
|
+
self
|
77
100
|
end
|
78
101
|
|
79
102
|
# Same as load_csv, but it will try to convert cell contents as numbers.
|
103
|
+
# @param filepath [Object]
|
104
|
+
# @return [Object]
|
80
105
|
def parse_csv(filepath)
|
81
106
|
items = []
|
82
107
|
open_csv_file(filepath) do |row|
|
83
|
-
items << row.collect
|
108
|
+
items << row.collect do |x|
|
109
|
+
number?(x) ? Float(x, exception: false) : x
|
110
|
+
end
|
84
111
|
end
|
85
112
|
set_data_items(items)
|
86
113
|
end
|
87
114
|
|
88
115
|
# Same as load_csv_with_labels, but it will try to convert cell contents as numbers.
|
116
|
+
# @param filepath [Object]
|
117
|
+
# @return [Object]
|
89
118
|
def parse_csv_with_labels(filepath)
|
90
|
-
|
91
|
-
@data_labels = @data_items.shift
|
92
|
-
return self
|
119
|
+
load_csv_with_labels(filepath, parse_numeric: true)
|
93
120
|
end
|
94
121
|
|
95
122
|
# Set data labels.
|
@@ -98,23 +125,25 @@ module Ai4r
|
|
98
125
|
#
|
99
126
|
# If you do not provide labels for you data, the following labels will
|
100
127
|
# be created by default:
|
101
|
-
# [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
|
128
|
+
# [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
|
129
|
+
# @param labels [Object]
|
130
|
+
# @return [Object]
|
102
131
|
def set_data_labels(labels)
|
103
132
|
check_data_labels(labels)
|
104
133
|
@data_labels = labels
|
105
|
-
|
134
|
+
self
|
106
135
|
end
|
107
136
|
|
108
137
|
# Set the data items.
|
109
|
-
# M data items with N attributes must have the following
|
138
|
+
# M data items with N attributes must have the following
|
110
139
|
# format:
|
111
|
-
#
|
112
|
-
# [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CLASS_VAL1],
|
113
|
-
# [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CLASS_VAL2],
|
140
|
+
#
|
141
|
+
# [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CLASS_VAL1],
|
142
|
+
# [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CLASS_VAL2],
|
114
143
|
# ...
|
115
|
-
# [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CLASS_VALM],
|
144
|
+
# [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CLASS_VALM],
|
116
145
|
# ]
|
117
|
-
#
|
146
|
+
#
|
118
147
|
# e.g.
|
119
148
|
# [ ['New York', '<30', 'M', 'Y'],
|
120
149
|
# ['Chicago', '<30', 'M', 'Y'],
|
@@ -132,144 +161,235 @@ module Ai4r
|
|
132
161
|
# ['New York', '[50-80]', 'F', 'N'],
|
133
162
|
# ['Chicago', '>80', 'F', 'Y']
|
134
163
|
# ]
|
135
|
-
#
|
164
|
+
#
|
136
165
|
# This method returns the classifier (self), allowing method chaining.
|
166
|
+
# @param items [Object]
|
167
|
+
# @return [Object]
|
137
168
|
def set_data_items(items)
|
138
169
|
check_data_items(items)
|
139
170
|
@data_labels = default_data_labels(items) if @data_labels.empty?
|
140
171
|
@data_items = items
|
141
|
-
|
172
|
+
self
|
142
173
|
end
|
143
174
|
|
144
175
|
# Returns an array with the domain of each attribute:
|
145
176
|
# * Set instance containing all possible values for nominal attributes
|
146
177
|
# * Array with min and max values for numeric attributes (i.e. [min, max])
|
147
|
-
#
|
178
|
+
#
|
148
179
|
# Return example:
|
149
|
-
# => [#<Set: {"New York", "Chicago"}>,
|
150
|
-
# #<Set: {"<30", "[30-50)", "[50-80]", ">80"}>,
|
180
|
+
# => [#<Set: {"New York", "Chicago"}>,
|
181
|
+
# #<Set: {"<30", "[30-50)", "[50-80]", ">80"}>,
|
151
182
|
# #<Set: {"M", "F"}>,
|
152
|
-
# [5, 85],
|
183
|
+
# [5, 85],
|
153
184
|
# #<Set: {"Y", "N"}>]
|
185
|
+
# @return [Object]
|
154
186
|
def build_domains
|
155
|
-
@data_labels.collect {|attr_label| build_domain(attr_label) }
|
187
|
+
@data_labels.collect { |attr_label| build_domain(attr_label) }
|
156
188
|
end
|
157
189
|
|
158
190
|
# Returns a Set instance containing all possible values for an attribute
|
159
191
|
# The parameter can be an attribute label or index (0 based).
|
160
192
|
# * Set instance containing all possible values for nominal attributes
|
161
193
|
# * Array with min and max values for numeric attributes (i.e. [min, max])
|
162
|
-
#
|
194
|
+
#
|
163
195
|
# build_domain("city")
|
164
196
|
# => #<Set: {"New York", "Chicago"}>
|
165
|
-
#
|
197
|
+
#
|
166
198
|
# build_domain("age")
|
167
199
|
# => [5, 85]
|
168
|
-
#
|
200
|
+
#
|
169
201
|
# build_domain(2) # In this example, the third attribute is gender
|
170
202
|
# => #<Set: {"M", "F"}>
|
203
|
+
# @param attr [Object]
|
204
|
+
# @return [Object]
|
171
205
|
def build_domain(attr)
|
172
206
|
index = get_index(attr)
|
173
|
-
if @data_items.first[index].is_a?(Numeric)
|
174
|
-
|
175
|
-
|
176
|
-
return @data_items.inject(Set.new){|domain, x| domain << x[index]}
|
177
|
-
end
|
207
|
+
return [Statistics.min(self, index), Statistics.max(self, index)] if @data_items.first[index].is_a?(Numeric)
|
208
|
+
|
209
|
+
@data_items.inject(Set.new) { |domain, x| domain << x[index] }
|
178
210
|
end
|
179
211
|
|
180
212
|
# Returns attributes number, including class attribute
|
213
|
+
# @return [Object]
|
181
214
|
def num_attributes
|
182
|
-
|
215
|
+
@data_items.empty? ? 0 : @data_items.first.size
|
183
216
|
end
|
184
217
|
|
185
218
|
# Returns the index of a given attribute (0-based).
|
186
219
|
# For example, if "gender" is the third attribute, then:
|
187
|
-
# get_index("gender")
|
220
|
+
# get_index("gender")
|
188
221
|
# => 2
|
222
|
+
# @param attr [Object]
|
223
|
+
# @return [Object]
|
189
224
|
def get_index(attr)
|
190
|
-
|
225
|
+
attr.is_a?(Integer) || attr.is_a?(Range) ? attr : @data_labels.index(attr)
|
191
226
|
end
|
192
227
|
|
193
228
|
# Raise an exception if there is no data item.
|
229
|
+
# @return [Object]
|
194
230
|
def check_not_empty
|
195
|
-
|
196
|
-
|
197
|
-
|
231
|
+
return unless @data_items.empty?
|
232
|
+
|
233
|
+
raise ArgumentError, 'Examples data set must not be empty.'
|
198
234
|
end
|
199
235
|
|
200
236
|
# Add a data item to the data set
|
201
|
-
|
237
|
+
# @return [Object]
|
238
|
+
def <<(data_item)
|
202
239
|
if data_item.nil? || !data_item.is_a?(Enumerable) || data_item.empty?
|
203
|
-
raise ArgumentError,
|
240
|
+
raise ArgumentError, 'Data must not be an non empty array.'
|
204
241
|
elsif @data_items.empty?
|
205
242
|
set_data_items([data_item])
|
206
243
|
elsif data_item.length != num_attributes
|
207
|
-
raise ArgumentError,
|
208
|
-
|
209
|
-
|
244
|
+
raise ArgumentError, 'Number of attributes do not match. ' \
|
245
|
+
"#{data_item.length} attributes provided, " \
|
246
|
+
"#{num_attributes} attributes expected."
|
210
247
|
else
|
211
248
|
@data_items << data_item
|
212
249
|
end
|
213
250
|
end
|
214
251
|
|
215
|
-
# Returns an array with the mean value of numeric attributes, and
|
252
|
+
# Returns an array with the mean value of numeric attributes, and
|
216
253
|
# the most frequent value of non numeric attributes
|
254
|
+
# @return [Object]
|
217
255
|
def get_mean_or_mode
|
218
256
|
mean = []
|
219
257
|
num_attributes.times do |i|
|
220
258
|
mean[i] =
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
259
|
+
if @data_items.first[i].is_a?(Numeric)
|
260
|
+
Statistics.mean(self, i)
|
261
|
+
else
|
262
|
+
Statistics.mode(self, i)
|
263
|
+
end
|
226
264
|
end
|
227
|
-
|
265
|
+
mean
|
266
|
+
end
|
267
|
+
|
268
|
+
# Normalize numeric attributes in place. Supported methods are
|
269
|
+
# +:zscore+ (default) and +:minmax+.
|
270
|
+
# @param method [Object]
|
271
|
+
# @return [Object]
|
272
|
+
def normalize!(method = :zscore)
|
273
|
+
numeric_indices = (0...num_attributes).select do |i|
|
274
|
+
@data_items.first[i].is_a?(Numeric)
|
275
|
+
end
|
276
|
+
|
277
|
+
case method
|
278
|
+
when :zscore
|
279
|
+
means = numeric_indices.map { |i| Statistics.mean(self, i) }
|
280
|
+
sds = numeric_indices.map { |i| Statistics.standard_deviation(self, i) }
|
281
|
+
@data_items.each do |row|
|
282
|
+
numeric_indices.each_with_index do |idx, j|
|
283
|
+
sd = sds[j]
|
284
|
+
row[idx] = sd.zero? ? 0 : (row[idx] - means[j]) / sd
|
285
|
+
end
|
286
|
+
end
|
287
|
+
when :minmax
|
288
|
+
mins = numeric_indices.map { |i| Statistics.min(self, i) }
|
289
|
+
maxs = numeric_indices.map { |i| Statistics.max(self, i) }
|
290
|
+
@data_items.each do |row|
|
291
|
+
numeric_indices.each_with_index do |idx, j|
|
292
|
+
range = maxs[j] - mins[j]
|
293
|
+
row[idx] = range.zero? ? 0 : (row[idx] - mins[j]) / range.to_f
|
294
|
+
end
|
295
|
+
end
|
296
|
+
else
|
297
|
+
raise ArgumentError, "Unknown normalization method #{method}"
|
298
|
+
end
|
299
|
+
|
300
|
+
self
|
301
|
+
end
|
302
|
+
|
303
|
+
# Randomizes the order of data items in place.
|
304
|
+
# If a +seed+ is provided, it is used to initialize the random number
|
305
|
+
# generator for deterministic shuffling.
|
306
|
+
#
|
307
|
+
# data_set.shuffle!(seed: 123)
|
308
|
+
#
|
309
|
+
# @param seed [Integer, nil] Seed for the RNG
|
310
|
+
# @return [DataSet] self
|
311
|
+
def shuffle!(seed: nil)
|
312
|
+
rng = seed ? Random.new(seed) : Random.new
|
313
|
+
@data_items.shuffle!(random: rng)
|
314
|
+
self
|
315
|
+
end
|
316
|
+
|
317
|
+
# Split the dataset into two new DataSet instances using the given ratio
|
318
|
+
# for the first set.
|
319
|
+
#
|
320
|
+
# train, test = data_set.split(ratio: 0.8)
|
321
|
+
#
|
322
|
+
# @param ratio [Float] fraction of items to place in the first set
|
323
|
+
# @return [Array<DataSet, DataSet>] the two resulting datasets
|
324
|
+
def split(ratio:)
|
325
|
+
raise ArgumentError, 'ratio must be between 0 and 1' unless ratio.positive? && ratio < 1
|
326
|
+
|
327
|
+
pivot = (ratio * @data_items.length).round
|
328
|
+
first_items = @data_items[0...pivot].map(&:dup)
|
329
|
+
second_items = @data_items[pivot..].map(&:dup)
|
330
|
+
|
331
|
+
[
|
332
|
+
DataSet.new(data_items: first_items, data_labels: @data_labels.dup),
|
333
|
+
DataSet.new(data_items: second_items, data_labels: @data_labels.dup)
|
334
|
+
]
|
335
|
+
end
|
336
|
+
|
337
|
+
# Returns label of category
|
338
|
+
# @return [Object]
|
339
|
+
def category_label
|
340
|
+
data_labels.last
|
228
341
|
end
|
229
342
|
|
230
343
|
protected
|
231
344
|
|
232
|
-
|
233
|
-
|
345
|
+
# @param x [Object]
|
346
|
+
# @return [Object]
|
347
|
+
def number?(x)
|
348
|
+
!Float(x, exception: false).nil?
|
234
349
|
end
|
235
350
|
|
351
|
+
# @param data_items [Object]
|
352
|
+
# @return [Object]
|
236
353
|
def check_data_items(data_items)
|
237
354
|
if !data_items || data_items.empty?
|
238
|
-
raise ArgumentError,
|
355
|
+
raise ArgumentError, 'Examples data set must not be empty.'
|
239
356
|
elsif !data_items.first.is_a?(Enumerable)
|
240
|
-
raise ArgumentError,
|
357
|
+
raise ArgumentError, 'Unkown format for example data.'
|
241
358
|
end
|
359
|
+
|
242
360
|
attributes_num = data_items.first.length
|
243
361
|
data_items.each_index do |index|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
362
|
+
next unless data_items[index].length != attributes_num
|
363
|
+
|
364
|
+
raise ArgumentError,
|
365
|
+
'Quantity of attributes is inconsistent. ' \
|
366
|
+
"The first item has #{attributes_num} attributes " \
|
367
|
+
"and row #{index} has #{data_items[index].length} attributes"
|
250
368
|
end
|
251
369
|
end
|
252
370
|
|
371
|
+
# @param labels [Object]
|
372
|
+
# @return [Object]
|
253
373
|
def check_data_labels(labels)
|
254
|
-
if
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
end
|
374
|
+
return if @data_items.empty?
|
375
|
+
return unless labels.length != @data_items.first.length
|
376
|
+
|
377
|
+
raise ArgumentError,
|
378
|
+
'Number of labels and attributes do not match. ' \
|
379
|
+
"#{labels.length} labels and " \
|
380
|
+
"#{@data_items.first.length} attributes found."
|
262
381
|
end
|
263
382
|
|
383
|
+
# @param data_items [Object]
|
384
|
+
# @return [Object]
|
264
385
|
def default_data_labels(data_items)
|
265
386
|
data_labels = []
|
266
387
|
data_items[0][0..-2].each_index do |i|
|
267
|
-
data_labels[i] = "attribute_#{i+1}"
|
388
|
+
data_labels[i] = "attribute_#{i + 1}"
|
268
389
|
end
|
269
|
-
data_labels[data_labels.length]=
|
270
|
-
|
390
|
+
data_labels[data_labels.length] = 'class_value'
|
391
|
+
data_labels
|
271
392
|
end
|
272
|
-
|
273
393
|
end
|
274
394
|
end
|
275
395
|
end
|