orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +3 -3
- orca_sdk/_utils/analysis_ui.py +4 -1
- orca_sdk/_utils/auth.py +2 -3
- orca_sdk/_utils/common.py +24 -1
- orca_sdk/_utils/prediction_result_ui.py +4 -1
- orca_sdk/_utils/torch_parsing.py +77 -0
- orca_sdk/_utils/torch_parsing_test.py +142 -0
- orca_sdk/_utils/value_parser.py +44 -17
- orca_sdk/_utils/value_parser_test.py +6 -5
- orca_sdk/async_client.py +234 -22
- orca_sdk/classification_model.py +203 -66
- orca_sdk/classification_model_test.py +85 -25
- orca_sdk/client.py +234 -20
- orca_sdk/conftest.py +97 -16
- orca_sdk/credentials_test.py +5 -8
- orca_sdk/datasource.py +44 -21
- orca_sdk/datasource_test.py +8 -2
- orca_sdk/embedding_model.py +15 -33
- orca_sdk/embedding_model_test.py +30 -1
- orca_sdk/memoryset.py +558 -425
- orca_sdk/memoryset_test.py +120 -185
- orca_sdk/regression_model.py +186 -65
- orca_sdk/regression_model_test.py +62 -3
- orca_sdk/telemetry.py +16 -7
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/METADATA +4 -8
- orca_sdk-0.1.12.dist-info/RECORD +38 -0
- orca_sdk/_shared/__init__.py +0 -10
- orca_sdk/_shared/metrics.py +0 -634
- orca_sdk/_shared/metrics_test.py +0 -570
- orca_sdk/_utils/data_parsing.py +0 -129
- orca_sdk/_utils/data_parsing_test.py +0 -244
- orca_sdk-0.1.10.dist-info/RECORD +0 -41
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/WHEEL +0 -0
orca_sdk/conftest.py
CHANGED
|
@@ -24,15 +24,6 @@ os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:15
|
|
|
24
24
|
os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def skip_in_prod(reason: str):
|
|
28
|
-
"""Custom decorator to skip tests when running against production API"""
|
|
29
|
-
PROD_API_URLs = ["https://api.orcadb.ai", "https://api.staging.orcadb.ai"]
|
|
30
|
-
return pytest.mark.skipif(
|
|
31
|
-
os.environ["ORCA_API_URL"] in PROD_API_URLs,
|
|
32
|
-
reason=reason,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
|
|
36
27
|
def skip_in_ci(reason: str):
|
|
37
28
|
"""Custom decorator to skip tests when running in CI"""
|
|
38
29
|
return pytest.mark.skipif(
|
|
@@ -201,6 +192,11 @@ SAMPLE_DATA = [
|
|
|
201
192
|
]
|
|
202
193
|
|
|
203
194
|
|
|
195
|
+
@pytest.fixture(scope="session")
|
|
196
|
+
def data() -> list[dict]:
|
|
197
|
+
return SAMPLE_DATA
|
|
198
|
+
|
|
199
|
+
|
|
204
200
|
@pytest.fixture(scope="session")
|
|
205
201
|
def hf_dataset(label_names: list[str]) -> Dataset:
|
|
206
202
|
return Dataset.from_list(
|
|
@@ -232,6 +228,11 @@ EVAL_DATASET = [
|
|
|
232
228
|
]
|
|
233
229
|
|
|
234
230
|
|
|
231
|
+
@pytest.fixture(scope="session")
|
|
232
|
+
def eval_data() -> list[dict]:
|
|
233
|
+
return EVAL_DATASET
|
|
234
|
+
|
|
235
|
+
|
|
235
236
|
@pytest.fixture(scope="session")
|
|
236
237
|
def eval_datasource() -> Datasource:
|
|
237
238
|
eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
|
|
@@ -288,6 +289,7 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
288
289
|
datasource=datasource,
|
|
289
290
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
290
291
|
source_id_column="source_id",
|
|
292
|
+
partition_id_column="partition_id",
|
|
291
293
|
max_seq_length_override=32,
|
|
292
294
|
if_exists="open",
|
|
293
295
|
)
|
|
@@ -297,13 +299,7 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
297
299
|
# Restore the memoryset to a clean state for the next test.
|
|
298
300
|
with OrcaClient(api_key=api_key).use():
|
|
299
301
|
if LabeledMemoryset.exists("test_writable_memoryset"):
|
|
300
|
-
memoryset.
|
|
301
|
-
|
|
302
|
-
memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
|
|
303
|
-
|
|
304
|
-
if memory_ids:
|
|
305
|
-
memoryset.delete(memory_ids)
|
|
306
|
-
memoryset.refresh()
|
|
302
|
+
memoryset.truncate()
|
|
307
303
|
assert len(memoryset) == 0
|
|
308
304
|
memoryset.insert(SAMPLE_DATA)
|
|
309
305
|
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
|
@@ -380,3 +376,88 @@ def partitioned_regression_model(readonly_partitioned_scored_memoryset: ScoredMe
|
|
|
380
376
|
description="test_partitioned_regression_description",
|
|
381
377
|
)
|
|
382
378
|
return model
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@pytest.fixture(scope="function")
|
|
382
|
+
def fully_partitioned_classification_resources() -> (
|
|
383
|
+
Generator[tuple[Datasource, LabeledMemoryset, ClassificationModel], None, None]
|
|
384
|
+
):
|
|
385
|
+
data = [
|
|
386
|
+
{"value": "i love soup", "label": 0, "partition_id": "p1"},
|
|
387
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
388
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
389
|
+
{"value": "i love cats", "label": 1, "partition_id": "p2"},
|
|
390
|
+
{"value": "everyone loves cats", "label": 1, "partition_id": "p2"},
|
|
391
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
392
|
+
{"value": "cats are amazing animals", "label": 1, "partition_id": "p2"},
|
|
393
|
+
{"value": "tomato soup is delicious", "label": 0, "partition_id": "p1"},
|
|
394
|
+
{"value": "cats love to play", "label": 1, "partition_id": "p2"},
|
|
395
|
+
{"value": "i enjoy eating soup", "label": 0, "partition_id": "p1"},
|
|
396
|
+
{"value": "my cat is fluffy", "label": 1, "partition_id": "p2"},
|
|
397
|
+
{"value": "chicken soup is tasty", "label": 0, "partition_id": "p1"},
|
|
398
|
+
{"value": "cats are playful pets", "label": 1, "partition_id": "p2"},
|
|
399
|
+
{"value": "soup warms the soul", "label": 0, "partition_id": "p1"},
|
|
400
|
+
{"value": "cats have soft fur", "label": 1, "partition_id": "p2"},
|
|
401
|
+
{"value": "vegetable soup is healthy", "label": 0, "partition_id": "p1"},
|
|
402
|
+
]
|
|
403
|
+
|
|
404
|
+
datasource = None
|
|
405
|
+
memoryset = None
|
|
406
|
+
classification_model = None
|
|
407
|
+
try:
|
|
408
|
+
datasource = Datasource.from_list("fully_partitioned_classification_datasource", data)
|
|
409
|
+
memoryset = LabeledMemoryset.create(
|
|
410
|
+
"fully_partitioned_classification_memoryset",
|
|
411
|
+
datasource=datasource,
|
|
412
|
+
label_names=["soup", "cats"],
|
|
413
|
+
partition_id_column="partition_id",
|
|
414
|
+
)
|
|
415
|
+
classification_model = ClassificationModel.create("fully_partitioned_classification_model", memoryset=memoryset)
|
|
416
|
+
yield (datasource, memoryset, classification_model)
|
|
417
|
+
finally:
|
|
418
|
+
# Clean up in reverse order of creation
|
|
419
|
+
ClassificationModel.drop("fully_partitioned_classification_model", if_not_exists="ignore")
|
|
420
|
+
LabeledMemoryset.drop("fully_partitioned_classification_memoryset", if_not_exists="ignore")
|
|
421
|
+
Datasource.drop("fully_partitioned_classification_datasource", if_not_exists="ignore")
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@pytest.fixture(scope="function")
|
|
425
|
+
def fully_partitioned_regression_resources() -> (
|
|
426
|
+
Generator[tuple[Datasource, ScoredMemoryset, RegressionModel], None, None]
|
|
427
|
+
):
|
|
428
|
+
data = [
|
|
429
|
+
{"value": "i love soup", "score": 0.1, "partition_id": "p1"},
|
|
430
|
+
{"value": "cats are cute", "score": 0.9, "partition_id": "p1"},
|
|
431
|
+
{"value": "soup is good", "score": 0.1, "partition_id": "p1"},
|
|
432
|
+
{"value": "i love cats", "score": 0.9, "partition_id": "p2"},
|
|
433
|
+
{"value": "everyone loves cats", "score": 0.9, "partition_id": "p2"},
|
|
434
|
+
{"value": "soup is good", "score": 0.1, "partition_id": "p1"},
|
|
435
|
+
{"value": "cats are amazing animals", "score": 0.9, "partition_id": "p2"},
|
|
436
|
+
{"value": "tomato soup is delicious", "score": 0.1, "partition_id": "p1"},
|
|
437
|
+
{"value": "cats love to play", "score": 0.9, "partition_id": "p2"},
|
|
438
|
+
{"value": "i enjoy eating soup", "score": 0.1, "partition_id": "p1"},
|
|
439
|
+
{"value": "my cat is fluffy", "score": 0.9, "partition_id": "p2"},
|
|
440
|
+
{"value": "chicken soup is tasty", "score": 0.1, "partition_id": "p1"},
|
|
441
|
+
{"value": "cats are playful pets", "score": 0.9, "partition_id": "p2"},
|
|
442
|
+
{"value": "soup warms the soul", "score": 0.1, "partition_id": "p1"},
|
|
443
|
+
{"value": "cats have soft fur", "score": 0.9, "partition_id": "p2"},
|
|
444
|
+
{"value": "vegetable soup is healthy", "score": 0.1, "partition_id": "p1"},
|
|
445
|
+
]
|
|
446
|
+
|
|
447
|
+
datasource = None
|
|
448
|
+
memoryset = None
|
|
449
|
+
regression_model = None
|
|
450
|
+
try:
|
|
451
|
+
datasource = Datasource.from_list("fully_partitioned_regression_datasource", data)
|
|
452
|
+
memoryset = ScoredMemoryset.create(
|
|
453
|
+
"fully_partitioned_regression_memoryset",
|
|
454
|
+
datasource=datasource,
|
|
455
|
+
partition_id_column="partition_id",
|
|
456
|
+
)
|
|
457
|
+
regression_model = RegressionModel.create("fully_partitioned_regression_model", memoryset=memoryset)
|
|
458
|
+
yield (datasource, memoryset, regression_model)
|
|
459
|
+
finally:
|
|
460
|
+
# Clean up in reverse order of creation
|
|
461
|
+
RegressionModel.drop("fully_partitioned_regression_model", if_not_exists="ignore")
|
|
462
|
+
ScoredMemoryset.drop("fully_partitioned_regression_memoryset", if_not_exists="ignore")
|
|
463
|
+
Datasource.drop("fully_partitioned_regression_datasource", if_not_exists="ignore")
|
orca_sdk/credentials_test.py
CHANGED
|
@@ -75,7 +75,7 @@ def test_create_api_key_already_exists():
|
|
|
75
75
|
OrcaCredentials.create_api_key("orca_sdk_test")
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
def
|
|
78
|
+
def test_use_client(api_key):
|
|
79
79
|
client = OrcaClient(api_key=str(uuid4()))
|
|
80
80
|
with client.use():
|
|
81
81
|
assert not OrcaCredentials.is_authenticated()
|
|
@@ -91,17 +91,14 @@ def test_set_base_url(api_key):
|
|
|
91
91
|
assert client.base_url == "http://localhost:1583"
|
|
92
92
|
|
|
93
93
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def test_deprecated_set_api_key(api_key):
|
|
94
|
+
def test_set_api_key(api_key):
|
|
98
95
|
with OrcaClient(api_key=str(uuid4())).use():
|
|
99
96
|
assert not OrcaCredentials.is_authenticated()
|
|
100
97
|
OrcaCredentials.set_api_key(api_key)
|
|
101
98
|
assert OrcaCredentials.is_authenticated()
|
|
102
99
|
|
|
103
100
|
|
|
104
|
-
def
|
|
101
|
+
def test_set_invalid_api_key(api_key):
|
|
105
102
|
with OrcaClient(api_key=api_key).use():
|
|
106
103
|
assert OrcaCredentials.is_authenticated()
|
|
107
104
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
@@ -109,13 +106,13 @@ def test_deprecated_set_invalid_api_key(api_key):
|
|
|
109
106
|
assert not OrcaCredentials.is_authenticated()
|
|
110
107
|
|
|
111
108
|
|
|
112
|
-
def
|
|
109
|
+
def test_set_api_url(api_key):
|
|
113
110
|
with OrcaClient(api_key=api_key).use():
|
|
114
111
|
OrcaCredentials.set_api_url("http://api.orcadb.ai")
|
|
115
112
|
assert str(OrcaClient._resolve_client().base_url) == "http://api.orcadb.ai"
|
|
116
113
|
|
|
117
114
|
|
|
118
|
-
def
|
|
115
|
+
def test_set_invalid_api_url(api_key):
|
|
119
116
|
with OrcaClient(api_key=api_key).use():
|
|
120
117
|
with pytest.raises(ValueError, match="No API found at http://localhost:1582"):
|
|
121
118
|
OrcaCredentials.set_api_url("http://localhost:1582")
|
orca_sdk/datasource.py
CHANGED
|
@@ -1,28 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import tempfile
|
|
5
4
|
import zipfile
|
|
6
5
|
from datetime import datetime
|
|
7
6
|
from io import BytesIO
|
|
8
7
|
from os import PathLike
|
|
9
8
|
from pathlib import Path
|
|
10
|
-
from typing import Any, Literal, Union, cast
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal, Union, cast
|
|
11
10
|
|
|
12
|
-
import pandas as pd
|
|
13
|
-
import pyarrow as pa
|
|
14
|
-
from datasets import Dataset, DatasetDict
|
|
15
11
|
from httpx._types import FileTypes # type: ignore
|
|
16
|
-
from pyarrow import parquet
|
|
17
|
-
from torch.utils.data import DataLoader as TorchDataLoader
|
|
18
|
-
from torch.utils.data import Dataset as TorchDataset
|
|
19
12
|
from tqdm.auto import tqdm
|
|
20
13
|
|
|
21
|
-
from ._utils.common import CreateMode, DropMode
|
|
22
|
-
from ._utils.
|
|
14
|
+
from ._utils.common import CreateMode, DropMode, logger
|
|
15
|
+
from ._utils.torch_parsing import list_from_torch
|
|
23
16
|
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
24
17
|
from .client import DatasourceMetadata, OrcaClient
|
|
25
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
# These are peer dependencies that are used for types only
|
|
21
|
+
from datasets import Dataset as HFDataset # type: ignore
|
|
22
|
+
from datasets import DatasetDict as HFDatasetDict # type: ignore
|
|
23
|
+
from pandas import DataFrame as PandasDataFrame # type: ignore
|
|
24
|
+
from pyarrow import Table as PyArrowTable # type: ignore
|
|
25
|
+
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
26
|
+
from torch.utils.data import Dataset as TorchDataset # type: ignore
|
|
27
|
+
|
|
26
28
|
|
|
27
29
|
def _upload_files_to_datasource(
|
|
28
30
|
name: str,
|
|
@@ -144,7 +146,7 @@ class Datasource:
|
|
|
144
146
|
|
|
145
147
|
@classmethod
|
|
146
148
|
def from_hf_dataset(
|
|
147
|
-
cls, name: str, dataset:
|
|
149
|
+
cls, name: str, dataset: HFDataset, if_exists: CreateMode = "error", description: str | None = None
|
|
148
150
|
) -> Datasource:
|
|
149
151
|
"""
|
|
150
152
|
Create a new datasource from a Hugging Face Dataset
|
|
@@ -181,7 +183,7 @@ class Datasource:
|
|
|
181
183
|
def from_hf_dataset_dict(
|
|
182
184
|
cls,
|
|
183
185
|
name: str,
|
|
184
|
-
dataset_dict:
|
|
186
|
+
dataset_dict: HFDatasetDict,
|
|
185
187
|
if_exists: CreateMode = "error",
|
|
186
188
|
description: dict[str, str | None] | str | None = None,
|
|
187
189
|
) -> dict[str, Datasource]:
|
|
@@ -237,8 +239,8 @@ class Datasource:
|
|
|
237
239
|
Raises:
|
|
238
240
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
239
241
|
"""
|
|
240
|
-
|
|
241
|
-
return cls.
|
|
242
|
+
data_list = list_from_torch(torch_data, column_names=column_names)
|
|
243
|
+
return cls.from_list(name, data_list, if_exists=if_exists, description=description)
|
|
242
244
|
|
|
243
245
|
@classmethod
|
|
244
246
|
def from_list(
|
|
@@ -312,7 +314,7 @@ class Datasource:
|
|
|
312
314
|
|
|
313
315
|
@classmethod
|
|
314
316
|
def from_pandas(
|
|
315
|
-
cls, name: str, dataframe:
|
|
317
|
+
cls, name: str, dataframe: PandasDataFrame, if_exists: CreateMode = "error", description: str | None = None
|
|
316
318
|
) -> Datasource:
|
|
317
319
|
"""
|
|
318
320
|
Create a new datasource from a pandas DataFrame
|
|
@@ -324,18 +326,28 @@ class Datasource:
|
|
|
324
326
|
`"error"`. Other option is `"open"` to open the existing datasource.
|
|
325
327
|
description: Optional description for the datasource
|
|
326
328
|
|
|
329
|
+
Notes:
|
|
330
|
+
Data type precision may be lost during upload unless the [`datasets`][datasets] library is installed.
|
|
331
|
+
|
|
327
332
|
Returns:
|
|
328
333
|
A handle to the new datasource in the OrcaCloud
|
|
329
334
|
|
|
330
335
|
Raises:
|
|
331
336
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
337
|
+
ImportError: If the upload dependency group is not installed
|
|
332
338
|
"""
|
|
333
|
-
|
|
334
|
-
|
|
339
|
+
try:
|
|
340
|
+
from datasets import Dataset # type: ignore
|
|
341
|
+
|
|
342
|
+
return cls.from_hf_dataset(
|
|
343
|
+
name, Dataset.from_pandas(dataframe), if_exists=if_exists, description=description
|
|
344
|
+
)
|
|
345
|
+
except ImportError:
|
|
346
|
+
return cls.from_dict(name, dataframe.to_dict(orient="list"), if_exists=if_exists, description=description)
|
|
335
347
|
|
|
336
348
|
@classmethod
|
|
337
349
|
def from_arrow(
|
|
338
|
-
cls, name: str, pyarrow_table:
|
|
350
|
+
cls, name: str, pyarrow_table: PyArrowTable, if_exists: CreateMode = "error", description: str | None = None
|
|
339
351
|
) -> Datasource:
|
|
340
352
|
"""
|
|
341
353
|
Create a new datasource from a pyarrow Table
|
|
@@ -358,6 +370,9 @@ class Datasource:
|
|
|
358
370
|
if existing is not None:
|
|
359
371
|
return existing
|
|
360
372
|
|
|
373
|
+
# peer dependency that is guaranteed to exist if the user provided a pyarrow table
|
|
374
|
+
from pyarrow import parquet # type: ignore
|
|
375
|
+
|
|
361
376
|
# Write to bytes buffer
|
|
362
377
|
buffer = BytesIO()
|
|
363
378
|
parquet.write_table(pyarrow_table, buffer)
|
|
@@ -399,6 +414,7 @@ class Datasource:
|
|
|
399
414
|
|
|
400
415
|
Raises:
|
|
401
416
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
417
|
+
ImportError: If the path is a directory and [`datasets`][datasets] is not installed
|
|
402
418
|
"""
|
|
403
419
|
# Check if datasource already exists and handle accordingly
|
|
404
420
|
existing = _handle_existing_datasource(name, if_exists)
|
|
@@ -409,6 +425,13 @@ class Datasource:
|
|
|
409
425
|
|
|
410
426
|
# For dataset directories, use the upload endpoint with multiple files
|
|
411
427
|
if file_path.is_dir():
|
|
428
|
+
try:
|
|
429
|
+
from datasets import Dataset # type: ignore
|
|
430
|
+
except ImportError as e:
|
|
431
|
+
raise ImportError(
|
|
432
|
+
"The path is a directory, we only support uploading directories that contain saved HuggingFace datasets but datasets is not installed."
|
|
433
|
+
) from e
|
|
434
|
+
|
|
412
435
|
return cls.from_hf_dataset(
|
|
413
436
|
name, Dataset.load_from_disk(file_path), if_exists=if_exists, description=description
|
|
414
437
|
)
|
|
@@ -479,7 +502,7 @@ class Datasource:
|
|
|
479
502
|
try:
|
|
480
503
|
client = OrcaClient._resolve_client()
|
|
481
504
|
client.DELETE("/datasource/{name_or_id}", params={"name_or_id": name_or_id})
|
|
482
|
-
|
|
505
|
+
logger.info(f"Deleted datasource {name_or_id}")
|
|
483
506
|
except LookupError:
|
|
484
507
|
if if_not_exists == "error":
|
|
485
508
|
raise
|
|
@@ -561,9 +584,9 @@ class Datasource:
|
|
|
561
584
|
with zipfile.ZipFile(output_path, "r") as zip_ref:
|
|
562
585
|
zip_ref.extractall(extract_dir)
|
|
563
586
|
output_path.unlink() # Remove the zip file after extraction
|
|
564
|
-
|
|
587
|
+
logger.info(f"Downloaded {extract_dir}")
|
|
565
588
|
else:
|
|
566
|
-
|
|
589
|
+
logger.info(f"Downloaded {output_path}")
|
|
567
590
|
|
|
568
591
|
def to_list(self) -> list[dict]:
|
|
569
592
|
"""
|
orca_sdk/datasource_test.py
CHANGED
|
@@ -5,8 +5,6 @@ from typing import cast
|
|
|
5
5
|
from uuid import uuid4
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
import pandas as pd
|
|
9
|
-
import pyarrow as pa
|
|
10
8
|
import pytest
|
|
11
9
|
from datasets import Dataset
|
|
12
10
|
|
|
@@ -137,6 +135,8 @@ def test_from_dict():
|
|
|
137
135
|
|
|
138
136
|
|
|
139
137
|
def test_from_pandas():
|
|
138
|
+
pd = pytest.importorskip("pandas")
|
|
139
|
+
|
|
140
140
|
# Test creating datasource from pandas DataFrame
|
|
141
141
|
df = pd.DataFrame(
|
|
142
142
|
{
|
|
@@ -152,6 +152,8 @@ def test_from_pandas():
|
|
|
152
152
|
|
|
153
153
|
|
|
154
154
|
def test_from_arrow():
|
|
155
|
+
pa = pytest.importorskip("pyarrow")
|
|
156
|
+
|
|
155
157
|
# Test creating datasource from pyarrow Table
|
|
156
158
|
table = pa.table(
|
|
157
159
|
{
|
|
@@ -205,6 +207,8 @@ def test_from_dict_already_exists():
|
|
|
205
207
|
|
|
206
208
|
|
|
207
209
|
def test_from_pandas_already_exists():
|
|
210
|
+
pd = pytest.importorskip("pandas")
|
|
211
|
+
|
|
208
212
|
# Test the if_exists parameter with from_pandas
|
|
209
213
|
df = pd.DataFrame({"column1": [1], "column2": ["a"]})
|
|
210
214
|
name = f"test_pandas_exists_{uuid4()}"
|
|
@@ -224,6 +228,8 @@ def test_from_pandas_already_exists():
|
|
|
224
228
|
|
|
225
229
|
|
|
226
230
|
def test_from_arrow_already_exists():
|
|
231
|
+
pa = pytest.importorskip("pyarrow")
|
|
232
|
+
|
|
227
233
|
# Test the if_exists parameter with from_arrow
|
|
228
234
|
table = pa.table({"column1": [1], "column2": ["a"]})
|
|
229
235
|
name = f"test_arrow_exists_{uuid4()}"
|
orca_sdk/embedding_model.py
CHANGED
|
@@ -4,8 +4,7 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
|
|
6
6
|
|
|
7
|
-
from .
|
|
8
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
7
|
+
from ._utils.common import CreateMode, DropMode
|
|
9
8
|
from .client import (
|
|
10
9
|
EmbeddingEvaluationRequest,
|
|
11
10
|
EmbeddingFinetuningMethod,
|
|
@@ -20,7 +19,9 @@ from .datasource import Datasource
|
|
|
20
19
|
from .job import Job, Status
|
|
21
20
|
|
|
22
21
|
if TYPE_CHECKING:
|
|
22
|
+
from .classification_model import ClassificationMetrics
|
|
23
23
|
from .memoryset import LabeledMemoryset, ScoredMemoryset
|
|
24
|
+
from .regression_model import RegressionMetrics
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class EmbeddingModelBase(ABC):
|
|
@@ -230,6 +231,9 @@ class EmbeddingModelBase(ABC):
|
|
|
230
231
|
raise ValueError("Invalid embedding model")
|
|
231
232
|
|
|
232
233
|
def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
234
|
+
from .classification_model import ClassificationMetrics
|
|
235
|
+
from .regression_model import RegressionMetrics
|
|
236
|
+
|
|
233
237
|
client = OrcaClient._resolve_client()
|
|
234
238
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
235
239
|
res = client.GET(
|
|
@@ -244,34 +248,7 @@ class EmbeddingModelBase(ABC):
|
|
|
244
248
|
else:
|
|
245
249
|
raise ValueError("Invalid embedding model")
|
|
246
250
|
assert res is not None
|
|
247
|
-
return (
|
|
248
|
-
RegressionMetrics(
|
|
249
|
-
coverage=res.get("coverage"),
|
|
250
|
-
mse=res.get("mse"),
|
|
251
|
-
rmse=res.get("rmse"),
|
|
252
|
-
mae=res.get("mae"),
|
|
253
|
-
r2=res.get("r2"),
|
|
254
|
-
explained_variance=res.get("explained_variance"),
|
|
255
|
-
loss=res.get("loss"),
|
|
256
|
-
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
257
|
-
anomaly_score_median=res.get("anomaly_score_median"),
|
|
258
|
-
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
259
|
-
)
|
|
260
|
-
if "mse" in res
|
|
261
|
-
else ClassificationMetrics(
|
|
262
|
-
coverage=res.get("coverage"),
|
|
263
|
-
f1_score=res.get("f1_score"),
|
|
264
|
-
accuracy=res.get("accuracy"),
|
|
265
|
-
loss=res.get("loss"),
|
|
266
|
-
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
267
|
-
anomaly_score_median=res.get("anomaly_score_median"),
|
|
268
|
-
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
269
|
-
roc_auc=res.get("roc_auc"),
|
|
270
|
-
pr_auc=res.get("pr_auc"),
|
|
271
|
-
pr_curve=res.get("pr_curve"),
|
|
272
|
-
roc_curve=res.get("roc_curve"),
|
|
273
|
-
)
|
|
274
|
-
)
|
|
251
|
+
return RegressionMetrics(res) if "mse" in res else ClassificationMetrics(res)
|
|
275
252
|
|
|
276
253
|
job = Job(response["job_id"], lambda: get_result(response["job_id"]))
|
|
277
254
|
return job if background else job.result()
|
|
@@ -404,7 +381,7 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
404
381
|
return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
|
|
405
382
|
|
|
406
383
|
def __repr__(self) -> str:
|
|
407
|
-
return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params/1000000:.0f}M}})"
|
|
384
|
+
return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params / 1000000:.0f}M}})"
|
|
408
385
|
|
|
409
386
|
@classmethod
|
|
410
387
|
def all(cls) -> list[PretrainedEmbeddingModel]:
|
|
@@ -691,21 +668,26 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
691
668
|
return False
|
|
692
669
|
|
|
693
670
|
@classmethod
|
|
694
|
-
def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error"):
|
|
671
|
+
def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error", cascade: bool = False):
|
|
695
672
|
"""
|
|
696
673
|
Delete the finetuned embedding model from the OrcaCloud
|
|
697
674
|
|
|
698
675
|
Params:
|
|
699
676
|
name_or_id: The name or id of the finetuned embedding model
|
|
677
|
+
if_not_exists: What to do if the finetuned embedding model does not exist, defaults to `"error"`.
|
|
678
|
+
Other option is `"ignore"` to do nothing if the model does not exist.
|
|
679
|
+
cascade: If True, also delete all associated memorysets and their predictive models.
|
|
680
|
+
Defaults to False.
|
|
700
681
|
|
|
701
682
|
Raises:
|
|
702
683
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
684
|
+
RuntimeError: If the model has associated memorysets and cascade is False
|
|
703
685
|
"""
|
|
704
686
|
try:
|
|
705
687
|
client = OrcaClient._resolve_client()
|
|
706
688
|
client.DELETE(
|
|
707
689
|
"/finetuned_embedding_model/{name_or_id}",
|
|
708
|
-
params={"name_or_id": name_or_id},
|
|
690
|
+
params={"name_or_id": name_or_id, "cascade": cascade},
|
|
709
691
|
)
|
|
710
692
|
except LookupError:
|
|
711
693
|
if if_not_exists == "error":
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -4,9 +4,9 @@ from uuid import uuid4
|
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
6
|
|
|
7
|
+
from .classification_model import ClassificationMetrics
|
|
7
8
|
from .datasource import Datasource
|
|
8
9
|
from .embedding_model import (
|
|
9
|
-
ClassificationMetrics,
|
|
10
10
|
FinetunedEmbeddingModel,
|
|
11
11
|
PretrainedEmbeddingModel,
|
|
12
12
|
PretrainedEmbeddingModelName,
|
|
@@ -172,6 +172,35 @@ def test_drop_finetuned_model(datasource: Datasource):
|
|
|
172
172
|
FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
173
173
|
|
|
174
174
|
|
|
175
|
+
def test_drop_finetuned_model_with_memoryset_cascade(datasource: Datasource):
|
|
176
|
+
"""Test that cascade=False prevents deletion and cascade=True allows it."""
|
|
177
|
+
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_cascade_delete", datasource)
|
|
178
|
+
memoryset = LabeledMemoryset.create(
|
|
179
|
+
"test_memoryset_for_finetuned_model_cascade",
|
|
180
|
+
datasource=datasource,
|
|
181
|
+
embedding_model=finetuned_model,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Verify memoryset exists and uses the finetuned model
|
|
185
|
+
assert LabeledMemoryset.open(memoryset.name) is not None
|
|
186
|
+
assert memoryset.embedding_model == finetuned_model
|
|
187
|
+
|
|
188
|
+
# Without cascade, deletion should fail
|
|
189
|
+
with pytest.raises(RuntimeError):
|
|
190
|
+
FinetunedEmbeddingModel.drop(finetuned_model.id, cascade=False)
|
|
191
|
+
|
|
192
|
+
# Model and memoryset should still exist
|
|
193
|
+
assert FinetunedEmbeddingModel.exists(finetuned_model.name)
|
|
194
|
+
assert LabeledMemoryset.exists(memoryset.name)
|
|
195
|
+
|
|
196
|
+
# With cascade, deletion should succeed
|
|
197
|
+
FinetunedEmbeddingModel.drop(finetuned_model.id, cascade=True)
|
|
198
|
+
|
|
199
|
+
# Both model and memoryset should be deleted
|
|
200
|
+
assert not FinetunedEmbeddingModel.exists(finetuned_model.name)
|
|
201
|
+
assert not LabeledMemoryset.exists(memoryset.name)
|
|
202
|
+
|
|
203
|
+
|
|
175
204
|
def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
176
205
|
with unauthenticated_client.use():
|
|
177
206
|
with pytest.raises(ValueError, match="Invalid API key"):
|