confusion_matrix 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.rdoc +22 -0
- data/README.rdoc +88 -0
- data/lib/confusion_matrix.rb +451 -0
- data/test/matrix_test.rb +160 -0
- metadata +55 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: bc3f9ed40d45bb892541c4a3dc97c04c5893e335ec79b2526d072d291d08a7b7
|
4
|
+
data.tar.gz: 602f9857a4e45283357117974943d46f4802fbd9c5ce7b1be0ec704472d29ba4
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 20cc86e92c2ad0867206ee2c2a46fe31de7c08c15e90491d481a8c40fd94541b4dea87b575792652cedbb49fea284e70e4f72250ab6f2e83de2e9b57ea7136c6
|
7
|
+
data.tar.gz: 972386c261254d2fc44b09f320d48b41639c09682f16c0ddeb88cd2697b683f7b178d257e7edcb3bcba51d0dfc4c625bf478fcadd1a6e2e81ed8354a588266ce
|
data/LICENSE.rdoc
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
= MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2020-23, Peter Lane
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
22
|
+
|
data/README.rdoc
ADDED
@@ -0,0 +1,88 @@
|
|
1
|
+
= Confusion Matrix
|
2
|
+
|
3
|
+
Install from {RubyGems}[https://rubygems.org/gems/confusion_matrix/]:
|
4
|
+
|
5
|
+
> gem install confusion_matrix
|
6
|
+
|
7
|
+
source:: https://notabug.org/peterlane/confusion-matrix-ruby/
|
8
|
+
|
9
|
+
== Description
|
10
|
+
|
11
|
+
A confusion matrix is used in data-mining as a summary of the performance of a
|
12
|
+
classification algorithm. Each row represents the _actual_ class of an
|
13
|
+
instance, and each column represents the _predicted_ class of that instance,
|
14
|
+
i.e. the class that they were classified as. Numbers at each (row, column)
|
15
|
+
reflect the total number of instances of actual class "row" which were
|
16
|
+
predicted to fall in class "column".
|
17
|
+
|
18
|
+
A two-class example is:
|
19
|
+
|
20
|
+
Classified Classified |
|
21
|
+
Positive Negative | Actual
|
22
|
+
------------------------------+------------
|
23
|
+
a b | Positive
|
24
|
+
c d | Negative
|
25
|
+
|
26
|
+
Here the value:
|
27
|
+
|
28
|
+
a:: is the number of true positives (those labelled positive and classified positive)
|
29
|
+
b:: is the number of false negatives (those labelled positive but classified negative)
|
30
|
+
c:: is the number of false positives (those labelled negative but classified positive)
|
31
|
+
d:: is the number of true negatives (those labelled negative and classified negative)
|
32
|
+
|
33
|
+
From this table we can calculate statistics like:
|
34
|
+
|
35
|
+
true_positive_rate:: a/(a+b)
|
36
|
+
positive recall:: a/(a+c)
|
37
|
+
|
38
|
+
The implementation supports confusion matrices with more than two
|
39
|
+
classes, and hence most statistics are calculated with reference to a
|
40
|
+
named class. When more than two classes are in use, the statistics
|
41
|
+
are calculated as if the named class were positive and all the other
|
42
|
+
classes are grouped as if negative.
|
43
|
+
|
44
|
+
For example, in a three-class example:
|
45
|
+
|
46
|
+
Classified Classified Classified |
|
47
|
+
Red Blue Green | Actual
|
48
|
+
--------------------------------------------+------------
|
49
|
+
a b c | Red
|
50
|
+
d e f | Blue
|
51
|
+
g h i | Green
|
52
|
+
|
53
|
+
We can calculate:
|
54
|
+
|
55
|
+
true_red_rate:: a/(a+b+c)
|
56
|
+
red recall:: a/(a+d+g)
|
57
|
+
|
58
|
+
== Example
|
59
|
+
|
60
|
+
The following example creates a simple two-class confusion matrix,
|
61
|
+
prints a few statistics and displays the table.
|
62
|
+
|
63
|
+
require 'confusion_matrix'
|
64
|
+
|
65
|
+
cm = ConfusionMatrix.new :pos, :neg
|
66
|
+
cm.add_for(:pos, :pos, 10)
|
67
|
+
3.times { cm.add_for(:pos, :neg) }
|
68
|
+
20.times { cm.add_for(:neg, :neg) }
|
69
|
+
5.times { cm.add_for(:neg, :pos) }
|
70
|
+
|
71
|
+
puts "Precision: #{cm.precision}"
|
72
|
+
puts "Recall: #{cm.recall}"
|
73
|
+
puts "MCC: #{cm.matthews_correlation}"
|
74
|
+
puts
|
75
|
+
puts(cm.to_s)
|
76
|
+
|
77
|
+
Output:
|
78
|
+
|
79
|
+
Precision: 0.6666666666666666
|
80
|
+
Recall: 0.7692307692307693
|
81
|
+
MCC: 0.5524850114241865
|
82
|
+
|
83
|
+
Predicted |
|
84
|
+
pos neg | Actual
|
85
|
+
----------+-------
|
86
|
+
10 3 | pos
|
87
|
+
5 20 | neg
|
88
|
+
|
@@ -0,0 +1,451 @@
|
|
1
|
+
|
2
|
+
# This class holds the confusion matrix information.
|
3
|
+
# It is designed to be called incrementally, as results are obtained
|
4
|
+
# from the classifier model.
|
5
|
+
#
|
6
|
+
# At any point, statistics may be obtained by calling the relevant methods.
|
7
|
+
#
|
8
|
+
# A two-class example is:
|
9
|
+
#
|
10
|
+
# Classified Classified |
|
11
|
+
# Positive Negative | Actual
|
12
|
+
# ------------------------------+------------
|
13
|
+
# a b | Positive
|
14
|
+
# c d | Negative
|
15
|
+
#
|
16
|
+
# Statistical methods will be described with reference to this example.
|
17
|
+
#
|
18
|
+
class ConfusionMatrix
|
19
|
+
# Creates a new, empty instance of a confusion matrix.
|
20
|
+
#
|
21
|
+
# @param labels [<String, Symbol>, ...] if provided, makes the matrix
|
22
|
+
# use the first label as a default label, and also check
|
23
|
+
# all operations use one of the pre-defined labels.
|
24
|
+
# @raise [ArgumentError] if there are not at least two unique labels, when provided.
|
25
|
+
def initialize(*labels)
|
26
|
+
@matrix = {}
|
27
|
+
@labels = labels.uniq
|
28
|
+
if @labels.size == 1
|
29
|
+
raise ArgumentError.new("If labels are provided, there must be at least two.")
|
30
|
+
else # preset the matrix Hash
|
31
|
+
@labels.each do |actual|
|
32
|
+
@matrix[actual] = {}
|
33
|
+
@labels.each do |predicted|
|
34
|
+
@matrix[actual][predicted] = 0
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
# Returns a list of labels used in the matrix.
|
41
|
+
#
|
42
|
+
# cm = ConfusionMatrix.new
|
43
|
+
# cm.add_for(:pos, :neg)
|
44
|
+
# cm.labels # => [:neg, :pos]
|
45
|
+
#
|
46
|
+
# @return [Array<String>] labels used in the matrix.
|
47
|
+
def labels
|
48
|
+
if @labels.size >= 2 # if we defined some labels, return them
|
49
|
+
@labels
|
50
|
+
else
|
51
|
+
result = []
|
52
|
+
|
53
|
+
@matrix.each_pair do |key, predictions|
|
54
|
+
result << key
|
55
|
+
predictions.each_key do |key|
|
56
|
+
result << key
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
result.uniq.sort
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
# Return the count for (actual,prediction) pair.
|
65
|
+
#
|
66
|
+
# cm = ConfusionMatrix.new
|
67
|
+
# cm.add_for(:pos, :neg)
|
68
|
+
# cm.count_for(:pos, :neg) # => 1
|
69
|
+
#
|
70
|
+
# @param actual [String, Symbol] is actual class of the instance,
|
71
|
+
# which we expect the classifier to predict
|
72
|
+
# @param prediction [String, Symbol] is the predicted class of the instance,
|
73
|
+
# as output from the classifier
|
74
|
+
# @return [Integer] number of observations of (actual, prediction) pair
|
75
|
+
# @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
|
76
|
+
# pre-defined labels in matrix
|
77
|
+
def count_for(actual, prediction)
|
78
|
+
validate_label actual, prediction
|
79
|
+
predictions = @matrix.fetch(actual, {})
|
80
|
+
predictions.fetch(prediction, 0)
|
81
|
+
end
|
82
|
+
|
83
|
+
# Adds one result to the matrix for a given (actual, prediction) pair of labels.
|
84
|
+
# If the matrix was given a pre-defined list of labels on construction, then
|
85
|
+
# these given labels must be from the pre-defined list.
|
86
|
+
# If no pre-defined list of labels was used in constructing the matrix, then
|
87
|
+
# labels will be added to matrix.
|
88
|
+
#
|
89
|
+
# Class labels may be any hashable value, though ideally they are strings or symbols.
|
90
|
+
#
|
91
|
+
# @param actual [String, Symbol] is actual class of the instance,
|
92
|
+
# which we expect the classifier to predict
|
93
|
+
# @param prediction [String, Symbol] is the predicted class of the instance,
|
94
|
+
# as output from the classifier
|
95
|
+
# @param n [Integer] number of observations to add
|
96
|
+
# @raise [ArgumentError] if +n+ is not an Integer
|
97
|
+
# @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
|
98
|
+
# pre-defined labels in matrix
|
99
|
+
def add_for(actual, prediction, n = 1)
|
100
|
+
validate_label actual, prediction
|
101
|
+
if !@matrix.has_key?(actual)
|
102
|
+
@matrix[actual] = {}
|
103
|
+
end
|
104
|
+
predictions = @matrix[actual]
|
105
|
+
if !predictions.has_key?(prediction)
|
106
|
+
predictions[prediction] = 0
|
107
|
+
end
|
108
|
+
|
109
|
+
unless n.class == Integer and n.positive?
|
110
|
+
raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
|
111
|
+
end
|
112
|
+
|
113
|
+
@matrix[actual][prediction] += n
|
114
|
+
end
|
115
|
+
|
116
|
+
# Returns the number of instances of the given class label which
|
117
|
+
# are incorrectly classified.
|
118
|
+
#
|
119
|
+
# false_negative(:positive) = b
|
120
|
+
#
|
121
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
122
|
+
# @return [Float] value of false negative
|
123
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
124
|
+
def false_negative(label = @labels.first)
|
125
|
+
validate_label label
|
126
|
+
predictions = @matrix.fetch(label, {})
|
127
|
+
total = 0
|
128
|
+
|
129
|
+
predictions.each_pair do |key, count|
|
130
|
+
if key != label
|
131
|
+
total += count
|
132
|
+
end
|
133
|
+
end
|
134
|
+
|
135
|
+
total
|
136
|
+
end
|
137
|
+
|
138
|
+
# Returns the number of instances incorrectly classified with the given
|
139
|
+
# class label.
|
140
|
+
#
|
141
|
+
# false_positive(:positive) = c
|
142
|
+
#
|
143
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
144
|
+
# @return [Float] value of false positive
|
145
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
146
|
+
def false_positive(label = @labels.first)
|
147
|
+
validate_label label
|
148
|
+
total = 0
|
149
|
+
|
150
|
+
@matrix.each_pair do |key, predictions|
|
151
|
+
if key != label
|
152
|
+
total += predictions.fetch(label, 0)
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
total
|
157
|
+
end
|
158
|
+
|
159
|
+
# The false rate for a given class label is the proportion of instances
|
160
|
+
# incorrectly classified as that label, out of all those instances
|
161
|
+
# not originally of that label.
|
162
|
+
#
|
163
|
+
# false_rate(:positive) = c/(c+d)
|
164
|
+
#
|
165
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
166
|
+
# @return [Float] value of false rate
|
167
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
168
|
+
def false_rate(label = @labels.first)
|
169
|
+
validate_label label
|
170
|
+
fp = false_positive(label)
|
171
|
+
tn = true_negative(label)
|
172
|
+
|
173
|
+
divide(fp, fp+tn)
|
174
|
+
end
|
175
|
+
|
176
|
+
# The F-measure for a given label is the harmonic mean of the precision
|
177
|
+
# and recall for that label.
|
178
|
+
#
|
179
|
+
# F = 2*(precision*recall)/(precision+recall)
|
180
|
+
#
|
181
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
182
|
+
# @return [Float] value of F-measure
|
183
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
184
|
+
def f_measure(label = @labels.first)
|
185
|
+
validate_label label
|
186
|
+
2*precision(label)*recall(label)/(precision(label) + recall(label))
|
187
|
+
end
|
188
|
+
|
189
|
+
# The geometric mean is the nth-root of the product of the true_rate for
|
190
|
+
# each label.
|
191
|
+
#
|
192
|
+
# a1 = a/(a+b)
|
193
|
+
# a2 = d/(c+d)
|
194
|
+
# geometric_mean = Math.sqrt(a1*a2)
|
195
|
+
#
|
196
|
+
# @return [Float] value of geometric mean
|
197
|
+
def geometric_mean
|
198
|
+
product = 1
|
199
|
+
|
200
|
+
@matrix.each_key do |key|
|
201
|
+
product *= true_rate(key)
|
202
|
+
end
|
203
|
+
|
204
|
+
product**(1.0/@matrix.size)
|
205
|
+
end
|
206
|
+
|
207
|
+
# The Kappa statistic compares the observed accuracy with an expected
|
208
|
+
# accuracy.
|
209
|
+
#
|
210
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
211
|
+
# @return [Float] value of Cohen's Kappa Statistic
|
212
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
213
|
+
def kappa(label = @labels.first)
|
214
|
+
validate_label label
|
215
|
+
tp = true_positive(label)
|
216
|
+
fn = false_negative(label)
|
217
|
+
fp = false_positive(label)
|
218
|
+
tn = true_negative(label)
|
219
|
+
total = tp+fn+fp+tn
|
220
|
+
|
221
|
+
total_accuracy = divide(tp+tn, tp+tn+fp+fn)
|
222
|
+
random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
|
223
|
+
|
224
|
+
divide(total_accuracy - random_accuracy, 1 - random_accuracy)
|
225
|
+
end
|
226
|
+
|
227
|
+
# Matthews Correlation Coefficient is a measure of the quality of binary
|
228
|
+
# classifications.
|
229
|
+
#
|
230
|
+
# mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
|
231
|
+
#
|
232
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
233
|
+
# @return [Float] value of Matthews Correlation Coefficient
|
234
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
235
|
+
def matthews_correlation(label = @labels.first)
|
236
|
+
validate_label label
|
237
|
+
tp = true_positive(label)
|
238
|
+
fn = false_negative(label)
|
239
|
+
fp = false_positive(label)
|
240
|
+
tn = true_negative(label)
|
241
|
+
|
242
|
+
divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
|
243
|
+
end
|
244
|
+
|
245
|
+
# The overall accuracy is the proportion of instances which are
|
246
|
+
# correctly labelled.
|
247
|
+
#
|
248
|
+
# overall_accuracy = (a+d)/(a+b+c+d)
|
249
|
+
#
|
250
|
+
# @return [Float] value of overall accuracy
|
251
|
+
def overall_accuracy
|
252
|
+
total_correct = 0
|
253
|
+
|
254
|
+
@matrix.each_pair do |key, predictions|
|
255
|
+
total_correct += true_positive(key)
|
256
|
+
end
|
257
|
+
|
258
|
+
divide(total_correct, total)
|
259
|
+
end
|
260
|
+
|
261
|
+
# The precision for a given class label is the proportion of instances
|
262
|
+
# classified as that class which are correct.
|
263
|
+
#
|
264
|
+
# precision(:positive) = a/(a+c)
|
265
|
+
#
|
266
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
267
|
+
# @return [Float] value of precision
|
268
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
269
|
+
def precision(label = @labels.first)
|
270
|
+
validate_label label
|
271
|
+
tp = true_positive(label)
|
272
|
+
fp = false_positive(label)
|
273
|
+
|
274
|
+
divide(tp, tp+fp)
|
275
|
+
end
|
276
|
+
|
277
|
+
# The prevalence for a given class label is the proportion of instances
|
278
|
+
# which were classified as of that label, out of the total.
|
279
|
+
#
|
280
|
+
# prevalence(:positive) = (a+c)/(a+b+c+d)
|
281
|
+
#
|
282
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
283
|
+
# @return [Float] value of prevalence
|
284
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
285
|
+
def prevalence(label = @labels.first)
|
286
|
+
validate_label label
|
287
|
+
tp = true_positive(label)
|
288
|
+
fn = false_negative(label)
|
289
|
+
fp = false_positive(label)
|
290
|
+
tn = true_negative(label)
|
291
|
+
total = tp+fn+fp+tn
|
292
|
+
|
293
|
+
divide(tp+fn, total)
|
294
|
+
end
|
295
|
+
|
296
|
+
# The recall is another name for the true rate.
|
297
|
+
#
|
298
|
+
# @see true_rate
|
299
|
+
# @param (see #true_rate)
|
300
|
+
# @return (see #true_rate)
|
301
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
302
|
+
def recall(label = @labels.first)
|
303
|
+
validate_label label
|
304
|
+
true_rate(label)
|
305
|
+
end
|
306
|
+
|
307
|
+
# Sensitivity is another name for the true rate.
|
308
|
+
#
|
309
|
+
# @see true_rate
|
310
|
+
# @param (see #true_rate)
|
311
|
+
# @return (see #true_rate)
|
312
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
313
|
+
def sensitivity(label = @labels.first)
|
314
|
+
validate_label label
|
315
|
+
true_rate(label)
|
316
|
+
end
|
317
|
+
|
318
|
+
# The specificity for a given class label is 1 - false_rate(label)
|
319
|
+
#
|
320
|
+
# In two-class case, specificity = 1 - false_positive_rate
|
321
|
+
#
|
322
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
323
|
+
# @return [Float] value of specificity
|
324
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
325
|
+
def specificity(label = @labels.first)
|
326
|
+
validate_label label
|
327
|
+
1-false_rate(label)
|
328
|
+
end
|
329
|
+
|
330
|
+
# Returns the table in a string format, representing the entries as a
|
331
|
+
# printable table.
|
332
|
+
#
|
333
|
+
# @return [String] representation as a printable table.
|
334
|
+
def to_s
|
335
|
+
ls = labels
|
336
|
+
result = ""
|
337
|
+
|
338
|
+
title_line = "Predicted "
|
339
|
+
label_line = ""
|
340
|
+
ls.each { |l| label_line << "#{l} " }
|
341
|
+
label_line << " " while label_line.size < title_line.size
|
342
|
+
title_line << " " while title_line.size < label_line.size
|
343
|
+
result << title_line << "|\n" << label_line << "| Actual\n"
|
344
|
+
result << "-"*title_line.size << "+-------\n"
|
345
|
+
|
346
|
+
ls.each do |l|
|
347
|
+
count_line = ""
|
348
|
+
ls.each_with_index do |m, i|
|
349
|
+
count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
|
350
|
+
end
|
351
|
+
result << count_line.ljust(title_line.size) << "| #{l}\n"
|
352
|
+
end
|
353
|
+
|
354
|
+
result
|
355
|
+
end
|
356
|
+
|
357
|
+
# Returns the total number of instances referenced in the matrix.
|
358
|
+
#
|
359
|
+
# total = a+b+c+d
|
360
|
+
#
|
361
|
+
# @return [Integer] total number of instances referenced in the matrix.
|
362
|
+
def total
|
363
|
+
total = 0
|
364
|
+
|
365
|
+
@matrix.each_value do |predictions|
|
366
|
+
predictions.each_value do |count|
|
367
|
+
total += count
|
368
|
+
end
|
369
|
+
end
|
370
|
+
|
371
|
+
total
|
372
|
+
end
|
373
|
+
|
374
|
+
# Returns the number of instances NOT of the given class label which
|
375
|
+
# are correctly classified.
|
376
|
+
#
|
377
|
+
# true_negative(:positive) = d
|
378
|
+
#
|
379
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
380
|
+
# @return [Integer] number of instances not of given label which are correctly classified
|
381
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
382
|
+
def true_negative(label = @labels.first)
|
383
|
+
validate_label label
|
384
|
+
total = 0
|
385
|
+
|
386
|
+
@matrix.each_pair do |key, predictions|
|
387
|
+
if key != label
|
388
|
+
total += predictions.fetch(key, 0)
|
389
|
+
end
|
390
|
+
end
|
391
|
+
|
392
|
+
total
|
393
|
+
end
|
394
|
+
|
395
|
+
# Returns the number of instances of the given class label which are
|
396
|
+
# correctly classified.
|
397
|
+
#
|
398
|
+
# true_positive(:positive) = a
|
399
|
+
#
|
400
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
401
|
+
# @return [Integer] number of instances of given label which are correctly classified
|
402
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
403
|
+
def true_positive(label = @labels.first)
|
404
|
+
validate_label label
|
405
|
+
predictions = @matrix.fetch(label, {})
|
406
|
+
predictions.fetch(label, 0)
|
407
|
+
end
|
408
|
+
|
409
|
+
# The true rate for a given class label is the proportion of instances of
|
410
|
+
# that class which are correctly classified.
|
411
|
+
#
|
412
|
+
# true_rate(:positive) = a/(a+b)
|
413
|
+
#
|
414
|
+
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
415
|
+
# @return [Float] proportion of instances which are correctly classified
|
416
|
+
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
417
|
+
def true_rate(label = @labels.first)
|
418
|
+
validate_label label
|
419
|
+
tp = true_positive(label)
|
420
|
+
fn = false_negative(label)
|
421
|
+
|
422
|
+
divide(tp, tp+fn)
|
423
|
+
end
|
424
|
+
|
425
|
+
private
|
426
|
+
|
427
|
+
# A form of "safe divide".
|
428
|
+
# Checks if divisor is zero, and returns 0.0 if so.
|
429
|
+
# This avoids a run-time error.
|
430
|
+
# Also, ensures floating point division is done.
|
431
|
+
def divide(x, y)
|
432
|
+
if y.zero?
|
433
|
+
0.0
|
434
|
+
else
|
435
|
+
x.to_f/y
|
436
|
+
end
|
437
|
+
end
|
438
|
+
|
439
|
+
# Checks if given label(s) is non-nil and in @labels, or if @labels is empty
|
440
|
+
# Raises ArgumentError if not
|
441
|
+
def validate_label *labels
|
442
|
+
return true if @labels.empty?
|
443
|
+
labels.each do |label|
|
444
|
+
unless label and @labels.include?(label)
|
445
|
+
raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
|
446
|
+
end
|
447
|
+
end
|
448
|
+
end
|
449
|
+
end
|
450
|
+
|
451
|
+
|
data/test/matrix_test.rb
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
require 'confusion_matrix'
|
2
|
+
require 'minitest/autorun'
|
3
|
+
|
4
|
+
class TestConfusionMatrix < MiniTest::Test
|
5
|
+
def test_empty_case
|
6
|
+
cm = ConfusionMatrix.new
|
7
|
+
assert(0, cm.total)
|
8
|
+
assert(0, cm.true_positive(:none))
|
9
|
+
assert(0, cm.false_negative(:none))
|
10
|
+
assert(0, cm.false_positive(:none))
|
11
|
+
assert(0, cm.true_negative(:none))
|
12
|
+
assert_in_delta(0, cm.true_rate(:none))
|
13
|
+
end
|
14
|
+
|
15
|
+
def test_two_classes
|
16
|
+
cm = ConfusionMatrix.new
|
17
|
+
10.times { cm.add_for(:pos, :pos) }
|
18
|
+
5.times { cm.add_for(:pos, :neg) }
|
19
|
+
20.times { cm.add_for(:neg, :neg) }
|
20
|
+
5.times { cm.add_for(:neg, :pos) }
|
21
|
+
|
22
|
+
assert_equal([:neg, :pos], cm.labels)
|
23
|
+
assert_equal(10, cm.count_for(:pos, :pos))
|
24
|
+
assert_equal(5, cm.count_for(:pos, :neg))
|
25
|
+
assert_equal(20, cm.count_for(:neg, :neg))
|
26
|
+
assert_equal(5, cm.count_for(:neg, :pos))
|
27
|
+
|
28
|
+
assert_equal(40, cm.total)
|
29
|
+
assert_equal(10, cm.true_positive(:pos))
|
30
|
+
assert_equal(5, cm.false_negative(:pos))
|
31
|
+
assert_equal(5, cm.false_positive(:pos))
|
32
|
+
assert_equal(20, cm.true_negative(:pos))
|
33
|
+
assert_equal(20, cm.true_positive(:neg))
|
34
|
+
assert_equal(5, cm.false_negative(:neg))
|
35
|
+
assert_equal(5, cm.false_positive(:neg))
|
36
|
+
assert_equal(10, cm.true_negative(:neg))
|
37
|
+
|
38
|
+
assert_in_delta(0.6667, cm.true_rate(:pos))
|
39
|
+
assert_in_delta(0.8, cm.true_rate(:neg))
|
40
|
+
assert_in_delta(0.2, cm.false_rate(:pos))
|
41
|
+
assert_in_delta(0.3333, cm.false_rate(:neg))
|
42
|
+
assert_in_delta(0.6667, cm.precision(:pos))
|
43
|
+
assert_in_delta(0.8, cm.precision(:neg))
|
44
|
+
assert_in_delta(0.6667, cm.recall(:pos))
|
45
|
+
assert_in_delta(0.8, cm.recall(:neg))
|
46
|
+
assert_in_delta(0.6667, cm.sensitivity(:pos))
|
47
|
+
assert_in_delta(0.8, cm.sensitivity(:neg))
|
48
|
+
assert_in_delta(0.75, cm.overall_accuracy)
|
49
|
+
assert_in_delta(0.6667, cm.f_measure(:pos))
|
50
|
+
assert_in_delta(0.8, cm.f_measure(:neg))
|
51
|
+
assert_in_delta(0.7303, cm.geometric_mean)
|
52
|
+
end
|
53
|
+
|
54
|
+
# Example from:
|
55
|
+
# https://www.datatechnotes.com/2019/02/accuracy-metrics-in-classification.html
|
56
|
+
def test_two_classes_2
|
57
|
+
cm = ConfusionMatrix.new
|
58
|
+
5.times { cm.add_for(:pos, :pos) }
|
59
|
+
1.times { cm.add_for(:pos, :neg) }
|
60
|
+
3.times { cm.add_for(:neg, :neg) }
|
61
|
+
2.times { cm.add_for(:neg, :pos) }
|
62
|
+
|
63
|
+
assert_equal(11, cm.total)
|
64
|
+
assert_equal(5, cm.true_positive(:pos))
|
65
|
+
assert_equal(1, cm.false_negative(:pos))
|
66
|
+
assert_equal(2, cm.false_positive(:pos))
|
67
|
+
assert_equal(3, cm.true_negative(:pos))
|
68
|
+
|
69
|
+
assert_in_delta(0.7142, cm.precision(:pos))
|
70
|
+
assert_in_delta(0.8333, cm.recall(:pos))
|
71
|
+
assert_in_delta(0.7272, cm.overall_accuracy)
|
72
|
+
assert_in_delta(0.7692, cm.f_measure(:pos))
|
73
|
+
assert_in_delta(0.8333, cm.sensitivity(:pos))
|
74
|
+
assert_in_delta(0.6, cm.specificity(:pos))
|
75
|
+
assert_in_delta(0.4407, cm.kappa(:pos))
|
76
|
+
assert_in_delta(0.5454, cm.prevalence(:pos))
|
77
|
+
end
|
78
|
+
|
79
|
+
# Examples from:
|
80
|
+
# https://standardwisdom.com/softwarejournal/2011/12/matthews-correlation-coefficient-how-well-does-it-do/
|
81
|
+
def two_class_case(a,b,c,d,e,f,g,h,i)
|
82
|
+
cm = ConfusionMatrix.new
|
83
|
+
a.times { cm.add_for(:pos, :pos) }
|
84
|
+
b.times { cm.add_for(:pos, :neg) }
|
85
|
+
c.times { cm.add_for(:neg, :neg) }
|
86
|
+
d.times { cm.add_for(:neg, :pos) }
|
87
|
+
|
88
|
+
assert_in_delta(e, cm.matthews_correlation(:pos))
|
89
|
+
assert_in_delta(f, cm.precision(:pos))
|
90
|
+
assert_in_delta(g, cm.recall(:pos))
|
91
|
+
assert_in_delta(h, cm.f_measure(:pos))
|
92
|
+
assert_in_delta(i, cm.kappa(:pos))
|
93
|
+
end
|
94
|
+
|
95
|
+
def test_two_classes_3
|
96
|
+
two_class_case(100, 0, 900, 0, 1.0, 1.0, 1.0, 1.0, 1.0)
|
97
|
+
two_class_case(65, 35, 825, 75, 0.490, 0.4643, 0.65, 0.542, 0.4811)
|
98
|
+
two_class_case(50, 50, 700, 200, 0.192, 0.2, 0.5, 0.286, 0.1666)
|
99
|
+
end
|
100
|
+
|
101
|
+
def test_three_classes
|
102
|
+
cm = ConfusionMatrix.new
|
103
|
+
10.times { cm.add_for(:red, :red) }
|
104
|
+
7.times { cm.add_for(:red, :blue) }
|
105
|
+
5.times { cm.add_for(:red, :green) }
|
106
|
+
20.times { cm.add_for(:blue, :red) }
|
107
|
+
5.times { cm.add_for(:blue, :blue) }
|
108
|
+
15.times { cm.add_for(:blue, :green) }
|
109
|
+
30.times { cm.add_for(:green, :red) }
|
110
|
+
12.times { cm.add_for(:green, :blue) }
|
111
|
+
8.times { cm.add_for(:green, :green) }
|
112
|
+
|
113
|
+
assert_equal([:blue, :green, :red], cm.labels)
|
114
|
+
assert_equal(112, cm.total)
|
115
|
+
assert_equal(10, cm.true_positive(:red))
|
116
|
+
assert_equal(12, cm.false_negative(:red))
|
117
|
+
assert_equal(50, cm.false_positive(:red))
|
118
|
+
assert_equal(13, cm.true_negative(:red))
|
119
|
+
assert_equal(5, cm.true_positive(:blue))
|
120
|
+
assert_equal(35, cm.false_negative(:blue))
|
121
|
+
assert_equal(19, cm.false_positive(:blue))
|
122
|
+
assert_equal(18, cm.true_negative(:blue))
|
123
|
+
assert_equal(8, cm.true_positive(:green))
|
124
|
+
assert_equal(42, cm.false_negative(:green))
|
125
|
+
assert_equal(20, cm.false_positive(:green))
|
126
|
+
assert_equal(15, cm.true_negative(:green))
|
127
|
+
end
|
128
|
+
|
129
|
+
def test_add_for_n
|
130
|
+
cm = ConfusionMatrix.new
|
131
|
+
cm.add_for(:pos, :pos, 3)
|
132
|
+
cm.add_for(:pos, :neg)
|
133
|
+
cm.add_for(:neg, :pos, 2)
|
134
|
+
cm.add_for(:neg, :neg, 1)
|
135
|
+
assert_equal(7, cm.total)
|
136
|
+
assert_equal(3, cm.count_for(:pos, :pos))
|
137
|
+
# - check errors
|
138
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :pos, 0) }
|
139
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :pos, -3) }
|
140
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :pos, nil) }
|
141
|
+
end
|
142
|
+
|
143
|
+
def test_use_labels
|
144
|
+
# - check errors
|
145
|
+
assert_raises(ArgumentError) { ConfusionMatrix.new(:pos) }
|
146
|
+
assert_raises(ArgumentError) { ConfusionMatrix.new(:pos, :pos) }
|
147
|
+
# - check created matrix
|
148
|
+
cm = ConfusionMatrix.new(:pos, :neg)
|
149
|
+
assert_equal([:pos, :neg], cm.labels)
|
150
|
+
assert_raises(ArgumentError) { cm.add_for(:pos, :nothing) }
|
151
|
+
cm.add_for(:pos, :neg, 3)
|
152
|
+
cm.add_for(:neg, :pos, 2)
|
153
|
+
assert_equal(2, cm.false_negative(:neg))
|
154
|
+
assert_equal(3, cm.false_negative(:pos))
|
155
|
+
assert_equal(3, cm.false_negative())
|
156
|
+
assert_raises(ArgumentError) { cm.false_negative(:nothing) }
|
157
|
+
assert_raises(ArgumentError) { cm.false_negative(nil) }
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
metadata
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: confusion_matrix
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 1.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Peter Lane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2023-02-04 00:00:00.000000000 Z
|
12
|
+
dependencies: []
|
13
|
+
description: "A confusion matrix is used in data-mining as a summary of the performance
|
14
|
+
of a\nclassification algorithm. This library allows the user to incrementally add
|
15
|
+
\nresults to a confusion matrix, and then retrieve statistical information.\n"
|
16
|
+
email: peterlane@gmx.com
|
17
|
+
executables: []
|
18
|
+
extensions: []
|
19
|
+
extra_rdoc_files:
|
20
|
+
- README.rdoc
|
21
|
+
- LICENSE.rdoc
|
22
|
+
files:
|
23
|
+
- LICENSE.rdoc
|
24
|
+
- README.rdoc
|
25
|
+
- lib/confusion_matrix.rb
|
26
|
+
- test/matrix_test.rb
|
27
|
+
homepage:
|
28
|
+
licenses:
|
29
|
+
- MIT
|
30
|
+
metadata: {}
|
31
|
+
post_install_message:
|
32
|
+
rdoc_options:
|
33
|
+
- "-m"
|
34
|
+
- README.rdoc
|
35
|
+
require_paths:
|
36
|
+
- lib
|
37
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
38
|
+
requirements:
|
39
|
+
- - ">="
|
40
|
+
- !ruby/object:Gem::Version
|
41
|
+
version: '2.5'
|
42
|
+
- - "<"
|
43
|
+
- !ruby/object:Gem::Version
|
44
|
+
version: '4.0'
|
45
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
46
|
+
requirements:
|
47
|
+
- - ">="
|
48
|
+
- !ruby/object:Gem::Version
|
49
|
+
version: '0'
|
50
|
+
requirements: []
|
51
|
+
rubygems_version: 3.4.5
|
52
|
+
signing_key:
|
53
|
+
specification_version: 4
|
54
|
+
summary: Construct a confusion matrix and retrieve statistical information from it.
|
55
|
+
test_files: []
|