orca-sdk 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_utils/analysis_ui.py +1 -1
- orca_sdk/_utils/data_parsing.py +16 -12
- orca_sdk/_utils/data_parsing_test.py +8 -8
- orca_sdk/async_client.py +96 -28
- orca_sdk/classification_model.py +184 -104
- orca_sdk/classification_model_test.py +8 -4
- orca_sdk/client.py +96 -28
- orca_sdk/credentials.py +8 -10
- orca_sdk/datasource.py +3 -3
- orca_sdk/memoryset.py +64 -38
- orca_sdk/memoryset_test.py +5 -3
- orca_sdk/regression_model.py +124 -67
- orca_sdk/regression_model_test.py +8 -4
- {orca_sdk-0.1.8.dist-info → orca_sdk-0.1.10.dist-info}/METADATA +4 -4
- {orca_sdk-0.1.8.dist-info → orca_sdk-0.1.10.dist-info}/RECORD +16 -16
- {orca_sdk-0.1.8.dist-info → orca_sdk-0.1.10.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -15,6 +15,7 @@ from .client import (
|
|
|
15
15
|
BootstrapLabeledMemoryDataResult,
|
|
16
16
|
ClassificationModelMetadata,
|
|
17
17
|
ClassificationPredictionRequest,
|
|
18
|
+
ListPredictionsRequest,
|
|
18
19
|
OrcaClient,
|
|
19
20
|
PredictiveModelUpdate,
|
|
20
21
|
RACHeadType,
|
|
@@ -363,6 +364,7 @@ class ClassificationModel:
|
|
|
363
364
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
364
365
|
] = "include_global",
|
|
365
366
|
use_gpu: bool = True,
|
|
367
|
+
batch_size: int = 100,
|
|
366
368
|
) -> list[ClassificationPrediction]:
|
|
367
369
|
pass
|
|
368
370
|
|
|
@@ -383,6 +385,7 @@ class ClassificationModel:
|
|
|
383
385
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
384
386
|
] = "include_global",
|
|
385
387
|
use_gpu: bool = True,
|
|
388
|
+
batch_size: int = 100,
|
|
386
389
|
) -> ClassificationPrediction:
|
|
387
390
|
pass
|
|
388
391
|
|
|
@@ -402,6 +405,7 @@ class ClassificationModel:
|
|
|
402
405
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
403
406
|
] = "include_global",
|
|
404
407
|
use_gpu: bool = True,
|
|
408
|
+
batch_size: int = 100,
|
|
405
409
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
406
410
|
"""
|
|
407
411
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -429,6 +433,7 @@ class ClassificationModel:
|
|
|
429
433
|
* `"exclude_global"`: Exclude global memories
|
|
430
434
|
* `"only_global"`: Only include global memories
|
|
431
435
|
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
436
|
+
batch_size: Number of values to process in a single API call
|
|
432
437
|
|
|
433
438
|
Returns:
|
|
434
439
|
Label prediction or list of label predictions
|
|
@@ -456,6 +461,8 @@ class ClassificationModel:
|
|
|
456
461
|
|
|
457
462
|
if timeout_seconds <= 0:
|
|
458
463
|
raise ValueError("timeout_seconds must be a positive integer")
|
|
464
|
+
if batch_size <= 0 or batch_size > 500:
|
|
465
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
459
466
|
|
|
460
467
|
parsed_filters = [
|
|
461
468
|
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
@@ -464,10 +471,17 @@ class ClassificationModel:
|
|
|
464
471
|
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
465
472
|
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
466
473
|
|
|
474
|
+
# Convert to list for batching
|
|
475
|
+
values = value if isinstance(value, list) else [value]
|
|
476
|
+
if isinstance(expected_labels, list) and len(expected_labels) != len(values):
|
|
477
|
+
raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
|
|
478
|
+
if isinstance(partition_id, list) and len(partition_id) != len(values):
|
|
479
|
+
raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
|
|
480
|
+
|
|
467
481
|
if isinstance(expected_labels, int):
|
|
468
|
-
expected_labels = [expected_labels]
|
|
482
|
+
expected_labels = [expected_labels] * len(values)
|
|
469
483
|
elif isinstance(expected_labels, str):
|
|
470
|
-
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
484
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
|
|
471
485
|
elif isinstance(expected_labels, list):
|
|
472
486
|
expected_labels = [
|
|
473
487
|
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
@@ -481,47 +495,56 @@ class ClassificationModel:
|
|
|
481
495
|
|
|
482
496
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
483
497
|
client = OrcaClient._resolve_client()
|
|
484
|
-
request_json: ClassificationPredictionRequest = {
|
|
485
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
486
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
487
|
-
"expected_labels": expected_labels,
|
|
488
|
-
"tags": list(tags or set()),
|
|
489
|
-
"save_telemetry": telemetry_on,
|
|
490
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
491
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
492
|
-
"prompt": prompt,
|
|
493
|
-
"use_lookup_cache": use_lookup_cache,
|
|
494
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
495
|
-
"partition_filter_mode": partition_filter_mode,
|
|
496
|
-
}
|
|
497
|
-
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
498
|
-
if partition_filter_mode != "ignore_partitions":
|
|
499
|
-
request_json["partition_ids"] = partition_id
|
|
500
|
-
response = client.POST(
|
|
501
|
-
endpoint,
|
|
502
|
-
params={"name_or_id": self.id},
|
|
503
|
-
json=request_json,
|
|
504
|
-
timeout=timeout_seconds,
|
|
505
|
-
)
|
|
506
498
|
|
|
507
|
-
|
|
508
|
-
|
|
499
|
+
predictions: list[ClassificationPrediction] = []
|
|
500
|
+
for i in range(0, len(values), batch_size):
|
|
501
|
+
batch_values = values[i : i + batch_size]
|
|
502
|
+
batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
|
|
509
503
|
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
504
|
+
request_json: ClassificationPredictionRequest = {
|
|
505
|
+
"input_values": batch_values,
|
|
506
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
507
|
+
"expected_labels": batch_expected_labels,
|
|
508
|
+
"tags": list(tags or set()),
|
|
509
|
+
"save_telemetry": telemetry_on,
|
|
510
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
511
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
512
|
+
"prompt": prompt,
|
|
513
|
+
"use_lookup_cache": use_lookup_cache,
|
|
514
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
515
|
+
"partition_filter_mode": partition_filter_mode,
|
|
516
|
+
}
|
|
517
|
+
if partition_filter_mode != "ignore_partitions":
|
|
518
|
+
request_json["partition_ids"] = (
|
|
519
|
+
partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
response = client.POST(
|
|
523
|
+
endpoint,
|
|
524
|
+
params={"name_or_id": self.id},
|
|
525
|
+
json=request_json,
|
|
526
|
+
timeout=timeout_seconds,
|
|
522
527
|
)
|
|
523
|
-
|
|
524
|
-
|
|
528
|
+
|
|
529
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
530
|
+
raise RuntimeError("Failed to save some prediction to database.")
|
|
531
|
+
|
|
532
|
+
predictions.extend(
|
|
533
|
+
ClassificationPrediction(
|
|
534
|
+
prediction_id=prediction["prediction_id"],
|
|
535
|
+
label=prediction["label"],
|
|
536
|
+
label_name=prediction["label_name"],
|
|
537
|
+
score=None,
|
|
538
|
+
confidence=prediction["confidence"],
|
|
539
|
+
anomaly_score=prediction["anomaly_score"],
|
|
540
|
+
memoryset=self.memoryset,
|
|
541
|
+
model=self,
|
|
542
|
+
logits=prediction["logits"],
|
|
543
|
+
input_value=input_value,
|
|
544
|
+
)
|
|
545
|
+
for prediction, input_value in zip(response, batch_values)
|
|
546
|
+
)
|
|
547
|
+
|
|
525
548
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
526
549
|
self._last_prediction = predictions[-1]
|
|
527
550
|
return predictions if isinstance(value, list) else predictions[0]
|
|
@@ -542,6 +565,7 @@ class ClassificationModel:
|
|
|
542
565
|
partition_filter_mode: Literal[
|
|
543
566
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
544
567
|
] = "include_global",
|
|
568
|
+
batch_size: int = 100,
|
|
545
569
|
) -> list[ClassificationPrediction]:
|
|
546
570
|
pass
|
|
547
571
|
|
|
@@ -561,6 +585,7 @@ class ClassificationModel:
|
|
|
561
585
|
partition_filter_mode: Literal[
|
|
562
586
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
563
587
|
] = "include_global",
|
|
588
|
+
batch_size: int = 100,
|
|
564
589
|
) -> ClassificationPrediction:
|
|
565
590
|
pass
|
|
566
591
|
|
|
@@ -579,6 +604,7 @@ class ClassificationModel:
|
|
|
579
604
|
partition_filter_mode: Literal[
|
|
580
605
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
581
606
|
] = "include_global",
|
|
607
|
+
batch_size: int = 100,
|
|
582
608
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
583
609
|
"""
|
|
584
610
|
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -605,6 +631,8 @@ class ClassificationModel:
|
|
|
605
631
|
* `"include_global"`: Include global memories
|
|
606
632
|
* `"exclude_global"`: Exclude global memories
|
|
607
633
|
* `"only_global"`: Only include global memories
|
|
634
|
+
batch_size: Number of values to process in a single API call
|
|
635
|
+
|
|
608
636
|
Returns:
|
|
609
637
|
Label prediction or list of label predictions.
|
|
610
638
|
|
|
@@ -631,6 +659,8 @@ class ClassificationModel:
|
|
|
631
659
|
|
|
632
660
|
if timeout_seconds <= 0:
|
|
633
661
|
raise ValueError("timeout_seconds must be a positive integer")
|
|
662
|
+
if batch_size <= 0 or batch_size > 500:
|
|
663
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
634
664
|
|
|
635
665
|
parsed_filters = [
|
|
636
666
|
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
@@ -639,10 +669,17 @@ class ClassificationModel:
|
|
|
639
669
|
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
640
670
|
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
641
671
|
|
|
672
|
+
# Convert to list for batching
|
|
673
|
+
values = value if isinstance(value, list) else [value]
|
|
674
|
+
if isinstance(expected_labels, list) and len(expected_labels) != len(values):
|
|
675
|
+
raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
|
|
676
|
+
if isinstance(partition_id, list) and len(partition_id) != len(values):
|
|
677
|
+
raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
|
|
678
|
+
|
|
642
679
|
if isinstance(expected_labels, int):
|
|
643
|
-
expected_labels = [expected_labels]
|
|
680
|
+
expected_labels = [expected_labels] * len(values)
|
|
644
681
|
elif isinstance(expected_labels, str):
|
|
645
|
-
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
682
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
|
|
646
683
|
elif isinstance(expected_labels, list):
|
|
647
684
|
expected_labels = [
|
|
648
685
|
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
@@ -651,75 +688,89 @@ class ClassificationModel:
|
|
|
651
688
|
|
|
652
689
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
653
690
|
client = OrcaAsyncClient._resolve_client()
|
|
654
|
-
request_json: ClassificationPredictionRequest = {
|
|
655
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
656
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
657
|
-
"expected_labels": expected_labels,
|
|
658
|
-
"tags": list(tags or set()),
|
|
659
|
-
"save_telemetry": telemetry_on,
|
|
660
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
661
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
662
|
-
"prompt": prompt,
|
|
663
|
-
"use_lookup_cache": use_lookup_cache,
|
|
664
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
665
|
-
"partition_filter_mode": partition_filter_mode,
|
|
666
|
-
}
|
|
667
|
-
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
668
|
-
if partition_filter_mode != "ignore_partitions":
|
|
669
|
-
request_json["partition_ids"] = partition_id
|
|
670
|
-
response = await client.POST(
|
|
671
|
-
"/gpu/classification_model/{name_or_id}/prediction",
|
|
672
|
-
params={"name_or_id": self.id},
|
|
673
|
-
json=request_json,
|
|
674
|
-
timeout=timeout_seconds,
|
|
675
|
-
)
|
|
676
691
|
|
|
677
|
-
|
|
678
|
-
|
|
692
|
+
predictions: list[ClassificationPrediction] = []
|
|
693
|
+
for i in range(0, len(values), batch_size):
|
|
694
|
+
batch_values = values[i : i + batch_size]
|
|
695
|
+
batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
|
|
679
696
|
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
697
|
+
request_json: ClassificationPredictionRequest = {
|
|
698
|
+
"input_values": batch_values,
|
|
699
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
700
|
+
"expected_labels": batch_expected_labels,
|
|
701
|
+
"tags": list(tags or set()),
|
|
702
|
+
"save_telemetry": telemetry_on,
|
|
703
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
704
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
705
|
+
"prompt": prompt,
|
|
706
|
+
"use_lookup_cache": use_lookup_cache,
|
|
707
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
708
|
+
"partition_filter_mode": partition_filter_mode,
|
|
709
|
+
}
|
|
710
|
+
if partition_filter_mode != "ignore_partitions":
|
|
711
|
+
request_json["partition_ids"] = (
|
|
712
|
+
partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
|
|
713
|
+
)
|
|
714
|
+
response = await client.POST(
|
|
715
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
716
|
+
params={"name_or_id": self.id},
|
|
717
|
+
json=request_json,
|
|
718
|
+
timeout=timeout_seconds,
|
|
692
719
|
)
|
|
693
|
-
|
|
694
|
-
|
|
720
|
+
|
|
721
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
722
|
+
raise RuntimeError("Failed to save some prediction to database.")
|
|
723
|
+
|
|
724
|
+
predictions.extend(
|
|
725
|
+
ClassificationPrediction(
|
|
726
|
+
prediction_id=prediction["prediction_id"],
|
|
727
|
+
label=prediction["label"],
|
|
728
|
+
label_name=prediction["label_name"],
|
|
729
|
+
score=None,
|
|
730
|
+
confidence=prediction["confidence"],
|
|
731
|
+
anomaly_score=prediction["anomaly_score"],
|
|
732
|
+
memoryset=self.memoryset,
|
|
733
|
+
model=self,
|
|
734
|
+
logits=prediction["logits"],
|
|
735
|
+
input_value=input_value,
|
|
736
|
+
)
|
|
737
|
+
for prediction, input_value in zip(response, batch_values)
|
|
738
|
+
)
|
|
739
|
+
|
|
695
740
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
696
741
|
self._last_prediction = predictions[-1]
|
|
697
742
|
return predictions if isinstance(value, list) else predictions[0]
|
|
698
743
|
|
|
699
744
|
def predictions(
|
|
700
745
|
self,
|
|
701
|
-
limit: int =
|
|
746
|
+
limit: int | None = None,
|
|
702
747
|
offset: int = 0,
|
|
703
748
|
tag: str | None = None,
|
|
704
749
|
sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
|
|
705
750
|
expected_label_match: bool | None = None,
|
|
751
|
+
batch_size: int = 100,
|
|
706
752
|
) -> list[ClassificationPrediction]:
|
|
707
753
|
"""
|
|
708
754
|
Get a list of predictions made by this model
|
|
709
755
|
|
|
710
756
|
Params:
|
|
711
|
-
limit:
|
|
757
|
+
limit: Maximum number of predictions to return. If `None`, returns all predictions
|
|
758
|
+
by automatically paginating through results.
|
|
712
759
|
offset: Optional offset of the first prediction to return
|
|
713
760
|
tag: Optional tag to filter predictions by
|
|
714
761
|
sort: Optional list of columns and directions to sort the predictions by.
|
|
715
762
|
Predictions can be sorted by `timestamp` or `confidence`.
|
|
716
763
|
expected_label_match: Optional filter to only include predictions where the expected
|
|
717
764
|
label does (`True`) or doesn't (`False`) match the predicted label
|
|
765
|
+
batch_size: Number of predictions to fetch in a single API call
|
|
718
766
|
|
|
719
767
|
Returns:
|
|
720
768
|
List of label predictions
|
|
721
769
|
|
|
722
770
|
Examples:
|
|
771
|
+
Get all predictions with a specific tag:
|
|
772
|
+
>>> predictions = model.predictions(tag="evaluation")
|
|
773
|
+
|
|
723
774
|
Get the last 3 predictions:
|
|
724
775
|
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
725
776
|
[
|
|
@@ -737,33 +788,61 @@ class ClassificationModel:
|
|
|
737
788
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
738
789
|
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
739
790
|
"""
|
|
791
|
+
if batch_size <= 0 or batch_size > 500:
|
|
792
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
793
|
+
if limit == 0:
|
|
794
|
+
return []
|
|
795
|
+
|
|
740
796
|
client = OrcaClient._resolve_client()
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
797
|
+
all_predictions: list[ClassificationPrediction] = []
|
|
798
|
+
|
|
799
|
+
if limit is not None and limit < batch_size:
|
|
800
|
+
pages = [(offset, limit)]
|
|
801
|
+
else:
|
|
802
|
+
# automatically paginate the requests if necessary
|
|
803
|
+
total = client.POST(
|
|
804
|
+
"/telemetry/prediction/count",
|
|
805
|
+
json={
|
|
806
|
+
"model_id": self.id,
|
|
807
|
+
"tag": tag,
|
|
808
|
+
"expected_label_match": expected_label_match,
|
|
809
|
+
},
|
|
810
|
+
)
|
|
811
|
+
max_limit = max(total - offset, 0)
|
|
812
|
+
limit = min(limit, max_limit) if limit is not None else max_limit
|
|
813
|
+
pages = [(o, min(batch_size, limit - (o - offset))) for o in range(offset, offset + limit, batch_size)]
|
|
814
|
+
|
|
815
|
+
for current_offset, current_limit in pages:
|
|
816
|
+
request_json: ListPredictionsRequest = {
|
|
744
817
|
"model_id": self.id,
|
|
745
|
-
"limit":
|
|
746
|
-
"offset":
|
|
747
|
-
"sort": [list(sort_item) for sort_item in sort],
|
|
818
|
+
"limit": current_limit,
|
|
819
|
+
"offset": current_offset,
|
|
748
820
|
"tag": tag,
|
|
749
821
|
"expected_label_match": expected_label_match,
|
|
750
|
-
}
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
label_name=prediction["label_name"],
|
|
757
|
-
score=None,
|
|
758
|
-
confidence=prediction["confidence"],
|
|
759
|
-
anomaly_score=prediction["anomaly_score"],
|
|
760
|
-
memoryset=self.memoryset,
|
|
761
|
-
model=self,
|
|
762
|
-
telemetry=prediction,
|
|
822
|
+
}
|
|
823
|
+
if sort:
|
|
824
|
+
request_json["sort"] = sort
|
|
825
|
+
response = client.POST(
|
|
826
|
+
"/telemetry/prediction",
|
|
827
|
+
json=request_json,
|
|
763
828
|
)
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
829
|
+
all_predictions.extend(
|
|
830
|
+
ClassificationPrediction(
|
|
831
|
+
prediction_id=prediction["prediction_id"],
|
|
832
|
+
label=prediction["label"],
|
|
833
|
+
label_name=prediction["label_name"],
|
|
834
|
+
score=None,
|
|
835
|
+
confidence=prediction["confidence"],
|
|
836
|
+
anomaly_score=prediction["anomaly_score"],
|
|
837
|
+
memoryset=self.memoryset,
|
|
838
|
+
model=self,
|
|
839
|
+
telemetry=prediction,
|
|
840
|
+
)
|
|
841
|
+
for prediction in response
|
|
842
|
+
if "label" in prediction
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
return all_predictions
|
|
767
846
|
|
|
768
847
|
def _evaluate_datasource(
|
|
769
848
|
self,
|
|
@@ -861,6 +940,7 @@ class ClassificationModel:
|
|
|
861
940
|
logits=[p.logits for p in predictions],
|
|
862
941
|
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
863
942
|
include_curves=True,
|
|
943
|
+
include_confusion_matrix=True,
|
|
864
944
|
)
|
|
865
945
|
|
|
866
946
|
@overload
|
|
@@ -218,13 +218,17 @@ def test_evaluate_dataset_with_nones_raises_error(classification_model: Classifi
|
|
|
218
218
|
|
|
219
219
|
|
|
220
220
|
def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
|
|
221
|
-
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
|
|
221
|
+
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"}, batch_size=2)
|
|
222
222
|
assert result is not None
|
|
223
223
|
assert isinstance(result, ClassificationMetrics)
|
|
224
|
-
predictions = classification_model.predictions(tag="test")
|
|
224
|
+
predictions = classification_model.predictions(tag="test", batch_size=100, sort=[("timestamp", "asc")])
|
|
225
225
|
assert len(predictions) == 4
|
|
226
226
|
assert all(p.tags == {"test"} for p in predictions)
|
|
227
|
-
|
|
227
|
+
prediction_expected_labels = [p.expected_label if p.expected_label is not None else -1 for p in predictions]
|
|
228
|
+
eval_expected_labels = list(eval_dataset["label"])
|
|
229
|
+
assert all(
|
|
230
|
+
p == l for p, l in zip(prediction_expected_labels, eval_expected_labels)
|
|
231
|
+
), f"Prediction expected labels: {prediction_expected_labels} do not match eval expected labels: {eval_expected_labels}"
|
|
228
232
|
|
|
229
233
|
|
|
230
234
|
def test_evaluate_with_partition_column_dataset(partitioned_classification_model: ClassificationModel):
|
|
@@ -361,7 +365,7 @@ def test_evaluate_with_partition_column_datasource(partitioned_classification_mo
|
|
|
361
365
|
|
|
362
366
|
|
|
363
367
|
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
364
|
-
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
368
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], batch_size=1)
|
|
365
369
|
assert len(predictions) == 2
|
|
366
370
|
assert predictions[0].prediction_id is not None
|
|
367
371
|
assert predictions[1].prediction_id is not None
|