orca-sdk 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +120 -18
- orca_sdk/_shared/metrics_test.py +204 -0
- orca_sdk/async_client.py +105 -25
- orca_sdk/classification_model.py +4 -5
- orca_sdk/client.py +105 -25
- orca_sdk/embedding_model.py +19 -14
- orca_sdk/embedding_model_test.py +1 -1
- orca_sdk/memoryset.py +1093 -231
- orca_sdk/memoryset_test.py +109 -2
- orca_sdk/regression_model.py +2 -3
- {orca_sdk-0.1.5.dist-info → orca_sdk-0.1.7.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.5.dist-info → orca_sdk-0.1.7.dist-info}/RECORD +13 -13
- {orca_sdk-0.1.5.dist-info → orca_sdk-0.1.7.dist-info}/WHEEL +0 -0
orca_sdk/_shared/metrics.py
CHANGED
|
@@ -8,7 +8,8 @@ IMPORTANT:
|
|
|
8
8
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass, field
|
|
12
13
|
from typing import Any, Literal, Sequence, TypedDict, cast
|
|
13
14
|
|
|
14
15
|
import numpy as np
|
|
@@ -20,7 +21,9 @@ from numpy.typing import NDArray
|
|
|
20
21
|
def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
21
22
|
shifted = logits - np.max(logits, axis=axis, keepdims=True)
|
|
22
23
|
exps = np.exp(shifted)
|
|
23
|
-
|
|
24
|
+
sums = np.sum(exps, axis=axis, keepdims=True)
|
|
25
|
+
# Guard against division by zero (can happen if all logits are -inf or NaN)
|
|
26
|
+
return exps / np.where(sums > 0, sums, 1.0)
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
# We don't want to depend on transformers just for the eval_pred type in orca_sdk
|
|
@@ -240,6 +243,12 @@ class ClassificationMetrics:
|
|
|
240
243
|
roc_curve: ROCCurve | None = None
|
|
241
244
|
"""Receiver operating characteristic curve"""
|
|
242
245
|
|
|
246
|
+
confusion_matrix: list[list[int]] | None = None
|
|
247
|
+
"""Confusion matrix where confusion_matrix[i][j] is the count of samples with true label i predicted as label j"""
|
|
248
|
+
|
|
249
|
+
warnings: list[str] = field(default_factory=list)
|
|
250
|
+
"""Human-readable warnings about skipped or adjusted metrics"""
|
|
251
|
+
|
|
243
252
|
def __repr__(self) -> str:
|
|
244
253
|
return (
|
|
245
254
|
"ClassificationMetrics({\n"
|
|
@@ -300,7 +309,9 @@ def convert_logits_to_probabilities(logits: NDArray[np.float32]) -> NDArray[np.f
|
|
|
300
309
|
probabilities = cast(NDArray[np.float32], softmax(logits))
|
|
301
310
|
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
302
311
|
# Rows don't sum to 1: normalize to probabilities
|
|
303
|
-
|
|
312
|
+
row_sums = logits.sum(-1, keepdims=True)
|
|
313
|
+
# Guard against division by zero (can happen if all values in a row are 0 or NaN)
|
|
314
|
+
probabilities = cast(NDArray[np.float32], logits / np.where(row_sums > 0, row_sums, 1.0))
|
|
304
315
|
else:
|
|
305
316
|
# Already normalized probabilities
|
|
306
317
|
probabilities = logits
|
|
@@ -317,7 +328,9 @@ def calculate_classification_metrics(
|
|
|
317
328
|
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
318
329
|
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
319
330
|
include_curves: bool = False,
|
|
331
|
+
include_confusion_matrix: bool = False,
|
|
320
332
|
) -> ClassificationMetrics:
|
|
333
|
+
warnings: list[str] = []
|
|
321
334
|
references = np.array(expected_labels)
|
|
322
335
|
|
|
323
336
|
# Convert to numpy array, handling None values
|
|
@@ -338,6 +351,7 @@ def calculate_classification_metrics(
|
|
|
338
351
|
pr_auc=None,
|
|
339
352
|
pr_curve=None,
|
|
340
353
|
roc_curve=None,
|
|
354
|
+
confusion_matrix=None,
|
|
341
355
|
)
|
|
342
356
|
|
|
343
357
|
# Convert logits to probabilities
|
|
@@ -347,41 +361,102 @@ def calculate_classification_metrics(
|
|
|
347
361
|
predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
|
|
348
362
|
|
|
349
363
|
num_classes_references = len(set(references))
|
|
350
|
-
num_classes_predictions =
|
|
364
|
+
num_classes_predictions = probabilities.shape[1] # Number of probability columns (model's known classes)
|
|
351
365
|
num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
|
|
352
|
-
coverage = 1 - num_none_predictions / len(probabilities)
|
|
366
|
+
coverage = 1 - (num_none_predictions / len(probabilities) if len(probabilities) > 0 else 0)
|
|
367
|
+
if num_none_predictions > 0:
|
|
368
|
+
warnings.append(f"Some predictions were missing (coverage={coverage:.3f}); loss and AUC metrics were skipped.")
|
|
353
369
|
|
|
354
370
|
if average is None:
|
|
355
371
|
average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
|
|
356
372
|
|
|
357
373
|
accuracy = sklearn.metrics.accuracy_score(references, predictions)
|
|
358
374
|
f1 = sklearn.metrics.f1_score(references, predictions, average=average)
|
|
375
|
+
|
|
376
|
+
# Check for unknown classes early (before log_loss)
|
|
377
|
+
classes_in_references = np.unique(references)
|
|
378
|
+
has_unknown_classes = np.max(classes_in_references) >= num_classes_predictions
|
|
379
|
+
if has_unknown_classes:
|
|
380
|
+
logging.warning(
|
|
381
|
+
f"Test labels contain classes not in the model's predictions. "
|
|
382
|
+
f"Model has {num_classes_predictions} classes (0 - {num_classes_predictions - 1}), "
|
|
383
|
+
f"but test labels contain class {np.max(classes_in_references)}. "
|
|
384
|
+
f"ROC AUC and PR AUC cannot be calculated."
|
|
385
|
+
)
|
|
386
|
+
warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
|
|
387
|
+
|
|
359
388
|
# Ensure sklearn sees the full class set corresponding to probability columns
|
|
360
389
|
# to avoid errors when y_true does not contain all classes.
|
|
390
|
+
# Skip log_loss if there are unknown classes (would cause ValueError)
|
|
361
391
|
loss = (
|
|
362
392
|
sklearn.metrics.log_loss(
|
|
363
393
|
references,
|
|
364
394
|
probabilities,
|
|
365
395
|
labels=list(range(probabilities.shape[1])),
|
|
366
396
|
)
|
|
367
|
-
if num_none_predictions == 0
|
|
397
|
+
if num_none_predictions == 0 and not has_unknown_classes
|
|
368
398
|
else None
|
|
369
399
|
)
|
|
370
400
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
if
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
roc_auc =
|
|
401
|
+
# Calculate ROC AUC with filtering for class mismatch
|
|
402
|
+
if num_none_predictions == 0:
|
|
403
|
+
# Check if y_true contains classes not in the model (unknown classes)
|
|
404
|
+
if has_unknown_classes:
|
|
405
|
+
# Unknown classes present - can't calculate meaningful ROC AUC
|
|
406
|
+
logging.warning(
|
|
407
|
+
"Cannot calculate ROC AUC and PR AUC: test labels contain classes not in the model's predictions."
|
|
408
|
+
)
|
|
409
|
+
if "y_true contains classes unknown to the model" not in " ".join(warnings):
|
|
410
|
+
warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
|
|
411
|
+
roc_auc = None
|
|
412
|
+
pr_auc = None
|
|
413
|
+
pr_curve = None
|
|
382
414
|
roc_curve = None
|
|
415
|
+
elif len(classes_in_references) < 2:
|
|
416
|
+
# Need at least 2 classes for ROC AUC
|
|
417
|
+
logging.warning(
|
|
418
|
+
f"Cannot calculate ROC AUC and PR AUC: need at least 2 classes, but only {len(classes_in_references)} class(es) found in test labels."
|
|
419
|
+
)
|
|
420
|
+
roc_auc = None
|
|
383
421
|
pr_auc = None
|
|
384
422
|
pr_curve = None
|
|
423
|
+
roc_curve = None
|
|
424
|
+
warnings.append("ROC AUC requires at least 2 classes; metric was skipped.")
|
|
425
|
+
else:
|
|
426
|
+
# Filter probabilities to only classes present in references
|
|
427
|
+
if len(classes_in_references) < num_classes_predictions:
|
|
428
|
+
# Subset and renormalize probabilities
|
|
429
|
+
probabilities_filtered = probabilities[:, classes_in_references]
|
|
430
|
+
# Safe renormalization: guard against zero denominators
|
|
431
|
+
row_sums = probabilities_filtered.sum(axis=1, keepdims=True)
|
|
432
|
+
probabilities_filtered = probabilities_filtered / np.where(row_sums > 0, row_sums, 1.0)
|
|
433
|
+
|
|
434
|
+
# Remap references to filtered indices
|
|
435
|
+
class_mapping = {cls: idx for idx, cls in enumerate(classes_in_references)}
|
|
436
|
+
references_remapped = np.array([class_mapping[y] for y in references])
|
|
437
|
+
warnings.append(
|
|
438
|
+
f"ROC AUC computed only on classes present in y_true: {classes_in_references.tolist()}."
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
# All classes present, no filtering needed
|
|
442
|
+
probabilities_filtered = probabilities
|
|
443
|
+
references_remapped = references
|
|
444
|
+
|
|
445
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
446
|
+
if len(classes_in_references) == 2:
|
|
447
|
+
# Use probabilities[:, 1] which is guaranteed to be 2D
|
|
448
|
+
probabilities_positive = cast(NDArray[np.float32], probabilities_filtered[:, 1].astype(np.float32))
|
|
449
|
+
roc_auc = sklearn.metrics.roc_auc_score(references_remapped, probabilities_positive)
|
|
450
|
+
roc_curve = calculate_roc_curve(references_remapped, probabilities_positive) if include_curves else None
|
|
451
|
+
pr_auc = sklearn.metrics.average_precision_score(references_remapped, probabilities_positive)
|
|
452
|
+
pr_curve = calculate_pr_curve(references_remapped, probabilities_positive) if include_curves else None
|
|
453
|
+
else:
|
|
454
|
+
roc_auc = sklearn.metrics.roc_auc_score(
|
|
455
|
+
references_remapped, probabilities_filtered, multi_class=multi_class
|
|
456
|
+
)
|
|
457
|
+
roc_curve = None
|
|
458
|
+
pr_auc = None
|
|
459
|
+
pr_curve = None
|
|
385
460
|
else:
|
|
386
461
|
roc_auc = None
|
|
387
462
|
pr_auc = None
|
|
@@ -391,6 +466,31 @@ def calculate_classification_metrics(
|
|
|
391
466
|
# Calculate anomaly score statistics
|
|
392
467
|
anomaly_score_mean, anomaly_score_median, anomaly_score_variance = calculate_anomaly_score_stats(anomaly_scores)
|
|
393
468
|
|
|
469
|
+
# Calculate confusion matrix if requested
|
|
470
|
+
confusion_matrix: list[list[int]] | None = None
|
|
471
|
+
if include_confusion_matrix:
|
|
472
|
+
# Get the number of classes from the probabilities shape
|
|
473
|
+
num_classes = probabilities.shape[1]
|
|
474
|
+
labels = list(range(num_classes))
|
|
475
|
+
# Filter out NaN predictions (which are set to -1) before computing confusion matrix
|
|
476
|
+
valid_mask = predictions != -1
|
|
477
|
+
num_filtered = (~valid_mask).sum()
|
|
478
|
+
if num_filtered > 0:
|
|
479
|
+
warning_msg = (
|
|
480
|
+
f"Confusion matrix computation: filtered out {num_filtered} samples with NaN predictions "
|
|
481
|
+
f"({num_filtered}/{len(predictions)} = {num_filtered / len(predictions):.1%})"
|
|
482
|
+
)
|
|
483
|
+
logging.warning(warning_msg)
|
|
484
|
+
warnings.append(warning_msg)
|
|
485
|
+
|
|
486
|
+
if np.any(valid_mask):
|
|
487
|
+
# Compute confusion matrix with explicit labels to ensure consistent shape
|
|
488
|
+
cm = sklearn.metrics.confusion_matrix(references[valid_mask], predictions[valid_mask], labels=labels)
|
|
489
|
+
else:
|
|
490
|
+
# No valid predictions; return an all-zero confusion matrix
|
|
491
|
+
cm = np.zeros((num_classes, num_classes), dtype=int)
|
|
492
|
+
confusion_matrix = cast(list[list[int]], cm.tolist())
|
|
493
|
+
|
|
394
494
|
return ClassificationMetrics(
|
|
395
495
|
coverage=coverage,
|
|
396
496
|
accuracy=float(accuracy),
|
|
@@ -403,6 +503,8 @@ def calculate_classification_metrics(
|
|
|
403
503
|
pr_auc=float(pr_auc) if pr_auc is not None else None,
|
|
404
504
|
pr_curve=pr_curve,
|
|
405
505
|
roc_curve=roc_curve,
|
|
506
|
+
confusion_matrix=confusion_matrix,
|
|
507
|
+
warnings=warnings,
|
|
406
508
|
)
|
|
407
509
|
|
|
408
510
|
|
|
@@ -503,7 +605,7 @@ def calculate_regression_metrics(
|
|
|
503
605
|
# Filter out NaN values from predictions (expected_scores are already validated to be non-NaN)
|
|
504
606
|
valid_mask = ~np.isnan(predictions)
|
|
505
607
|
num_none_predictions = (~valid_mask).sum()
|
|
506
|
-
coverage = 1 - num_none_predictions / len(predictions)
|
|
608
|
+
coverage = 1 - (num_none_predictions / len(predictions) if len(predictions) > 0 else 0)
|
|
507
609
|
if num_none_predictions > 0:
|
|
508
610
|
references = references[valid_mask]
|
|
509
611
|
predictions = predictions[valid_mask]
|
orca_sdk/_shared/metrics_test.py
CHANGED
|
@@ -364,3 +364,207 @@ def test_regression_metrics_all_predictions_nan():
|
|
|
364
364
|
assert metrics.anomaly_score_mean is None
|
|
365
365
|
assert metrics.anomaly_score_median is None
|
|
366
366
|
assert metrics.anomaly_score_variance is None
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def test_roc_auc_handles_missing_classes_in_y_true():
|
|
370
|
+
"""Test that ROC AUC is calculated with filtering when test set has fewer classes than model predictions."""
|
|
371
|
+
# Model trained on classes [0, 1, 2], but test set only has [0, 1]
|
|
372
|
+
y_true = np.array([0, 1, 0, 1])
|
|
373
|
+
y_score = np.array(
|
|
374
|
+
[
|
|
375
|
+
[0.7, 0.2, 0.1], # Predicts class 0
|
|
376
|
+
[0.1, 0.8, 0.1], # Predicts class 1
|
|
377
|
+
[0.6, 0.3, 0.1], # Predicts class 0
|
|
378
|
+
[0.2, 0.7, 0.1], # Predicts class 1
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
383
|
+
|
|
384
|
+
# Should calculate ROC AUC by filtering to classes [0, 1]
|
|
385
|
+
assert metrics.roc_auc is not None
|
|
386
|
+
assert metrics.roc_auc == 1.0
|
|
387
|
+
assert any("computed only on classes present" in w for w in metrics.warnings)
|
|
388
|
+
# Other metrics should still work
|
|
389
|
+
assert metrics.accuracy == 1.0
|
|
390
|
+
assert metrics.f1_score == 1.0
|
|
391
|
+
assert metrics.loss is not None
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def test_roc_auc_with_all_classes_present():
|
|
395
|
+
"""Test that ROC AUC works when all classes are present in test set."""
|
|
396
|
+
# Model trained on classes [0, 1, 2], test set has all three
|
|
397
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
398
|
+
y_score = np.array(
|
|
399
|
+
[
|
|
400
|
+
[0.9, 0.05, 0.05], # Predicts class 0
|
|
401
|
+
[0.1, 0.8, 0.1], # Predicts class 1
|
|
402
|
+
[0.1, 0.1, 0.8], # Predicts class 2
|
|
403
|
+
[0.7, 0.2, 0.1], # Predicts class 0
|
|
404
|
+
[0.2, 0.7, 0.1], # Predicts class 1
|
|
405
|
+
[0.1, 0.2, 0.7], # Predicts class 2
|
|
406
|
+
]
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
|
|
410
|
+
|
|
411
|
+
# ROC AUC should be calculated when all classes present
|
|
412
|
+
assert metrics.roc_auc is not None
|
|
413
|
+
assert metrics.accuracy == 1.0
|
|
414
|
+
assert metrics.f1_score == 1.0
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def test_roc_auc_handles_subset_of_many_classes():
|
|
418
|
+
"""Test ROC AUC where model knows 15 classes, test has 10."""
|
|
419
|
+
# Simulate the actual error scenario from the bug report
|
|
420
|
+
num_model_classes = 15
|
|
421
|
+
num_test_classes = 10
|
|
422
|
+
num_samples = 50
|
|
423
|
+
|
|
424
|
+
# Test set only uses classes 0-9
|
|
425
|
+
y_true = np.random.randint(0, num_test_classes, size=num_samples)
|
|
426
|
+
|
|
427
|
+
# Model produces predictions for all 15 classes
|
|
428
|
+
y_score = np.random.rand(num_samples, num_model_classes)
|
|
429
|
+
y_score = y_score / y_score.sum(axis=1, keepdims=True) # Normalize to probabilities
|
|
430
|
+
|
|
431
|
+
metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
|
|
432
|
+
|
|
433
|
+
# Should calculate ROC AUC by filtering to classes 0-9
|
|
434
|
+
assert metrics.roc_auc is not None
|
|
435
|
+
assert 0.0 <= metrics.roc_auc <= 1.0
|
|
436
|
+
assert any("computed only on classes present" in w for w in metrics.warnings)
|
|
437
|
+
# Other metrics should still work
|
|
438
|
+
assert metrics.accuracy is not None
|
|
439
|
+
assert metrics.f1_score is not None
|
|
440
|
+
assert metrics.loss is not None
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def test_roc_auc_handles_unknown_classes_in_y_true():
|
|
444
|
+
"""Test that metrics handle when y_true contains classes not in y_score."""
|
|
445
|
+
# Model trained on classes [0, 1, 2], but test set has class 3
|
|
446
|
+
y_true = np.array([0, 1, 2, 3])
|
|
447
|
+
y_score = np.array(
|
|
448
|
+
[
|
|
449
|
+
[0.7, 0.2, 0.1],
|
|
450
|
+
[0.1, 0.8, 0.1],
|
|
451
|
+
[0.1, 0.1, 0.8],
|
|
452
|
+
[0.3, 0.4, 0.3], # Unknown class 3
|
|
453
|
+
]
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
457
|
+
|
|
458
|
+
# Should skip ROC AUC and loss when unknown classes present
|
|
459
|
+
assert metrics.roc_auc is None
|
|
460
|
+
assert metrics.loss is None # Loss also skipped to avoid ValueError
|
|
461
|
+
assert any("unknown" in w for w in metrics.warnings)
|
|
462
|
+
# Other metrics should still work (they handle extra classes)
|
|
463
|
+
assert metrics.accuracy is not None
|
|
464
|
+
assert metrics.f1_score is not None
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def test_roc_auc_handles_zero_probability_on_present_classes():
|
|
468
|
+
"""Test ROC AUC when a sample has zero probability on all present classes (edge case for renormalization)."""
|
|
469
|
+
# Model trained on classes [0, 1, 2, 3], test set only has [0, 1, 2]
|
|
470
|
+
# One sample has ALL probability mass on excluded class 3 (zero on [0, 1, 2])
|
|
471
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
472
|
+
y_score = np.array(
|
|
473
|
+
[
|
|
474
|
+
[0.7, 0.2, 0.08, 0.02],
|
|
475
|
+
[0.1, 0.8, 0.08, 0.02],
|
|
476
|
+
[0.1, 0.1, 0.78, 0.02],
|
|
477
|
+
[0.6, 0.3, 0.08, 0.02],
|
|
478
|
+
[0.0, 0.0, 0.0, 1.0], # zero denominator
|
|
479
|
+
[0.1, 0.1, 0.78, 0.02],
|
|
480
|
+
]
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
|
|
484
|
+
|
|
485
|
+
# Should still calculate ROC AUC despite zero-denominator case
|
|
486
|
+
# The safe renormalization should prevent NaN/inf
|
|
487
|
+
assert metrics.roc_auc is not None
|
|
488
|
+
assert not np.isnan(metrics.roc_auc)
|
|
489
|
+
assert not np.isinf(metrics.roc_auc)
|
|
490
|
+
assert any("computed only on classes present" in w for w in metrics.warnings)
|
|
491
|
+
assert metrics.accuracy is not None
|
|
492
|
+
assert metrics.f1_score is not None
|
|
493
|
+
assert metrics.loss is not None
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def test_roc_auc_skipped_for_single_class():
|
|
497
|
+
"""Test that ROC AUC is skipped when only one class is present in y_true."""
|
|
498
|
+
# Model trained on classes [0, 1, 2], but test set only has class 0
|
|
499
|
+
y_true = np.array([0, 0, 0, 0])
|
|
500
|
+
y_score = np.array(
|
|
501
|
+
[
|
|
502
|
+
[0.9, 0.05, 0.05],
|
|
503
|
+
[0.8, 0.1, 0.1],
|
|
504
|
+
[0.85, 0.1, 0.05],
|
|
505
|
+
[0.9, 0.05, 0.05],
|
|
506
|
+
]
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
510
|
+
|
|
511
|
+
# ROC AUC requires at least 2 classes
|
|
512
|
+
assert metrics.roc_auc is None
|
|
513
|
+
assert metrics.accuracy == 1.0
|
|
514
|
+
assert metrics.loss is not None
|
|
515
|
+
assert any("requires at least 2 classes" in w for w in metrics.warnings)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
# Confusion Matrix Tests
|
|
519
|
+
def test_confusion_matrix_binary_classification():
|
|
520
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
521
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
522
|
+
|
|
523
|
+
metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
|
|
524
|
+
|
|
525
|
+
assert metrics.confusion_matrix is not None
|
|
526
|
+
expected_cm = sklearn.metrics.confusion_matrix(y_true, [0, 1, 1, 0, 0], labels=[0, 1])
|
|
527
|
+
assert metrics.confusion_matrix == expected_cm.tolist()
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def test_confusion_matrix_multiclass():
|
|
531
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
532
|
+
y_score = np.array(
|
|
533
|
+
[[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9], [0.7, 0.2, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
|
|
537
|
+
|
|
538
|
+
assert metrics.confusion_matrix is not None
|
|
539
|
+
# All predictions correct
|
|
540
|
+
assert metrics.confusion_matrix == [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def test_confusion_matrix_with_misclassifications():
|
|
544
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
545
|
+
y_score = np.array(
|
|
546
|
+
[[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.1, 0.8, 0.1], [0.2, 0.7, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
|
|
550
|
+
|
|
551
|
+
assert metrics.confusion_matrix is not None
|
|
552
|
+
# Class 0: 1 correct (index 0), 1 predicted as class 1 (index 3)
|
|
553
|
+
# Class 1: 2 correct (indices 1, 4)
|
|
554
|
+
# Class 2: 1 predicted as class 1 (index 2), 1 correct (index 5)
|
|
555
|
+
assert metrics.confusion_matrix == [[1, 1, 0], [0, 2, 0], [0, 1, 1]]
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def test_confusion_matrix_handles_nan_logits():
|
|
559
|
+
logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
|
|
560
|
+
expected_labels = [0, 1, 0, 1]
|
|
561
|
+
metrics = calculate_classification_metrics(expected_labels, logits, include_confusion_matrix=True)
|
|
562
|
+
|
|
563
|
+
# NaN predictions are set to -1, so they won't match any true label
|
|
564
|
+
# Only the last 2 predictions are valid: pred=[1, 1], true=[0, 1]
|
|
565
|
+
assert metrics.confusion_matrix is not None
|
|
566
|
+
# With NaN handling, predictions become [-1, -1, 1, 1]
|
|
567
|
+
# Only position 3 is correct (true=1, pred=1)
|
|
568
|
+
# Position 2 is wrong (true=0, pred=1)
|
|
569
|
+
assert len(metrics.confusion_matrix) == 2 # 2 classes
|
|
570
|
+
assert len(metrics.confusion_matrix[0]) == 2
|