orca-sdk 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +179 -40
- orca_sdk/_shared/metrics_test.py +99 -6
- orca_sdk/_utils/data_parsing_test.py +1 -1
- orca_sdk/async_client.py +14 -0
- orca_sdk/classification_model.py +105 -26
- orca_sdk/classification_model_test.py +327 -8
- orca_sdk/client.py +14 -0
- orca_sdk/conftest.py +140 -21
- orca_sdk/memoryset.py +141 -26
- orca_sdk/memoryset_test.py +253 -4
- orca_sdk/regression_model.py +73 -16
- orca_sdk/regression_model_test.py +213 -0
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.5.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.5.dist-info}/RECORD +15 -15
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.5.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -15,6 +15,7 @@ from .client import (
|
|
|
15
15
|
BootstrapClassificationModelResult,
|
|
16
16
|
ClassificationEvaluationRequest,
|
|
17
17
|
ClassificationModelMetadata,
|
|
18
|
+
ClassificationPredictionRequest,
|
|
18
19
|
OrcaClient,
|
|
19
20
|
PostClassificationModelByModelNameOrIdEvaluationParams,
|
|
20
21
|
PredictiveModelUpdate,
|
|
@@ -358,6 +359,10 @@ class ClassificationModel:
|
|
|
358
359
|
use_lookup_cache: bool = True,
|
|
359
360
|
timeout_seconds: int = 10,
|
|
360
361
|
ignore_unlabeled: bool = False,
|
|
362
|
+
partition_id: str | list[str | None] | None = None,
|
|
363
|
+
partition_filter_mode: Literal[
|
|
364
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
365
|
+
] = "include_global",
|
|
361
366
|
use_gpu: bool = True,
|
|
362
367
|
) -> list[ClassificationPrediction]:
|
|
363
368
|
pass
|
|
@@ -374,6 +379,10 @@ class ClassificationModel:
|
|
|
374
379
|
use_lookup_cache: bool = True,
|
|
375
380
|
timeout_seconds: int = 10,
|
|
376
381
|
ignore_unlabeled: bool = False,
|
|
382
|
+
partition_id: str | None = None,
|
|
383
|
+
partition_filter_mode: Literal[
|
|
384
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
385
|
+
] = "include_global",
|
|
377
386
|
use_gpu: bool = True,
|
|
378
387
|
) -> ClassificationPrediction:
|
|
379
388
|
pass
|
|
@@ -389,6 +398,10 @@ class ClassificationModel:
|
|
|
389
398
|
use_lookup_cache: bool = True,
|
|
390
399
|
timeout_seconds: int = 10,
|
|
391
400
|
ignore_unlabeled: bool = False,
|
|
401
|
+
partition_id: str | None | list[str | None] = None,
|
|
402
|
+
partition_filter_mode: Literal[
|
|
403
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
404
|
+
] = "include_global",
|
|
392
405
|
use_gpu: bool = True,
|
|
393
406
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
394
407
|
"""
|
|
@@ -410,6 +423,12 @@ class ClassificationModel:
|
|
|
410
423
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
411
424
|
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
412
425
|
If False (default), allow unlabeled memories when necessary.
|
|
426
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
427
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
428
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
429
|
+
* `"include_global"`: Include global memories
|
|
430
|
+
* `"exclude_global"`: Exclude global memories
|
|
431
|
+
* `"only_global"`: Only include global memories
|
|
413
432
|
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
414
433
|
|
|
415
434
|
Returns:
|
|
@@ -463,21 +482,26 @@ class ClassificationModel:
|
|
|
463
482
|
|
|
464
483
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
465
484
|
client = OrcaClient._resolve_client()
|
|
485
|
+
request_json: ClassificationPredictionRequest = {
|
|
486
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
487
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
488
|
+
"expected_labels": expected_labels,
|
|
489
|
+
"tags": list(tags or set()),
|
|
490
|
+
"save_telemetry": telemetry_on,
|
|
491
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
492
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
493
|
+
"prompt": prompt,
|
|
494
|
+
"use_lookup_cache": use_lookup_cache,
|
|
495
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
496
|
+
"partition_filter_mode": partition_filter_mode,
|
|
497
|
+
}
|
|
498
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
499
|
+
if partition_filter_mode != "ignore_partitions":
|
|
500
|
+
request_json["partition_ids"] = partition_id
|
|
466
501
|
response = client.POST(
|
|
467
502
|
endpoint,
|
|
468
503
|
params={"name_or_id": self.id},
|
|
469
|
-
json=
|
|
470
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
471
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
472
|
-
"expected_labels": expected_labels,
|
|
473
|
-
"tags": list(tags or set()),
|
|
474
|
-
"save_telemetry": telemetry_on,
|
|
475
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
476
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
477
|
-
"prompt": prompt,
|
|
478
|
-
"use_lookup_cache": use_lookup_cache,
|
|
479
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
480
|
-
},
|
|
504
|
+
json=request_json,
|
|
481
505
|
timeout=timeout_seconds,
|
|
482
506
|
)
|
|
483
507
|
|
|
@@ -515,6 +539,10 @@ class ClassificationModel:
|
|
|
515
539
|
use_lookup_cache: bool = True,
|
|
516
540
|
timeout_seconds: int = 10,
|
|
517
541
|
ignore_unlabeled: bool = False,
|
|
542
|
+
partition_id: str | list[str | None] | None = None,
|
|
543
|
+
partition_filter_mode: Literal[
|
|
544
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
545
|
+
] = "include_global",
|
|
518
546
|
) -> list[ClassificationPrediction]:
|
|
519
547
|
pass
|
|
520
548
|
|
|
@@ -530,6 +558,10 @@ class ClassificationModel:
|
|
|
530
558
|
use_lookup_cache: bool = True,
|
|
531
559
|
timeout_seconds: int = 10,
|
|
532
560
|
ignore_unlabeled: bool = False,
|
|
561
|
+
partition_id: str | None = None,
|
|
562
|
+
partition_filter_mode: Literal[
|
|
563
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
564
|
+
] = "include_global",
|
|
533
565
|
) -> ClassificationPrediction:
|
|
534
566
|
pass
|
|
535
567
|
|
|
@@ -544,6 +576,10 @@ class ClassificationModel:
|
|
|
544
576
|
use_lookup_cache: bool = True,
|
|
545
577
|
timeout_seconds: int = 10,
|
|
546
578
|
ignore_unlabeled: bool = False,
|
|
579
|
+
partition_id: str | None | list[str | None] = None,
|
|
580
|
+
partition_filter_mode: Literal[
|
|
581
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
582
|
+
] = "include_global",
|
|
547
583
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
548
584
|
"""
|
|
549
585
|
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -564,7 +600,12 @@ class ClassificationModel:
|
|
|
564
600
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
565
601
|
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
566
602
|
If False (default), allow unlabeled memories when necessary.
|
|
567
|
-
|
|
603
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
604
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
605
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
606
|
+
* `"include_global"`: Include global memories
|
|
607
|
+
* `"exclude_global"`: Exclude global memories
|
|
608
|
+
* `"only_global"`: Only include global memories
|
|
568
609
|
Returns:
|
|
569
610
|
Label prediction or list of label predictions.
|
|
570
611
|
|
|
@@ -611,21 +652,26 @@ class ClassificationModel:
|
|
|
611
652
|
|
|
612
653
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
613
654
|
client = OrcaAsyncClient._resolve_client()
|
|
655
|
+
request_json: ClassificationPredictionRequest = {
|
|
656
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
657
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
658
|
+
"expected_labels": expected_labels,
|
|
659
|
+
"tags": list(tags or set()),
|
|
660
|
+
"save_telemetry": telemetry_on,
|
|
661
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
662
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
663
|
+
"prompt": prompt,
|
|
664
|
+
"use_lookup_cache": use_lookup_cache,
|
|
665
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
666
|
+
"partition_filter_mode": partition_filter_mode,
|
|
667
|
+
}
|
|
668
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
669
|
+
if partition_filter_mode != "ignore_partitions":
|
|
670
|
+
request_json["partition_ids"] = partition_id
|
|
614
671
|
response = await client.POST(
|
|
615
672
|
"/gpu/classification_model/{name_or_id}/prediction",
|
|
616
673
|
params={"name_or_id": self.id},
|
|
617
|
-
json=
|
|
618
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
619
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
620
|
-
"expected_labels": expected_labels,
|
|
621
|
-
"tags": list(tags or set()),
|
|
622
|
-
"save_telemetry": telemetry_on,
|
|
623
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
624
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
625
|
-
"prompt": prompt,
|
|
626
|
-
"use_lookup_cache": use_lookup_cache,
|
|
627
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
628
|
-
},
|
|
674
|
+
json=request_json,
|
|
629
675
|
timeout=timeout_seconds,
|
|
630
676
|
)
|
|
631
677
|
|
|
@@ -730,6 +776,10 @@ class ClassificationModel:
|
|
|
730
776
|
subsample: int | float | None,
|
|
731
777
|
background: bool = False,
|
|
732
778
|
ignore_unlabeled: bool = False,
|
|
779
|
+
partition_column: str | None = None,
|
|
780
|
+
partition_filter_mode: Literal[
|
|
781
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
782
|
+
] = "include_global",
|
|
733
783
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
734
784
|
client = OrcaClient._resolve_client()
|
|
735
785
|
response = client.POST(
|
|
@@ -744,6 +794,8 @@ class ClassificationModel:
|
|
|
744
794
|
"telemetry_tags": list(tags) if tags else None,
|
|
745
795
|
"subsample": subsample,
|
|
746
796
|
"ignore_unlabeled": ignore_unlabeled,
|
|
797
|
+
"datasource_partition_column": partition_column,
|
|
798
|
+
"partition_filter_mode": partition_filter_mode,
|
|
747
799
|
},
|
|
748
800
|
)
|
|
749
801
|
|
|
@@ -780,6 +832,10 @@ class ClassificationModel:
|
|
|
780
832
|
tags: set[str],
|
|
781
833
|
batch_size: int,
|
|
782
834
|
ignore_unlabeled: bool,
|
|
835
|
+
partition_column: str | None = None,
|
|
836
|
+
partition_filter_mode: Literal[
|
|
837
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
838
|
+
] = "include_global",
|
|
783
839
|
) -> ClassificationMetrics:
|
|
784
840
|
if len(dataset) == 0:
|
|
785
841
|
raise ValueError("Evaluation dataset cannot be empty")
|
|
@@ -796,6 +852,8 @@ class ClassificationModel:
|
|
|
796
852
|
tags=tags,
|
|
797
853
|
save_telemetry="sync" if record_predictions else "off",
|
|
798
854
|
ignore_unlabeled=ignore_unlabeled,
|
|
855
|
+
partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
|
|
856
|
+
partition_filter_mode=partition_filter_mode,
|
|
799
857
|
)
|
|
800
858
|
]
|
|
801
859
|
|
|
@@ -813,12 +871,16 @@ class ClassificationModel:
|
|
|
813
871
|
*,
|
|
814
872
|
value_column: str = "value",
|
|
815
873
|
label_column: str = "label",
|
|
874
|
+
partition_column: str | None = None,
|
|
816
875
|
record_predictions: bool = False,
|
|
817
876
|
tags: set[str] = {"evaluation"},
|
|
818
877
|
batch_size: int = 100,
|
|
819
878
|
subsample: int | float | None = None,
|
|
820
879
|
background: Literal[True],
|
|
821
880
|
ignore_unlabeled: bool = False,
|
|
881
|
+
partition_filter_mode: Literal[
|
|
882
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
883
|
+
] = "include_global",
|
|
822
884
|
) -> Job[ClassificationMetrics]:
|
|
823
885
|
pass
|
|
824
886
|
|
|
@@ -829,12 +891,16 @@ class ClassificationModel:
|
|
|
829
891
|
*,
|
|
830
892
|
value_column: str = "value",
|
|
831
893
|
label_column: str = "label",
|
|
894
|
+
partition_column: str | None = None,
|
|
832
895
|
record_predictions: bool = False,
|
|
833
896
|
tags: set[str] = {"evaluation"},
|
|
834
897
|
batch_size: int = 100,
|
|
835
898
|
subsample: int | float | None = None,
|
|
836
899
|
background: Literal[False] = False,
|
|
837
900
|
ignore_unlabeled: bool = False,
|
|
901
|
+
partition_filter_mode: Literal[
|
|
902
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
903
|
+
] = "include_global",
|
|
838
904
|
) -> ClassificationMetrics:
|
|
839
905
|
pass
|
|
840
906
|
|
|
@@ -844,12 +910,16 @@ class ClassificationModel:
|
|
|
844
910
|
*,
|
|
845
911
|
value_column: str = "value",
|
|
846
912
|
label_column: str = "label",
|
|
913
|
+
partition_column: str | None = None,
|
|
847
914
|
record_predictions: bool = False,
|
|
848
915
|
tags: set[str] = {"evaluation"},
|
|
849
916
|
batch_size: int = 100,
|
|
850
917
|
subsample: int | float | None = None,
|
|
851
918
|
background: bool = False,
|
|
852
919
|
ignore_unlabeled: bool = False,
|
|
920
|
+
partition_filter_mode: Literal[
|
|
921
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
922
|
+
] = "include_global",
|
|
853
923
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
854
924
|
"""
|
|
855
925
|
Evaluate the classification model on a given dataset or datasource
|
|
@@ -858,13 +928,18 @@ class ClassificationModel:
|
|
|
858
928
|
data: Dataset or Datasource to evaluate the model on
|
|
859
929
|
value_column: Name of the column that contains the input values to the model
|
|
860
930
|
label_column: Name of the column containing the expected labels
|
|
931
|
+
partition_column: Optional name of the column that contains the partition IDs
|
|
861
932
|
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
862
933
|
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
863
934
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
864
935
|
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
865
936
|
background: Whether to run the operation in the background and return a job handle
|
|
866
937
|
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
867
|
-
|
|
938
|
+
partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
|
|
939
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
940
|
+
* `"include_global"`: Include global memories
|
|
941
|
+
* `"exclude_global"`: Exclude global memories
|
|
942
|
+
* `"only_global"`: Only include global memories
|
|
868
943
|
Returns:
|
|
869
944
|
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
870
945
|
|
|
@@ -888,6 +963,8 @@ class ClassificationModel:
|
|
|
888
963
|
subsample=subsample,
|
|
889
964
|
background=background,
|
|
890
965
|
ignore_unlabeled=ignore_unlabeled,
|
|
966
|
+
partition_column=partition_column,
|
|
967
|
+
partition_filter_mode=partition_filter_mode,
|
|
891
968
|
)
|
|
892
969
|
elif isinstance(data, Dataset):
|
|
893
970
|
return self._evaluate_dataset(
|
|
@@ -898,6 +975,8 @@ class ClassificationModel:
|
|
|
898
975
|
tags=tags,
|
|
899
976
|
batch_size=batch_size,
|
|
900
977
|
ignore_unlabeled=ignore_unlabeled,
|
|
978
|
+
partition_column=partition_column,
|
|
979
|
+
partition_filter_mode=partition_filter_mode,
|
|
901
980
|
)
|
|
902
981
|
else:
|
|
903
982
|
raise ValueError(f"Invalid data type: {type(data)}")
|
|
@@ -187,18 +187,24 @@ def test_evaluate(classification_model, eval_datasource: Datasource, eval_datase
|
|
|
187
187
|
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
188
188
|
|
|
189
189
|
assert result.pr_auc is not None
|
|
190
|
-
assert np.allclose(result.pr_auc, 0.
|
|
190
|
+
assert np.allclose(result.pr_auc, 0.83333)
|
|
191
191
|
assert result.pr_curve is not None
|
|
192
|
-
assert np.allclose(
|
|
193
|
-
|
|
194
|
-
|
|
192
|
+
assert np.allclose(
|
|
193
|
+
result.pr_curve["thresholds"],
|
|
194
|
+
[0.0, 0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364],
|
|
195
|
+
)
|
|
196
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.666666, 0.5, 1.0, 1.0])
|
|
197
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 1.0, 0.5, 0.5, 0.0])
|
|
195
198
|
|
|
196
199
|
assert result.roc_auc is not None
|
|
197
|
-
assert np.allclose(result.roc_auc, 0.
|
|
200
|
+
assert np.allclose(result.roc_auc, 0.75)
|
|
198
201
|
assert result.roc_curve is not None
|
|
199
|
-
assert np.allclose(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
+
assert np.allclose(
|
|
203
|
+
result.roc_curve["thresholds"],
|
|
204
|
+
[0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364, 1.0],
|
|
205
|
+
)
|
|
206
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.5, 0.0, 0.0])
|
|
207
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 1.0, 0.5, 0.5, 0.0])
|
|
202
208
|
|
|
203
209
|
|
|
204
210
|
def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
|
|
@@ -221,6 +227,139 @@ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval
|
|
|
221
227
|
assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
|
|
222
228
|
|
|
223
229
|
|
|
230
|
+
def test_evaluate_with_partition_column_dataset(partitioned_classification_model: ClassificationModel):
|
|
231
|
+
"""Test evaluate with partition_column on a Dataset"""
|
|
232
|
+
# Create a test dataset with partition_id column
|
|
233
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
234
|
+
[
|
|
235
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
236
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
237
|
+
{"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
|
|
238
|
+
{"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
|
|
239
|
+
]
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Evaluate with partition_column
|
|
243
|
+
result = partitioned_classification_model.evaluate(
|
|
244
|
+
eval_dataset_with_partition,
|
|
245
|
+
partition_column="partition_id",
|
|
246
|
+
partition_filter_mode="exclude_global",
|
|
247
|
+
)
|
|
248
|
+
assert result is not None
|
|
249
|
+
assert isinstance(result, ClassificationMetrics)
|
|
250
|
+
assert isinstance(result.accuracy, float)
|
|
251
|
+
assert isinstance(result.f1_score, float)
|
|
252
|
+
assert isinstance(result.loss, float)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_evaluate_with_partition_column_include_global(partitioned_classification_model: ClassificationModel):
|
|
256
|
+
"""Test evaluate with partition_column and include_global mode"""
|
|
257
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
258
|
+
[
|
|
259
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
260
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
261
|
+
]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Evaluate with partition_column and include_global (default)
|
|
265
|
+
result = partitioned_classification_model.evaluate(
|
|
266
|
+
eval_dataset_with_partition,
|
|
267
|
+
partition_column="partition_id",
|
|
268
|
+
partition_filter_mode="include_global",
|
|
269
|
+
)
|
|
270
|
+
assert result is not None
|
|
271
|
+
assert isinstance(result, ClassificationMetrics)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def test_evaluate_with_partition_column_exclude_global(partitioned_classification_model: ClassificationModel):
|
|
275
|
+
"""Test evaluate with partition_column and exclude_global mode"""
|
|
276
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
277
|
+
[
|
|
278
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
279
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
280
|
+
]
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Evaluate with partition_column and exclude_global
|
|
284
|
+
result = partitioned_classification_model.evaluate(
|
|
285
|
+
eval_dataset_with_partition,
|
|
286
|
+
partition_column="partition_id",
|
|
287
|
+
partition_filter_mode="exclude_global",
|
|
288
|
+
)
|
|
289
|
+
assert result is not None
|
|
290
|
+
assert isinstance(result, ClassificationMetrics)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def test_evaluate_with_partition_column_only_global(partitioned_classification_model: ClassificationModel):
|
|
294
|
+
"""Test evaluate with partition_filter_mode only_global"""
|
|
295
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
296
|
+
[
|
|
297
|
+
{"value": "cats are independent animals", "label": 1, "partition_id": None},
|
|
298
|
+
{"value": "i love the beach", "label": 1, "partition_id": None},
|
|
299
|
+
]
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Evaluate with only_global mode
|
|
303
|
+
result = partitioned_classification_model.evaluate(
|
|
304
|
+
eval_dataset_with_partition,
|
|
305
|
+
partition_column="partition_id",
|
|
306
|
+
partition_filter_mode="only_global",
|
|
307
|
+
)
|
|
308
|
+
assert result is not None
|
|
309
|
+
assert isinstance(result, ClassificationMetrics)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def test_evaluate_with_partition_column_ignore_partitions(partitioned_classification_model: ClassificationModel):
|
|
313
|
+
"""Test evaluate with partition_filter_mode ignore_partitions"""
|
|
314
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
315
|
+
[
|
|
316
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
317
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p2"},
|
|
318
|
+
]
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Evaluate with ignore_partitions mode
|
|
322
|
+
result = partitioned_classification_model.evaluate(
|
|
323
|
+
eval_dataset_with_partition,
|
|
324
|
+
partition_column="partition_id",
|
|
325
|
+
partition_filter_mode="ignore_partitions",
|
|
326
|
+
)
|
|
327
|
+
assert result is not None
|
|
328
|
+
assert isinstance(result, ClassificationMetrics)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
332
|
+
def test_evaluate_with_partition_column_datasource(partitioned_classification_model: ClassificationModel, data_type):
|
|
333
|
+
"""Test evaluate with partition_column on a Datasource"""
|
|
334
|
+
# Create a test datasource with partition_id column
|
|
335
|
+
eval_data_with_partition = [
|
|
336
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
337
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
338
|
+
{"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
|
|
339
|
+
{"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
|
|
340
|
+
]
|
|
341
|
+
|
|
342
|
+
if data_type == "dataset":
|
|
343
|
+
eval_data = Dataset.from_list(eval_data_with_partition)
|
|
344
|
+
result = partitioned_classification_model.evaluate(
|
|
345
|
+
eval_data,
|
|
346
|
+
partition_column="partition_id",
|
|
347
|
+
partition_filter_mode="exclude_global",
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
eval_datasource = Datasource.from_list("eval_datasource_with_partition", eval_data_with_partition)
|
|
351
|
+
result = partitioned_classification_model.evaluate(
|
|
352
|
+
eval_datasource,
|
|
353
|
+
partition_column="partition_id",
|
|
354
|
+
partition_filter_mode="exclude_global",
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
assert result is not None
|
|
358
|
+
assert isinstance(result, ClassificationMetrics)
|
|
359
|
+
assert isinstance(result.accuracy, float)
|
|
360
|
+
assert isinstance(result.f1_score, float)
|
|
361
|
+
|
|
362
|
+
|
|
224
363
|
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
225
364
|
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
226
365
|
assert len(predictions) == 2
|
|
@@ -284,6 +423,186 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
|
284
423
|
model.predict("test")
|
|
285
424
|
|
|
286
425
|
|
|
426
|
+
def test_predict_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
|
|
427
|
+
"""Test predict with a specific partition_id"""
|
|
428
|
+
# Predict with partition_id p1 - should use memories from p1
|
|
429
|
+
prediction = partitioned_classification_model.predict(
|
|
430
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
431
|
+
)
|
|
432
|
+
assert prediction.label is not None
|
|
433
|
+
assert prediction.label_name in label_names
|
|
434
|
+
assert 0 <= prediction.confidence <= 1
|
|
435
|
+
assert prediction.logits is not None
|
|
436
|
+
assert len(prediction.logits) == 2
|
|
437
|
+
|
|
438
|
+
# Predict with partition_id p2 - should use memories from p2
|
|
439
|
+
prediction_p2 = partitioned_classification_model.predict(
|
|
440
|
+
"cats", partition_id="p2", partition_filter_mode="exclude_global"
|
|
441
|
+
)
|
|
442
|
+
assert prediction_p2.label is not None
|
|
443
|
+
assert prediction_p2.label_name in label_names
|
|
444
|
+
assert 0 <= prediction_p2.confidence <= 1
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def test_predict_with_partition_id_include_global(
|
|
448
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
449
|
+
):
|
|
450
|
+
"""Test predict with partition_id and include_global mode (default)"""
|
|
451
|
+
# Predict with partition_id p1 and include_global (default) - should include both p1 and global memories
|
|
452
|
+
prediction = partitioned_classification_model.predict(
|
|
453
|
+
"soup", partition_id="p1", partition_filter_mode="include_global"
|
|
454
|
+
)
|
|
455
|
+
assert prediction.label is not None
|
|
456
|
+
assert prediction.label_name in label_names
|
|
457
|
+
assert 0 <= prediction.confidence <= 1
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def test_predict_with_partition_id_exclude_global(
|
|
461
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
462
|
+
):
|
|
463
|
+
"""Test predict with partition_id and exclude_global mode"""
|
|
464
|
+
# Predict with partition_id p1 and exclude_global - should only use p1 memories
|
|
465
|
+
prediction = partitioned_classification_model.predict(
|
|
466
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
467
|
+
)
|
|
468
|
+
assert prediction.label is not None
|
|
469
|
+
assert prediction.label_name in label_names
|
|
470
|
+
assert 0 <= prediction.confidence <= 1
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def test_predict_with_partition_id_only_global(
|
|
474
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
475
|
+
):
|
|
476
|
+
"""Test predict with partition_filter_mode only_global"""
|
|
477
|
+
# Predict with only_global mode - should only use global memories
|
|
478
|
+
prediction = partitioned_classification_model.predict("cats", partition_filter_mode="only_global")
|
|
479
|
+
assert prediction.label is not None
|
|
480
|
+
assert prediction.label_name in label_names
|
|
481
|
+
assert 0 <= prediction.confidence <= 1
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def test_predict_with_partition_id_ignore_partitions(
|
|
485
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
486
|
+
):
|
|
487
|
+
"""Test predict with partition_filter_mode ignore_partitions"""
|
|
488
|
+
# Predict with ignore_partitions mode - should ignore partition filtering
|
|
489
|
+
prediction = partitioned_classification_model.predict("soup", partition_filter_mode="ignore_partitions")
|
|
490
|
+
assert prediction.label is not None
|
|
491
|
+
assert prediction.label_name in label_names
|
|
492
|
+
assert 0 <= prediction.confidence <= 1
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def test_predict_batch_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
|
|
496
|
+
"""Test batch predict with partition_id"""
|
|
497
|
+
# Batch predict with partition_id p1
|
|
498
|
+
predictions = partitioned_classification_model.predict(
|
|
499
|
+
["soup is good", "cats are cute"],
|
|
500
|
+
partition_id="p1",
|
|
501
|
+
partition_filter_mode="exclude_global",
|
|
502
|
+
)
|
|
503
|
+
assert len(predictions) == 2
|
|
504
|
+
assert all(p.label is not None for p in predictions)
|
|
505
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
506
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
507
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def test_predict_with_partition_id_and_filters(
|
|
511
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
512
|
+
):
|
|
513
|
+
"""Test predict with partition_id and filters"""
|
|
514
|
+
# Predict with partition_id and filters
|
|
515
|
+
prediction = partitioned_classification_model.predict(
|
|
516
|
+
"soup",
|
|
517
|
+
partition_id="p1",
|
|
518
|
+
partition_filter_mode="exclude_global",
|
|
519
|
+
filters=[("key", "==", "g1")],
|
|
520
|
+
)
|
|
521
|
+
assert prediction.label is not None
|
|
522
|
+
assert prediction.label_name in label_names
|
|
523
|
+
assert 0 <= prediction.confidence <= 1
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def test_predict_batch_with_list_of_partition_ids(
|
|
527
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
528
|
+
):
|
|
529
|
+
"""Test batch predict with a list of partition_ids (one for each query input)"""
|
|
530
|
+
# Batch predict with a list of partition_ids - one for each input
|
|
531
|
+
# First input uses p1, second input uses p2
|
|
532
|
+
predictions = partitioned_classification_model.predict(
|
|
533
|
+
["soup is good", "cats are cute"],
|
|
534
|
+
partition_id=["p1", "p2"],
|
|
535
|
+
partition_filter_mode="exclude_global",
|
|
536
|
+
)
|
|
537
|
+
assert len(predictions) == 2
|
|
538
|
+
assert all(p.label is not None for p in predictions)
|
|
539
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
540
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
541
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
542
|
+
|
|
543
|
+
# Verify that predictions were made using the correct partitions
|
|
544
|
+
# Each prediction should use memories from its respective partition
|
|
545
|
+
assert predictions[0].input_value == "soup is good"
|
|
546
|
+
assert predictions[1].input_value == "cats are cute"
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@pytest.mark.asyncio
|
|
550
|
+
async def test_predict_async_with_partition_id(
|
|
551
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
552
|
+
):
|
|
553
|
+
"""Test async predict with partition_id"""
|
|
554
|
+
# Async predict with partition_id p1
|
|
555
|
+
prediction = await partitioned_classification_model.apredict(
|
|
556
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
557
|
+
)
|
|
558
|
+
assert prediction.label is not None
|
|
559
|
+
assert prediction.label_name in label_names
|
|
560
|
+
assert 0 <= prediction.confidence <= 1
|
|
561
|
+
assert prediction.logits is not None
|
|
562
|
+
assert len(prediction.logits) == 2
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
@pytest.mark.asyncio
|
|
566
|
+
async def test_predict_async_batch_with_partition_id(
|
|
567
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
568
|
+
):
|
|
569
|
+
"""Test async batch predict with partition_id"""
|
|
570
|
+
# Async batch predict with partition_id p1
|
|
571
|
+
predictions = await partitioned_classification_model.apredict(
|
|
572
|
+
["soup is good", "cats are cute"],
|
|
573
|
+
partition_id="p1",
|
|
574
|
+
partition_filter_mode="exclude_global",
|
|
575
|
+
)
|
|
576
|
+
assert len(predictions) == 2
|
|
577
|
+
assert all(p.label is not None for p in predictions)
|
|
578
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
579
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@pytest.mark.asyncio
|
|
583
|
+
async def test_predict_async_batch_with_list_of_partition_ids(
|
|
584
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
585
|
+
):
|
|
586
|
+
"""Test async batch predict with a list of partition_ids (one for each query input)"""
|
|
587
|
+
# Async batch predict with a list of partition_ids - one for each input
|
|
588
|
+
# First input uses p1, second input uses p2
|
|
589
|
+
predictions = await partitioned_classification_model.apredict(
|
|
590
|
+
["soup is good", "cats are cute"],
|
|
591
|
+
partition_id=["p1", "p2"],
|
|
592
|
+
partition_filter_mode="exclude_global",
|
|
593
|
+
)
|
|
594
|
+
assert len(predictions) == 2
|
|
595
|
+
assert all(p.label is not None for p in predictions)
|
|
596
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
597
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
598
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
599
|
+
|
|
600
|
+
# Verify that predictions were made using the correct partitions
|
|
601
|
+
# Each prediction should use memories from its respective partition
|
|
602
|
+
assert predictions[0].input_value == "soup is good"
|
|
603
|
+
assert predictions[1].input_value == "cats are cute"
|
|
604
|
+
|
|
605
|
+
|
|
287
606
|
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
288
607
|
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
289
608
|
expected_labels = [0, 1]
|