orca-sdk 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,11 +12,10 @@ from ._utils.common import UNSET, CreateMode, DropMode
12
12
  from .async_client import OrcaAsyncClient
13
13
  from .client import (
14
14
  BootstrapClassificationModelMeta,
15
- BootstrapClassificationModelResult,
16
- ClassificationEvaluationRequest,
15
+ BootstrapLabeledMemoryDataResult,
17
16
  ClassificationModelMetadata,
17
+ ClassificationPredictionRequest,
18
18
  OrcaClient,
19
- PostClassificationModelByModelNameOrIdEvaluationParams,
20
19
  PredictiveModelUpdate,
21
20
  RACHeadType,
22
21
  )
@@ -42,7 +41,7 @@ class BootstrappedClassificationModel:
42
41
  datasource: Datasource | None
43
42
  memoryset: LabeledMemoryset | None
44
43
  classification_model: ClassificationModel | None
45
- agent_output: BootstrapClassificationModelResult | None
44
+ agent_output: BootstrapLabeledMemoryDataResult | None
46
45
 
47
46
  def __init__(self, metadata: BootstrapClassificationModelMeta):
48
47
  self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
@@ -358,6 +357,10 @@ class ClassificationModel:
358
357
  use_lookup_cache: bool = True,
359
358
  timeout_seconds: int = 10,
360
359
  ignore_unlabeled: bool = False,
360
+ partition_id: str | list[str | None] | None = None,
361
+ partition_filter_mode: Literal[
362
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
363
+ ] = "include_global",
361
364
  use_gpu: bool = True,
362
365
  ) -> list[ClassificationPrediction]:
363
366
  pass
@@ -374,6 +377,10 @@ class ClassificationModel:
374
377
  use_lookup_cache: bool = True,
375
378
  timeout_seconds: int = 10,
376
379
  ignore_unlabeled: bool = False,
380
+ partition_id: str | None = None,
381
+ partition_filter_mode: Literal[
382
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
383
+ ] = "include_global",
377
384
  use_gpu: bool = True,
378
385
  ) -> ClassificationPrediction:
379
386
  pass
@@ -389,6 +396,10 @@ class ClassificationModel:
389
396
  use_lookup_cache: bool = True,
390
397
  timeout_seconds: int = 10,
391
398
  ignore_unlabeled: bool = False,
399
+ partition_id: str | None | list[str | None] = None,
400
+ partition_filter_mode: Literal[
401
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
402
+ ] = "include_global",
392
403
  use_gpu: bool = True,
393
404
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
394
405
  """
@@ -410,6 +421,12 @@ class ClassificationModel:
410
421
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
411
422
  ignore_unlabeled: If True, only use labeled memories during lookup.
412
423
  If False (default), allow unlabeled memories when necessary.
424
+ partition_id: Optional partition ID(s) to use during memory lookup
425
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
426
+ * `"ignore_partitions"`: Ignore partitions
427
+ * `"include_global"`: Include global memories
428
+ * `"exclude_global"`: Exclude global memories
429
+ * `"only_global"`: Only include global memories
413
430
  use_gpu: Whether to use GPU for the prediction (defaults to True)
414
431
 
415
432
  Returns:
@@ -463,21 +480,26 @@ class ClassificationModel:
463
480
 
464
481
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
465
482
  client = OrcaClient._resolve_client()
483
+ request_json: ClassificationPredictionRequest = {
484
+ "input_values": value if isinstance(value, list) else [value],
485
+ "memoryset_override_name_or_id": self._memoryset_override_id,
486
+ "expected_labels": expected_labels,
487
+ "tags": list(tags or set()),
488
+ "save_telemetry": telemetry_on,
489
+ "save_telemetry_synchronously": telemetry_sync,
490
+ "filters": cast(list[FilterItem], parsed_filters),
491
+ "prompt": prompt,
492
+ "use_lookup_cache": use_lookup_cache,
493
+ "ignore_unlabeled": ignore_unlabeled,
494
+ "partition_filter_mode": partition_filter_mode,
495
+ }
496
+ # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
497
+ if partition_filter_mode != "ignore_partitions":
498
+ request_json["partition_ids"] = partition_id
466
499
  response = client.POST(
467
500
  endpoint,
468
501
  params={"name_or_id": self.id},
469
- json={
470
- "input_values": value if isinstance(value, list) else [value],
471
- "memoryset_override_name_or_id": self._memoryset_override_id,
472
- "expected_labels": expected_labels,
473
- "tags": list(tags or set()),
474
- "save_telemetry": telemetry_on,
475
- "save_telemetry_synchronously": telemetry_sync,
476
- "filters": cast(list[FilterItem], parsed_filters),
477
- "prompt": prompt,
478
- "use_lookup_cache": use_lookup_cache,
479
- "ignore_unlabeled": ignore_unlabeled,
480
- },
502
+ json=request_json,
481
503
  timeout=timeout_seconds,
482
504
  )
483
505
 
@@ -515,6 +537,10 @@ class ClassificationModel:
515
537
  use_lookup_cache: bool = True,
516
538
  timeout_seconds: int = 10,
517
539
  ignore_unlabeled: bool = False,
540
+ partition_id: str | list[str | None] | None = None,
541
+ partition_filter_mode: Literal[
542
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
543
+ ] = "include_global",
518
544
  ) -> list[ClassificationPrediction]:
519
545
  pass
520
546
 
@@ -530,6 +556,10 @@ class ClassificationModel:
530
556
  use_lookup_cache: bool = True,
531
557
  timeout_seconds: int = 10,
532
558
  ignore_unlabeled: bool = False,
559
+ partition_id: str | None = None,
560
+ partition_filter_mode: Literal[
561
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
562
+ ] = "include_global",
533
563
  ) -> ClassificationPrediction:
534
564
  pass
535
565
 
@@ -544,6 +574,10 @@ class ClassificationModel:
544
574
  use_lookup_cache: bool = True,
545
575
  timeout_seconds: int = 10,
546
576
  ignore_unlabeled: bool = False,
577
+ partition_id: str | None | list[str | None] = None,
578
+ partition_filter_mode: Literal[
579
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
580
+ ] = "include_global",
547
581
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
548
582
  """
549
583
  Asynchronously predict label(s) for the given input value(s) grounded in similar memories
@@ -564,7 +598,12 @@ class ClassificationModel:
564
598
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
565
599
  ignore_unlabeled: If True, only use labeled memories during lookup.
566
600
  If False (default), allow unlabeled memories when necessary.
567
-
601
+ partition_id: Optional partition ID(s) to use during memory lookup
602
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
603
+ * `"ignore_partitions"`: Ignore partitions
604
+ * `"include_global"`: Include global memories
605
+ * `"exclude_global"`: Exclude global memories
606
+ * `"only_global"`: Only include global memories
568
607
  Returns:
569
608
  Label prediction or list of label predictions.
570
609
 
@@ -611,21 +650,26 @@ class ClassificationModel:
611
650
 
612
651
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
613
652
  client = OrcaAsyncClient._resolve_client()
653
+ request_json: ClassificationPredictionRequest = {
654
+ "input_values": value if isinstance(value, list) else [value],
655
+ "memoryset_override_name_or_id": self._memoryset_override_id,
656
+ "expected_labels": expected_labels,
657
+ "tags": list(tags or set()),
658
+ "save_telemetry": telemetry_on,
659
+ "save_telemetry_synchronously": telemetry_sync,
660
+ "filters": cast(list[FilterItem], parsed_filters),
661
+ "prompt": prompt,
662
+ "use_lookup_cache": use_lookup_cache,
663
+ "ignore_unlabeled": ignore_unlabeled,
664
+ "partition_filter_mode": partition_filter_mode,
665
+ }
666
+ # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
667
+ if partition_filter_mode != "ignore_partitions":
668
+ request_json["partition_ids"] = partition_id
614
669
  response = await client.POST(
615
670
  "/gpu/classification_model/{name_or_id}/prediction",
616
671
  params={"name_or_id": self.id},
617
- json={
618
- "input_values": value if isinstance(value, list) else [value],
619
- "memoryset_override_name_or_id": self._memoryset_override_id,
620
- "expected_labels": expected_labels,
621
- "tags": list(tags or set()),
622
- "save_telemetry": telemetry_on,
623
- "save_telemetry_synchronously": telemetry_sync,
624
- "filters": cast(list[FilterItem], parsed_filters),
625
- "prompt": prompt,
626
- "use_lookup_cache": use_lookup_cache,
627
- "ignore_unlabeled": ignore_unlabeled,
628
- },
672
+ json=request_json,
629
673
  timeout=timeout_seconds,
630
674
  )
631
675
 
@@ -730,6 +774,10 @@ class ClassificationModel:
730
774
  subsample: int | float | None,
731
775
  background: bool = False,
732
776
  ignore_unlabeled: bool = False,
777
+ partition_column: str | None = None,
778
+ partition_filter_mode: Literal[
779
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
780
+ ] = "include_global",
733
781
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
734
782
  client = OrcaClient._resolve_client()
735
783
  response = client.POST(
@@ -744,6 +792,8 @@ class ClassificationModel:
744
792
  "telemetry_tags": list(tags) if tags else None,
745
793
  "subsample": subsample,
746
794
  "ignore_unlabeled": ignore_unlabeled,
795
+ "datasource_partition_column": partition_column,
796
+ "partition_filter_mode": partition_filter_mode,
747
797
  },
748
798
  )
749
799
 
@@ -780,6 +830,10 @@ class ClassificationModel:
780
830
  tags: set[str],
781
831
  batch_size: int,
782
832
  ignore_unlabeled: bool,
833
+ partition_column: str | None = None,
834
+ partition_filter_mode: Literal[
835
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
836
+ ] = "include_global",
783
837
  ) -> ClassificationMetrics:
784
838
  if len(dataset) == 0:
785
839
  raise ValueError("Evaluation dataset cannot be empty")
@@ -796,6 +850,8 @@ class ClassificationModel:
796
850
  tags=tags,
797
851
  save_telemetry="sync" if record_predictions else "off",
798
852
  ignore_unlabeled=ignore_unlabeled,
853
+ partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
854
+ partition_filter_mode=partition_filter_mode,
799
855
  )
800
856
  ]
801
857
 
@@ -813,12 +869,16 @@ class ClassificationModel:
813
869
  *,
814
870
  value_column: str = "value",
815
871
  label_column: str = "label",
872
+ partition_column: str | None = None,
816
873
  record_predictions: bool = False,
817
874
  tags: set[str] = {"evaluation"},
818
875
  batch_size: int = 100,
819
876
  subsample: int | float | None = None,
820
877
  background: Literal[True],
821
878
  ignore_unlabeled: bool = False,
879
+ partition_filter_mode: Literal[
880
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
881
+ ] = "include_global",
822
882
  ) -> Job[ClassificationMetrics]:
823
883
  pass
824
884
 
@@ -829,12 +889,16 @@ class ClassificationModel:
829
889
  *,
830
890
  value_column: str = "value",
831
891
  label_column: str = "label",
892
+ partition_column: str | None = None,
832
893
  record_predictions: bool = False,
833
894
  tags: set[str] = {"evaluation"},
834
895
  batch_size: int = 100,
835
896
  subsample: int | float | None = None,
836
897
  background: Literal[False] = False,
837
898
  ignore_unlabeled: bool = False,
899
+ partition_filter_mode: Literal[
900
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
901
+ ] = "include_global",
838
902
  ) -> ClassificationMetrics:
839
903
  pass
840
904
 
@@ -844,12 +908,16 @@ class ClassificationModel:
844
908
  *,
845
909
  value_column: str = "value",
846
910
  label_column: str = "label",
911
+ partition_column: str | None = None,
847
912
  record_predictions: bool = False,
848
913
  tags: set[str] = {"evaluation"},
849
914
  batch_size: int = 100,
850
915
  subsample: int | float | None = None,
851
916
  background: bool = False,
852
917
  ignore_unlabeled: bool = False,
918
+ partition_filter_mode: Literal[
919
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
920
+ ] = "include_global",
853
921
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
854
922
  """
855
923
  Evaluate the classification model on a given dataset or datasource
@@ -858,13 +926,18 @@ class ClassificationModel:
858
926
  data: Dataset or Datasource to evaluate the model on
859
927
  value_column: Name of the column that contains the input values to the model
860
928
  label_column: Name of the column containing the expected labels
929
+ partition_column: Optional name of the column that contains the partition IDs
861
930
  record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
862
931
  tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
863
932
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
864
933
  subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
865
934
  background: Whether to run the operation in the background and return a job handle
866
935
  ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
867
-
936
+ partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
937
+ * `"ignore_partitions"`: Ignore partitions
938
+ * `"include_global"`: Include global memories
939
+ * `"exclude_global"`: Exclude global memories
940
+ * `"only_global"`: Only include global memories
868
941
  Returns:
869
942
  EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
870
943
 
@@ -888,6 +961,8 @@ class ClassificationModel:
888
961
  subsample=subsample,
889
962
  background=background,
890
963
  ignore_unlabeled=ignore_unlabeled,
964
+ partition_column=partition_column,
965
+ partition_filter_mode=partition_filter_mode,
891
966
  )
892
967
  elif isinstance(data, Dataset):
893
968
  return self._evaluate_dataset(
@@ -898,6 +973,8 @@ class ClassificationModel:
898
973
  tags=tags,
899
974
  batch_size=batch_size,
900
975
  ignore_unlabeled=ignore_unlabeled,
976
+ partition_column=partition_column,
977
+ partition_filter_mode=partition_filter_mode,
901
978
  )
902
979
  else:
903
980
  raise ValueError(f"Invalid data type: {type(data)}")
@@ -187,18 +187,24 @@ def test_evaluate(classification_model, eval_datasource: Datasource, eval_datase
187
187
  assert -1.0 <= result.anomaly_score_variance <= 1.0
188
188
 
189
189
  assert result.pr_auc is not None
190
- assert np.allclose(result.pr_auc, 0.75)
190
+ assert np.allclose(result.pr_auc, 0.83333)
191
191
  assert result.pr_curve is not None
192
- assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
193
- assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
194
- assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
192
+ assert np.allclose(
193
+ result.pr_curve["thresholds"],
194
+ [0.0, 0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364],
195
+ )
196
+ assert np.allclose(result.pr_curve["precisions"], [0.5, 0.666666, 0.5, 1.0, 1.0])
197
+ assert np.allclose(result.pr_curve["recalls"], [1.0, 1.0, 0.5, 0.5, 0.0])
195
198
 
196
199
  assert result.roc_auc is not None
197
- assert np.allclose(result.roc_auc, 0.625)
200
+ assert np.allclose(result.roc_auc, 0.75)
198
201
  assert result.roc_curve is not None
199
- assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
200
- assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
201
- assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
202
+ assert np.allclose(
203
+ result.roc_curve["thresholds"],
204
+ [0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364, 1.0],
205
+ )
206
+ assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.5, 0.0, 0.0])
207
+ assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 1.0, 0.5, 0.5, 0.0])
202
208
 
203
209
 
204
210
  def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
@@ -221,6 +227,139 @@ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval
221
227
  assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
222
228
 
223
229
 
230
+ def test_evaluate_with_partition_column_dataset(partitioned_classification_model: ClassificationModel):
231
+ """Test evaluate with partition_column on a Dataset"""
232
+ # Create a test dataset with partition_id column
233
+ eval_dataset_with_partition = Dataset.from_list(
234
+ [
235
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
236
+ {"value": "cats are cute", "label": 1, "partition_id": "p1"},
237
+ {"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
238
+ {"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
239
+ ]
240
+ )
241
+
242
+ # Evaluate with partition_column
243
+ result = partitioned_classification_model.evaluate(
244
+ eval_dataset_with_partition,
245
+ partition_column="partition_id",
246
+ partition_filter_mode="exclude_global",
247
+ )
248
+ assert result is not None
249
+ assert isinstance(result, ClassificationMetrics)
250
+ assert isinstance(result.accuracy, float)
251
+ assert isinstance(result.f1_score, float)
252
+ assert isinstance(result.loss, float)
253
+
254
+
255
+ def test_evaluate_with_partition_column_include_global(partitioned_classification_model: ClassificationModel):
256
+ """Test evaluate with partition_column and include_global mode"""
257
+ eval_dataset_with_partition = Dataset.from_list(
258
+ [
259
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
260
+ {"value": "cats are cute", "label": 1, "partition_id": "p1"},
261
+ ]
262
+ )
263
+
264
+ # Evaluate with partition_column and include_global (default)
265
+ result = partitioned_classification_model.evaluate(
266
+ eval_dataset_with_partition,
267
+ partition_column="partition_id",
268
+ partition_filter_mode="include_global",
269
+ )
270
+ assert result is not None
271
+ assert isinstance(result, ClassificationMetrics)
272
+
273
+
274
+ def test_evaluate_with_partition_column_exclude_global(partitioned_classification_model: ClassificationModel):
275
+ """Test evaluate with partition_column and exclude_global mode"""
276
+ eval_dataset_with_partition = Dataset.from_list(
277
+ [
278
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
279
+ {"value": "cats are cute", "label": 1, "partition_id": "p1"},
280
+ ]
281
+ )
282
+
283
+ # Evaluate with partition_column and exclude_global
284
+ result = partitioned_classification_model.evaluate(
285
+ eval_dataset_with_partition,
286
+ partition_column="partition_id",
287
+ partition_filter_mode="exclude_global",
288
+ )
289
+ assert result is not None
290
+ assert isinstance(result, ClassificationMetrics)
291
+
292
+
293
+ def test_evaluate_with_partition_column_only_global(partitioned_classification_model: ClassificationModel):
294
+ """Test evaluate with partition_filter_mode only_global"""
295
+ eval_dataset_with_partition = Dataset.from_list(
296
+ [
297
+ {"value": "cats are independent animals", "label": 1, "partition_id": None},
298
+ {"value": "i love the beach", "label": 1, "partition_id": None},
299
+ ]
300
+ )
301
+
302
+ # Evaluate with only_global mode
303
+ result = partitioned_classification_model.evaluate(
304
+ eval_dataset_with_partition,
305
+ partition_column="partition_id",
306
+ partition_filter_mode="only_global",
307
+ )
308
+ assert result is not None
309
+ assert isinstance(result, ClassificationMetrics)
310
+
311
+
312
+ def test_evaluate_with_partition_column_ignore_partitions(partitioned_classification_model: ClassificationModel):
313
+ """Test evaluate with partition_filter_mode ignore_partitions"""
314
+ eval_dataset_with_partition = Dataset.from_list(
315
+ [
316
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
317
+ {"value": "cats are cute", "label": 1, "partition_id": "p2"},
318
+ ]
319
+ )
320
+
321
+ # Evaluate with ignore_partitions mode
322
+ result = partitioned_classification_model.evaluate(
323
+ eval_dataset_with_partition,
324
+ partition_column="partition_id",
325
+ partition_filter_mode="ignore_partitions",
326
+ )
327
+ assert result is not None
328
+ assert isinstance(result, ClassificationMetrics)
329
+
330
+
331
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
332
+ def test_evaluate_with_partition_column_datasource(partitioned_classification_model: ClassificationModel, data_type):
333
+ """Test evaluate with partition_column on a Datasource"""
334
+ # Create a test datasource with partition_id column
335
+ eval_data_with_partition = [
336
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
337
+ {"value": "cats are cute", "label": 1, "partition_id": "p1"},
338
+ {"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
339
+ {"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
340
+ ]
341
+
342
+ if data_type == "dataset":
343
+ eval_data = Dataset.from_list(eval_data_with_partition)
344
+ result = partitioned_classification_model.evaluate(
345
+ eval_data,
346
+ partition_column="partition_id",
347
+ partition_filter_mode="exclude_global",
348
+ )
349
+ else:
350
+ eval_datasource = Datasource.from_list("eval_datasource_with_partition", eval_data_with_partition)
351
+ result = partitioned_classification_model.evaluate(
352
+ eval_datasource,
353
+ partition_column="partition_id",
354
+ partition_filter_mode="exclude_global",
355
+ )
356
+
357
+ assert result is not None
358
+ assert isinstance(result, ClassificationMetrics)
359
+ assert isinstance(result.accuracy, float)
360
+ assert isinstance(result.f1_score, float)
361
+
362
+
224
363
  def test_predict(classification_model: ClassificationModel, label_names: list[str]):
225
364
  predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
226
365
  assert len(predictions) == 2
@@ -284,6 +423,186 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
284
423
  model.predict("test")
285
424
 
286
425
 
426
+ def test_predict_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
427
+ """Test predict with a specific partition_id"""
428
+ # Predict with partition_id p1 - should use memories from p1
429
+ prediction = partitioned_classification_model.predict(
430
+ "soup", partition_id="p1", partition_filter_mode="exclude_global"
431
+ )
432
+ assert prediction.label is not None
433
+ assert prediction.label_name in label_names
434
+ assert 0 <= prediction.confidence <= 1
435
+ assert prediction.logits is not None
436
+ assert len(prediction.logits) == 2
437
+
438
+ # Predict with partition_id p2 - should use memories from p2
439
+ prediction_p2 = partitioned_classification_model.predict(
440
+ "cats", partition_id="p2", partition_filter_mode="exclude_global"
441
+ )
442
+ assert prediction_p2.label is not None
443
+ assert prediction_p2.label_name in label_names
444
+ assert 0 <= prediction_p2.confidence <= 1
445
+
446
+
447
+ def test_predict_with_partition_id_include_global(
448
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
449
+ ):
450
+ """Test predict with partition_id and include_global mode (default)"""
451
+ # Predict with partition_id p1 and include_global (default) - should include both p1 and global memories
452
+ prediction = partitioned_classification_model.predict(
453
+ "soup", partition_id="p1", partition_filter_mode="include_global"
454
+ )
455
+ assert prediction.label is not None
456
+ assert prediction.label_name in label_names
457
+ assert 0 <= prediction.confidence <= 1
458
+
459
+
460
+ def test_predict_with_partition_id_exclude_global(
461
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
462
+ ):
463
+ """Test predict with partition_id and exclude_global mode"""
464
+ # Predict with partition_id p1 and exclude_global - should only use p1 memories
465
+ prediction = partitioned_classification_model.predict(
466
+ "soup", partition_id="p1", partition_filter_mode="exclude_global"
467
+ )
468
+ assert prediction.label is not None
469
+ assert prediction.label_name in label_names
470
+ assert 0 <= prediction.confidence <= 1
471
+
472
+
473
+ def test_predict_with_partition_id_only_global(
474
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
475
+ ):
476
+ """Test predict with partition_filter_mode only_global"""
477
+ # Predict with only_global mode - should only use global memories
478
+ prediction = partitioned_classification_model.predict("cats", partition_filter_mode="only_global")
479
+ assert prediction.label is not None
480
+ assert prediction.label_name in label_names
481
+ assert 0 <= prediction.confidence <= 1
482
+
483
+
484
+ def test_predict_with_partition_id_ignore_partitions(
485
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
486
+ ):
487
+ """Test predict with partition_filter_mode ignore_partitions"""
488
+ # Predict with ignore_partitions mode - should ignore partition filtering
489
+ prediction = partitioned_classification_model.predict("soup", partition_filter_mode="ignore_partitions")
490
+ assert prediction.label is not None
491
+ assert prediction.label_name in label_names
492
+ assert 0 <= prediction.confidence <= 1
493
+
494
+
495
+ def test_predict_batch_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
496
+ """Test batch predict with partition_id"""
497
+ # Batch predict with partition_id p1
498
+ predictions = partitioned_classification_model.predict(
499
+ ["soup is good", "cats are cute"],
500
+ partition_id="p1",
501
+ partition_filter_mode="exclude_global",
502
+ )
503
+ assert len(predictions) == 2
504
+ assert all(p.label is not None for p in predictions)
505
+ assert all(p.label_name in label_names for p in predictions)
506
+ assert all(0 <= p.confidence <= 1 for p in predictions)
507
+ assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
508
+
509
+
510
+ def test_predict_with_partition_id_and_filters(
511
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
512
+ ):
513
+ """Test predict with partition_id and filters"""
514
+ # Predict with partition_id and filters
515
+ prediction = partitioned_classification_model.predict(
516
+ "soup",
517
+ partition_id="p1",
518
+ partition_filter_mode="exclude_global",
519
+ filters=[("key", "==", "g1")],
520
+ )
521
+ assert prediction.label is not None
522
+ assert prediction.label_name in label_names
523
+ assert 0 <= prediction.confidence <= 1
524
+
525
+
526
+ def test_predict_batch_with_list_of_partition_ids(
527
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
528
+ ):
529
+ """Test batch predict with a list of partition_ids (one for each query input)"""
530
+ # Batch predict with a list of partition_ids - one for each input
531
+ # First input uses p1, second input uses p2
532
+ predictions = partitioned_classification_model.predict(
533
+ ["soup is good", "cats are cute"],
534
+ partition_id=["p1", "p2"],
535
+ partition_filter_mode="exclude_global",
536
+ )
537
+ assert len(predictions) == 2
538
+ assert all(p.label is not None for p in predictions)
539
+ assert all(p.label_name in label_names for p in predictions)
540
+ assert all(0 <= p.confidence <= 1 for p in predictions)
541
+ assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
542
+
543
+ # Verify that predictions were made using the correct partitions
544
+ # Each prediction should use memories from its respective partition
545
+ assert predictions[0].input_value == "soup is good"
546
+ assert predictions[1].input_value == "cats are cute"
547
+
548
+
549
+ @pytest.mark.asyncio
550
+ async def test_predict_async_with_partition_id(
551
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
552
+ ):
553
+ """Test async predict with partition_id"""
554
+ # Async predict with partition_id p1
555
+ prediction = await partitioned_classification_model.apredict(
556
+ "soup", partition_id="p1", partition_filter_mode="exclude_global"
557
+ )
558
+ assert prediction.label is not None
559
+ assert prediction.label_name in label_names
560
+ assert 0 <= prediction.confidence <= 1
561
+ assert prediction.logits is not None
562
+ assert len(prediction.logits) == 2
563
+
564
+
565
+ @pytest.mark.asyncio
566
+ async def test_predict_async_batch_with_partition_id(
567
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
568
+ ):
569
+ """Test async batch predict with partition_id"""
570
+ # Async batch predict with partition_id p1
571
+ predictions = await partitioned_classification_model.apredict(
572
+ ["soup is good", "cats are cute"],
573
+ partition_id="p1",
574
+ partition_filter_mode="exclude_global",
575
+ )
576
+ assert len(predictions) == 2
577
+ assert all(p.label is not None for p in predictions)
578
+ assert all(p.label_name in label_names for p in predictions)
579
+ assert all(0 <= p.confidence <= 1 for p in predictions)
580
+
581
+
582
+ @pytest.mark.asyncio
583
+ async def test_predict_async_batch_with_list_of_partition_ids(
584
+ partitioned_classification_model: ClassificationModel, label_names: list[str]
585
+ ):
586
+ """Test async batch predict with a list of partition_ids (one for each query input)"""
587
+ # Async batch predict with a list of partition_ids - one for each input
588
+ # First input uses p1, second input uses p2
589
+ predictions = await partitioned_classification_model.apredict(
590
+ ["soup is good", "cats are cute"],
591
+ partition_id=["p1", "p2"],
592
+ partition_filter_mode="exclude_global",
593
+ )
594
+ assert len(predictions) == 2
595
+ assert all(p.label is not None for p in predictions)
596
+ assert all(p.label_name in label_names for p in predictions)
597
+ assert all(0 <= p.confidence <= 1 for p in predictions)
598
+ assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
599
+
600
+ # Verify that predictions were made using the correct partitions
601
+ # Each prediction should use memories from its respective partition
602
+ assert predictions[0].input_value == "soup is good"
603
+ assert predictions[1].input_value == "cats are cute"
604
+
605
+
287
606
  def test_record_prediction_feedback(classification_model: ClassificationModel):
288
607
  predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
289
608
  expected_labels = [0, 1]