orca-sdk 0.0.92__py3-none-any.whl → 0.0.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +8 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/models/__init__.py +10 -0
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_utils/data_parsing.py +31 -2
- orca_sdk/_utils/data_parsing_test.py +18 -15
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/classification_model.py +32 -12
- orca_sdk/classification_model_test.py +95 -34
- orca_sdk/conftest.py +87 -25
- orca_sdk/datasource.py +56 -12
- orca_sdk/datasource_test.py +9 -0
- orca_sdk/embedding_model_test.py +6 -5
- orca_sdk/memoryset.py +78 -0
- orca_sdk/memoryset_test.py +199 -123
- orca_sdk/telemetry.py +5 -3
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +36 -28
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
orca_sdk/datasource.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import tempfile
|
|
5
|
+
import zipfile
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from os import PathLike
|
|
7
8
|
from pathlib import Path
|
|
@@ -12,6 +13,7 @@ import pyarrow as pa
|
|
|
12
13
|
from datasets import Dataset
|
|
13
14
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
14
15
|
from torch.utils.data import Dataset as TorchDataset
|
|
16
|
+
from tqdm.auto import tqdm
|
|
15
17
|
|
|
16
18
|
from ._generated_api_client.api import (
|
|
17
19
|
delete_datasource,
|
|
@@ -25,6 +27,7 @@ from ._generated_api_client.client import get_client
|
|
|
25
27
|
from ._generated_api_client.models import ColumnType, DatasourceMetadata
|
|
26
28
|
from ._utils.common import CreateMode, DropMode
|
|
27
29
|
from ._utils.data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
|
|
30
|
+
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
class Datasource:
|
|
@@ -82,6 +85,39 @@ class Datasource:
|
|
|
82
85
|
+ "})"
|
|
83
86
|
)
|
|
84
87
|
|
|
88
|
+
def download(self, output_path: str | PathLike) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Download the datasource as a ZIP and extract them to a specified path.
|
|
91
|
+
|
|
92
|
+
Params:
|
|
93
|
+
output_path: The local file path or directory where the downloaded files will be saved.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
None
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
RuntimeError: If the download fails.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
output_path = Path(output_path)
|
|
103
|
+
client = get_client().get_httpx_client()
|
|
104
|
+
url = f"/datasource/{self.id}/download"
|
|
105
|
+
response = client.get(url)
|
|
106
|
+
if response.status_code == 404:
|
|
107
|
+
raise LookupError(f"Datasource {self.id} not found")
|
|
108
|
+
if response.status_code != 200:
|
|
109
|
+
raise RuntimeError(f"Failed to download datasource: {response.status_code} {response.text}")
|
|
110
|
+
|
|
111
|
+
with tempfile.NamedTemporaryFile(suffix=".zip") as tmp_zip:
|
|
112
|
+
tmp_zip.write(response.content)
|
|
113
|
+
tmp_zip.flush()
|
|
114
|
+
with zipfile.ZipFile(tmp_zip.name, "r") as zf:
|
|
115
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
for file in zf.namelist():
|
|
117
|
+
out_file = output_path / Path(file).name
|
|
118
|
+
with zf.open(file) as af:
|
|
119
|
+
out_file.write_bytes(af.read())
|
|
120
|
+
|
|
85
121
|
@classmethod
|
|
86
122
|
def from_hf_dataset(
|
|
87
123
|
cls, name: str, dataset: Dataset, if_exists: CreateMode = "error", description: str | None = None
|
|
@@ -113,19 +149,27 @@ class Datasource:
|
|
|
113
149
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
114
150
|
dataset.save_to_disk(tmp_dir)
|
|
115
151
|
files = []
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
files
|
|
126
|
-
|
|
152
|
+
|
|
153
|
+
# Calculate total size for all files
|
|
154
|
+
file_paths = list(Path(tmp_dir).iterdir())
|
|
155
|
+
total_size = sum(file_path.stat().st_size for file_path in file_paths)
|
|
156
|
+
|
|
157
|
+
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
|
|
158
|
+
for file_path in file_paths:
|
|
159
|
+
buffered_reader = open(file_path, "rb")
|
|
160
|
+
tqdm_reader = TqdmFileReader(buffered_reader, pbar)
|
|
161
|
+
files.append(("files", (file_path.name, tqdm_reader)))
|
|
162
|
+
|
|
163
|
+
# Do not use Generated client for this endpoint b/c it does not handle files properly
|
|
164
|
+
metadata = parse_create_response(
|
|
165
|
+
response=client.get_httpx_client().request(
|
|
166
|
+
method="post",
|
|
167
|
+
url="/datasource/",
|
|
168
|
+
files=files,
|
|
169
|
+
data={"name": name, "description": description},
|
|
170
|
+
)
|
|
127
171
|
)
|
|
128
|
-
|
|
172
|
+
|
|
129
173
|
return cls(metadata=metadata)
|
|
130
174
|
|
|
131
175
|
@classmethod
|
orca_sdk/datasource_test.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
1
3
|
from uuid import uuid4
|
|
2
4
|
|
|
3
5
|
import pytest
|
|
@@ -94,3 +96,10 @@ def test_drop_datasource_unauthorized(datasource, unauthorized):
|
|
|
94
96
|
def test_drop_datasource_invalid_input():
|
|
95
97
|
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
96
98
|
Datasource.drop("not valid id")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_download_datasource(datasource):
|
|
102
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
103
|
+
output_path = os.path.join(temp_dir, "datasource.zip")
|
|
104
|
+
datasource.download(output_path)
|
|
105
|
+
assert os.path.exists(output_path)
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -53,7 +53,7 @@ def test_embed_text_unauthenticated(unauthenticated):
|
|
|
53
53
|
|
|
54
54
|
@pytest.fixture(scope="session")
|
|
55
55
|
def finetuned_model(datasource) -> FinetunedEmbeddingModel:
|
|
56
|
-
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource
|
|
56
|
+
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
|
|
@@ -65,8 +65,10 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
|
|
|
65
65
|
assert finetuned_model._status == TaskStatus.COMPLETED
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
def test_finetune_model_with_memoryset(
|
|
69
|
-
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
68
|
+
def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
69
|
+
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
70
|
+
"test_finetuned_model_from_memoryset", readonly_memoryset
|
|
71
|
+
)
|
|
70
72
|
assert finetuned_model is not None
|
|
71
73
|
assert finetuned_model.name == "test_finetuned_model_from_memoryset"
|
|
72
74
|
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
@@ -109,7 +111,6 @@ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_mode
|
|
|
109
111
|
"test_memoryset_finetuned_model",
|
|
110
112
|
datasource,
|
|
111
113
|
embedding_model=finetuned_model,
|
|
112
|
-
value_column="text",
|
|
113
114
|
)
|
|
114
115
|
assert memoryset is not None
|
|
115
116
|
assert memoryset.name == "test_memoryset_finetuned_model"
|
|
@@ -152,7 +153,7 @@ def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: Finetu
|
|
|
152
153
|
|
|
153
154
|
|
|
154
155
|
def test_drop_finetuned_model(datasource: Datasource):
|
|
155
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource
|
|
156
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
156
157
|
assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
157
158
|
FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
|
|
158
159
|
with pytest.raises(LookupError):
|
orca_sdk/memoryset.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, cast, overload
|
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import pyarrow as pa
|
|
10
|
+
from attrs import fields
|
|
10
11
|
from datasets import Dataset
|
|
11
12
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
12
13
|
from torch.utils.data import Dataset as TorchDataset
|
|
@@ -29,11 +30,14 @@ from ._generated_api_client.api import (
|
|
|
29
30
|
memoryset_lookup_gpu,
|
|
30
31
|
potential_duplicate_groups,
|
|
31
32
|
query_memoryset,
|
|
33
|
+
suggest_cascading_edits,
|
|
32
34
|
update_memories_gpu,
|
|
33
35
|
update_memory_gpu,
|
|
34
36
|
update_memoryset,
|
|
35
37
|
)
|
|
36
38
|
from ._generated_api_client.models import (
|
|
39
|
+
CascadeEditSuggestionsRequest,
|
|
40
|
+
CascadingEditSuggestion,
|
|
37
41
|
CloneLabeledMemorysetRequest,
|
|
38
42
|
CreateLabeledMemorysetRequest,
|
|
39
43
|
DeleteMemoriesRequest,
|
|
@@ -1180,6 +1184,63 @@ class LabeledMemoryset:
|
|
|
1180
1184
|
updated_memories = [LabeledMemory(self.id, memory) for memory in response]
|
|
1181
1185
|
return updated_memories[0] if isinstance(updates, dict) else updated_memories
|
|
1182
1186
|
|
|
1187
|
+
def get_cascading_edits_suggestions(
|
|
1188
|
+
self: LabeledMemoryset,
|
|
1189
|
+
memory: LabeledMemory,
|
|
1190
|
+
*,
|
|
1191
|
+
old_label: int,
|
|
1192
|
+
new_label: int,
|
|
1193
|
+
max_neighbors: int = 50,
|
|
1194
|
+
max_validation_neighbors: int = 10,
|
|
1195
|
+
similarity_threshold: float | None = None,
|
|
1196
|
+
only_if_has_old_label: bool = True,
|
|
1197
|
+
exclude_if_new_label: bool = True,
|
|
1198
|
+
suggestion_cooldown_time: float = 3600.0 * 24.0, # 1 day
|
|
1199
|
+
label_confirmation_cooldown_time: float = 3600.0 * 24.0 * 7, # 1 week
|
|
1200
|
+
) -> list[CascadingEditSuggestion]:
|
|
1201
|
+
"""
|
|
1202
|
+
Suggests cascading edits for a given memory based on nearby points with similar labels.
|
|
1203
|
+
|
|
1204
|
+
This function is triggered after a user changes a memory's label. It looks for nearby
|
|
1205
|
+
candidates in embedding space that may be subject to similar relabeling and returns them
|
|
1206
|
+
as suggestions. The system uses scoring heuristics, label filters, and cooldown tracking
|
|
1207
|
+
to reduce noise and improve usability.
|
|
1208
|
+
|
|
1209
|
+
Params:
|
|
1210
|
+
memory: The memory whose label was just changed.
|
|
1211
|
+
old_label: The label this memory used to have.
|
|
1212
|
+
new_label: The label it was changed to.
|
|
1213
|
+
max_neighbors: Maximum number of neighbors to consider.
|
|
1214
|
+
max_validation_neighbors: Maximum number of neighbors to use for label suggestion.
|
|
1215
|
+
similarity_threshold: If set, only include neighbors with a lookup score above this threshold.
|
|
1216
|
+
only_if_has_old_label: If True, only consider neighbors that have the old label.
|
|
1217
|
+
exclude_if_new_label: If True, exclude neighbors that already have the new label.
|
|
1218
|
+
suggestion_cooldown_time: Minimum time (in seconds) since the last suggestion for a neighbor
|
|
1219
|
+
to be considered again.
|
|
1220
|
+
label_confirmation_cooldown_time: Minimum time (in seconds) since a neighbor's label was confirmed
|
|
1221
|
+
to be considered for suggestions.
|
|
1222
|
+
_current_time: Optional override for the current timestamp (useful for testing).
|
|
1223
|
+
|
|
1224
|
+
Returns:
|
|
1225
|
+
A list of CascadingEditSuggestion objects, each containing a neighbor and the suggested new label.
|
|
1226
|
+
"""
|
|
1227
|
+
|
|
1228
|
+
return suggest_cascading_edits(
|
|
1229
|
+
name_or_id=self.id,
|
|
1230
|
+
memory_id=memory.memory_id,
|
|
1231
|
+
body=CascadeEditSuggestionsRequest(
|
|
1232
|
+
old_label=old_label,
|
|
1233
|
+
new_label=new_label,
|
|
1234
|
+
max_neighbors=max_neighbors,
|
|
1235
|
+
max_validation_neighbors=max_validation_neighbors,
|
|
1236
|
+
similarity_threshold=similarity_threshold,
|
|
1237
|
+
only_if_has_old_label=only_if_has_old_label,
|
|
1238
|
+
exclude_if_new_label=exclude_if_new_label,
|
|
1239
|
+
suggestion_cooldown_time=suggestion_cooldown_time,
|
|
1240
|
+
label_confirmation_cooldown_time=label_confirmation_cooldown_time,
|
|
1241
|
+
),
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1183
1244
|
def delete(self, memory_id: str | Iterable[str]) -> None:
|
|
1184
1245
|
"""
|
|
1185
1246
|
Delete memories from the memoryset
|
|
@@ -1229,6 +1290,9 @@ class LabeledMemoryset:
|
|
|
1229
1290
|
Returns:
|
|
1230
1291
|
dictionary with aggregate metrics for each analysis that was run
|
|
1231
1292
|
|
|
1293
|
+
Raises:
|
|
1294
|
+
ValueError: If an invalid analysis name is provided
|
|
1295
|
+
|
|
1232
1296
|
Examples:
|
|
1233
1297
|
Run label and duplicate analysis:
|
|
1234
1298
|
>>> memoryset.analyze("label", {"name": "duplicate", "possible_duplicate_threshold": 0.99})
|
|
@@ -1263,12 +1327,26 @@ class LabeledMemoryset:
|
|
|
1263
1327
|
Display label analysis to review potential mislabelings:
|
|
1264
1328
|
>>> memoryset.display_label_analysis()
|
|
1265
1329
|
"""
|
|
1330
|
+
|
|
1331
|
+
# Get valid analysis names from MemorysetAnalysisConfigs
|
|
1332
|
+
valid_analysis_names = {
|
|
1333
|
+
field.name for field in fields(MemorysetAnalysisConfigs) if field.name != "additional_properties"
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1266
1336
|
configs: dict[str, dict] = {}
|
|
1267
1337
|
for analysis in analyses:
|
|
1268
1338
|
if isinstance(analysis, str):
|
|
1339
|
+
error_msg = (
|
|
1340
|
+
f"Invalid analysis name: {analysis}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
|
|
1341
|
+
)
|
|
1342
|
+
if analysis not in valid_analysis_names:
|
|
1343
|
+
raise ValueError(error_msg)
|
|
1269
1344
|
configs[analysis] = {}
|
|
1270
1345
|
else:
|
|
1271
1346
|
name = analysis.pop("name") # type: ignore
|
|
1347
|
+
error_msg = f"Invalid analysis name: {name}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
|
|
1348
|
+
if name not in valid_analysis_names:
|
|
1349
|
+
raise ValueError(error_msg)
|
|
1272
1350
|
configs[name] = analysis # type: ignore
|
|
1273
1351
|
|
|
1274
1352
|
analysis = analyze_memoryset(
|