orca-sdk 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,570 +0,0 @@
1
- """
2
- IMPORTANT:
3
- - This is a shared file between OrcaLib and the OrcaSDK.
4
- - Please ensure that it does not have any dependencies on the OrcaLib code.
5
- - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
6
- """
7
-
8
- from typing import Literal
9
-
10
- import numpy as np
11
- import pytest
12
- import sklearn.metrics
13
-
14
- from .metrics import (
15
- calculate_classification_metrics,
16
- calculate_pr_curve,
17
- calculate_regression_metrics,
18
- calculate_roc_curve,
19
- softmax,
20
- )
21
-
22
-
23
- def test_binary_metrics():
24
- y_true = np.array([0, 1, 1, 0, 1])
25
- y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
-
27
- metrics = calculate_classification_metrics(y_true, y_score)
28
-
29
- assert metrics.accuracy == 0.8
30
- assert metrics.f1_score == 0.8
31
- assert metrics.roc_auc is not None
32
- assert metrics.roc_auc > 0.8
33
- assert metrics.roc_auc < 1.0
34
- assert metrics.pr_auc is not None
35
- assert metrics.pr_auc > 0.8
36
- assert metrics.pr_auc < 1.0
37
- assert metrics.loss is not None
38
- assert metrics.loss > 0.0
39
-
40
-
41
- def test_multiclass_metrics_with_2_classes():
42
- y_true = np.array([0, 1, 1, 0, 1])
43
- y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
-
45
- metrics = calculate_classification_metrics(y_true, y_score)
46
-
47
- assert metrics.accuracy == 0.8
48
- assert metrics.f1_score == 0.8
49
- assert metrics.roc_auc is not None
50
- assert metrics.roc_auc > 0.8
51
- assert metrics.roc_auc < 1.0
52
- assert metrics.pr_auc is not None
53
- assert metrics.pr_auc > 0.8
54
- assert metrics.pr_auc < 1.0
55
- assert metrics.loss is not None
56
- assert metrics.loss > 0.0
57
-
58
-
59
- @pytest.mark.parametrize(
60
- "average, multiclass",
61
- [("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
62
- )
63
- def test_multiclass_metrics_with_3_classes(
64
- average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
65
- ):
66
- y_true = np.array([0, 1, 1, 0, 2])
67
- y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
-
69
- metrics = calculate_classification_metrics(y_true, y_score, average=average, multi_class=multiclass)
70
-
71
- assert metrics.accuracy == 1.0
72
- assert metrics.f1_score == 1.0
73
- assert metrics.roc_auc is not None
74
- assert metrics.roc_auc > 0.8
75
- assert metrics.pr_auc is None
76
- assert metrics.loss is not None
77
- assert metrics.loss > 0.0
78
-
79
-
80
- def test_does_not_modify_logits_unless_necessary():
81
- logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
- expected_labels = [0, 1, 0, 1]
83
- loss = calculate_classification_metrics(expected_labels, logits).loss
84
- assert loss is not None
85
- assert np.allclose(
86
- loss,
87
- sklearn.metrics.log_loss(expected_labels, logits),
88
- atol=1e-6,
89
- )
90
-
91
-
92
- def test_normalizes_logits_if_necessary():
93
- logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
94
- expected_labels = [0, 1, 0, 1]
95
- loss = calculate_classification_metrics(expected_labels, logits).loss
96
- assert loss is not None
97
- assert np.allclose(
98
- loss,
99
- sklearn.metrics.log_loss(expected_labels, logits / logits.sum(axis=1, keepdims=True)),
100
- atol=1e-6,
101
- )
102
-
103
-
104
- def test_softmaxes_logits_if_necessary():
105
- logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
106
- expected_labels = [0, 1, 0, 1]
107
- loss = calculate_classification_metrics(expected_labels, logits).loss
108
- assert loss is not None
109
- assert np.allclose(
110
- loss,
111
- sklearn.metrics.log_loss(expected_labels, softmax(logits)),
112
- atol=1e-6,
113
- )
114
-
115
-
116
- def test_handles_nan_logits():
117
- logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
118
- expected_labels = [0, 1, 0, 1]
119
- metrics = calculate_classification_metrics(expected_labels, logits)
120
- assert metrics.loss is None
121
- assert metrics.accuracy == 0.25
122
- assert metrics.f1_score == 0.25
123
- assert metrics.roc_auc is None
124
- assert metrics.pr_auc is None
125
- assert metrics.pr_curve is None
126
- assert metrics.roc_curve is None
127
- assert metrics.coverage == 0.5
128
-
129
-
130
- def test_precision_recall_curve():
131
- y_true = np.array([0, 1, 1, 0, 1])
132
- y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
133
-
134
- pr_curve = calculate_pr_curve(y_true, y_score)
135
-
136
- assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 6
137
- assert np.allclose(pr_curve["precisions"][0], 0.6)
138
- assert np.allclose(pr_curve["recalls"][0], 1.0)
139
- assert np.allclose(pr_curve["precisions"][-1], 1.0)
140
- assert np.allclose(pr_curve["recalls"][-1], 0.0)
141
-
142
- # test that thresholds are sorted
143
- assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
144
-
145
-
146
- def test_roc_curve():
147
- y_true = np.array([0, 1, 1, 0, 1])
148
- y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
149
-
150
- roc_curve = calculate_roc_curve(y_true, y_score)
151
-
152
- assert (
153
- len(roc_curve["false_positive_rates"])
154
- == len(roc_curve["true_positive_rates"])
155
- == len(roc_curve["thresholds"])
156
- == 6
157
- )
158
- assert roc_curve["false_positive_rates"][0] == 1.0
159
- assert roc_curve["true_positive_rates"][0] == 1.0
160
- assert roc_curve["false_positive_rates"][-1] == 0.0
161
- assert roc_curve["true_positive_rates"][-1] == 0.0
162
-
163
- # test that thresholds are sorted
164
- assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
165
-
166
-
167
- def test_log_loss_handles_missing_classes_in_y_true():
168
- # y_true contains only a subset of classes, but predictions include an extra class column
169
- y_true = np.array([0, 1, 0, 1])
170
- y_score = np.array(
171
- [
172
- [0.7, 0.2, 0.1],
173
- [0.1, 0.8, 0.1],
174
- [0.6, 0.3, 0.1],
175
- [0.2, 0.7, 0.1],
176
- ]
177
- )
178
-
179
- metrics = calculate_classification_metrics(y_true, y_score)
180
- expected_loss = sklearn.metrics.log_loss(y_true, y_score, labels=[0, 1, 2])
181
-
182
- assert metrics.loss is not None
183
- assert np.allclose(metrics.loss, expected_loss)
184
-
185
-
186
- def test_precision_recall_curve_max_length():
187
- y_true = np.array([0, 1, 1, 0, 1])
188
- y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
189
-
190
- pr_curve = calculate_pr_curve(y_true, y_score, max_length=5)
191
- assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 5
192
-
193
- assert np.allclose(pr_curve["precisions"][0], 0.6)
194
- assert np.allclose(pr_curve["recalls"][0], 1.0)
195
- assert np.allclose(pr_curve["precisions"][-1], 1.0)
196
- assert np.allclose(pr_curve["recalls"][-1], 0.0)
197
-
198
- # test that thresholds are sorted
199
- assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
200
-
201
-
202
- def test_roc_curve_max_length():
203
- y_true = np.array([0, 1, 1, 0, 1])
204
- y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
205
-
206
- roc_curve = calculate_roc_curve(y_true, y_score, max_length=5)
207
- assert (
208
- len(roc_curve["false_positive_rates"])
209
- == len(roc_curve["true_positive_rates"])
210
- == len(roc_curve["thresholds"])
211
- == 5
212
- )
213
- assert np.allclose(roc_curve["false_positive_rates"][0], 1.0)
214
- assert np.allclose(roc_curve["true_positive_rates"][0], 1.0)
215
- assert np.allclose(roc_curve["false_positive_rates"][-1], 0.0)
216
- assert np.allclose(roc_curve["true_positive_rates"][-1], 0.0)
217
-
218
- # test that thresholds are sorted
219
- assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
220
-
221
-
222
- # Regression Metrics Tests
223
- def test_perfect_regression_predictions():
224
- y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
225
- y_pred = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
226
-
227
- metrics = calculate_regression_metrics(y_true, y_pred)
228
-
229
- assert metrics.mse == 0.0
230
- assert metrics.rmse == 0.0
231
- assert metrics.mae == 0.0
232
- assert metrics.r2 == 1.0
233
- assert metrics.explained_variance == 1.0
234
- assert metrics.loss == 0.0
235
- assert metrics.anomaly_score_mean is None
236
- assert metrics.anomaly_score_median is None
237
- assert metrics.anomaly_score_variance is None
238
-
239
-
240
- def test_basic_regression_metrics():
241
- y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
242
- y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
243
-
244
- metrics = calculate_regression_metrics(y_true, y_pred)
245
-
246
- # Check that all metrics are reasonable
247
- assert metrics.mse > 0.0
248
- assert metrics.rmse == pytest.approx(np.sqrt(metrics.mse))
249
- assert metrics.mae > 0.0
250
- assert 0.0 <= metrics.r2 <= 1.0
251
- assert 0.0 <= metrics.explained_variance <= 1.0
252
- assert metrics.loss == metrics.mse
253
-
254
- # Check specific values based on the data
255
- expected_mse = np.mean((y_true - y_pred) ** 2)
256
- assert metrics.mse == pytest.approx(expected_mse)
257
-
258
- expected_mae = np.mean(np.abs(y_true - y_pred))
259
- assert metrics.mae == pytest.approx(expected_mae)
260
-
261
-
262
- def test_regression_metrics_with_anomaly_scores():
263
- y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
264
- y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
265
- anomaly_scores = [0.1, 0.2, 0.15, 0.3, 0.25]
266
-
267
- metrics = calculate_regression_metrics(y_true, y_pred, anomaly_scores)
268
-
269
- assert metrics.anomaly_score_mean == pytest.approx(np.mean(anomaly_scores))
270
- assert metrics.anomaly_score_median == pytest.approx(np.median(anomaly_scores))
271
- assert metrics.anomaly_score_variance == pytest.approx(np.var(anomaly_scores))
272
-
273
-
274
- def test_regression_metrics_handles_nans():
275
- y_true = np.array([1.0, 2.0, 3.0], dtype=np.float32)
276
- y_pred = np.array([1.1, 1.9, np.nan], dtype=np.float32)
277
-
278
- metrics = calculate_regression_metrics(y_true, y_pred)
279
-
280
- assert np.allclose(metrics.coverage, 0.6666666666666666)
281
- assert metrics.mse > 0.0
282
- assert metrics.rmse > 0.0
283
- assert metrics.mae > 0.0
284
- assert 0.0 <= metrics.r2 <= 1.0
285
- assert 0.0 <= metrics.explained_variance <= 1.0
286
-
287
-
288
- def test_regression_metrics_handles_none_values():
289
- # Test with lists containing None values
290
- y_true = [1.0, 2.0, 3.0, 4.0, 5.0]
291
- y_pred = [1.1, 1.9, None, 3.8, np.nan]
292
-
293
- metrics = calculate_regression_metrics(y_true, y_pred)
294
-
295
- # Coverage should be 0.6 (3 out of 5 predictions are valid)
296
- # Positions with None/NaN predictions (indices 2 and 4) are filtered out
297
- assert np.allclose(metrics.coverage, 0.6)
298
-
299
- # Metrics should be calculated only on valid pairs (indices 0, 1, 3)
300
- # Valid pairs: (1.0, 1.1), (2.0, 1.9), and (4.0, 3.8)
301
- expected_mse = np.mean([(1.0 - 1.1) ** 2, (2.0 - 1.9) ** 2, (4.0 - 3.8) ** 2])
302
- expected_mae = np.mean([abs(1.0 - 1.1), abs(2.0 - 1.9), abs(4.0 - 3.8)])
303
-
304
- assert metrics.mse == pytest.approx(expected_mse)
305
- assert metrics.mae == pytest.approx(expected_mae)
306
- assert metrics.rmse == pytest.approx(np.sqrt(expected_mse))
307
- assert 0.0 <= metrics.r2 <= 1.0
308
- assert 0.0 <= metrics.explained_variance <= 1.0
309
-
310
-
311
- def test_regression_metrics_rejects_none_expected_scores():
312
- # Test that None values in expected_scores are rejected
313
- y_true = [1.0, 2.0, None, 4.0, 5.0]
314
- y_pred = [1.1, 1.9, 3.2, 3.8, 5.1]
315
-
316
- with pytest.raises(ValueError, match="expected_scores must not contain None or NaN values"):
317
- calculate_regression_metrics(y_true, y_pred)
318
-
319
-
320
- def test_regression_metrics_rejects_nan_expected_scores():
321
- # Test that NaN values in expected_scores are rejected
322
- y_true = np.array([1.0, 2.0, np.nan, 4.0, 5.0], dtype=np.float32)
323
- y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
324
-
325
- with pytest.raises(ValueError, match="expected_scores must not contain None or NaN values"):
326
- calculate_regression_metrics(y_true, y_pred)
327
-
328
-
329
- def test_regression_metrics_all_predictions_none():
330
- # Test with all predictions being None
331
- y_true = [1.0, 2.0, 3.0, 4.0, 5.0]
332
- y_pred = [None, None, None, None, None]
333
-
334
- metrics = calculate_regression_metrics(y_true, y_pred)
335
-
336
- # When all predictions are None, coverage should be 0.0 and all metrics should be 0.0
337
- assert metrics.coverage == 0.0
338
- assert metrics.mse == 0.0
339
- assert metrics.rmse == 0.0
340
- assert metrics.mae == 0.0
341
- assert metrics.r2 == 0.0
342
- assert metrics.explained_variance == 0.0
343
- assert metrics.loss == 0.0
344
- assert metrics.anomaly_score_mean is None
345
- assert metrics.anomaly_score_median is None
346
- assert metrics.anomaly_score_variance is None
347
-
348
-
349
- def test_regression_metrics_all_predictions_nan():
350
- # Test with all predictions being NaN
351
- y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
352
- y_pred = np.array([np.nan, np.nan, np.nan, np.nan, np.nan], dtype=np.float32)
353
-
354
- metrics = calculate_regression_metrics(y_true, y_pred)
355
-
356
- # When all predictions are NaN, coverage should be 0.0 and all metrics should be 0.0
357
- assert metrics.coverage == 0.0
358
- assert metrics.mse == 0.0
359
- assert metrics.rmse == 0.0
360
- assert metrics.mae == 0.0
361
- assert metrics.r2 == 0.0
362
- assert metrics.explained_variance == 0.0
363
- assert metrics.loss == 0.0
364
- assert metrics.anomaly_score_mean is None
365
- assert metrics.anomaly_score_median is None
366
- assert metrics.anomaly_score_variance is None
367
-
368
-
369
- def test_roc_auc_handles_missing_classes_in_y_true():
370
- """Test that ROC AUC is calculated with filtering when test set has fewer classes than model predictions."""
371
- # Model trained on classes [0, 1, 2], but test set only has [0, 1]
372
- y_true = np.array([0, 1, 0, 1])
373
- y_score = np.array(
374
- [
375
- [0.7, 0.2, 0.1], # Predicts class 0
376
- [0.1, 0.8, 0.1], # Predicts class 1
377
- [0.6, 0.3, 0.1], # Predicts class 0
378
- [0.2, 0.7, 0.1], # Predicts class 1
379
- ]
380
- )
381
-
382
- metrics = calculate_classification_metrics(y_true, y_score)
383
-
384
- # Should calculate ROC AUC by filtering to classes [0, 1]
385
- assert metrics.roc_auc is not None
386
- assert metrics.roc_auc == 1.0
387
- assert any("computed only on classes present" in w for w in metrics.warnings)
388
- # Other metrics should still work
389
- assert metrics.accuracy == 1.0
390
- assert metrics.f1_score == 1.0
391
- assert metrics.loss is not None
392
-
393
-
394
- def test_roc_auc_with_all_classes_present():
395
- """Test that ROC AUC works when all classes are present in test set."""
396
- # Model trained on classes [0, 1, 2], test set has all three
397
- y_true = np.array([0, 1, 2, 0, 1, 2])
398
- y_score = np.array(
399
- [
400
- [0.9, 0.05, 0.05], # Predicts class 0
401
- [0.1, 0.8, 0.1], # Predicts class 1
402
- [0.1, 0.1, 0.8], # Predicts class 2
403
- [0.7, 0.2, 0.1], # Predicts class 0
404
- [0.2, 0.7, 0.1], # Predicts class 1
405
- [0.1, 0.2, 0.7], # Predicts class 2
406
- ]
407
- )
408
-
409
- metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
410
-
411
- # ROC AUC should be calculated when all classes present
412
- assert metrics.roc_auc is not None
413
- assert metrics.accuracy == 1.0
414
- assert metrics.f1_score == 1.0
415
-
416
-
417
- def test_roc_auc_handles_subset_of_many_classes():
418
- """Test ROC AUC where model knows 15 classes, test has 10."""
419
- # Simulate the actual error scenario from the bug report
420
- num_model_classes = 15
421
- num_test_classes = 10
422
- num_samples = 50
423
-
424
- # Test set only uses classes 0-9
425
- y_true = np.random.randint(0, num_test_classes, size=num_samples)
426
-
427
- # Model produces predictions for all 15 classes
428
- y_score = np.random.rand(num_samples, num_model_classes)
429
- y_score = y_score / y_score.sum(axis=1, keepdims=True) # Normalize to probabilities
430
-
431
- metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
432
-
433
- # Should calculate ROC AUC by filtering to classes 0-9
434
- assert metrics.roc_auc is not None
435
- assert 0.0 <= metrics.roc_auc <= 1.0
436
- assert any("computed only on classes present" in w for w in metrics.warnings)
437
- # Other metrics should still work
438
- assert metrics.accuracy is not None
439
- assert metrics.f1_score is not None
440
- assert metrics.loss is not None
441
-
442
-
443
- def test_roc_auc_handles_unknown_classes_in_y_true():
444
- """Test that metrics handle when y_true contains classes not in y_score."""
445
- # Model trained on classes [0, 1, 2], but test set has class 3
446
- y_true = np.array([0, 1, 2, 3])
447
- y_score = np.array(
448
- [
449
- [0.7, 0.2, 0.1],
450
- [0.1, 0.8, 0.1],
451
- [0.1, 0.1, 0.8],
452
- [0.3, 0.4, 0.3], # Unknown class 3
453
- ]
454
- )
455
-
456
- metrics = calculate_classification_metrics(y_true, y_score)
457
-
458
- # Should skip ROC AUC and loss when unknown classes present
459
- assert metrics.roc_auc is None
460
- assert metrics.loss is None # Loss also skipped to avoid ValueError
461
- assert any("unknown" in w for w in metrics.warnings)
462
- # Other metrics should still work (they handle extra classes)
463
- assert metrics.accuracy is not None
464
- assert metrics.f1_score is not None
465
-
466
-
467
- def test_roc_auc_handles_zero_probability_on_present_classes():
468
- """Test ROC AUC when a sample has zero probability on all present classes (edge case for renormalization)."""
469
- # Model trained on classes [0, 1, 2, 3], test set only has [0, 1, 2]
470
- # One sample has ALL probability mass on excluded class 3 (zero on [0, 1, 2])
471
- y_true = np.array([0, 1, 2, 0, 1, 2])
472
- y_score = np.array(
473
- [
474
- [0.7, 0.2, 0.08, 0.02],
475
- [0.1, 0.8, 0.08, 0.02],
476
- [0.1, 0.1, 0.78, 0.02],
477
- [0.6, 0.3, 0.08, 0.02],
478
- [0.0, 0.0, 0.0, 1.0], # zero denominator
479
- [0.1, 0.1, 0.78, 0.02],
480
- ]
481
- )
482
-
483
- metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
484
-
485
- # Should still calculate ROC AUC despite zero-denominator case
486
- # The safe renormalization should prevent NaN/inf
487
- assert metrics.roc_auc is not None
488
- assert not np.isnan(metrics.roc_auc)
489
- assert not np.isinf(metrics.roc_auc)
490
- assert any("computed only on classes present" in w for w in metrics.warnings)
491
- assert metrics.accuracy is not None
492
- assert metrics.f1_score is not None
493
- assert metrics.loss is not None
494
-
495
-
496
- def test_roc_auc_skipped_for_single_class():
497
- """Test that ROC AUC is skipped when only one class is present in y_true."""
498
- # Model trained on classes [0, 1, 2], but test set only has class 0
499
- y_true = np.array([0, 0, 0, 0])
500
- y_score = np.array(
501
- [
502
- [0.9, 0.05, 0.05],
503
- [0.8, 0.1, 0.1],
504
- [0.85, 0.1, 0.05],
505
- [0.9, 0.05, 0.05],
506
- ]
507
- )
508
-
509
- metrics = calculate_classification_metrics(y_true, y_score)
510
-
511
- # ROC AUC requires at least 2 classes
512
- assert metrics.roc_auc is None
513
- assert metrics.accuracy == 1.0
514
- assert metrics.loss is not None
515
- assert any("requires at least 2 classes" in w for w in metrics.warnings)
516
-
517
-
518
- # Confusion Matrix Tests
519
- def test_confusion_matrix_binary_classification():
520
- y_true = np.array([0, 1, 1, 0, 1])
521
- y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
522
-
523
- metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
524
-
525
- assert metrics.confusion_matrix is not None
526
- expected_cm = sklearn.metrics.confusion_matrix(y_true, [0, 1, 1, 0, 0], labels=[0, 1])
527
- assert metrics.confusion_matrix == expected_cm.tolist()
528
-
529
-
530
- def test_confusion_matrix_multiclass():
531
- y_true = np.array([0, 1, 2, 0, 1, 2])
532
- y_score = np.array(
533
- [[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9], [0.7, 0.2, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
534
- )
535
-
536
- metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
537
-
538
- assert metrics.confusion_matrix is not None
539
- # All predictions correct
540
- assert metrics.confusion_matrix == [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
541
-
542
-
543
- def test_confusion_matrix_with_misclassifications():
544
- y_true = np.array([0, 1, 2, 0, 1, 2])
545
- y_score = np.array(
546
- [[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.1, 0.8, 0.1], [0.2, 0.7, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
547
- )
548
-
549
- metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
550
-
551
- assert metrics.confusion_matrix is not None
552
- # Class 0: 1 correct (index 0), 1 predicted as class 1 (index 3)
553
- # Class 1: 2 correct (indices 1, 4)
554
- # Class 2: 1 predicted as class 1 (index 2), 1 correct (index 5)
555
- assert metrics.confusion_matrix == [[1, 1, 0], [0, 2, 0], [0, 1, 1]]
556
-
557
-
558
- def test_confusion_matrix_handles_nan_logits():
559
- logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
560
- expected_labels = [0, 1, 0, 1]
561
- metrics = calculate_classification_metrics(expected_labels, logits, include_confusion_matrix=True)
562
-
563
- # NaN predictions are set to -1, so they won't match any true label
564
- # Only the last 2 predictions are valid: pred=[1, 1], true=[0, 1]
565
- assert metrics.confusion_matrix is not None
566
- # With NaN handling, predictions become [-1, -1, 1, 1]
567
- # Only position 3 is correct (true=1, pred=1)
568
- # Position 2 is wrong (true=0, pred=1)
569
- assert len(metrics.confusion_matrix) == 2 # 2 classes
570
- assert len(metrics.confusion_matrix[0]) == 2
@@ -1,137 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import pickle
4
- from dataclasses import asdict, is_dataclass
5
- from os import PathLike
6
- from typing import TYPE_CHECKING, Any, cast
7
-
8
- from datasets import Dataset
9
- from datasets.exceptions import DatasetGenerationError
10
-
11
- if TYPE_CHECKING:
12
- # peer dependencies that are used for types only
13
- from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
14
- from torch.utils.data import Dataset as TorchDataset # type: ignore
15
-
16
-
17
- def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
18
- if isinstance(item, dict):
19
- return item
20
-
21
- if isinstance(item, tuple):
22
- if column_names is not None:
23
- if len(item) != len(column_names):
24
- raise ValueError(
25
- f"Tuple length ({len(item)}) does not match number of column names ({len(column_names)})"
26
- )
27
- return {column_names[i]: item[i] for i in range(len(item))}
28
- elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
29
- return {field: getattr(item, field) for field in item._fields} # type: ignore
30
- else:
31
- raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
32
-
33
- if is_dataclass(item) and not isinstance(item, type):
34
- return asdict(item)
35
-
36
- raise ValueError(f"Cannot parse {type(item)}")
37
-
38
-
39
- def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
40
- if isinstance(batch, list):
41
- return [parse_dict_like(item, column_names) for item in batch]
42
-
43
- batch = parse_dict_like(batch, column_names)
44
- keys = list(batch.keys())
45
- batch_size = len(batch[keys[0]])
46
- for key in keys:
47
- if not len(batch[key]) == batch_size:
48
- raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
49
- return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
50
-
51
-
52
- def hf_dataset_from_torch(
53
- torch_data: TorchDataLoader | TorchDataset,
54
- column_names: list[str] | None = None,
55
- ) -> Dataset:
56
- """
57
- Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
58
-
59
- NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
60
- cached results can ignore changes you've made to tests. This can make a test appear to succeed
61
- when it's actually broken or vice versa.
62
-
63
- Params:
64
- torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
65
- column_names: Optional list of column names to use for the dataset. If not provided,
66
- the column names will be inferred from the data.
67
- Returns:
68
- A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
69
- """
70
- # peer dependency that is guaranteed to exist if the user provided a torch dataset
71
- from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
72
-
73
- if isinstance(torch_data, TorchDataLoader):
74
- dataloader = torch_data
75
- else:
76
- dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
77
-
78
- # Collect data from the dataloader into a list to avoid serialization issues
79
- # with Dataset.from_generator in Python 3.14 (see datasets issue #7839)
80
- data_list = []
81
- try:
82
- for batch in dataloader:
83
- data_list.extend(parse_batch(batch, column_names=column_names))
84
- except ValueError as e:
85
- raise DatasetGenerationError(str(e)) from e
86
-
87
- ds = Dataset.from_list(data_list)
88
-
89
- if not isinstance(ds, Dataset):
90
- raise ValueError(f"Failed to create dataset from list: {type(ds)}")
91
- return ds
92
-
93
-
94
- def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
95
- """
96
- Load a dataset from disk into a HuggingFace Dataset object.
97
-
98
- Params:
99
- file_path: Path to the file on disk to create the memoryset from. The file type will
100
- be inferred from the file extension. The following file types are supported:
101
-
102
- - .pkl: [`Pickle`][pickle] files containing lists of dictionaries or dictionaries of columns
103
- - .json/.jsonl: [`JSON`][json] and [`JSON`] Lines files
104
- - .csv: [`CSV`][csv] files
105
- - .parquet: [`Parquet`](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile) files
106
- - dataset directory: Directory containing a saved HuggingFace [`Dataset`][datasets.Dataset]
107
-
108
- Returns:
109
- A HuggingFace Dataset object containing the loaded data.
110
-
111
- Raises:
112
- [`ValueError`][ValueError]: If the pickle file contains unsupported data types or if
113
- loading the dataset fails for any reason.
114
- """
115
- if str(file_path).endswith(".pkl"):
116
- data = pickle.load(open(file_path, "rb"))
117
- if isinstance(data, list):
118
- return Dataset.from_list(data)
119
- elif isinstance(data, dict):
120
- return Dataset.from_dict(data)
121
- else:
122
- raise ValueError(f"Unsupported pickle file: {file_path}")
123
- elif str(file_path).endswith(".json"):
124
- hf_dataset = Dataset.from_json(file_path)
125
- elif str(file_path).endswith(".jsonl"):
126
- hf_dataset = Dataset.from_json(file_path)
127
- elif str(file_path).endswith(".csv"):
128
- hf_dataset = Dataset.from_csv(file_path)
129
- elif str(file_path).endswith(".parquet"):
130
- hf_dataset = Dataset.from_parquet(file_path)
131
- else:
132
- try:
133
- hf_dataset = Dataset.load_from_disk(file_path)
134
- except Exception as e:
135
- raise ValueError(f"Failed to load dataset from disk: {e}")
136
-
137
- return cast(Dataset, hf_dataset)