confusion_matrix 1.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/LICENSE.rdoc +22 -0
- data/README.rdoc +88 -0
- data/lib/confusion_matrix.rb +451 -0
- data/test/matrix_test.rb +160 -0
- metadata +55 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: bc3f9ed40d45bb892541c4a3dc97c04c5893e335ec79b2526d072d291d08a7b7
|
4
|
+
data.tar.gz: 602f9857a4e45283357117974943d46f4802fbd9c5ce7b1be0ec704472d29ba4
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 20cc86e92c2ad0867206ee2c2a46fe31de7c08c15e90491d481a8c40fd94541b4dea87b575792652cedbb49fea284e70e4f72250ab6f2e83de2e9b57ea7136c6
|
7
|
+
data.tar.gz: 972386c261254d2fc44b09f320d48b41639c09682f16c0ddeb88cd2697b683f7b178d257e7edcb3bcba51d0dfc4c625bf478fcadd1a6e2e81ed8354a588266ce
|
data/LICENSE.rdoc
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
= MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2020-23, Peter Lane
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
22
|
+
|
data/README.rdoc
ADDED
@@ -0,0 +1,88 @@
|
|
1
|
+
= Confusion Matrix
|
2
|
+
|
3
|
+
Install from {RubyGems}[https://rubygems.org/gems/confusion_matrix/]:
|
4
|
+
|
5
|
+
> gem install confusion_matrix
|
6
|
+
|
7
|
+
source:: https://notabug.org/peterlane/confusion-matrix-ruby/
|
8
|
+
|
9
|
+
== Description
|
10
|
+
|
11
|
+
A confusion matrix is used in data-mining as a summary of the performance of a
|
12
|
+
classification algorithm. Each row represents the _actual_ class of an
|
13
|
+
instance, and each column represents the _predicted_ class of that instance,
|
14
|
+
i.e. the class that they were classified as. Numbers at each (row, column)
|
15
|
+
reflect the total number of instances of actual class "row" which were
|
16
|
+
predicted to fall in class "column".
|
17
|
+
|
18
|
+
A two-class example is:
|
19
|
+
|
20
|
+
Classified Classified |
|
21
|
+
Positive Negative | Actual
|
22
|
+
------------------------------+------------
|
23
|
+
a b | Positive
|
24
|
+
c d | Negative
|
25
|
+
|
26
|
+
Here the value:
|
27
|
+
|
28
|
+
a:: is the number of true positives (those labelled positive and classified positive)
|
29
|
+
b:: is the number of false negatives (those labelled positive but classified negative)
|
30
|
+
c:: is the number of false positives (those labelled negative but classified positive)
|
31
|
+
d:: is the number of true negatives (those labelled negative and classified negative)
|
32
|
+
|
33
|
+
From this table we can calculate statistics like:
|
34
|
+
|
35
|
+
true_positive_rate:: a/(a+b)
|
36
|
+
positive recall:: a/(a+c)
|
37
|
+
|
38
|
+
The implementation supports confusion matrices with more than two
|
39
|
+
classes, and hence most statistics are calculated with reference to a
|
40
|
+
named class. When more than two classes are in use, the statistics
|
41
|
+
are calculated as if the named class were positive and all the other
|
42
|
+
classes are grouped as if negative.
|
43
|
+
|
44
|
+
For example, in a three-class example:
|
45
|
+
|
46
|
+
Classified Classified Classified |
|
47
|
+
Red Blue Green | Actual
|
48
|
+
--------------------------------------------+------------
|
49
|
+
a b c | Red
|
50
|
+
d e f | Blue
|
51
|
+
g h i | Green
|
52
|
+
|
53
|
+
We can calculate:
|
54
|
+
|
55
|
+
true_red_rate:: a/(a+b+c)
|
56
|
+
red recall:: a/(a+d+g)
|
57
|
+
|
58
|
+
== Example
|
59
|
+
|
60
|
+
The following example creates a simple two-class confusion matrix,
|
61
|
+
prints a few statistics and displays the table.
|
62
|
+
|
63
|
+
require 'confusion_matrix'
|
64
|
+
|
65
|
+
cm = ConfusionMatrix.new :pos, :neg
|
66
|
+
cm.add_for(:pos, :pos, 10)
|
67
|
+
3.times { cm.add_for(:pos, :neg) }
|
68
|
+
20.times { cm.add_for(:neg, :neg) }
|
69
|
+
5.times { cm.add_for(:neg, :pos) }
|
70
|
+
|
71
|
+
puts "Precision: #{cm.precision}"
|
72
|
+
puts "Recall: #{cm.recall}"
|
73
|
+
puts "MCC: #{cm.matthews_correlation}"
|
74
|
+
puts
|
75
|
+
puts(cm.to_s)
|
76
|
+
|
77
|
+
Output:
|
78
|
+
|
79
|
+
Precision: 0.6666666666666666
|
80
|
+
Recall: 0.7692307692307693
|
81
|
+
MCC: 0.5524850114241865
|
82
|
+
|
83
|
+
Predicted |
|
84
|
+
pos neg | Actual
|
85
|
+
----------+-------
|
86
|
+
10 3 | pos
|
87
|
+
5 20 | neg
|
88
|
+
|
@@ -0,0 +1,451 @@
|
|
1
|
+
|
2
|
+
# This class holds the confusion matrix information.
|
3
|
+
# It is designed to be called incrementally, as results are obtained
|
4
|
+
# from the classifier model.
|
5
|
+
#
|
6
|
+
# At any point, statistics may be obtained by calling the relevant methods.
|
7
|
+
#
|
8
|
+
# A two-class example is:
|
9
|
+
#
|
10
|
+
# Classified Classified |
|
11
|
+
# Positive Negative | Actual
|
12
|
+
# ------------------------------+------------
|
13
|
+
# a b | Positive
|
14
|
+
# c d | Negative
|
15
|
+
#
|
16
|
+
# Statistical methods will be described with reference to this example.
|
17
|
+
#
|
18
|
+
class ConfusionMatrix
|
19
|
+
# Creates a new, empty instance of a confusion matrix.
|
20
|
+
#
|
21
|
+
# @param labels [<String, Symbol>, ...] if provided, makes the matrix
|
22
|
+
# use the first label as a default label, and also check
|
23
|
+
# all operations use one of the pre-defined labels.
|
24
|
+
# @raise [ArgumentError] if there are not at least two unique labels, when provided.
|
25
|
+
def initialize(*labels)
|
26
|
+
@matrix = {}
|
27
|
+
@labels = labels.uniq
|
28
|
+
if @labels.size == 1
|
29
|
+
raise ArgumentError.new("If labels are provided, there must be at least two.")
|
30
|
+
else # preset the matrix Hash
|
31
|
+
@labels.each do |actual|
|
32
|
+
@matrix[actual] = {}
|
33
|
+
@labels.each do |predicted|
|
34
|
+
@matrix[actual][predicted] = 0
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
# Returns a list of labels used in the matrix.
|
41
|
+
#
|
42
|
+
# cm = ConfusionMatrix.new
|
43
|
+
# cm.add_for(:pos, :neg)
|
44
|
+
# cm.labels # => [:neg, :pos]
|
45
|
+
#
|
46
|
+
# @return [Array<String>] labels used in the matrix.
|
47
|
+
def labels
|
48
|
+
if @labels.size >= 2 # if we defined some labels, return them
|
49
|
+
@labels
|
50
|
+
else
|
51
|
+
result = []
|
52
|
+
|
53
|
+
@matrix.each_pair do |key, predictions|
|
54
|
+
result << key
|
55
|
+
predictions.each_key do |key|
|
56
|
+
result << key
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
result.uniq.sort
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
# Return the count for (actual,prediction) pair.
|
65
|
+
#
|
66
|
+
# cm = ConfusionMatrix.new
|
67
|
+
# cm.add_for(:pos, :neg)
|
68
|
+
# cm.count_for(:pos, :neg) # => 1
|
69
|
+
#
|
70
|
+
# @param actual [String, Symbol] is actual class of the instance,
|
71
|
+
# which we expect the classifier to predict
|
72
|
+
# @param prediction [String, Symbol] is the predicted class of the instance,
|
73
|
+
# as output from the classifier
|
74
|
+
# @return [Integer] number of observations of (actual, prediction) pair
|
75
|
+
# @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
|
76
|
+
# pre-defined labels in matrix
|
77
|
+
def count_for(actual, prediction)
|
78
|
+
validate_label actual, prediction
|
79
|
+
predictions = @matrix.fetch(actual, {})
|
80
|
+
predictions.fetch(prediction, 0)
|
81
|
+
end
|
82
|
+
|
83
|
+
# Adds one result to the matrix for a given (actual, prediction) pair of labels.
|
84
|
+
# If the matrix was given a pre-defined list of labels on construction, then
|
85
|
+
# these given labels must be from the pre-defined list.
|
86
|
+
# If no pre-defined list of labels was used in constructing the matrix, then
|
87
|
+
# labels will be added to matrix.
|
88
|
+
#
|
89
|
+
# Class labels may be any hashable value, though ideally they are strings or symbols.
|
90
|
+
#
|
91
|
+
# @param actual [String, Symbol] is actual class of the instance,
|
92
|
+
# which we expect the classifier to predict
|
93
|
+
# @param prediction [String, Symbol] is the predicted class of the instance,
|
94
|
+
# as output from the classifier
|
95
|
+
# @param n [Integer] number of observations to add
|
96
|
+
# @raise [ArgumentError] if +n+ is not an Integer
|
97
|
+
# @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
|
98
|
+
# pre-defined labels in matrix
|
99
|
+
def add_for(actual, prediction, n = 1)
|
100
|
+
validate_label actual, prediction
|
101
|
+
if !@matrix.has_key?(actual)
|
102
|
+
@matrix[actual] = {}
|
103
|
+
end
|
104
|
+
predictions = @matrix[actual]
|
105
|
+
if !predictions.has_key?(prediction)
|
106
|
+
predictions[prediction] = 0
|
107
|
+
end
|
108
|
+
|
109
|
+
unless n.class == Integer and n.positive?
|
110
|
+
raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
|
111
|
+
end
|
112
|
+
|
113
|
+
@matrix[actual][prediction] += n
|
114
|
+
end
|
115
|
+
|
116
|
+
# Returns the number of instances of the given class label which
|
117
|
+
# are incorrectly classified.
|
118
|
+
#
|
119
|
+
# false_negative(:positive) = b
|
120
|
+
#
|
121
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
122
|
+
# @return [Float] value of false negative
|
123
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
124
|
+
def false_negative(label = @labels.first)
|
125
|
+
validate_label label
|
126
|
+
predictions = @matrix.fetch(label, {})
|
127
|
+
total = 0
|
128
|
+
|
129
|
+
predictions.each_pair do |key, count|
|
130
|
+
if key != label
|
131
|
+
total += count
|
132
|
+
end
|
133
|
+
end
|
134
|
+
|
135
|
+
total
|
136
|
+
end
|
137
|
+
|
138
|
+
# Returns the number of instances incorrectly classified with the given
|
139
|
+
# class label.
|
140
|
+
#
|
141
|
+
# false_positive(:positive) = c
|
142
|
+
#
|
143
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
144
|
+
# @return [Float] value of false positive
|
145
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
146
|
+
def false_positive(label = @labels.first)
|
147
|
+
validate_label label
|
148
|
+
total = 0
|
149
|
+
|
150
|
+
@matrix.each_pair do |key, predictions|
|
151
|
+
if key != label
|
152
|
+
total += predictions.fetch(label, 0)
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
total
|
157
|
+
end
|
158
|
+
|
159
|
+
# The false rate for a given class label is the proportion of instances
|
160
|
+
# incorrectly classified as that label, out of all those instances
|
161
|
+
# not originally of that label.
|
162
|
+
#
|
163
|
+
# false_rate(:positive) = c/(c+d)
|
164
|
+
#
|
165
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
166
|
+
# @return [Float] value of false rate
|
167
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
168
|
+
def false_rate(label = @labels.first)
|
169
|
+
validate_label label
|
170
|
+
fp = false_positive(label)
|
171
|
+
tn = true_negative(label)
|
172
|
+
|
173
|
+
divide(fp, fp+tn)
|
174
|
+
end
|
175
|
+
|
176
|
+
# The F-measure for a given label is the harmonic mean of the precision
|
177
|
+
# and recall for that label.
|
178
|
+
#
|
179
|
+
# F = 2*(precision*recall)/(precision+recall)
|
180
|
+
#
|
181
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
182
|
+
# @return [Float] value of F-measure
|
183
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
184
|
+
def f_measure(label = @labels.first)
|
185
|
+
validate_label label
|
186
|
+
2*precision(label)*recall(label)/(precision(label) + recall(label))
|
187
|
+
end
|
188
|
+
|
189
|
+
# The geometric mean is the nth-root of the product of the true_rate for
|
190
|
+
# each label.
|
191
|
+
#
|
192
|
+
# a1 = a/(a+b)
|
193
|
+
# a2 = d/(c+d)
|
194
|
+
# geometric_mean = Math.sqrt(a1*a2)
|
195
|
+
#
|
196
|
+
# @return [Float] value of geometric mean
|
197
|
+
def geometric_mean
|
198
|
+
product = 1
|
199
|
+
|
200
|
+
@matrix.each_key do |key|
|
201
|
+
product *= true_rate(key)
|
202
|
+
end
|
203
|
+
|
204
|
+
product**(1.0/@matrix.size)
|
205
|
+
end
|
206
|
+
|
207
|
+
# The Kappa statistic compares the observed accuracy with an expected
|
208
|
+
# accuracy.
|
209
|
+
#
|
210
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
211
|
+
# @return [Float] value of Cohen's Kappa Statistic
|
212
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
213
|
+
def kappa(label = @labels.first)
|
214
|
+
validate_label label
|
215
|
+
tp = true_positive(label)
|
216
|
+
fn = false_negative(label)
|
217
|
+
fp = false_positive(label)
|
218
|
+
tn = true_negative(label)
|
219
|
+
total = tp+fn+fp+tn
|
220
|
+
|
221
|
+
total_accuracy = divide(tp+tn, tp+tn+fp+fn)
|
222
|
+
random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
|
223
|
+
|
224
|
+
divide(total_accuracy - random_accuracy, 1 - random_accuracy)
|
225
|
+
end
|
226
|
+
|
227
|
+
# Matthews Correlation Coefficient is a measure of the quality of binary
|
228
|
+
# classifications.
|
229
|
+
#
|
230
|
+
# mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
|
231
|
+
#
|
232
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
233
|
+
# @return [Float] value of Matthews Correlation Coefficient
|
234
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
235
|
+
def matthews_correlation(label = @labels.first)
|
236
|
+
validate_label label
|
237
|
+
tp = true_positive(label)
|
238
|
+
fn = false_negative(label)
|
239
|
+
fp = false_positive(label)
|
240
|
+
tn = true_negative(label)
|
241
|
+
|
242
|
+
divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
|
243
|
+
end
|
244
|
+
|
245
|
+
# The overall accuracy is the proportion of instances which are
|
246
|
+
# correctly labelled.
|
247
|
+
#
|
248
|
+
# overall_accuracy = (a+d)/(a+b+c+d)
|
249
|
+
#
|
250
|
+
# @return [Float] value of overall accuracy
|
251
|
+
def overall_accuracy
|
252
|
+
total_correct = 0
|
253
|
+
|
254
|
+
@matrix.each_pair do |key, predictions|
|
255
|
+
total_correct += true_positive(key)
|
256
|
+
end
|
257
|
+
|
258
|
+
divide(total_correct, total)
|
259
|
+
end
|
260
|
+
|
261
|
+
# The precision for a given class label is the proportion of instances
|
262
|
+
# classified as that class which are correct.
|
263
|
+
#
|
264
|
+
# precision(:positive) = a/(a+c)
|
265
|
+
#
|
266
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
267
|
+
# @return [Float] value of precision
|
268
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
269
|
+
def precision(label = @labels.first)
|
270
|
+
validate_label label
|
271
|
+
tp = true_positive(label)
|
272
|
+
fp = false_positive(label)
|
273
|
+
|
274
|
+
divide(tp, tp+fp)
|
275
|
+
end
|
276
|
+
|
277
|
+
# The prevalence for a given class label is the proportion of instances
|
278
|
+
# which were classified as of that label, out of the total.
|
279
|
+
#
|
280
|
+
# prevalence(:positive) = (a+c)/(a+b+c+d)
|
281
|
+
#
|
282
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
283
|
+
# @return [Float] value of prevalence
|
284
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
285
|
+
def prevalence(label = @labels.first)
|
286
|
+
validate_label label
|
287
|
+
tp = true_positive(label)
|
288
|
+
fn = false_negative(label)
|
289
|
+
fp = false_positive(label)
|
290
|
+
tn = true_negative(label)
|
291
|
+
total = tp+fn+fp+tn
|
292
|
+
|
293
|
+
divide(tp+fn, total)
|
294
|
+
end
|
295
|
+
|
296
|
+
# The recall is another name for the true rate.
|
297
|
+
#
|
298
|
+
# @see true_rate
|
299
|
+
# @param (see #true_rate)
|
300
|
+
# @return (see #true_rate)
|
301
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
302
|
+
def recall(label = @labels.first)
|
303
|
+
validate_label label
|
304
|
+
true_rate(label)
|
305
|
+
end
|
306
|
+
|
307
|
+
# Sensitivity is another name for the true rate.
|
308
|
+
#
|
309
|
+
# @see true_rate
|
310
|
+
# @param (see #true_rate)
|
311
|
+
# @return (see #true_rate)
|
312
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
313
|
+
def sensitivity(label = @labels.first)
|
314
|
+
validate_label label
|
315
|
+
true_rate(label)
|
316
|
+
end
|
317
|
+
|
318
|
+
# The specificity for a given class label is 1 - false_rate(label)
|
319
|
+
#
|
320
|
+
# In two-class case, specificity = 1 - false_positive_rate
|
321
|
+
#
|
322
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
323
|
+
# @return [Float] value of specificity
|
324
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
325
|
+
def specificity(label = @labels.first)
|
326
|
+
validate_label label
|
327
|
+
1-false_rate(label)
|
328
|
+
end
|
329
|
+
|
330
|
+
# Returns the table in a string format, representing the entries as a
|
331
|
+
# printable table.
|
332
|
+
#
|
333
|
+
# @return [String] representation as a printable table.
|
334
|
+
def to_s
|
335
|
+
ls = labels
|
336
|
+
result = ""
|
337
|
+
|
338
|
+
title_line = "Predicted "
|
339
|
+
label_line = ""
|
340
|
+
ls.each { |l| label_line << "#{l} " }
|
341
|
+
label_line << " " while label_line.size < title_line.size
|
342
|
+
title_line << " " while title_line.size < label_line.size
|
343
|
+
result << title_line << "|\n" << label_line << "| Actual\n"
|
344
|
+
result << "-"*title_line.size << "+-------\n"
|
345
|
+
|
346
|
+
ls.each do |l|
|
347
|
+
count_line = ""
|
348
|
+
ls.each_with_index do |m, i|
|
349
|
+
count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
|
350
|
+
end
|
351
|
+
result << count_line.ljust(title_line.size) << "| #{l}\n"
|
352
|
+
end
|
353
|
+
|
354
|
+
result
|
355
|
+
end
|
356
|
+
|
357
|
+
# Returns the total number of instances referenced in the matrix.
|
358
|
+
#
|
359
|
+
# total = a+b+c+d
|
360
|
+
#
|
361
|
+
# @return [Integer] total number of instances referenced in the matrix.
|
362
|
+
def total
|
363
|
+
total = 0
|
364
|
+
|
365
|
+
@matrix.each_value do |predictions|
|
366
|
+
predictions.each_value do |count|
|
367
|
+
total += count
|
368
|
+
end
|
369
|
+
end
|
370
|
+
|
371
|
+
total
|
372
|
+
end
|
373
|
+
|
374
|
+
# Returns the number of instances NOT of the given class label which
|
375
|
+
# are correctly classified.
|
376
|
+
#
|
377
|
+
# true_negative(:positive) = d
|
378
|
+
#
|
379
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
380
|
+
# @return [Integer] number of instances not of given label which are correctly classified
|
381
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
382
|
+
def true_negative(label = @labels.first)
|
383
|
+
validate_label label
|
384
|
+
total = 0
|
385
|
+
|
386
|
+
@matrix.each_pair do |key, predictions|
|
387
|
+
if key != label
|
388
|
+
total += predictions.fetch(key, 0)
|
389
|
+
end
|
390
|
+
end
|
391
|
+
|
392
|
+
total
|
393
|
+
end
|
394
|
+
|
395
|
+
# Returns the number of instances of the given class label which are
|
396
|
+
# correctly classified.
|
397
|
+
#
|
398
|
+
# true_positive(:positive) = a
|
399
|
+
#
|
400
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
401
|
+
# @return [Integer] number of instances of given label which are correctly classified
|
402
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
403
|
+
def true_positive(label = @labels.first)
|
404
|
+
validate_label label
|
405
|
+
predictions = @matrix.fetch(label, {})
|
406
|
+
predictions.fetch(label, 0)
|
407
|
+
end
|
408
|
+
|
409
|
+
# The true rate for a given class label is the proportion of instances of
|
410
|
+
# that class which are correctly classified.
|
411
|
+
#
|
412
|
+
# true_rate(:positive) = a/(a+b)
|
413
|
+
#
|
414
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
415
|
+
# @return [Float] proportion of instances which are correctly classified
|
416
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
417
|
+
def true_rate(label = @labels.first)
|
418
|
+
validate_label label
|
419
|
+
tp = true_positive(label)
|
420
|
+
fn = false_negative(label)
|
421
|
+
|
422
|
+
divide(tp, tp+fn)
|
423
|
+
end
|
424
|
+
|
425
|
+
private
|
426
|
+
|
427
|
+
# A form of "safe divide".
|
428
|
+
# Checks if divisor is zero, and returns 0.0 if so.
|
429
|
+
# This avoids a run-time error.
|
430
|
+
# Also, ensures floating point division is done.
|
431
|
+
def divide(x, y)
|
432
|
+
if y.zero?
|
433
|
+
0.0
|
434
|
+
else
|
435
|
+
x.to_f/y
|
436
|
+
end
|
437
|
+
end
|
438
|
+
|
439
|
+
# Checks if given label(s) is non-nil and in @labels, or if @labels is empty
|
440
|
+
# Raises ArgumentError if not
|
441
|
+
def validate_label *labels
|
442
|
+
return true if @labels.empty?
|
443
|
+
labels.each do |label|
|
444
|
+
unless label and @labels.include?(label)
|
445
|
+
raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
|
446
|
+
end
|
447
|
+
end
|
448
|
+
end
|
449
|
+
end
|
450
|
+
|
451
|
+
|
data/test/matrix_test.rb
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
require 'confusion_matrix'
|
2
|
+
require 'minitest/autorun'
|
3
|
+
|
4
|
+
class TestConfusionMatrix < MiniTest::Test
|
5
|
+
def test_empty_case
|
6
|
+
cm = ConfusionMatrix.new
|
7
|
+
assert(0, cm.total)
|
8
|
+
assert(0, cm.true_positive(:none))
|
9
|
+
assert(0, cm.false_negative(:none))
|
10
|
+
assert(0, cm.false_positive(:none))
|
11
|
+
assert(0, cm.true_negative(:none))
|
12
|
+
assert_in_delta(0, cm.true_rate(:none))
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_two_classes
|
16
|
+
cm = ConfusionMatrix.new
|
17
|
+
10.times { cm.add_for(:pos, :pos) }
|
18
|
+
5.times { cm.add_for(:pos, :neg) }
|
19
|
+
20.times { cm.add_for(:neg, :neg) }
|
20
|
+
5.times { cm.add_for(:neg, :pos) }
|
21
|
+
|
22
|
+
assert_equal([:neg, :pos], cm.labels)
|
23
|
+
assert_equal(10, cm.count_for(:pos, :pos))
|
24
|
+
assert_equal(5, cm.count_for(:pos, :neg))
|
25
|
+
assert_equal(20, cm.count_for(:neg, :neg))
|
26
|
+
assert_equal(5, cm.count_for(:neg, :pos))
|
27
|
+
|
28
|
+
assert_equal(40, cm.total)
|
29
|
+
assert_equal(10, cm.true_positive(:pos))
|
30
|
+
assert_equal(5, cm.false_negative(:pos))
|
31
|
+
assert_equal(5, cm.false_positive(:pos))
|
32
|
+
assert_equal(20, cm.true_negative(:pos))
|
33
|
+
assert_equal(20, cm.true_positive(:neg))
|
34
|
+
assert_equal(5, cm.false_negative(:neg))
|
35
|
+
assert_equal(5, cm.false_positive(:neg))
|
36
|
+
assert_equal(10, cm.true_negative(:neg))
|
37
|
+
|
38
|
+
assert_in_delta(0.6667, cm.true_rate(:pos))
|
39
|
+
assert_in_delta(0.8, cm.true_rate(:neg))
|
40
|
+
assert_in_delta(0.2, cm.false_rate(:pos))
|
41
|
+
assert_in_delta(0.3333, cm.false_rate(:neg))
|
42
|
+
assert_in_delta(0.6667, cm.precision(:pos))
|
43
|
+
assert_in_delta(0.8, cm.precision(:neg))
|
44
|
+
assert_in_delta(0.6667, cm.recall(:pos))
|
45
|
+
assert_in_delta(0.8, cm.recall(:neg))
|
46
|
+
assert_in_delta(0.6667, cm.sensitivity(:pos))
|
47
|
+
assert_in_delta(0.8, cm.sensitivity(:neg))
|
48
|
+
assert_in_delta(0.75, cm.overall_accuracy)
|
49
|
+
assert_in_delta(0.6667, cm.f_measure(:pos))
|
50
|
+
assert_in_delta(0.8, cm.f_measure(:neg))
|
51
|
+
assert_in_delta(0.7303, cm.geometric_mean)
|
52
|
+
end
|
53
|
+
|
54
|
+
# Example from:
|
55
|
+
# https://www.datatechnotes.com/2019/02/accuracy-metrics-in-classification.html
|
56
|
+
def test_two_classes_2
|
57
|
+
cm = ConfusionMatrix.new
|
58
|
+
5.times { cm.add_for(:pos, :pos) }
|
59
|
+
1.times { cm.add_for(:pos, :neg) }
|
60
|
+
3.times { cm.add_for(:neg, :neg) }
|
61
|
+
2.times { cm.add_for(:neg, :pos) }
|
62
|
+
|
63
|
+
assert_equal(11, cm.total)
|
64
|
+
assert_equal(5, cm.true_positive(:pos))
|
65
|
+
assert_equal(1, cm.false_negative(:pos))
|
66
|
+
assert_equal(2, cm.false_positive(:pos))
|
67
|
+
assert_equal(3, cm.true_negative(:pos))
|
68
|
+
|
69
|
+
assert_in_delta(0.7142, cm.precision(:pos))
|
70
|
+
assert_in_delta(0.8333, cm.recall(:pos))
|
71
|
+
assert_in_delta(0.7272, cm.overall_accuracy)
|
72
|
+
assert_in_delta(0.7692, cm.f_measure(:pos))
|
73
|
+
assert_in_delta(0.8333, cm.sensitivity(:pos))
|
74
|
+
assert_in_delta(0.6, cm.specificity(:pos))
|
75
|
+
assert_in_delta(0.4407, cm.kappa(:pos))
|
76
|
+
assert_in_delta(0.5454, cm.prevalence(:pos))
|
77
|
+
end
|
78
|
+
|
79
|
+
# Examples from:
|
80
|
+
# https://standardwisdom.com/softwarejournal/2011/12/matthews-correlation-coefficient-how-well-does-it-do/
|
81
|
+
def two_class_case(a,b,c,d,e,f,g,h,i)
|
82
|
+
cm = ConfusionMatrix.new
|
83
|
+
a.times { cm.add_for(:pos, :pos) }
|
84
|
+
b.times { cm.add_for(:pos, :neg) }
|
85
|
+
c.times { cm.add_for(:neg, :neg) }
|
86
|
+
d.times { cm.add_for(:neg, :pos) }
|
87
|
+
|
88
|
+
assert_in_delta(e, cm.matthews_correlation(:pos))
|
89
|
+
assert_in_delta(f, cm.precision(:pos))
|
90
|
+
assert_in_delta(g, cm.recall(:pos))
|
91
|
+
assert_in_delta(h, cm.f_measure(:pos))
|
92
|
+
assert_in_delta(i, cm.kappa(:pos))
|
93
|
+
end
|
94
|
+
|
95
|
+
def test_two_classes_3
|
96
|
+
two_class_case(100, 0, 900, 0, 1.0, 1.0, 1.0, 1.0, 1.0)
|
97
|
+
two_class_case(65, 35, 825, 75, 0.490, 0.4643, 0.65, 0.542, 0.4811)
|
98
|
+
two_class_case(50, 50, 700, 200, 0.192, 0.2, 0.5, 0.286, 0.1666)
|
99
|
+
end
|
100
|
+
|
101
|
+
def test_three_classes
|
102
|
+
cm = ConfusionMatrix.new
|
103
|
+
10.times { cm.add_for(:red, :red) }
|
104
|
+
7.times { cm.add_for(:red, :blue) }
|
105
|
+
5.times { cm.add_for(:red, :green) }
|
106
|
+
20.times { cm.add_for(:blue, :red) }
|
107
|
+
5.times { cm.add_for(:blue, :blue) }
|
108
|
+
15.times { cm.add_for(:blue, :green) }
|
109
|
+
30.times { cm.add_for(:green, :red) }
|
110
|
+
12.times { cm.add_for(:green, :blue) }
|
111
|
+
8.times { cm.add_for(:green, :green) }
|
112
|
+
|
113
|
+
assert_equal([:blue, :green, :red], cm.labels)
|
114
|
+
assert_equal(112, cm.total)
|
115
|
+
assert_equal(10, cm.true_positive(:red))
|
116
|
+
assert_equal(12, cm.false_negative(:red))
|
117
|
+
assert_equal(50, cm.false_positive(:red))
|
118
|
+
assert_equal(13, cm.true_negative(:red))
|
119
|
+
assert_equal(5, cm.true_positive(:blue))
|
120
|
+
assert_equal(35, cm.false_negative(:blue))
|
121
|
+
assert_equal(19, cm.false_positive(:blue))
|
122
|
+
assert_equal(18, cm.true_negative(:blue))
|
123
|
+
assert_equal(8, cm.true_positive(:green))
|
124
|
+
assert_equal(42, cm.false_negative(:green))
|
125
|
+
assert_equal(20, cm.false_positive(:green))
|
126
|
+
assert_equal(15, cm.true_negative(:green))
|
127
|
+
end
|
128
|
+
|
129
|
+
def test_add_for_n
|
130
|
+
cm = ConfusionMatrix.new
|
131
|
+
cm.add_for(:pos, :pos, 3)
|
132
|
+
cm.add_for(:pos, :neg)
|
133
|
+
cm.add_for(:neg, :pos, 2)
|
134
|
+
cm.add_for(:neg, :neg, 1)
|
135
|
+
assert_equal(7, cm.total)
|
136
|
+
assert_equal(3, cm.count_for(:pos, :pos))
|
137
|
+
# - check errors
|
138
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :pos, 0) }
|
139
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :pos, -3) }
|
140
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :pos, nil) }
|
141
|
+
end
|
142
|
+
|
143
|
+
def test_use_labels
|
144
|
+
# - check errors
|
145
|
+
assert_raises(ArgumentError) { ConfusionMatrix.new(:pos) }
|
146
|
+
assert_raises(ArgumentError) { ConfusionMatrix.new(:pos, :pos) }
|
147
|
+
# - check created matrix
|
148
|
+
cm = ConfusionMatrix.new(:pos, :neg)
|
149
|
+
assert_equal([:pos, :neg], cm.labels)
|
150
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :nothing) }
|
151
|
+
cm.add_for(:pos, :neg, 3)
|
152
|
+
cm.add_for(:neg, :pos, 2)
|
153
|
+
assert_equal(2, cm.false_negative(:neg))
|
154
|
+
assert_equal(3, cm.false_negative(:pos))
|
155
|
+
assert_equal(3, cm.false_negative())
|
156
|
+
assert_raises(ArgumentError) { cm.false_negative(:nothing) }
|
157
|
+
assert_raises(ArgumentError) { cm.false_negative(nil) }
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
metadata
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: confusion_matrix
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 1.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Peter Lane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2023-02-04 00:00:00.000000000 Z
|
12
|
+
dependencies: []
|
13
|
+
description: "A confusion matrix is used in data-mining as a summary of the performance
|
14
|
+
of a\nclassification algorithm. This library allows the user to incrementally add
|
15
|
+
\nresults to a confusion matrix, and then retrieve statistical information.\n"
|
16
|
+
email: peterlane@gmx.com
|
17
|
+
executables: []
|
18
|
+
extensions: []
|
19
|
+
extra_rdoc_files:
|
20
|
+
- README.rdoc
|
21
|
+
- LICENSE.rdoc
|
22
|
+
files:
|
23
|
+
- LICENSE.rdoc
|
24
|
+
- README.rdoc
|
25
|
+
- lib/confusion_matrix.rb
|
26
|
+
- test/matrix_test.rb
|
27
|
+
homepage:
|
28
|
+
licenses:
|
29
|
+
- MIT
|
30
|
+
metadata: {}
|
31
|
+
post_install_message:
|
32
|
+
rdoc_options:
|
33
|
+
- "-m"
|
34
|
+
- README.rdoc
|
35
|
+
require_paths:
|
36
|
+
- lib
|
37
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
38
|
+
requirements:
|
39
|
+
- - ">="
|
40
|
+
- !ruby/object:Gem::Version
|
41
|
+
version: '2.5'
|
42
|
+
- - "<"
|
43
|
+
- !ruby/object:Gem::Version
|
44
|
+
version: '4.0'
|
45
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
46
|
+
requirements:
|
47
|
+
- - ">="
|
48
|
+
- !ruby/object:Gem::Version
|
49
|
+
version: '0'
|
50
|
+
requirements: []
|
51
|
+
rubygems_version: 3.4.5
|
52
|
+
signing_key:
|
53
|
+
specification_version: 4
|
54
|
+
summary: Construct a confusion matrix and retrieve statistical information from it.
|
55
|
+
test_files: []
|