orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +3 -3
- orca_sdk/_utils/analysis_ui.py +4 -1
- orca_sdk/_utils/auth.py +2 -3
- orca_sdk/_utils/common.py +24 -1
- orca_sdk/_utils/prediction_result_ui.py +4 -1
- orca_sdk/_utils/torch_parsing.py +77 -0
- orca_sdk/_utils/torch_parsing_test.py +142 -0
- orca_sdk/_utils/value_parser.py +44 -17
- orca_sdk/_utils/value_parser_test.py +6 -5
- orca_sdk/async_client.py +234 -22
- orca_sdk/classification_model.py +203 -66
- orca_sdk/classification_model_test.py +85 -25
- orca_sdk/client.py +234 -20
- orca_sdk/conftest.py +97 -16
- orca_sdk/credentials_test.py +5 -8
- orca_sdk/datasource.py +44 -21
- orca_sdk/datasource_test.py +8 -2
- orca_sdk/embedding_model.py +15 -33
- orca_sdk/embedding_model_test.py +30 -1
- orca_sdk/memoryset.py +558 -425
- orca_sdk/memoryset_test.py +120 -185
- orca_sdk/regression_model.py +186 -65
- orca_sdk/regression_model_test.py +62 -3
- orca_sdk/telemetry.py +16 -7
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/METADATA +4 -8
- orca_sdk-0.1.12.dist-info/RECORD +38 -0
- orca_sdk/_shared/__init__.py +0 -10
- orca_sdk/_shared/metrics.py +0 -634
- orca_sdk/_shared/metrics_test.py +0 -570
- orca_sdk/_utils/data_parsing.py +0 -129
- orca_sdk/_utils/data_parsing_test.py +0 -244
- orca_sdk-0.1.10.dist-info/RECORD +0 -41
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -1,28 +1,39 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
from contextlib import contextmanager
|
|
5
4
|
from datetime import datetime
|
|
6
|
-
from typing import
|
|
7
|
-
|
|
8
|
-
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Any,
|
|
8
|
+
Generator,
|
|
9
|
+
Iterable,
|
|
10
|
+
Literal,
|
|
11
|
+
Sequence,
|
|
12
|
+
cast,
|
|
13
|
+
overload,
|
|
14
|
+
)
|
|
9
15
|
|
|
10
|
-
from .
|
|
11
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
16
|
+
from ._utils.common import UNSET, CreateMode, DropMode, logger
|
|
12
17
|
from .async_client import OrcaAsyncClient
|
|
13
18
|
from .client import (
|
|
14
19
|
BootstrapClassificationModelMeta,
|
|
15
20
|
BootstrapLabeledMemoryDataResult,
|
|
21
|
+
)
|
|
22
|
+
from .client import ClassificationMetrics as ClassificationMetricsResponse
|
|
23
|
+
from .client import (
|
|
16
24
|
ClassificationModelMetadata,
|
|
17
25
|
ClassificationPredictionRequest,
|
|
18
26
|
ListPredictionsRequest,
|
|
19
27
|
OrcaClient,
|
|
28
|
+
PRCurve,
|
|
20
29
|
PredictiveModelUpdate,
|
|
21
30
|
RACHeadType,
|
|
31
|
+
ROCCurve,
|
|
22
32
|
)
|
|
23
33
|
from .datasource import Datasource
|
|
24
34
|
from .job import Job
|
|
25
35
|
from .memoryset import (
|
|
36
|
+
ConsistencyLevel,
|
|
26
37
|
FilterItem,
|
|
27
38
|
FilterItemTuple,
|
|
28
39
|
LabeledMemoryset,
|
|
@@ -36,6 +47,115 @@ from .telemetry import (
|
|
|
36
47
|
_parse_feedback,
|
|
37
48
|
)
|
|
38
49
|
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
# Peer dependency - user has datasets if they have a Dataset object
|
|
52
|
+
from datasets import Dataset as HFDataset # type: ignore
|
|
53
|
+
from pandas import DataFrame as PandasDataFrame # type: ignore
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ClassificationMetrics:
|
|
57
|
+
"""
|
|
58
|
+
Metrics for evaluating classification model performance.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
coverage: Percentage of predictions that are not none
|
|
62
|
+
f1_score: F1 score of the predictions
|
|
63
|
+
accuracy: Accuracy of the predictions
|
|
64
|
+
loss: Cross-entropy loss of the logits
|
|
65
|
+
anomaly_score_mean: Mean of anomaly scores across the dataset
|
|
66
|
+
anomaly_score_median: Median of anomaly scores across the dataset
|
|
67
|
+
anomaly_score_variance: Variance of anomaly scores across the dataset
|
|
68
|
+
roc_auc: Receiver operating characteristic area under the curve
|
|
69
|
+
pr_auc: Average precision (area under the precision-recall curve)
|
|
70
|
+
pr_curve: Precision-recall curve
|
|
71
|
+
roc_curve: Receiver operating characteristic curve
|
|
72
|
+
confusion_matrix: Confusion matrix where entry (i, j) is count of samples with true label i predicted as j
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
coverage: float
|
|
76
|
+
f1_score: float
|
|
77
|
+
accuracy: float
|
|
78
|
+
loss: float | None
|
|
79
|
+
anomaly_score_mean: float | None
|
|
80
|
+
anomaly_score_median: float | None
|
|
81
|
+
anomaly_score_variance: float | None
|
|
82
|
+
roc_auc: float | None
|
|
83
|
+
pr_auc: float | None
|
|
84
|
+
pr_curve: PRCurve | None
|
|
85
|
+
roc_curve: ROCCurve | None
|
|
86
|
+
confusion_matrix: list[list[int]] | None
|
|
87
|
+
|
|
88
|
+
def __init__(self, response: ClassificationMetricsResponse):
|
|
89
|
+
self.coverage = response["coverage"]
|
|
90
|
+
self.f1_score = response["f1_score"]
|
|
91
|
+
self.accuracy = response["accuracy"]
|
|
92
|
+
self.loss = response.get("loss")
|
|
93
|
+
self.anomaly_score_mean = response.get("anomaly_score_mean")
|
|
94
|
+
self.anomaly_score_median = response.get("anomaly_score_median")
|
|
95
|
+
self.anomaly_score_variance = response.get("anomaly_score_variance")
|
|
96
|
+
self.roc_auc = response.get("roc_auc")
|
|
97
|
+
self.pr_auc = response.get("pr_auc")
|
|
98
|
+
self.pr_curve = response.get("pr_curve")
|
|
99
|
+
self.roc_curve = response.get("roc_curve")
|
|
100
|
+
self.confusion_matrix = response.get("confusion_matrix")
|
|
101
|
+
for warning in response.get("warnings", []):
|
|
102
|
+
logger.warning(warning)
|
|
103
|
+
|
|
104
|
+
def __repr__(self) -> str:
|
|
105
|
+
return (
|
|
106
|
+
"ClassificationMetrics({\n"
|
|
107
|
+
+ f" accuracy: {self.accuracy:.4f},\n"
|
|
108
|
+
+ f" f1_score: {self.f1_score:.4f},\n"
|
|
109
|
+
+ (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
|
|
110
|
+
+ (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
|
|
111
|
+
+ (
|
|
112
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
113
|
+
if self.anomaly_score_mean
|
|
114
|
+
else ""
|
|
115
|
+
)
|
|
116
|
+
+ "})"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def compute(
|
|
121
|
+
cls,
|
|
122
|
+
predictions: Sequence[ClassificationPrediction],
|
|
123
|
+
) -> ClassificationMetrics:
|
|
124
|
+
"""
|
|
125
|
+
Compute classification metrics from a list of predictions.
|
|
126
|
+
|
|
127
|
+
Params:
|
|
128
|
+
predictions: List of ClassificationPrediction objects with expected_label set
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
ClassificationMetrics with computed metrics
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError: If any prediction is missing expected_label or logits
|
|
135
|
+
"""
|
|
136
|
+
if len(predictions) > 100_000:
|
|
137
|
+
raise ValueError("Too many predictions, maximum is 100,000")
|
|
138
|
+
logits = [p.logits for p in predictions]
|
|
139
|
+
if any(p.expected_label is None for p in predictions):
|
|
140
|
+
raise ValueError("All predictions must have expected_labels")
|
|
141
|
+
expected_labels = [cast(int, cp.expected_label) for cp in predictions]
|
|
142
|
+
anomaly_scores = (
|
|
143
|
+
None
|
|
144
|
+
if any(p.anomaly_score is None for p in predictions)
|
|
145
|
+
else [cast(float, p.anomaly_score) for p in predictions]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
client = OrcaClient._resolve_client()
|
|
149
|
+
response = client.POST(
|
|
150
|
+
"/classification_model/metrics",
|
|
151
|
+
json={
|
|
152
|
+
"expected_labels": expected_labels,
|
|
153
|
+
"logits": logits,
|
|
154
|
+
"anomaly_scores": anomaly_scores,
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
return cls(response)
|
|
158
|
+
|
|
39
159
|
|
|
40
160
|
class BootstrappedClassificationModel:
|
|
41
161
|
|
|
@@ -137,7 +257,7 @@ class ClassificationModel:
|
|
|
137
257
|
is raised.
|
|
138
258
|
"""
|
|
139
259
|
if self._last_prediction_was_batch:
|
|
140
|
-
|
|
260
|
+
logger.warning(
|
|
141
261
|
"Last prediction was part of a batch prediction, returning the last prediction from the batch"
|
|
142
262
|
)
|
|
143
263
|
if self._last_prediction is None:
|
|
@@ -279,7 +399,7 @@ class ClassificationModel:
|
|
|
279
399
|
List of handles to all classification models in the OrcaCloud
|
|
280
400
|
"""
|
|
281
401
|
client = OrcaClient._resolve_client()
|
|
282
|
-
return [cls(metadata) for metadata in client.GET("/classification_model")]
|
|
402
|
+
return [cls(metadata) for metadata in client.GET("/classification_model", params={})]
|
|
283
403
|
|
|
284
404
|
@classmethod
|
|
285
405
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -300,7 +420,7 @@ class ClassificationModel:
|
|
|
300
420
|
try:
|
|
301
421
|
client = OrcaClient._resolve_client()
|
|
302
422
|
client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
303
|
-
|
|
423
|
+
logger.info(f"Deleted model {name_or_id}")
|
|
304
424
|
except LookupError:
|
|
305
425
|
if if_not_exists == "error":
|
|
306
426
|
raise
|
|
@@ -365,6 +485,7 @@ class ClassificationModel:
|
|
|
365
485
|
] = "include_global",
|
|
366
486
|
use_gpu: bool = True,
|
|
367
487
|
batch_size: int = 100,
|
|
488
|
+
consistency_level: ConsistencyLevel = "Bounded",
|
|
368
489
|
) -> list[ClassificationPrediction]:
|
|
369
490
|
pass
|
|
370
491
|
|
|
@@ -386,6 +507,7 @@ class ClassificationModel:
|
|
|
386
507
|
] = "include_global",
|
|
387
508
|
use_gpu: bool = True,
|
|
388
509
|
batch_size: int = 100,
|
|
510
|
+
consistency_level: ConsistencyLevel = "Bounded",
|
|
389
511
|
) -> ClassificationPrediction:
|
|
390
512
|
pass
|
|
391
513
|
|
|
@@ -406,6 +528,7 @@ class ClassificationModel:
|
|
|
406
528
|
] = "include_global",
|
|
407
529
|
use_gpu: bool = True,
|
|
408
530
|
batch_size: int = 100,
|
|
531
|
+
consistency_level: ConsistencyLevel = "Bounded",
|
|
409
532
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
410
533
|
"""
|
|
411
534
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -433,6 +556,7 @@ class ClassificationModel:
|
|
|
433
556
|
* `"exclude_global"`: Exclude global memories
|
|
434
557
|
* `"only_global"`: Only include global memories
|
|
435
558
|
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
559
|
+
consistency_level: Consistency level to use for the prediction(s)
|
|
436
560
|
batch_size: Number of values to process in a single API call
|
|
437
561
|
|
|
438
562
|
Returns:
|
|
@@ -472,7 +596,7 @@ class ClassificationModel:
|
|
|
472
596
|
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
473
597
|
|
|
474
598
|
# Convert to list for batching
|
|
475
|
-
values = value if isinstance(value,
|
|
599
|
+
values = [value] if isinstance(value, str) else list(value)
|
|
476
600
|
if isinstance(expected_labels, list) and len(expected_labels) != len(values):
|
|
477
601
|
raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
|
|
478
602
|
if isinstance(partition_id, list) and len(partition_id) != len(values):
|
|
@@ -482,7 +606,7 @@ class ClassificationModel:
|
|
|
482
606
|
expected_labels = [expected_labels] * len(values)
|
|
483
607
|
elif isinstance(expected_labels, str):
|
|
484
608
|
expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
|
|
485
|
-
elif
|
|
609
|
+
elif expected_labels is not None:
|
|
486
610
|
expected_labels = [
|
|
487
611
|
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
488
612
|
for label in expected_labels
|
|
@@ -513,6 +637,7 @@ class ClassificationModel:
|
|
|
513
637
|
"use_lookup_cache": use_lookup_cache,
|
|
514
638
|
"ignore_unlabeled": ignore_unlabeled,
|
|
515
639
|
"partition_filter_mode": partition_filter_mode,
|
|
640
|
+
"consistency_level": consistency_level,
|
|
516
641
|
}
|
|
517
642
|
if partition_filter_mode != "ignore_partitions":
|
|
518
643
|
request_json["partition_ids"] = (
|
|
@@ -529,6 +654,7 @@ class ClassificationModel:
|
|
|
529
654
|
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
530
655
|
raise RuntimeError("Failed to save some prediction to database.")
|
|
531
656
|
|
|
657
|
+
batch_expected = batch_expected_labels or [None] * len(batch_values)
|
|
532
658
|
predictions.extend(
|
|
533
659
|
ClassificationPrediction(
|
|
534
660
|
prediction_id=prediction["prediction_id"],
|
|
@@ -541,8 +667,9 @@ class ClassificationModel:
|
|
|
541
667
|
model=self,
|
|
542
668
|
logits=prediction["logits"],
|
|
543
669
|
input_value=input_value,
|
|
670
|
+
expected_label=exp_label,
|
|
544
671
|
)
|
|
545
|
-
for prediction, input_value in zip(response, batch_values)
|
|
672
|
+
for prediction, input_value, exp_label in zip(response, batch_values, batch_expected)
|
|
546
673
|
)
|
|
547
674
|
|
|
548
675
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
@@ -566,6 +693,7 @@ class ClassificationModel:
|
|
|
566
693
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
567
694
|
] = "include_global",
|
|
568
695
|
batch_size: int = 100,
|
|
696
|
+
consistency_level: ConsistencyLevel = "Bounded",
|
|
569
697
|
) -> list[ClassificationPrediction]:
|
|
570
698
|
pass
|
|
571
699
|
|
|
@@ -586,6 +714,7 @@ class ClassificationModel:
|
|
|
586
714
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
587
715
|
] = "include_global",
|
|
588
716
|
batch_size: int = 100,
|
|
717
|
+
consistency_level: ConsistencyLevel = "Bounded",
|
|
589
718
|
) -> ClassificationPrediction:
|
|
590
719
|
pass
|
|
591
720
|
|
|
@@ -605,6 +734,7 @@ class ClassificationModel:
|
|
|
605
734
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
606
735
|
] = "include_global",
|
|
607
736
|
batch_size: int = 100,
|
|
737
|
+
consistency_level: ConsistencyLevel = "Bounded",
|
|
608
738
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
609
739
|
"""
|
|
610
740
|
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -632,6 +762,7 @@ class ClassificationModel:
|
|
|
632
762
|
* `"exclude_global"`: Exclude global memories
|
|
633
763
|
* `"only_global"`: Only include global memories
|
|
634
764
|
batch_size: Number of values to process in a single API call
|
|
765
|
+
consistency_level: Consistency level to use for the prediction(s)
|
|
635
766
|
|
|
636
767
|
Returns:
|
|
637
768
|
Label prediction or list of label predictions.
|
|
@@ -670,7 +801,7 @@ class ClassificationModel:
|
|
|
670
801
|
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
671
802
|
|
|
672
803
|
# Convert to list for batching
|
|
673
|
-
values = value if isinstance(value,
|
|
804
|
+
values = [value] if isinstance(value, str) else list(value)
|
|
674
805
|
if isinstance(expected_labels, list) and len(expected_labels) != len(values):
|
|
675
806
|
raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
|
|
676
807
|
if isinstance(partition_id, list) and len(partition_id) != len(values):
|
|
@@ -680,7 +811,7 @@ class ClassificationModel:
|
|
|
680
811
|
expected_labels = [expected_labels] * len(values)
|
|
681
812
|
elif isinstance(expected_labels, str):
|
|
682
813
|
expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
|
|
683
|
-
elif
|
|
814
|
+
elif expected_labels is not None:
|
|
684
815
|
expected_labels = [
|
|
685
816
|
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
686
817
|
for label in expected_labels
|
|
@@ -706,6 +837,7 @@ class ClassificationModel:
|
|
|
706
837
|
"use_lookup_cache": use_lookup_cache,
|
|
707
838
|
"ignore_unlabeled": ignore_unlabeled,
|
|
708
839
|
"partition_filter_mode": partition_filter_mode,
|
|
840
|
+
"consistency_level": consistency_level,
|
|
709
841
|
}
|
|
710
842
|
if partition_filter_mode != "ignore_partitions":
|
|
711
843
|
request_json["partition_ids"] = (
|
|
@@ -721,6 +853,7 @@ class ClassificationModel:
|
|
|
721
853
|
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
722
854
|
raise RuntimeError("Failed to save some prediction to database.")
|
|
723
855
|
|
|
856
|
+
batch_expected = batch_expected_labels or [None] * len(batch_values)
|
|
724
857
|
predictions.extend(
|
|
725
858
|
ClassificationPrediction(
|
|
726
859
|
prediction_id=prediction["prediction_id"],
|
|
@@ -733,8 +866,9 @@ class ClassificationModel:
|
|
|
733
866
|
model=self,
|
|
734
867
|
logits=prediction["logits"],
|
|
735
868
|
input_value=input_value,
|
|
869
|
+
expected_label=exp_label,
|
|
736
870
|
)
|
|
737
|
-
for prediction, input_value in zip(response, batch_values)
|
|
871
|
+
for prediction, input_value, exp_label in zip(response, batch_values, batch_expected)
|
|
738
872
|
)
|
|
739
873
|
|
|
740
874
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
@@ -884,26 +1018,14 @@ class ClassificationModel:
|
|
|
884
1018
|
params={"model_name_or_id": self.id, "job_id": response["job_id"]},
|
|
885
1019
|
)
|
|
886
1020
|
assert res["result"] is not None
|
|
887
|
-
return ClassificationMetrics(
|
|
888
|
-
coverage=res["result"].get("coverage"),
|
|
889
|
-
f1_score=res["result"].get("f1_score"),
|
|
890
|
-
accuracy=res["result"].get("accuracy"),
|
|
891
|
-
loss=res["result"].get("loss"),
|
|
892
|
-
anomaly_score_mean=res["result"].get("anomaly_score_mean"),
|
|
893
|
-
anomaly_score_median=res["result"].get("anomaly_score_median"),
|
|
894
|
-
anomaly_score_variance=res["result"].get("anomaly_score_variance"),
|
|
895
|
-
roc_auc=res["result"].get("roc_auc"),
|
|
896
|
-
pr_auc=res["result"].get("pr_auc"),
|
|
897
|
-
pr_curve=res["result"].get("pr_curve"),
|
|
898
|
-
roc_curve=res["result"].get("roc_curve"),
|
|
899
|
-
)
|
|
1021
|
+
return ClassificationMetrics(res["result"])
|
|
900
1022
|
|
|
901
1023
|
job = Job(response["job_id"], get_value)
|
|
902
1024
|
return job if background else job.result()
|
|
903
1025
|
|
|
904
|
-
def
|
|
1026
|
+
def _evaluate_local(
|
|
905
1027
|
self,
|
|
906
|
-
|
|
1028
|
+
data: Iterable[dict[str, Any]],
|
|
907
1029
|
value_column: str,
|
|
908
1030
|
label_column: str,
|
|
909
1031
|
record_predictions: bool,
|
|
@@ -915,38 +1037,41 @@ class ClassificationModel:
|
|
|
915
1037
|
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
916
1038
|
] = "include_global",
|
|
917
1039
|
) -> ClassificationMetrics:
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
1040
|
+
values: list[str] = []
|
|
1041
|
+
expected_labels: list[int] | list[str] = []
|
|
1042
|
+
partition_ids: list[str | None] | None = [] if partition_column else None
|
|
1043
|
+
|
|
1044
|
+
for sample in data:
|
|
1045
|
+
if len(values) >= 100_000:
|
|
1046
|
+
raise ValueError("Upload a Datasource to evaluate against more than 100,000 samples.")
|
|
1047
|
+
values.append(sample[value_column])
|
|
1048
|
+
expected_label = sample[label_column]
|
|
1049
|
+
if expected_label is None:
|
|
1050
|
+
raise ValueError("Expected label is required for all samples")
|
|
1051
|
+
expected_labels.append(expected_label)
|
|
1052
|
+
if partition_ids is not None and partition_column:
|
|
1053
|
+
partition_ids.append(sample[partition_column])
|
|
1054
|
+
|
|
1055
|
+
if not values:
|
|
1056
|
+
raise ValueError("Evaluation data cannot be empty")
|
|
1057
|
+
|
|
1058
|
+
predictions = self.predict(
|
|
1059
|
+
values,
|
|
1060
|
+
expected_labels=expected_labels,
|
|
1061
|
+
tags=tags,
|
|
1062
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
1063
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
1064
|
+
partition_id=partition_ids,
|
|
1065
|
+
partition_filter_mode=partition_filter_mode,
|
|
1066
|
+
batch_size=batch_size,
|
|
944
1067
|
)
|
|
945
1068
|
|
|
1069
|
+
return ClassificationMetrics.compute(predictions)
|
|
1070
|
+
|
|
946
1071
|
@overload
|
|
947
1072
|
def evaluate(
|
|
948
1073
|
self,
|
|
949
|
-
data: Datasource
|
|
1074
|
+
data: Datasource,
|
|
950
1075
|
*,
|
|
951
1076
|
value_column: str = "value",
|
|
952
1077
|
label_column: str = "label",
|
|
@@ -966,7 +1091,7 @@ class ClassificationModel:
|
|
|
966
1091
|
@overload
|
|
967
1092
|
def evaluate(
|
|
968
1093
|
self,
|
|
969
|
-
data: Datasource |
|
|
1094
|
+
data: Datasource | HFDataset | PandasDataFrame | Iterable[dict[str, Any]],
|
|
970
1095
|
*,
|
|
971
1096
|
value_column: str = "value",
|
|
972
1097
|
label_column: str = "label",
|
|
@@ -985,7 +1110,7 @@ class ClassificationModel:
|
|
|
985
1110
|
|
|
986
1111
|
def evaluate(
|
|
987
1112
|
self,
|
|
988
|
-
data: Datasource |
|
|
1113
|
+
data: Datasource | HFDataset | PandasDataFrame | Iterable[dict[str, Any]],
|
|
989
1114
|
*,
|
|
990
1115
|
value_column: str = "value",
|
|
991
1116
|
label_column: str = "label",
|
|
@@ -1004,13 +1129,14 @@ class ClassificationModel:
|
|
|
1004
1129
|
Evaluate the classification model on a given dataset or datasource
|
|
1005
1130
|
|
|
1006
1131
|
Params:
|
|
1007
|
-
data:
|
|
1132
|
+
data: the data to evaluate the model on. This can be an Orca [`Datasource`][orca_sdk.datasource.Datasource],
|
|
1133
|
+
a Hugging Face [`Dataset`][datasets.Dataset], a pandas [`DataFrame`][pandas.DataFrame], or an iterable of dictionaries.
|
|
1008
1134
|
value_column: Name of the column that contains the input values to the model
|
|
1009
1135
|
label_column: Name of the column containing the expected labels
|
|
1010
1136
|
partition_column: Optional name of the column that contains the partition IDs
|
|
1011
1137
|
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
1012
1138
|
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
1013
|
-
batch_size: Batch size for processing
|
|
1139
|
+
batch_size: Batch size for processing the data inputs (not used for Datasource inputs)
|
|
1014
1140
|
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
1015
1141
|
background: Whether to run the operation in the background and return a job handle
|
|
1016
1142
|
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
@@ -1045,9 +1171,22 @@ class ClassificationModel:
|
|
|
1045
1171
|
partition_column=partition_column,
|
|
1046
1172
|
partition_filter_mode=partition_filter_mode,
|
|
1047
1173
|
)
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1174
|
+
else:
|
|
1175
|
+
if background:
|
|
1176
|
+
raise ValueError("Background evaluation is only supported for Datasource inputs")
|
|
1177
|
+
# Convert to Iterable[dict] - DataFrame needs conversion, others are assumed iterable
|
|
1178
|
+
try:
|
|
1179
|
+
import pandas as pd # type: ignore
|
|
1180
|
+
|
|
1181
|
+
if isinstance(data, pd.DataFrame):
|
|
1182
|
+
data = data.to_dict(orient="records") # type: ignore
|
|
1183
|
+
except ImportError:
|
|
1184
|
+
pass
|
|
1185
|
+
if not hasattr(data, "__iter__"):
|
|
1186
|
+
raise ValueError(f"Invalid data type: {type(data).__name__}. ")
|
|
1187
|
+
|
|
1188
|
+
return self._evaluate_local(
|
|
1189
|
+
data=cast(Iterable[dict[str, Any]], data),
|
|
1051
1190
|
value_column=value_column,
|
|
1052
1191
|
label_column=label_column,
|
|
1053
1192
|
record_predictions=record_predictions,
|
|
@@ -1057,8 +1196,6 @@ class ClassificationModel:
|
|
|
1057
1196
|
partition_column=partition_column,
|
|
1058
1197
|
partition_filter_mode=partition_filter_mode,
|
|
1059
1198
|
)
|
|
1060
|
-
else:
|
|
1061
|
-
raise ValueError(f"Invalid data type: {type(data)}")
|
|
1062
1199
|
|
|
1063
1200
|
def finetune(self, datasource: Datasource):
|
|
1064
1201
|
# do not document until implemented
|
|
@@ -108,6 +108,14 @@ def test_list_models_unauthorized(unauthorized_client, classification_model: Cla
|
|
|
108
108
|
assert ClassificationModel.all() == []
|
|
109
109
|
|
|
110
110
|
|
|
111
|
+
def test_memoryset_classification_models_property(
|
|
112
|
+
classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset
|
|
113
|
+
):
|
|
114
|
+
models = readonly_memoryset.classification_models
|
|
115
|
+
assert len(models) > 0
|
|
116
|
+
assert any(model.id == classification_model.id for model in models)
|
|
117
|
+
|
|
118
|
+
|
|
111
119
|
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
112
120
|
classification_model.description = "New description"
|
|
113
121
|
assert classification_model.description == "New description"
|
|
@@ -162,12 +170,41 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
|
162
170
|
LabeledMemoryset.drop(memoryset.id)
|
|
163
171
|
|
|
164
172
|
|
|
165
|
-
|
|
166
|
-
|
|
173
|
+
def test_delete_memoryset_with_model_cascade(hf_dataset):
|
|
174
|
+
"""Test that cascade=False prevents deletion and cascade=True allows it."""
|
|
175
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_cascade_delete", hf_dataset)
|
|
176
|
+
model = ClassificationModel.create("test_model_cascade_delete", memoryset)
|
|
177
|
+
|
|
178
|
+
# Verify model exists
|
|
179
|
+
assert ClassificationModel.open(model.name) is not None
|
|
180
|
+
|
|
181
|
+
# Without cascade, deletion should fail
|
|
182
|
+
with pytest.raises(RuntimeError):
|
|
183
|
+
LabeledMemoryset.drop(memoryset.id, cascade=False)
|
|
184
|
+
|
|
185
|
+
# Model should still exist
|
|
186
|
+
assert ClassificationModel.exists(model.name)
|
|
187
|
+
|
|
188
|
+
# With cascade, deletion should succeed
|
|
189
|
+
LabeledMemoryset.drop(memoryset.id, cascade=True)
|
|
190
|
+
|
|
191
|
+
# Model should be deleted along with the memoryset
|
|
192
|
+
assert not ClassificationModel.exists(model.name)
|
|
193
|
+
assert not LabeledMemoryset.exists(memoryset.name)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource", "list"])
|
|
197
|
+
def test_evaluate(
|
|
198
|
+
classification_model, eval_data: list[dict], eval_datasource: Datasource, eval_dataset: Dataset, data_type
|
|
199
|
+
):
|
|
167
200
|
result = (
|
|
168
201
|
classification_model.evaluate(eval_dataset)
|
|
169
202
|
if data_type == "dataset"
|
|
170
|
-
else
|
|
203
|
+
else (
|
|
204
|
+
classification_model.evaluate(eval_datasource)
|
|
205
|
+
if data_type == "datasource"
|
|
206
|
+
else classification_model.evaluate(eval_data)
|
|
207
|
+
)
|
|
171
208
|
)
|
|
172
209
|
|
|
173
210
|
assert result is not None
|
|
@@ -660,6 +697,13 @@ def test_predict_with_expected_labels(classification_model: ClassificationModel)
|
|
|
660
697
|
assert prediction.expected_label == 1
|
|
661
698
|
|
|
662
699
|
|
|
700
|
+
def test_predict_with_expected_labels_no_telemetry(classification_model: ClassificationModel):
|
|
701
|
+
"""Test that expected_label is available even when telemetry is disabled"""
|
|
702
|
+
prediction = classification_model.predict("Do you love soup?", expected_labels=1, save_telemetry="off")
|
|
703
|
+
assert prediction.prediction_id is None # telemetry is off
|
|
704
|
+
assert prediction.expected_label == 1 # but expected_label should still be available
|
|
705
|
+
|
|
706
|
+
|
|
663
707
|
def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
|
|
664
708
|
# invalid number of expected labels for batch prediction
|
|
665
709
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
@@ -683,28 +727,27 @@ def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
|
|
|
683
727
|
num_classes=2,
|
|
684
728
|
memory_lookup_count=3,
|
|
685
729
|
)
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
ClassificationModel.drop("test_predict_with_memoryset_update")
|
|
730
|
+
try:
|
|
731
|
+
prediction = model.predict("Do you love soup?", partition_filter_mode="ignore_partitions")
|
|
732
|
+
assert prediction.label == 0
|
|
733
|
+
assert prediction.label_name == "soup"
|
|
734
|
+
# insert new memories
|
|
735
|
+
writable_memoryset.insert(
|
|
736
|
+
[
|
|
737
|
+
{"value": "Do you love soup?", "label": 1, "key": "g1"},
|
|
738
|
+
{"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
|
|
739
|
+
{"value": "Do you love crackers?", "label": 1, "key": "g2"},
|
|
740
|
+
{"value": "Do you love broth?", "label": 1, "key": "g2"},
|
|
741
|
+
{"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
|
|
742
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
743
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
744
|
+
],
|
|
745
|
+
)
|
|
746
|
+
prediction = model.predict("Do you love soup?")
|
|
747
|
+
assert prediction.label == 1
|
|
748
|
+
assert prediction.label_name == "cats"
|
|
749
|
+
finally:
|
|
750
|
+
ClassificationModel.drop("test_predict_with_memoryset_update")
|
|
708
751
|
|
|
709
752
|
|
|
710
753
|
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
@@ -828,6 +871,23 @@ def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
|
828
871
|
assert prediction_without_prompt.label is not None
|
|
829
872
|
|
|
830
873
|
|
|
874
|
+
def test_predict_with_empty_partition(fully_partitioned_classification_resources):
|
|
875
|
+
datasource, memoryset, classification_model = fully_partitioned_classification_resources
|
|
876
|
+
|
|
877
|
+
assert memoryset.length == 15
|
|
878
|
+
|
|
879
|
+
with pytest.raises(RuntimeError, match="lookup failed to return the correct number of memories"):
|
|
880
|
+
classification_model.predict("i love cats", partition_filter_mode="only_global")
|
|
881
|
+
|
|
882
|
+
with pytest.raises(RuntimeError, match="lookup failed to return the correct number of memories"):
|
|
883
|
+
classification_model.predict(
|
|
884
|
+
"i love cats", partition_filter_mode="exclude_global", partition_id="p_does_not_exist"
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
with pytest.raises(RuntimeError, match="lookup failed to return the correct number of memories"):
|
|
888
|
+
classification_model.evaluate(datasource, partition_filter_mode="only_global")
|
|
889
|
+
|
|
890
|
+
|
|
831
891
|
@pytest.mark.asyncio
|
|
832
892
|
async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
|
|
833
893
|
"""Test async prediction with a single value"""
|