ai4r 1.12 → 2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +174 -0
- data/examples/classifiers/hyperpipes_data.csv +14 -0
- data/examples/classifiers/hyperpipes_example.rb +22 -0
- data/examples/classifiers/ib1_example.rb +12 -0
- data/examples/classifiers/id3_example.rb +15 -10
- data/examples/classifiers/id3_graphviz_example.rb +17 -0
- data/examples/classifiers/logistic_regression_example.rb +11 -0
- data/examples/classifiers/naive_bayes_attributes_example.rb +13 -0
- data/examples/classifiers/naive_bayes_example.rb +12 -13
- data/examples/classifiers/one_r_example.rb +27 -0
- data/examples/classifiers/parameter_tutorial.rb +29 -0
- data/examples/classifiers/prism_nominal_example.rb +15 -0
- data/examples/classifiers/prism_numeric_example.rb +21 -0
- data/examples/classifiers/simple_linear_regression_example.csv +159 -0
- data/examples/classifiers/simple_linear_regression_example.rb +18 -0
- data/examples/classifiers/zero_and_one_r_example.rb +34 -0
- data/examples/classifiers/zero_one_r_data.csv +8 -0
- data/examples/clusterers/clusterer_example.rb +62 -0
- data/examples/clusterers/dbscan_example.rb +17 -0
- data/examples/clusterers/dendrogram_example.rb +17 -0
- data/examples/clusterers/hierarchical_dendrogram_example.rb +20 -0
- data/examples/clusterers/kmeans_custom_example.rb +26 -0
- data/examples/genetic_algorithm/bitstring_example.rb +41 -0
- data/examples/genetic_algorithm/genetic_algorithm_example.rb +26 -18
- data/examples/genetic_algorithm/kmeans_seed_tuning.rb +45 -0
- data/examples/neural_network/backpropagation_example.rb +49 -48
- data/examples/neural_network/hopfield_example.rb +45 -0
- data/examples/neural_network/patterns_with_base_noise.rb +39 -39
- data/examples/neural_network/patterns_with_noise.rb +41 -39
- data/examples/neural_network/train_epochs_callback.rb +25 -0
- data/examples/neural_network/training_patterns.rb +39 -39
- data/examples/neural_network/transformer_text_classification.rb +78 -0
- data/examples/neural_network/xor_example.rb +23 -22
- data/examples/reinforcement/q_learning_example.rb +10 -0
- data/examples/som/som_data.rb +155 -152
- data/examples/som/som_multi_node_example.rb +12 -13
- data/examples/som/som_single_example.rb +12 -15
- data/examples/transformer/decode_classifier_example.rb +68 -0
- data/examples/transformer/deterministic_example.rb +10 -0
- data/examples/transformer/seq2seq_example.rb +16 -0
- data/lib/ai4r/classifiers/classifier.rb +24 -16
- data/lib/ai4r/classifiers/gradient_boosting.rb +64 -0
- data/lib/ai4r/classifiers/hyperpipes.rb +119 -43
- data/lib/ai4r/classifiers/ib1.rb +122 -32
- data/lib/ai4r/classifiers/id3.rb +527 -144
- data/lib/ai4r/classifiers/logistic_regression.rb +96 -0
- data/lib/ai4r/classifiers/multilayer_perceptron.rb +75 -59
- data/lib/ai4r/classifiers/naive_bayes.rb +112 -48
- data/lib/ai4r/classifiers/one_r.rb +112 -44
- data/lib/ai4r/classifiers/prism.rb +167 -76
- data/lib/ai4r/classifiers/random_forest.rb +72 -0
- data/lib/ai4r/classifiers/simple_linear_regression.rb +143 -0
- data/lib/ai4r/classifiers/support_vector_machine.rb +91 -0
- data/lib/ai4r/classifiers/votes.rb +57 -0
- data/lib/ai4r/classifiers/zero_r.rb +71 -30
- data/lib/ai4r/clusterers/average_linkage.rb +46 -27
- data/lib/ai4r/clusterers/bisecting_k_means.rb +50 -44
- data/lib/ai4r/clusterers/centroid_linkage.rb +52 -36
- data/lib/ai4r/clusterers/cluster_tree.rb +50 -0
- data/lib/ai4r/clusterers/clusterer.rb +28 -24
- data/lib/ai4r/clusterers/complete_linkage.rb +42 -31
- data/lib/ai4r/clusterers/dbscan.rb +134 -0
- data/lib/ai4r/clusterers/diana.rb +75 -49
- data/lib/ai4r/clusterers/k_means.rb +309 -72
- data/lib/ai4r/clusterers/median_linkage.rb +49 -33
- data/lib/ai4r/clusterers/single_linkage.rb +196 -88
- data/lib/ai4r/clusterers/ward_linkage.rb +51 -35
- data/lib/ai4r/clusterers/ward_linkage_hierarchical.rb +63 -0
- data/lib/ai4r/clusterers/weighted_average_linkage.rb +48 -32
- data/lib/ai4r/data/data_set.rb +229 -100
- data/lib/ai4r/data/parameterizable.rb +31 -25
- data/lib/ai4r/data/proximity.rb +72 -50
- data/lib/ai4r/data/statistics.rb +46 -35
- data/lib/ai4r/experiment/classifier_evaluator.rb +84 -32
- data/lib/ai4r/experiment/split.rb +39 -0
- data/lib/ai4r/genetic_algorithm/chromosome_base.rb +43 -0
- data/lib/ai4r/genetic_algorithm/genetic_algorithm.rb +92 -170
- data/lib/ai4r/genetic_algorithm/tsp_chromosome.rb +83 -0
- data/lib/ai4r/hmm/hidden_markov_model.rb +134 -0
- data/lib/ai4r/neural_network/activation_functions.rb +37 -0
- data/lib/ai4r/neural_network/backpropagation.rb +419 -143
- data/lib/ai4r/neural_network/hopfield.rb +175 -58
- data/lib/ai4r/neural_network/transformer.rb +194 -0
- data/lib/ai4r/neural_network/weight_initializations.rb +40 -0
- data/lib/ai4r/reinforcement/policy_iteration.rb +66 -0
- data/lib/ai4r/reinforcement/q_learning.rb +51 -0
- data/lib/ai4r/search/a_star.rb +76 -0
- data/lib/ai4r/search/bfs.rb +50 -0
- data/lib/ai4r/search/dfs.rb +50 -0
- data/lib/ai4r/search/mcts.rb +118 -0
- data/lib/ai4r/search.rb +12 -0
- data/lib/ai4r/som/distance_metrics.rb +29 -0
- data/lib/ai4r/som/layer.rb +28 -17
- data/lib/ai4r/som/node.rb +61 -32
- data/lib/ai4r/som/som.rb +158 -41
- data/lib/ai4r/som/two_phase_layer.rb +21 -25
- data/lib/ai4r/version.rb +3 -0
- data/lib/ai4r.rb +58 -27
- metadata +117 -106
- data/README.rdoc +0 -44
- data/test/classifiers/hyperpipes_test.rb +0 -84
- data/test/classifiers/ib1_test.rb +0 -78
- data/test/classifiers/id3_test.rb +0 -208
- data/test/classifiers/multilayer_perceptron_test.rb +0 -79
- data/test/classifiers/naive_bayes_test.rb +0 -43
- data/test/classifiers/one_r_test.rb +0 -62
- data/test/classifiers/prism_test.rb +0 -85
- data/test/classifiers/zero_r_test.rb +0 -50
- data/test/clusterers/average_linkage_test.rb +0 -51
- data/test/clusterers/bisecting_k_means_test.rb +0 -66
- data/test/clusterers/centroid_linkage_test.rb +0 -53
- data/test/clusterers/complete_linkage_test.rb +0 -57
- data/test/clusterers/diana_test.rb +0 -69
- data/test/clusterers/k_means_test.rb +0 -100
- data/test/clusterers/median_linkage_test.rb +0 -53
- data/test/clusterers/single_linkage_test.rb +0 -122
- data/test/clusterers/ward_linkage_test.rb +0 -53
- data/test/clusterers/weighted_average_linkage_test.rb +0 -53
- data/test/data/data_set_test.rb +0 -96
- data/test/data/proximity_test.rb +0 -81
- data/test/data/statistics_test.rb +0 -65
- data/test/experiment/classifier_evaluator_test.rb +0 -76
- data/test/genetic_algorithm/chromosome_test.rb +0 -57
- data/test/genetic_algorithm/genetic_algorithm_test.rb +0 -81
- data/test/neural_network/backpropagation_test.rb +0 -82
- data/test/neural_network/hopfield_test.rb +0 -72
- data/test/som/som_test.rb +0 -97
@@ -0,0 +1,63 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# Author:: Peter Lubell-Doughtie
|
4
|
+
# License:: BSD 3 Clause
|
5
|
+
# Project:: ai4r
|
6
|
+
# Url:: http://peet.ldee.org
|
7
|
+
|
8
|
+
require_relative '../clusterers/ward_linkage'
|
9
|
+
require_relative '../clusterers/cluster_tree'
|
10
|
+
|
11
|
+
module Ai4r
|
12
|
+
module Clusterers
|
13
|
+
# Hierarchical version to store classes as merges occur.
|
14
|
+
class WardLinkageHierarchical < WardLinkage
|
15
|
+
include ClusterTree
|
16
|
+
|
17
|
+
# @param depth [Object]
|
18
|
+
# @return [Object]
|
19
|
+
def initialize(depth = nil)
|
20
|
+
@cluster_tree = []
|
21
|
+
@depth = depth
|
22
|
+
@merges_so_far = 0
|
23
|
+
super(depth)
|
24
|
+
end
|
25
|
+
|
26
|
+
# @param data_set [Object]
|
27
|
+
# @param number_of_clusters [Object]
|
28
|
+
# @param *options [Object]
|
29
|
+
# @return [Object]
|
30
|
+
def build(data_set, number_of_clusters = 1, **options)
|
31
|
+
data_len = data_set.data_items.length
|
32
|
+
@total_merges = data_len - number_of_clusters
|
33
|
+
super
|
34
|
+
@cluster_tree << clusters
|
35
|
+
@cluster_tree.reverse!
|
36
|
+
self
|
37
|
+
end
|
38
|
+
|
39
|
+
# @return [Object]
|
40
|
+
def supports_eval?
|
41
|
+
false
|
42
|
+
end
|
43
|
+
|
44
|
+
protected
|
45
|
+
|
46
|
+
# @param index_a [Object]
|
47
|
+
# @param index_b [Object]
|
48
|
+
# @param index_clusters [Object]
|
49
|
+
# @return [Object]
|
50
|
+
def merge_clusters(index_a, index_b, index_clusters)
|
51
|
+
# only store if no or above depth
|
52
|
+
if @depth.nil? || (@merges_so_far > @total_merges - @depth)
|
53
|
+
# store current clusters
|
54
|
+
stored_distance_matrix = @distance_matrix.dup
|
55
|
+
@cluster_tree << build_clusters_from_index_clusters(index_clusters)
|
56
|
+
@distance_matrix = stored_distance_matrix
|
57
|
+
end
|
58
|
+
@merges_so_far += 1
|
59
|
+
super
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -1,61 +1,77 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens (implementation)
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
|
-
|
11
|
-
|
12
|
+
require_relative '../data/data_set'
|
13
|
+
require_relative '../clusterers/single_linkage'
|
14
|
+
require_relative '../clusterers/cluster_tree'
|
12
15
|
|
13
16
|
module Ai4r
|
14
17
|
module Clusterers
|
15
|
-
|
16
|
-
#
|
17
|
-
# weighted average linkage algorithm, aka weighted pair group method
|
18
|
+
# Implementation of an Agglomerative Hierarchical clusterer with
|
19
|
+
# weighted average linkage algorithm, aka weighted pair group method
|
18
20
|
# average or WPGMA (Jain and Dubes, 1988 ; McQuitty, 1966 )
|
19
|
-
# Hierarchical
|
21
|
+
# Hierarchical clusterer create one cluster per element, and then
|
20
22
|
# progressively merge clusters, until the required number of clusters
|
21
23
|
# is reached.
|
22
|
-
# Similar to AverageLinkage, but the distances between clusters are
|
24
|
+
# Similar to AverageLinkage, but the distances between clusters are
|
23
25
|
# weighted based on the number of data items in each of them.
|
24
|
-
#
|
26
|
+
#
|
25
27
|
# D(cx, (ci U cj)) = ( ni * D(cx, ci) + nj * D(cx, cj)) / (ni + nj)
|
26
28
|
class WeightedAverageLinkage < SingleLinkage
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
29
|
+
include ClusterTree
|
30
|
+
|
31
|
+
parameters_info distance_function:
|
32
|
+
'Custom implementation of distance function. ' \
|
33
|
+
'It must be a closure receiving two data items and return the ' \
|
34
|
+
'distance between them. By default, this algorithm uses ' \
|
35
|
+
'euclidean distance of numeric attributes to the power of 2.'
|
36
|
+
|
34
37
|
# Build a new clusterer, using data examples found in data_set.
|
35
38
|
# Items will be clustered in "number_of_clusters" different
|
36
39
|
# clusters.
|
37
|
-
|
40
|
+
# @param data_set [Object]
|
41
|
+
# @param number_of_clusters [Object]
|
42
|
+
# @param *options [Object]
|
43
|
+
# @return [Object]
|
44
|
+
def build(data_set, number_of_clusters = 1, **options)
|
38
45
|
super
|
39
46
|
end
|
40
|
-
|
41
|
-
# This algorithms does not allow classification of new data items
|
47
|
+
|
48
|
+
# This algorithms does not allow classification of new data items
|
42
49
|
# once it has been built. Rebuild the cluster including you data element.
|
43
|
-
|
44
|
-
|
50
|
+
# @param _data_item [Object]
|
51
|
+
# @return [Object]
|
52
|
+
def eval(_data_item)
|
53
|
+
raise NotImplementedError, 'Eval of new data is not supported by this algorithm.'
|
45
54
|
end
|
46
|
-
|
55
|
+
|
56
|
+
# @return [Object]
|
57
|
+
def supports_eval?
|
58
|
+
false
|
59
|
+
end
|
60
|
+
|
47
61
|
protected
|
48
|
-
|
62
|
+
|
49
63
|
# return distance between cluster cx and cluster (ci U cj),
|
50
64
|
# using weighted average linkage
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
65
|
+
# @param cx [Object]
|
66
|
+
# @param ci [Object]
|
67
|
+
# @param cj [Object]
|
68
|
+
# @return [Object]
|
69
|
+
def linkage_distance(cluster_x, cluster_i, cluster_j)
|
70
|
+
ni = @index_clusters[cluster_i].length
|
71
|
+
nj = @index_clusters[cluster_j].length
|
72
|
+
((1.0 * ni * read_distance_matrix(cluster_x, cluster_i)) +
|
73
|
+
(nj * read_distance_matrix(cluster_x, cluster_j))) / (ni + nj)
|
56
74
|
end
|
57
|
-
|
58
75
|
end
|
59
76
|
end
|
60
77
|
end
|
61
|
-
|
data/lib/ai4r/data/data_set.rb
CHANGED
@@ -1,36 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
1
3
|
# Author:: Sergio Fierens
|
2
4
|
# License:: MPL 1.1
|
3
5
|
# Project:: ai4r
|
4
|
-
# Url::
|
6
|
+
# Url:: https://github.com/SergioFierens/ai4r
|
5
7
|
#
|
6
|
-
# You can redistribute it and/or modify it under the terms of
|
7
|
-
# the Mozilla Public License version 1.1 as published by the
|
8
|
+
# You can redistribute it and/or modify it under the terms of
|
9
|
+
# the Mozilla Public License version 1.1 as published by the
|
8
10
|
# Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
|
9
11
|
|
10
12
|
require 'csv'
|
11
13
|
require 'set'
|
12
|
-
|
14
|
+
require_relative 'statistics'
|
13
15
|
|
14
16
|
module Ai4r
|
15
17
|
module Data
|
16
|
-
|
17
|
-
# A data set is a collection of N data items. Each data item is
|
18
|
+
# A data set is a collection of N data items. Each data item is
|
18
19
|
# described by a set of attributes, represented as an array.
|
19
|
-
# Optionally, you can assign a label to the attributes, using
|
20
|
+
# Optionally, you can assign a label to the attributes, using
|
20
21
|
# the data_labels property.
|
21
22
|
class DataSet
|
22
|
-
|
23
|
-
@@number_regex = /(((\b[0-9]+)?\.)?\b[0-9]+([eE][-+]?[0-9]+)?\b)/
|
24
|
-
|
25
23
|
attr_reader :data_labels, :data_items
|
26
24
|
|
25
|
+
# Return a new DataSet with numeric attributes normalized.
|
26
|
+
# Available methods are:
|
27
|
+
# * +:zscore+ - subtract the mean and divide by the standard deviation
|
28
|
+
# * +:minmax+ - scale values to the [0,1] range
|
29
|
+
# @param data_set [Object]
|
30
|
+
# @param method [Object]
|
31
|
+
# @return [Object]
|
32
|
+
def self.normalized(data_set, method: :zscore)
|
33
|
+
new_set = DataSet.new(
|
34
|
+
data_items: data_set.data_items.map(&:dup),
|
35
|
+
data_labels: data_set.data_labels.dup
|
36
|
+
)
|
37
|
+
new_set.normalize!(method)
|
38
|
+
end
|
39
|
+
|
27
40
|
# Create a new DataSet. By default, empty.
|
28
41
|
# Optionaly, you can provide the initial data items and data labels.
|
29
|
-
#
|
42
|
+
#
|
30
43
|
# e.g. DataSet.new(:data_items => data_items, :data_labels => labels)
|
31
|
-
#
|
44
|
+
#
|
32
45
|
# If you provide data items, but no data labels, the data set will
|
33
46
|
# use the default data label values (see set_data_labels)
|
47
|
+
# @param options [Object]
|
48
|
+
# @return [Object]
|
34
49
|
def initialize(options = {})
|
35
50
|
@data_labels = []
|
36
51
|
@data_items = options[:data_items] || []
|
@@ -38,78 +53,97 @@ module Ai4r
|
|
38
53
|
set_data_items(options[:data_items]) if options[:data_items]
|
39
54
|
end
|
40
55
|
|
41
|
-
# Retrieve a new DataSet, with the item(s) selected by the provided
|
56
|
+
# Retrieve a new DataSet, with the item(s) selected by the provided
|
42
57
|
# index. You can specify an index range, too.
|
58
|
+
# @param index [Object]
|
59
|
+
# @return [Object]
|
43
60
|
def [](index)
|
44
|
-
selected_items =
|
45
|
-
|
46
|
-
|
47
|
-
|
61
|
+
selected_items = if index.is_a?(Integer)
|
62
|
+
[@data_items[index]]
|
63
|
+
else
|
64
|
+
@data_items[index]
|
65
|
+
end
|
66
|
+
DataSet.new(data_items: selected_items,
|
67
|
+
data_labels: @data_labels)
|
48
68
|
end
|
49
69
|
|
50
70
|
# Load data items from csv file
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
set_data_items(items)
|
57
|
-
end
|
58
|
-
|
59
|
-
# opens a csv-file and reads it line by line
|
60
|
-
# for each line, a block is called and the row is passed to the block
|
61
|
-
# ruby1.8 and 1.9 safe
|
62
|
-
def open_csv_file(filepath, &block)
|
63
|
-
if CSV.const_defined? :Reader
|
64
|
-
CSV::Reader.parse(File.open(filepath, 'r')) do |row|
|
65
|
-
block.call row
|
66
|
-
end
|
71
|
+
# @param filepath [Object]
|
72
|
+
# @return [Object]
|
73
|
+
def load_csv(filepath, parse_numeric: false)
|
74
|
+
if parse_numeric
|
75
|
+
parse_csv(filepath)
|
67
76
|
else
|
68
|
-
|
69
|
-
|
77
|
+
items = []
|
78
|
+
open_csv_file(filepath) do |entry|
|
79
|
+
items << entry
|
70
80
|
end
|
81
|
+
set_data_items(items)
|
71
82
|
end
|
72
83
|
end
|
73
84
|
|
85
|
+
# Open a CSV file and yield each row to the provided block.
|
86
|
+
# @param filepath [Object]
|
87
|
+
# @param block [Object]
|
88
|
+
# @return [Object]
|
89
|
+
def open_csv_file(filepath, &)
|
90
|
+
CSV.foreach(filepath, &)
|
91
|
+
end
|
92
|
+
|
74
93
|
# Load data items from csv file. The first row is used as data labels.
|
75
|
-
|
76
|
-
|
94
|
+
# @param filepath [Object]
|
95
|
+
# @return [Object]
|
96
|
+
def load_csv_with_labels(filepath, parse_numeric: false)
|
97
|
+
load_csv(filepath, parse_numeric: parse_numeric)
|
77
98
|
@data_labels = @data_items.shift
|
78
|
-
|
99
|
+
self
|
79
100
|
end
|
80
101
|
|
81
102
|
# Same as load_csv, but it will try to convert cell contents as numbers.
|
103
|
+
# @param filepath [Object]
|
104
|
+
# @return [Object]
|
82
105
|
def parse_csv(filepath)
|
83
106
|
items = []
|
84
107
|
open_csv_file(filepath) do |row|
|
85
|
-
items << row.collect
|
108
|
+
items << row.collect do |x|
|
109
|
+
number?(x) ? Float(x, exception: false) : x
|
110
|
+
end
|
86
111
|
end
|
87
112
|
set_data_items(items)
|
88
113
|
end
|
89
114
|
|
115
|
+
# Same as load_csv_with_labels, but it will try to convert cell contents as numbers.
|
116
|
+
# @param filepath [Object]
|
117
|
+
# @return [Object]
|
118
|
+
def parse_csv_with_labels(filepath)
|
119
|
+
load_csv_with_labels(filepath, parse_numeric: true)
|
120
|
+
end
|
121
|
+
|
90
122
|
# Set data labels.
|
91
123
|
# Data labels must have the following format:
|
92
124
|
# [ 'city', 'age_range', 'gender', 'marketing_target' ]
|
93
125
|
#
|
94
126
|
# If you do not provide labels for you data, the following labels will
|
95
127
|
# be created by default:
|
96
|
-
# [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
|
128
|
+
# [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
|
129
|
+
# @param labels [Object]
|
130
|
+
# @return [Object]
|
97
131
|
def set_data_labels(labels)
|
98
132
|
check_data_labels(labels)
|
99
133
|
@data_labels = labels
|
100
|
-
|
134
|
+
self
|
101
135
|
end
|
102
136
|
|
103
137
|
# Set the data items.
|
104
|
-
# M data items with N attributes must have the following
|
138
|
+
# M data items with N attributes must have the following
|
105
139
|
# format:
|
106
|
-
#
|
107
|
-
# [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CLASS_VAL1],
|
108
|
-
# [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CLASS_VAL2],
|
140
|
+
#
|
141
|
+
# [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CLASS_VAL1],
|
142
|
+
# [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CLASS_VAL2],
|
109
143
|
# ...
|
110
|
-
# [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CLASS_VALM],
|
144
|
+
# [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CLASS_VALM],
|
111
145
|
# ]
|
112
|
-
#
|
146
|
+
#
|
113
147
|
# e.g.
|
114
148
|
# [ ['New York', '<30', 'M', 'Y'],
|
115
149
|
# ['Chicago', '<30', 'M', 'Y'],
|
@@ -127,140 +161,235 @@ module Ai4r
|
|
127
161
|
# ['New York', '[50-80]', 'F', 'N'],
|
128
162
|
# ['Chicago', '>80', 'F', 'Y']
|
129
163
|
# ]
|
130
|
-
#
|
164
|
+
#
|
131
165
|
# This method returns the classifier (self), allowing method chaining.
|
166
|
+
# @param items [Object]
|
167
|
+
# @return [Object]
|
132
168
|
def set_data_items(items)
|
133
169
|
check_data_items(items)
|
134
170
|
@data_labels = default_data_labels(items) if @data_labels.empty?
|
135
171
|
@data_items = items
|
136
|
-
|
172
|
+
self
|
137
173
|
end
|
138
174
|
|
139
175
|
# Returns an array with the domain of each attribute:
|
140
176
|
# * Set instance containing all possible values for nominal attributes
|
141
177
|
# * Array with min and max values for numeric attributes (i.e. [min, max])
|
142
|
-
#
|
178
|
+
#
|
143
179
|
# Return example:
|
144
|
-
# => [#<Set: {"New York", "Chicago"}>,
|
145
|
-
# #<Set: {"<30", "[30-50)", "[50-80]", ">80"}>,
|
180
|
+
# => [#<Set: {"New York", "Chicago"}>,
|
181
|
+
# #<Set: {"<30", "[30-50)", "[50-80]", ">80"}>,
|
146
182
|
# #<Set: {"M", "F"}>,
|
147
|
-
# [5, 85],
|
183
|
+
# [5, 85],
|
148
184
|
# #<Set: {"Y", "N"}>]
|
185
|
+
# @return [Object]
|
149
186
|
def build_domains
|
150
|
-
@data_labels.collect {|attr_label| build_domain(attr_label) }
|
187
|
+
@data_labels.collect { |attr_label| build_domain(attr_label) }
|
151
188
|
end
|
152
189
|
|
153
190
|
# Returns a Set instance containing all possible values for an attribute
|
154
191
|
# The parameter can be an attribute label or index (0 based).
|
155
192
|
# * Set instance containing all possible values for nominal attributes
|
156
193
|
# * Array with min and max values for numeric attributes (i.e. [min, max])
|
157
|
-
#
|
194
|
+
#
|
158
195
|
# build_domain("city")
|
159
196
|
# => #<Set: {"New York", "Chicago"}>
|
160
|
-
#
|
197
|
+
#
|
161
198
|
# build_domain("age")
|
162
199
|
# => [5, 85]
|
163
|
-
#
|
200
|
+
#
|
164
201
|
# build_domain(2) # In this example, the third attribute is gender
|
165
202
|
# => #<Set: {"M", "F"}>
|
203
|
+
# @param attr [Object]
|
204
|
+
# @return [Object]
|
166
205
|
def build_domain(attr)
|
167
206
|
index = get_index(attr)
|
168
|
-
if @data_items.first[index].is_a?(Numeric)
|
169
|
-
|
170
|
-
|
171
|
-
return @data_items.inject(Set.new){|domain, x| domain << x[index]}
|
172
|
-
end
|
207
|
+
return [Statistics.min(self, index), Statistics.max(self, index)] if @data_items.first[index].is_a?(Numeric)
|
208
|
+
|
209
|
+
@data_items.inject(Set.new) { |domain, x| domain << x[index] }
|
173
210
|
end
|
174
211
|
|
175
212
|
# Returns attributes number, including class attribute
|
213
|
+
# @return [Object]
|
176
214
|
def num_attributes
|
177
|
-
|
215
|
+
@data_items.empty? ? 0 : @data_items.first.size
|
178
216
|
end
|
179
217
|
|
180
218
|
# Returns the index of a given attribute (0-based).
|
181
219
|
# For example, if "gender" is the third attribute, then:
|
182
|
-
# get_index("gender")
|
220
|
+
# get_index("gender")
|
183
221
|
# => 2
|
222
|
+
# @param attr [Object]
|
223
|
+
# @return [Object]
|
184
224
|
def get_index(attr)
|
185
|
-
|
225
|
+
attr.is_a?(Integer) || attr.is_a?(Range) ? attr : @data_labels.index(attr)
|
186
226
|
end
|
187
227
|
|
188
228
|
# Raise an exception if there is no data item.
|
229
|
+
# @return [Object]
|
189
230
|
def check_not_empty
|
190
|
-
|
191
|
-
|
192
|
-
|
231
|
+
return unless @data_items.empty?
|
232
|
+
|
233
|
+
raise ArgumentError, 'Examples data set must not be empty.'
|
193
234
|
end
|
194
235
|
|
195
236
|
# Add a data item to the data set
|
196
|
-
|
237
|
+
# @return [Object]
|
238
|
+
def <<(data_item)
|
197
239
|
if data_item.nil? || !data_item.is_a?(Enumerable) || data_item.empty?
|
198
|
-
raise ArgumentError,
|
240
|
+
raise ArgumentError, 'Data must not be an non empty array.'
|
199
241
|
elsif @data_items.empty?
|
200
242
|
set_data_items([data_item])
|
201
243
|
elsif data_item.length != num_attributes
|
202
|
-
raise ArgumentError,
|
203
|
-
|
204
|
-
|
244
|
+
raise ArgumentError, 'Number of attributes do not match. ' \
|
245
|
+
"#{data_item.length} attributes provided, " \
|
246
|
+
"#{num_attributes} attributes expected."
|
205
247
|
else
|
206
248
|
@data_items << data_item
|
207
249
|
end
|
208
250
|
end
|
209
251
|
|
210
|
-
# Returns an array with the mean value of numeric attributes, and
|
252
|
+
# Returns an array with the mean value of numeric attributes, and
|
211
253
|
# the most frequent value of non numeric attributes
|
254
|
+
# @return [Object]
|
212
255
|
def get_mean_or_mode
|
213
256
|
mean = []
|
214
257
|
num_attributes.times do |i|
|
215
258
|
mean[i] =
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
259
|
+
if @data_items.first[i].is_a?(Numeric)
|
260
|
+
Statistics.mean(self, i)
|
261
|
+
else
|
262
|
+
Statistics.mode(self, i)
|
263
|
+
end
|
221
264
|
end
|
222
|
-
|
265
|
+
mean
|
266
|
+
end
|
267
|
+
|
268
|
+
# Normalize numeric attributes in place. Supported methods are
|
269
|
+
# +:zscore+ (default) and +:minmax+.
|
270
|
+
# @param method [Object]
|
271
|
+
# @return [Object]
|
272
|
+
def normalize!(method = :zscore)
|
273
|
+
numeric_indices = (0...num_attributes).select do |i|
|
274
|
+
@data_items.first[i].is_a?(Numeric)
|
275
|
+
end
|
276
|
+
|
277
|
+
case method
|
278
|
+
when :zscore
|
279
|
+
means = numeric_indices.map { |i| Statistics.mean(self, i) }
|
280
|
+
sds = numeric_indices.map { |i| Statistics.standard_deviation(self, i) }
|
281
|
+
@data_items.each do |row|
|
282
|
+
numeric_indices.each_with_index do |idx, j|
|
283
|
+
sd = sds[j]
|
284
|
+
row[idx] = sd.zero? ? 0 : (row[idx] - means[j]) / sd
|
285
|
+
end
|
286
|
+
end
|
287
|
+
when :minmax
|
288
|
+
mins = numeric_indices.map { |i| Statistics.min(self, i) }
|
289
|
+
maxs = numeric_indices.map { |i| Statistics.max(self, i) }
|
290
|
+
@data_items.each do |row|
|
291
|
+
numeric_indices.each_with_index do |idx, j|
|
292
|
+
range = maxs[j] - mins[j]
|
293
|
+
row[idx] = range.zero? ? 0 : (row[idx] - mins[j]) / range.to_f
|
294
|
+
end
|
295
|
+
end
|
296
|
+
else
|
297
|
+
raise ArgumentError, "Unknown normalization method #{method}"
|
298
|
+
end
|
299
|
+
|
300
|
+
self
|
301
|
+
end
|
302
|
+
|
303
|
+
# Randomizes the order of data items in place.
|
304
|
+
# If a +seed+ is provided, it is used to initialize the random number
|
305
|
+
# generator for deterministic shuffling.
|
306
|
+
#
|
307
|
+
# data_set.shuffle!(seed: 123)
|
308
|
+
#
|
309
|
+
# @param seed [Integer, nil] Seed for the RNG
|
310
|
+
# @return [DataSet] self
|
311
|
+
def shuffle!(seed: nil)
|
312
|
+
rng = seed ? Random.new(seed) : Random.new
|
313
|
+
@data_items.shuffle!(random: rng)
|
314
|
+
self
|
315
|
+
end
|
316
|
+
|
317
|
+
# Split the dataset into two new DataSet instances using the given ratio
|
318
|
+
# for the first set.
|
319
|
+
#
|
320
|
+
# train, test = data_set.split(ratio: 0.8)
|
321
|
+
#
|
322
|
+
# @param ratio [Float] fraction of items to place in the first set
|
323
|
+
# @return [Array<DataSet, DataSet>] the two resulting datasets
|
324
|
+
def split(ratio:)
|
325
|
+
raise ArgumentError, 'ratio must be between 0 and 1' unless ratio.positive? && ratio < 1
|
326
|
+
|
327
|
+
pivot = (ratio * @data_items.length).round
|
328
|
+
first_items = @data_items[0...pivot].map(&:dup)
|
329
|
+
second_items = @data_items[pivot..].map(&:dup)
|
330
|
+
|
331
|
+
[
|
332
|
+
DataSet.new(data_items: first_items, data_labels: @data_labels.dup),
|
333
|
+
DataSet.new(data_items: second_items, data_labels: @data_labels.dup)
|
334
|
+
]
|
335
|
+
end
|
336
|
+
|
337
|
+
# Returns label of category
|
338
|
+
# @return [Object]
|
339
|
+
def category_label
|
340
|
+
data_labels.last
|
223
341
|
end
|
224
342
|
|
225
343
|
protected
|
226
344
|
|
345
|
+
# @param x [Object]
|
346
|
+
# @return [Object]
|
347
|
+
def number?(x)
|
348
|
+
!Float(x, exception: false).nil?
|
349
|
+
end
|
350
|
+
|
351
|
+
# @param data_items [Object]
|
352
|
+
# @return [Object]
|
227
353
|
def check_data_items(data_items)
|
228
354
|
if !data_items || data_items.empty?
|
229
|
-
raise ArgumentError,
|
355
|
+
raise ArgumentError, 'Examples data set must not be empty.'
|
230
356
|
elsif !data_items.first.is_a?(Enumerable)
|
231
|
-
raise ArgumentError,
|
357
|
+
raise ArgumentError, 'Unkown format for example data.'
|
232
358
|
end
|
359
|
+
|
233
360
|
attributes_num = data_items.first.length
|
234
361
|
data_items.each_index do |index|
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
362
|
+
next unless data_items[index].length != attributes_num
|
363
|
+
|
364
|
+
raise ArgumentError,
|
365
|
+
'Quantity of attributes is inconsistent. ' \
|
366
|
+
"The first item has #{attributes_num} attributes " \
|
367
|
+
"and row #{index} has #{data_items[index].length} attributes"
|
241
368
|
end
|
242
369
|
end
|
243
370
|
|
371
|
+
# @param labels [Object]
|
372
|
+
# @return [Object]
|
244
373
|
def check_data_labels(labels)
|
245
|
-
if
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
end
|
374
|
+
return if @data_items.empty?
|
375
|
+
return unless labels.length != @data_items.first.length
|
376
|
+
|
377
|
+
raise ArgumentError,
|
378
|
+
'Number of labels and attributes do not match. ' \
|
379
|
+
"#{labels.length} labels and " \
|
380
|
+
"#{@data_items.first.length} attributes found."
|
253
381
|
end
|
254
382
|
|
383
|
+
# @param data_items [Object]
|
384
|
+
# @return [Object]
|
255
385
|
def default_data_labels(data_items)
|
256
386
|
data_labels = []
|
257
387
|
data_items[0][0..-2].each_index do |i|
|
258
|
-
data_labels[i] = "attribute_#{i+1}"
|
388
|
+
data_labels[i] = "attribute_#{i + 1}"
|
259
389
|
end
|
260
|
-
data_labels[data_labels.length]=
|
261
|
-
|
390
|
+
data_labels[data_labels.length] = 'class_value'
|
391
|
+
data_labels
|
262
392
|
end
|
263
|
-
|
264
393
|
end
|
265
394
|
end
|
266
395
|
end
|