orca-sdk 0.0.91__py3-none-any.whl → 0.0.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +4 -0
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
- orca_sdk/_generated_api_client/models/__init__.py +4 -0
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
- orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_shared/metrics.py +195 -0
- orca_sdk/_shared/metrics_test.py +169 -0
- orca_sdk/_utils/data_parsing.py +31 -2
- orca_sdk/_utils/data_parsing_test.py +18 -15
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/classification_model.py +170 -27
- orca_sdk/classification_model_test.py +74 -32
- orca_sdk/conftest.py +86 -25
- orca_sdk/datasource.py +22 -12
- orca_sdk/embedding_model_test.py +6 -5
- orca_sdk/memoryset.py +78 -0
- orca_sdk/memoryset_test.py +197 -123
- orca_sdk/telemetry.py +3 -0
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/METADATA +3 -1
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/RECORD +32 -25
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/WHEEL +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from uuid import uuid4
|
|
2
2
|
|
|
3
|
+
import numpy as np
|
|
3
4
|
import pytest
|
|
4
5
|
from datasets.arrow_dataset import Dataset
|
|
5
6
|
|
|
@@ -9,45 +10,45 @@ from .embedding_model import PretrainedEmbeddingModel
|
|
|
9
10
|
from .memoryset import LabeledMemoryset
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def test_create_model(model: ClassificationModel,
|
|
13
|
+
def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
13
14
|
assert model is not None
|
|
14
15
|
assert model.name == "test_model"
|
|
15
|
-
assert model.memoryset ==
|
|
16
|
+
assert model.memoryset == readonly_memoryset
|
|
16
17
|
assert model.num_classes == 2
|
|
17
18
|
assert model.memory_lookup_count == 3
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
def test_create_model_already_exists_error(
|
|
21
|
+
def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
|
|
21
22
|
with pytest.raises(ValueError):
|
|
22
|
-
ClassificationModel.create("test_model",
|
|
23
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
23
24
|
with pytest.raises(ValueError):
|
|
24
|
-
ClassificationModel.create("test_model",
|
|
25
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="error")
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
def test_create_model_already_exists_return(
|
|
28
|
+
def test_create_model_already_exists_return(readonly_memoryset, model: ClassificationModel):
|
|
28
29
|
with pytest.raises(ValueError):
|
|
29
|
-
ClassificationModel.create("test_model",
|
|
30
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
30
31
|
|
|
31
32
|
with pytest.raises(ValueError):
|
|
32
|
-
ClassificationModel.create("test_model",
|
|
33
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", memory_lookup_count=37)
|
|
33
34
|
|
|
34
35
|
with pytest.raises(ValueError):
|
|
35
|
-
ClassificationModel.create("test_model",
|
|
36
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
36
37
|
|
|
37
38
|
with pytest.raises(ValueError):
|
|
38
|
-
ClassificationModel.create("test_model",
|
|
39
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77)
|
|
39
40
|
|
|
40
|
-
new_model = ClassificationModel.create("test_model",
|
|
41
|
+
new_model = ClassificationModel.create("test_model", readonly_memoryset, if_exists="open")
|
|
41
42
|
assert new_model is not None
|
|
42
43
|
assert new_model.name == "test_model"
|
|
43
|
-
assert new_model.memoryset ==
|
|
44
|
+
assert new_model.memoryset == readonly_memoryset
|
|
44
45
|
assert new_model.num_classes == 2
|
|
45
46
|
assert new_model.memory_lookup_count == 3
|
|
46
47
|
|
|
47
48
|
|
|
48
|
-
def test_create_model_unauthenticated(unauthenticated,
|
|
49
|
+
def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
|
|
49
50
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
-
ClassificationModel.create("test_model",
|
|
51
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
def test_get_model(model: ClassificationModel):
|
|
@@ -106,8 +107,8 @@ def test_update_model_no_description(model: ClassificationModel):
|
|
|
106
107
|
assert model.description is None
|
|
107
108
|
|
|
108
109
|
|
|
109
|
-
def test_delete_model(
|
|
110
|
-
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(
|
|
110
|
+
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
111
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
|
|
111
112
|
assert ClassificationModel.open("model_to_delete")
|
|
112
113
|
ClassificationModel.drop("model_to_delete")
|
|
113
114
|
with pytest.raises(LookupError):
|
|
@@ -132,25 +133,23 @@ def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
|
|
|
132
133
|
|
|
133
134
|
|
|
134
135
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
135
|
-
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset
|
|
136
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
|
|
136
137
|
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
137
138
|
with pytest.raises(RuntimeError):
|
|
138
139
|
LabeledMemoryset.drop(memoryset.id)
|
|
139
140
|
|
|
140
141
|
|
|
141
|
-
def test_evaluate(model):
|
|
142
|
-
|
|
143
|
-
"eval_datasource",
|
|
144
|
-
[
|
|
145
|
-
{"text": "chicken noodle soup is the best", "label": 1},
|
|
146
|
-
{"text": "cats are cute", "label": 0},
|
|
147
|
-
{"text": "soup is great for the winter", "label": 0},
|
|
148
|
-
{"text": "i love cats", "label": 1},
|
|
149
|
-
],
|
|
150
|
-
)
|
|
151
|
-
result = model.evaluate(eval_datasource, value_column="text")
|
|
142
|
+
def test_evaluate(model, eval_datasource: Datasource):
|
|
143
|
+
result = model.evaluate(eval_datasource)
|
|
152
144
|
assert result is not None
|
|
153
145
|
assert isinstance(result, dict)
|
|
146
|
+
# And anomaly score statistics are present and valid
|
|
147
|
+
assert isinstance(result["anomaly_score_mean"], float)
|
|
148
|
+
assert isinstance(result["anomaly_score_median"], float)
|
|
149
|
+
assert isinstance(result["anomaly_score_variance"], float)
|
|
150
|
+
assert -1.0 <= result["anomaly_score_mean"] <= 1.0
|
|
151
|
+
assert -1.0 <= result["anomaly_score_median"] <= 1.0
|
|
152
|
+
assert -1.0 <= result["anomaly_score_variance"] <= 1.0
|
|
154
153
|
assert isinstance(result["accuracy"], float)
|
|
155
154
|
assert isinstance(result["f1_score"], float)
|
|
156
155
|
assert isinstance(result["loss"], float)
|
|
@@ -162,6 +161,40 @@ def test_evaluate(model):
|
|
|
162
161
|
assert len(result["roc_curve"]["true_positive_rates"]) == 4
|
|
163
162
|
|
|
164
163
|
|
|
164
|
+
def test_evaluate_combined(model, eval_datasource: Datasource, eval_dataset: Dataset):
|
|
165
|
+
result_datasource = model.evaluate(eval_datasource)
|
|
166
|
+
|
|
167
|
+
result_dataset = model.evaluate(eval_dataset)
|
|
168
|
+
|
|
169
|
+
for result in [result_datasource, result_dataset]:
|
|
170
|
+
assert result is not None
|
|
171
|
+
assert isinstance(result, dict)
|
|
172
|
+
assert isinstance(result["accuracy"], float)
|
|
173
|
+
assert isinstance(result["f1_score"], float)
|
|
174
|
+
assert isinstance(result["loss"], float)
|
|
175
|
+
assert np.allclose(result["accuracy"], 0.5)
|
|
176
|
+
assert np.allclose(result["f1_score"], 0.5)
|
|
177
|
+
|
|
178
|
+
assert isinstance(result["precision_recall_curve"]["thresholds"], list)
|
|
179
|
+
assert isinstance(result["precision_recall_curve"]["precisions"], list)
|
|
180
|
+
assert isinstance(result["precision_recall_curve"]["recalls"], list)
|
|
181
|
+
assert isinstance(result["roc_curve"]["thresholds"], list)
|
|
182
|
+
assert isinstance(result["roc_curve"]["false_positive_rates"], list)
|
|
183
|
+
assert isinstance(result["roc_curve"]["true_positive_rates"], list)
|
|
184
|
+
|
|
185
|
+
assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
186
|
+
assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
187
|
+
assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
188
|
+
assert np.allclose(result["roc_curve"]["auc"], 0.625)
|
|
189
|
+
|
|
190
|
+
assert np.allclose(
|
|
191
|
+
result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
|
|
192
|
+
)
|
|
193
|
+
assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
194
|
+
assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
195
|
+
assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
|
|
196
|
+
|
|
197
|
+
|
|
165
198
|
def test_evaluate_with_telemetry(model):
|
|
166
199
|
samples = [
|
|
167
200
|
{"text": "chicken noodle soup is the best", "label": 1},
|
|
@@ -188,9 +221,16 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
188
221
|
assert predictions[1].label_name == label_names[1]
|
|
189
222
|
assert 0 <= predictions[1].confidence <= 1
|
|
190
223
|
|
|
224
|
+
assert predictions[0].logits is not None
|
|
225
|
+
assert predictions[1].logits is not None
|
|
226
|
+
assert len(predictions[0].logits) == 2
|
|
227
|
+
assert len(predictions[1].logits) == 2
|
|
228
|
+
assert predictions[0].logits[0] > predictions[0].logits[1]
|
|
229
|
+
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
230
|
+
|
|
191
231
|
|
|
192
232
|
def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
|
|
193
|
-
predictions = model.predict(["Do you love soup?", "Are cats cute?"],
|
|
233
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry=False)
|
|
194
234
|
assert len(predictions) == 2
|
|
195
235
|
assert predictions[0].prediction_id is None
|
|
196
236
|
assert predictions[1].prediction_id is None
|
|
@@ -212,9 +252,12 @@ def test_predict_unauthorized(unauthorized, model: ClassificationModel):
|
|
|
212
252
|
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
213
253
|
|
|
214
254
|
|
|
215
|
-
def test_predict_constraint_violation(
|
|
255
|
+
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
216
256
|
model = ClassificationModel.create(
|
|
217
|
-
"test_model_lookup_count_too_high",
|
|
257
|
+
"test_model_lookup_count_too_high",
|
|
258
|
+
readonly_memoryset,
|
|
259
|
+
num_classes=2,
|
|
260
|
+
memory_lookup_count=readonly_memoryset.length + 2,
|
|
218
261
|
)
|
|
219
262
|
with pytest.raises(RuntimeError):
|
|
220
263
|
model.predict("test")
|
|
@@ -254,7 +297,6 @@ def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset:
|
|
|
254
297
|
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
255
298
|
"test_memoryset_inverted_labels",
|
|
256
299
|
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
257
|
-
value_column="text",
|
|
258
300
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
259
301
|
)
|
|
260
302
|
with model.use_memoryset(inverted_labeled_memoryset):
|
orca_sdk/conftest.py
CHANGED
|
@@ -17,6 +17,8 @@ logging.basicConfig(level=logging.INFO)
|
|
|
17
17
|
|
|
18
18
|
os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
|
|
19
19
|
|
|
20
|
+
os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
21
|
+
|
|
20
22
|
|
|
21
23
|
def _create_org_id():
|
|
22
24
|
# UUID start to identify test data (0xtest...)
|
|
@@ -69,22 +71,22 @@ def label_names():
|
|
|
69
71
|
|
|
70
72
|
|
|
71
73
|
SAMPLE_DATA = [
|
|
72
|
-
{"
|
|
73
|
-
{"
|
|
74
|
-
{"
|
|
75
|
-
{"
|
|
76
|
-
{"
|
|
77
|
-
{"
|
|
78
|
-
{"
|
|
79
|
-
{"
|
|
80
|
-
{"
|
|
81
|
-
{"
|
|
82
|
-
{"
|
|
83
|
-
{"
|
|
84
|
-
{"
|
|
85
|
-
{"
|
|
86
|
-
{"
|
|
87
|
-
{"
|
|
74
|
+
{"value": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
|
|
75
|
+
{"value": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
|
|
76
|
+
{"value": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
|
|
77
|
+
{"value": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
|
|
78
|
+
{"value": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
|
|
79
|
+
{"value": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
|
|
80
|
+
{"value": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
|
|
81
|
+
{"value": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
|
|
82
|
+
{"value": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
|
|
83
|
+
{"value": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
|
|
84
|
+
{"value": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
|
|
85
|
+
{"value": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
|
|
86
|
+
{"value": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
|
|
87
|
+
{"value": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
|
|
88
|
+
{"value": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
|
|
89
|
+
{"value": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
|
|
88
90
|
]
|
|
89
91
|
|
|
90
92
|
|
|
@@ -94,7 +96,7 @@ def hf_dataset(label_names):
|
|
|
94
96
|
SAMPLE_DATA,
|
|
95
97
|
features=Features(
|
|
96
98
|
{
|
|
97
|
-
"
|
|
99
|
+
"value": Value("string"),
|
|
98
100
|
"label": ClassLabel(names=label_names),
|
|
99
101
|
"key": Value("string"),
|
|
100
102
|
"score": Value("float"),
|
|
@@ -106,23 +108,82 @@ def hf_dataset(label_names):
|
|
|
106
108
|
|
|
107
109
|
@pytest.fixture(scope="session")
|
|
108
110
|
def datasource(hf_dataset) -> Datasource:
|
|
109
|
-
|
|
111
|
+
datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
112
|
+
return datasource
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
EVAL_DATASET = [
|
|
116
|
+
{"value": "chicken noodle soup is the best", "label": 1},
|
|
117
|
+
{"value": "cats are cute", "label": 0},
|
|
118
|
+
{"value": "soup is great for the winter", "label": 0},
|
|
119
|
+
{"value": "i love cats", "label": 1},
|
|
120
|
+
]
|
|
110
121
|
|
|
111
122
|
|
|
112
123
|
@pytest.fixture(scope="session")
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
|
|
124
|
+
def eval_datasource() -> Datasource:
|
|
125
|
+
eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
|
|
126
|
+
return eval_datasource
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.fixture(scope="session")
|
|
130
|
+
def eval_dataset() -> Dataset:
|
|
131
|
+
eval_dataset = Dataset.from_list(EVAL_DATASET)
|
|
132
|
+
return eval_dataset
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@pytest.fixture(scope="session")
|
|
136
|
+
def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
|
|
137
|
+
memoryset = LabeledMemoryset.create(
|
|
138
|
+
"test_readonly_memoryset",
|
|
116
139
|
datasource=datasource,
|
|
117
140
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
118
|
-
value_column="text",
|
|
119
141
|
source_id_column="source_id",
|
|
120
142
|
max_seq_length_override=32,
|
|
121
143
|
)
|
|
144
|
+
return memoryset
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.fixture(scope="function")
|
|
148
|
+
def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[LabeledMemoryset, None, None]:
|
|
149
|
+
"""
|
|
150
|
+
Function-scoped fixture that provides a writable memoryset for tests that mutate state.
|
|
151
|
+
|
|
152
|
+
This fixture creates a fresh `LabeledMemoryset` named 'test_writable_memoryset' before each test.
|
|
153
|
+
After the test, it attempts to restore the memoryset to its initial state by deleting any added entries
|
|
154
|
+
and reinserting sample data — unless the memoryset has been dropped by the test itself, in which case
|
|
155
|
+
it will be recreated on the next invocation.
|
|
156
|
+
|
|
157
|
+
Note: Re-creating the memoryset from scratch is surprisingly more expensive than cleaning it up.
|
|
158
|
+
"""
|
|
159
|
+
# It shouldn't be possible for this memoryset to already exist
|
|
160
|
+
memoryset = LabeledMemoryset.create(
|
|
161
|
+
"test_writable_memoryset",
|
|
162
|
+
datasource=datasource,
|
|
163
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
164
|
+
source_id_column="source_id",
|
|
165
|
+
max_seq_length_override=32,
|
|
166
|
+
if_exists="open",
|
|
167
|
+
)
|
|
168
|
+
try:
|
|
169
|
+
yield memoryset
|
|
170
|
+
finally:
|
|
171
|
+
# Restore the memoryset to a clean state for the next test.
|
|
172
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
173
|
+
|
|
174
|
+
if LabeledMemoryset.exists("test_writable_memoryset"):
|
|
175
|
+
memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
|
|
176
|
+
|
|
177
|
+
if memory_ids:
|
|
178
|
+
memoryset.delete(memory_ids)
|
|
179
|
+
assert len(memoryset) == 0
|
|
180
|
+
memoryset.insert(SAMPLE_DATA)
|
|
181
|
+
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
|
122
182
|
|
|
123
183
|
|
|
124
184
|
@pytest.fixture(scope="session")
|
|
125
|
-
def model(
|
|
126
|
-
|
|
127
|
-
"test_model",
|
|
185
|
+
def model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
|
|
186
|
+
model = ClassificationModel.create(
|
|
187
|
+
"test_model", readonly_memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
|
|
128
188
|
)
|
|
189
|
+
return model
|
orca_sdk/datasource.py
CHANGED
|
@@ -12,6 +12,7 @@ import pyarrow as pa
|
|
|
12
12
|
from datasets import Dataset
|
|
13
13
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
14
14
|
from torch.utils.data import Dataset as TorchDataset
|
|
15
|
+
from tqdm.auto import tqdm
|
|
15
16
|
|
|
16
17
|
from ._generated_api_client.api import (
|
|
17
18
|
delete_datasource,
|
|
@@ -25,6 +26,7 @@ from ._generated_api_client.client import get_client
|
|
|
25
26
|
from ._generated_api_client.models import ColumnType, DatasourceMetadata
|
|
26
27
|
from ._utils.common import CreateMode, DropMode
|
|
27
28
|
from ._utils.data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
|
|
29
|
+
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class Datasource:
|
|
@@ -113,19 +115,27 @@ class Datasource:
|
|
|
113
115
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
114
116
|
dataset.save_to_disk(tmp_dir)
|
|
115
117
|
files = []
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
files
|
|
126
|
-
|
|
118
|
+
|
|
119
|
+
# Calculate total size for all files
|
|
120
|
+
file_paths = list(Path(tmp_dir).iterdir())
|
|
121
|
+
total_size = sum(file_path.stat().st_size for file_path in file_paths)
|
|
122
|
+
|
|
123
|
+
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
|
|
124
|
+
for file_path in file_paths:
|
|
125
|
+
buffered_reader = open(file_path, "rb")
|
|
126
|
+
tqdm_reader = TqdmFileReader(buffered_reader, pbar)
|
|
127
|
+
files.append(("files", (file_path.name, tqdm_reader)))
|
|
128
|
+
|
|
129
|
+
# Do not use Generated client for this endpoint b/c it does not handle files properly
|
|
130
|
+
metadata = parse_create_response(
|
|
131
|
+
response=client.get_httpx_client().request(
|
|
132
|
+
method="post",
|
|
133
|
+
url="/datasource/",
|
|
134
|
+
files=files,
|
|
135
|
+
data={"name": name, "description": description},
|
|
136
|
+
)
|
|
127
137
|
)
|
|
128
|
-
|
|
138
|
+
|
|
129
139
|
return cls(metadata=metadata)
|
|
130
140
|
|
|
131
141
|
@classmethod
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -53,7 +53,7 @@ def test_embed_text_unauthenticated(unauthenticated):
|
|
|
53
53
|
|
|
54
54
|
@pytest.fixture(scope="session")
|
|
55
55
|
def finetuned_model(datasource) -> FinetunedEmbeddingModel:
|
|
56
|
-
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource
|
|
56
|
+
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
|
|
@@ -65,8 +65,10 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
|
|
|
65
65
|
assert finetuned_model._status == TaskStatus.COMPLETED
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
def test_finetune_model_with_memoryset(
|
|
69
|
-
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
68
|
+
def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
69
|
+
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
70
|
+
"test_finetuned_model_from_memoryset", readonly_memoryset
|
|
71
|
+
)
|
|
70
72
|
assert finetuned_model is not None
|
|
71
73
|
assert finetuned_model.name == "test_finetuned_model_from_memoryset"
|
|
72
74
|
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
@@ -109,7 +111,6 @@ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_mode
|
|
|
109
111
|
"test_memoryset_finetuned_model",
|
|
110
112
|
datasource,
|
|
111
113
|
embedding_model=finetuned_model,
|
|
112
|
-
value_column="text",
|
|
113
114
|
)
|
|
114
115
|
assert memoryset is not None
|
|
115
116
|
assert memoryset.name == "test_memoryset_finetuned_model"
|
|
@@ -152,7 +153,7 @@ def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: Finetu
|
|
|
152
153
|
|
|
153
154
|
|
|
154
155
|
def test_drop_finetuned_model(datasource: Datasource):
|
|
155
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource
|
|
156
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
156
157
|
assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
157
158
|
FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
|
|
158
159
|
with pytest.raises(LookupError):
|
orca_sdk/memoryset.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, cast, overload
|
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import pyarrow as pa
|
|
10
|
+
from attrs import fields
|
|
10
11
|
from datasets import Dataset
|
|
11
12
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
12
13
|
from torch.utils.data import Dataset as TorchDataset
|
|
@@ -29,11 +30,14 @@ from ._generated_api_client.api import (
|
|
|
29
30
|
memoryset_lookup_gpu,
|
|
30
31
|
potential_duplicate_groups,
|
|
31
32
|
query_memoryset,
|
|
33
|
+
suggest_cascading_edits,
|
|
32
34
|
update_memories_gpu,
|
|
33
35
|
update_memory_gpu,
|
|
34
36
|
update_memoryset,
|
|
35
37
|
)
|
|
36
38
|
from ._generated_api_client.models import (
|
|
39
|
+
CascadeEditSuggestionsRequest,
|
|
40
|
+
CascadingEditSuggestion,
|
|
37
41
|
CloneLabeledMemorysetRequest,
|
|
38
42
|
CreateLabeledMemorysetRequest,
|
|
39
43
|
DeleteMemoriesRequest,
|
|
@@ -1180,6 +1184,63 @@ class LabeledMemoryset:
|
|
|
1180
1184
|
updated_memories = [LabeledMemory(self.id, memory) for memory in response]
|
|
1181
1185
|
return updated_memories[0] if isinstance(updates, dict) else updated_memories
|
|
1182
1186
|
|
|
1187
|
+
def get_cascading_edits_suggestions(
|
|
1188
|
+
self: LabeledMemoryset,
|
|
1189
|
+
memory: LabeledMemory,
|
|
1190
|
+
*,
|
|
1191
|
+
old_label: int,
|
|
1192
|
+
new_label: int,
|
|
1193
|
+
max_neighbors: int = 50,
|
|
1194
|
+
max_validation_neighbors: int = 10,
|
|
1195
|
+
similarity_threshold: float | None = None,
|
|
1196
|
+
only_if_has_old_label: bool = True,
|
|
1197
|
+
exclude_if_new_label: bool = True,
|
|
1198
|
+
suggestion_cooldown_time: float = 3600.0 * 24.0, # 1 day
|
|
1199
|
+
label_confirmation_cooldown_time: float = 3600.0 * 24.0 * 7, # 1 week
|
|
1200
|
+
) -> list[CascadingEditSuggestion]:
|
|
1201
|
+
"""
|
|
1202
|
+
Suggests cascading edits for a given memory based on nearby points with similar labels.
|
|
1203
|
+
|
|
1204
|
+
This function is triggered after a user changes a memory's label. It looks for nearby
|
|
1205
|
+
candidates in embedding space that may be subject to similar relabeling and returns them
|
|
1206
|
+
as suggestions. The system uses scoring heuristics, label filters, and cooldown tracking
|
|
1207
|
+
to reduce noise and improve usability.
|
|
1208
|
+
|
|
1209
|
+
Params:
|
|
1210
|
+
memory: The memory whose label was just changed.
|
|
1211
|
+
old_label: The label this memory used to have.
|
|
1212
|
+
new_label: The label it was changed to.
|
|
1213
|
+
max_neighbors: Maximum number of neighbors to consider.
|
|
1214
|
+
max_validation_neighbors: Maximum number of neighbors to use for label suggestion.
|
|
1215
|
+
similarity_threshold: If set, only include neighbors with a lookup score above this threshold.
|
|
1216
|
+
only_if_has_old_label: If True, only consider neighbors that have the old label.
|
|
1217
|
+
exclude_if_new_label: If True, exclude neighbors that already have the new label.
|
|
1218
|
+
suggestion_cooldown_time: Minimum time (in seconds) since the last suggestion for a neighbor
|
|
1219
|
+
to be considered again.
|
|
1220
|
+
label_confirmation_cooldown_time: Minimum time (in seconds) since a neighbor's label was confirmed
|
|
1221
|
+
to be considered for suggestions.
|
|
1222
|
+
_current_time: Optional override for the current timestamp (useful for testing).
|
|
1223
|
+
|
|
1224
|
+
Returns:
|
|
1225
|
+
A list of CascadingEditSuggestion objects, each containing a neighbor and the suggested new label.
|
|
1226
|
+
"""
|
|
1227
|
+
|
|
1228
|
+
return suggest_cascading_edits(
|
|
1229
|
+
name_or_id=self.id,
|
|
1230
|
+
memory_id=memory.memory_id,
|
|
1231
|
+
body=CascadeEditSuggestionsRequest(
|
|
1232
|
+
old_label=old_label,
|
|
1233
|
+
new_label=new_label,
|
|
1234
|
+
max_neighbors=max_neighbors,
|
|
1235
|
+
max_validation_neighbors=max_validation_neighbors,
|
|
1236
|
+
similarity_threshold=similarity_threshold,
|
|
1237
|
+
only_if_has_old_label=only_if_has_old_label,
|
|
1238
|
+
exclude_if_new_label=exclude_if_new_label,
|
|
1239
|
+
suggestion_cooldown_time=suggestion_cooldown_time,
|
|
1240
|
+
label_confirmation_cooldown_time=label_confirmation_cooldown_time,
|
|
1241
|
+
),
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1183
1244
|
def delete(self, memory_id: str | Iterable[str]) -> None:
|
|
1184
1245
|
"""
|
|
1185
1246
|
Delete memories from the memoryset
|
|
@@ -1229,6 +1290,9 @@ class LabeledMemoryset:
|
|
|
1229
1290
|
Returns:
|
|
1230
1291
|
dictionary with aggregate metrics for each analysis that was run
|
|
1231
1292
|
|
|
1293
|
+
Raises:
|
|
1294
|
+
ValueError: If an invalid analysis name is provided
|
|
1295
|
+
|
|
1232
1296
|
Examples:
|
|
1233
1297
|
Run label and duplicate analysis:
|
|
1234
1298
|
>>> memoryset.analyze("label", {"name": "duplicate", "possible_duplicate_threshold": 0.99})
|
|
@@ -1263,12 +1327,26 @@ class LabeledMemoryset:
|
|
|
1263
1327
|
Display label analysis to review potential mislabelings:
|
|
1264
1328
|
>>> memoryset.display_label_analysis()
|
|
1265
1329
|
"""
|
|
1330
|
+
|
|
1331
|
+
# Get valid analysis names from MemorysetAnalysisConfigs
|
|
1332
|
+
valid_analysis_names = {
|
|
1333
|
+
field.name for field in fields(MemorysetAnalysisConfigs) if field.name != "additional_properties"
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1266
1336
|
configs: dict[str, dict] = {}
|
|
1267
1337
|
for analysis in analyses:
|
|
1268
1338
|
if isinstance(analysis, str):
|
|
1339
|
+
error_msg = (
|
|
1340
|
+
f"Invalid analysis name: {analysis}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
|
|
1341
|
+
)
|
|
1342
|
+
if analysis not in valid_analysis_names:
|
|
1343
|
+
raise ValueError(error_msg)
|
|
1269
1344
|
configs[analysis] = {}
|
|
1270
1345
|
else:
|
|
1271
1346
|
name = analysis.pop("name") # type: ignore
|
|
1347
|
+
error_msg = f"Invalid analysis name: {name}. Valid names are: {', '.join(sorted(valid_analysis_names))}"
|
|
1348
|
+
if name not in valid_analysis_names:
|
|
1349
|
+
raise ValueError(error_msg)
|
|
1272
1350
|
configs[name] = analysis # type: ignore
|
|
1273
1351
|
|
|
1274
1352
|
analysis = analyze_memoryset(
|