confusion_matrix 1.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: bc3f9ed40d45bb892541c4a3dc97c04c5893e335ec79b2526d072d291d08a7b7
4
+ data.tar.gz: 602f9857a4e45283357117974943d46f4802fbd9c5ce7b1be0ec704472d29ba4
5
+ SHA512:
6
+ metadata.gz: 20cc86e92c2ad0867206ee2c2a46fe31de7c08c15e90491d481a8c40fd94541b4dea87b575792652cedbb49fea284e70e4f72250ab6f2e83de2e9b57ea7136c6
7
+ data.tar.gz: 972386c261254d2fc44b09f320d48b41639c09682f16c0ddeb88cd2697b683f7b178d257e7edcb3bcba51d0dfc4c625bf478fcadd1a6e2e81ed8354a588266ce
data/LICENSE.rdoc ADDED
@@ -0,0 +1,22 @@
1
+ = MIT License
2
+
3
+ Copyright (c) 2020-23, Peter Lane
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
data/README.rdoc ADDED
@@ -0,0 +1,88 @@
1
+ = Confusion Matrix
2
+
3
+ Install from {RubyGems}[https://rubygems.org/gems/confusion_matrix/]:
4
+
5
+ > gem install confusion_matrix
6
+
7
+ source:: https://notabug.org/peterlane/confusion-matrix-ruby/
8
+
9
+ == Description
10
+
11
+ A confusion matrix is used in data-mining as a summary of the performance of a
12
+ classification algorithm. Each row represents the _actual_ class of an
13
+ instance, and each column represents the _predicted_ class of that instance,
14
+ i.e. the class that they were classified as. Numbers at each (row, column)
15
+ reflect the total number of instances of actual class "row" which were
16
+ predicted to fall in class "column".
17
+
18
+ A two-class example is:
19
+
20
+ Classified Classified |
21
+ Positive Negative | Actual
22
+ ------------------------------+------------
23
+ a b | Positive
24
+ c d | Negative
25
+
26
+ Here the value:
27
+
28
+ a:: is the number of true positives (those labelled positive and classified positive)
29
+ b:: is the number of false negatives (those labelled positive but classified negative)
30
+ c:: is the number of false positives (those labelled negative but classified positive)
31
+ d:: is the number of true negatives (those labelled negative and classified negative)
32
+
33
+ From this table we can calculate statistics like:
34
+
35
+ true_positive_rate:: a/(a+b)
36
+ positive recall:: a/(a+c)
37
+
38
+ The implementation supports confusion matrices with more than two
39
+ classes, and hence most statistics are calculated with reference to a
40
+ named class. When more than two classes are in use, the statistics
41
+ are calculated as if the named class were positive and all the other
42
+ classes are grouped as if negative.
43
+
44
+ For example, in a three-class example:
45
+
46
+ Classified Classified Classified |
47
+ Red Blue Green | Actual
48
+ --------------------------------------------+------------
49
+ a b c | Red
50
+ d e f | Blue
51
+ g h i | Green
52
+
53
+ We can calculate:
54
+
55
+ true_red_rate:: a/(a+b+c)
56
+ red recall:: a/(a+d+g)
57
+
58
+ == Example
59
+
60
+ The following example creates a simple two-class confusion matrix,
61
+ prints a few statistics and displays the table.
62
+
63
+ require 'confusion_matrix'
64
+
65
+ cm = ConfusionMatrix.new :pos, :neg
66
+ cm.add_for(:pos, :pos, 10)
67
+ 3.times { cm.add_for(:pos, :neg) }
68
+ 20.times { cm.add_for(:neg, :neg) }
69
+ 5.times { cm.add_for(:neg, :pos) }
70
+
71
+ puts "Precision: #{cm.precision}"
72
+ puts "Recall: #{cm.recall}"
73
+ puts "MCC: #{cm.matthews_correlation}"
74
+ puts
75
+ puts(cm.to_s)
76
+
77
+ Output:
78
+
79
+ Precision: 0.6666666666666666
80
+ Recall: 0.7692307692307693
81
+ MCC: 0.5524850114241865
82
+
83
+ Predicted |
84
+ pos neg | Actual
85
+ ----------+-------
86
+ 10 3 | pos
87
+ 5 20 | neg
88
+
@@ -0,0 +1,451 @@
1
+
2
+ # This class holds the confusion matrix information.
3
+ # It is designed to be called incrementally, as results are obtained
4
+ # from the classifier model.
5
+ #
6
+ # At any point, statistics may be obtained by calling the relevant methods.
7
+ #
8
+ # A two-class example is:
9
+ #
10
+ # Classified Classified |
11
+ # Positive Negative | Actual
12
+ # ------------------------------+------------
13
+ # a b | Positive
14
+ # c d | Negative
15
+ #
16
+ # Statistical methods will be described with reference to this example.
17
+ #
18
+ class ConfusionMatrix
19
+ # Creates a new, empty instance of a confusion matrix.
20
+ #
21
+ # @param labels [<String, Symbol>, ...] if provided, makes the matrix
22
+ # use the first label as a default label, and also check
23
+ # all operations use one of the pre-defined labels.
24
+ # @raise [ArgumentError] if there are not at least two unique labels, when provided.
25
+ def initialize(*labels)
26
+ @matrix = {}
27
+ @labels = labels.uniq
28
+ if @labels.size == 1
29
+ raise ArgumentError.new("If labels are provided, there must be at least two.")
30
+ else # preset the matrix Hash
31
+ @labels.each do |actual|
32
+ @matrix[actual] = {}
33
+ @labels.each do |predicted|
34
+ @matrix[actual][predicted] = 0
35
+ end
36
+ end
37
+ end
38
+ end
39
+
40
+ # Returns a list of labels used in the matrix.
41
+ #
42
+ # cm = ConfusionMatrix.new
43
+ # cm.add_for(:pos, :neg)
44
+ # cm.labels # => [:neg, :pos]
45
+ #
46
+ # @return [Array<String>] labels used in the matrix.
47
+ def labels
48
+ if @labels.size >= 2 # if we defined some labels, return them
49
+ @labels
50
+ else
51
+ result = []
52
+
53
+ @matrix.each_pair do |key, predictions|
54
+ result << key
55
+ predictions.each_key do |key|
56
+ result << key
57
+ end
58
+ end
59
+
60
+ result.uniq.sort
61
+ end
62
+ end
63
+
64
+ # Return the count for (actual,prediction) pair.
65
+ #
66
+ # cm = ConfusionMatrix.new
67
+ # cm.add_for(:pos, :neg)
68
+ # cm.count_for(:pos, :neg) # => 1
69
+ #
70
+ # @param actual [String, Symbol] is actual class of the instance,
71
+ # which we expect the classifier to predict
72
+ # @param prediction [String, Symbol] is the predicted class of the instance,
73
+ # as output from the classifier
74
+ # @return [Integer] number of observations of (actual, prediction) pair
75
+ # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
76
+ # pre-defined labels in matrix
77
+ def count_for(actual, prediction)
78
+ validate_label actual, prediction
79
+ predictions = @matrix.fetch(actual, {})
80
+ predictions.fetch(prediction, 0)
81
+ end
82
+
83
+ # Adds one result to the matrix for a given (actual, prediction) pair of labels.
84
+ # If the matrix was given a pre-defined list of labels on construction, then
85
+ # these given labels must be from the pre-defined list.
86
+ # If no pre-defined list of labels was used in constructing the matrix, then
87
+ # labels will be added to matrix.
88
+ #
89
+ # Class labels may be any hashable value, though ideally they are strings or symbols.
90
+ #
91
+ # @param actual [String, Symbol] is actual class of the instance,
92
+ # which we expect the classifier to predict
93
+ # @param prediction [String, Symbol] is the predicted class of the instance,
94
+ # as output from the classifier
95
+ # @param n [Integer] number of observations to add
96
+ # @raise [ArgumentError] if +n+ is not an Integer
97
+ # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
98
+ # pre-defined labels in matrix
99
+ def add_for(actual, prediction, n = 1)
100
+ validate_label actual, prediction
101
+ if !@matrix.has_key?(actual)
102
+ @matrix[actual] = {}
103
+ end
104
+ predictions = @matrix[actual]
105
+ if !predictions.has_key?(prediction)
106
+ predictions[prediction] = 0
107
+ end
108
+
109
+ unless n.class == Integer and n.positive?
110
+ raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
111
+ end
112
+
113
+ @matrix[actual][prediction] += n
114
+ end
115
+
116
+ # Returns the number of instances of the given class label which
117
+ # are incorrectly classified.
118
+ #
119
+ # false_negative(:positive) = b
120
+ #
121
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
122
+ # @return [Float] value of false negative
123
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
124
+ def false_negative(label = @labels.first)
125
+ validate_label label
126
+ predictions = @matrix.fetch(label, {})
127
+ total = 0
128
+
129
+ predictions.each_pair do |key, count|
130
+ if key != label
131
+ total += count
132
+ end
133
+ end
134
+
135
+ total
136
+ end
137
+
138
+ # Returns the number of instances incorrectly classified with the given
139
+ # class label.
140
+ #
141
+ # false_positive(:positive) = c
142
+ #
143
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
144
+ # @return [Float] value of false positive
145
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
146
+ def false_positive(label = @labels.first)
147
+ validate_label label
148
+ total = 0
149
+
150
+ @matrix.each_pair do |key, predictions|
151
+ if key != label
152
+ total += predictions.fetch(label, 0)
153
+ end
154
+ end
155
+
156
+ total
157
+ end
158
+
159
+ # The false rate for a given class label is the proportion of instances
160
+ # incorrectly classified as that label, out of all those instances
161
+ # not originally of that label.
162
+ #
163
+ # false_rate(:positive) = c/(c+d)
164
+ #
165
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
166
+ # @return [Float] value of false rate
167
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
168
+ def false_rate(label = @labels.first)
169
+ validate_label label
170
+ fp = false_positive(label)
171
+ tn = true_negative(label)
172
+
173
+ divide(fp, fp+tn)
174
+ end
175
+
176
+ # The F-measure for a given label is the harmonic mean of the precision
177
+ # and recall for that label.
178
+ #
179
+ # F = 2*(precision*recall)/(precision+recall)
180
+ #
181
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
182
+ # @return [Float] value of F-measure
183
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
184
+ def f_measure(label = @labels.first)
185
+ validate_label label
186
+ 2*precision(label)*recall(label)/(precision(label) + recall(label))
187
+ end
188
+
189
+ # The geometric mean is the nth-root of the product of the true_rate for
190
+ # each label.
191
+ #
192
+ # a1 = a/(a+b)
193
+ # a2 = d/(c+d)
194
+ # geometric_mean = Math.sqrt(a1*a2)
195
+ #
196
+ # @return [Float] value of geometric mean
197
+ def geometric_mean
198
+ product = 1
199
+
200
+ @matrix.each_key do |key|
201
+ product *= true_rate(key)
202
+ end
203
+
204
+ product**(1.0/@matrix.size)
205
+ end
206
+
207
+ # The Kappa statistic compares the observed accuracy with an expected
208
+ # accuracy.
209
+ #
210
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
211
+ # @return [Float] value of Cohen's Kappa Statistic
212
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
213
+ def kappa(label = @labels.first)
214
+ validate_label label
215
+ tp = true_positive(label)
216
+ fn = false_negative(label)
217
+ fp = false_positive(label)
218
+ tn = true_negative(label)
219
+ total = tp+fn+fp+tn
220
+
221
+ total_accuracy = divide(tp+tn, tp+tn+fp+fn)
222
+ random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
223
+
224
+ divide(total_accuracy - random_accuracy, 1 - random_accuracy)
225
+ end
226
+
227
+ # Matthews Correlation Coefficient is a measure of the quality of binary
228
+ # classifications.
229
+ #
230
+ # mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
231
+ #
232
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
233
+ # @return [Float] value of Matthews Correlation Coefficient
234
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
235
+ def matthews_correlation(label = @labels.first)
236
+ validate_label label
237
+ tp = true_positive(label)
238
+ fn = false_negative(label)
239
+ fp = false_positive(label)
240
+ tn = true_negative(label)
241
+
242
+ divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
243
+ end
244
+
245
+ # The overall accuracy is the proportion of instances which are
246
+ # correctly labelled.
247
+ #
248
+ # overall_accuracy = (a+d)/(a+b+c+d)
249
+ #
250
+ # @return [Float] value of overall accuracy
251
+ def overall_accuracy
252
+ total_correct = 0
253
+
254
+ @matrix.each_pair do |key, predictions|
255
+ total_correct += true_positive(key)
256
+ end
257
+
258
+ divide(total_correct, total)
259
+ end
260
+
261
+ # The precision for a given class label is the proportion of instances
262
+ # classified as that class which are correct.
263
+ #
264
+ # precision(:positive) = a/(a+c)
265
+ #
266
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
267
+ # @return [Float] value of precision
268
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
269
+ def precision(label = @labels.first)
270
+ validate_label label
271
+ tp = true_positive(label)
272
+ fp = false_positive(label)
273
+
274
+ divide(tp, tp+fp)
275
+ end
276
+
277
+ # The prevalence for a given class label is the proportion of instances
278
+ # which were classified as of that label, out of the total.
279
+ #
280
+ # prevalence(:positive) = (a+c)/(a+b+c+d)
281
+ #
282
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
283
+ # @return [Float] value of prevalence
284
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
285
+ def prevalence(label = @labels.first)
286
+ validate_label label
287
+ tp = true_positive(label)
288
+ fn = false_negative(label)
289
+ fp = false_positive(label)
290
+ tn = true_negative(label)
291
+ total = tp+fn+fp+tn
292
+
293
+ divide(tp+fn, total)
294
+ end
295
+
296
+ # The recall is another name for the true rate.
297
+ #
298
+ # @see true_rate
299
+ # @param (see #true_rate)
300
+ # @return (see #true_rate)
301
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
302
+ def recall(label = @labels.first)
303
+ validate_label label
304
+ true_rate(label)
305
+ end
306
+
307
+ # Sensitivity is another name for the true rate.
308
+ #
309
+ # @see true_rate
310
+ # @param (see #true_rate)
311
+ # @return (see #true_rate)
312
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
313
+ def sensitivity(label = @labels.first)
314
+ validate_label label
315
+ true_rate(label)
316
+ end
317
+
318
+ # The specificity for a given class label is 1 - false_rate(label)
319
+ #
320
+ # In two-class case, specificity = 1 - false_positive_rate
321
+ #
322
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
323
+ # @return [Float] value of specificity
324
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
325
+ def specificity(label = @labels.first)
326
+ validate_label label
327
+ 1-false_rate(label)
328
+ end
329
+
330
+ # Returns the table in a string format, representing the entries as a
331
+ # printable table.
332
+ #
333
+ # @return [String] representation as a printable table.
334
+ def to_s
335
+ ls = labels
336
+ result = ""
337
+
338
+ title_line = "Predicted "
339
+ label_line = ""
340
+ ls.each { |l| label_line << "#{l} " }
341
+ label_line << " " while label_line.size < title_line.size
342
+ title_line << " " while title_line.size < label_line.size
343
+ result << title_line << "|\n" << label_line << "| Actual\n"
344
+ result << "-"*title_line.size << "+-------\n"
345
+
346
+ ls.each do |l|
347
+ count_line = ""
348
+ ls.each_with_index do |m, i|
349
+ count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
350
+ end
351
+ result << count_line.ljust(title_line.size) << "| #{l}\n"
352
+ end
353
+
354
+ result
355
+ end
356
+
357
+ # Returns the total number of instances referenced in the matrix.
358
+ #
359
+ # total = a+b+c+d
360
+ #
361
+ # @return [Integer] total number of instances referenced in the matrix.
362
+ def total
363
+ total = 0
364
+
365
+ @matrix.each_value do |predictions|
366
+ predictions.each_value do |count|
367
+ total += count
368
+ end
369
+ end
370
+
371
+ total
372
+ end
373
+
374
+ # Returns the number of instances NOT of the given class label which
375
+ # are correctly classified.
376
+ #
377
+ # true_negative(:positive) = d
378
+ #
379
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
380
+ # @return [Integer] number of instances not of given label which are correctly classified
381
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
382
+ def true_negative(label = @labels.first)
383
+ validate_label label
384
+ total = 0
385
+
386
+ @matrix.each_pair do |key, predictions|
387
+ if key != label
388
+ total += predictions.fetch(key, 0)
389
+ end
390
+ end
391
+
392
+ total
393
+ end
394
+
395
+ # Returns the number of instances of the given class label which are
396
+ # correctly classified.
397
+ #
398
+ # true_positive(:positive) = a
399
+ #
400
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
401
+ # @return [Integer] number of instances of given label which are correctly classified
402
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
403
+ def true_positive(label = @labels.first)
404
+ validate_label label
405
+ predictions = @matrix.fetch(label, {})
406
+ predictions.fetch(label, 0)
407
+ end
408
+
409
+ # The true rate for a given class label is the proportion of instances of
410
+ # that class which are correctly classified.
411
+ #
412
+ # true_rate(:positive) = a/(a+b)
413
+ #
414
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
415
+ # @return [Float] proportion of instances which are correctly classified
416
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
417
+ def true_rate(label = @labels.first)
418
+ validate_label label
419
+ tp = true_positive(label)
420
+ fn = false_negative(label)
421
+
422
+ divide(tp, tp+fn)
423
+ end
424
+
425
+ private
426
+
427
+ # A form of "safe divide".
428
+ # Checks if divisor is zero, and returns 0.0 if so.
429
+ # This avoids a run-time error.
430
+ # Also, ensures floating point division is done.
431
+ def divide(x, y)
432
+ if y.zero?
433
+ 0.0
434
+ else
435
+ x.to_f/y
436
+ end
437
+ end
438
+
439
+ # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
440
+ # Raises ArgumentError if not
441
+ def validate_label *labels
442
+ return true if @labels.empty?
443
+ labels.each do |label|
444
+ unless label and @labels.include?(label)
445
+ raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
446
+ end
447
+ end
448
+ end
449
+ end
450
+
451
+
@@ -0,0 +1,160 @@
1
+ require 'confusion_matrix'
2
+ require 'minitest/autorun'
3
+
4
+ class TestConfusionMatrix < MiniTest::Test
5
+ def test_empty_case
6
+ cm = ConfusionMatrix.new
7
+ assert(0, cm.total)
8
+ assert(0, cm.true_positive(:none))
9
+ assert(0, cm.false_negative(:none))
10
+ assert(0, cm.false_positive(:none))
11
+ assert(0, cm.true_negative(:none))
12
+ assert_in_delta(0, cm.true_rate(:none))
13
+ end
14
+
15
+ def test_two_classes
16
+ cm = ConfusionMatrix.new
17
+ 10.times { cm.add_for(:pos, :pos) }
18
+ 5.times { cm.add_for(:pos, :neg) }
19
+ 20.times { cm.add_for(:neg, :neg) }
20
+ 5.times { cm.add_for(:neg, :pos) }
21
+
22
+ assert_equal([:neg, :pos], cm.labels)
23
+ assert_equal(10, cm.count_for(:pos, :pos))
24
+ assert_equal(5, cm.count_for(:pos, :neg))
25
+ assert_equal(20, cm.count_for(:neg, :neg))
26
+ assert_equal(5, cm.count_for(:neg, :pos))
27
+
28
+ assert_equal(40, cm.total)
29
+ assert_equal(10, cm.true_positive(:pos))
30
+ assert_equal(5, cm.false_negative(:pos))
31
+ assert_equal(5, cm.false_positive(:pos))
32
+ assert_equal(20, cm.true_negative(:pos))
33
+ assert_equal(20, cm.true_positive(:neg))
34
+ assert_equal(5, cm.false_negative(:neg))
35
+ assert_equal(5, cm.false_positive(:neg))
36
+ assert_equal(10, cm.true_negative(:neg))
37
+
38
+ assert_in_delta(0.6667, cm.true_rate(:pos))
39
+ assert_in_delta(0.8, cm.true_rate(:neg))
40
+ assert_in_delta(0.2, cm.false_rate(:pos))
41
+ assert_in_delta(0.3333, cm.false_rate(:neg))
42
+ assert_in_delta(0.6667, cm.precision(:pos))
43
+ assert_in_delta(0.8, cm.precision(:neg))
44
+ assert_in_delta(0.6667, cm.recall(:pos))
45
+ assert_in_delta(0.8, cm.recall(:neg))
46
+ assert_in_delta(0.6667, cm.sensitivity(:pos))
47
+ assert_in_delta(0.8, cm.sensitivity(:neg))
48
+ assert_in_delta(0.75, cm.overall_accuracy)
49
+ assert_in_delta(0.6667, cm.f_measure(:pos))
50
+ assert_in_delta(0.8, cm.f_measure(:neg))
51
+ assert_in_delta(0.7303, cm.geometric_mean)
52
+ end
53
+
54
+ # Example from:
55
+ # https://www.datatechnotes.com/2019/02/accuracy-metrics-in-classification.html
56
+ def test_two_classes_2
57
+ cm = ConfusionMatrix.new
58
+ 5.times { cm.add_for(:pos, :pos) }
59
+ 1.times { cm.add_for(:pos, :neg) }
60
+ 3.times { cm.add_for(:neg, :neg) }
61
+ 2.times { cm.add_for(:neg, :pos) }
62
+
63
+ assert_equal(11, cm.total)
64
+ assert_equal(5, cm.true_positive(:pos))
65
+ assert_equal(1, cm.false_negative(:pos))
66
+ assert_equal(2, cm.false_positive(:pos))
67
+ assert_equal(3, cm.true_negative(:pos))
68
+
69
+ assert_in_delta(0.7142, cm.precision(:pos))
70
+ assert_in_delta(0.8333, cm.recall(:pos))
71
+ assert_in_delta(0.7272, cm.overall_accuracy)
72
+ assert_in_delta(0.7692, cm.f_measure(:pos))
73
+ assert_in_delta(0.8333, cm.sensitivity(:pos))
74
+ assert_in_delta(0.6, cm.specificity(:pos))
75
+ assert_in_delta(0.4407, cm.kappa(:pos))
76
+ assert_in_delta(0.5454, cm.prevalence(:pos))
77
+ end
78
+
79
+ # Examples from:
80
+ # https://standardwisdom.com/softwarejournal/2011/12/matthews-correlation-coefficient-how-well-does-it-do/
81
+ def two_class_case(a,b,c,d,e,f,g,h,i)
82
+ cm = ConfusionMatrix.new
83
+ a.times { cm.add_for(:pos, :pos) }
84
+ b.times { cm.add_for(:pos, :neg) }
85
+ c.times { cm.add_for(:neg, :neg) }
86
+ d.times { cm.add_for(:neg, :pos) }
87
+
88
+ assert_in_delta(e, cm.matthews_correlation(:pos))
89
+ assert_in_delta(f, cm.precision(:pos))
90
+ assert_in_delta(g, cm.recall(:pos))
91
+ assert_in_delta(h, cm.f_measure(:pos))
92
+ assert_in_delta(i, cm.kappa(:pos))
93
+ end
94
+
95
+ def test_two_classes_3
96
+ two_class_case(100, 0, 900, 0, 1.0, 1.0, 1.0, 1.0, 1.0)
97
+ two_class_case(65, 35, 825, 75, 0.490, 0.4643, 0.65, 0.542, 0.4811)
98
+ two_class_case(50, 50, 700, 200, 0.192, 0.2, 0.5, 0.286, 0.1666)
99
+ end
100
+
101
+ def test_three_classes
102
+ cm = ConfusionMatrix.new
103
+ 10.times { cm.add_for(:red, :red) }
104
+ 7.times { cm.add_for(:red, :blue) }
105
+ 5.times { cm.add_for(:red, :green) }
106
+ 20.times { cm.add_for(:blue, :red) }
107
+ 5.times { cm.add_for(:blue, :blue) }
108
+ 15.times { cm.add_for(:blue, :green) }
109
+ 30.times { cm.add_for(:green, :red) }
110
+ 12.times { cm.add_for(:green, :blue) }
111
+ 8.times { cm.add_for(:green, :green) }
112
+
113
+ assert_equal([:blue, :green, :red], cm.labels)
114
+ assert_equal(112, cm.total)
115
+ assert_equal(10, cm.true_positive(:red))
116
+ assert_equal(12, cm.false_negative(:red))
117
+ assert_equal(50, cm.false_positive(:red))
118
+ assert_equal(13, cm.true_negative(:red))
119
+ assert_equal(5, cm.true_positive(:blue))
120
+ assert_equal(35, cm.false_negative(:blue))
121
+ assert_equal(19, cm.false_positive(:blue))
122
+ assert_equal(18, cm.true_negative(:blue))
123
+ assert_equal(8, cm.true_positive(:green))
124
+ assert_equal(42, cm.false_negative(:green))
125
+ assert_equal(20, cm.false_positive(:green))
126
+ assert_equal(15, cm.true_negative(:green))
127
+ end
128
+
129
+ def test_add_for_n
130
+ cm = ConfusionMatrix.new
131
+ cm.add_for(:pos, :pos, 3)
132
+ cm.add_for(:pos, :neg)
133
+ cm.add_for(:neg, :pos, 2)
134
+ cm.add_for(:neg, :neg, 1)
135
+ assert_equal(7, cm.total)
136
+ assert_equal(3, cm.count_for(:pos, :pos))
137
+ # - check errors
138
+ assert_raises(ArgumentError) { cm.add_for(:pos, :pos, 0) }
139
+ assert_raises(ArgumentError) { cm.add_for(:pos, :pos, -3) }
140
+ assert_raises(ArgumentError) { cm.add_for(:pos, :pos, nil) }
141
+ end
142
+
143
+ def test_use_labels
144
+ # - check errors
145
+ assert_raises(ArgumentError) { ConfusionMatrix.new(:pos) }
146
+ assert_raises(ArgumentError) { ConfusionMatrix.new(:pos, :pos) }
147
+ # - check created matrix
148
+ cm = ConfusionMatrix.new(:pos, :neg)
149
+ assert_equal([:pos, :neg], cm.labels)
150
+ assert_raises(ArgumentError) { cm.add_for(:pos, :nothing) }
151
+ cm.add_for(:pos, :neg, 3)
152
+ cm.add_for(:neg, :pos, 2)
153
+ assert_equal(2, cm.false_negative(:neg))
154
+ assert_equal(3, cm.false_negative(:pos))
155
+ assert_equal(3, cm.false_negative())
156
+ assert_raises(ArgumentError) { cm.false_negative(:nothing) }
157
+ assert_raises(ArgumentError) { cm.false_negative(nil) }
158
+ end
159
+ end
160
+
metadata ADDED
@@ -0,0 +1,55 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: confusion_matrix
3
+ version: !ruby/object:Gem::Version
4
+ version: 1.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Peter Lane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-02-04 00:00:00.000000000 Z
12
+ dependencies: []
13
+ description: "A confusion matrix is used in data-mining as a summary of the performance
14
+ of a\nclassification algorithm. This library allows the user to incrementally add
15
+ \nresults to a confusion matrix, and then retrieve statistical information.\n"
16
+ email: peterlane@gmx.com
17
+ executables: []
18
+ extensions: []
19
+ extra_rdoc_files:
20
+ - README.rdoc
21
+ - LICENSE.rdoc
22
+ files:
23
+ - LICENSE.rdoc
24
+ - README.rdoc
25
+ - lib/confusion_matrix.rb
26
+ - test/matrix_test.rb
27
+ homepage:
28
+ licenses:
29
+ - MIT
30
+ metadata: {}
31
+ post_install_message:
32
+ rdoc_options:
33
+ - "-m"
34
+ - README.rdoc
35
+ require_paths:
36
+ - lib
37
+ required_ruby_version: !ruby/object:Gem::Requirement
38
+ requirements:
39
+ - - ">="
40
+ - !ruby/object:Gem::Version
41
+ version: '2.5'
42
+ - - "<"
43
+ - !ruby/object:Gem::Version
44
+ version: '4.0'
45
+ required_rubygems_version: !ruby/object:Gem::Requirement
46
+ requirements:
47
+ - - ">="
48
+ - !ruby/object:Gem::Version
49
+ version: '0'
50
+ requirements: []
51
+ rubygems_version: 3.4.5
52
+ signing_key:
53
+ specification_version: 4
54
+ summary: Construct a confusion matrix and retrieve statistical information from it.
55
+ test_files: []