orca-sdk 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/async_client.py +448 -301
- orca_sdk/classification_model.py +53 -17
- orca_sdk/client.py +448 -301
- orca_sdk/datasource.py +45 -2
- orca_sdk/datasource_test.py +120 -0
- orca_sdk/embedding_model.py +32 -24
- orca_sdk/job.py +17 -17
- orca_sdk/memoryset.py +318 -30
- orca_sdk/memoryset_test.py +185 -1
- orca_sdk/regression_model.py +38 -4
- orca_sdk/telemetry.py +52 -13
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.4.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.4.dist-info}/RECORD +14 -14
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.4.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -3,14 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from contextlib import contextmanager
|
|
5
5
|
from datetime import datetime
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Generator,
|
|
9
|
-
Iterable,
|
|
10
|
-
Literal,
|
|
11
|
-
cast,
|
|
12
|
-
overload,
|
|
13
|
-
)
|
|
6
|
+
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
14
7
|
|
|
15
8
|
from datasets import Dataset
|
|
16
9
|
|
|
@@ -20,8 +13,10 @@ from .async_client import OrcaAsyncClient
|
|
|
20
13
|
from .client import (
|
|
21
14
|
BootstrapClassificationModelMeta,
|
|
22
15
|
BootstrapClassificationModelResult,
|
|
16
|
+
ClassificationEvaluationRequest,
|
|
23
17
|
ClassificationModelMetadata,
|
|
24
18
|
OrcaClient,
|
|
19
|
+
PostClassificationModelByModelNameOrIdEvaluationParams,
|
|
25
20
|
PredictiveModelUpdate,
|
|
26
21
|
RACHeadType,
|
|
27
22
|
)
|
|
@@ -207,7 +202,12 @@ class ClassificationModel:
|
|
|
207
202
|
raise ValueError(f"Model with name {name} already exists")
|
|
208
203
|
elif if_exists == "open":
|
|
209
204
|
existing = cls.open(name)
|
|
210
|
-
for attribute in {
|
|
205
|
+
for attribute in {
|
|
206
|
+
"head_type",
|
|
207
|
+
"memory_lookup_count",
|
|
208
|
+
"num_classes",
|
|
209
|
+
"min_memory_weight",
|
|
210
|
+
}:
|
|
211
211
|
local_attribute = locals()[attribute]
|
|
212
212
|
existing_attribute = getattr(existing, attribute)
|
|
213
213
|
if local_attribute is not None and local_attribute != existing_attribute:
|
|
@@ -357,6 +357,8 @@ class ClassificationModel:
|
|
|
357
357
|
prompt: str | None = None,
|
|
358
358
|
use_lookup_cache: bool = True,
|
|
359
359
|
timeout_seconds: int = 10,
|
|
360
|
+
ignore_unlabeled: bool = False,
|
|
361
|
+
use_gpu: bool = True,
|
|
360
362
|
) -> list[ClassificationPrediction]:
|
|
361
363
|
pass
|
|
362
364
|
|
|
@@ -371,6 +373,8 @@ class ClassificationModel:
|
|
|
371
373
|
prompt: str | None = None,
|
|
372
374
|
use_lookup_cache: bool = True,
|
|
373
375
|
timeout_seconds: int = 10,
|
|
376
|
+
ignore_unlabeled: bool = False,
|
|
377
|
+
use_gpu: bool = True,
|
|
374
378
|
) -> ClassificationPrediction:
|
|
375
379
|
pass
|
|
376
380
|
|
|
@@ -384,6 +388,8 @@ class ClassificationModel:
|
|
|
384
388
|
prompt: str | None = None,
|
|
385
389
|
use_lookup_cache: bool = True,
|
|
386
390
|
timeout_seconds: int = 10,
|
|
391
|
+
ignore_unlabeled: bool = False,
|
|
392
|
+
use_gpu: bool = True,
|
|
387
393
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
388
394
|
"""
|
|
389
395
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -402,6 +408,9 @@ class ClassificationModel:
|
|
|
402
408
|
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
403
409
|
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
404
410
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
411
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
412
|
+
If False (default), allow unlabeled memories when necessary.
|
|
413
|
+
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
405
414
|
|
|
406
415
|
Returns:
|
|
407
416
|
Label prediction or list of label predictions
|
|
@@ -447,10 +456,15 @@ class ClassificationModel:
|
|
|
447
456
|
for label in expected_labels
|
|
448
457
|
]
|
|
449
458
|
|
|
459
|
+
if use_gpu:
|
|
460
|
+
endpoint = "/gpu/classification_model/{name_or_id}/prediction"
|
|
461
|
+
else:
|
|
462
|
+
endpoint = "/classification_model/{name_or_id}/prediction"
|
|
463
|
+
|
|
450
464
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
451
465
|
client = OrcaClient._resolve_client()
|
|
452
466
|
response = client.POST(
|
|
453
|
-
|
|
467
|
+
endpoint,
|
|
454
468
|
params={"name_or_id": self.id},
|
|
455
469
|
json={
|
|
456
470
|
"input_values": value if isinstance(value, list) else [value],
|
|
@@ -462,6 +476,7 @@ class ClassificationModel:
|
|
|
462
476
|
"filters": cast(list[FilterItem], parsed_filters),
|
|
463
477
|
"prompt": prompt,
|
|
464
478
|
"use_lookup_cache": use_lookup_cache,
|
|
479
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
465
480
|
},
|
|
466
481
|
timeout=timeout_seconds,
|
|
467
482
|
)
|
|
@@ -499,6 +514,7 @@ class ClassificationModel:
|
|
|
499
514
|
prompt: str | None = None,
|
|
500
515
|
use_lookup_cache: bool = True,
|
|
501
516
|
timeout_seconds: int = 10,
|
|
517
|
+
ignore_unlabeled: bool = False,
|
|
502
518
|
) -> list[ClassificationPrediction]:
|
|
503
519
|
pass
|
|
504
520
|
|
|
@@ -513,6 +529,7 @@ class ClassificationModel:
|
|
|
513
529
|
prompt: str | None = None,
|
|
514
530
|
use_lookup_cache: bool = True,
|
|
515
531
|
timeout_seconds: int = 10,
|
|
532
|
+
ignore_unlabeled: bool = False,
|
|
516
533
|
) -> ClassificationPrediction:
|
|
517
534
|
pass
|
|
518
535
|
|
|
@@ -526,6 +543,7 @@ class ClassificationModel:
|
|
|
526
543
|
prompt: str | None = None,
|
|
527
544
|
use_lookup_cache: bool = True,
|
|
528
545
|
timeout_seconds: int = 10,
|
|
546
|
+
ignore_unlabeled: bool = False,
|
|
529
547
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
530
548
|
"""
|
|
531
549
|
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -544,6 +562,8 @@ class ClassificationModel:
|
|
|
544
562
|
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
545
563
|
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
546
564
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
565
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
566
|
+
If False (default), allow unlabeled memories when necessary.
|
|
547
567
|
|
|
548
568
|
Returns:
|
|
549
569
|
Label prediction or list of label predictions.
|
|
@@ -604,6 +624,7 @@ class ClassificationModel:
|
|
|
604
624
|
"filters": cast(list[FilterItem], parsed_filters),
|
|
605
625
|
"prompt": prompt,
|
|
606
626
|
"use_lookup_cache": use_lookup_cache,
|
|
627
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
607
628
|
},
|
|
608
629
|
timeout=timeout_seconds,
|
|
609
630
|
)
|
|
@@ -706,7 +727,9 @@ class ClassificationModel:
|
|
|
706
727
|
label_column: str,
|
|
707
728
|
record_predictions: bool,
|
|
708
729
|
tags: set[str] | None,
|
|
730
|
+
subsample: int | float | None,
|
|
709
731
|
background: bool = False,
|
|
732
|
+
ignore_unlabeled: bool = False,
|
|
710
733
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
711
734
|
client = OrcaClient._resolve_client()
|
|
712
735
|
response = client.POST(
|
|
@@ -719,14 +742,16 @@ class ClassificationModel:
|
|
|
719
742
|
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
720
743
|
"record_telemetry": record_predictions,
|
|
721
744
|
"telemetry_tags": list(tags) if tags else None,
|
|
745
|
+
"subsample": subsample,
|
|
746
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
722
747
|
},
|
|
723
748
|
)
|
|
724
749
|
|
|
725
750
|
def get_value():
|
|
726
751
|
client = OrcaClient._resolve_client()
|
|
727
752
|
res = client.GET(
|
|
728
|
-
"/classification_model/{model_name_or_id}/evaluation/{
|
|
729
|
-
params={"model_name_or_id": self.id, "
|
|
753
|
+
"/classification_model/{model_name_or_id}/evaluation/{job_id}",
|
|
754
|
+
params={"model_name_or_id": self.id, "job_id": response["job_id"]},
|
|
730
755
|
)
|
|
731
756
|
assert res["result"] is not None
|
|
732
757
|
return ClassificationMetrics(
|
|
@@ -743,7 +768,7 @@ class ClassificationModel:
|
|
|
743
768
|
roc_curve=res["result"].get("roc_curve"),
|
|
744
769
|
)
|
|
745
770
|
|
|
746
|
-
job = Job(response["
|
|
771
|
+
job = Job(response["job_id"], get_value)
|
|
747
772
|
return job if background else job.result()
|
|
748
773
|
|
|
749
774
|
def _evaluate_dataset(
|
|
@@ -754,6 +779,7 @@ class ClassificationModel:
|
|
|
754
779
|
record_predictions: bool,
|
|
755
780
|
tags: set[str],
|
|
756
781
|
batch_size: int,
|
|
782
|
+
ignore_unlabeled: bool,
|
|
757
783
|
) -> ClassificationMetrics:
|
|
758
784
|
if len(dataset) == 0:
|
|
759
785
|
raise ValueError("Evaluation dataset cannot be empty")
|
|
@@ -769,6 +795,7 @@ class ClassificationModel:
|
|
|
769
795
|
expected_labels=dataset[i : i + batch_size][label_column],
|
|
770
796
|
tags=tags,
|
|
771
797
|
save_telemetry="sync" if record_predictions else "off",
|
|
798
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
772
799
|
)
|
|
773
800
|
]
|
|
774
801
|
|
|
@@ -789,7 +816,9 @@ class ClassificationModel:
|
|
|
789
816
|
record_predictions: bool = False,
|
|
790
817
|
tags: set[str] = {"evaluation"},
|
|
791
818
|
batch_size: int = 100,
|
|
819
|
+
subsample: int | float | None = None,
|
|
792
820
|
background: Literal[True],
|
|
821
|
+
ignore_unlabeled: bool = False,
|
|
793
822
|
) -> Job[ClassificationMetrics]:
|
|
794
823
|
pass
|
|
795
824
|
|
|
@@ -803,7 +832,9 @@ class ClassificationModel:
|
|
|
803
832
|
record_predictions: bool = False,
|
|
804
833
|
tags: set[str] = {"evaluation"},
|
|
805
834
|
batch_size: int = 100,
|
|
835
|
+
subsample: int | float | None = None,
|
|
806
836
|
background: Literal[False] = False,
|
|
837
|
+
ignore_unlabeled: bool = False,
|
|
807
838
|
) -> ClassificationMetrics:
|
|
808
839
|
pass
|
|
809
840
|
|
|
@@ -816,7 +847,9 @@ class ClassificationModel:
|
|
|
816
847
|
record_predictions: bool = False,
|
|
817
848
|
tags: set[str] = {"evaluation"},
|
|
818
849
|
batch_size: int = 100,
|
|
850
|
+
subsample: int | float | None = None,
|
|
819
851
|
background: bool = False,
|
|
852
|
+
ignore_unlabeled: bool = False,
|
|
820
853
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
821
854
|
"""
|
|
822
855
|
Evaluate the classification model on a given dataset or datasource
|
|
@@ -828,7 +861,9 @@ class ClassificationModel:
|
|
|
828
861
|
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
829
862
|
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
830
863
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
864
|
+
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
831
865
|
background: Whether to run the operation in the background and return a job handle
|
|
866
|
+
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
832
867
|
|
|
833
868
|
Returns:
|
|
834
869
|
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
@@ -850,7 +885,9 @@ class ClassificationModel:
|
|
|
850
885
|
label_column=label_column,
|
|
851
886
|
record_predictions=record_predictions,
|
|
852
887
|
tags=tags,
|
|
888
|
+
subsample=subsample,
|
|
853
889
|
background=background,
|
|
890
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
854
891
|
)
|
|
855
892
|
elif isinstance(data, Dataset):
|
|
856
893
|
return self._evaluate_dataset(
|
|
@@ -860,6 +897,7 @@ class ClassificationModel:
|
|
|
860
897
|
record_predictions=record_predictions,
|
|
861
898
|
tags=tags,
|
|
862
899
|
batch_size=batch_size,
|
|
900
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
863
901
|
)
|
|
864
902
|
else:
|
|
865
903
|
raise ValueError(f"Invalid data type: {type(data)}")
|
|
@@ -961,11 +999,9 @@ class ClassificationModel:
|
|
|
961
999
|
|
|
962
1000
|
def get_result() -> BootstrappedClassificationModel:
|
|
963
1001
|
client = OrcaClient._resolve_client()
|
|
964
|
-
res = client.GET(
|
|
965
|
-
"/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
|
|
966
|
-
)
|
|
1002
|
+
res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
|
|
967
1003
|
assert res["result"] is not None
|
|
968
1004
|
return BootstrappedClassificationModel(res["result"])
|
|
969
1005
|
|
|
970
|
-
job = Job(response["
|
|
1006
|
+
job = Job(response["job_id"], get_result)
|
|
971
1007
|
return job if background else job.result()
|