orca-sdk 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +112 -14
- orca_sdk/_shared/metrics_test.py +204 -0
- orca_sdk/async_client.py +67 -11
- orca_sdk/classification_model.py +2 -1
- orca_sdk/client.py +67 -11
- orca_sdk/embedding_model.py +19 -12
- orca_sdk/embedding_model_test.py +1 -1
- orca_sdk/memoryset.py +1093 -231
- orca_sdk/memoryset_test.py +109 -2
- orca_sdk/regression_model.py +2 -3
- {orca_sdk-0.1.6.dist-info → orca_sdk-0.1.7.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.6.dist-info → orca_sdk-0.1.7.dist-info}/RECORD +13 -13
- {orca_sdk-0.1.6.dist-info → orca_sdk-0.1.7.dist-info}/WHEEL +0 -0
orca_sdk/_shared/metrics.py
CHANGED
|
@@ -8,7 +8,8 @@ IMPORTANT:
|
|
|
8
8
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass, field
|
|
12
13
|
from typing import Any, Literal, Sequence, TypedDict, cast
|
|
13
14
|
|
|
14
15
|
import numpy as np
|
|
@@ -242,6 +243,12 @@ class ClassificationMetrics:
|
|
|
242
243
|
roc_curve: ROCCurve | None = None
|
|
243
244
|
"""Receiver operating characteristic curve"""
|
|
244
245
|
|
|
246
|
+
confusion_matrix: list[list[int]] | None = None
|
|
247
|
+
"""Confusion matrix where confusion_matrix[i][j] is the count of samples with true label i predicted as label j"""
|
|
248
|
+
|
|
249
|
+
warnings: list[str] = field(default_factory=list)
|
|
250
|
+
"""Human-readable warnings about skipped or adjusted metrics"""
|
|
251
|
+
|
|
245
252
|
def __repr__(self) -> str:
|
|
246
253
|
return (
|
|
247
254
|
"ClassificationMetrics({\n"
|
|
@@ -321,7 +328,9 @@ def calculate_classification_metrics(
|
|
|
321
328
|
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
322
329
|
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
323
330
|
include_curves: bool = False,
|
|
331
|
+
include_confusion_matrix: bool = False,
|
|
324
332
|
) -> ClassificationMetrics:
|
|
333
|
+
warnings: list[str] = []
|
|
325
334
|
references = np.array(expected_labels)
|
|
326
335
|
|
|
327
336
|
# Convert to numpy array, handling None values
|
|
@@ -342,6 +351,7 @@ def calculate_classification_metrics(
|
|
|
342
351
|
pr_auc=None,
|
|
343
352
|
pr_curve=None,
|
|
344
353
|
roc_curve=None,
|
|
354
|
+
confusion_matrix=None,
|
|
345
355
|
)
|
|
346
356
|
|
|
347
357
|
# Convert logits to probabilities
|
|
@@ -351,41 +361,102 @@ def calculate_classification_metrics(
|
|
|
351
361
|
predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
|
|
352
362
|
|
|
353
363
|
num_classes_references = len(set(references))
|
|
354
|
-
num_classes_predictions =
|
|
364
|
+
num_classes_predictions = probabilities.shape[1] # Number of probability columns (model's known classes)
|
|
355
365
|
num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
|
|
356
366
|
coverage = 1 - (num_none_predictions / len(probabilities) if len(probabilities) > 0 else 0)
|
|
367
|
+
if num_none_predictions > 0:
|
|
368
|
+
warnings.append(f"Some predictions were missing (coverage={coverage:.3f}); loss and AUC metrics were skipped.")
|
|
357
369
|
|
|
358
370
|
if average is None:
|
|
359
371
|
average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
|
|
360
372
|
|
|
361
373
|
accuracy = sklearn.metrics.accuracy_score(references, predictions)
|
|
362
374
|
f1 = sklearn.metrics.f1_score(references, predictions, average=average)
|
|
375
|
+
|
|
376
|
+
# Check for unknown classes early (before log_loss)
|
|
377
|
+
classes_in_references = np.unique(references)
|
|
378
|
+
has_unknown_classes = np.max(classes_in_references) >= num_classes_predictions
|
|
379
|
+
if has_unknown_classes:
|
|
380
|
+
logging.warning(
|
|
381
|
+
f"Test labels contain classes not in the model's predictions. "
|
|
382
|
+
f"Model has {num_classes_predictions} classes (0 - {num_classes_predictions - 1}), "
|
|
383
|
+
f"but test labels contain class {np.max(classes_in_references)}. "
|
|
384
|
+
f"ROC AUC and PR AUC cannot be calculated."
|
|
385
|
+
)
|
|
386
|
+
warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
|
|
387
|
+
|
|
363
388
|
# Ensure sklearn sees the full class set corresponding to probability columns
|
|
364
389
|
# to avoid errors when y_true does not contain all classes.
|
|
390
|
+
# Skip log_loss if there are unknown classes (would cause ValueError)
|
|
365
391
|
loss = (
|
|
366
392
|
sklearn.metrics.log_loss(
|
|
367
393
|
references,
|
|
368
394
|
probabilities,
|
|
369
395
|
labels=list(range(probabilities.shape[1])),
|
|
370
396
|
)
|
|
371
|
-
if num_none_predictions == 0
|
|
397
|
+
if num_none_predictions == 0 and not has_unknown_classes
|
|
372
398
|
else None
|
|
373
399
|
)
|
|
374
400
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
if
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
roc_auc =
|
|
401
|
+
# Calculate ROC AUC with filtering for class mismatch
|
|
402
|
+
if num_none_predictions == 0:
|
|
403
|
+
# Check if y_true contains classes not in the model (unknown classes)
|
|
404
|
+
if has_unknown_classes:
|
|
405
|
+
# Unknown classes present - can't calculate meaningful ROC AUC
|
|
406
|
+
logging.warning(
|
|
407
|
+
"Cannot calculate ROC AUC and PR AUC: test labels contain classes not in the model's predictions."
|
|
408
|
+
)
|
|
409
|
+
if "y_true contains classes unknown to the model" not in " ".join(warnings):
|
|
410
|
+
warnings.append("y_true contains classes unknown to the model; loss and AUC metrics were skipped.")
|
|
411
|
+
roc_auc = None
|
|
412
|
+
pr_auc = None
|
|
413
|
+
pr_curve = None
|
|
386
414
|
roc_curve = None
|
|
415
|
+
elif len(classes_in_references) < 2:
|
|
416
|
+
# Need at least 2 classes for ROC AUC
|
|
417
|
+
logging.warning(
|
|
418
|
+
f"Cannot calculate ROC AUC and PR AUC: need at least 2 classes, but only {len(classes_in_references)} class(es) found in test labels."
|
|
419
|
+
)
|
|
420
|
+
roc_auc = None
|
|
387
421
|
pr_auc = None
|
|
388
422
|
pr_curve = None
|
|
423
|
+
roc_curve = None
|
|
424
|
+
warnings.append("ROC AUC requires at least 2 classes; metric was skipped.")
|
|
425
|
+
else:
|
|
426
|
+
# Filter probabilities to only classes present in references
|
|
427
|
+
if len(classes_in_references) < num_classes_predictions:
|
|
428
|
+
# Subset and renormalize probabilities
|
|
429
|
+
probabilities_filtered = probabilities[:, classes_in_references]
|
|
430
|
+
# Safe renormalization: guard against zero denominators
|
|
431
|
+
row_sums = probabilities_filtered.sum(axis=1, keepdims=True)
|
|
432
|
+
probabilities_filtered = probabilities_filtered / np.where(row_sums > 0, row_sums, 1.0)
|
|
433
|
+
|
|
434
|
+
# Remap references to filtered indices
|
|
435
|
+
class_mapping = {cls: idx for idx, cls in enumerate(classes_in_references)}
|
|
436
|
+
references_remapped = np.array([class_mapping[y] for y in references])
|
|
437
|
+
warnings.append(
|
|
438
|
+
f"ROC AUC computed only on classes present in y_true: {classes_in_references.tolist()}."
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
# All classes present, no filtering needed
|
|
442
|
+
probabilities_filtered = probabilities
|
|
443
|
+
references_remapped = references
|
|
444
|
+
|
|
445
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
446
|
+
if len(classes_in_references) == 2:
|
|
447
|
+
# Use probabilities[:, 1] which is guaranteed to be 2D
|
|
448
|
+
probabilities_positive = cast(NDArray[np.float32], probabilities_filtered[:, 1].astype(np.float32))
|
|
449
|
+
roc_auc = sklearn.metrics.roc_auc_score(references_remapped, probabilities_positive)
|
|
450
|
+
roc_curve = calculate_roc_curve(references_remapped, probabilities_positive) if include_curves else None
|
|
451
|
+
pr_auc = sklearn.metrics.average_precision_score(references_remapped, probabilities_positive)
|
|
452
|
+
pr_curve = calculate_pr_curve(references_remapped, probabilities_positive) if include_curves else None
|
|
453
|
+
else:
|
|
454
|
+
roc_auc = sklearn.metrics.roc_auc_score(
|
|
455
|
+
references_remapped, probabilities_filtered, multi_class=multi_class
|
|
456
|
+
)
|
|
457
|
+
roc_curve = None
|
|
458
|
+
pr_auc = None
|
|
459
|
+
pr_curve = None
|
|
389
460
|
else:
|
|
390
461
|
roc_auc = None
|
|
391
462
|
pr_auc = None
|
|
@@ -395,6 +466,31 @@ def calculate_classification_metrics(
|
|
|
395
466
|
# Calculate anomaly score statistics
|
|
396
467
|
anomaly_score_mean, anomaly_score_median, anomaly_score_variance = calculate_anomaly_score_stats(anomaly_scores)
|
|
397
468
|
|
|
469
|
+
# Calculate confusion matrix if requested
|
|
470
|
+
confusion_matrix: list[list[int]] | None = None
|
|
471
|
+
if include_confusion_matrix:
|
|
472
|
+
# Get the number of classes from the probabilities shape
|
|
473
|
+
num_classes = probabilities.shape[1]
|
|
474
|
+
labels = list(range(num_classes))
|
|
475
|
+
# Filter out NaN predictions (which are set to -1) before computing confusion matrix
|
|
476
|
+
valid_mask = predictions != -1
|
|
477
|
+
num_filtered = (~valid_mask).sum()
|
|
478
|
+
if num_filtered > 0:
|
|
479
|
+
warning_msg = (
|
|
480
|
+
f"Confusion matrix computation: filtered out {num_filtered} samples with NaN predictions "
|
|
481
|
+
f"({num_filtered}/{len(predictions)} = {num_filtered / len(predictions):.1%})"
|
|
482
|
+
)
|
|
483
|
+
logging.warning(warning_msg)
|
|
484
|
+
warnings.append(warning_msg)
|
|
485
|
+
|
|
486
|
+
if np.any(valid_mask):
|
|
487
|
+
# Compute confusion matrix with explicit labels to ensure consistent shape
|
|
488
|
+
cm = sklearn.metrics.confusion_matrix(references[valid_mask], predictions[valid_mask], labels=labels)
|
|
489
|
+
else:
|
|
490
|
+
# No valid predictions; return an all-zero confusion matrix
|
|
491
|
+
cm = np.zeros((num_classes, num_classes), dtype=int)
|
|
492
|
+
confusion_matrix = cast(list[list[int]], cm.tolist())
|
|
493
|
+
|
|
398
494
|
return ClassificationMetrics(
|
|
399
495
|
coverage=coverage,
|
|
400
496
|
accuracy=float(accuracy),
|
|
@@ -407,6 +503,8 @@ def calculate_classification_metrics(
|
|
|
407
503
|
pr_auc=float(pr_auc) if pr_auc is not None else None,
|
|
408
504
|
pr_curve=pr_curve,
|
|
409
505
|
roc_curve=roc_curve,
|
|
506
|
+
confusion_matrix=confusion_matrix,
|
|
507
|
+
warnings=warnings,
|
|
410
508
|
)
|
|
411
509
|
|
|
412
510
|
|
orca_sdk/_shared/metrics_test.py
CHANGED
|
@@ -364,3 +364,207 @@ def test_regression_metrics_all_predictions_nan():
|
|
|
364
364
|
assert metrics.anomaly_score_mean is None
|
|
365
365
|
assert metrics.anomaly_score_median is None
|
|
366
366
|
assert metrics.anomaly_score_variance is None
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def test_roc_auc_handles_missing_classes_in_y_true():
|
|
370
|
+
"""Test that ROC AUC is calculated with filtering when test set has fewer classes than model predictions."""
|
|
371
|
+
# Model trained on classes [0, 1, 2], but test set only has [0, 1]
|
|
372
|
+
y_true = np.array([0, 1, 0, 1])
|
|
373
|
+
y_score = np.array(
|
|
374
|
+
[
|
|
375
|
+
[0.7, 0.2, 0.1], # Predicts class 0
|
|
376
|
+
[0.1, 0.8, 0.1], # Predicts class 1
|
|
377
|
+
[0.6, 0.3, 0.1], # Predicts class 0
|
|
378
|
+
[0.2, 0.7, 0.1], # Predicts class 1
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
383
|
+
|
|
384
|
+
# Should calculate ROC AUC by filtering to classes [0, 1]
|
|
385
|
+
assert metrics.roc_auc is not None
|
|
386
|
+
assert metrics.roc_auc == 1.0
|
|
387
|
+
assert any("computed only on classes present" in w for w in metrics.warnings)
|
|
388
|
+
# Other metrics should still work
|
|
389
|
+
assert metrics.accuracy == 1.0
|
|
390
|
+
assert metrics.f1_score == 1.0
|
|
391
|
+
assert metrics.loss is not None
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def test_roc_auc_with_all_classes_present():
|
|
395
|
+
"""Test that ROC AUC works when all classes are present in test set."""
|
|
396
|
+
# Model trained on classes [0, 1, 2], test set has all three
|
|
397
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
398
|
+
y_score = np.array(
|
|
399
|
+
[
|
|
400
|
+
[0.9, 0.05, 0.05], # Predicts class 0
|
|
401
|
+
[0.1, 0.8, 0.1], # Predicts class 1
|
|
402
|
+
[0.1, 0.1, 0.8], # Predicts class 2
|
|
403
|
+
[0.7, 0.2, 0.1], # Predicts class 0
|
|
404
|
+
[0.2, 0.7, 0.1], # Predicts class 1
|
|
405
|
+
[0.1, 0.2, 0.7], # Predicts class 2
|
|
406
|
+
]
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
|
|
410
|
+
|
|
411
|
+
# ROC AUC should be calculated when all classes present
|
|
412
|
+
assert metrics.roc_auc is not None
|
|
413
|
+
assert metrics.accuracy == 1.0
|
|
414
|
+
assert metrics.f1_score == 1.0
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def test_roc_auc_handles_subset_of_many_classes():
|
|
418
|
+
"""Test ROC AUC where model knows 15 classes, test has 10."""
|
|
419
|
+
# Simulate the actual error scenario from the bug report
|
|
420
|
+
num_model_classes = 15
|
|
421
|
+
num_test_classes = 10
|
|
422
|
+
num_samples = 50
|
|
423
|
+
|
|
424
|
+
# Test set only uses classes 0-9
|
|
425
|
+
y_true = np.random.randint(0, num_test_classes, size=num_samples)
|
|
426
|
+
|
|
427
|
+
# Model produces predictions for all 15 classes
|
|
428
|
+
y_score = np.random.rand(num_samples, num_model_classes)
|
|
429
|
+
y_score = y_score / y_score.sum(axis=1, keepdims=True) # Normalize to probabilities
|
|
430
|
+
|
|
431
|
+
metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
|
|
432
|
+
|
|
433
|
+
# Should calculate ROC AUC by filtering to classes 0-9
|
|
434
|
+
assert metrics.roc_auc is not None
|
|
435
|
+
assert 0.0 <= metrics.roc_auc <= 1.0
|
|
436
|
+
assert any("computed only on classes present" in w for w in metrics.warnings)
|
|
437
|
+
# Other metrics should still work
|
|
438
|
+
assert metrics.accuracy is not None
|
|
439
|
+
assert metrics.f1_score is not None
|
|
440
|
+
assert metrics.loss is not None
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def test_roc_auc_handles_unknown_classes_in_y_true():
|
|
444
|
+
"""Test that metrics handle when y_true contains classes not in y_score."""
|
|
445
|
+
# Model trained on classes [0, 1, 2], but test set has class 3
|
|
446
|
+
y_true = np.array([0, 1, 2, 3])
|
|
447
|
+
y_score = np.array(
|
|
448
|
+
[
|
|
449
|
+
[0.7, 0.2, 0.1],
|
|
450
|
+
[0.1, 0.8, 0.1],
|
|
451
|
+
[0.1, 0.1, 0.8],
|
|
452
|
+
[0.3, 0.4, 0.3], # Unknown class 3
|
|
453
|
+
]
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
457
|
+
|
|
458
|
+
# Should skip ROC AUC and loss when unknown classes present
|
|
459
|
+
assert metrics.roc_auc is None
|
|
460
|
+
assert metrics.loss is None # Loss also skipped to avoid ValueError
|
|
461
|
+
assert any("unknown" in w for w in metrics.warnings)
|
|
462
|
+
# Other metrics should still work (they handle extra classes)
|
|
463
|
+
assert metrics.accuracy is not None
|
|
464
|
+
assert metrics.f1_score is not None
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def test_roc_auc_handles_zero_probability_on_present_classes():
|
|
468
|
+
"""Test ROC AUC when a sample has zero probability on all present classes (edge case for renormalization)."""
|
|
469
|
+
# Model trained on classes [0, 1, 2, 3], test set only has [0, 1, 2]
|
|
470
|
+
# One sample has ALL probability mass on excluded class 3 (zero on [0, 1, 2])
|
|
471
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
472
|
+
y_score = np.array(
|
|
473
|
+
[
|
|
474
|
+
[0.7, 0.2, 0.08, 0.02],
|
|
475
|
+
[0.1, 0.8, 0.08, 0.02],
|
|
476
|
+
[0.1, 0.1, 0.78, 0.02],
|
|
477
|
+
[0.6, 0.3, 0.08, 0.02],
|
|
478
|
+
[0.0, 0.0, 0.0, 1.0], # zero denominator
|
|
479
|
+
[0.1, 0.1, 0.78, 0.02],
|
|
480
|
+
]
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
metrics = calculate_classification_metrics(y_true, y_score, multi_class="ovr")
|
|
484
|
+
|
|
485
|
+
# Should still calculate ROC AUC despite zero-denominator case
|
|
486
|
+
# The safe renormalization should prevent NaN/inf
|
|
487
|
+
assert metrics.roc_auc is not None
|
|
488
|
+
assert not np.isnan(metrics.roc_auc)
|
|
489
|
+
assert not np.isinf(metrics.roc_auc)
|
|
490
|
+
assert any("computed only on classes present" in w for w in metrics.warnings)
|
|
491
|
+
assert metrics.accuracy is not None
|
|
492
|
+
assert metrics.f1_score is not None
|
|
493
|
+
assert metrics.loss is not None
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def test_roc_auc_skipped_for_single_class():
|
|
497
|
+
"""Test that ROC AUC is skipped when only one class is present in y_true."""
|
|
498
|
+
# Model trained on classes [0, 1, 2], but test set only has class 0
|
|
499
|
+
y_true = np.array([0, 0, 0, 0])
|
|
500
|
+
y_score = np.array(
|
|
501
|
+
[
|
|
502
|
+
[0.9, 0.05, 0.05],
|
|
503
|
+
[0.8, 0.1, 0.1],
|
|
504
|
+
[0.85, 0.1, 0.05],
|
|
505
|
+
[0.9, 0.05, 0.05],
|
|
506
|
+
]
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
510
|
+
|
|
511
|
+
# ROC AUC requires at least 2 classes
|
|
512
|
+
assert metrics.roc_auc is None
|
|
513
|
+
assert metrics.accuracy == 1.0
|
|
514
|
+
assert metrics.loss is not None
|
|
515
|
+
assert any("requires at least 2 classes" in w for w in metrics.warnings)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
# Confusion Matrix Tests
|
|
519
|
+
def test_confusion_matrix_binary_classification():
|
|
520
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
521
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
522
|
+
|
|
523
|
+
metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
|
|
524
|
+
|
|
525
|
+
assert metrics.confusion_matrix is not None
|
|
526
|
+
expected_cm = sklearn.metrics.confusion_matrix(y_true, [0, 1, 1, 0, 0], labels=[0, 1])
|
|
527
|
+
assert metrics.confusion_matrix == expected_cm.tolist()
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def test_confusion_matrix_multiclass():
|
|
531
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
532
|
+
y_score = np.array(
|
|
533
|
+
[[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9], [0.7, 0.2, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
|
|
537
|
+
|
|
538
|
+
assert metrics.confusion_matrix is not None
|
|
539
|
+
# All predictions correct
|
|
540
|
+
assert metrics.confusion_matrix == [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def test_confusion_matrix_with_misclassifications():
|
|
544
|
+
y_true = np.array([0, 1, 2, 0, 1, 2])
|
|
545
|
+
y_score = np.array(
|
|
546
|
+
[[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.1, 0.8, 0.1], [0.2, 0.7, 0.1], [0.2, 0.7, 0.1], [0.1, 0.1, 0.8]]
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
metrics = calculate_classification_metrics(y_true, y_score, include_confusion_matrix=True)
|
|
550
|
+
|
|
551
|
+
assert metrics.confusion_matrix is not None
|
|
552
|
+
# Class 0: 1 correct (index 0), 1 predicted as class 1 (index 3)
|
|
553
|
+
# Class 1: 2 correct (indices 1, 4)
|
|
554
|
+
# Class 2: 1 predicted as class 1 (index 2), 1 correct (index 5)
|
|
555
|
+
assert metrics.confusion_matrix == [[1, 1, 0], [0, 2, 0], [0, 1, 1]]
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def test_confusion_matrix_handles_nan_logits():
|
|
559
|
+
logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
|
|
560
|
+
expected_labels = [0, 1, 0, 1]
|
|
561
|
+
metrics = calculate_classification_metrics(expected_labels, logits, include_confusion_matrix=True)
|
|
562
|
+
|
|
563
|
+
# NaN predictions are set to -1, so they won't match any true label
|
|
564
|
+
# Only the last 2 predictions are valid: pred=[1, 1], true=[0, 1]
|
|
565
|
+
assert metrics.confusion_matrix is not None
|
|
566
|
+
# With NaN handling, predictions become [-1, -1, 1, 1]
|
|
567
|
+
# Only position 3 is correct (true=1, pred=1)
|
|
568
|
+
# Position 2 is wrong (true=0, pred=1)
|
|
569
|
+
assert len(metrics.confusion_matrix) == 2 # 2 classes
|
|
570
|
+
assert len(metrics.confusion_matrix[0]) == 2
|
orca_sdk/async_client.py
CHANGED
|
@@ -62,10 +62,12 @@ class ActionRecommendation(TypedDict):
|
|
|
62
62
|
class AddMemorySuggestion(TypedDict):
|
|
63
63
|
value: str
|
|
64
64
|
label_name: str
|
|
65
|
+
similarity: NotRequired[float | None]
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
class AliveResponse(TypedDict):
|
|
68
69
|
ok: bool
|
|
70
|
+
checks: dict[str, bool]
|
|
69
71
|
|
|
70
72
|
|
|
71
73
|
class ApiKeyMetadata(TypedDict):
|
|
@@ -292,7 +294,7 @@ class JobStatusInfo(TypedDict):
|
|
|
292
294
|
class LabelClassMetrics(TypedDict):
|
|
293
295
|
label: int | None
|
|
294
296
|
label_name: NotRequired[str | None]
|
|
295
|
-
average_lookup_score: float
|
|
297
|
+
average_lookup_score: float | None
|
|
296
298
|
memory_count: int
|
|
297
299
|
|
|
298
300
|
|
|
@@ -346,7 +348,7 @@ class MemoryMetrics(TypedDict):
|
|
|
346
348
|
cluster: NotRequired[int]
|
|
347
349
|
embedding_2d: NotRequired[list]
|
|
348
350
|
anomaly_score: NotRequired[float]
|
|
349
|
-
neighbor_label_logits: NotRequired[list[float]]
|
|
351
|
+
neighbor_label_logits: NotRequired[list[float] | None]
|
|
350
352
|
neighbor_predicted_label: NotRequired[int | None]
|
|
351
353
|
neighbor_predicted_label_ambiguity: NotRequired[float]
|
|
352
354
|
neighbor_predicted_label_confidence: NotRequired[float]
|
|
@@ -1157,6 +1159,9 @@ class FieldValidationError(TypedDict):
|
|
|
1157
1159
|
|
|
1158
1160
|
class AddMemoryRecommendations(TypedDict):
|
|
1159
1161
|
memories: list[AddMemorySuggestion]
|
|
1162
|
+
attempts_used: NotRequired[int]
|
|
1163
|
+
partial: NotRequired[bool]
|
|
1164
|
+
rejection_counts: NotRequired[dict[str, int]]
|
|
1160
1165
|
|
|
1161
1166
|
|
|
1162
1167
|
class BootstrapClassificationModelRequest(TypedDict):
|
|
@@ -1272,10 +1277,19 @@ class CreateClassificationModelRequest(TypedDict):
|
|
|
1272
1277
|
num_classes: NotRequired[int | None]
|
|
1273
1278
|
|
|
1274
1279
|
|
|
1275
|
-
class
|
|
1280
|
+
class CreateMemorysetFromDatasourceRequest(TypedDict):
|
|
1276
1281
|
name: str
|
|
1277
1282
|
description: NotRequired[str | None]
|
|
1278
1283
|
notes: NotRequired[str | None]
|
|
1284
|
+
pretrained_embedding_model_name: NotRequired[PretrainedEmbeddingModelName | None]
|
|
1285
|
+
finetuned_embedding_model_name_or_id: NotRequired[str | None]
|
|
1286
|
+
max_seq_length_override: NotRequired[int | None]
|
|
1287
|
+
label_names: NotRequired[list[str] | None]
|
|
1288
|
+
index_type: NotRequired[Literal["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "DISKANN"]]
|
|
1289
|
+
index_params: NotRequired[dict[str, int | float | str]]
|
|
1290
|
+
prompt: NotRequired[str]
|
|
1291
|
+
hidden: NotRequired[bool]
|
|
1292
|
+
memory_type: NotRequired[MemoryType | None]
|
|
1279
1293
|
datasource_name_or_id: str
|
|
1280
1294
|
datasource_label_column: NotRequired[str | None]
|
|
1281
1295
|
datasource_score_column: NotRequired[str | None]
|
|
@@ -1283,6 +1297,14 @@ class CreateMemorysetRequest(TypedDict):
|
|
|
1283
1297
|
datasource_source_id_column: NotRequired[str | None]
|
|
1284
1298
|
datasource_partition_id_column: NotRequired[str | None]
|
|
1285
1299
|
remove_duplicates: NotRequired[bool]
|
|
1300
|
+
batch_size: NotRequired[int]
|
|
1301
|
+
subsample: NotRequired[int | float | None]
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
class CreateMemorysetRequest(TypedDict):
|
|
1305
|
+
name: str
|
|
1306
|
+
description: NotRequired[str | None]
|
|
1307
|
+
notes: NotRequired[str | None]
|
|
1286
1308
|
pretrained_embedding_model_name: NotRequired[PretrainedEmbeddingModelName | None]
|
|
1287
1309
|
finetuned_embedding_model_name_or_id: NotRequired[str | None]
|
|
1288
1310
|
max_seq_length_override: NotRequired[int | None]
|
|
@@ -1291,9 +1313,7 @@ class CreateMemorysetRequest(TypedDict):
|
|
|
1291
1313
|
index_params: NotRequired[dict[str, int | float | str]]
|
|
1292
1314
|
prompt: NotRequired[str]
|
|
1293
1315
|
hidden: NotRequired[bool]
|
|
1294
|
-
|
|
1295
|
-
subsample: NotRequired[int | float | None]
|
|
1296
|
-
memory_type: NotRequired[MemoryType]
|
|
1316
|
+
memory_type: NotRequired[MemoryType | None]
|
|
1297
1317
|
|
|
1298
1318
|
|
|
1299
1319
|
class CreateRegressionModelRequest(TypedDict):
|
|
@@ -1653,8 +1673,8 @@ class MemorysetMetadata(TypedDict):
|
|
|
1653
1673
|
created_at: str
|
|
1654
1674
|
updated_at: str
|
|
1655
1675
|
memories_updated_at: str
|
|
1656
|
-
insertion_job_id: str
|
|
1657
|
-
insertion_status: JobStatus
|
|
1676
|
+
insertion_job_id: str | None
|
|
1677
|
+
insertion_status: JobStatus | None
|
|
1658
1678
|
metrics: MemorysetMetrics
|
|
1659
1679
|
memory_type: MemoryType
|
|
1660
1680
|
label_names: list[str] | None
|
|
@@ -1664,7 +1684,7 @@ class MemorysetMetadata(TypedDict):
|
|
|
1664
1684
|
document_prompt_override: str | None
|
|
1665
1685
|
query_prompt_override: str | None
|
|
1666
1686
|
hidden: bool
|
|
1667
|
-
insertion_task_id: str
|
|
1687
|
+
insertion_task_id: str | None
|
|
1668
1688
|
|
|
1669
1689
|
|
|
1670
1690
|
class PaginatedWorkerInfo(TypedDict):
|
|
@@ -1735,6 +1755,22 @@ class OrcaAsyncClient(AsyncClient):
|
|
|
1735
1755
|
) -> AliveResponse:
|
|
1736
1756
|
pass
|
|
1737
1757
|
|
|
1758
|
+
@overload
|
|
1759
|
+
async def GET(
|
|
1760
|
+
self,
|
|
1761
|
+
path: Literal["/gpu/check/alive"],
|
|
1762
|
+
*,
|
|
1763
|
+
params: None = None,
|
|
1764
|
+
parse_as: Literal["json"] = "json",
|
|
1765
|
+
headers: HeaderTypes | None = None,
|
|
1766
|
+
cookies: CookieTypes | None = None,
|
|
1767
|
+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1768
|
+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1769
|
+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1770
|
+
extensions: RequestExtensions | None = None,
|
|
1771
|
+
) -> AliveResponse:
|
|
1772
|
+
pass
|
|
1773
|
+
|
|
1738
1774
|
@overload
|
|
1739
1775
|
async def GET(
|
|
1740
1776
|
self,
|
|
@@ -1754,7 +1790,7 @@ class OrcaAsyncClient(AsyncClient):
|
|
|
1754
1790
|
@overload
|
|
1755
1791
|
async def GET(
|
|
1756
1792
|
self,
|
|
1757
|
-
path: Literal["/
|
|
1793
|
+
path: Literal["/check/healthy"],
|
|
1758
1794
|
*,
|
|
1759
1795
|
params: None = None,
|
|
1760
1796
|
parse_as: Literal["json"] = "json",
|
|
@@ -1770,7 +1806,7 @@ class OrcaAsyncClient(AsyncClient):
|
|
|
1770
1806
|
@overload
|
|
1771
1807
|
async def GET(
|
|
1772
1808
|
self,
|
|
1773
|
-
path: Literal["/check/healthy"],
|
|
1809
|
+
path: Literal["/gpu/check/healthy"],
|
|
1774
1810
|
*,
|
|
1775
1811
|
params: None = None,
|
|
1776
1812
|
parse_as: Literal["json"] = "json",
|
|
@@ -2895,6 +2931,26 @@ class OrcaAsyncClient(AsyncClient):
|
|
|
2895
2931
|
path: Literal["/memoryset"],
|
|
2896
2932
|
*,
|
|
2897
2933
|
params: None = None,
|
|
2934
|
+
json: CreateMemorysetFromDatasourceRequest,
|
|
2935
|
+
data: None = None,
|
|
2936
|
+
files: None = None,
|
|
2937
|
+
content: None = None,
|
|
2938
|
+
parse_as: Literal["json"] = "json",
|
|
2939
|
+
headers: HeaderTypes | None = None,
|
|
2940
|
+
cookies: CookieTypes | None = None,
|
|
2941
|
+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
2942
|
+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
2943
|
+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
2944
|
+
extensions: RequestExtensions | None = None,
|
|
2945
|
+
) -> MemorysetMetadata:
|
|
2946
|
+
pass
|
|
2947
|
+
|
|
2948
|
+
@overload
|
|
2949
|
+
async def POST(
|
|
2950
|
+
self,
|
|
2951
|
+
path: Literal["/memoryset/empty"],
|
|
2952
|
+
*,
|
|
2953
|
+
params: None = None,
|
|
2898
2954
|
json: CreateMemorysetRequest,
|
|
2899
2955
|
data: None = None,
|
|
2900
2956
|
files: None = None,
|
orca_sdk/classification_model.py
CHANGED
|
@@ -114,13 +114,14 @@ class ClassificationModel:
|
|
|
114
114
|
return isinstance(other, ClassificationModel) and self.id == other.id
|
|
115
115
|
|
|
116
116
|
def __repr__(self):
|
|
117
|
+
memoryset_repr = self.memoryset.__repr__().replace("\n", "\n ")
|
|
117
118
|
return (
|
|
118
119
|
"ClassificationModel({\n"
|
|
119
120
|
f" name: '{self.name}',\n"
|
|
120
121
|
f" head_type: {self.head_type},\n"
|
|
121
122
|
f" num_classes: {self.num_classes},\n"
|
|
122
123
|
f" memory_lookup_count: {self.memory_lookup_count},\n"
|
|
123
|
-
f" memoryset:
|
|
124
|
+
f" memoryset: {memoryset_repr},\n"
|
|
124
125
|
"})"
|
|
125
126
|
)
|
|
126
127
|
|