orca-sdk 0.1.5__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/PKG-INFO +1 -1
  2. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_shared/metrics.py +120 -18
  3. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_shared/metrics_test.py +204 -0
  4. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/async_client.py +105 -25
  5. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/classification_model.py +4 -5
  6. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/client.py +105 -25
  7. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/embedding_model.py +19 -14
  8. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/embedding_model_test.py +1 -1
  9. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/memoryset.py +1093 -231
  10. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/memoryset_test.py +109 -2
  11. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/regression_model.py +2 -3
  12. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/pyproject.toml +1 -1
  13. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/README.md +0 -0
  14. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/__init__.py +0 -0
  15. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_shared/__init__.py +0 -0
  16. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/__init__.py +0 -0
  17. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/analysis_ui.py +0 -0
  18. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/analysis_ui_style.css +0 -0
  19. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/auth.py +0 -0
  20. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/auth_test.py +0 -0
  21. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/common.py +0 -0
  22. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/data_parsing.py +0 -0
  23. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/data_parsing_test.py +0 -0
  24. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/pagination.py +0 -0
  25. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/pagination_test.py +0 -0
  26. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/prediction_result_ui.css +0 -0
  27. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/prediction_result_ui.py +0 -0
  28. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/tqdm_file_reader.py +0 -0
  29. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/value_parser.py +0 -0
  30. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/_utils/value_parser_test.py +0 -0
  31. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/classification_model_test.py +0 -0
  32. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/conftest.py +0 -0
  33. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/credentials.py +0 -0
  34. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/credentials_test.py +0 -0
  35. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/datasource.py +0 -0
  36. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/datasource_test.py +0 -0
  37. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/job.py +0 -0
  38. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/job_test.py +0 -0
  39. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/regression_model_test.py +0 -0
  40. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/telemetry.py +0 -0
  41. {orca_sdk-0.1.5 → orca_sdk-0.1.7}/orca_sdk/telemetry_test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: orca_sdk
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: SDK for interacting with Orca Services
5
5
  License-Expression: Apache-2.0
6
6
  Author: Orca DB Inc.
@@ -8,7 +8,8 @@ IMPORTANT:
8
8
 
9
9
  """
10
10
 
11
- from dataclasses import dataclass
11
+ import logging
12
+ from dataclasses import dataclass, field
12
13
  from typing import Any, Literal, Sequence, TypedDict, cast
13
14
 
14
15
  import numpy as np
@@ -20,7 +21,9 @@ from numpy.typing import NDArray
20
21
  def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
21
22
  shifted = logits - np.max(logits, axis=axis, keepdims=True)
22
23
  exps = np.exp(shifted)
23
- return exps / np.sum(exps, axis=axis, keepdims=True)
24
+ sums = np.sum(exps, axis=axis, keepdims=True)
25
+ # Guard against division by zero (can happen if all logits are -inf or NaN)
26
+ return exps / np.where(sums > 0, sums, 1.0)
24
27
 
25
28
 
26
29
  # We don't want to depend on transformers just for the eval_pred type in orca_sdk
@@ -240,6 +243,12 @@ class ClassificationMetrics:
240
243
  roc_curve: ROCCurve | None = None
241
244
  """Receiver operating characteristic curve"""
242
245
 
246
+ confusion_matrix: list[list[int]] | None = None
247
+ """Confusion matrix where confusion_matrix[i][j] is the count of samples with true label i predicted as label j"""
248
+
249
+ warnings: list[str] = field(default_factory=list)
250
+ """Human-readable warnings about skipped or adjusted metrics"""
251
+
243
252
  def __repr__(self) -> str:
244
253
  return (
245
254
  "ClassificationMetrics({\n"
@@ -300,7 +309,9 @@ def convert_logits_to_probabilities(logits: NDArray[np.float32]) -> NDArray[np.f
300
309
  probabilities = cast(NDArray[np.float32], softmax(logits))
301
310
  elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
302
311
  # Rows don't sum to 1: normalize to probabilities
303
- probabilities = cast(NDArray[np.float32], logits / logits.sum(-1, keepdims=True))
312
+ row_sums = logits.sum(-1, keepdims=True)
313
+ # Guard against division by zero (can happen if all values in a row are 0 or NaN)
314
+ probabilities = cast(NDArray[np.float32], logits / np.where(row_sums > 0, row_sums, 1.0))
304
315
  else:
305
316
  # Already normalized probabilities
306
317
  probabilities = logits
@@ -317,7 +328,9 @@ def calculate_classification_metrics(
317
328
  average: Literal["micro", "macro", "weighted", "binary"] | None = None,
318
329
  multi_class: Literal["ovr", "ovo"] = "ovr",
319
330
  include_curves: bool = False,
331
+ include_confusion_matrix: bool = False,
320
332
  ) -> ClassificationMetrics:
333
+ warnings: list[str] = []
321
334
  references = np.array(expected_labels)
322
335
 
323
336
  # Convert to numpy array, handling None values
@@ -338,6 +351,7 @@ def calculate_classification_metrics(
338
351
  pr_auc=None,
339
352
  pr_curve=None,
340
353
  roc_curve=None,
354
+ confusion_matrix=None,
341
355
  )
342
356
 
343
357
  # Convert logits to probabilities
@@ -347,41 +361,102 @@ def calculate_classification_metrics(
347
361
  predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
348
362
 
349
363
  num_classes_references = len(set(references))
350
- num_classes_predictions = len(set(predictions))
364
+ num_classes_predictions = probabilities.shape[1] # Number of probability columns (model's known classes)
351
365
  num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
352
- coverage = 1 - num_none_predictions / len(probabilities)
366
+ coverage = 1 - (num_none_predictions / len(probabilities) if len(probabilities) > 0 else 0)
367
+ if num_none_predictions > 0:
368
+ warnings.append(f"Some predictions were missing (coverage={coverage:.3f}); loss and AUC metrics were skipped.")
353
369
 
354
370
  if average is None:
355
371
  average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
356
372
 
357
373
  accuracy = sklearn.metrics.accuracy_score(references, predictions)
358
374
  f1 = sklearn.metrics.f1_score(references, predictions, average=average)
375
+
376
+ # Check for unknown classes early (before log_loss)
377
+ classes_in_references = np.unique(references)
378
+ has_unknown_classes = np.max(classes_in_references) >= num_classes_predictions
379
+ if has_unknown_classes:
380
+ logging.warning(
381
+ f"Test labels contain classes not in the model's predictions. "
382
+ f"Model has {num_classes_predictions} classes (0 - {num_classes_predictions - 1}), "
383
+ f"but test labels contain class {np.max(classes_in_references)}. "
384
+ f"ROC AUC and PR AUC cannot be calculated."
385
+ )
386
+ warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
387
+
359
388
  # Ensure sklearn sees the full class set corresponding to probability columns
360
389
  # to avoid errors when y_true does not contain all classes.
390
+ # Skip log_loss if there are unknown classes (would cause ValueError)
361
391
  loss = (
362
392
  sklearn.metrics.log_loss(
363
393
  references,
364
394
  probabilities,
365
395
  labels=list(range(probabilities.shape[1])),
366
396
  )
367
- if num_none_predictions == 0
397
+ if num_none_predictions == 0 and not has_unknown_classes
368
398
  else None
369
399
  )
370
400
 
371
- if num_classes_references == num_classes_predictions and num_none_predictions == 0:
372
- # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
373
- if num_classes_references == 2:
374
- # Use probabilities[:, 1] which is guaranteed to be 2D
375
- probabilities_positive = probabilities[:, 1]
376
- roc_auc = sklearn.metrics.roc_auc_score(references, probabilities_positive)
377
- roc_curve = calculate_roc_curve(references, probabilities_positive) if include_curves else None
378
- pr_auc = sklearn.metrics.average_precision_score(references, probabilities_positive)
379
- pr_curve = calculate_pr_curve(references, probabilities_positive) if include_curves else None
380
- else:
381
- roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
401
+ # Calculate ROC AUC with filtering for class mismatch
402
+ if num_none_predictions == 0:
403
+ # Check if y_true contains classes not in the model (unknown classes)
404
+ if has_unknown_classes:
405
+ # Unknown classes present - can't calculate meaningful ROC AUC
406
+ logging.warning(
407
+ "Cannot calculate ROC AUC and PR AUC: test labels contain classes not in the model's predictions."
408
+ )
409
+ if "y_true contains classes unknown to the model" not in " ".join(warnings):
410
+ warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
411
+ roc_auc = None
412
+ pr_auc = None
413
+ pr_curve = None
382
414
  roc_curve = None
415
+ elif len(classes_in_references) < 2:
416
+ # Need at least 2 classes for ROC AUC
417
+ logging.warning(
418
+ f"Cannot calculate ROC AUC and PR AUC: need at least 2 classes, but only {len(classes_in_references)} class(es) found in test labels."
419
+ )
420
+ roc_auc = None
383
421
  pr_auc = None
384
422
  pr_curve = None
423
+ roc_curve = None
424
+ warnings.append("ROC AUC requires at least 2 classes; metric was skipped.")
425
+ else:
426
+ # Filter probabilities to only classes present in references
427
+ if len(classes_in_references) < num_classes_predictions:
428
+ # Subset and renormalize probabilities
429
+ probabilities_filtered = probabilities[:, classes_in_references]
430
+ # Safe renormalization: guard against zero denominators
431
+ row_sums = probabilities_filtered.sum(axis=1, keepdims=True)
432
+ probabilities_filtered = probabilities_filtered / np.where(row_sums > 0, row_sums, 1.0)
433
+
434
+ # Remap references to filtered indices
435
+ class_mapping = {cls: idx for idx, cls in enumerate(classes_in_references)}
436
+ references_remapped = np.array([class_mapping[y] for y in references])
437
+ warnings.append(
438
+ f"ROC AUC computed only on classes present in y_true: {classes_in_references.tolist()}."
439
+ )
440
+ else:
441
+ # All classes present, no filtering needed
442
+ probabilities_filtered = probabilities
443
+ references_remapped = references
444
+
445
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
446
+ if len(classes_in_references) == 2:
447
+ # Use probabilities[:, 1] which is guaranteed to be 2D
448
+ probabilities_positive = cast(NDArray[np.float32], probabilities_filtered[:, 1].astype(np.float32))
449
+ roc_auc = sklearn.metrics.roc_auc_score(references_remapped, probabilities_positive)
450
+ roc_curve = calculate_roc_curve(references_remapped, probabilities_positive) if include_curves else None
451
+ pr_auc = sklearn.metrics.average_precision_score(references_remapped, probabilities_positive)
452
+ pr_curve = calculate_pr_curve(references_remapped, probabilities_positive) if include_curves else None
453
+ else:
454
+ roc_auc = sklearn.metrics.roc_auc_score(
455
+ references_remapped, probabilities_filtered, multi_class=multi_class
456
+ )
457
+ roc_curve = None
458
+ pr_auc = None
459
+ pr_curve = None
385
460
  else:
386
461
  roc_auc = None
387
462
  pr_auc = None
@@ -391,6 +466,31 @@ def calculate_classification_metrics(
391
466
  # Calculate anomaly score statistics
392
467
  anomaly_score_mean, anomaly_score_median, anomaly_score_variance = calculate_anomaly_score_stats(anomaly_scores)
393
468
 
469
+ # Calculate confusion matrix if requested
470
+ confusion_matrix: list[list[int]] | None = None
471
+ if include_confusion_matrix:
472
+ # Get the number of classes from the probabilities shape
473
+ num_classes = probabilities.shape[1]
474
+ labels = list(range(num_classes))
475
+ # Filter out NaN predictions (which are set to -1) before computing confusion matrix
476
+ valid_mask = predictions != -1
477
+ num_filtered = (~valid_mask).sum()
478
+ if num_filtered > 0:
479
+ warning_msg = (
480
+ f"Confusion matrix computation: filtered out {num_filtered} samples with NaN predictions "
481
+ f"({num_filtered}/{len(predictions)} = {num_filtered / len(predictions):.1%})"
482
+ )
483
+ logging.warning(warning_msg)
484
+ warnings.append(warning_msg)
485
+
486
+ if np.any(valid_mask):
487
+ # Compute confusion matrix with explicit labels to ensure consistent shape
488
+ cm = sklearn.metrics.confusion_matrix(references[valid_mask], predictions[valid_mask], labels=labels)
489
+ else:
490
+ # No valid predictions; return an all-zero confusion matrix
491
+ cm = np.zeros((num_classes, num_classes), dtype=int)
492
+ confusion_matrix = cast(list[list[int]], cm.tolist())
493
+
394
494
  return ClassificationMetrics(
395
495
  coverage=coverage,
396
496
  accuracy=float(accuracy),
@@ -403,6 +503,8 @@ def calculate_classification_metrics(
403
503
  pr_auc=float(pr_auc) if pr_auc is not None else None,
404
504
  pr_curve=pr_curve,
405
505
  roc_curve=roc_curve,
506
+ confusion_matrix=confusion_matrix,
507
+ warnings=warnings,
406
508
  )
407
509
 
408
510
 
@@ -503,7 +605,7 @@ def calculate_regression_metrics(
503
605
  # Filter out NaN values from predictions (expected_scores are already validated to be non-NaN)
504
606
  valid_mask = ~np.isnan(predictions)
505
607
  num_none_predictions = (~valid_mask).sum()
506
- coverage = 1 - num_none_predictions / len(predictions)
608
+ coverage = 1 - (num_none_predictions / len(predictions) if len(predictions) > 0 else 0)
507
609
  if num_none_predictions > 0:
508
610
  references = references[valid_mask]
509
611
  predictions = predictions[valid_mask]
@@ -364,3 +364,207 @@ def test_regression_metrics_all_predictions_nan():
364
364
  assert metrics.anomaly_score_mean is None
365
365
  assert metrics.anomaly_score_median is None
366
366
  assert metrics.anomaly_score_variance is None
367
+
368
+
369
+ def test_roc_auc_handles_missing_classes_in_y_true():
370
+ """Test that ROC AUC is calculated with filtering when test set has fewer classes than model predictions."""
371
+ # Model trained on classes [0, 1, 2], but test set only has [0, 1]
372
+ y_true = np.array([0, 1, 0, 1])
373
+ y_score = np.array(
374
+ [
375
+ [0.7, 0.2, 0.1], # Predicts class 0
376
+ [0.1, 0.8, 0.1], # Predicts class 1
377
+ [0.6, 0.3, 0.1], # Predicts class 0
378
+ [0.2, 0.7, 0.1], # Predicts class 1
379
+ ]
380
+ )
381
+
382
+ metrics = calculate_classification_metrics(y_true, y_score)
383
+
384
+ # Should calculate ROC AUC by filtering to classes [0, 1]
385
+ assert metrics.roc_auc is not None
386
+ assert metrics.roc_auc == 1.0
387
+ assert any("computed only on classes present" in w for w in metrics.warnings)
388
+ # Other metrics should still work
389
+ assert metrics.accuracy == 1.0
390
+ assert metrics.f1_score == 1.0
391
+ assert metrics.loss is not None
392
+
393
+
394
+ def test_roc_auc_with_all_classes_present():
395
+ """Test that ROC AUC works when all classes are present in test set."""
396
+ # Model trained on classes [0, 1, 2], test set has all three
397
+ y_true = np.array([0, 1, 2, 0, 1, 2])
398
+ y_score = np.array(
399
+ [
400
+ [0.9, 0.05, 0.05], # Predicts class 0
401
+ [0.1, 0.8, 0.1], # Predicts class 1
402
+ [0.1, 0.1, 0.8], # Predicts class 2
403
+ [0.7, 0.2, 0.1], # Predicts class 0
404
+ [0.2, 0.7, 0.1], # Predicts class 1
405
+ [0.1, 0.2, 0.7], # Predicts class 2
406
+ ]
407
+ )
408
+
409
+ metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
410
+
411
+ # ROC AUC should be calculated when all classes present
412
+ assert metrics.roc_auc is not None
413
+ assert metrics.accuracy == 1.0
414
+ assert metrics.f1_score == 1.0
415
+
416
+
417
+ def test_roc_auc_handles_subset_of_many_classes():
418
+ """Test ROC AUC where model knows 15 classes, test has 10."""
419
+ # Simulate the actual error scenario from the bug report
420
+ num_model_classes = 15
421
+ num_test_classes = 10
422
+ num_samples = 50
423
+
424
+ # Test set only uses classes 0-9
425
+ y_true = np.random.randint(0, num_test_classes, size=num_samples)
426
+
427
+ # Model produces predictions for all 15 classes
428
+ y_score = np.random.rand(num_samples, num_model_classes)
429
+ y_score = y_score / y_score.sum(axis=1, keepdims=True) # Normalize to probabilities
430
+
431
+ metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
432
+
433
+ # Should calculate ROC AUC by filtering to classes 0-9
434
+ assert metrics.roc_auc is not None
435
+ assert 0.0 <= metrics.roc_auc <= 1.0
436
+ assert any("computed only on classes present" in w for w in metrics.warnings)
437
+ # Other metrics should still work
438
+ assert metrics.accuracy is not None
439
+ assert metrics.f1_score is not None
440
+ assert metrics.loss is not None
441
+
442
+
443
+ def test_roc_auc_handles_unknown_classes_in_y_true():
444
+ """Test that metrics handle when y_true contains classes not in y_score."""
445
+ # Model trained on classes [0, 1, 2], but test set has class 3
446
+ y_true = np.array([0, 1, 2, 3])
447
+ y_score = np.array(
448
+ [
449
+ [0.7, 0.2, 0.1],
450
+ [0.1, 0.8, 0.1],
451
+ [0.1, 0.1, 0.8],
452
+ [0.3, 0.4, 0.3], # Unknown class 3
453
+ ]
454
+ )
455
+
456
+ metrics = calculate_classification_metrics(y_true, y_score)
457
+
458
+ # Should skip ROC AUC and loss when unknown classes present
459
+ assert metrics.roc_auc is None
460
+ assert metrics.loss is None # Loss also skipped to avoid ValueError
461
+ assert any("unknown" in w for w in metrics.warnings)
462
+ # Other metrics should still work (they handle extra classes)
463
+ assert metrics.accuracy is not None
464
+ assert metrics.f1_score is not None
465
+
466
+
467
+ def test_roc_auc_handles_zero_probability_on_present_classes():
468
+ """Test ROC AUC when a sample has zero probability on all present classes (edge case for renormalization)."""
469
+ # Model trained on classes [0, 1, 2, 3], test set only has [0, 1, 2]
470
+ # One sample has ALL probability mass on excluded class 3 (zero on [0, 1, 2])
471
+ y_true = np.array([0, 1, 2, 0, 1, 2])
472
+ y_score = np.array(
473
+ [
474
+ [0.7, 0.2, 0.08, 0.02],
475
+ [0.1, 0.8, 0.08, 0.02],
476
+ [0.1, 0.1, 0.78, 0.02],
477
+ [0.6, 0.3, 0.08, 0.02],
478
+ [0.0, 0.0, 0.0, 1.0], # zero denominator
479
+ [0.1, 0.1, 0.78, 0.02],
480
+ ]
481
+ )
482
+
483
+ metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
484
+
485
+ # Should still calculate ROC AUC despite zero-denominator case
486
+ # The safe renormalization should prevent NaN/inf
487
+ assert metrics.roc_auc is not None
488
+ assert not np.isnan(metrics.roc_auc)
489
+ assert not np.isinf(metrics.roc_auc)
490
+ assert any("computed only on classes present" in w for w in metrics.warnings)
491
+ assert metrics.accuracy is not None
492
+ assert metrics.f1_score is not None
493
+ assert metrics.loss is not None
494
+
495
+
496
+ def test_roc_auc_skipped_for_single_class():
497
+ """Test that ROC AUC is skipped when only one class is present in y_true."""
498
+ # Model trained on classes [0, 1, 2], but test set only has class 0
499
+ y_true = np.array([0, 0, 0, 0])
500
+ y_score = np.array(
501
+ [
502
+ [0.9, 0.05, 0.05],
503
+ [0.8, 0.1, 0.1],
504
+ [0.85, 0.1, 0.05],
505
+ [0.9, 0.05, 0.05],
506
+ ]
507
+ )
508
+
509
+ metrics = calculate_classification_metrics(y_true, y_score)
510
+
511
+ # ROC AUC requires at least 2 classes
512
+ assert metrics.roc_auc is None
513
+ assert metrics.accuracy == 1.0
514
+ assert metrics.loss is not None
515
+ assert any("requires at least 2 classes" in w for w in metrics.warnings)
516
+
517
+
518
+ # Confusion Matrix Tests
519
+ def test_confusion_matrix_binary_classification():
520
+ y_true = np.array([0, 1, 1, 0, 1])
521
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
522
+
523
+ metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
524
+
525
+ assert metrics.confusion_matrix is not None
526
+ expected_cm = sklearn.metrics.confusion_matrix(y_true, [0, 1, 1, 0, 0], labels=[0, 1])
527
+ assert metrics.confusion_matrix == expected_cm.tolist()
528
+
529
+
530
+ def test_confusion_matrix_multiclass():
531
+ y_true = np.array([0, 1, 2, 0, 1, 2])
532
+ y_score = np.array(
533
+ [[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9], [0.7, 0.2, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
534
+ )
535
+
536
+ metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
537
+
538
+ assert metrics.confusion_matrix is not None
539
+ # All predictions correct
540
+ assert metrics.confusion_matrix == [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
541
+
542
+
543
+ def test_confusion_matrix_with_misclassifications():
544
+ y_true = np.array([0, 1, 2, 0, 1, 2])
545
+ y_score = np.array(
546
+ [[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.1, 0.8, 0.1], [0.2, 0.7, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
547
+ )
548
+
549
+ metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
550
+
551
+ assert metrics.confusion_matrix is not None
552
+ # Class 0: 1 correct (index 0), 1 predicted as class 1 (index 3)
553
+ # Class 1: 2 correct (indices 1, 4)
554
+ # Class 2: 1 predicted as class 1 (index 2), 1 correct (index 5)
555
+ assert metrics.confusion_matrix == [[1, 1, 0], [0, 2, 0], [0, 1, 1]]
556
+
557
+
558
+ def test_confusion_matrix_handles_nan_logits():
559
+ logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
560
+ expected_labels = [0, 1, 0, 1]
561
+ metrics = calculate_classification_metrics(expected_labels, logits, include_confusion_matrix=True)
562
+
563
+ # NaN predictions are set to -1, so they won't match any true label
564
+ # Only the last 2 predictions are valid: pred=[1, 1], true=[0, 1]
565
+ assert metrics.confusion_matrix is not None
566
+ # With NaN handling, predictions become [-1, -1, 1, 1]
567
+ # Only position 3 is correct (true=1, pred=1)
568
+ # Position 2 is wrong (true=0, pred=1)
569
+ assert len(metrics.confusion_matrix) == 2 # 2 classes
570
+ assert len(metrics.confusion_matrix[0]) == 2