orca-sdk 0.0.92__py3-none-any.whl → 0.0.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +8 -0
  2. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
  3. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
  4. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  5. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  6. orca_sdk/_generated_api_client/models/__init__.py +10 -0
  7. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
  8. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
  9. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
  10. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  11. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
  12. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  13. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
  14. orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
  15. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
  16. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
  17. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  18. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  19. orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
  20. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  21. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  22. orca_sdk/_utils/data_parsing.py +31 -2
  23. orca_sdk/_utils/data_parsing_test.py +18 -15
  24. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  25. orca_sdk/classification_model.py +32 -12
  26. orca_sdk/classification_model_test.py +95 -34
  27. orca_sdk/conftest.py +87 -25
  28. orca_sdk/datasource.py +56 -12
  29. orca_sdk/datasource_test.py +9 -0
  30. orca_sdk/embedding_model_test.py +6 -5
  31. orca_sdk/memoryset.py +78 -0
  32. orca_sdk/memoryset_test.py +199 -123
  33. orca_sdk/telemetry.py +5 -3
  34. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
  35. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +36 -28
  36. {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
orca_sdk/datasource.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import tempfile
5
+ import zipfile
5
6
  from datetime import datetime
6
7
  from os import PathLike
7
8
  from pathlib import Path
@@ -12,6 +13,7 @@ import pyarrow as pa
12
13
  from datasets import Dataset
13
14
  from torch.utils.data import DataLoader as TorchDataLoader
14
15
  from torch.utils.data import Dataset as TorchDataset
16
+ from tqdm.auto import tqdm
15
17
 
16
18
  from ._generated_api_client.api import (
17
19
  delete_datasource,
@@ -25,6 +27,7 @@ from ._generated_api_client.client import get_client
25
27
  from ._generated_api_client.models import ColumnType, DatasourceMetadata
26
28
  from ._utils.common import CreateMode, DropMode
27
29
  from ._utils.data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
30
+ from ._utils.tqdm_file_reader import TqdmFileReader
28
31
 
29
32
 
30
33
  class Datasource:
@@ -82,6 +85,39 @@ class Datasource:
82
85
  + "})"
83
86
  )
84
87
 
88
+ def download(self, output_path: str | PathLike) -> None:
89
+ """
90
+ Download the datasource as a ZIP and extract them to a specified path.
91
+
92
+ Params:
93
+ output_path: The local file path or directory where the downloaded files will be saved.
94
+
95
+ Returns:
96
+ None
97
+
98
+ Raises:
99
+ RuntimeError: If the download fails.
100
+ """
101
+
102
+ output_path = Path(output_path)
103
+ client = get_client().get_httpx_client()
104
+ url = f"/datasource/{self.id}/download"
105
+ response = client.get(url)
106
+ if response.status_code == 404:
107
+ raise LookupError(f"Datasource {self.id} not found")
108
+ if response.status_code != 200:
109
+ raise RuntimeError(f"Failed to download datasource: {response.status_code} {response.text}")
110
+
111
+ with tempfile.NamedTemporaryFile(suffix=".zip") as tmp_zip:
112
+ tmp_zip.write(response.content)
113
+ tmp_zip.flush()
114
+ with zipfile.ZipFile(tmp_zip.name, "r") as zf:
115
+ output_path.mkdir(parents=True, exist_ok=True)
116
+ for file in zf.namelist():
117
+ out_file = output_path / Path(file).name
118
+ with zf.open(file) as af:
119
+ out_file.write_bytes(af.read())
120
+
85
121
  @classmethod
86
122
  def from_hf_dataset(
87
123
  cls, name: str, dataset: Dataset, if_exists: CreateMode = "error", description: str | None = None
@@ -113,19 +149,27 @@ class Datasource:
113
149
  with tempfile.TemporaryDirectory() as tmp_dir:
114
150
  dataset.save_to_disk(tmp_dir)
115
151
  files = []
116
- for file_path in Path(tmp_dir).iterdir():
117
- buffered_reader = open(file_path, "rb")
118
- files.append(("files", buffered_reader))
119
-
120
- # Do not use Generated client for this endpoint b/c it does not handle files properly
121
- metadata = parse_create_response(
122
- response=client.get_httpx_client().request(
123
- method="post",
124
- url="/datasource/",
125
- files=files,
126
- data={"name": name, "description": description},
152
+
153
+ # Calculate total size for all files
154
+ file_paths = list(Path(tmp_dir).iterdir())
155
+ total_size = sum(file_path.stat().st_size for file_path in file_paths)
156
+
157
+ with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
158
+ for file_path in file_paths:
159
+ buffered_reader = open(file_path, "rb")
160
+ tqdm_reader = TqdmFileReader(buffered_reader, pbar)
161
+ files.append(("files", (file_path.name, tqdm_reader)))
162
+
163
+ # Do not use Generated client for this endpoint b/c it does not handle files properly
164
+ metadata = parse_create_response(
165
+ response=client.get_httpx_client().request(
166
+ method="post",
167
+ url="/datasource/",
168
+ files=files,
169
+ data={"name": name, "description": description},
170
+ )
127
171
  )
128
- )
172
+
129
173
  return cls(metadata=metadata)
130
174
 
131
175
  @classmethod
@@ -1,3 +1,5 @@
1
+ import os
2
+ import tempfile
1
3
  from uuid import uuid4
2
4
 
3
5
  import pytest
@@ -94,3 +96,10 @@ def test_drop_datasource_unauthorized(datasource, unauthorized):
94
96
  def test_drop_datasource_invalid_input():
95
97
  with pytest.raises(ValueError, match=r"Invalid input:.*"):
96
98
  Datasource.drop("not valid id")
99
+
100
+
101
+ def test_download_datasource(datasource):
102
+ with tempfile.TemporaryDirectory() as temp_dir:
103
+ output_path = os.path.join(temp_dir, "datasource.zip")
104
+ datasource.download(output_path)
105
+ assert os.path.exists(output_path)
@@ -53,7 +53,7 @@ def test_embed_text_unauthenticated(unauthenticated):
53
53
 
54
54
  @pytest.fixture(scope="session")
55
55
  def finetuned_model(datasource) -> FinetunedEmbeddingModel:
56
- return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, value_column="text")
56
+ return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
57
57
 
58
58
 
59
59
  def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
@@ -65,8 +65,10 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
65
65
  assert finetuned_model._status == TaskStatus.COMPLETED
66
66
 
67
67
 
68
- def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
69
- finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
68
+ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
69
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
70
+ "test_finetuned_model_from_memoryset", readonly_memoryset
71
+ )
70
72
  assert finetuned_model is not None
71
73
  assert finetuned_model.name == "test_finetuned_model_from_memoryset"
72
74
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
@@ -109,7 +111,6 @@ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_mode
109
111
  "test_memoryset_finetuned_model",
110
112
  datasource,
111
113
  embedding_model=finetuned_model,
112
- value_column="text",
113
114
  )
114
115
  assert memoryset is not None
115
116
  assert memoryset.name == "test_memoryset_finetuned_model"
@@ -152,7 +153,7 @@ def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: Finetu
152
153
 
153
154
 
154
155
  def test_drop_finetuned_model(datasource: Datasource):
155
- PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource, value_column="text")
156
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
156
157
  assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
157
158
  FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
158
159
  with pytest.raises(LookupError):
orca_sdk/memoryset.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, cast, overload
7
7
 
8
8
  import pandas as pd
9
9
  import pyarrow as pa
10
+ from attrs import fields
10
11
  from datasets import Dataset
11
12
  from torch.utils.data import DataLoader as TorchDataLoader
12
13
  from torch.utils.data import Dataset as TorchDataset
@@ -29,11 +30,14 @@ from ._generated_api_client.api import (
29
30
  memoryset_lookup_gpu,
30
31
  potential_duplicate_groups,
31
32
  query_memoryset,
33
+ suggest_cascading_edits,
32
34
  update_memories_gpu,
33
35
  update_memory_gpu,
34
36
  update_memoryset,
35
37
  )
36
38
  from ._generated_api_client.models import (
39
+ CascadeEditSuggestionsRequest,
40
+ CascadingEditSuggestion,
37
41
  CloneLabeledMemorysetRequest,
38
42
  CreateLabeledMemorysetRequest,
39
43
  DeleteMemoriesRequest,
@@ -1180,6 +1184,63 @@ class LabeledMemoryset:
1180
1184
  updated_memories = [LabeledMemory(self.id, memory) for memory in response]
1181
1185
  return updated_memories[0] if isinstance(updates, dict) else updated_memories
1182
1186
 
1187
+ def get_cascading_edits_suggestions(
1188
+ self: LabeledMemoryset,
1189
+ memory: LabeledMemory,
1190
+ *,
1191
+ old_label: int,
1192
+ new_label: int,
1193
+ max_neighbors: int = 50,
1194
+ max_validation_neighbors: int = 10,
1195
+ similarity_threshold: float | None = None,
1196
+ only_if_has_old_label: bool = True,
1197
+ exclude_if_new_label: bool = True,
1198
+ suggestion_cooldown_time: float = 3600.0 * 24.0, # 1 day
1199
+ label_confirmation_cooldown_time: float = 3600.0 * 24.0 * 7, # 1 week
1200
+ ) -> list[CascadingEditSuggestion]:
1201
+ """
1202
+ Suggests cascading edits for a given memory based on nearby points with similar labels.
1203
+
1204
+ This function is triggered after a user changes a memory's label. It looks for nearby
1205
+ candidates in embedding space that may be subject to similar relabeling and returns them
1206
+ as suggestions. The system uses scoring heuristics, label filters, and cooldown tracking
1207
+ to reduce noise and improve usability.
1208
+
1209
+ Params:
1210
+ memory: The memory whose label was just changed.
1211
+ old_label: The label this memory used to have.
1212
+ new_label: The label it was changed to.
1213
+ max_neighbors: Maximum number of neighbors to consider.
1214
+ max_validation_neighbors: Maximum number of neighbors to use for label suggestion.
1215
+ similarity_threshold: If set, only include neighbors with a lookup score above this threshold.
1216
+ only_if_has_old_label: If True, only consider neighbors that have the old label.
1217
+ exclude_if_new_label: If True, exclude neighbors that already have the new label.
1218
+ suggestion_cooldown_time: Minimum time (in seconds) since the last suggestion for a neighbor
1219
+ to be considered again.
1220
+ label_confirmation_cooldown_time: Minimum time (in seconds) since a neighbor's label was confirmed
1221
+ to be considered for suggestions.
1222
+ _current_time: Optional override for the current timestamp (useful for testing).
1223
+
1224
+ Returns:
1225
+ A list of CascadingEditSuggestion objects, each containing a neighbor and the suggested new label.
1226
+ """
1227
+
1228
+ return suggest_cascading_edits(
1229
+ name_or_id=self.id,
1230
+ memory_id=memory.memory_id,
1231
+ body=CascadeEditSuggestionsRequest(
1232
+ old_label=old_label,
1233
+ new_label=new_label,
1234
+ max_neighbors=max_neighbors,
1235
+ max_validation_neighbors=max_validation_neighbors,
1236
+ similarity_threshold=similarity_threshold,
1237
+ only_if_has_old_label=only_if_has_old_label,
1238
+ exclude_if_new_label=exclude_if_new_label,
1239
+ suggestion_cooldown_time=suggestion_cooldown_time,
1240
+ label_confirmation_cooldown_time=label_confirmation_cooldown_time,
1241
+ ),
1242
+ )
1243
+
1183
1244
  def delete(self, memory_id: str | Iterable[str]) -> None:
1184
1245
  """
1185
1246
  Delete memories from the memoryset
@@ -1229,6 +1290,9 @@ class LabeledMemoryset:
1229
1290
  Returns:
1230
1291
  dictionary with aggregate metrics for each analysis that was run
1231
1292
 
1293
+ Raises:
1294
+ ValueError: If an invalid analysis name is provided
1295
+
1232
1296
  Examples:
1233
1297
  Run label and duplicate analysis:
1234
1298
  >>> memoryset.analyze("label", {"name": "duplicate", "possible_duplicate_threshold": 0.99})
@@ -1263,12 +1327,26 @@ class LabeledMemoryset:
1263
1327
  Display label analysis to review potential mislabelings:
1264
1328
  >>> memoryset.display_label_analysis()
1265
1329
  """
1330
+
1331
+ # Get valid analysis names from MemorysetAnalysisConfigs
1332
+ valid_analysis_names = {
1333
+ field.name for field in fields(MemorysetAnalysisConfigs) if field.name != "additional_properties"
1334
+ }
1335
+
1266
1336
  configs: dict[str, dict] = {}
1267
1337
  for analysis in analyses:
1268
1338
  if isinstance(analysis, str):
1339
+ error_msg = (
1340
+ f"Invalid analysis name: {analysis}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
1341
+ )
1342
+ if analysis not in valid_analysis_names:
1343
+ raise ValueError(error_msg)
1269
1344
  configs[analysis] = {}
1270
1345
  else:
1271
1346
  name = analysis.pop("name") # type: ignore
1347
+ error_msg = f"Invalid analysis name: {name}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
1348
+ if name not in valid_analysis_names:
1349
+ raise ValueError(error_msg)
1272
1350
  configs[name] = analysis # type: ignore
1273
1351
 
1274
1352
  analysis = analyze_memoryset(