orca-sdk 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +3 -3
- orca_sdk/_utils/auth.py +2 -3
- orca_sdk/_utils/common.py +24 -1
- orca_sdk/_utils/torch_parsing.py +77 -0
- orca_sdk/_utils/torch_parsing_test.py +142 -0
- orca_sdk/async_client.py +156 -4
- orca_sdk/classification_model.py +202 -65
- orca_sdk/classification_model_test.py +16 -3
- orca_sdk/client.py +156 -4
- orca_sdk/conftest.py +10 -9
- orca_sdk/datasource.py +31 -13
- orca_sdk/embedding_model.py +8 -31
- orca_sdk/embedding_model_test.py +1 -1
- orca_sdk/memoryset.py +236 -321
- orca_sdk/memoryset_test.py +39 -13
- orca_sdk/regression_model.py +185 -64
- orca_sdk/regression_model_test.py +18 -3
- orca_sdk/telemetry.py +15 -6
- {orca_sdk-0.1.11.dist-info → orca_sdk-0.1.12.dist-info}/METADATA +3 -5
- orca_sdk-0.1.12.dist-info/RECORD +38 -0
- orca_sdk/_shared/__init__.py +0 -10
- orca_sdk/_shared/metrics.py +0 -634
- orca_sdk/_shared/metrics_test.py +0 -570
- orca_sdk/_utils/data_parsing.py +0 -137
- orca_sdk/_utils/data_parsing_disk_test.py +0 -91
- orca_sdk/_utils/data_parsing_torch_test.py +0 -159
- orca_sdk-0.1.11.dist-info/RECORD +0 -42
- {orca_sdk-0.1.11.dist-info → orca_sdk-0.1.12.dist-info}/WHEEL +0 -0
orca_sdk/client.py
CHANGED
|
@@ -85,7 +85,7 @@ class BaseLabelPredictionResult(TypedDict):
|
|
|
85
85
|
anomaly_score: float | None
|
|
86
86
|
label: int | None
|
|
87
87
|
label_name: str | None
|
|
88
|
-
logits: list[float]
|
|
88
|
+
logits: list[float] | None
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
class BaseModel(TypedDict):
|
|
@@ -160,6 +160,18 @@ The type of a column in a datasource
|
|
|
160
160
|
"""
|
|
161
161
|
|
|
162
162
|
|
|
163
|
+
class ComputeClassificationMetricsRequest(TypedDict):
|
|
164
|
+
expected_labels: list[int]
|
|
165
|
+
logits: list[list[float] | None]
|
|
166
|
+
anomaly_scores: NotRequired[list[float] | None]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class ComputeRegressionMetricsRequest(TypedDict):
|
|
170
|
+
expected_scores: list[float]
|
|
171
|
+
predicted_scores: list[float | None]
|
|
172
|
+
anomaly_scores: NotRequired[list[float] | None]
|
|
173
|
+
|
|
174
|
+
|
|
163
175
|
class ConstraintViolationErrorResponse(TypedDict):
|
|
164
176
|
status_code: Literal[409]
|
|
165
177
|
constraint: str
|
|
@@ -322,6 +334,7 @@ class GetDatasourceRowsRequest(TypedDict):
|
|
|
322
334
|
|
|
323
335
|
class GetMemoriesRequest(TypedDict):
|
|
324
336
|
memory_ids: list[str]
|
|
337
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
325
338
|
|
|
326
339
|
|
|
327
340
|
class HealthyResponse(TypedDict):
|
|
@@ -392,6 +405,7 @@ class ListMemoriesRequest(TypedDict):
|
|
|
392
405
|
offset: NotRequired[int]
|
|
393
406
|
limit: NotRequired[int]
|
|
394
407
|
filters: NotRequired[list[FilterItem]]
|
|
408
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
395
409
|
|
|
396
410
|
|
|
397
411
|
class LookupRequest(TypedDict):
|
|
@@ -400,6 +414,7 @@ class LookupRequest(TypedDict):
|
|
|
400
414
|
prompt: NotRequired[str | None]
|
|
401
415
|
partition_id: NotRequired[str | list[str | None] | None]
|
|
402
416
|
partition_filter_mode: NotRequired[Literal["ignore_partitions", "include_global", "exclude_global", "only_global"]]
|
|
417
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
403
418
|
|
|
404
419
|
|
|
405
420
|
class LookupScoreMetrics(TypedDict):
|
|
@@ -570,8 +585,17 @@ class OrgPlan(TypedDict):
|
|
|
570
585
|
|
|
571
586
|
class PRCurve(TypedDict):
|
|
572
587
|
thresholds: list[float]
|
|
588
|
+
"""
|
|
589
|
+
Threshold values for the curve
|
|
590
|
+
"""
|
|
573
591
|
precisions: list[float]
|
|
592
|
+
"""
|
|
593
|
+
Precision values at each threshold
|
|
594
|
+
"""
|
|
574
595
|
recalls: list[float]
|
|
596
|
+
"""
|
|
597
|
+
Recall values at each threshold
|
|
598
|
+
"""
|
|
575
599
|
|
|
576
600
|
|
|
577
601
|
class PredictionFeedback(TypedDict):
|
|
@@ -642,8 +666,17 @@ RARHeadType: TypeAlias = Literal["MMOE", "KNN"]
|
|
|
642
666
|
|
|
643
667
|
class ROCCurve(TypedDict):
|
|
644
668
|
thresholds: list[float]
|
|
669
|
+
"""
|
|
670
|
+
Threshold values for the curve
|
|
671
|
+
"""
|
|
645
672
|
false_positive_rates: list[float]
|
|
673
|
+
"""
|
|
674
|
+
False positive rate values at each threshold
|
|
675
|
+
"""
|
|
646
676
|
true_positive_rates: list[float]
|
|
677
|
+
"""
|
|
678
|
+
True positive rate values at each threshold
|
|
679
|
+
"""
|
|
647
680
|
|
|
648
681
|
|
|
649
682
|
class ReadyResponse(TypedDict):
|
|
@@ -666,15 +699,49 @@ class RegressionEvaluationRequest(TypedDict):
|
|
|
666
699
|
|
|
667
700
|
class RegressionMetrics(TypedDict):
|
|
668
701
|
coverage: float
|
|
702
|
+
"""
|
|
703
|
+
Percentage of predictions that are not none
|
|
704
|
+
"""
|
|
669
705
|
mse: float
|
|
706
|
+
"""
|
|
707
|
+
Mean squared error of the predictions
|
|
708
|
+
"""
|
|
670
709
|
rmse: float
|
|
710
|
+
"""
|
|
711
|
+
Root mean squared error of the predictions
|
|
712
|
+
"""
|
|
671
713
|
mae: float
|
|
714
|
+
"""
|
|
715
|
+
Mean absolute error of the predictions
|
|
716
|
+
"""
|
|
672
717
|
r2: float
|
|
718
|
+
"""
|
|
719
|
+
R-squared score (coefficient of determination) of the predictions
|
|
720
|
+
"""
|
|
673
721
|
explained_variance: float
|
|
722
|
+
"""
|
|
723
|
+
Explained variance score of the predictions
|
|
724
|
+
"""
|
|
674
725
|
loss: float
|
|
726
|
+
"""
|
|
727
|
+
Mean squared error loss of the predictions
|
|
728
|
+
"""
|
|
675
729
|
anomaly_score_mean: NotRequired[float | None]
|
|
730
|
+
"""
|
|
731
|
+
Mean of anomaly scores across the dataset
|
|
732
|
+
"""
|
|
676
733
|
anomaly_score_median: NotRequired[float | None]
|
|
734
|
+
"""
|
|
735
|
+
Median of anomaly scores across the dataset
|
|
736
|
+
"""
|
|
677
737
|
anomaly_score_variance: NotRequired[float | None]
|
|
738
|
+
"""
|
|
739
|
+
Variance of anomaly scores across the dataset
|
|
740
|
+
"""
|
|
741
|
+
warnings: NotRequired[list[str]]
|
|
742
|
+
"""
|
|
743
|
+
Human-readable warnings about skipped or adjusted metrics
|
|
744
|
+
"""
|
|
678
745
|
|
|
679
746
|
|
|
680
747
|
class RegressionModelMetadata(TypedDict):
|
|
@@ -703,7 +770,7 @@ class RegressionPredictionRequest(TypedDict):
|
|
|
703
770
|
save_telemetry_synchronously: NotRequired[bool]
|
|
704
771
|
prompt: NotRequired[str | None]
|
|
705
772
|
use_lookup_cache: NotRequired[bool]
|
|
706
|
-
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]
|
|
773
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
707
774
|
ignore_unlabeled: NotRequired[bool]
|
|
708
775
|
partition_ids: NotRequired[str | list[str | None] | None]
|
|
709
776
|
partition_filter_mode: NotRequired[Literal["ignore_partitions", "include_global", "exclude_global", "only_global"]]
|
|
@@ -927,6 +994,7 @@ class GetMemorysetByNameOrIdMemoryByMemoryIdParams(TypedDict):
|
|
|
927
994
|
"""
|
|
928
995
|
ID of the memory
|
|
929
996
|
"""
|
|
997
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
930
998
|
|
|
931
999
|
|
|
932
1000
|
class DeleteMemorysetByNameOrIdMemoryByMemoryIdParams(TypedDict):
|
|
@@ -1304,18 +1372,57 @@ class BootstrapLabeledMemoryDataResult(TypedDict):
|
|
|
1304
1372
|
|
|
1305
1373
|
class ClassificationMetrics(TypedDict):
|
|
1306
1374
|
coverage: float
|
|
1375
|
+
"""
|
|
1376
|
+
Percentage of predictions that are not none
|
|
1377
|
+
"""
|
|
1307
1378
|
f1_score: float
|
|
1379
|
+
"""
|
|
1380
|
+
F1 score of the predictions
|
|
1381
|
+
"""
|
|
1308
1382
|
accuracy: float
|
|
1383
|
+
"""
|
|
1384
|
+
Accuracy of the predictions
|
|
1385
|
+
"""
|
|
1309
1386
|
loss: float | None
|
|
1387
|
+
"""
|
|
1388
|
+
Cross-entropy loss of the logits
|
|
1389
|
+
"""
|
|
1310
1390
|
anomaly_score_mean: NotRequired[float | None]
|
|
1391
|
+
"""
|
|
1392
|
+
Mean of anomaly scores across the dataset
|
|
1393
|
+
"""
|
|
1311
1394
|
anomaly_score_median: NotRequired[float | None]
|
|
1395
|
+
"""
|
|
1396
|
+
Median of anomaly scores across the dataset
|
|
1397
|
+
"""
|
|
1312
1398
|
anomaly_score_variance: NotRequired[float | None]
|
|
1399
|
+
"""
|
|
1400
|
+
Variance of anomaly scores across the dataset
|
|
1401
|
+
"""
|
|
1313
1402
|
roc_auc: NotRequired[float | None]
|
|
1403
|
+
"""
|
|
1404
|
+
Receiver operating characteristic area under the curve
|
|
1405
|
+
"""
|
|
1314
1406
|
pr_auc: NotRequired[float | None]
|
|
1407
|
+
"""
|
|
1408
|
+
Average precision (area under the curve of the precision-recall curve)
|
|
1409
|
+
"""
|
|
1315
1410
|
pr_curve: NotRequired[PRCurve | None]
|
|
1411
|
+
"""
|
|
1412
|
+
Precision-recall curve
|
|
1413
|
+
"""
|
|
1316
1414
|
roc_curve: NotRequired[ROCCurve | None]
|
|
1415
|
+
"""
|
|
1416
|
+
Receiver operating characteristic curve
|
|
1417
|
+
"""
|
|
1317
1418
|
confusion_matrix: NotRequired[list[list[int]] | None]
|
|
1419
|
+
"""
|
|
1420
|
+
Confusion matrix where the entry at row i, column j is the count of samples with true label i predicted as label j
|
|
1421
|
+
"""
|
|
1318
1422
|
warnings: NotRequired[list[str]]
|
|
1423
|
+
"""
|
|
1424
|
+
Human-readable warnings about skipped or adjusted metrics
|
|
1425
|
+
"""
|
|
1319
1426
|
|
|
1320
1427
|
|
|
1321
1428
|
class ClassificationModelMetadata(TypedDict):
|
|
@@ -1348,7 +1455,7 @@ class ClassificationPredictionRequest(TypedDict):
|
|
|
1348
1455
|
save_telemetry_synchronously: NotRequired[bool]
|
|
1349
1456
|
prompt: NotRequired[str | None]
|
|
1350
1457
|
use_lookup_cache: NotRequired[bool]
|
|
1351
|
-
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]
|
|
1458
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
1352
1459
|
ignore_unlabeled: NotRequired[bool]
|
|
1353
1460
|
partition_ids: NotRequired[str | list[str | None] | None]
|
|
1354
1461
|
partition_filter_mode: NotRequired[Literal["ignore_partitions", "include_global", "exclude_global", "only_global"]]
|
|
@@ -1362,6 +1469,7 @@ class CloneMemorysetRequest(TypedDict):
|
|
|
1362
1469
|
finetuned_embedding_model_name_or_id: NotRequired[str | None]
|
|
1363
1470
|
max_seq_length_override: NotRequired[int | None]
|
|
1364
1471
|
prompt: NotRequired[str]
|
|
1472
|
+
is_partitioned: NotRequired[bool | None]
|
|
1365
1473
|
|
|
1366
1474
|
|
|
1367
1475
|
class ColumnInfo(TypedDict):
|
|
@@ -1409,6 +1517,7 @@ class CreateMemorysetFromDatasourceRequest(TypedDict):
|
|
|
1409
1517
|
prompt: NotRequired[str]
|
|
1410
1518
|
hidden: NotRequired[bool]
|
|
1411
1519
|
memory_type: NotRequired[MemoryType | None]
|
|
1520
|
+
is_partitioned: NotRequired[bool]
|
|
1412
1521
|
datasource_name_or_id: str
|
|
1413
1522
|
datasource_label_column: NotRequired[str | None]
|
|
1414
1523
|
datasource_score_column: NotRequired[str | None]
|
|
@@ -1433,6 +1542,7 @@ class CreateMemorysetRequest(TypedDict):
|
|
|
1433
1542
|
prompt: NotRequired[str]
|
|
1434
1543
|
hidden: NotRequired[bool]
|
|
1435
1544
|
memory_type: NotRequired[MemoryType | None]
|
|
1545
|
+
is_partitioned: NotRequired[bool]
|
|
1436
1546
|
|
|
1437
1547
|
|
|
1438
1548
|
class CreateRegressionModelRequest(TypedDict):
|
|
@@ -1590,7 +1700,7 @@ class LabelPredictionWithMemoriesAndFeedback(TypedDict):
|
|
|
1590
1700
|
anomaly_score: float | None
|
|
1591
1701
|
label: int | None
|
|
1592
1702
|
label_name: str | None
|
|
1593
|
-
logits: list[float]
|
|
1703
|
+
logits: list[float] | None
|
|
1594
1704
|
timestamp: str
|
|
1595
1705
|
input_value: str | bytes
|
|
1596
1706
|
input_embedding: list[float]
|
|
@@ -1746,6 +1856,7 @@ class TelemetryMemoriesRequest(TypedDict):
|
|
|
1746
1856
|
limit: NotRequired[int]
|
|
1747
1857
|
filters: NotRequired[list[FilterItem | TelemetryFilterItem]]
|
|
1748
1858
|
sort: NotRequired[list[TelemetrySortOptions] | None]
|
|
1859
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
1749
1860
|
|
|
1750
1861
|
|
|
1751
1862
|
class WorkerInfo(TypedDict):
|
|
@@ -1812,6 +1923,7 @@ class MemorysetMetadata(TypedDict):
|
|
|
1812
1923
|
document_prompt_override: str | None
|
|
1813
1924
|
query_prompt_override: str | None
|
|
1814
1925
|
hidden: bool
|
|
1926
|
+
is_partitioned: bool
|
|
1815
1927
|
insertion_task_id: str | None
|
|
1816
1928
|
|
|
1817
1929
|
|
|
@@ -3660,6 +3772,46 @@ class OrcaClient(Client):
|
|
|
3660
3772
|
) -> EvaluationResponse:
|
|
3661
3773
|
pass
|
|
3662
3774
|
|
|
3775
|
+
@overload
|
|
3776
|
+
def POST(
|
|
3777
|
+
self,
|
|
3778
|
+
path: Literal["/classification_model/metrics"],
|
|
3779
|
+
*,
|
|
3780
|
+
params: None = None,
|
|
3781
|
+
json: ComputeClassificationMetricsRequest,
|
|
3782
|
+
data: None = None,
|
|
3783
|
+
files: None = None,
|
|
3784
|
+
content: None = None,
|
|
3785
|
+
parse_as: Literal["json"] = "json",
|
|
3786
|
+
headers: HeaderTypes | None = None,
|
|
3787
|
+
cookies: CookieTypes | None = None,
|
|
3788
|
+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3789
|
+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3790
|
+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3791
|
+
extensions: RequestExtensions | None = None,
|
|
3792
|
+
) -> ClassificationMetrics:
|
|
3793
|
+
pass
|
|
3794
|
+
|
|
3795
|
+
@overload
|
|
3796
|
+
def POST(
|
|
3797
|
+
self,
|
|
3798
|
+
path: Literal["/regression_model/metrics"],
|
|
3799
|
+
*,
|
|
3800
|
+
params: None = None,
|
|
3801
|
+
json: ComputeRegressionMetricsRequest,
|
|
3802
|
+
data: None = None,
|
|
3803
|
+
files: None = None,
|
|
3804
|
+
content: None = None,
|
|
3805
|
+
parse_as: Literal["json"] = "json",
|
|
3806
|
+
headers: HeaderTypes | None = None,
|
|
3807
|
+
cookies: CookieTypes | None = None,
|
|
3808
|
+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3809
|
+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3810
|
+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3811
|
+
extensions: RequestExtensions | None = None,
|
|
3812
|
+
) -> RegressionMetrics:
|
|
3813
|
+
pass
|
|
3814
|
+
|
|
3663
3815
|
@overload
|
|
3664
3816
|
def POST(
|
|
3665
3817
|
self,
|
orca_sdk/conftest.py
CHANGED
|
@@ -24,15 +24,6 @@ os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:15
|
|
|
24
24
|
os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def skip_in_prod(reason: str):
|
|
28
|
-
"""Custom decorator to skip tests when running against production API"""
|
|
29
|
-
PROD_API_URLs = ["https://api.orcadb.ai", "https://api.staging.orcadb.ai"]
|
|
30
|
-
return pytest.mark.skipif(
|
|
31
|
-
os.environ["ORCA_API_URL"] in PROD_API_URLs,
|
|
32
|
-
reason=reason,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
|
|
36
27
|
def skip_in_ci(reason: str):
|
|
37
28
|
"""Custom decorator to skip tests when running in CI"""
|
|
38
29
|
return pytest.mark.skipif(
|
|
@@ -201,6 +192,11 @@ SAMPLE_DATA = [
|
|
|
201
192
|
]
|
|
202
193
|
|
|
203
194
|
|
|
195
|
+
@pytest.fixture(scope="session")
|
|
196
|
+
def data() -> list[dict]:
|
|
197
|
+
return SAMPLE_DATA
|
|
198
|
+
|
|
199
|
+
|
|
204
200
|
@pytest.fixture(scope="session")
|
|
205
201
|
def hf_dataset(label_names: list[str]) -> Dataset:
|
|
206
202
|
return Dataset.from_list(
|
|
@@ -232,6 +228,11 @@ EVAL_DATASET = [
|
|
|
232
228
|
]
|
|
233
229
|
|
|
234
230
|
|
|
231
|
+
@pytest.fixture(scope="session")
|
|
232
|
+
def eval_data() -> list[dict]:
|
|
233
|
+
return EVAL_DATASET
|
|
234
|
+
|
|
235
|
+
|
|
235
236
|
@pytest.fixture(scope="session")
|
|
236
237
|
def eval_datasource() -> Datasource:
|
|
237
238
|
eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
|
orca_sdk/datasource.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import tempfile
|
|
5
4
|
import zipfile
|
|
6
5
|
from datetime import datetime
|
|
@@ -9,17 +8,18 @@ from os import PathLike
|
|
|
9
8
|
from pathlib import Path
|
|
10
9
|
from typing import TYPE_CHECKING, Any, Literal, Union, cast
|
|
11
10
|
|
|
12
|
-
from datasets import Dataset, DatasetDict
|
|
13
11
|
from httpx._types import FileTypes # type: ignore
|
|
14
12
|
from tqdm.auto import tqdm
|
|
15
13
|
|
|
16
|
-
from ._utils.common import CreateMode, DropMode
|
|
17
|
-
from ._utils.
|
|
14
|
+
from ._utils.common import CreateMode, DropMode, logger
|
|
15
|
+
from ._utils.torch_parsing import list_from_torch
|
|
18
16
|
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
19
17
|
from .client import DatasourceMetadata, OrcaClient
|
|
20
18
|
|
|
21
19
|
if TYPE_CHECKING:
|
|
22
20
|
# These are peer dependencies that are used for types only
|
|
21
|
+
from datasets import Dataset as HFDataset # type: ignore
|
|
22
|
+
from datasets import DatasetDict as HFDatasetDict # type: ignore
|
|
23
23
|
from pandas import DataFrame as PandasDataFrame # type: ignore
|
|
24
24
|
from pyarrow import Table as PyArrowTable # type: ignore
|
|
25
25
|
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
@@ -146,7 +146,7 @@ class Datasource:
|
|
|
146
146
|
|
|
147
147
|
@classmethod
|
|
148
148
|
def from_hf_dataset(
|
|
149
|
-
cls, name: str, dataset:
|
|
149
|
+
cls, name: str, dataset: HFDataset, if_exists: CreateMode = "error", description: str | None = None
|
|
150
150
|
) -> Datasource:
|
|
151
151
|
"""
|
|
152
152
|
Create a new datasource from a Hugging Face Dataset
|
|
@@ -183,7 +183,7 @@ class Datasource:
|
|
|
183
183
|
def from_hf_dataset_dict(
|
|
184
184
|
cls,
|
|
185
185
|
name: str,
|
|
186
|
-
dataset_dict:
|
|
186
|
+
dataset_dict: HFDatasetDict,
|
|
187
187
|
if_exists: CreateMode = "error",
|
|
188
188
|
description: dict[str, str | None] | str | None = None,
|
|
189
189
|
) -> dict[str, Datasource]:
|
|
@@ -239,8 +239,8 @@ class Datasource:
|
|
|
239
239
|
Raises:
|
|
240
240
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
241
241
|
"""
|
|
242
|
-
|
|
243
|
-
return cls.
|
|
242
|
+
data_list = list_from_torch(torch_data, column_names=column_names)
|
|
243
|
+
return cls.from_list(name, data_list, if_exists=if_exists, description=description)
|
|
244
244
|
|
|
245
245
|
@classmethod
|
|
246
246
|
def from_list(
|
|
@@ -326,14 +326,24 @@ class Datasource:
|
|
|
326
326
|
`"error"`. Other option is `"open"` to open the existing datasource.
|
|
327
327
|
description: Optional description for the datasource
|
|
328
328
|
|
|
329
|
+
Notes:
|
|
330
|
+
Data type precision may be lost during upload unless the [`datasets`][datasets] library is installed.
|
|
331
|
+
|
|
329
332
|
Returns:
|
|
330
333
|
A handle to the new datasource in the OrcaCloud
|
|
331
334
|
|
|
332
335
|
Raises:
|
|
333
336
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
337
|
+
ImportError: If the upload dependency group is not installed
|
|
334
338
|
"""
|
|
335
|
-
|
|
336
|
-
|
|
339
|
+
try:
|
|
340
|
+
from datasets import Dataset # type: ignore
|
|
341
|
+
|
|
342
|
+
return cls.from_hf_dataset(
|
|
343
|
+
name, Dataset.from_pandas(dataframe), if_exists=if_exists, description=description
|
|
344
|
+
)
|
|
345
|
+
except ImportError:
|
|
346
|
+
return cls.from_dict(name, dataframe.to_dict(orient="list"), if_exists=if_exists, description=description)
|
|
337
347
|
|
|
338
348
|
@classmethod
|
|
339
349
|
def from_arrow(
|
|
@@ -404,6 +414,7 @@ class Datasource:
|
|
|
404
414
|
|
|
405
415
|
Raises:
|
|
406
416
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
417
|
+
ImportError: If the path is a directory and [`datasets`][datasets] is not installed
|
|
407
418
|
"""
|
|
408
419
|
# Check if datasource already exists and handle accordingly
|
|
409
420
|
existing = _handle_existing_datasource(name, if_exists)
|
|
@@ -414,6 +425,13 @@ class Datasource:
|
|
|
414
425
|
|
|
415
426
|
# For dataset directories, use the upload endpoint with multiple files
|
|
416
427
|
if file_path.is_dir():
|
|
428
|
+
try:
|
|
429
|
+
from datasets import Dataset # type: ignore
|
|
430
|
+
except ImportError as e:
|
|
431
|
+
raise ImportError(
|
|
432
|
+
"The path is a directory, we only support uploading directories that contain saved HuggingFace datasets but datasets is not installed."
|
|
433
|
+
) from e
|
|
434
|
+
|
|
417
435
|
return cls.from_hf_dataset(
|
|
418
436
|
name, Dataset.load_from_disk(file_path), if_exists=if_exists, description=description
|
|
419
437
|
)
|
|
@@ -484,7 +502,7 @@ class Datasource:
|
|
|
484
502
|
try:
|
|
485
503
|
client = OrcaClient._resolve_client()
|
|
486
504
|
client.DELETE("/datasource/{name_or_id}", params={"name_or_id": name_or_id})
|
|
487
|
-
|
|
505
|
+
logger.info(f"Deleted datasource {name_or_id}")
|
|
488
506
|
except LookupError:
|
|
489
507
|
if if_not_exists == "error":
|
|
490
508
|
raise
|
|
@@ -566,9 +584,9 @@ class Datasource:
|
|
|
566
584
|
with zipfile.ZipFile(output_path, "r") as zip_ref:
|
|
567
585
|
zip_ref.extractall(extract_dir)
|
|
568
586
|
output_path.unlink() # Remove the zip file after extraction
|
|
569
|
-
|
|
587
|
+
logger.info(f"Downloaded {extract_dir}")
|
|
570
588
|
else:
|
|
571
|
-
|
|
589
|
+
logger.info(f"Downloaded {output_path}")
|
|
572
590
|
|
|
573
591
|
def to_list(self) -> list[dict]:
|
|
574
592
|
"""
|
orca_sdk/embedding_model.py
CHANGED
|
@@ -4,8 +4,7 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
|
|
6
6
|
|
|
7
|
-
from .
|
|
8
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
7
|
+
from ._utils.common import CreateMode, DropMode
|
|
9
8
|
from .client import (
|
|
10
9
|
EmbeddingEvaluationRequest,
|
|
11
10
|
EmbeddingFinetuningMethod,
|
|
@@ -20,7 +19,9 @@ from .datasource import Datasource
|
|
|
20
19
|
from .job import Job, Status
|
|
21
20
|
|
|
22
21
|
if TYPE_CHECKING:
|
|
22
|
+
from .classification_model import ClassificationMetrics
|
|
23
23
|
from .memoryset import LabeledMemoryset, ScoredMemoryset
|
|
24
|
+
from .regression_model import RegressionMetrics
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class EmbeddingModelBase(ABC):
|
|
@@ -230,6 +231,9 @@ class EmbeddingModelBase(ABC):
|
|
|
230
231
|
raise ValueError("Invalid embedding model")
|
|
231
232
|
|
|
232
233
|
def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
234
|
+
from .classification_model import ClassificationMetrics
|
|
235
|
+
from .regression_model import RegressionMetrics
|
|
236
|
+
|
|
233
237
|
client = OrcaClient._resolve_client()
|
|
234
238
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
235
239
|
res = client.GET(
|
|
@@ -244,34 +248,7 @@ class EmbeddingModelBase(ABC):
|
|
|
244
248
|
else:
|
|
245
249
|
raise ValueError("Invalid embedding model")
|
|
246
250
|
assert res is not None
|
|
247
|
-
return (
|
|
248
|
-
RegressionMetrics(
|
|
249
|
-
coverage=res.get("coverage"),
|
|
250
|
-
mse=res.get("mse"),
|
|
251
|
-
rmse=res.get("rmse"),
|
|
252
|
-
mae=res.get("mae"),
|
|
253
|
-
r2=res.get("r2"),
|
|
254
|
-
explained_variance=res.get("explained_variance"),
|
|
255
|
-
loss=res.get("loss"),
|
|
256
|
-
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
257
|
-
anomaly_score_median=res.get("anomaly_score_median"),
|
|
258
|
-
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
259
|
-
)
|
|
260
|
-
if "mse" in res
|
|
261
|
-
else ClassificationMetrics(
|
|
262
|
-
coverage=res.get("coverage"),
|
|
263
|
-
f1_score=res.get("f1_score"),
|
|
264
|
-
accuracy=res.get("accuracy"),
|
|
265
|
-
loss=res.get("loss"),
|
|
266
|
-
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
267
|
-
anomaly_score_median=res.get("anomaly_score_median"),
|
|
268
|
-
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
269
|
-
roc_auc=res.get("roc_auc"),
|
|
270
|
-
pr_auc=res.get("pr_auc"),
|
|
271
|
-
pr_curve=res.get("pr_curve"),
|
|
272
|
-
roc_curve=res.get("roc_curve"),
|
|
273
|
-
)
|
|
274
|
-
)
|
|
251
|
+
return RegressionMetrics(res) if "mse" in res else ClassificationMetrics(res)
|
|
275
252
|
|
|
276
253
|
job = Job(response["job_id"], lambda: get_result(response["job_id"]))
|
|
277
254
|
return job if background else job.result()
|
|
@@ -404,7 +381,7 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
404
381
|
return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
|
|
405
382
|
|
|
406
383
|
def __repr__(self) -> str:
|
|
407
|
-
return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params/1000000:.0f}M}})"
|
|
384
|
+
return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params / 1000000:.0f}M}})"
|
|
408
385
|
|
|
409
386
|
@classmethod
|
|
410
387
|
def all(cls) -> list[PretrainedEmbeddingModel]:
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -4,9 +4,9 @@ from uuid import uuid4
|
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
6
|
|
|
7
|
+
from .classification_model import ClassificationMetrics
|
|
7
8
|
from .datasource import Datasource
|
|
8
9
|
from .embedding_model import (
|
|
9
|
-
ClassificationMetrics,
|
|
10
10
|
FinetunedEmbeddingModel,
|
|
11
11
|
PretrainedEmbeddingModel,
|
|
12
12
|
PretrainedEmbeddingModelName,
|