orca-sdk 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +186 -43
- orca_sdk/_shared/metrics_test.py +99 -6
- orca_sdk/_utils/data_parsing_test.py +1 -1
- orca_sdk/async_client.py +52 -14
- orca_sdk/classification_model.py +107 -30
- orca_sdk/classification_model_test.py +327 -8
- orca_sdk/client.py +52 -14
- orca_sdk/conftest.py +140 -21
- orca_sdk/embedding_model.py +0 -2
- orca_sdk/memoryset.py +141 -26
- orca_sdk/memoryset_test.py +253 -4
- orca_sdk/regression_model.py +73 -16
- orca_sdk/regression_model_test.py +213 -0
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.6.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.6.dist-info}/RECORD +16 -16
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.6.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -12,11 +12,10 @@ from ._utils.common import UNSET, CreateMode, DropMode
|
|
|
12
12
|
from .async_client import OrcaAsyncClient
|
|
13
13
|
from .client import (
|
|
14
14
|
BootstrapClassificationModelMeta,
|
|
15
|
-
|
|
16
|
-
ClassificationEvaluationRequest,
|
|
15
|
+
BootstrapLabeledMemoryDataResult,
|
|
17
16
|
ClassificationModelMetadata,
|
|
17
|
+
ClassificationPredictionRequest,
|
|
18
18
|
OrcaClient,
|
|
19
|
-
PostClassificationModelByModelNameOrIdEvaluationParams,
|
|
20
19
|
PredictiveModelUpdate,
|
|
21
20
|
RACHeadType,
|
|
22
21
|
)
|
|
@@ -42,7 +41,7 @@ class BootstrappedClassificationModel:
|
|
|
42
41
|
datasource: Datasource | None
|
|
43
42
|
memoryset: LabeledMemoryset | None
|
|
44
43
|
classification_model: ClassificationModel | None
|
|
45
|
-
agent_output:
|
|
44
|
+
agent_output: BootstrapLabeledMemoryDataResult | None
|
|
46
45
|
|
|
47
46
|
def __init__(self, metadata: BootstrapClassificationModelMeta):
|
|
48
47
|
self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
|
|
@@ -358,6 +357,10 @@ class ClassificationModel:
|
|
|
358
357
|
use_lookup_cache: bool = True,
|
|
359
358
|
timeout_seconds: int = 10,
|
|
360
359
|
ignore_unlabeled: bool = False,
|
|
360
|
+
partition_id: str | list[str | None] | None = None,
|
|
361
|
+
partition_filter_mode: Literal[
|
|
362
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
363
|
+
] = "include_global",
|
|
361
364
|
use_gpu: bool = True,
|
|
362
365
|
) -> list[ClassificationPrediction]:
|
|
363
366
|
pass
|
|
@@ -374,6 +377,10 @@ class ClassificationModel:
|
|
|
374
377
|
use_lookup_cache: bool = True,
|
|
375
378
|
timeout_seconds: int = 10,
|
|
376
379
|
ignore_unlabeled: bool = False,
|
|
380
|
+
partition_id: str | None = None,
|
|
381
|
+
partition_filter_mode: Literal[
|
|
382
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
383
|
+
] = "include_global",
|
|
377
384
|
use_gpu: bool = True,
|
|
378
385
|
) -> ClassificationPrediction:
|
|
379
386
|
pass
|
|
@@ -389,6 +396,10 @@ class ClassificationModel:
|
|
|
389
396
|
use_lookup_cache: bool = True,
|
|
390
397
|
timeout_seconds: int = 10,
|
|
391
398
|
ignore_unlabeled: bool = False,
|
|
399
|
+
partition_id: str | None | list[str | None] = None,
|
|
400
|
+
partition_filter_mode: Literal[
|
|
401
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
402
|
+
] = "include_global",
|
|
392
403
|
use_gpu: bool = True,
|
|
393
404
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
394
405
|
"""
|
|
@@ -410,6 +421,12 @@ class ClassificationModel:
|
|
|
410
421
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
411
422
|
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
412
423
|
If False (default), allow unlabeled memories when necessary.
|
|
424
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
425
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
426
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
427
|
+
* `"include_global"`: Include global memories
|
|
428
|
+
* `"exclude_global"`: Exclude global memories
|
|
429
|
+
* `"only_global"`: Only include global memories
|
|
413
430
|
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
414
431
|
|
|
415
432
|
Returns:
|
|
@@ -463,21 +480,26 @@ class ClassificationModel:
|
|
|
463
480
|
|
|
464
481
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
465
482
|
client = OrcaClient._resolve_client()
|
|
483
|
+
request_json: ClassificationPredictionRequest = {
|
|
484
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
485
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
486
|
+
"expected_labels": expected_labels,
|
|
487
|
+
"tags": list(tags or set()),
|
|
488
|
+
"save_telemetry": telemetry_on,
|
|
489
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
490
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
491
|
+
"prompt": prompt,
|
|
492
|
+
"use_lookup_cache": use_lookup_cache,
|
|
493
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
494
|
+
"partition_filter_mode": partition_filter_mode,
|
|
495
|
+
}
|
|
496
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
497
|
+
if partition_filter_mode != "ignore_partitions":
|
|
498
|
+
request_json["partition_ids"] = partition_id
|
|
466
499
|
response = client.POST(
|
|
467
500
|
endpoint,
|
|
468
501
|
params={"name_or_id": self.id},
|
|
469
|
-
json=
|
|
470
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
471
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
472
|
-
"expected_labels": expected_labels,
|
|
473
|
-
"tags": list(tags or set()),
|
|
474
|
-
"save_telemetry": telemetry_on,
|
|
475
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
476
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
477
|
-
"prompt": prompt,
|
|
478
|
-
"use_lookup_cache": use_lookup_cache,
|
|
479
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
480
|
-
},
|
|
502
|
+
json=request_json,
|
|
481
503
|
timeout=timeout_seconds,
|
|
482
504
|
)
|
|
483
505
|
|
|
@@ -515,6 +537,10 @@ class ClassificationModel:
|
|
|
515
537
|
use_lookup_cache: bool = True,
|
|
516
538
|
timeout_seconds: int = 10,
|
|
517
539
|
ignore_unlabeled: bool = False,
|
|
540
|
+
partition_id: str | list[str | None] | None = None,
|
|
541
|
+
partition_filter_mode: Literal[
|
|
542
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
543
|
+
] = "include_global",
|
|
518
544
|
) -> list[ClassificationPrediction]:
|
|
519
545
|
pass
|
|
520
546
|
|
|
@@ -530,6 +556,10 @@ class ClassificationModel:
|
|
|
530
556
|
use_lookup_cache: bool = True,
|
|
531
557
|
timeout_seconds: int = 10,
|
|
532
558
|
ignore_unlabeled: bool = False,
|
|
559
|
+
partition_id: str | None = None,
|
|
560
|
+
partition_filter_mode: Literal[
|
|
561
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
562
|
+
] = "include_global",
|
|
533
563
|
) -> ClassificationPrediction:
|
|
534
564
|
pass
|
|
535
565
|
|
|
@@ -544,6 +574,10 @@ class ClassificationModel:
|
|
|
544
574
|
use_lookup_cache: bool = True,
|
|
545
575
|
timeout_seconds: int = 10,
|
|
546
576
|
ignore_unlabeled: bool = False,
|
|
577
|
+
partition_id: str | None | list[str | None] = None,
|
|
578
|
+
partition_filter_mode: Literal[
|
|
579
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
580
|
+
] = "include_global",
|
|
547
581
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
548
582
|
"""
|
|
549
583
|
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -564,7 +598,12 @@ class ClassificationModel:
|
|
|
564
598
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
565
599
|
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
566
600
|
If False (default), allow unlabeled memories when necessary.
|
|
567
|
-
|
|
601
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
602
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
603
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
604
|
+
* `"include_global"`: Include global memories
|
|
605
|
+
* `"exclude_global"`: Exclude global memories
|
|
606
|
+
* `"only_global"`: Only include global memories
|
|
568
607
|
Returns:
|
|
569
608
|
Label prediction or list of label predictions.
|
|
570
609
|
|
|
@@ -611,21 +650,26 @@ class ClassificationModel:
|
|
|
611
650
|
|
|
612
651
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
613
652
|
client = OrcaAsyncClient._resolve_client()
|
|
653
|
+
request_json: ClassificationPredictionRequest = {
|
|
654
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
655
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
656
|
+
"expected_labels": expected_labels,
|
|
657
|
+
"tags": list(tags or set()),
|
|
658
|
+
"save_telemetry": telemetry_on,
|
|
659
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
660
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
661
|
+
"prompt": prompt,
|
|
662
|
+
"use_lookup_cache": use_lookup_cache,
|
|
663
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
664
|
+
"partition_filter_mode": partition_filter_mode,
|
|
665
|
+
}
|
|
666
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
667
|
+
if partition_filter_mode != "ignore_partitions":
|
|
668
|
+
request_json["partition_ids"] = partition_id
|
|
614
669
|
response = await client.POST(
|
|
615
670
|
"/gpu/classification_model/{name_or_id}/prediction",
|
|
616
671
|
params={"name_or_id": self.id},
|
|
617
|
-
json=
|
|
618
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
619
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
620
|
-
"expected_labels": expected_labels,
|
|
621
|
-
"tags": list(tags or set()),
|
|
622
|
-
"save_telemetry": telemetry_on,
|
|
623
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
624
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
625
|
-
"prompt": prompt,
|
|
626
|
-
"use_lookup_cache": use_lookup_cache,
|
|
627
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
628
|
-
},
|
|
672
|
+
json=request_json,
|
|
629
673
|
timeout=timeout_seconds,
|
|
630
674
|
)
|
|
631
675
|
|
|
@@ -730,6 +774,10 @@ class ClassificationModel:
|
|
|
730
774
|
subsample: int | float | None,
|
|
731
775
|
background: bool = False,
|
|
732
776
|
ignore_unlabeled: bool = False,
|
|
777
|
+
partition_column: str | None = None,
|
|
778
|
+
partition_filter_mode: Literal[
|
|
779
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
780
|
+
] = "include_global",
|
|
733
781
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
734
782
|
client = OrcaClient._resolve_client()
|
|
735
783
|
response = client.POST(
|
|
@@ -744,6 +792,8 @@ class ClassificationModel:
|
|
|
744
792
|
"telemetry_tags": list(tags) if tags else None,
|
|
745
793
|
"subsample": subsample,
|
|
746
794
|
"ignore_unlabeled": ignore_unlabeled,
|
|
795
|
+
"datasource_partition_column": partition_column,
|
|
796
|
+
"partition_filter_mode": partition_filter_mode,
|
|
747
797
|
},
|
|
748
798
|
)
|
|
749
799
|
|
|
@@ -780,6 +830,10 @@ class ClassificationModel:
|
|
|
780
830
|
tags: set[str],
|
|
781
831
|
batch_size: int,
|
|
782
832
|
ignore_unlabeled: bool,
|
|
833
|
+
partition_column: str | None = None,
|
|
834
|
+
partition_filter_mode: Literal[
|
|
835
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
836
|
+
] = "include_global",
|
|
783
837
|
) -> ClassificationMetrics:
|
|
784
838
|
if len(dataset) == 0:
|
|
785
839
|
raise ValueError("Evaluation dataset cannot be empty")
|
|
@@ -796,6 +850,8 @@ class ClassificationModel:
|
|
|
796
850
|
tags=tags,
|
|
797
851
|
save_telemetry="sync" if record_predictions else "off",
|
|
798
852
|
ignore_unlabeled=ignore_unlabeled,
|
|
853
|
+
partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
|
|
854
|
+
partition_filter_mode=partition_filter_mode,
|
|
799
855
|
)
|
|
800
856
|
]
|
|
801
857
|
|
|
@@ -813,12 +869,16 @@ class ClassificationModel:
|
|
|
813
869
|
*,
|
|
814
870
|
value_column: str = "value",
|
|
815
871
|
label_column: str = "label",
|
|
872
|
+
partition_column: str | None = None,
|
|
816
873
|
record_predictions: bool = False,
|
|
817
874
|
tags: set[str] = {"evaluation"},
|
|
818
875
|
batch_size: int = 100,
|
|
819
876
|
subsample: int | float | None = None,
|
|
820
877
|
background: Literal[True],
|
|
821
878
|
ignore_unlabeled: bool = False,
|
|
879
|
+
partition_filter_mode: Literal[
|
|
880
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
881
|
+
] = "include_global",
|
|
822
882
|
) -> Job[ClassificationMetrics]:
|
|
823
883
|
pass
|
|
824
884
|
|
|
@@ -829,12 +889,16 @@ class ClassificationModel:
|
|
|
829
889
|
*,
|
|
830
890
|
value_column: str = "value",
|
|
831
891
|
label_column: str = "label",
|
|
892
|
+
partition_column: str | None = None,
|
|
832
893
|
record_predictions: bool = False,
|
|
833
894
|
tags: set[str] = {"evaluation"},
|
|
834
895
|
batch_size: int = 100,
|
|
835
896
|
subsample: int | float | None = None,
|
|
836
897
|
background: Literal[False] = False,
|
|
837
898
|
ignore_unlabeled: bool = False,
|
|
899
|
+
partition_filter_mode: Literal[
|
|
900
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
901
|
+
] = "include_global",
|
|
838
902
|
) -> ClassificationMetrics:
|
|
839
903
|
pass
|
|
840
904
|
|
|
@@ -844,12 +908,16 @@ class ClassificationModel:
|
|
|
844
908
|
*,
|
|
845
909
|
value_column: str = "value",
|
|
846
910
|
label_column: str = "label",
|
|
911
|
+
partition_column: str | None = None,
|
|
847
912
|
record_predictions: bool = False,
|
|
848
913
|
tags: set[str] = {"evaluation"},
|
|
849
914
|
batch_size: int = 100,
|
|
850
915
|
subsample: int | float | None = None,
|
|
851
916
|
background: bool = False,
|
|
852
917
|
ignore_unlabeled: bool = False,
|
|
918
|
+
partition_filter_mode: Literal[
|
|
919
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
920
|
+
] = "include_global",
|
|
853
921
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
854
922
|
"""
|
|
855
923
|
Evaluate the classification model on a given dataset or datasource
|
|
@@ -858,13 +926,18 @@ class ClassificationModel:
|
|
|
858
926
|
data: Dataset or Datasource to evaluate the model on
|
|
859
927
|
value_column: Name of the column that contains the input values to the model
|
|
860
928
|
label_column: Name of the column containing the expected labels
|
|
929
|
+
partition_column: Optional name of the column that contains the partition IDs
|
|
861
930
|
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
862
931
|
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
863
932
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
864
933
|
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
865
934
|
background: Whether to run the operation in the background and return a job handle
|
|
866
935
|
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
867
|
-
|
|
936
|
+
partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
|
|
937
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
938
|
+
* `"include_global"`: Include global memories
|
|
939
|
+
* `"exclude_global"`: Exclude global memories
|
|
940
|
+
* `"only_global"`: Only include global memories
|
|
868
941
|
Returns:
|
|
869
942
|
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
870
943
|
|
|
@@ -888,6 +961,8 @@ class ClassificationModel:
|
|
|
888
961
|
subsample=subsample,
|
|
889
962
|
background=background,
|
|
890
963
|
ignore_unlabeled=ignore_unlabeled,
|
|
964
|
+
partition_column=partition_column,
|
|
965
|
+
partition_filter_mode=partition_filter_mode,
|
|
891
966
|
)
|
|
892
967
|
elif isinstance(data, Dataset):
|
|
893
968
|
return self._evaluate_dataset(
|
|
@@ -898,6 +973,8 @@ class ClassificationModel:
|
|
|
898
973
|
tags=tags,
|
|
899
974
|
batch_size=batch_size,
|
|
900
975
|
ignore_unlabeled=ignore_unlabeled,
|
|
976
|
+
partition_column=partition_column,
|
|
977
|
+
partition_filter_mode=partition_filter_mode,
|
|
901
978
|
)
|
|
902
979
|
else:
|
|
903
980
|
raise ValueError(f"Invalid data type: {type(data)}")
|
|
@@ -187,18 +187,24 @@ def test_evaluate(classification_model, eval_datasource: Datasource, eval_datase
|
|
|
187
187
|
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
188
188
|
|
|
189
189
|
assert result.pr_auc is not None
|
|
190
|
-
assert np.allclose(result.pr_auc, 0.
|
|
190
|
+
assert np.allclose(result.pr_auc, 0.83333)
|
|
191
191
|
assert result.pr_curve is not None
|
|
192
|
-
assert np.allclose(
|
|
193
|
-
|
|
194
|
-
|
|
192
|
+
assert np.allclose(
|
|
193
|
+
result.pr_curve["thresholds"],
|
|
194
|
+
[0.0, 0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364],
|
|
195
|
+
)
|
|
196
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.666666, 0.5, 1.0, 1.0])
|
|
197
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 1.0, 0.5, 0.5, 0.0])
|
|
195
198
|
|
|
196
199
|
assert result.roc_auc is not None
|
|
197
|
-
assert np.allclose(result.roc_auc, 0.
|
|
200
|
+
assert np.allclose(result.roc_auc, 0.75)
|
|
198
201
|
assert result.roc_curve is not None
|
|
199
|
-
assert np.allclose(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
+
assert np.allclose(
|
|
203
|
+
result.roc_curve["thresholds"],
|
|
204
|
+
[0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364, 1.0],
|
|
205
|
+
)
|
|
206
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.5, 0.0, 0.0])
|
|
207
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 1.0, 0.5, 0.5, 0.0])
|
|
202
208
|
|
|
203
209
|
|
|
204
210
|
def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
|
|
@@ -221,6 +227,139 @@ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval
|
|
|
221
227
|
assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
|
|
222
228
|
|
|
223
229
|
|
|
230
|
+
def test_evaluate_with_partition_column_dataset(partitioned_classification_model: ClassificationModel):
|
|
231
|
+
"""Test evaluate with partition_column on a Dataset"""
|
|
232
|
+
# Create a test dataset with partition_id column
|
|
233
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
234
|
+
[
|
|
235
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
236
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
237
|
+
{"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
|
|
238
|
+
{"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
|
|
239
|
+
]
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Evaluate with partition_column
|
|
243
|
+
result = partitioned_classification_model.evaluate(
|
|
244
|
+
eval_dataset_with_partition,
|
|
245
|
+
partition_column="partition_id",
|
|
246
|
+
partition_filter_mode="exclude_global",
|
|
247
|
+
)
|
|
248
|
+
assert result is not None
|
|
249
|
+
assert isinstance(result, ClassificationMetrics)
|
|
250
|
+
assert isinstance(result.accuracy, float)
|
|
251
|
+
assert isinstance(result.f1_score, float)
|
|
252
|
+
assert isinstance(result.loss, float)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_evaluate_with_partition_column_include_global(partitioned_classification_model: ClassificationModel):
|
|
256
|
+
"""Test evaluate with partition_column and include_global mode"""
|
|
257
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
258
|
+
[
|
|
259
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
260
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
261
|
+
]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Evaluate with partition_column and include_global (default)
|
|
265
|
+
result = partitioned_classification_model.evaluate(
|
|
266
|
+
eval_dataset_with_partition,
|
|
267
|
+
partition_column="partition_id",
|
|
268
|
+
partition_filter_mode="include_global",
|
|
269
|
+
)
|
|
270
|
+
assert result is not None
|
|
271
|
+
assert isinstance(result, ClassificationMetrics)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def test_evaluate_with_partition_column_exclude_global(partitioned_classification_model: ClassificationModel):
|
|
275
|
+
"""Test evaluate with partition_column and exclude_global mode"""
|
|
276
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
277
|
+
[
|
|
278
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
279
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
280
|
+
]
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Evaluate with partition_column and exclude_global
|
|
284
|
+
result = partitioned_classification_model.evaluate(
|
|
285
|
+
eval_dataset_with_partition,
|
|
286
|
+
partition_column="partition_id",
|
|
287
|
+
partition_filter_mode="exclude_global",
|
|
288
|
+
)
|
|
289
|
+
assert result is not None
|
|
290
|
+
assert isinstance(result, ClassificationMetrics)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def test_evaluate_with_partition_column_only_global(partitioned_classification_model: ClassificationModel):
|
|
294
|
+
"""Test evaluate with partition_filter_mode only_global"""
|
|
295
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
296
|
+
[
|
|
297
|
+
{"value": "cats are independent animals", "label": 1, "partition_id": None},
|
|
298
|
+
{"value": "i love the beach", "label": 1, "partition_id": None},
|
|
299
|
+
]
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Evaluate with only_global mode
|
|
303
|
+
result = partitioned_classification_model.evaluate(
|
|
304
|
+
eval_dataset_with_partition,
|
|
305
|
+
partition_column="partition_id",
|
|
306
|
+
partition_filter_mode="only_global",
|
|
307
|
+
)
|
|
308
|
+
assert result is not None
|
|
309
|
+
assert isinstance(result, ClassificationMetrics)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def test_evaluate_with_partition_column_ignore_partitions(partitioned_classification_model: ClassificationModel):
|
|
313
|
+
"""Test evaluate with partition_filter_mode ignore_partitions"""
|
|
314
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
315
|
+
[
|
|
316
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
317
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p2"},
|
|
318
|
+
]
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Evaluate with ignore_partitions mode
|
|
322
|
+
result = partitioned_classification_model.evaluate(
|
|
323
|
+
eval_dataset_with_partition,
|
|
324
|
+
partition_column="partition_id",
|
|
325
|
+
partition_filter_mode="ignore_partitions",
|
|
326
|
+
)
|
|
327
|
+
assert result is not None
|
|
328
|
+
assert isinstance(result, ClassificationMetrics)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
332
|
+
def test_evaluate_with_partition_column_datasource(partitioned_classification_model: ClassificationModel, data_type):
|
|
333
|
+
"""Test evaluate with partition_column on a Datasource"""
|
|
334
|
+
# Create a test datasource with partition_id column
|
|
335
|
+
eval_data_with_partition = [
|
|
336
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
337
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
338
|
+
{"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
|
|
339
|
+
{"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
|
|
340
|
+
]
|
|
341
|
+
|
|
342
|
+
if data_type == "dataset":
|
|
343
|
+
eval_data = Dataset.from_list(eval_data_with_partition)
|
|
344
|
+
result = partitioned_classification_model.evaluate(
|
|
345
|
+
eval_data,
|
|
346
|
+
partition_column="partition_id",
|
|
347
|
+
partition_filter_mode="exclude_global",
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
eval_datasource = Datasource.from_list("eval_datasource_with_partition", eval_data_with_partition)
|
|
351
|
+
result = partitioned_classification_model.evaluate(
|
|
352
|
+
eval_datasource,
|
|
353
|
+
partition_column="partition_id",
|
|
354
|
+
partition_filter_mode="exclude_global",
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
assert result is not None
|
|
358
|
+
assert isinstance(result, ClassificationMetrics)
|
|
359
|
+
assert isinstance(result.accuracy, float)
|
|
360
|
+
assert isinstance(result.f1_score, float)
|
|
361
|
+
|
|
362
|
+
|
|
224
363
|
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
225
364
|
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
226
365
|
assert len(predictions) == 2
|
|
@@ -284,6 +423,186 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
|
284
423
|
model.predict("test")
|
|
285
424
|
|
|
286
425
|
|
|
426
|
+
def test_predict_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
|
|
427
|
+
"""Test predict with a specific partition_id"""
|
|
428
|
+
# Predict with partition_id p1 - should use memories from p1
|
|
429
|
+
prediction = partitioned_classification_model.predict(
|
|
430
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
431
|
+
)
|
|
432
|
+
assert prediction.label is not None
|
|
433
|
+
assert prediction.label_name in label_names
|
|
434
|
+
assert 0 <= prediction.confidence <= 1
|
|
435
|
+
assert prediction.logits is not None
|
|
436
|
+
assert len(prediction.logits) == 2
|
|
437
|
+
|
|
438
|
+
# Predict with partition_id p2 - should use memories from p2
|
|
439
|
+
prediction_p2 = partitioned_classification_model.predict(
|
|
440
|
+
"cats", partition_id="p2", partition_filter_mode="exclude_global"
|
|
441
|
+
)
|
|
442
|
+
assert prediction_p2.label is not None
|
|
443
|
+
assert prediction_p2.label_name in label_names
|
|
444
|
+
assert 0 <= prediction_p2.confidence <= 1
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def test_predict_with_partition_id_include_global(
|
|
448
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
449
|
+
):
|
|
450
|
+
"""Test predict with partition_id and include_global mode (default)"""
|
|
451
|
+
# Predict with partition_id p1 and include_global (default) - should include both p1 and global memories
|
|
452
|
+
prediction = partitioned_classification_model.predict(
|
|
453
|
+
"soup", partition_id="p1", partition_filter_mode="include_global"
|
|
454
|
+
)
|
|
455
|
+
assert prediction.label is not None
|
|
456
|
+
assert prediction.label_name in label_names
|
|
457
|
+
assert 0 <= prediction.confidence <= 1
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def test_predict_with_partition_id_exclude_global(
|
|
461
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
462
|
+
):
|
|
463
|
+
"""Test predict with partition_id and exclude_global mode"""
|
|
464
|
+
# Predict with partition_id p1 and exclude_global - should only use p1 memories
|
|
465
|
+
prediction = partitioned_classification_model.predict(
|
|
466
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
467
|
+
)
|
|
468
|
+
assert prediction.label is not None
|
|
469
|
+
assert prediction.label_name in label_names
|
|
470
|
+
assert 0 <= prediction.confidence <= 1
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def test_predict_with_partition_id_only_global(
|
|
474
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
475
|
+
):
|
|
476
|
+
"""Test predict with partition_filter_mode only_global"""
|
|
477
|
+
# Predict with only_global mode - should only use global memories
|
|
478
|
+
prediction = partitioned_classification_model.predict("cats", partition_filter_mode="only_global")
|
|
479
|
+
assert prediction.label is not None
|
|
480
|
+
assert prediction.label_name in label_names
|
|
481
|
+
assert 0 <= prediction.confidence <= 1
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def test_predict_with_partition_id_ignore_partitions(
|
|
485
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
486
|
+
):
|
|
487
|
+
"""Test predict with partition_filter_mode ignore_partitions"""
|
|
488
|
+
# Predict with ignore_partitions mode - should ignore partition filtering
|
|
489
|
+
prediction = partitioned_classification_model.predict("soup", partition_filter_mode="ignore_partitions")
|
|
490
|
+
assert prediction.label is not None
|
|
491
|
+
assert prediction.label_name in label_names
|
|
492
|
+
assert 0 <= prediction.confidence <= 1
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def test_predict_batch_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
|
|
496
|
+
"""Test batch predict with partition_id"""
|
|
497
|
+
# Batch predict with partition_id p1
|
|
498
|
+
predictions = partitioned_classification_model.predict(
|
|
499
|
+
["soup is good", "cats are cute"],
|
|
500
|
+
partition_id="p1",
|
|
501
|
+
partition_filter_mode="exclude_global",
|
|
502
|
+
)
|
|
503
|
+
assert len(predictions) == 2
|
|
504
|
+
assert all(p.label is not None for p in predictions)
|
|
505
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
506
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
507
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def test_predict_with_partition_id_and_filters(
|
|
511
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
512
|
+
):
|
|
513
|
+
"""Test predict with partition_id and filters"""
|
|
514
|
+
# Predict with partition_id and filters
|
|
515
|
+
prediction = partitioned_classification_model.predict(
|
|
516
|
+
"soup",
|
|
517
|
+
partition_id="p1",
|
|
518
|
+
partition_filter_mode="exclude_global",
|
|
519
|
+
filters=[("key", "==", "g1")],
|
|
520
|
+
)
|
|
521
|
+
assert prediction.label is not None
|
|
522
|
+
assert prediction.label_name in label_names
|
|
523
|
+
assert 0 <= prediction.confidence <= 1
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def test_predict_batch_with_list_of_partition_ids(
|
|
527
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
528
|
+
):
|
|
529
|
+
"""Test batch predict with a list of partition_ids (one for each query input)"""
|
|
530
|
+
# Batch predict with a list of partition_ids - one for each input
|
|
531
|
+
# First input uses p1, second input uses p2
|
|
532
|
+
predictions = partitioned_classification_model.predict(
|
|
533
|
+
["soup is good", "cats are cute"],
|
|
534
|
+
partition_id=["p1", "p2"],
|
|
535
|
+
partition_filter_mode="exclude_global",
|
|
536
|
+
)
|
|
537
|
+
assert len(predictions) == 2
|
|
538
|
+
assert all(p.label is not None for p in predictions)
|
|
539
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
540
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
541
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
542
|
+
|
|
543
|
+
# Verify that predictions were made using the correct partitions
|
|
544
|
+
# Each prediction should use memories from its respective partition
|
|
545
|
+
assert predictions[0].input_value == "soup is good"
|
|
546
|
+
assert predictions[1].input_value == "cats are cute"
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@pytest.mark.asyncio
|
|
550
|
+
async def test_predict_async_with_partition_id(
|
|
551
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
552
|
+
):
|
|
553
|
+
"""Test async predict with partition_id"""
|
|
554
|
+
# Async predict with partition_id p1
|
|
555
|
+
prediction = await partitioned_classification_model.apredict(
|
|
556
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
557
|
+
)
|
|
558
|
+
assert prediction.label is not None
|
|
559
|
+
assert prediction.label_name in label_names
|
|
560
|
+
assert 0 <= prediction.confidence <= 1
|
|
561
|
+
assert prediction.logits is not None
|
|
562
|
+
assert len(prediction.logits) == 2
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
@pytest.mark.asyncio
|
|
566
|
+
async def test_predict_async_batch_with_partition_id(
|
|
567
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
568
|
+
):
|
|
569
|
+
"""Test async batch predict with partition_id"""
|
|
570
|
+
# Async batch predict with partition_id p1
|
|
571
|
+
predictions = await partitioned_classification_model.apredict(
|
|
572
|
+
["soup is good", "cats are cute"],
|
|
573
|
+
partition_id="p1",
|
|
574
|
+
partition_filter_mode="exclude_global",
|
|
575
|
+
)
|
|
576
|
+
assert len(predictions) == 2
|
|
577
|
+
assert all(p.label is not None for p in predictions)
|
|
578
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
579
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@pytest.mark.asyncio
|
|
583
|
+
async def test_predict_async_batch_with_list_of_partition_ids(
|
|
584
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
585
|
+
):
|
|
586
|
+
"""Test async batch predict with a list of partition_ids (one for each query input)"""
|
|
587
|
+
# Async batch predict with a list of partition_ids - one for each input
|
|
588
|
+
# First input uses p1, second input uses p2
|
|
589
|
+
predictions = await partitioned_classification_model.apredict(
|
|
590
|
+
["soup is good", "cats are cute"],
|
|
591
|
+
partition_id=["p1", "p2"],
|
|
592
|
+
partition_filter_mode="exclude_global",
|
|
593
|
+
)
|
|
594
|
+
assert len(predictions) == 2
|
|
595
|
+
assert all(p.label is not None for p in predictions)
|
|
596
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
597
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
598
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
599
|
+
|
|
600
|
+
# Verify that predictions were made using the correct partitions
|
|
601
|
+
# Each prediction should use memories from its respective partition
|
|
602
|
+
assert predictions[0].input_value == "soup is good"
|
|
603
|
+
assert predictions[1].input_value == "cats are cute"
|
|
604
|
+
|
|
605
|
+
|
|
287
606
|
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
288
607
|
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
289
608
|
expected_labels = [0, 1]
|