confusion_matrix 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (6) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE.rdoc +22 -22
  3. data/README.rdoc +88 -88
  4. data/lib/confusion_matrix.rb +383 -451
  5. data/test/matrix_test.rb +160 -160
  6. metadata +10 -15
@@ -1,451 +1,383 @@
1
-
2
- # This class holds the confusion matrix information.
3
- # It is designed to be called incrementally, as results are obtained
4
- # from the classifier model.
5
- #
6
- # At any point, statistics may be obtained by calling the relevant methods.
7
- #
8
- # A two-class example is:
9
- #
10
- # Classified Classified |
11
- # Positive Negative | Actual
12
- # ------------------------------+------------
13
- # a b | Positive
14
- # c d | Negative
15
- #
16
- # Statistical methods will be described with reference to this example.
17
- #
18
- class ConfusionMatrix
19
- # Creates a new, empty instance of a confusion matrix.
20
- #
21
- # @param labels [<String, Symbol>, ...] if provided, makes the matrix
22
- # use the first label as a default label, and also check
23
- # all operations use one of the pre-defined labels.
24
- # @raise [ArgumentError] if there are not at least two unique labels, when provided.
25
- def initialize(*labels)
26
- @matrix = {}
27
- @labels = labels.uniq
28
- if @labels.size == 1
29
- raise ArgumentError.new("If labels are provided, there must be at least two.")
30
- else # preset the matrix Hash
31
- @labels.each do |actual|
32
- @matrix[actual] = {}
33
- @labels.each do |predicted|
34
- @matrix[actual][predicted] = 0
35
- end
36
- end
37
- end
38
- end
39
-
40
- # Returns a list of labels used in the matrix.
41
- #
42
- # cm = ConfusionMatrix.new
43
- # cm.add_for(:pos, :neg)
44
- # cm.labels # => [:neg, :pos]
45
- #
46
- # @return [Array<String>] labels used in the matrix.
47
- def labels
48
- if @labels.size >= 2 # if we defined some labels, return them
49
- @labels
50
- else
51
- result = []
52
-
53
- @matrix.each_pair do |key, predictions|
54
- result << key
55
- predictions.each_key do |key|
56
- result << key
57
- end
58
- end
59
-
60
- result.uniq.sort
61
- end
62
- end
63
-
64
- # Return the count for (actual,prediction) pair.
65
- #
66
- # cm = ConfusionMatrix.new
67
- # cm.add_for(:pos, :neg)
68
- # cm.count_for(:pos, :neg) # => 1
69
- #
70
- # @param actual [String, Symbol] is actual class of the instance,
71
- # which we expect the classifier to predict
72
- # @param prediction [String, Symbol] is the predicted class of the instance,
73
- # as output from the classifier
74
- # @return [Integer] number of observations of (actual, prediction) pair
75
- # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
76
- # pre-defined labels in matrix
77
- def count_for(actual, prediction)
78
- validate_label actual, prediction
79
- predictions = @matrix.fetch(actual, {})
80
- predictions.fetch(prediction, 0)
81
- end
82
-
83
- # Adds one result to the matrix for a given (actual, prediction) pair of labels.
84
- # If the matrix was given a pre-defined list of labels on construction, then
85
- # these given labels must be from the pre-defined list.
86
- # If no pre-defined list of labels was used in constructing the matrix, then
87
- # labels will be added to matrix.
88
- #
89
- # Class labels may be any hashable value, though ideally they are strings or symbols.
90
- #
91
- # @param actual [String, Symbol] is actual class of the instance,
92
- # which we expect the classifier to predict
93
- # @param prediction [String, Symbol] is the predicted class of the instance,
94
- # as output from the classifier
95
- # @param n [Integer] number of observations to add
96
- # @raise [ArgumentError] if +n+ is not an Integer
97
- # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
98
- # pre-defined labels in matrix
99
- def add_for(actual, prediction, n = 1)
100
- validate_label actual, prediction
101
- if !@matrix.has_key?(actual)
102
- @matrix[actual] = {}
103
- end
104
- predictions = @matrix[actual]
105
- if !predictions.has_key?(prediction)
106
- predictions[prediction] = 0
107
- end
108
-
109
- unless n.class == Integer and n.positive?
110
- raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
111
- end
112
-
113
- @matrix[actual][prediction] += n
114
- end
115
-
116
- # Returns the number of instances of the given class label which
117
- # are incorrectly classified.
118
- #
119
- # false_negative(:positive) = b
120
- #
121
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
122
- # @return [Float] value of false negative
123
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
124
- def false_negative(label = @labels.first)
125
- validate_label label
126
- predictions = @matrix.fetch(label, {})
127
- total = 0
128
-
129
- predictions.each_pair do |key, count|
130
- if key != label
131
- total += count
132
- end
133
- end
134
-
135
- total
136
- end
137
-
138
- # Returns the number of instances incorrectly classified with the given
139
- # class label.
140
- #
141
- # false_positive(:positive) = c
142
- #
143
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
144
- # @return [Float] value of false positive
145
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
146
- def false_positive(label = @labels.first)
147
- validate_label label
148
- total = 0
149
-
150
- @matrix.each_pair do |key, predictions|
151
- if key != label
152
- total += predictions.fetch(label, 0)
153
- end
154
- end
155
-
156
- total
157
- end
158
-
159
- # The false rate for a given class label is the proportion of instances
160
- # incorrectly classified as that label, out of all those instances
161
- # not originally of that label.
162
- #
163
- # false_rate(:positive) = c/(c+d)
164
- #
165
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
166
- # @return [Float] value of false rate
167
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
168
- def false_rate(label = @labels.first)
169
- validate_label label
170
- fp = false_positive(label)
171
- tn = true_negative(label)
172
-
173
- divide(fp, fp+tn)
174
- end
175
-
176
- # The F-measure for a given label is the harmonic mean of the precision
177
- # and recall for that label.
178
- #
179
- # F = 2*(precision*recall)/(precision+recall)
180
- #
181
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
182
- # @return [Float] value of F-measure
183
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
184
- def f_measure(label = @labels.first)
185
- validate_label label
186
- 2*precision(label)*recall(label)/(precision(label) + recall(label))
187
- end
188
-
189
- # The geometric mean is the nth-root of the product of the true_rate for
190
- # each label.
191
- #
192
- # a1 = a/(a+b)
193
- # a2 = d/(c+d)
194
- # geometric_mean = Math.sqrt(a1*a2)
195
- #
196
- # @return [Float] value of geometric mean
197
- def geometric_mean
198
- product = 1
199
-
200
- @matrix.each_key do |key|
201
- product *= true_rate(key)
202
- end
203
-
204
- product**(1.0/@matrix.size)
205
- end
206
-
207
- # The Kappa statistic compares the observed accuracy with an expected
208
- # accuracy.
209
- #
210
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
211
- # @return [Float] value of Cohen's Kappa Statistic
212
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
213
- def kappa(label = @labels.first)
214
- validate_label label
215
- tp = true_positive(label)
216
- fn = false_negative(label)
217
- fp = false_positive(label)
218
- tn = true_negative(label)
219
- total = tp+fn+fp+tn
220
-
221
- total_accuracy = divide(tp+tn, tp+tn+fp+fn)
222
- random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
223
-
224
- divide(total_accuracy - random_accuracy, 1 - random_accuracy)
225
- end
226
-
227
- # Matthews Correlation Coefficient is a measure of the quality of binary
228
- # classifications.
229
- #
230
- # mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
231
- #
232
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
233
- # @return [Float] value of Matthews Correlation Coefficient
234
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
235
- def matthews_correlation(label = @labels.first)
236
- validate_label label
237
- tp = true_positive(label)
238
- fn = false_negative(label)
239
- fp = false_positive(label)
240
- tn = true_negative(label)
241
-
242
- divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
243
- end
244
-
245
- # The overall accuracy is the proportion of instances which are
246
- # correctly labelled.
247
- #
248
- # overall_accuracy = (a+d)/(a+b+c+d)
249
- #
250
- # @return [Float] value of overall accuracy
251
- def overall_accuracy
252
- total_correct = 0
253
-
254
- @matrix.each_pair do |key, predictions|
255
- total_correct += true_positive(key)
256
- end
257
-
258
- divide(total_correct, total)
259
- end
260
-
261
- # The precision for a given class label is the proportion of instances
262
- # classified as that class which are correct.
263
- #
264
- # precision(:positive) = a/(a+c)
265
- #
266
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
267
- # @return [Float] value of precision
268
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
269
- def precision(label = @labels.first)
270
- validate_label label
271
- tp = true_positive(label)
272
- fp = false_positive(label)
273
-
274
- divide(tp, tp+fp)
275
- end
276
-
277
- # The prevalence for a given class label is the proportion of instances
278
- # which were classified as of that label, out of the total.
279
- #
280
- # prevalence(:positive) = (a+c)/(a+b+c+d)
281
- #
282
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
283
- # @return [Float] value of prevalence
284
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
285
- def prevalence(label = @labels.first)
286
- validate_label label
287
- tp = true_positive(label)
288
- fn = false_negative(label)
289
- fp = false_positive(label)
290
- tn = true_negative(label)
291
- total = tp+fn+fp+tn
292
-
293
- divide(tp+fn, total)
294
- end
295
-
296
- # The recall is another name for the true rate.
297
- #
298
- # @see true_rate
299
- # @param (see #true_rate)
300
- # @return (see #true_rate)
301
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
302
- def recall(label = @labels.first)
303
- validate_label label
304
- true_rate(label)
305
- end
306
-
307
- # Sensitivity is another name for the true rate.
308
- #
309
- # @see true_rate
310
- # @param (see #true_rate)
311
- # @return (see #true_rate)
312
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
313
- def sensitivity(label = @labels.first)
314
- validate_label label
315
- true_rate(label)
316
- end
317
-
318
- # The specificity for a given class label is 1 - false_rate(label)
319
- #
320
- # In two-class case, specificity = 1 - false_positive_rate
321
- #
322
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
323
- # @return [Float] value of specificity
324
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
325
- def specificity(label = @labels.first)
326
- validate_label label
327
- 1-false_rate(label)
328
- end
329
-
330
- # Returns the table in a string format, representing the entries as a
331
- # printable table.
332
- #
333
- # @return [String] representation as a printable table.
334
- def to_s
335
- ls = labels
336
- result = ""
337
-
338
- title_line = "Predicted "
339
- label_line = ""
340
- ls.each { |l| label_line << "#{l} " }
341
- label_line << " " while label_line.size < title_line.size
342
- title_line << " " while title_line.size < label_line.size
343
- result << title_line << "|\n" << label_line << "| Actual\n"
344
- result << "-"*title_line.size << "+-------\n"
345
-
346
- ls.each do |l|
347
- count_line = ""
348
- ls.each_with_index do |m, i|
349
- count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
350
- end
351
- result << count_line.ljust(title_line.size) << "| #{l}\n"
352
- end
353
-
354
- result
355
- end
356
-
357
- # Returns the total number of instances referenced in the matrix.
358
- #
359
- # total = a+b+c+d
360
- #
361
- # @return [Integer] total number of instances referenced in the matrix.
362
- def total
363
- total = 0
364
-
365
- @matrix.each_value do |predictions|
366
- predictions.each_value do |count|
367
- total += count
368
- end
369
- end
370
-
371
- total
372
- end
373
-
374
- # Returns the number of instances NOT of the given class label which
375
- # are correctly classified.
376
- #
377
- # true_negative(:positive) = d
378
- #
379
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
380
- # @return [Integer] number of instances not of given label which are correctly classified
381
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
382
- def true_negative(label = @labels.first)
383
- validate_label label
384
- total = 0
385
-
386
- @matrix.each_pair do |key, predictions|
387
- if key != label
388
- total += predictions.fetch(key, 0)
389
- end
390
- end
391
-
392
- total
393
- end
394
-
395
- # Returns the number of instances of the given class label which are
396
- # correctly classified.
397
- #
398
- # true_positive(:positive) = a
399
- #
400
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
401
- # @return [Integer] number of instances of given label which are correctly classified
402
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
403
- def true_positive(label = @labels.first)
404
- validate_label label
405
- predictions = @matrix.fetch(label, {})
406
- predictions.fetch(label, 0)
407
- end
408
-
409
- # The true rate for a given class label is the proportion of instances of
410
- # that class which are correctly classified.
411
- #
412
- # true_rate(:positive) = a/(a+b)
413
- #
414
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
415
- # @return [Float] proportion of instances which are correctly classified
416
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
417
- def true_rate(label = @labels.first)
418
- validate_label label
419
- tp = true_positive(label)
420
- fn = false_negative(label)
421
-
422
- divide(tp, tp+fn)
423
- end
424
-
425
- private
426
-
427
- # A form of "safe divide".
428
- # Checks if divisor is zero, and returns 0.0 if so.
429
- # This avoids a run-time error.
430
- # Also, ensures floating point division is done.
431
- def divide(x, y)
432
- if y.zero?
433
- 0.0
434
- else
435
- x.to_f/y
436
- end
437
- end
438
-
439
- # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
440
- # Raises ArgumentError if not
441
- def validate_label *labels
442
- return true if @labels.empty?
443
- labels.each do |label|
444
- unless label and @labels.include?(label)
445
- raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
446
- end
447
- end
448
- end
449
- end
450
-
451
-
1
+ # Instances of this class hold the confusion matrix information.
2
+ # The object is designed to be called incrementally, as results are
3
+ # received.
4
+ # At any point, statistics may be obtained from the current results.
5
+ #
6
+ # A two-label confusion matrix example is:
7
+ #
8
+ # Observed Observed |
9
+ # Positive Negative | Predicted
10
+ # ------------------------------+------------
11
+ # a b | Positive
12
+ # c d | Negative
13
+ #
14
+ # Statistical methods will be described with reference to this example.
15
+ #
16
+ class ConfusionMatrix
17
+ # Creates a new, empty instance of a confusion matrix.
18
+ #
19
+ # labels:: a list of strings or labels. If provided, the first label is
20
+ # used as a default label, and all method calls must use one of the
21
+ # pre-defined labels.
22
+ #
23
+ # Raises an +ArgumentError+ if there are not at least two unique labels, when provided.
24
+ def initialize(*labels)
25
+ @matrix = {}
26
+ @labels = labels.uniq
27
+ if @labels.size == 1
28
+ raise ArgumentError.new("If labels are provided, there must be at least two.")
29
+ else # preset the matrix Hash
30
+ @labels.each do |predefined|
31
+ @matrix[predefined] = {}
32
+ @labels.each do |observed|
33
+ @matrix[predefined][observed] = 0
34
+ end
35
+ end
36
+ end
37
+ end
38
+
39
+ # Returns a list of labels used in the matrix.
40
+ #
41
+ # cm = ConfusionMatrix.new
42
+ # cm.add_for(:pos, :neg)
43
+ # cm.labels # => [:neg, :pos]
44
+ #
45
+ def labels
46
+ if @labels.size >= 2 # if we defined some labels, return them
47
+ @labels
48
+ else
49
+ result = []
50
+
51
+ @matrix.each_pair do |key, observed|
52
+ result << key
53
+ observed.each_key do |key|
54
+ result << key
55
+ end
56
+ end
57
+
58
+ result.uniq.sort
59
+ end
60
+ end
61
+
62
+ # Returns the count for a given (predicted, observed) pair.
63
+ #
64
+ # cm = ConfusionMatrix.new
65
+ # cm.add_for(:pos, :neg)
66
+ # cm.count_for(:pos, :neg) # => 1
67
+ #
68
+ def count_for(predicted, observed)
69
+ validate_label predicted, observed
70
+ observations = @matrix.fetch(predicted, {})
71
+ observations.fetch(observed, 0)
72
+ end
73
+
74
+ # Adds one result to the matrix for a given (predicted, observed) pair of labels.
75
+ #
76
+ # If the matrix was given a pre-defined list of labels on construction, then
77
+ # these given labels must be from the pre-defined list.
78
+ # If no pre-defined list of labels was used in constructing the matrix, then
79
+ # labels will be added to matrix.
80
+ # Labels may be any hashable value, although ideally they are strings or symbols.
81
+ def add_for(predicted, observed, n = 1)
82
+ validate_label predicted, observed
83
+ unless @matrix.has_key?(predicted)
84
+ @matrix[predicted] = {}
85
+ end
86
+ observations = @matrix[predicted]
87
+ unless observations.has_key?(observed)
88
+ observations[observed] = 0
89
+ end
90
+
91
+ unless n.class == Integer and n.positive?
92
+ raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
93
+ end
94
+
95
+ @matrix[predicted][observed] += n
96
+ end
97
+
98
+ # Returns the number of observations of the given label which are incorrect.
99
+ #
100
+ # For example matrix, <code>false_negative(:positive)</code> is +b+
101
+ #
102
+ def false_negative(label = @labels.first)
103
+ validate_label label
104
+ observations = @matrix.fetch(label, {})
105
+ total = 0
106
+
107
+ observations.each_pair do |key, count|
108
+ if key != label
109
+ total += count
110
+ end
111
+ end
112
+
113
+ total
114
+ end
115
+
116
+ # Returns the number of observations incorrect with the given label.
117
+ #
118
+ # For example matrix, <code>false_positive(:positive)</code> is +c+.
119
+ #
120
+ def false_positive(label = @labels.first)
121
+ validate_label label
122
+ total = 0
123
+
124
+ @matrix.each_pair do |key, observations|
125
+ if key != label
126
+ total += observations.fetch(label, 0)
127
+ end
128
+ end
129
+
130
+ total
131
+ end
132
+
133
+ # The false rate for a given label is the proportion of observations
134
+ # incorrect for that label, out of all those observations
135
+ # not originally of that label.
136
+ #
137
+ # For example matrix, <code>false_rate(:positive)</code> is <code>c/(c+d)</code>.
138
+ #
139
+ def false_rate(label = @labels.first)
140
+ validate_label label
141
+ fp = false_positive(label)
142
+ tn = true_negative(label)
143
+
144
+ divide(fp, fp+tn)
145
+ end
146
+
147
+ # The F-measure for a given label is the harmonic mean of the precision
148
+ # and recall for that label.
149
+ #
150
+ # F = 2*(precision*recall)/(precision+recall)
151
+ #
152
+ def f_measure(label = @labels.first)
153
+ validate_label label
154
+ 2*precision(label)*recall(label)/(precision(label) + recall(label))
155
+ end
156
+
157
+ # The geometric mean is the nth-root of the product of the true_rate for
158
+ # each label.
159
+ #
160
+ # For example:
161
+ # - a1 = a/(a+b)
162
+ # - a2 = d/(c+d)
163
+ # - geometric mean = Math.sqrt(a1*a2)
164
+ #
165
+ def geometric_mean
166
+ product = 1
167
+
168
+ @matrix.each_key do |key|
169
+ product *= true_rate(key)
170
+ end
171
+
172
+ product**(1.0/@matrix.size)
173
+ end
174
+
175
+ # The Kappa statistic compares the observed accuracy with an expected
176
+ # accuracy.
177
+ #
178
+ def kappa(label = @labels.first)
179
+ validate_label label
180
+ tp = true_positive(label)
181
+ fn = false_negative(label)
182
+ fp = false_positive(label)
183
+ tn = true_negative(label)
184
+ total = tp+fn+fp+tn
185
+
186
+ total_accuracy = divide(tp+tn, tp+tn+fp+fn)
187
+ random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
188
+
189
+ divide(total_accuracy - random_accuracy, 1 - random_accuracy)
190
+ end
191
+
192
+ # Matthews Correlation Coefficient is a measure of the quality of binary
193
+ # classifications.
194
+ #
195
+ # For example matrix, <code>mathews_correlation(:positive)</code> is
196
+ # <code>(a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))</code>.
197
+ #
198
+ def matthews_correlation(label = @labels.first)
199
+ validate_label label
200
+ tp = true_positive(label)
201
+ fn = false_negative(label)
202
+ fp = false_positive(label)
203
+ tn = true_negative(label)
204
+
205
+ divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
206
+ end
207
+
208
+ # The overall accuracy is the proportion of observations which are
209
+ # correctly labelled.
210
+ #
211
+ # For example matrix, <code>overall_accuracy</code> is <code>(a+d)/(a+b+c+d)</code>.
212
+ #
213
+ def overall_accuracy
214
+ total_correct = 0
215
+
216
+ @matrix.each_pair do |key, observations|
217
+ total_correct += true_positive(key)
218
+ end
219
+
220
+ divide(total_correct, total)
221
+ end
222
+
223
+ # The precision for a given label is the proportion of observations
224
+ # observed as that label which are correct.
225
+ #
226
+ # For example matrix, <code>precision(:positive)</code> is <code>a/(a+c)</code>.
227
+ #
228
+ def precision(label = @labels.first)
229
+ validate_label label
230
+ tp = true_positive(label)
231
+ fp = false_positive(label)
232
+
233
+ divide(tp, tp+fp)
234
+ end
235
+
236
+ # The prevalence for a given label is the proportion of observations
237
+ # which were observed as of that label, out of the total.
238
+ #
239
+ # For example matrix, <code>prevalence(:positive)</code> is <code>(a+c)/(a+b+c+d)</code>.
240
+ #
241
+ def prevalence(label = @labels.first)
242
+ validate_label label
243
+ tp = true_positive(label)
244
+ fn = false_negative(label)
245
+ fp = false_positive(label)
246
+ tn = true_negative(label)
247
+ total = tp+fn+fp+tn
248
+
249
+ divide(tp+fn, total)
250
+ end
251
+
252
+ # The recall is another name for the true rate.
253
+ #
254
+ def recall(label = @labels.first)
255
+ validate_label label
256
+ true_rate(label)
257
+ end
258
+
259
+ # Sensitivity is another name for the true rate.
260
+ #
261
+ def sensitivity(label = @labels.first)
262
+ validate_label label
263
+ true_rate(label)
264
+ end
265
+
266
+ # The specificity for a given label is 1 - false_rate(label)
267
+ #
268
+ # In two-class case, specificity = 1 - false_positive_rate
269
+ #
270
+ def specificity(label = @labels.first)
271
+ validate_label label
272
+ 1-false_rate(label)
273
+ end
274
+
275
+ # Returns the table in a string format, representing the entries as a
276
+ # printable table.
277
+ #
278
+ def to_s
279
+ ls = labels
280
+ result = ""
281
+
282
+ title_line = "Observed "
283
+ label_line = ""
284
+ ls.each { |l| label_line << "#{l} " }
285
+ label_line << " " while label_line.size < title_line.size
286
+ title_line << " " while title_line.size < label_line.size
287
+ result << title_line << "|\n" << label_line << "| Predicted\n"
288
+ result << "-"*title_line.size << "+----------\n"
289
+
290
+ ls.each do |l|
291
+ count_line = ""
292
+ ls.each_with_index do |m, i|
293
+ count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
294
+ end
295
+ result << count_line.ljust(title_line.size) << "| #{l}\n"
296
+ end
297
+
298
+ result
299
+ end
300
+
301
+ # Returns the total number of observations referenced in the matrix.
302
+ #
303
+ # For example matrix, +total+ is <code>a+b+c+d</code>.
304
+ #
305
+ def total
306
+ total = 0
307
+
308
+ @matrix.each_value do |observations|
309
+ observations.each_value do |count|
310
+ total += count
311
+ end
312
+ end
313
+
314
+ total
315
+ end
316
+
317
+ # Returns the number of observations NOT of the given label which are correct.
318
+ #
319
+ # For example matrix, <code>true_negative(:positive)</code> is +d+.
320
+ #
321
+ def true_negative(label = @labels.first)
322
+ validate_label label
323
+ total = 0
324
+
325
+ @matrix.each_pair do |key, observations|
326
+ if key != label
327
+ total += observations.fetch(key, 0)
328
+ end
329
+ end
330
+
331
+ total
332
+ end
333
+
334
+ # Returns the number of observations of the given label which are correct.
335
+ #
336
+ # For example matrix, <code>true_positive(:positive)</code> is +a+.
337
+ #
338
+ def true_positive(label = @labels.first)
339
+ validate_label label
340
+ observations = @matrix.fetch(label, {})
341
+ observations.fetch(label, 0)
342
+ end
343
+
344
+ # The true rate for a given label is the proportion of observations of
345
+ # that label which are correct.
346
+ #
347
+ # For example matrix, <code>true_rate(:positive)</code> is <code>a/(a+b)</code>.
348
+ #
349
+ def true_rate(label = @labels.first)
350
+ validate_label label
351
+ tp = true_positive(label)
352
+ fn = false_negative(label)
353
+
354
+ divide(tp, tp+fn)
355
+ end
356
+
357
+ private
358
+
359
+ # A form of "safe divide".
360
+ # Checks if divisor is zero, and returns 0.0 if so.
361
+ # This avoids a run-time error.
362
+ # Also, ensures floating point division is done.
363
+ def divide(x, y)
364
+ if y.zero?
365
+ 0.0
366
+ else
367
+ x.to_f/y
368
+ end
369
+ end
370
+
371
+ # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
372
+ # Raises ArgumentError if not
373
+ def validate_label *labels
374
+ return true if @labels.empty?
375
+ labels.each do |label|
376
+ unless label and @labels.include?(label)
377
+ raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
378
+ end
379
+ end
380
+ end
381
+ end
382
+
383
+