confusion_matrix 1.1.0 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE.rdoc +22 -22
- data/README.rdoc +88 -88
- data/lib/confusion_matrix.rb +383 -451
- data/test/matrix_test.rb +160 -160
- metadata +10 -15
data/lib/confusion_matrix.rb
CHANGED
|
@@ -1,451 +1,383 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# from the
|
|
5
|
-
#
|
|
6
|
-
#
|
|
7
|
-
#
|
|
8
|
-
#
|
|
9
|
-
#
|
|
10
|
-
#
|
|
11
|
-
#
|
|
12
|
-
#
|
|
13
|
-
#
|
|
14
|
-
#
|
|
15
|
-
#
|
|
16
|
-
|
|
17
|
-
#
|
|
18
|
-
|
|
19
|
-
#
|
|
20
|
-
#
|
|
21
|
-
#
|
|
22
|
-
#
|
|
23
|
-
#
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@
|
|
27
|
-
@labels
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
#
|
|
41
|
-
#
|
|
42
|
-
# cm
|
|
43
|
-
# cm.
|
|
44
|
-
#
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
#
|
|
65
|
-
#
|
|
66
|
-
# cm
|
|
67
|
-
#
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
#
|
|
75
|
-
#
|
|
76
|
-
#
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
end
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
end
|
|
115
|
-
|
|
116
|
-
# Returns the number of
|
|
117
|
-
#
|
|
118
|
-
#
|
|
119
|
-
#
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
#
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
#
|
|
160
|
-
#
|
|
161
|
-
#
|
|
162
|
-
#
|
|
163
|
-
#
|
|
164
|
-
#
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
#
|
|
177
|
-
#
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
#
|
|
193
|
-
#
|
|
194
|
-
#
|
|
195
|
-
#
|
|
196
|
-
#
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
# accuracy
|
|
209
|
-
#
|
|
210
|
-
#
|
|
211
|
-
#
|
|
212
|
-
#
|
|
213
|
-
def
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
#
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
#
|
|
267
|
-
#
|
|
268
|
-
#
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
#
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
#
|
|
319
|
-
#
|
|
320
|
-
#
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
end
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
#
|
|
360
|
-
#
|
|
361
|
-
#
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
total = 0
|
|
385
|
-
|
|
386
|
-
@matrix.each_pair do |key, predictions|
|
|
387
|
-
if key != label
|
|
388
|
-
total += predictions.fetch(key, 0)
|
|
389
|
-
end
|
|
390
|
-
end
|
|
391
|
-
|
|
392
|
-
total
|
|
393
|
-
end
|
|
394
|
-
|
|
395
|
-
# Returns the number of instances of the given class label which are
|
|
396
|
-
# correctly classified.
|
|
397
|
-
#
|
|
398
|
-
# true_positive(:positive) = a
|
|
399
|
-
#
|
|
400
|
-
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
|
401
|
-
# @return [Integer] number of instances of given label which are correctly classified
|
|
402
|
-
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
|
403
|
-
def true_positive(label = @labels.first)
|
|
404
|
-
validate_label label
|
|
405
|
-
predictions = @matrix.fetch(label, {})
|
|
406
|
-
predictions.fetch(label, 0)
|
|
407
|
-
end
|
|
408
|
-
|
|
409
|
-
# The true rate for a given class label is the proportion of instances of
|
|
410
|
-
# that class which are correctly classified.
|
|
411
|
-
#
|
|
412
|
-
# true_rate(:positive) = a/(a+b)
|
|
413
|
-
#
|
|
414
|
-
# @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
|
|
415
|
-
# @return [Float] proportion of instances which are correctly classified
|
|
416
|
-
# @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
|
|
417
|
-
def true_rate(label = @labels.first)
|
|
418
|
-
validate_label label
|
|
419
|
-
tp = true_positive(label)
|
|
420
|
-
fn = false_negative(label)
|
|
421
|
-
|
|
422
|
-
divide(tp, tp+fn)
|
|
423
|
-
end
|
|
424
|
-
|
|
425
|
-
private
|
|
426
|
-
|
|
427
|
-
# A form of "safe divide".
|
|
428
|
-
# Checks if divisor is zero, and returns 0.0 if so.
|
|
429
|
-
# This avoids a run-time error.
|
|
430
|
-
# Also, ensures floating point division is done.
|
|
431
|
-
def divide(x, y)
|
|
432
|
-
if y.zero?
|
|
433
|
-
0.0
|
|
434
|
-
else
|
|
435
|
-
x.to_f/y
|
|
436
|
-
end
|
|
437
|
-
end
|
|
438
|
-
|
|
439
|
-
# Checks if given label(s) is non-nil and in @labels, or if @labels is empty
|
|
440
|
-
# Raises ArgumentError if not
|
|
441
|
-
def validate_label *labels
|
|
442
|
-
return true if @labels.empty?
|
|
443
|
-
labels.each do |label|
|
|
444
|
-
unless label and @labels.include?(label)
|
|
445
|
-
raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
|
|
446
|
-
end
|
|
447
|
-
end
|
|
448
|
-
end
|
|
449
|
-
end
|
|
450
|
-
|
|
451
|
-
|
|
1
|
+
# Instances of this class hold the confusion matrix information.
|
|
2
|
+
# The object is designed to be called incrementally, as results are
|
|
3
|
+
# received.
|
|
4
|
+
# At any point, statistics may be obtained from the current results.
|
|
5
|
+
#
|
|
6
|
+
# A two-label confusion matrix example is:
|
|
7
|
+
#
|
|
8
|
+
# Observed Observed |
|
|
9
|
+
# Positive Negative | Predicted
|
|
10
|
+
# ------------------------------+------------
|
|
11
|
+
# a b | Positive
|
|
12
|
+
# c d | Negative
|
|
13
|
+
#
|
|
14
|
+
# Statistical methods will be described with reference to this example.
|
|
15
|
+
#
|
|
16
|
+
class ConfusionMatrix
|
|
17
|
+
# Creates a new, empty instance of a confusion matrix.
|
|
18
|
+
#
|
|
19
|
+
# labels:: a list of strings or labels. If provided, the first label is
|
|
20
|
+
# used as a default label, and all method calls must use one of the
|
|
21
|
+
# pre-defined labels.
|
|
22
|
+
#
|
|
23
|
+
# Raises an +ArgumentError+ if there are not at least two unique labels, when provided.
|
|
24
|
+
def initialize(*labels)
|
|
25
|
+
@matrix = {}
|
|
26
|
+
@labels = labels.uniq
|
|
27
|
+
if @labels.size == 1
|
|
28
|
+
raise ArgumentError.new("If labels are provided, there must be at least two.")
|
|
29
|
+
else # preset the matrix Hash
|
|
30
|
+
@labels.each do |predefined|
|
|
31
|
+
@matrix[predefined] = {}
|
|
32
|
+
@labels.each do |observed|
|
|
33
|
+
@matrix[predefined][observed] = 0
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
# Returns a list of labels used in the matrix.
|
|
40
|
+
#
|
|
41
|
+
# cm = ConfusionMatrix.new
|
|
42
|
+
# cm.add_for(:pos, :neg)
|
|
43
|
+
# cm.labels # => [:neg, :pos]
|
|
44
|
+
#
|
|
45
|
+
def labels
|
|
46
|
+
if @labels.size >= 2 # if we defined some labels, return them
|
|
47
|
+
@labels
|
|
48
|
+
else
|
|
49
|
+
result = []
|
|
50
|
+
|
|
51
|
+
@matrix.each_pair do |key, observed|
|
|
52
|
+
result << key
|
|
53
|
+
observed.each_key do |key|
|
|
54
|
+
result << key
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
result.uniq.sort
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Returns the count for a given (predicted, observed) pair.
|
|
63
|
+
#
|
|
64
|
+
# cm = ConfusionMatrix.new
|
|
65
|
+
# cm.add_for(:pos, :neg)
|
|
66
|
+
# cm.count_for(:pos, :neg) # => 1
|
|
67
|
+
#
|
|
68
|
+
def count_for(predicted, observed)
|
|
69
|
+
validate_label predicted, observed
|
|
70
|
+
observations = @matrix.fetch(predicted, {})
|
|
71
|
+
observations.fetch(observed, 0)
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Adds one result to the matrix for a given (predicted, observed) pair of labels.
|
|
75
|
+
#
|
|
76
|
+
# If the matrix was given a pre-defined list of labels on construction, then
|
|
77
|
+
# these given labels must be from the pre-defined list.
|
|
78
|
+
# If no pre-defined list of labels was used in constructing the matrix, then
|
|
79
|
+
# labels will be added to matrix.
|
|
80
|
+
# Labels may be any hashable value, although ideally they are strings or symbols.
|
|
81
|
+
def add_for(predicted, observed, n = 1)
|
|
82
|
+
validate_label predicted, observed
|
|
83
|
+
unless @matrix.has_key?(predicted)
|
|
84
|
+
@matrix[predicted] = {}
|
|
85
|
+
end
|
|
86
|
+
observations = @matrix[predicted]
|
|
87
|
+
unless observations.has_key?(observed)
|
|
88
|
+
observations[observed] = 0
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
unless n.class == Integer and n.positive?
|
|
92
|
+
raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
@matrix[predicted][observed] += n
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Returns the number of observations of the given label which are incorrect.
|
|
99
|
+
#
|
|
100
|
+
# For example matrix, <code>false_negative(:positive)</code> is +b+
|
|
101
|
+
#
|
|
102
|
+
def false_negative(label = @labels.first)
|
|
103
|
+
validate_label label
|
|
104
|
+
observations = @matrix.fetch(label, {})
|
|
105
|
+
total = 0
|
|
106
|
+
|
|
107
|
+
observations.each_pair do |key, count|
|
|
108
|
+
if key != label
|
|
109
|
+
total += count
|
|
110
|
+
end
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
total
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
# Returns the number of observations incorrect with the given label.
|
|
117
|
+
#
|
|
118
|
+
# For example matrix, <code>false_positive(:positive)</code> is +c+.
|
|
119
|
+
#
|
|
120
|
+
def false_positive(label = @labels.first)
|
|
121
|
+
validate_label label
|
|
122
|
+
total = 0
|
|
123
|
+
|
|
124
|
+
@matrix.each_pair do |key, observations|
|
|
125
|
+
if key != label
|
|
126
|
+
total += observations.fetch(label, 0)
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
total
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
# The false rate for a given label is the proportion of observations
|
|
134
|
+
# incorrect for that label, out of all those observations
|
|
135
|
+
# not originally of that label.
|
|
136
|
+
#
|
|
137
|
+
# For example matrix, <code>false_rate(:positive)</code> is <code>c/(c+d)</code>.
|
|
138
|
+
#
|
|
139
|
+
def false_rate(label = @labels.first)
|
|
140
|
+
validate_label label
|
|
141
|
+
fp = false_positive(label)
|
|
142
|
+
tn = true_negative(label)
|
|
143
|
+
|
|
144
|
+
divide(fp, fp+tn)
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
# The F-measure for a given label is the harmonic mean of the precision
|
|
148
|
+
# and recall for that label.
|
|
149
|
+
#
|
|
150
|
+
# F = 2*(precision*recall)/(precision+recall)
|
|
151
|
+
#
|
|
152
|
+
def f_measure(label = @labels.first)
|
|
153
|
+
validate_label label
|
|
154
|
+
2*precision(label)*recall(label)/(precision(label) + recall(label))
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# The geometric mean is the nth-root of the product of the true_rate for
|
|
158
|
+
# each label.
|
|
159
|
+
#
|
|
160
|
+
# For example:
|
|
161
|
+
# - a1 = a/(a+b)
|
|
162
|
+
# - a2 = d/(c+d)
|
|
163
|
+
# - geometric mean = Math.sqrt(a1*a2)
|
|
164
|
+
#
|
|
165
|
+
def geometric_mean
|
|
166
|
+
product = 1
|
|
167
|
+
|
|
168
|
+
@matrix.each_key do |key|
|
|
169
|
+
product *= true_rate(key)
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
product**(1.0/@matrix.size)
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
# The Kappa statistic compares the observed accuracy with an expected
|
|
176
|
+
# accuracy.
|
|
177
|
+
#
|
|
178
|
+
def kappa(label = @labels.first)
|
|
179
|
+
validate_label label
|
|
180
|
+
tp = true_positive(label)
|
|
181
|
+
fn = false_negative(label)
|
|
182
|
+
fp = false_positive(label)
|
|
183
|
+
tn = true_negative(label)
|
|
184
|
+
total = tp+fn+fp+tn
|
|
185
|
+
|
|
186
|
+
total_accuracy = divide(tp+tn, tp+tn+fp+fn)
|
|
187
|
+
random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
|
|
188
|
+
|
|
189
|
+
divide(total_accuracy - random_accuracy, 1 - random_accuracy)
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
# Matthews Correlation Coefficient is a measure of the quality of binary
|
|
193
|
+
# classifications.
|
|
194
|
+
#
|
|
195
|
+
# For example matrix, <code>mathews_correlation(:positive)</code> is
|
|
196
|
+
# <code>(a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))</code>.
|
|
197
|
+
#
|
|
198
|
+
def matthews_correlation(label = @labels.first)
|
|
199
|
+
validate_label label
|
|
200
|
+
tp = true_positive(label)
|
|
201
|
+
fn = false_negative(label)
|
|
202
|
+
fp = false_positive(label)
|
|
203
|
+
tn = true_negative(label)
|
|
204
|
+
|
|
205
|
+
divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
# The overall accuracy is the proportion of observations which are
|
|
209
|
+
# correctly labelled.
|
|
210
|
+
#
|
|
211
|
+
# For example matrix, <code>overall_accuracy</code> is <code>(a+d)/(a+b+c+d)</code>.
|
|
212
|
+
#
|
|
213
|
+
def overall_accuracy
|
|
214
|
+
total_correct = 0
|
|
215
|
+
|
|
216
|
+
@matrix.each_pair do |key, observations|
|
|
217
|
+
total_correct += true_positive(key)
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
divide(total_correct, total)
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
# The precision for a given label is the proportion of observations
|
|
224
|
+
# observed as that label which are correct.
|
|
225
|
+
#
|
|
226
|
+
# For example matrix, <code>precision(:positive)</code> is <code>a/(a+c)</code>.
|
|
227
|
+
#
|
|
228
|
+
def precision(label = @labels.first)
|
|
229
|
+
validate_label label
|
|
230
|
+
tp = true_positive(label)
|
|
231
|
+
fp = false_positive(label)
|
|
232
|
+
|
|
233
|
+
divide(tp, tp+fp)
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
# The prevalence for a given label is the proportion of observations
|
|
237
|
+
# which were observed as of that label, out of the total.
|
|
238
|
+
#
|
|
239
|
+
# For example matrix, <code>prevalence(:positive)</code> is <code>(a+c)/(a+b+c+d)</code>.
|
|
240
|
+
#
|
|
241
|
+
def prevalence(label = @labels.first)
|
|
242
|
+
validate_label label
|
|
243
|
+
tp = true_positive(label)
|
|
244
|
+
fn = false_negative(label)
|
|
245
|
+
fp = false_positive(label)
|
|
246
|
+
tn = true_negative(label)
|
|
247
|
+
total = tp+fn+fp+tn
|
|
248
|
+
|
|
249
|
+
divide(tp+fn, total)
|
|
250
|
+
end
|
|
251
|
+
|
|
252
|
+
# The recall is another name for the true rate.
|
|
253
|
+
#
|
|
254
|
+
def recall(label = @labels.first)
|
|
255
|
+
validate_label label
|
|
256
|
+
true_rate(label)
|
|
257
|
+
end
|
|
258
|
+
|
|
259
|
+
# Sensitivity is another name for the true rate.
|
|
260
|
+
#
|
|
261
|
+
def sensitivity(label = @labels.first)
|
|
262
|
+
validate_label label
|
|
263
|
+
true_rate(label)
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
# The specificity for a given label is 1 - false_rate(label)
|
|
267
|
+
#
|
|
268
|
+
# In two-class case, specificity = 1 - false_positive_rate
|
|
269
|
+
#
|
|
270
|
+
def specificity(label = @labels.first)
|
|
271
|
+
validate_label label
|
|
272
|
+
1-false_rate(label)
|
|
273
|
+
end
|
|
274
|
+
|
|
275
|
+
# Returns the table in a string format, representing the entries as a
|
|
276
|
+
# printable table.
|
|
277
|
+
#
|
|
278
|
+
def to_s
|
|
279
|
+
ls = labels
|
|
280
|
+
result = ""
|
|
281
|
+
|
|
282
|
+
title_line = "Observed "
|
|
283
|
+
label_line = ""
|
|
284
|
+
ls.each { |l| label_line << "#{l} " }
|
|
285
|
+
label_line << " " while label_line.size < title_line.size
|
|
286
|
+
title_line << " " while title_line.size < label_line.size
|
|
287
|
+
result << title_line << "|\n" << label_line << "| Predicted\n"
|
|
288
|
+
result << "-"*title_line.size << "+----------\n"
|
|
289
|
+
|
|
290
|
+
ls.each do |l|
|
|
291
|
+
count_line = ""
|
|
292
|
+
ls.each_with_index do |m, i|
|
|
293
|
+
count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
|
|
294
|
+
end
|
|
295
|
+
result << count_line.ljust(title_line.size) << "| #{l}\n"
|
|
296
|
+
end
|
|
297
|
+
|
|
298
|
+
result
|
|
299
|
+
end
|
|
300
|
+
|
|
301
|
+
# Returns the total number of observations referenced in the matrix.
|
|
302
|
+
#
|
|
303
|
+
# For example matrix, +total+ is <code>a+b+c+d</code>.
|
|
304
|
+
#
|
|
305
|
+
def total
|
|
306
|
+
total = 0
|
|
307
|
+
|
|
308
|
+
@matrix.each_value do |observations|
|
|
309
|
+
observations.each_value do |count|
|
|
310
|
+
total += count
|
|
311
|
+
end
|
|
312
|
+
end
|
|
313
|
+
|
|
314
|
+
total
|
|
315
|
+
end
|
|
316
|
+
|
|
317
|
+
# Returns the number of observations NOT of the given label which are correct.
|
|
318
|
+
#
|
|
319
|
+
# For example matrix, <code>true_negative(:positive)</code> is +d+.
|
|
320
|
+
#
|
|
321
|
+
def true_negative(label = @labels.first)
|
|
322
|
+
validate_label label
|
|
323
|
+
total = 0
|
|
324
|
+
|
|
325
|
+
@matrix.each_pair do |key, observations|
|
|
326
|
+
if key != label
|
|
327
|
+
total += observations.fetch(key, 0)
|
|
328
|
+
end
|
|
329
|
+
end
|
|
330
|
+
|
|
331
|
+
total
|
|
332
|
+
end
|
|
333
|
+
|
|
334
|
+
# Returns the number of observations of the given label which are correct.
|
|
335
|
+
#
|
|
336
|
+
# For example matrix, <code>true_positive(:positive)</code> is +a+.
|
|
337
|
+
#
|
|
338
|
+
def true_positive(label = @labels.first)
|
|
339
|
+
validate_label label
|
|
340
|
+
observations = @matrix.fetch(label, {})
|
|
341
|
+
observations.fetch(label, 0)
|
|
342
|
+
end
|
|
343
|
+
|
|
344
|
+
# The true rate for a given label is the proportion of observations of
|
|
345
|
+
# that label which are correct.
|
|
346
|
+
#
|
|
347
|
+
# For example matrix, <code>true_rate(:positive)</code> is <code>a/(a+b)</code>.
|
|
348
|
+
#
|
|
349
|
+
def true_rate(label = @labels.first)
|
|
350
|
+
validate_label label
|
|
351
|
+
tp = true_positive(label)
|
|
352
|
+
fn = false_negative(label)
|
|
353
|
+
|
|
354
|
+
divide(tp, tp+fn)
|
|
355
|
+
end
|
|
356
|
+
|
|
357
|
+
private
|
|
358
|
+
|
|
359
|
+
# A form of "safe divide".
|
|
360
|
+
# Checks if divisor is zero, and returns 0.0 if so.
|
|
361
|
+
# This avoids a run-time error.
|
|
362
|
+
# Also, ensures floating point division is done.
|
|
363
|
+
def divide(x, y)
|
|
364
|
+
if y.zero?
|
|
365
|
+
0.0
|
|
366
|
+
else
|
|
367
|
+
x.to_f/y
|
|
368
|
+
end
|
|
369
|
+
end
|
|
370
|
+
|
|
371
|
+
# Checks if given label(s) is non-nil and in @labels, or if @labels is empty
|
|
372
|
+
# Raises ArgumentError if not
|
|
373
|
+
def validate_label *labels
|
|
374
|
+
return true if @labels.empty?
|
|
375
|
+
labels.each do |label|
|
|
376
|
+
unless label and @labels.include?(label)
|
|
377
|
+
raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
|
|
378
|
+
end
|
|
379
|
+
end
|
|
380
|
+
end
|
|
381
|
+
end
|
|
382
|
+
|
|
383
|
+
|