orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orca_sdk/conftest.py CHANGED
@@ -24,15 +24,6 @@ os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:15
24
24
  os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
25
25
 
26
26
 
27
- def skip_in_prod(reason: str):
28
- """Custom decorator to skip tests when running against production API"""
29
- PROD_API_URLs = ["https://api.orcadb.ai", "https://api.staging.orcadb.ai"]
30
- return pytest.mark.skipif(
31
- os.environ["ORCA_API_URL"] in PROD_API_URLs,
32
- reason=reason,
33
- )
34
-
35
-
36
27
  def skip_in_ci(reason: str):
37
28
  """Custom decorator to skip tests when running in CI"""
38
29
  return pytest.mark.skipif(
@@ -201,6 +192,11 @@ SAMPLE_DATA = [
201
192
  ]
202
193
 
203
194
 
195
+ @pytest.fixture(scope="session")
196
+ def data() -> list[dict]:
197
+ return SAMPLE_DATA
198
+
199
+
204
200
  @pytest.fixture(scope="session")
205
201
  def hf_dataset(label_names: list[str]) -> Dataset:
206
202
  return Dataset.from_list(
@@ -232,6 +228,11 @@ EVAL_DATASET = [
232
228
  ]
233
229
 
234
230
 
231
+ @pytest.fixture(scope="session")
232
+ def eval_data() -> list[dict]:
233
+ return EVAL_DATASET
234
+
235
+
235
236
  @pytest.fixture(scope="session")
236
237
  def eval_datasource() -> Datasource:
237
238
  eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
@@ -288,6 +289,7 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
288
289
  datasource=datasource,
289
290
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
290
291
  source_id_column="source_id",
292
+ partition_id_column="partition_id",
291
293
  max_seq_length_override=32,
292
294
  if_exists="open",
293
295
  )
@@ -297,13 +299,7 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
297
299
  # Restore the memoryset to a clean state for the next test.
298
300
  with OrcaClient(api_key=api_key).use():
299
301
  if LabeledMemoryset.exists("test_writable_memoryset"):
300
- memoryset.refresh()
301
-
302
- memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
303
-
304
- if memory_ids:
305
- memoryset.delete(memory_ids)
306
- memoryset.refresh()
302
+ memoryset.truncate()
307
303
  assert len(memoryset) == 0
308
304
  memoryset.insert(SAMPLE_DATA)
309
305
  # If the test dropped the memoryset, do nothing — it will be recreated on the next use.
@@ -380,3 +376,88 @@ def partitioned_regression_model(readonly_partitioned_scored_memoryset: ScoredMe
380
376
  description="test_partitioned_regression_description",
381
377
  )
382
378
  return model
379
+
380
+
381
+ @pytest.fixture(scope="function")
382
+ def fully_partitioned_classification_resources() -> (
383
+ Generator[tuple[Datasource, LabeledMemoryset, ClassificationModel], None, None]
384
+ ):
385
+ data = [
386
+ {"value": "i love soup", "label": 0, "partition_id": "p1"},
387
+ {"value": "cats are cute", "label": 1, "partition_id": "p1"},
388
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
389
+ {"value": "i love cats", "label": 1, "partition_id": "p2"},
390
+ {"value": "everyone loves cats", "label": 1, "partition_id": "p2"},
391
+ {"value": "soup is good", "label": 0, "partition_id": "p1"},
392
+ {"value": "cats are amazing animals", "label": 1, "partition_id": "p2"},
393
+ {"value": "tomato soup is delicious", "label": 0, "partition_id": "p1"},
394
+ {"value": "cats love to play", "label": 1, "partition_id": "p2"},
395
+ {"value": "i enjoy eating soup", "label": 0, "partition_id": "p1"},
396
+ {"value": "my cat is fluffy", "label": 1, "partition_id": "p2"},
397
+ {"value": "chicken soup is tasty", "label": 0, "partition_id": "p1"},
398
+ {"value": "cats are playful pets", "label": 1, "partition_id": "p2"},
399
+ {"value": "soup warms the soul", "label": 0, "partition_id": "p1"},
400
+ {"value": "cats have soft fur", "label": 1, "partition_id": "p2"},
401
+ {"value": "vegetable soup is healthy", "label": 0, "partition_id": "p1"},
402
+ ]
403
+
404
+ datasource = None
405
+ memoryset = None
406
+ classification_model = None
407
+ try:
408
+ datasource = Datasource.from_list("fully_partitioned_classification_datasource", data)
409
+ memoryset = LabeledMemoryset.create(
410
+ "fully_partitioned_classification_memoryset",
411
+ datasource=datasource,
412
+ label_names=["soup", "cats"],
413
+ partition_id_column="partition_id",
414
+ )
415
+ classification_model = ClassificationModel.create("fully_partitioned_classification_model", memoryset=memoryset)
416
+ yield (datasource, memoryset, classification_model)
417
+ finally:
418
+ # Clean up in reverse order of creation
419
+ ClassificationModel.drop("fully_partitioned_classification_model", if_not_exists="ignore")
420
+ LabeledMemoryset.drop("fully_partitioned_classification_memoryset", if_not_exists="ignore")
421
+ Datasource.drop("fully_partitioned_classification_datasource", if_not_exists="ignore")
422
+
423
+
424
+ @pytest.fixture(scope="function")
425
+ def fully_partitioned_regression_resources() -> (
426
+ Generator[tuple[Datasource, ScoredMemoryset, RegressionModel], None, None]
427
+ ):
428
+ data = [
429
+ {"value": "i love soup", "score": 0.1, "partition_id": "p1"},
430
+ {"value": "cats are cute", "score": 0.9, "partition_id": "p1"},
431
+ {"value": "soup is good", "score": 0.1, "partition_id": "p1"},
432
+ {"value": "i love cats", "score": 0.9, "partition_id": "p2"},
433
+ {"value": "everyone loves cats", "score": 0.9, "partition_id": "p2"},
434
+ {"value": "soup is good", "score": 0.1, "partition_id": "p1"},
435
+ {"value": "cats are amazing animals", "score": 0.9, "partition_id": "p2"},
436
+ {"value": "tomato soup is delicious", "score": 0.1, "partition_id": "p1"},
437
+ {"value": "cats love to play", "score": 0.9, "partition_id": "p2"},
438
+ {"value": "i enjoy eating soup", "score": 0.1, "partition_id": "p1"},
439
+ {"value": "my cat is fluffy", "score": 0.9, "partition_id": "p2"},
440
+ {"value": "chicken soup is tasty", "score": 0.1, "partition_id": "p1"},
441
+ {"value": "cats are playful pets", "score": 0.9, "partition_id": "p2"},
442
+ {"value": "soup warms the soul", "score": 0.1, "partition_id": "p1"},
443
+ {"value": "cats have soft fur", "score": 0.9, "partition_id": "p2"},
444
+ {"value": "vegetable soup is healthy", "score": 0.1, "partition_id": "p1"},
445
+ ]
446
+
447
+ datasource = None
448
+ memoryset = None
449
+ regression_model = None
450
+ try:
451
+ datasource = Datasource.from_list("fully_partitioned_regression_datasource", data)
452
+ memoryset = ScoredMemoryset.create(
453
+ "fully_partitioned_regression_memoryset",
454
+ datasource=datasource,
455
+ partition_id_column="partition_id",
456
+ )
457
+ regression_model = RegressionModel.create("fully_partitioned_regression_model", memoryset=memoryset)
458
+ yield (datasource, memoryset, regression_model)
459
+ finally:
460
+ # Clean up in reverse order of creation
461
+ RegressionModel.drop("fully_partitioned_regression_model", if_not_exists="ignore")
462
+ ScoredMemoryset.drop("fully_partitioned_regression_memoryset", if_not_exists="ignore")
463
+ Datasource.drop("fully_partitioned_regression_datasource", if_not_exists="ignore")
@@ -75,7 +75,7 @@ def test_create_api_key_already_exists():
75
75
  OrcaCredentials.create_api_key("orca_sdk_test")
76
76
 
77
77
 
78
- def test_set_api_key(api_key):
78
+ def test_use_client(api_key):
79
79
  client = OrcaClient(api_key=str(uuid4()))
80
80
  with client.use():
81
81
  assert not OrcaCredentials.is_authenticated()
@@ -91,17 +91,14 @@ def test_set_base_url(api_key):
91
91
  assert client.base_url == "http://localhost:1583"
92
92
 
93
93
 
94
- # deprecated methods:
95
-
96
-
97
- def test_deprecated_set_api_key(api_key):
94
+ def test_set_api_key(api_key):
98
95
  with OrcaClient(api_key=str(uuid4())).use():
99
96
  assert not OrcaCredentials.is_authenticated()
100
97
  OrcaCredentials.set_api_key(api_key)
101
98
  assert OrcaCredentials.is_authenticated()
102
99
 
103
100
 
104
- def test_deprecated_set_invalid_api_key(api_key):
101
+ def test_set_invalid_api_key(api_key):
105
102
  with OrcaClient(api_key=api_key).use():
106
103
  assert OrcaCredentials.is_authenticated()
107
104
  with pytest.raises(ValueError, match="Invalid API key"):
@@ -109,13 +106,13 @@ def test_deprecated_set_invalid_api_key(api_key):
109
106
  assert not OrcaCredentials.is_authenticated()
110
107
 
111
108
 
112
- def test_deprecated_set_api_url(api_key):
109
+ def test_set_api_url(api_key):
113
110
  with OrcaClient(api_key=api_key).use():
114
111
  OrcaCredentials.set_api_url("http://api.orcadb.ai")
115
112
  assert str(OrcaClient._resolve_client().base_url) == "http://api.orcadb.ai"
116
113
 
117
114
 
118
- def test_deprecated_set_invalid_api_url(api_key):
115
+ def test_set_invalid_api_url(api_key):
119
116
  with OrcaClient(api_key=api_key).use():
120
117
  with pytest.raises(ValueError, match="No API found at http://localhost:1582"):
121
118
  OrcaCredentials.set_api_url("http://localhost:1582")
orca_sdk/datasource.py CHANGED
@@ -1,28 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import tempfile
5
4
  import zipfile
6
5
  from datetime import datetime
7
6
  from io import BytesIO
8
7
  from os import PathLike
9
8
  from pathlib import Path
10
- from typing import Any, Literal, Union, cast
9
+ from typing import TYPE_CHECKING, Any, Literal, Union, cast
11
10
 
12
- import pandas as pd
13
- import pyarrow as pa
14
- from datasets import Dataset, DatasetDict
15
11
  from httpx._types import FileTypes # type: ignore
16
- from pyarrow import parquet
17
- from torch.utils.data import DataLoader as TorchDataLoader
18
- from torch.utils.data import Dataset as TorchDataset
19
12
  from tqdm.auto import tqdm
20
13
 
21
- from ._utils.common import CreateMode, DropMode
22
- from ._utils.data_parsing import hf_dataset_from_torch
14
+ from ._utils.common import CreateMode, DropMode, logger
15
+ from ._utils.torch_parsing import list_from_torch
23
16
  from ._utils.tqdm_file_reader import TqdmFileReader
24
17
  from .client import DatasourceMetadata, OrcaClient
25
18
 
19
+ if TYPE_CHECKING:
20
+ # These are peer dependencies that are used for types only
21
+ from datasets import Dataset as HFDataset # type: ignore
22
+ from datasets import DatasetDict as HFDatasetDict # type: ignore
23
+ from pandas import DataFrame as PandasDataFrame # type: ignore
24
+ from pyarrow import Table as PyArrowTable # type: ignore
25
+ from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
26
+ from torch.utils.data import Dataset as TorchDataset # type: ignore
27
+
26
28
 
27
29
  def _upload_files_to_datasource(
28
30
  name: str,
@@ -144,7 +146,7 @@ class Datasource:
144
146
 
145
147
  @classmethod
146
148
  def from_hf_dataset(
147
- cls, name: str, dataset: Dataset, if_exists: CreateMode = "error", description: str | None = None
149
+ cls, name: str, dataset: HFDataset, if_exists: CreateMode = "error", description: str | None = None
148
150
  ) -> Datasource:
149
151
  """
150
152
  Create a new datasource from a Hugging Face Dataset
@@ -181,7 +183,7 @@ class Datasource:
181
183
  def from_hf_dataset_dict(
182
184
  cls,
183
185
  name: str,
184
- dataset_dict: DatasetDict,
186
+ dataset_dict: HFDatasetDict,
185
187
  if_exists: CreateMode = "error",
186
188
  description: dict[str, str | None] | str | None = None,
187
189
  ) -> dict[str, Datasource]:
@@ -237,8 +239,8 @@ class Datasource:
237
239
  Raises:
238
240
  ValueError: If the datasource already exists and if_exists is `"error"`
239
241
  """
240
- hf_dataset = hf_dataset_from_torch(torch_data, column_names=column_names)
241
- return cls.from_hf_dataset(name, hf_dataset, if_exists=if_exists, description=description)
242
+ data_list = list_from_torch(torch_data, column_names=column_names)
243
+ return cls.from_list(name, data_list, if_exists=if_exists, description=description)
242
244
 
243
245
  @classmethod
244
246
  def from_list(
@@ -312,7 +314,7 @@ class Datasource:
312
314
 
313
315
  @classmethod
314
316
  def from_pandas(
315
- cls, name: str, dataframe: pd.DataFrame, if_exists: CreateMode = "error", description: str | None = None
317
+ cls, name: str, dataframe: PandasDataFrame, if_exists: CreateMode = "error", description: str | None = None
316
318
  ) -> Datasource:
317
319
  """
318
320
  Create a new datasource from a pandas DataFrame
@@ -324,18 +326,28 @@ class Datasource:
324
326
  `"error"`. Other option is `"open"` to open the existing datasource.
325
327
  description: Optional description for the datasource
326
328
 
329
+ Notes:
330
+ Data type precision may be lost during upload unless the [`datasets`][datasets] library is installed.
331
+
327
332
  Returns:
328
333
  A handle to the new datasource in the OrcaCloud
329
334
 
330
335
  Raises:
331
336
  ValueError: If the datasource already exists and if_exists is `"error"`
337
+ ImportError: If the upload dependency group is not installed
332
338
  """
333
- dataset = Dataset.from_pandas(dataframe)
334
- return cls.from_hf_dataset(name, dataset, if_exists=if_exists, description=description)
339
+ try:
340
+ from datasets import Dataset # type: ignore
341
+
342
+ return cls.from_hf_dataset(
343
+ name, Dataset.from_pandas(dataframe), if_exists=if_exists, description=description
344
+ )
345
+ except ImportError:
346
+ return cls.from_dict(name, dataframe.to_dict(orient="list"), if_exists=if_exists, description=description)
335
347
 
336
348
  @classmethod
337
349
  def from_arrow(
338
- cls, name: str, pyarrow_table: pa.Table, if_exists: CreateMode = "error", description: str | None = None
350
+ cls, name: str, pyarrow_table: PyArrowTable, if_exists: CreateMode = "error", description: str | None = None
339
351
  ) -> Datasource:
340
352
  """
341
353
  Create a new datasource from a pyarrow Table
@@ -358,6 +370,9 @@ class Datasource:
358
370
  if existing is not None:
359
371
  return existing
360
372
 
373
+ # peer dependency that is guaranteed to exist if the user provided a pyarrow table
374
+ from pyarrow import parquet # type: ignore
375
+
361
376
  # Write to bytes buffer
362
377
  buffer = BytesIO()
363
378
  parquet.write_table(pyarrow_table, buffer)
@@ -399,6 +414,7 @@ class Datasource:
399
414
 
400
415
  Raises:
401
416
  ValueError: If the datasource already exists and if_exists is `"error"`
417
+ ImportError: If the path is a directory and [`datasets`][datasets] is not installed
402
418
  """
403
419
  # Check if datasource already exists and handle accordingly
404
420
  existing = _handle_existing_datasource(name, if_exists)
@@ -409,6 +425,13 @@ class Datasource:
409
425
 
410
426
  # For dataset directories, use the upload endpoint with multiple files
411
427
  if file_path.is_dir():
428
+ try:
429
+ from datasets import Dataset # type: ignore
430
+ except ImportError as e:
431
+ raise ImportError(
432
+ "The path is a directory, we only support uploading directories that contain saved HuggingFace datasets but datasets is not installed."
433
+ ) from e
434
+
412
435
  return cls.from_hf_dataset(
413
436
  name, Dataset.load_from_disk(file_path), if_exists=if_exists, description=description
414
437
  )
@@ -479,7 +502,7 @@ class Datasource:
479
502
  try:
480
503
  client = OrcaClient._resolve_client()
481
504
  client.DELETE("/datasource/{name_or_id}", params={"name_or_id": name_or_id})
482
- logging.info(f"Deleted datasource {name_or_id}")
505
+ logger.info(f"Deleted datasource {name_or_id}")
483
506
  except LookupError:
484
507
  if if_not_exists == "error":
485
508
  raise
@@ -561,9 +584,9 @@ class Datasource:
561
584
  with zipfile.ZipFile(output_path, "r") as zip_ref:
562
585
  zip_ref.extractall(extract_dir)
563
586
  output_path.unlink() # Remove the zip file after extraction
564
- logging.info(f"Downloaded {extract_dir}")
587
+ logger.info(f"Downloaded {extract_dir}")
565
588
  else:
566
- logging.info(f"Downloaded {output_path}")
589
+ logger.info(f"Downloaded {output_path}")
567
590
 
568
591
  def to_list(self) -> list[dict]:
569
592
  """
@@ -5,8 +5,6 @@ from typing import cast
5
5
  from uuid import uuid4
6
6
 
7
7
  import numpy as np
8
- import pandas as pd
9
- import pyarrow as pa
10
8
  import pytest
11
9
  from datasets import Dataset
12
10
 
@@ -137,6 +135,8 @@ def test_from_dict():
137
135
 
138
136
 
139
137
  def test_from_pandas():
138
+ pd = pytest.importorskip("pandas")
139
+
140
140
  # Test creating datasource from pandas DataFrame
141
141
  df = pd.DataFrame(
142
142
  {
@@ -152,6 +152,8 @@ def test_from_pandas():
152
152
 
153
153
 
154
154
  def test_from_arrow():
155
+ pa = pytest.importorskip("pyarrow")
156
+
155
157
  # Test creating datasource from pyarrow Table
156
158
  table = pa.table(
157
159
  {
@@ -205,6 +207,8 @@ def test_from_dict_already_exists():
205
207
 
206
208
 
207
209
  def test_from_pandas_already_exists():
210
+ pd = pytest.importorskip("pandas")
211
+
208
212
  # Test the if_exists parameter with from_pandas
209
213
  df = pd.DataFrame({"column1": [1], "column2": ["a"]})
210
214
  name = f"test_pandas_exists_{uuid4()}"
@@ -224,6 +228,8 @@ def test_from_pandas_already_exists():
224
228
 
225
229
 
226
230
  def test_from_arrow_already_exists():
231
+ pa = pytest.importorskip("pyarrow")
232
+
227
233
  # Test the if_exists parameter with from_arrow
228
234
  table = pa.table({"column1": [1], "column2": ["a"]})
229
235
  name = f"test_arrow_exists_{uuid4()}"
@@ -4,8 +4,7 @@ from abc import ABC, abstractmethod
4
4
  from datetime import datetime
5
5
  from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
6
6
 
7
- from ._shared.metrics import ClassificationMetrics, RegressionMetrics
8
- from ._utils.common import UNSET, CreateMode, DropMode
7
+ from ._utils.common import CreateMode, DropMode
9
8
  from .client import (
10
9
  EmbeddingEvaluationRequest,
11
10
  EmbeddingFinetuningMethod,
@@ -20,7 +19,9 @@ from .datasource import Datasource
20
19
  from .job import Job, Status
21
20
 
22
21
  if TYPE_CHECKING:
22
+ from .classification_model import ClassificationMetrics
23
23
  from .memoryset import LabeledMemoryset, ScoredMemoryset
24
+ from .regression_model import RegressionMetrics
24
25
 
25
26
 
26
27
  class EmbeddingModelBase(ABC):
@@ -230,6 +231,9 @@ class EmbeddingModelBase(ABC):
230
231
  raise ValueError("Invalid embedding model")
231
232
 
232
233
  def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
234
+ from .classification_model import ClassificationMetrics
235
+ from .regression_model import RegressionMetrics
236
+
233
237
  client = OrcaClient._resolve_client()
234
238
  if isinstance(self, PretrainedEmbeddingModel):
235
239
  res = client.GET(
@@ -244,34 +248,7 @@ class EmbeddingModelBase(ABC):
244
248
  else:
245
249
  raise ValueError("Invalid embedding model")
246
250
  assert res is not None
247
- return (
248
- RegressionMetrics(
249
- coverage=res.get("coverage"),
250
- mse=res.get("mse"),
251
- rmse=res.get("rmse"),
252
- mae=res.get("mae"),
253
- r2=res.get("r2"),
254
- explained_variance=res.get("explained_variance"),
255
- loss=res.get("loss"),
256
- anomaly_score_mean=res.get("anomaly_score_mean"),
257
- anomaly_score_median=res.get("anomaly_score_median"),
258
- anomaly_score_variance=res.get("anomaly_score_variance"),
259
- )
260
- if "mse" in res
261
- else ClassificationMetrics(
262
- coverage=res.get("coverage"),
263
- f1_score=res.get("f1_score"),
264
- accuracy=res.get("accuracy"),
265
- loss=res.get("loss"),
266
- anomaly_score_mean=res.get("anomaly_score_mean"),
267
- anomaly_score_median=res.get("anomaly_score_median"),
268
- anomaly_score_variance=res.get("anomaly_score_variance"),
269
- roc_auc=res.get("roc_auc"),
270
- pr_auc=res.get("pr_auc"),
271
- pr_curve=res.get("pr_curve"),
272
- roc_curve=res.get("roc_curve"),
273
- )
274
- )
251
+ return RegressionMetrics(res) if "mse" in res else ClassificationMetrics(res)
275
252
 
276
253
  job = Job(response["job_id"], lambda: get_result(response["job_id"]))
277
254
  return job if background else job.result()
@@ -404,7 +381,7 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
404
381
  return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
405
382
 
406
383
  def __repr__(self) -> str:
407
- return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params/1000000:.0f}M}})"
384
+ return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params / 1000000:.0f}M}})"
408
385
 
409
386
  @classmethod
410
387
  def all(cls) -> list[PretrainedEmbeddingModel]:
@@ -691,21 +668,26 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
691
668
  return False
692
669
 
693
670
  @classmethod
694
- def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error"):
671
+ def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error", cascade: bool = False):
695
672
  """
696
673
  Delete the finetuned embedding model from the OrcaCloud
697
674
 
698
675
  Params:
699
676
  name_or_id: The name or id of the finetuned embedding model
677
+ if_not_exists: What to do if the finetuned embedding model does not exist, defaults to `"error"`.
678
+ Other option is `"ignore"` to do nothing if the model does not exist.
679
+ cascade: If True, also delete all associated memorysets and their predictive models.
680
+ Defaults to False.
700
681
 
701
682
  Raises:
702
683
  LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
684
+ RuntimeError: If the model has associated memorysets and cascade is False
703
685
  """
704
686
  try:
705
687
  client = OrcaClient._resolve_client()
706
688
  client.DELETE(
707
689
  "/finetuned_embedding_model/{name_or_id}",
708
- params={"name_or_id": name_or_id},
690
+ params={"name_or_id": name_or_id, "cascade": cascade},
709
691
  )
710
692
  except LookupError:
711
693
  if if_not_exists == "error":
@@ -4,9 +4,9 @@ from uuid import uuid4
4
4
 
5
5
  import pytest
6
6
 
7
+ from .classification_model import ClassificationMetrics
7
8
  from .datasource import Datasource
8
9
  from .embedding_model import (
9
- ClassificationMetrics,
10
10
  FinetunedEmbeddingModel,
11
11
  PretrainedEmbeddingModel,
12
12
  PretrainedEmbeddingModelName,
@@ -172,6 +172,35 @@ def test_drop_finetuned_model(datasource: Datasource):
172
172
  FinetunedEmbeddingModel.open("finetuned_model_to_delete")
173
173
 
174
174
 
175
+ def test_drop_finetuned_model_with_memoryset_cascade(datasource: Datasource):
176
+ """Test that cascade=False prevents deletion and cascade=True allows it."""
177
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_cascade_delete", datasource)
178
+ memoryset = LabeledMemoryset.create(
179
+ "test_memoryset_for_finetuned_model_cascade",
180
+ datasource=datasource,
181
+ embedding_model=finetuned_model,
182
+ )
183
+
184
+ # Verify memoryset exists and uses the finetuned model
185
+ assert LabeledMemoryset.open(memoryset.name) is not None
186
+ assert memoryset.embedding_model == finetuned_model
187
+
188
+ # Without cascade, deletion should fail
189
+ with pytest.raises(RuntimeError):
190
+ FinetunedEmbeddingModel.drop(finetuned_model.id, cascade=False)
191
+
192
+ # Model and memoryset should still exist
193
+ assert FinetunedEmbeddingModel.exists(finetuned_model.name)
194
+ assert LabeledMemoryset.exists(memoryset.name)
195
+
196
+ # With cascade, deletion should succeed
197
+ FinetunedEmbeddingModel.drop(finetuned_model.id, cascade=True)
198
+
199
+ # Both model and memoryset should be deleted
200
+ assert not FinetunedEmbeddingModel.exists(finetuned_model.name)
201
+ assert not LabeledMemoryset.exists(memoryset.name)
202
+
203
+
175
204
  def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
176
205
  with unauthenticated_client.use():
177
206
  with pytest.raises(ValueError, match="Invalid API key"):