orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,28 +1,39 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  from contextlib import contextmanager
5
4
  from datetime import datetime
6
- from typing import Any, Generator, Iterable, Literal, cast, overload
7
-
8
- from datasets import Dataset
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Generator,
9
+ Iterable,
10
+ Literal,
11
+ Sequence,
12
+ cast,
13
+ overload,
14
+ )
9
15
 
10
- from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
- from ._utils.common import UNSET, CreateMode, DropMode
16
+ from ._utils.common import UNSET, CreateMode, DropMode, logger
12
17
  from .async_client import OrcaAsyncClient
13
18
  from .client import (
14
19
  BootstrapClassificationModelMeta,
15
20
  BootstrapLabeledMemoryDataResult,
21
+ )
22
+ from .client import ClassificationMetrics as ClassificationMetricsResponse
23
+ from .client import (
16
24
  ClassificationModelMetadata,
17
25
  ClassificationPredictionRequest,
18
26
  ListPredictionsRequest,
19
27
  OrcaClient,
28
+ PRCurve,
20
29
  PredictiveModelUpdate,
21
30
  RACHeadType,
31
+ ROCCurve,
22
32
  )
23
33
  from .datasource import Datasource
24
34
  from .job import Job
25
35
  from .memoryset import (
36
+ ConsistencyLevel,
26
37
  FilterItem,
27
38
  FilterItemTuple,
28
39
  LabeledMemoryset,
@@ -36,6 +47,115 @@ from .telemetry import (
36
47
  _parse_feedback,
37
48
  )
38
49
 
50
+ if TYPE_CHECKING:
51
+ # Peer dependency - user has datasets if they have a Dataset object
52
+ from datasets import Dataset as HFDataset # type: ignore
53
+ from pandas import DataFrame as PandasDataFrame # type: ignore
54
+
55
+
56
+ class ClassificationMetrics:
57
+ """
58
+ Metrics for evaluating classification model performance.
59
+
60
+ Attributes:
61
+ coverage: Percentage of predictions that are not none
62
+ f1_score: F1 score of the predictions
63
+ accuracy: Accuracy of the predictions
64
+ loss: Cross-entropy loss of the logits
65
+ anomaly_score_mean: Mean of anomaly scores across the dataset
66
+ anomaly_score_median: Median of anomaly scores across the dataset
67
+ anomaly_score_variance: Variance of anomaly scores across the dataset
68
+ roc_auc: Receiver operating characteristic area under the curve
69
+ pr_auc: Average precision (area under the precision-recall curve)
70
+ pr_curve: Precision-recall curve
71
+ roc_curve: Receiver operating characteristic curve
72
+ confusion_matrix: Confusion matrix where entry (i, j) is count of samples with true label i predicted as j
73
+ """
74
+
75
+ coverage: float
76
+ f1_score: float
77
+ accuracy: float
78
+ loss: float | None
79
+ anomaly_score_mean: float | None
80
+ anomaly_score_median: float | None
81
+ anomaly_score_variance: float | None
82
+ roc_auc: float | None
83
+ pr_auc: float | None
84
+ pr_curve: PRCurve | None
85
+ roc_curve: ROCCurve | None
86
+ confusion_matrix: list[list[int]] | None
87
+
88
+ def __init__(self, response: ClassificationMetricsResponse):
89
+ self.coverage = response["coverage"]
90
+ self.f1_score = response["f1_score"]
91
+ self.accuracy = response["accuracy"]
92
+ self.loss = response.get("loss")
93
+ self.anomaly_score_mean = response.get("anomaly_score_mean")
94
+ self.anomaly_score_median = response.get("anomaly_score_median")
95
+ self.anomaly_score_variance = response.get("anomaly_score_variance")
96
+ self.roc_auc = response.get("roc_auc")
97
+ self.pr_auc = response.get("pr_auc")
98
+ self.pr_curve = response.get("pr_curve")
99
+ self.roc_curve = response.get("roc_curve")
100
+ self.confusion_matrix = response.get("confusion_matrix")
101
+ for warning in response.get("warnings", []):
102
+ logger.warning(warning)
103
+
104
+ def __repr__(self) -> str:
105
+ return (
106
+ "ClassificationMetrics({\n"
107
+ + f" accuracy: {self.accuracy:.4f},\n"
108
+ + f" f1_score: {self.f1_score:.4f},\n"
109
+ + (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
110
+ + (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
111
+ + (
112
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
113
+ if self.anomaly_score_mean
114
+ else ""
115
+ )
116
+ + "})"
117
+ )
118
+
119
+ @classmethod
120
+ def compute(
121
+ cls,
122
+ predictions: Sequence[ClassificationPrediction],
123
+ ) -> ClassificationMetrics:
124
+ """
125
+ Compute classification metrics from a list of predictions.
126
+
127
+ Params:
128
+ predictions: List of ClassificationPrediction objects with expected_label set
129
+
130
+ Returns:
131
+ ClassificationMetrics with computed metrics
132
+
133
+ Raises:
134
+ ValueError: If any prediction is missing expected_label or logits
135
+ """
136
+ if len(predictions) > 100_000:
137
+ raise ValueError("Too many predictions, maximum is 100,000")
138
+ logits = [p.logits for p in predictions]
139
+ if any(p.expected_label is None for p in predictions):
140
+ raise ValueError("All predictions must have expected_labels")
141
+ expected_labels = [cast(int, cp.expected_label) for cp in predictions]
142
+ anomaly_scores = (
143
+ None
144
+ if any(p.anomaly_score is None for p in predictions)
145
+ else [cast(float, p.anomaly_score) for p in predictions]
146
+ )
147
+
148
+ client = OrcaClient._resolve_client()
149
+ response = client.POST(
150
+ "/classification_model/metrics",
151
+ json={
152
+ "expected_labels": expected_labels,
153
+ "logits": logits,
154
+ "anomaly_scores": anomaly_scores,
155
+ },
156
+ )
157
+ return cls(response)
158
+
39
159
 
40
160
  class BootstrappedClassificationModel:
41
161
 
@@ -137,7 +257,7 @@ class ClassificationModel:
137
257
  is raised.
138
258
  """
139
259
  if self._last_prediction_was_batch:
140
- logging.warning(
260
+ logger.warning(
141
261
  "Last prediction was part of a batch prediction, returning the last prediction from the batch"
142
262
  )
143
263
  if self._last_prediction is None:
@@ -279,7 +399,7 @@ class ClassificationModel:
279
399
  List of handles to all classification models in the OrcaCloud
280
400
  """
281
401
  client = OrcaClient._resolve_client()
282
- return [cls(metadata) for metadata in client.GET("/classification_model")]
402
+ return [cls(metadata) for metadata in client.GET("/classification_model", params={})]
283
403
 
284
404
  @classmethod
285
405
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -300,7 +420,7 @@ class ClassificationModel:
300
420
  try:
301
421
  client = OrcaClient._resolve_client()
302
422
  client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
303
- logging.info(f"Deleted model {name_or_id}")
423
+ logger.info(f"Deleted model {name_or_id}")
304
424
  except LookupError:
305
425
  if if_not_exists == "error":
306
426
  raise
@@ -365,6 +485,7 @@ class ClassificationModel:
365
485
  ] = "include_global",
366
486
  use_gpu: bool = True,
367
487
  batch_size: int = 100,
488
+ consistency_level: ConsistencyLevel = "Bounded",
368
489
  ) -> list[ClassificationPrediction]:
369
490
  pass
370
491
 
@@ -386,6 +507,7 @@ class ClassificationModel:
386
507
  ] = "include_global",
387
508
  use_gpu: bool = True,
388
509
  batch_size: int = 100,
510
+ consistency_level: ConsistencyLevel = "Bounded",
389
511
  ) -> ClassificationPrediction:
390
512
  pass
391
513
 
@@ -406,6 +528,7 @@ class ClassificationModel:
406
528
  ] = "include_global",
407
529
  use_gpu: bool = True,
408
530
  batch_size: int = 100,
531
+ consistency_level: ConsistencyLevel = "Bounded",
409
532
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
410
533
  """
411
534
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -433,6 +556,7 @@ class ClassificationModel:
433
556
  * `"exclude_global"`: Exclude global memories
434
557
  * `"only_global"`: Only include global memories
435
558
  use_gpu: Whether to use GPU for the prediction (defaults to True)
559
+ consistency_level: Consistency level to use for the prediction(s)
436
560
  batch_size: Number of values to process in a single API call
437
561
 
438
562
  Returns:
@@ -472,7 +596,7 @@ class ClassificationModel:
472
596
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
473
597
 
474
598
  # Convert to list for batching
475
- values = value if isinstance(value, list) else [value]
599
+ values = [value] if isinstance(value, str) else list(value)
476
600
  if isinstance(expected_labels, list) and len(expected_labels) != len(values):
477
601
  raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
478
602
  if isinstance(partition_id, list) and len(partition_id) != len(values):
@@ -482,7 +606,7 @@ class ClassificationModel:
482
606
  expected_labels = [expected_labels] * len(values)
483
607
  elif isinstance(expected_labels, str):
484
608
  expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
485
- elif isinstance(expected_labels, list):
609
+ elif expected_labels is not None:
486
610
  expected_labels = [
487
611
  self.memoryset.label_names.index(label) if isinstance(label, str) else label
488
612
  for label in expected_labels
@@ -513,6 +637,7 @@ class ClassificationModel:
513
637
  "use_lookup_cache": use_lookup_cache,
514
638
  "ignore_unlabeled": ignore_unlabeled,
515
639
  "partition_filter_mode": partition_filter_mode,
640
+ "consistency_level": consistency_level,
516
641
  }
517
642
  if partition_filter_mode != "ignore_partitions":
518
643
  request_json["partition_ids"] = (
@@ -529,6 +654,7 @@ class ClassificationModel:
529
654
  if telemetry_on and any(p["prediction_id"] is None for p in response):
530
655
  raise RuntimeError("Failed to save some prediction to database.")
531
656
 
657
+ batch_expected = batch_expected_labels or [None] * len(batch_values)
532
658
  predictions.extend(
533
659
  ClassificationPrediction(
534
660
  prediction_id=prediction["prediction_id"],
@@ -541,8 +667,9 @@ class ClassificationModel:
541
667
  model=self,
542
668
  logits=prediction["logits"],
543
669
  input_value=input_value,
670
+ expected_label=exp_label,
544
671
  )
545
- for prediction, input_value in zip(response, batch_values)
672
+ for prediction, input_value, exp_label in zip(response, batch_values, batch_expected)
546
673
  )
547
674
 
548
675
  self._last_prediction_was_batch = isinstance(value, list)
@@ -566,6 +693,7 @@ class ClassificationModel:
566
693
  "ignore_partitions", "include_global", "exclude_global", "only_global"
567
694
  ] = "include_global",
568
695
  batch_size: int = 100,
696
+ consistency_level: ConsistencyLevel = "Bounded",
569
697
  ) -> list[ClassificationPrediction]:
570
698
  pass
571
699
 
@@ -586,6 +714,7 @@ class ClassificationModel:
586
714
  "ignore_partitions", "include_global", "exclude_global", "only_global"
587
715
  ] = "include_global",
588
716
  batch_size: int = 100,
717
+ consistency_level: ConsistencyLevel = "Bounded",
589
718
  ) -> ClassificationPrediction:
590
719
  pass
591
720
 
@@ -605,6 +734,7 @@ class ClassificationModel:
605
734
  "ignore_partitions", "include_global", "exclude_global", "only_global"
606
735
  ] = "include_global",
607
736
  batch_size: int = 100,
737
+ consistency_level: ConsistencyLevel = "Bounded",
608
738
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
609
739
  """
610
740
  Asynchronously predict label(s) for the given input value(s) grounded in similar memories
@@ -632,6 +762,7 @@ class ClassificationModel:
632
762
  * `"exclude_global"`: Exclude global memories
633
763
  * `"only_global"`: Only include global memories
634
764
  batch_size: Number of values to process in a single API call
765
+ consistency_level: Consistency level to use for the prediction(s)
635
766
 
636
767
  Returns:
637
768
  Label prediction or list of label predictions.
@@ -670,7 +801,7 @@ class ClassificationModel:
670
801
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
671
802
 
672
803
  # Convert to list for batching
673
- values = value if isinstance(value, list) else [value]
804
+ values = [value] if isinstance(value, str) else list(value)
674
805
  if isinstance(expected_labels, list) and len(expected_labels) != len(values):
675
806
  raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
676
807
  if isinstance(partition_id, list) and len(partition_id) != len(values):
@@ -680,7 +811,7 @@ class ClassificationModel:
680
811
  expected_labels = [expected_labels] * len(values)
681
812
  elif isinstance(expected_labels, str):
682
813
  expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
683
- elif isinstance(expected_labels, list):
814
+ elif expected_labels is not None:
684
815
  expected_labels = [
685
816
  self.memoryset.label_names.index(label) if isinstance(label, str) else label
686
817
  for label in expected_labels
@@ -706,6 +837,7 @@ class ClassificationModel:
706
837
  "use_lookup_cache": use_lookup_cache,
707
838
  "ignore_unlabeled": ignore_unlabeled,
708
839
  "partition_filter_mode": partition_filter_mode,
840
+ "consistency_level": consistency_level,
709
841
  }
710
842
  if partition_filter_mode != "ignore_partitions":
711
843
  request_json["partition_ids"] = (
@@ -721,6 +853,7 @@ class ClassificationModel:
721
853
  if telemetry_on and any(p["prediction_id"] is None for p in response):
722
854
  raise RuntimeError("Failed to save some prediction to database.")
723
855
 
856
+ batch_expected = batch_expected_labels or [None] * len(batch_values)
724
857
  predictions.extend(
725
858
  ClassificationPrediction(
726
859
  prediction_id=prediction["prediction_id"],
@@ -733,8 +866,9 @@ class ClassificationModel:
733
866
  model=self,
734
867
  logits=prediction["logits"],
735
868
  input_value=input_value,
869
+ expected_label=exp_label,
736
870
  )
737
- for prediction, input_value in zip(response, batch_values)
871
+ for prediction, input_value, exp_label in zip(response, batch_values, batch_expected)
738
872
  )
739
873
 
740
874
  self._last_prediction_was_batch = isinstance(value, list)
@@ -884,26 +1018,14 @@ class ClassificationModel:
884
1018
  params={"model_name_or_id": self.id, "job_id": response["job_id"]},
885
1019
  )
886
1020
  assert res["result"] is not None
887
- return ClassificationMetrics(
888
- coverage=res["result"].get("coverage"),
889
- f1_score=res["result"].get("f1_score"),
890
- accuracy=res["result"].get("accuracy"),
891
- loss=res["result"].get("loss"),
892
- anomaly_score_mean=res["result"].get("anomaly_score_mean"),
893
- anomaly_score_median=res["result"].get("anomaly_score_median"),
894
- anomaly_score_variance=res["result"].get("anomaly_score_variance"),
895
- roc_auc=res["result"].get("roc_auc"),
896
- pr_auc=res["result"].get("pr_auc"),
897
- pr_curve=res["result"].get("pr_curve"),
898
- roc_curve=res["result"].get("roc_curve"),
899
- )
1021
+ return ClassificationMetrics(res["result"])
900
1022
 
901
1023
  job = Job(response["job_id"], get_value)
902
1024
  return job if background else job.result()
903
1025
 
904
- def _evaluate_dataset(
1026
+ def _evaluate_local(
905
1027
  self,
906
- dataset: Dataset,
1028
+ data: Iterable[dict[str, Any]],
907
1029
  value_column: str,
908
1030
  label_column: str,
909
1031
  record_predictions: bool,
@@ -915,38 +1037,41 @@ class ClassificationModel:
915
1037
  "ignore_partitions", "include_global", "exclude_global", "only_global"
916
1038
  ] = "include_global",
917
1039
  ) -> ClassificationMetrics:
918
- if len(dataset) == 0:
919
- raise ValueError("Evaluation dataset cannot be empty")
920
-
921
- if any(x is None for x in dataset[label_column]):
922
- raise ValueError("Evaluation dataset cannot contain None values in the label column")
923
-
924
- predictions = [
925
- prediction
926
- for i in range(0, len(dataset), batch_size)
927
- for prediction in self.predict(
928
- dataset[i : i + batch_size][value_column],
929
- expected_labels=dataset[i : i + batch_size][label_column],
930
- tags=tags,
931
- save_telemetry="sync" if record_predictions else "off",
932
- ignore_unlabeled=ignore_unlabeled,
933
- partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
934
- partition_filter_mode=partition_filter_mode,
935
- )
936
- ]
937
-
938
- return calculate_classification_metrics(
939
- expected_labels=dataset[label_column],
940
- logits=[p.logits for p in predictions],
941
- anomaly_scores=[p.anomaly_score for p in predictions],
942
- include_curves=True,
943
- include_confusion_matrix=True,
1040
+ values: list[str] = []
1041
+ expected_labels: list[int] | list[str] = []
1042
+ partition_ids: list[str | None] | None = [] if partition_column else None
1043
+
1044
+ for sample in data:
1045
+ if len(values) >= 100_000:
1046
+ raise ValueError("Upload a Datasource to evaluate against more than 100,000 samples.")
1047
+ values.append(sample[value_column])
1048
+ expected_label = sample[label_column]
1049
+ if expected_label is None:
1050
+ raise ValueError("Expected label is required for all samples")
1051
+ expected_labels.append(expected_label)
1052
+ if partition_ids is not None and partition_column:
1053
+ partition_ids.append(sample[partition_column])
1054
+
1055
+ if not values:
1056
+ raise ValueError("Evaluation data cannot be empty")
1057
+
1058
+ predictions = self.predict(
1059
+ values,
1060
+ expected_labels=expected_labels,
1061
+ tags=tags,
1062
+ save_telemetry="sync" if record_predictions else "off",
1063
+ ignore_unlabeled=ignore_unlabeled,
1064
+ partition_id=partition_ids,
1065
+ partition_filter_mode=partition_filter_mode,
1066
+ batch_size=batch_size,
944
1067
  )
945
1068
 
1069
+ return ClassificationMetrics.compute(predictions)
1070
+
946
1071
  @overload
947
1072
  def evaluate(
948
1073
  self,
949
- data: Datasource | Dataset,
1074
+ data: Datasource,
950
1075
  *,
951
1076
  value_column: str = "value",
952
1077
  label_column: str = "label",
@@ -966,7 +1091,7 @@ class ClassificationModel:
966
1091
  @overload
967
1092
  def evaluate(
968
1093
  self,
969
- data: Datasource | Dataset,
1094
+ data: Datasource | HFDataset | PandasDataFrame | Iterable[dict[str, Any]],
970
1095
  *,
971
1096
  value_column: str = "value",
972
1097
  label_column: str = "label",
@@ -985,7 +1110,7 @@ class ClassificationModel:
985
1110
 
986
1111
  def evaluate(
987
1112
  self,
988
- data: Datasource | Dataset,
1113
+ data: Datasource | HFDataset | PandasDataFrame | Iterable[dict[str, Any]],
989
1114
  *,
990
1115
  value_column: str = "value",
991
1116
  label_column: str = "label",
@@ -1004,13 +1129,14 @@ class ClassificationModel:
1004
1129
  Evaluate the classification model on a given dataset or datasource
1005
1130
 
1006
1131
  Params:
1007
- data: Dataset or Datasource to evaluate the model on
1132
+ data: the data to evaluate the model on. This can be an Orca [`Datasource`][orca_sdk.datasource.Datasource],
1133
+ a Hugging Face [`Dataset`][datasets.Dataset], a pandas [`DataFrame`][pandas.DataFrame], or an iterable of dictionaries.
1008
1134
  value_column: Name of the column that contains the input values to the model
1009
1135
  label_column: Name of the column containing the expected labels
1010
1136
  partition_column: Optional name of the column that contains the partition IDs
1011
1137
  record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
1012
1138
  tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
1013
- batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
1139
+ batch_size: Batch size for processing the data inputs (not used for Datasource inputs)
1014
1140
  subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
1015
1141
  background: Whether to run the operation in the background and return a job handle
1016
1142
  ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
@@ -1045,9 +1171,22 @@ class ClassificationModel:
1045
1171
  partition_column=partition_column,
1046
1172
  partition_filter_mode=partition_filter_mode,
1047
1173
  )
1048
- elif isinstance(data, Dataset):
1049
- return self._evaluate_dataset(
1050
- dataset=data,
1174
+ else:
1175
+ if background:
1176
+ raise ValueError("Background evaluation is only supported for Datasource inputs")
1177
+ # Convert to Iterable[dict] - DataFrame needs conversion, others are assumed iterable
1178
+ try:
1179
+ import pandas as pd # type: ignore
1180
+
1181
+ if isinstance(data, pd.DataFrame):
1182
+ data = data.to_dict(orient="records") # type: ignore
1183
+ except ImportError:
1184
+ pass
1185
+ if not hasattr(data, "__iter__"):
1186
+ raise ValueError(f"Invalid data type: {type(data).__name__}. ")
1187
+
1188
+ return self._evaluate_local(
1189
+ data=cast(Iterable[dict[str, Any]], data),
1051
1190
  value_column=value_column,
1052
1191
  label_column=label_column,
1053
1192
  record_predictions=record_predictions,
@@ -1057,8 +1196,6 @@ class ClassificationModel:
1057
1196
  partition_column=partition_column,
1058
1197
  partition_filter_mode=partition_filter_mode,
1059
1198
  )
1060
- else:
1061
- raise ValueError(f"Invalid data type: {type(data)}")
1062
1199
 
1063
1200
  def finetune(self, datasource: Datasource):
1064
1201
  # do not document until implemented
@@ -108,6 +108,14 @@ def test_list_models_unauthorized(unauthorized_client, classification_model: Cla
108
108
  assert ClassificationModel.all() == []
109
109
 
110
110
 
111
+ def test_memoryset_classification_models_property(
112
+ classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset
113
+ ):
114
+ models = readonly_memoryset.classification_models
115
+ assert len(models) > 0
116
+ assert any(model.id == classification_model.id for model in models)
117
+
118
+
111
119
  def test_update_model_attributes(classification_model: ClassificationModel):
112
120
  classification_model.description = "New description"
113
121
  assert classification_model.description == "New description"
@@ -162,12 +170,41 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
162
170
  LabeledMemoryset.drop(memoryset.id)
163
171
 
164
172
 
165
- @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
166
- def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
173
+ def test_delete_memoryset_with_model_cascade(hf_dataset):
174
+ """Test that cascade=False prevents deletion and cascade=True allows it."""
175
+ memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_cascade_delete", hf_dataset)
176
+ model = ClassificationModel.create("test_model_cascade_delete", memoryset)
177
+
178
+ # Verify model exists
179
+ assert ClassificationModel.open(model.name) is not None
180
+
181
+ # Without cascade, deletion should fail
182
+ with pytest.raises(RuntimeError):
183
+ LabeledMemoryset.drop(memoryset.id, cascade=False)
184
+
185
+ # Model should still exist
186
+ assert ClassificationModel.exists(model.name)
187
+
188
+ # With cascade, deletion should succeed
189
+ LabeledMemoryset.drop(memoryset.id, cascade=True)
190
+
191
+ # Model should be deleted along with the memoryset
192
+ assert not ClassificationModel.exists(model.name)
193
+ assert not LabeledMemoryset.exists(memoryset.name)
194
+
195
+
196
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource", "list"])
197
+ def test_evaluate(
198
+ classification_model, eval_data: list[dict], eval_datasource: Datasource, eval_dataset: Dataset, data_type
199
+ ):
167
200
  result = (
168
201
  classification_model.evaluate(eval_dataset)
169
202
  if data_type == "dataset"
170
- else classification_model.evaluate(eval_datasource)
203
+ else (
204
+ classification_model.evaluate(eval_datasource)
205
+ if data_type == "datasource"
206
+ else classification_model.evaluate(eval_data)
207
+ )
171
208
  )
172
209
 
173
210
  assert result is not None
@@ -660,6 +697,13 @@ def test_predict_with_expected_labels(classification_model: ClassificationModel)
660
697
  assert prediction.expected_label == 1
661
698
 
662
699
 
700
+ def test_predict_with_expected_labels_no_telemetry(classification_model: ClassificationModel):
701
+ """Test that expected_label is available even when telemetry is disabled"""
702
+ prediction = classification_model.predict("Do you love soup?", expected_labels=1, save_telemetry="off")
703
+ assert prediction.prediction_id is None # telemetry is off
704
+ assert prediction.expected_label == 1 # but expected_label should still be available
705
+
706
+
663
707
  def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
664
708
  # invalid number of expected labels for batch prediction
665
709
  with pytest.raises(ValueError, match=r"Invalid input.*"):
@@ -683,28 +727,27 @@ def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
683
727
  num_classes=2,
684
728
  memory_lookup_count=3,
685
729
  )
686
-
687
- prediction = model.predict("Do you love soup?")
688
- assert prediction.label == 0
689
- assert prediction.label_name == "soup"
690
-
691
- # insert new memories
692
- writable_memoryset.insert(
693
- [
694
- {"value": "Do you love soup?", "label": 1, "key": "g1"},
695
- {"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
696
- {"value": "Do you love crackers?", "label": 1, "key": "g2"},
697
- {"value": "Do you love broth?", "label": 1, "key": "g2"},
698
- {"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
699
- {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
700
- {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
701
- ],
702
- )
703
- prediction = model.predict("Do you love soup?")
704
- assert prediction.label == 1
705
- assert prediction.label_name == "cats"
706
-
707
- ClassificationModel.drop("test_predict_with_memoryset_update")
730
+ try:
731
+ prediction = model.predict("Do you love soup?", partition_filter_mode="ignore_partitions")
732
+ assert prediction.label == 0
733
+ assert prediction.label_name == "soup"
734
+ # insert new memories
735
+ writable_memoryset.insert(
736
+ [
737
+ {"value": "Do you love soup?", "label": 1, "key": "g1"},
738
+ {"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
739
+ {"value": "Do you love crackers?", "label": 1, "key": "g2"},
740
+ {"value": "Do you love broth?", "label": 1, "key": "g2"},
741
+ {"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
742
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
743
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
744
+ ],
745
+ )
746
+ prediction = model.predict("Do you love soup?")
747
+ assert prediction.label == 1
748
+ assert prediction.label_name == "cats"
749
+ finally:
750
+ ClassificationModel.drop("test_predict_with_memoryset_update")
708
751
 
709
752
 
710
753
  def test_last_prediction_with_batch(classification_model: ClassificationModel):
@@ -828,6 +871,23 @@ def test_predict_with_prompt(classification_model: ClassificationModel):
828
871
  assert prediction_without_prompt.label is not None
829
872
 
830
873
 
874
+ def test_predict_with_empty_partition(fully_partitioned_classification_resources):
875
+ datasource, memoryset, classification_model = fully_partitioned_classification_resources
876
+
877
+ assert memoryset.length == 15
878
+
879
+ with pytest.raises(RuntimeError, match="lookup failed to return the correct number of memories"):
880
+ classification_model.predict("i love cats", partition_filter_mode="only_global")
881
+
882
+ with pytest.raises(RuntimeError, match="lookup failed to return the correct number of memories"):
883
+ classification_model.predict(
884
+ "i love cats", partition_filter_mode="exclude_global", partition_id="p_does_not_exist"
885
+ )
886
+
887
+ with pytest.raises(RuntimeError, match="lookup failed to return the correct number of memories"):
888
+ classification_model.evaluate(datasource, partition_filter_mode="only_global")
889
+
890
+
831
891
  @pytest.mark.asyncio
832
892
  async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
833
893
  """Test async prediction with a single value"""