confusion_matrix 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: bc3f9ed40d45bb892541c4a3dc97c04c5893e335ec79b2526d072d291d08a7b7
4
+ data.tar.gz: 602f9857a4e45283357117974943d46f4802fbd9c5ce7b1be0ec704472d29ba4
5
+ SHA512:
6
+ metadata.gz: 20cc86e92c2ad0867206ee2c2a46fe31de7c08c15e90491d481a8c40fd94541b4dea87b575792652cedbb49fea284e70e4f72250ab6f2e83de2e9b57ea7136c6
7
+ data.tar.gz: 972386c261254d2fc44b09f320d48b41639c09682f16c0ddeb88cd2697b683f7b178d257e7edcb3bcba51d0dfc4c625bf478fcadd1a6e2e81ed8354a588266ce
data/LICENSE.rdoc ADDED
@@ -0,0 +1,22 @@
1
+ = MIT License
2
+
3
+ Copyright (c) 2020-23, Peter Lane
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
data/README.rdoc ADDED
@@ -0,0 +1,88 @@
1
+ = Confusion Matrix
2
+
3
+ Install from {RubyGems}[https://rubygems.org/gems/confusion_matrix/]:
4
+
5
+ > gem install confusion_matrix
6
+
7
+ source:: https://notabug.org/peterlane/confusion-matrix-ruby/
8
+
9
+ == Description
10
+
11
+ A confusion matrix is used in data-mining as a summary of the performance of a
12
+ classification algorithm. Each row represents the _actual_ class of an
13
+ instance, and each column represents the _predicted_ class of that instance,
14
+ i.e. the class that they were classified as. Numbers at each (row, column)
15
+ reflect the total number of instances of actual class "row" which were
16
+ predicted to fall in class "column".
17
+
18
+ A two-class example is:
19
+
20
+ Classified Classified |
21
+ Positive Negative | Actual
22
+ ------------------------------+------------
23
+ a b | Positive
24
+ c d | Negative
25
+
26
+ Here the value:
27
+
28
+ a:: is the number of true positives (those labelled positive and classified positive)
29
+ b:: is the number of false negatives (those labelled positive but classified negative)
30
+ c:: is the number of false positives (those labelled negative but classified positive)
31
+ d:: is the number of true negatives (those labelled negative and classified negative)
32
+
33
+ From this table we can calculate statistics like:
34
+
35
+ true_positive_rate:: a/(a+b)
36
+ positive recall:: a/(a+c)
37
+
38
+ The implementation supports confusion matrices with more than two
39
+ classes, and hence most statistics are calculated with reference to a
40
+ named class. When more than two classes are in use, the statistics
41
+ are calculated as if the named class were positive and all the other
42
+ classes are grouped as if negative.
43
+
44
+ For example, in a three-class example:
45
+
46
+ Classified Classified Classified |
47
+ Red Blue Green | Actual
48
+ --------------------------------------------+------------
49
+ a b c | Red
50
+ d e f | Blue
51
+ g h i | Green
52
+
53
+ We can calculate:
54
+
55
+ true_red_rate:: a/(a+b+c)
56
+ red recall:: a/(a+d+g)
57
+
58
+ == Example
59
+
60
+ The following example creates a simple two-class confusion matrix,
61
+ prints a few statistics and displays the table.
62
+
63
+ require 'confusion_matrix'
64
+
65
+ cm = ConfusionMatrix.new :pos, :neg
66
+ cm.add_for(:pos, :pos, 10)
67
+ 3.times { cm.add_for(:pos, :neg) }
68
+ 20.times { cm.add_for(:neg, :neg) }
69
+ 5.times { cm.add_for(:neg, :pos) }
70
+
71
+ puts "Precision: #{cm.precision}"
72
+ puts "Recall: #{cm.recall}"
73
+ puts "MCC: #{cm.matthews_correlation}"
74
+ puts
75
+ puts(cm.to_s)
76
+
77
+ Output:
78
+
79
+ Precision: 0.6666666666666666
80
+ Recall: 0.7692307692307693
81
+ MCC: 0.5524850114241865
82
+
83
+ Predicted |
84
+ pos neg | Actual
85
+ ----------+-------
86
+ 10 3 | pos
87
+ 5 20 | neg
88
+
@@ -0,0 +1,451 @@
1
+
2
+ # This class holds the confusion matrix information.
3
+ # It is designed to be called incrementally, as results are obtained
4
+ # from the classifier model.
5
+ #
6
+ # At any point, statistics may be obtained by calling the relevant methods.
7
+ #
8
+ # A two-class example is:
9
+ #
10
+ # Classified Classified |
11
+ # Positive Negative | Actual
12
+ # ------------------------------+------------
13
+ # a b | Positive
14
+ # c d | Negative
15
+ #
16
+ # Statistical methods will be described with reference to this example.
17
+ #
18
+ class ConfusionMatrix
19
+ # Creates a new, empty instance of a confusion matrix.
20
+ #
21
+ # @param labels [<String, Symbol>, ...] if provided, makes the matrix
22
+ # use the first label as a default label, and also check
23
+ # all operations use one of the pre-defined labels.
24
+ # @raise [ArgumentError] if there are not at least two unique labels, when provided.
25
+ def initialize(*labels)
26
+ @matrix = {}
27
+ @labels = labels.uniq
28
+ if @labels.size == 1
29
+ raise ArgumentError.new("If labels are provided, there must be at least two.")
30
+ else # preset the matrix Hash
31
+ @labels.each do |actual|
32
+ @matrix[actual] = {}
33
+ @labels.each do |predicted|
34
+ @matrix[actual][predicted] = 0
35
+ end
36
+ end
37
+ end
38
+ end
39
+
40
+ # Returns a list of labels used in the matrix.
41
+ #
42
+ # cm = ConfusionMatrix.new
43
+ # cm.add_for(:pos, :neg)
44
+ # cm.labels # => [:neg, :pos]
45
+ #
46
+ # @return [Array<String>] labels used in the matrix.
47
+ def labels
48
+ if @labels.size >= 2 # if we defined some labels, return them
49
+ @labels
50
+ else
51
+ result = []
52
+
53
+ @matrix.each_pair do |key, predictions|
54
+ result << key
55
+ predictions.each_key do |key|
56
+ result << key
57
+ end
58
+ end
59
+
60
+ result.uniq.sort
61
+ end
62
+ end
63
+
64
+ # Return the count for (actual,prediction) pair.
65
+ #
66
+ # cm = ConfusionMatrix.new
67
+ # cm.add_for(:pos, :neg)
68
+ # cm.count_for(:pos, :neg) # => 1
69
+ #
70
+ # @param actual [String, Symbol] is actual class of the instance,
71
+ # which we expect the classifier to predict
72
+ # @param prediction [String, Symbol] is the predicted class of the instance,
73
+ # as output from the classifier
74
+ # @return [Integer] number of observations of (actual, prediction) pair
75
+ # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
76
+ # pre-defined labels in matrix
77
+ def count_for(actual, prediction)
78
+ validate_label actual, prediction
79
+ predictions = @matrix.fetch(actual, {})
80
+ predictions.fetch(prediction, 0)
81
+ end
82
+
83
+ # Adds one result to the matrix for a given (actual, prediction) pair of labels.
84
+ # If the matrix was given a pre-defined list of labels on construction, then
85
+ # these given labels must be from the pre-defined list.
86
+ # If no pre-defined list of labels was used in constructing the matrix, then
87
+ # labels will be added to matrix.
88
+ #
89
+ # Class labels may be any hashable value, though ideally they are strings or symbols.
90
+ #
91
+ # @param actual [String, Symbol] is actual class of the instance,
92
+ # which we expect the classifier to predict
93
+ # @param prediction [String, Symbol] is the predicted class of the instance,
94
+ # as output from the classifier
95
+ # @param n [Integer] number of observations to add
96
+ # @raise [ArgumentError] if +n+ is not an Integer
97
+ # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
98
+ # pre-defined labels in matrix
99
+ def add_for(actual, prediction, n = 1)
100
+ validate_label actual, prediction
101
+ if !@matrix.has_key?(actual)
102
+ @matrix[actual] = {}
103
+ end
104
+ predictions = @matrix[actual]
105
+ if !predictions.has_key?(prediction)
106
+ predictions[prediction] = 0
107
+ end
108
+
109
+ unless n.class == Integer and n.positive?
110
+ raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
111
+ end
112
+
113
+ @matrix[actual][prediction] += n
114
+ end
115
+
116
+ # Returns the number of instances of the given class label which
117
+ # are incorrectly classified.
118
+ #
119
+ # false_negative(:positive) = b
120
+ #
121
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
122
+ # @return [Float] value of false negative
123
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
124
+ def false_negative(label = @labels.first)
125
+ validate_label label
126
+ predictions = @matrix.fetch(label, {})
127
+ total = 0
128
+
129
+ predictions.each_pair do |key, count|
130
+ if key != label
131
+ total += count
132
+ end
133
+ end
134
+
135
+ total
136
+ end
137
+
138
+ # Returns the number of instances incorrectly classified with the given
139
+ # class label.
140
+ #
141
+ # false_positive(:positive) = c
142
+ #
143
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
144
+ # @return [Float] value of false positive
145
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
146
+ def false_positive(label = @labels.first)
147
+ validate_label label
148
+ total = 0
149
+
150
+ @matrix.each_pair do |key, predictions|
151
+ if key != label
152
+ total += predictions.fetch(label, 0)
153
+ end
154
+ end
155
+
156
+ total
157
+ end
158
+
159
+ # The false rate for a given class label is the proportion of instances
160
+ # incorrectly classified as that label, out of all those instances
161
+ # not originally of that label.
162
+ #
163
+ # false_rate(:positive) = c/(c+d)
164
+ #
165
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
166
+ # @return [Float] value of false rate
167
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
168
+ def false_rate(label = @labels.first)
169
+ validate_label label
170
+ fp = false_positive(label)
171
+ tn = true_negative(label)
172
+
173
+ divide(fp, fp+tn)
174
+ end
175
+
176
+ # The F-measure for a given label is the harmonic mean of the precision
177
+ # and recall for that label.
178
+ #
179
+ # F = 2*(precision*recall)/(precision+recall)
180
+ #
181
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
182
+ # @return [Float] value of F-measure
183
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
184
+ def f_measure(label = @labels.first)
185
+ validate_label label
186
+ 2*precision(label)*recall(label)/(precision(label) + recall(label))
187
+ end
188
+
189
+ # The geometric mean is the nth-root of the product of the true_rate for
190
+ # each label.
191
+ #
192
+ # a1 = a/(a+b)
193
+ # a2 = d/(c+d)
194
+ # geometric_mean = Math.sqrt(a1*a2)
195
+ #
196
+ # @return [Float] value of geometric mean
197
+ def geometric_mean
198
+ product = 1
199
+
200
+ @matrix.each_key do |key|
201
+ product *= true_rate(key)
202
+ end
203
+
204
+ product**(1.0/@matrix.size)
205
+ end
206
+
207
+ # The Kappa statistic compares the observed accuracy with an expected
208
+ # accuracy.
209
+ #
210
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
211
+ # @return [Float] value of Cohen's Kappa Statistic
212
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
213
+ def kappa(label = @labels.first)
214
+ validate_label label
215
+ tp = true_positive(label)
216
+ fn = false_negative(label)
217
+ fp = false_positive(label)
218
+ tn = true_negative(label)
219
+ total = tp+fn+fp+tn
220
+
221
+ total_accuracy = divide(tp+tn, tp+tn+fp+fn)
222
+ random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
223
+
224
+ divide(total_accuracy - random_accuracy, 1 - random_accuracy)
225
+ end
226
+
227
+ # Matthews Correlation Coefficient is a measure of the quality of binary
228
+ # classifications.
229
+ #
230
+ # mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
231
+ #
232
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
233
+ # @return [Float] value of Matthews Correlation Coefficient
234
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
235
+ def matthews_correlation(label = @labels.first)
236
+ validate_label label
237
+ tp = true_positive(label)
238
+ fn = false_negative(label)
239
+ fp = false_positive(label)
240
+ tn = true_negative(label)
241
+
242
+ divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
243
+ end
244
+
245
+ # The overall accuracy is the proportion of instances which are
246
+ # correctly labelled.
247
+ #
248
+ # overall_accuracy = (a+d)/(a+b+c+d)
249
+ #
250
+ # @return [Float] value of overall accuracy
251
+ def overall_accuracy
252
+ total_correct = 0
253
+
254
+ @matrix.each_pair do |key, predictions|
255
+ total_correct += true_positive(key)
256
+ end
257
+
258
+ divide(total_correct, total)
259
+ end
260
+
261
+ # The precision for a given class label is the proportion of instances
262
+ # classified as that class which are correct.
263
+ #
264
+ # precision(:positive) = a/(a+c)
265
+ #
266
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
267
+ # @return [Float] value of precision
268
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
269
+ def precision(label = @labels.first)
270
+ validate_label label
271
+ tp = true_positive(label)
272
+ fp = false_positive(label)
273
+
274
+ divide(tp, tp+fp)
275
+ end
276
+
277
+ # The prevalence for a given class label is the proportion of instances
278
+ # which were classified as of that label, out of the total.
279
+ #
280
+ # prevalence(:positive) = (a+c)/(a+b+c+d)
281
+ #
282
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
283
+ # @return [Float] value of prevalence
284
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
285
+ def prevalence(label = @labels.first)
286
+ validate_label label
287
+ tp = true_positive(label)
288
+ fn = false_negative(label)
289
+ fp = false_positive(label)
290
+ tn = true_negative(label)
291
+ total = tp+fn+fp+tn
292
+
293
+ divide(tp+fn, total)
294
+ end
295
+
296
+ # The recall is another name for the true rate.
297
+ #
298
+ # @see true_rate
299
+ # @param (see #true_rate)
300
+ # @return (see #true_rate)
301
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
302
+ def recall(label = @labels.first)
303
+ validate_label label
304
+ true_rate(label)
305
+ end
306
+
307
+ # Sensitivity is another name for the true rate.
308
+ #
309
+ # @see true_rate
310
+ # @param (see #true_rate)
311
+ # @return (see #true_rate)
312
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
313
+ def sensitivity(label = @labels.first)
314
+ validate_label label
315
+ true_rate(label)
316
+ end
317
+
318
+ # The specificity for a given class label is 1 - false_rate(label)
319
+ #
320
+ # In two-class case, specificity = 1 - false_positive_rate
321
+ #
322
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
323
+ # @return [Float] value of specificity
324
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
325
+ def specificity(label = @labels.first)
326
+ validate_label label
327
+ 1-false_rate(label)
328
+ end
329
+
330
+ # Returns the table in a string format, representing the entries as a
331
+ # printable table.
332
+ #
333
+ # @return [String] representation as a printable table.
334
+ def to_s
335
+ ls = labels
336
+ result = ""
337
+
338
+ title_line = "Predicted "
339
+ label_line = ""
340
+ ls.each { |l| label_line << "#{l} " }
341
+ label_line << " " while label_line.size < title_line.size
342
+ title_line << " " while title_line.size < label_line.size
343
+ result << title_line << "|\n" << label_line << "| Actual\n"
344
+ result << "-"*title_line.size << "+-------\n"
345
+
346
+ ls.each do |l|
347
+ count_line = ""
348
+ ls.each_with_index do |m, i|
349
+ count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
350
+ end
351
+ result << count_line.ljust(title_line.size) << "| #{l}\n"
352
+ end
353
+
354
+ result
355
+ end
356
+
357
+ # Returns the total number of instances referenced in the matrix.
358
+ #
359
+ # total = a+b+c+d
360
+ #
361
+ # @return [Integer] total number of instances referenced in the matrix.
362
+ def total
363
+ total = 0
364
+
365
+ @matrix.each_value do |predictions|
366
+ predictions.each_value do |count|
367
+ total += count
368
+ end
369
+ end
370
+
371
+ total
372
+ end
373
+
374
+ # Returns the number of instances NOT of the given class label which
375
+ # are correctly classified.
376
+ #
377
+ # true_negative(:positive) = d
378
+ #
379
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
380
+ # @return [Integer] number of instances not of given label which are correctly classified
381
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
382
+ def true_negative(label = @labels.first)
383
+ validate_label label
384
+ total = 0
385
+
386
+ @matrix.each_pair do |key, predictions|
387
+ if key != label
388
+ total += predictions.fetch(key, 0)
389
+ end
390
+ end
391
+
392
+ total
393
+ end
394
+
395
+ # Returns the number of instances of the given class label which are
396
+ # correctly classified.
397
+ #
398
+ # true_positive(:positive) = a
399
+ #
400
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
401
+ # @return [Integer] number of instances of given label which are correctly classified
402
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
403
+ def true_positive(label = @labels.first)
404
+ validate_label label
405
+ predictions = @matrix.fetch(label, {})
406
+ predictions.fetch(label, 0)
407
+ end
408
+
409
+ # The true rate for a given class label is the proportion of instances of
410
+ # that class which are correctly classified.
411
+ #
412
+ # true_rate(:positive) = a/(a+b)
413
+ #
414
+ # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
415
+ # @return [Float] proportion of instances which are correctly classified
416
+ # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
417
+ def true_rate(label = @labels.first)
418
+ validate_label label
419
+ tp = true_positive(label)
420
+ fn = false_negative(label)
421
+
422
+ divide(tp, tp+fn)
423
+ end
424
+
425
+ private
426
+
427
+ # A form of "safe divide".
428
+ # Checks if divisor is zero, and returns 0.0 if so.
429
+ # This avoids a run-time error.
430
+ # Also, ensures floating point division is done.
431
+ def divide(x, y)
432
+ if y.zero?
433
+ 0.0
434
+ else
435
+ x.to_f/y
436
+ end
437
+ end
438
+
439
+ # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
440
+ # Raises ArgumentError if not
441
+ def validate_label *labels
442
+ return true if @labels.empty?
443
+ labels.each do |label|
444
+ unless label and @labels.include?(label)
445
+ raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
446
+ end
447
+ end
448
+ end
449
+ end
450
+
451
+
@@ -0,0 +1,160 @@
1
+ require 'confusion_matrix'
2
+ require 'minitest/autorun'
3
+
4
+ class TestConfusionMatrix < MiniTest::Test
5
+ def test_empty_case
6
+ cm = ConfusionMatrix.new
7
+ assert(0, cm.total)
8
+ assert(0, cm.true_positive(:none))
9
+ assert(0, cm.false_negative(:none))
10
+ assert(0, cm.false_positive(:none))
11
+ assert(0, cm.true_negative(:none))
12
+ assert_in_delta(0, cm.true_rate(:none))
13
+ end
14
+
15
+ def test_two_classes
16
+ cm = ConfusionMatrix.new
17
+ 10.times { cm.add_for(:pos, :pos) }
18
+ 5.times { cm.add_for(:pos, :neg) }
19
+ 20.times { cm.add_for(:neg, :neg) }
20
+ 5.times { cm.add_for(:neg, :pos) }
21
+
22
+ assert_equal([:neg, :pos], cm.labels)
23
+ assert_equal(10, cm.count_for(:pos, :pos))
24
+ assert_equal(5, cm.count_for(:pos, :neg))
25
+ assert_equal(20, cm.count_for(:neg, :neg))
26
+ assert_equal(5, cm.count_for(:neg, :pos))
27
+
28
+ assert_equal(40, cm.total)
29
+ assert_equal(10, cm.true_positive(:pos))
30
+ assert_equal(5, cm.false_negative(:pos))
31
+ assert_equal(5, cm.false_positive(:pos))
32
+ assert_equal(20, cm.true_negative(:pos))
33
+ assert_equal(20, cm.true_positive(:neg))
34
+ assert_equal(5, cm.false_negative(:neg))
35
+ assert_equal(5, cm.false_positive(:neg))
36
+ assert_equal(10, cm.true_negative(:neg))
37
+
38
+ assert_in_delta(0.6667, cm.true_rate(:pos))
39
+ assert_in_delta(0.8, cm.true_rate(:neg))
40
+ assert_in_delta(0.2, cm.false_rate(:pos))
41
+ assert_in_delta(0.3333, cm.false_rate(:neg))
42
+ assert_in_delta(0.6667, cm.precision(:pos))
43
+ assert_in_delta(0.8, cm.precision(:neg))
44
+ assert_in_delta(0.6667, cm.recall(:pos))
45
+ assert_in_delta(0.8, cm.recall(:neg))
46
+ assert_in_delta(0.6667, cm.sensitivity(:pos))
47
+ assert_in_delta(0.8, cm.sensitivity(:neg))
48
+ assert_in_delta(0.75, cm.overall_accuracy)
49
+ assert_in_delta(0.6667, cm.f_measure(:pos))
50
+ assert_in_delta(0.8, cm.f_measure(:neg))
51
+ assert_in_delta(0.7303, cm.geometric_mean)
52
+ end
53
+
54
+ # Example from:
55
+ # https://www.datatechnotes.com/2019/02/accuracy-metrics-in-classification.html
56
+ def test_two_classes_2
57
+ cm = ConfusionMatrix.new
58
+ 5.times { cm.add_for(:pos, :pos) }
59
+ 1.times { cm.add_for(:pos, :neg) }
60
+ 3.times { cm.add_for(:neg, :neg) }
61
+ 2.times { cm.add_for(:neg, :pos) }
62
+
63
+ assert_equal(11, cm.total)
64
+ assert_equal(5, cm.true_positive(:pos))
65
+ assert_equal(1, cm.false_negative(:pos))
66
+ assert_equal(2, cm.false_positive(:pos))
67
+ assert_equal(3, cm.true_negative(:pos))
68
+
69
+ assert_in_delta(0.7142, cm.precision(:pos))
70
+ assert_in_delta(0.8333, cm.recall(:pos))
71
+ assert_in_delta(0.7272, cm.overall_accuracy)
72
+ assert_in_delta(0.7692, cm.f_measure(:pos))
73
+ assert_in_delta(0.8333, cm.sensitivity(:pos))
74
+ assert_in_delta(0.6, cm.specificity(:pos))
75
+ assert_in_delta(0.4407, cm.kappa(:pos))
76
+ assert_in_delta(0.5454, cm.prevalence(:pos))
77
+ end
78
+
79
+ # Examples from:
80
+ # https://standardwisdom.com/softwarejournal/2011/12/matthews-correlation-coefficient-how-well-does-it-do/
81
+ def two_class_case(a,b,c,d,e,f,g,h,i)
82
+ cm = ConfusionMatrix.new
83
+ a.times { cm.add_for(:pos, :pos) }
84
+ b.times { cm.add_for(:pos, :neg) }
85
+ c.times { cm.add_for(:neg, :neg) }
86
+ d.times { cm.add_for(:neg, :pos) }
87
+
88
+ assert_in_delta(e, cm.matthews_correlation(:pos))
89
+ assert_in_delta(f, cm.precision(:pos))
90
+ assert_in_delta(g, cm.recall(:pos))
91
+ assert_in_delta(h, cm.f_measure(:pos))
92
+ assert_in_delta(i, cm.kappa(:pos))
93
+ end
94
+
95
+ def test_two_classes_3
96
+ two_class_case(100, 0, 900, 0, 1.0, 1.0, 1.0, 1.0, 1.0)
97
+ two_class_case(65, 35, 825, 75, 0.490, 0.4643, 0.65, 0.542, 0.4811)
98
+ two_class_case(50, 50, 700, 200, 0.192, 0.2, 0.5, 0.286, 0.1666)
99
+ end
100
+
101
+ def test_three_classes
102
+ cm = ConfusionMatrix.new
103
+ 10.times { cm.add_for(:red, :red) }
104
+ 7.times { cm.add_for(:red, :blue) }
105
+ 5.times { cm.add_for(:red, :green) }
106
+ 20.times { cm.add_for(:blue, :red) }
107
+ 5.times { cm.add_for(:blue, :blue) }
108
+ 15.times { cm.add_for(:blue, :green) }
109
+ 30.times { cm.add_for(:green, :red) }
110
+ 12.times { cm.add_for(:green, :blue) }
111
+ 8.times { cm.add_for(:green, :green) }
112
+
113
+ assert_equal([:blue, :green, :red], cm.labels)
114
+ assert_equal(112, cm.total)
115
+ assert_equal(10, cm.true_positive(:red))
116
+ assert_equal(12, cm.false_negative(:red))
117
+ assert_equal(50, cm.false_positive(:red))
118
+ assert_equal(13, cm.true_negative(:red))
119
+ assert_equal(5, cm.true_positive(:blue))
120
+ assert_equal(35, cm.false_negative(:blue))
121
+ assert_equal(19, cm.false_positive(:blue))
122
+ assert_equal(18, cm.true_negative(:blue))
123
+ assert_equal(8, cm.true_positive(:green))
124
+ assert_equal(42, cm.false_negative(:green))
125
+ assert_equal(20, cm.false_positive(:green))
126
+ assert_equal(15, cm.true_negative(:green))
127
+ end
128
+
129
+ def test_add_for_n
130
+ cm = ConfusionMatrix.new
131
+ cm.add_for(:pos, :pos, 3)
132
+ cm.add_for(:pos, :neg)
133
+ cm.add_for(:neg, :pos, 2)
134
+ cm.add_for(:neg, :neg, 1)
135
+ assert_equal(7, cm.total)
136
+ assert_equal(3, cm.count_for(:pos, :pos))
137
+ # - check errors
138
+ assert_raises(ArgumentError) { cm.add_for(:pos, :pos, 0) }
139
+ assert_raises(ArgumentError) { cm.add_for(:pos, :pos, -3) }
140
+ assert_raises(ArgumentError) { cm.add_for(:pos, :pos, nil) }
141
+ end
142
+
143
+ def test_use_labels
144
+ # - check errors
145
+ assert_raises(ArgumentError) { ConfusionMatrix.new(:pos) }
146
+ assert_raises(ArgumentError) { ConfusionMatrix.new(:pos, :pos) }
147
+ # - check created matrix
148
+ cm = ConfusionMatrix.new(:pos, :neg)
149
+ assert_equal([:pos, :neg], cm.labels)
150
+ assert_raises(ArgumentError) { cm.add_for(:pos, :nothing) }
151
+ cm.add_for(:pos, :neg, 3)
152
+ cm.add_for(:neg, :pos, 2)
153
+ assert_equal(2, cm.false_negative(:neg))
154
+ assert_equal(3, cm.false_negative(:pos))
155
+ assert_equal(3, cm.false_negative())
156
+ assert_raises(ArgumentError) { cm.false_negative(:nothing) }
157
+ assert_raises(ArgumentError) { cm.false_negative(nil) }
158
+ end
159
+ end
160
+
metadata ADDED
@@ -0,0 +1,55 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: confusion_matrix
3
+ version: !ruby/object:Gem::Version
4
+ version: 1.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Peter Lane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-02-04 00:00:00.000000000 Z
12
+ dependencies: []
13
+ description: "A confusion matrix is used in data-mining as a summary of the performance
14
+ of a\nclassification algorithm. This library allows the user to incrementally add
15
+ \nresults to a confusion matrix, and then retrieve statistical information.\n"
16
+ email: peterlane@gmx.com
17
+ executables: []
18
+ extensions: []
19
+ extra_rdoc_files:
20
+ - README.rdoc
21
+ - LICENSE.rdoc
22
+ files:
23
+ - LICENSE.rdoc
24
+ - README.rdoc
25
+ - lib/confusion_matrix.rb
26
+ - test/matrix_test.rb
27
+ homepage:
28
+ licenses:
29
+ - MIT
30
+ metadata: {}
31
+ post_install_message:
32
+ rdoc_options:
33
+ - "-m"
34
+ - README.rdoc
35
+ require_paths:
36
+ - lib
37
+ required_ruby_version: !ruby/object:Gem::Requirement
38
+ requirements:
39
+ - - ">="
40
+ - !ruby/object:Gem::Version
41
+ version: '2.5'
42
+ - - "<"
43
+ - !ruby/object:Gem::Version
44
+ version: '4.0'
45
+ required_rubygems_version: !ruby/object:Gem::Requirement
46
+ requirements:
47
+ - - ">="
48
+ - !ruby/object:Gem::Version
49
+ version: '0'
50
+ requirements: []
51
+ rubygems_version: 3.4.5
52
+ signing_key:
53
+ specification_version: 4
54
+ summary: Construct a confusion matrix and retrieve statistical information from it.
55
+ test_files: []