orca-sdk 0.0.92__py3-none-any.whl → 0.0.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +8 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/models/__init__.py +10 -0
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_utils/data_parsing.py +31 -2
- orca_sdk/_utils/data_parsing_test.py +18 -15
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/classification_model.py +32 -12
- orca_sdk/classification_model_test.py +95 -34
- orca_sdk/conftest.py +87 -25
- orca_sdk/datasource.py +56 -12
- orca_sdk/datasource_test.py +9 -0
- orca_sdk/embedding_model_test.py +6 -5
- orca_sdk/memoryset.py +78 -0
- orca_sdk/memoryset_test.py +199 -123
- orca_sdk/telemetry.py +5 -3
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +36 -28
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -1,23 +1,45 @@
|
|
|
1
1
|
import random
|
|
2
|
+
import time
|
|
3
|
+
from typing import Generator
|
|
2
4
|
from uuid import uuid4
|
|
3
5
|
|
|
4
6
|
import pytest
|
|
7
|
+
from datasets import ClassLabel, Features, Value
|
|
5
8
|
from datasets.arrow_dataset import Dataset
|
|
6
9
|
|
|
10
|
+
from orca_sdk.conftest import SAMPLE_DATA
|
|
11
|
+
|
|
12
|
+
from ._generated_api_client.models import CascadingEditSuggestion
|
|
7
13
|
from .classification_model import ClassificationModel
|
|
8
14
|
from .datasource import Datasource
|
|
9
15
|
from .embedding_model import PretrainedEmbeddingModel
|
|
10
16
|
from .memoryset import LabeledMemoryset, TaskStatus
|
|
11
17
|
|
|
18
|
+
"""
|
|
19
|
+
Test Performance Note:
|
|
20
|
+
|
|
21
|
+
Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
|
|
22
|
+
|
|
23
|
+
- Two fixtures are used to manage memorysets:
|
|
24
|
+
- `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
|
|
25
|
+
It should only be used in nullipotent tests.
|
|
26
|
+
- `writable_memoryset` is a function-scoped, regenerating fixture.
|
|
27
|
+
It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
|
|
28
|
+
|
|
29
|
+
- To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
|
|
30
|
+
For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
|
|
31
|
+
rather than separate `test_delete_single` and `test_delete_multiple` tests.
|
|
32
|
+
"""
|
|
12
33
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
assert
|
|
16
|
-
assert
|
|
17
|
-
assert
|
|
18
|
-
assert
|
|
19
|
-
assert
|
|
20
|
-
assert
|
|
34
|
+
|
|
35
|
+
def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
36
|
+
assert readonly_memoryset is not None
|
|
37
|
+
assert readonly_memoryset.name == "test_readonly_memoryset"
|
|
38
|
+
assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
39
|
+
assert readonly_memoryset.label_names == label_names
|
|
40
|
+
assert readonly_memoryset.insertion_status == TaskStatus.COMPLETED
|
|
41
|
+
assert isinstance(readonly_memoryset.length, int)
|
|
42
|
+
assert readonly_memoryset.length == len(hf_dataset)
|
|
21
43
|
|
|
22
44
|
|
|
23
45
|
def test_create_memoryset_unauthenticated(unauthenticated, datasource):
|
|
@@ -29,61 +51,55 @@ def test_create_memoryset_invalid_input(datasource):
|
|
|
29
51
|
# invalid name
|
|
30
52
|
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
31
53
|
LabeledMemoryset.create("test memoryset", datasource)
|
|
32
|
-
# invalid datasource
|
|
33
|
-
datasource.id = str(uuid4())
|
|
34
|
-
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
35
|
-
LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
|
|
36
54
|
|
|
37
55
|
|
|
38
|
-
def test_create_memoryset_already_exists_error(hf_dataset, label_names,
|
|
56
|
+
def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
|
|
57
|
+
memoryset_name = readonly_memoryset.name
|
|
39
58
|
with pytest.raises(ValueError):
|
|
40
|
-
LabeledMemoryset.from_hf_dataset(
|
|
59
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
|
|
41
60
|
with pytest.raises(ValueError):
|
|
42
|
-
LabeledMemoryset.from_hf_dataset(
|
|
43
|
-
"test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
|
|
44
|
-
)
|
|
61
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
|
|
45
62
|
|
|
46
63
|
|
|
47
|
-
def test_create_memoryset_already_exists_open(hf_dataset, label_names,
|
|
64
|
+
def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
|
|
48
65
|
# invalid label names
|
|
49
66
|
with pytest.raises(ValueError):
|
|
50
67
|
LabeledMemoryset.from_hf_dataset(
|
|
51
|
-
|
|
68
|
+
readonly_memoryset.name,
|
|
52
69
|
hf_dataset,
|
|
53
70
|
label_names=["turtles", "frogs"],
|
|
54
|
-
value_column="text",
|
|
55
71
|
if_exists="open",
|
|
56
72
|
)
|
|
57
73
|
# different embedding model
|
|
58
74
|
with pytest.raises(ValueError):
|
|
59
75
|
LabeledMemoryset.from_hf_dataset(
|
|
60
|
-
|
|
76
|
+
readonly_memoryset.name,
|
|
61
77
|
hf_dataset,
|
|
62
78
|
label_names=label_names,
|
|
63
79
|
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
64
80
|
if_exists="open",
|
|
65
81
|
)
|
|
66
82
|
opened_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
67
|
-
|
|
83
|
+
readonly_memoryset.name,
|
|
68
84
|
hf_dataset,
|
|
69
85
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
70
86
|
if_exists="open",
|
|
71
87
|
)
|
|
72
88
|
assert opened_memoryset is not None
|
|
73
|
-
assert opened_memoryset.name ==
|
|
89
|
+
assert opened_memoryset.name == readonly_memoryset.name
|
|
74
90
|
assert opened_memoryset.length == len(hf_dataset)
|
|
75
91
|
|
|
76
92
|
|
|
77
|
-
def test_open_memoryset(
|
|
78
|
-
fetched_memoryset = LabeledMemoryset.open(
|
|
93
|
+
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
94
|
+
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
79
95
|
assert fetched_memoryset is not None
|
|
80
|
-
assert fetched_memoryset.name ==
|
|
96
|
+
assert fetched_memoryset.name == readonly_memoryset.name
|
|
81
97
|
assert fetched_memoryset.length == len(hf_dataset)
|
|
82
98
|
|
|
83
99
|
|
|
84
|
-
def test_open_memoryset_unauthenticated(unauthenticated,
|
|
100
|
+
def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
|
|
85
101
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
86
|
-
LabeledMemoryset.open(
|
|
102
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
87
103
|
|
|
88
104
|
|
|
89
105
|
def test_open_memoryset_not_found():
|
|
@@ -96,15 +112,15 @@ def test_open_memoryset_invalid_input():
|
|
|
96
112
|
LabeledMemoryset.open("not valid id")
|
|
97
113
|
|
|
98
114
|
|
|
99
|
-
def test_open_memoryset_unauthorized(unauthorized,
|
|
115
|
+
def test_open_memoryset_unauthorized(unauthorized, readonly_memoryset):
|
|
100
116
|
with pytest.raises(LookupError):
|
|
101
|
-
LabeledMemoryset.open(
|
|
117
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
102
118
|
|
|
103
119
|
|
|
104
|
-
def test_all_memorysets(
|
|
120
|
+
def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
105
121
|
memorysets = LabeledMemoryset.all()
|
|
106
122
|
assert len(memorysets) > 0
|
|
107
|
-
assert any(memoryset.name ==
|
|
123
|
+
assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
|
|
108
124
|
|
|
109
125
|
|
|
110
126
|
def test_all_memorysets_unauthenticated(unauthenticated):
|
|
@@ -112,51 +128,39 @@ def test_all_memorysets_unauthenticated(unauthenticated):
|
|
|
112
128
|
LabeledMemoryset.all()
|
|
113
129
|
|
|
114
130
|
|
|
115
|
-
def test_all_memorysets_unauthorized(unauthorized,
|
|
116
|
-
assert
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def test_drop_memoryset(hf_dataset):
|
|
120
|
-
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
121
|
-
"test_memoryset_delete",
|
|
122
|
-
hf_dataset.select(range(1)),
|
|
123
|
-
value_column="text",
|
|
124
|
-
)
|
|
125
|
-
assert LabeledMemoryset.exists(memoryset.name)
|
|
126
|
-
LabeledMemoryset.drop(memoryset.name)
|
|
127
|
-
assert not LabeledMemoryset.exists(memoryset.name)
|
|
131
|
+
def test_all_memorysets_unauthorized(unauthorized, readonly_memoryset):
|
|
132
|
+
assert readonly_memoryset not in LabeledMemoryset.all()
|
|
128
133
|
|
|
129
134
|
|
|
130
|
-
def test_drop_memoryset_unauthenticated(unauthenticated,
|
|
135
|
+
def test_drop_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
|
|
131
136
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
132
|
-
LabeledMemoryset.drop(
|
|
137
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
133
138
|
|
|
134
139
|
|
|
135
|
-
def test_drop_memoryset_not_found(
|
|
140
|
+
def test_drop_memoryset_not_found():
|
|
136
141
|
with pytest.raises(LookupError):
|
|
137
142
|
LabeledMemoryset.drop(str(uuid4()))
|
|
138
143
|
# ignores error if specified
|
|
139
144
|
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
140
145
|
|
|
141
146
|
|
|
142
|
-
def test_drop_memoryset_unauthorized(unauthorized,
|
|
147
|
+
def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
|
|
143
148
|
with pytest.raises(LookupError):
|
|
144
|
-
LabeledMemoryset.drop(
|
|
149
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
145
150
|
|
|
146
151
|
|
|
147
|
-
def test_update_memoryset_metadata(
|
|
148
|
-
|
|
149
|
-
assert memoryset.description == "New description"
|
|
152
|
+
def test_update_memoryset_metadata(writable_memoryset: LabeledMemoryset):
|
|
153
|
+
# NOTE: We're combining multiple tests into one here to avoid multiple API calls
|
|
150
154
|
|
|
155
|
+
writable_memoryset.update_metadata(description="New description")
|
|
156
|
+
assert writable_memoryset.description == "New description"
|
|
151
157
|
|
|
152
|
-
|
|
153
|
-
assert
|
|
154
|
-
memoryset.update_metadata(description=None)
|
|
155
|
-
assert memoryset.description is None
|
|
158
|
+
writable_memoryset.update_metadata(description=None)
|
|
159
|
+
assert writable_memoryset.description is None
|
|
156
160
|
|
|
157
161
|
|
|
158
|
-
def test_search(
|
|
159
|
-
memory_lookups =
|
|
162
|
+
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
163
|
+
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
160
164
|
assert len(memory_lookups) == 2
|
|
161
165
|
assert len(memory_lookups[0]) == 1
|
|
162
166
|
assert len(memory_lookups[1]) == 1
|
|
@@ -164,53 +168,53 @@ def test_search(memoryset: LabeledMemoryset):
|
|
|
164
168
|
assert memory_lookups[1][0].label == 1
|
|
165
169
|
|
|
166
170
|
|
|
167
|
-
def test_search_count(
|
|
168
|
-
memory_lookups =
|
|
171
|
+
def test_search_count(readonly_memoryset: LabeledMemoryset):
|
|
172
|
+
memory_lookups = readonly_memoryset.search("i love soup", count=3)
|
|
169
173
|
assert len(memory_lookups) == 3
|
|
170
174
|
assert memory_lookups[0].label == 0
|
|
171
175
|
assert memory_lookups[1].label == 0
|
|
172
176
|
assert memory_lookups[2].label == 0
|
|
173
177
|
|
|
174
178
|
|
|
175
|
-
def test_get_memory_at_index(
|
|
176
|
-
memory =
|
|
177
|
-
assert memory.value == hf_dataset[0]["
|
|
179
|
+
def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
180
|
+
memory = readonly_memoryset[0]
|
|
181
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
178
182
|
assert memory.label == hf_dataset[0]["label"]
|
|
179
183
|
assert memory.label_name == label_names[hf_dataset[0]["label"]]
|
|
180
184
|
assert memory.source_id == hf_dataset[0]["source_id"]
|
|
181
185
|
assert memory.score == hf_dataset[0]["score"]
|
|
182
186
|
assert memory.key == hf_dataset[0]["key"]
|
|
183
|
-
last_memory =
|
|
184
|
-
assert last_memory.value == hf_dataset[-1]["
|
|
187
|
+
last_memory = readonly_memoryset[-1]
|
|
188
|
+
assert last_memory.value == hf_dataset[-1]["value"]
|
|
185
189
|
assert last_memory.label == hf_dataset[-1]["label"]
|
|
186
190
|
|
|
187
191
|
|
|
188
|
-
def test_get_range_of_memories(
|
|
189
|
-
memories =
|
|
192
|
+
def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
193
|
+
memories = readonly_memoryset[1:3]
|
|
190
194
|
assert len(memories) == 2
|
|
191
|
-
assert memories[0].value == hf_dataset["
|
|
192
|
-
assert memories[1].value == hf_dataset["
|
|
195
|
+
assert memories[0].value == hf_dataset["value"][1]
|
|
196
|
+
assert memories[1].value == hf_dataset["value"][2]
|
|
193
197
|
|
|
194
198
|
|
|
195
|
-
def test_get_memory_by_id(
|
|
196
|
-
memory =
|
|
197
|
-
assert memory.value == hf_dataset[0]["
|
|
198
|
-
assert memory ==
|
|
199
|
+
def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
200
|
+
memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
|
|
201
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
202
|
+
assert memory == readonly_memoryset[memory.memory_id]
|
|
199
203
|
|
|
200
204
|
|
|
201
|
-
def test_get_memories_by_id(
|
|
202
|
-
memories =
|
|
205
|
+
def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
206
|
+
memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
|
|
203
207
|
assert len(memories) == 2
|
|
204
|
-
assert memories[0].value == hf_dataset[0]["
|
|
205
|
-
assert memories[1].value == hf_dataset[1]["
|
|
208
|
+
assert memories[0].value == hf_dataset[0]["value"]
|
|
209
|
+
assert memories[1].value == hf_dataset[1]["value"]
|
|
206
210
|
|
|
207
211
|
|
|
208
|
-
def test_query_memoryset(
|
|
209
|
-
memories =
|
|
212
|
+
def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
213
|
+
memories = readonly_memoryset.query(filters=[("label", "==", 1)])
|
|
210
214
|
assert len(memories) == 8
|
|
211
215
|
assert all(memory.label == 1 for memory in memories)
|
|
212
|
-
assert len(
|
|
213
|
-
assert len(
|
|
216
|
+
assert len(readonly_memoryset.query(limit=2)) == 2
|
|
217
|
+
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
|
|
214
218
|
|
|
215
219
|
|
|
216
220
|
def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
|
|
@@ -268,19 +272,21 @@ def test_query_memoryset_with_feedback_metrics_sort(model: ClassificationModel):
|
|
|
268
272
|
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
269
273
|
|
|
270
274
|
|
|
271
|
-
def test_insert_memories(
|
|
272
|
-
|
|
273
|
-
prev_length =
|
|
274
|
-
|
|
275
|
+
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
276
|
+
writable_memoryset.refresh()
|
|
277
|
+
prev_length = writable_memoryset.length
|
|
278
|
+
writable_memoryset.insert(
|
|
275
279
|
[
|
|
276
280
|
dict(value="tomato soup is my favorite", label=0),
|
|
277
281
|
dict(value="cats are fun to play with", label=1),
|
|
278
282
|
]
|
|
279
283
|
)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
+
writable_memoryset.refresh()
|
|
285
|
+
assert writable_memoryset.length == prev_length + 2
|
|
286
|
+
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
287
|
+
writable_memoryset.refresh()
|
|
288
|
+
assert writable_memoryset.length == prev_length + 3
|
|
289
|
+
last_memory = writable_memoryset[-1]
|
|
284
290
|
assert last_memory.value == "tomato soup is my favorite"
|
|
285
291
|
assert last_memory.label == 0
|
|
286
292
|
assert last_memory.metadata
|
|
@@ -288,25 +294,26 @@ def test_insert_memories(memoryset: LabeledMemoryset):
|
|
|
288
294
|
assert last_memory.source_id == "test"
|
|
289
295
|
|
|
290
296
|
|
|
291
|
-
def
|
|
292
|
-
|
|
293
|
-
|
|
297
|
+
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
298
|
+
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
299
|
+
|
|
300
|
+
# test updating a single memory
|
|
301
|
+
memory_id = writable_memoryset[0].memory_id
|
|
302
|
+
updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
294
303
|
assert updated_memory.value == "i love soup so much"
|
|
295
304
|
assert updated_memory.label == hf_dataset[0]["label"]
|
|
296
|
-
assert
|
|
297
|
-
|
|
305
|
+
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
298
306
|
|
|
299
|
-
|
|
300
|
-
memory =
|
|
307
|
+
# test updating a memory instance
|
|
308
|
+
memory = writable_memoryset[0]
|
|
301
309
|
updated_memory = memory.update(value="i love soup even more")
|
|
302
310
|
assert updated_memory is memory
|
|
303
311
|
assert memory.value == "i love soup even more"
|
|
304
312
|
assert memory.label == hf_dataset[0]["label"]
|
|
305
313
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
updated_memories = memoryset.update(
|
|
314
|
+
# test updating multiple memories
|
|
315
|
+
memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
|
|
316
|
+
updated_memories = writable_memoryset.update(
|
|
310
317
|
[
|
|
311
318
|
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
312
319
|
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
@@ -316,35 +323,37 @@ def test_update_memories(memoryset: LabeledMemoryset):
|
|
|
316
323
|
assert updated_memories[1].value == "cats are so cute"
|
|
317
324
|
|
|
318
325
|
|
|
319
|
-
def
|
|
320
|
-
|
|
321
|
-
prev_length = memoryset.length
|
|
322
|
-
memory_id = memoryset[0].memory_id
|
|
323
|
-
memoryset.delete(memory_id)
|
|
324
|
-
with pytest.raises(LookupError):
|
|
325
|
-
memoryset.get(memory_id)
|
|
326
|
-
assert memoryset.length == prev_length - 1
|
|
326
|
+
def test_delete_memories(writable_memoryset: LabeledMemoryset):
|
|
327
|
+
# We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
327
328
|
|
|
329
|
+
# test deleting a single memory
|
|
330
|
+
prev_length = writable_memoryset.length
|
|
331
|
+
memory_id = writable_memoryset[0].memory_id
|
|
332
|
+
writable_memoryset.delete(memory_id)
|
|
333
|
+
with pytest.raises(LookupError):
|
|
334
|
+
writable_memoryset.get(memory_id)
|
|
335
|
+
assert writable_memoryset.length == prev_length - 1
|
|
328
336
|
|
|
329
|
-
|
|
330
|
-
prev_length =
|
|
331
|
-
|
|
332
|
-
assert
|
|
337
|
+
# test deleting multiple memories
|
|
338
|
+
prev_length = writable_memoryset.length
|
|
339
|
+
writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id])
|
|
340
|
+
assert writable_memoryset.length == prev_length - 2
|
|
333
341
|
|
|
334
342
|
|
|
335
|
-
def test_clone_memoryset(
|
|
336
|
-
cloned_memoryset =
|
|
343
|
+
def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
344
|
+
cloned_memoryset = readonly_memoryset.clone(
|
|
345
|
+
"test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
|
|
346
|
+
)
|
|
337
347
|
assert cloned_memoryset is not None
|
|
338
348
|
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
339
|
-
assert cloned_memoryset.length ==
|
|
349
|
+
assert cloned_memoryset.length == readonly_memoryset.length
|
|
340
350
|
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
341
351
|
assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
|
|
342
352
|
|
|
343
353
|
|
|
344
|
-
def test_embedding_evaluation(
|
|
345
|
-
datasource = Datasource.from_hf_dataset("eval_datasource", hf_dataset, if_exists="open")
|
|
354
|
+
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
346
355
|
response = LabeledMemoryset.run_embedding_evaluation(
|
|
347
|
-
|
|
356
|
+
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=2
|
|
348
357
|
)
|
|
349
358
|
assert response is not None
|
|
350
359
|
assert isinstance(response, dict)
|
|
@@ -358,8 +367,8 @@ def test_embedding_evaluation(hf_dataset):
|
|
|
358
367
|
|
|
359
368
|
|
|
360
369
|
@pytest.fixture(scope="function")
|
|
361
|
-
async def test_group_potential_duplicates(
|
|
362
|
-
|
|
370
|
+
async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
371
|
+
writable_memoryset.insert(
|
|
363
372
|
[
|
|
364
373
|
dict(value="raspberry soup Is my favorite", label=0),
|
|
365
374
|
dict(value="Raspberry soup is MY favorite", label=0),
|
|
@@ -375,7 +384,74 @@ async def test_group_potential_duplicates(memoryset: LabeledMemoryset):
|
|
|
375
384
|
]
|
|
376
385
|
)
|
|
377
386
|
|
|
378
|
-
|
|
379
|
-
response =
|
|
387
|
+
writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
|
|
388
|
+
response = writable_memoryset.get_potential_duplicate_groups()
|
|
380
389
|
assert isinstance(response, list)
|
|
381
390
|
assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
394
|
+
# Insert a memory to test cascading edits
|
|
395
|
+
SOUP = 0
|
|
396
|
+
CATS = 1
|
|
397
|
+
query_text = "i love soup" # from SAMPLE_DATA in conftest.py
|
|
398
|
+
mislabeled_soup_text = "soup is comfort in a bowl"
|
|
399
|
+
writable_memoryset.insert(
|
|
400
|
+
[
|
|
401
|
+
dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
|
|
402
|
+
]
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Fetch the memory to update
|
|
406
|
+
memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
|
|
407
|
+
|
|
408
|
+
# Update the label and get cascading edit suggestions
|
|
409
|
+
suggestions = writable_memoryset.get_cascading_edits_suggestions(
|
|
410
|
+
memory=memory,
|
|
411
|
+
old_label=CATS,
|
|
412
|
+
new_label=SOUP,
|
|
413
|
+
max_neighbors=10,
|
|
414
|
+
max_validation_neighbors=5,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Validate the suggestions
|
|
418
|
+
assert len(suggestions) == 1
|
|
419
|
+
assert suggestions[0].neighbor.value == mislabeled_soup_text
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
|
|
423
|
+
"""Test that analyze() raises ValueError for invalid analysis names"""
|
|
424
|
+
memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
425
|
+
|
|
426
|
+
# Test with string input
|
|
427
|
+
with pytest.raises(ValueError) as excinfo:
|
|
428
|
+
memoryset.analyze("invalid_name")
|
|
429
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
430
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
431
|
+
|
|
432
|
+
# Test with dict input
|
|
433
|
+
with pytest.raises(ValueError) as excinfo:
|
|
434
|
+
memoryset.analyze({"name": "invalid_name"})
|
|
435
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
436
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
437
|
+
|
|
438
|
+
# Test with multiple analyses where one is invalid
|
|
439
|
+
with pytest.raises(ValueError) as excinfo:
|
|
440
|
+
memoryset.analyze("duplicate", "invalid_name")
|
|
441
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
442
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
443
|
+
|
|
444
|
+
# Test with valid analysis names
|
|
445
|
+
result = memoryset.analyze("duplicate", "cluster")
|
|
446
|
+
assert isinstance(result, dict)
|
|
447
|
+
assert "duplicate" in result
|
|
448
|
+
assert "cluster" in result
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
452
|
+
# NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
|
|
453
|
+
# Otherwise, it would be recreated on the next test run if it were dropped earlier, and
|
|
454
|
+
# that's expensive.
|
|
455
|
+
assert LabeledMemoryset.exists(writable_memoryset.name)
|
|
456
|
+
LabeledMemoryset.drop(writable_memoryset.name)
|
|
457
|
+
assert not LabeledMemoryset.exists(writable_memoryset.name)
|
orca_sdk/telemetry.py
CHANGED
|
@@ -149,6 +149,7 @@ class LabelPrediction:
|
|
|
149
149
|
model: ClassificationModel | str,
|
|
150
150
|
telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
|
|
151
151
|
logits: list[float] | None = None,
|
|
152
|
+
input_value: str | list[list[float]] | None = None,
|
|
152
153
|
):
|
|
153
154
|
# for internal use only, do not document
|
|
154
155
|
from .classification_model import ClassificationModel
|
|
@@ -162,15 +163,14 @@ class LabelPrediction:
|
|
|
162
163
|
self.model = ClassificationModel.open(model) if isinstance(model, str) else model
|
|
163
164
|
self.__telemetry = telemetry if telemetry else None
|
|
164
165
|
self.logits = logits
|
|
166
|
+
self._input_value = input_value
|
|
165
167
|
|
|
166
168
|
def __repr__(self):
|
|
167
169
|
return (
|
|
168
170
|
"LabelPrediction({"
|
|
169
171
|
+ f"label: <{self.label_name}: {self.label}>, "
|
|
170
172
|
+ f"confidence: {self.confidence:.2f}, "
|
|
171
|
-
+ f"anomaly_score: {self.anomaly_score:.2f}, "
|
|
172
|
-
if self.anomaly_score is not None
|
|
173
|
-
else ""
|
|
173
|
+
+ (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
|
|
174
174
|
+ f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
|
|
175
175
|
+ "})"
|
|
176
176
|
)
|
|
@@ -188,6 +188,8 @@ class LabelPrediction:
|
|
|
188
188
|
|
|
189
189
|
@property
|
|
190
190
|
def input_value(self) -> str | list[list[float]] | None:
|
|
191
|
+
if self._input_value is not None:
|
|
192
|
+
return self._input_value
|
|
191
193
|
return self._telemetry.input_value
|
|
192
194
|
|
|
193
195
|
@property
|