orca-sdk 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +30 -0
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +634 -0
- orca_sdk/_shared/metrics_test.py +570 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +196 -0
- orca_sdk/_utils/analysis_ui_style.css +51 -0
- orca_sdk/_utils/auth.py +65 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +129 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +4104 -0
- orca_sdk/classification_model.py +1165 -0
- orca_sdk/classification_model_test.py +887 -0
- orca_sdk/client.py +4096 -0
- orca_sdk/conftest.py +382 -0
- orca_sdk/credentials.py +217 -0
- orca_sdk/credentials_test.py +121 -0
- orca_sdk/datasource.py +576 -0
- orca_sdk/datasource_test.py +463 -0
- orca_sdk/embedding_model.py +712 -0
- orca_sdk/embedding_model_test.py +206 -0
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +3811 -0
- orca_sdk/memoryset_test.py +1150 -0
- orca_sdk/regression_model.py +841 -0
- orca_sdk/regression_model_test.py +595 -0
- orca_sdk/telemetry.py +742 -0
- orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.9.dist-info/METADATA +98 -0
- orca_sdk-0.1.9.dist-info/RECORD +41 -0
- orca_sdk-0.1.9.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1150 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from uuid import uuid4
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from datasets.arrow_dataset import Dataset
|
|
6
|
+
|
|
7
|
+
from .classification_model import ClassificationModel
|
|
8
|
+
from .conftest import skip_in_ci, skip_in_prod
|
|
9
|
+
from .datasource import Datasource
|
|
10
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
11
|
+
from .memoryset import (
|
|
12
|
+
LabeledMemory,
|
|
13
|
+
LabeledMemoryset,
|
|
14
|
+
ScoredMemory,
|
|
15
|
+
ScoredMemoryset,
|
|
16
|
+
Status,
|
|
17
|
+
)
|
|
18
|
+
from .regression_model import RegressionModel
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
Test Performance Note:
|
|
22
|
+
|
|
23
|
+
Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
|
|
24
|
+
|
|
25
|
+
- Two fixtures are used to manage memorysets:
|
|
26
|
+
- `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
|
|
27
|
+
It should only be used in nullipotent tests.
|
|
28
|
+
- `writable_memoryset` is a function-scoped, regenerating fixture.
|
|
29
|
+
It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
|
|
30
|
+
|
|
31
|
+
- To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
|
|
32
|
+
For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
|
|
33
|
+
rather than separate `test_delete_single` and `test_delete_multiple` tests.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
38
|
+
assert readonly_memoryset is not None
|
|
39
|
+
assert readonly_memoryset.name == "test_readonly_memoryset"
|
|
40
|
+
assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
41
|
+
assert readonly_memoryset.label_names == label_names
|
|
42
|
+
assert readonly_memoryset.insertion_status == Status.COMPLETED
|
|
43
|
+
assert isinstance(readonly_memoryset.length, int)
|
|
44
|
+
assert readonly_memoryset.length == len(hf_dataset)
|
|
45
|
+
assert readonly_memoryset.index_type == "IVF_FLAT"
|
|
46
|
+
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_create_empty_labeled_memoryset():
|
|
50
|
+
name = f"test_empty_labeled_{uuid4()}"
|
|
51
|
+
label_names = ["negative", "positive"]
|
|
52
|
+
try:
|
|
53
|
+
memoryset = LabeledMemoryset.create(name, label_names=label_names, description="empty labeled test")
|
|
54
|
+
assert memoryset is not None
|
|
55
|
+
assert memoryset.name == name
|
|
56
|
+
assert memoryset.length == 0
|
|
57
|
+
assert memoryset.label_names == label_names
|
|
58
|
+
assert memoryset.insertion_status is None
|
|
59
|
+
|
|
60
|
+
# inserting should work on an empty memoryset
|
|
61
|
+
memoryset.insert(dict(value="i love soup", label=1, key="k1"))
|
|
62
|
+
memoryset.refresh()
|
|
63
|
+
assert memoryset.length == 1
|
|
64
|
+
m = memoryset[0]
|
|
65
|
+
assert isinstance(m, LabeledMemory)
|
|
66
|
+
assert m.value == "i love soup"
|
|
67
|
+
assert m.label == 1
|
|
68
|
+
assert m.label_name == "positive"
|
|
69
|
+
assert m.metadata.get("key") == "k1"
|
|
70
|
+
|
|
71
|
+
# if_exists="open" should re-open the same memoryset
|
|
72
|
+
reopened = LabeledMemoryset.create(name, label_names=label_names, if_exists="open")
|
|
73
|
+
assert reopened.id == memoryset.id
|
|
74
|
+
assert len(reopened) == 1
|
|
75
|
+
|
|
76
|
+
# if_exists="open" should raise if label_names mismatch
|
|
77
|
+
with pytest.raises(ValueError, match=r"label names|requested"):
|
|
78
|
+
LabeledMemoryset.create(name, label_names=["turtles", "frogs"], if_exists="open")
|
|
79
|
+
|
|
80
|
+
# if_exists="open" should raise if embedding_model mismatch
|
|
81
|
+
with pytest.raises(ValueError, match=r"embedding_model|requested"):
|
|
82
|
+
LabeledMemoryset.create(
|
|
83
|
+
name,
|
|
84
|
+
label_names=label_names,
|
|
85
|
+
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
86
|
+
if_exists="open",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# if_exists="error" should raise when it already exists
|
|
90
|
+
with pytest.raises(ValueError, match="already exists"):
|
|
91
|
+
LabeledMemoryset.create(name, label_names=label_names, if_exists="error")
|
|
92
|
+
finally:
|
|
93
|
+
LabeledMemoryset.drop(name, if_not_exists="ignore")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_create_empty_scored_memoryset():
|
|
97
|
+
name = f"test_empty_scored_{uuid4()}"
|
|
98
|
+
try:
|
|
99
|
+
memoryset = ScoredMemoryset.create(name, description="empty scored test")
|
|
100
|
+
assert memoryset is not None
|
|
101
|
+
assert memoryset.name == name
|
|
102
|
+
assert memoryset.length == 0
|
|
103
|
+
assert memoryset.insertion_status is None
|
|
104
|
+
|
|
105
|
+
# inserting should work on an empty memoryset
|
|
106
|
+
memoryset.insert(dict(value="i love soup", score=0.25, key="k1", label=0))
|
|
107
|
+
memoryset.refresh()
|
|
108
|
+
assert memoryset.length == 1
|
|
109
|
+
m = memoryset[0]
|
|
110
|
+
assert isinstance(m, ScoredMemory)
|
|
111
|
+
assert m.value == "i love soup"
|
|
112
|
+
assert m.score == 0.25
|
|
113
|
+
assert m.metadata.get("key") == "k1"
|
|
114
|
+
assert m.metadata.get("label") == 0
|
|
115
|
+
|
|
116
|
+
# if_exists="open" should re-open the same memoryset
|
|
117
|
+
reopened = ScoredMemoryset.create(name, if_exists="open")
|
|
118
|
+
assert reopened.id == memoryset.id
|
|
119
|
+
|
|
120
|
+
# if_exists="open" should raise if embedding_model mismatch
|
|
121
|
+
with pytest.raises(ValueError, match=r"embedding_model|requested"):
|
|
122
|
+
ScoredMemoryset.create(name, embedding_model=PretrainedEmbeddingModel.DISTILBERT, if_exists="open")
|
|
123
|
+
|
|
124
|
+
# if_exists="error" should raise when it already exists
|
|
125
|
+
with pytest.raises(ValueError, match="already exists"):
|
|
126
|
+
ScoredMemoryset.create(name, if_exists="error")
|
|
127
|
+
finally:
|
|
128
|
+
ScoredMemoryset.drop(name, if_not_exists="ignore")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
|
|
132
|
+
with unauthenticated_client.use():
|
|
133
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
134
|
+
LabeledMemoryset.create("test_memoryset", datasource=datasource)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def test_create_memoryset_invalid_input(datasource):
|
|
138
|
+
# invalid name
|
|
139
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
140
|
+
LabeledMemoryset.create("test memoryset", datasource=datasource)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
|
|
144
|
+
memoryset_name = readonly_memoryset.name
|
|
145
|
+
with pytest.raises(ValueError):
|
|
146
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
|
|
147
|
+
with pytest.raises(ValueError):
|
|
148
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
|
|
152
|
+
# invalid label names
|
|
153
|
+
with pytest.raises(ValueError):
|
|
154
|
+
LabeledMemoryset.from_hf_dataset(
|
|
155
|
+
readonly_memoryset.name,
|
|
156
|
+
hf_dataset,
|
|
157
|
+
label_names=["turtles", "frogs"],
|
|
158
|
+
if_exists="open",
|
|
159
|
+
)
|
|
160
|
+
# different embedding model
|
|
161
|
+
with pytest.raises(ValueError):
|
|
162
|
+
LabeledMemoryset.from_hf_dataset(
|
|
163
|
+
readonly_memoryset.name,
|
|
164
|
+
hf_dataset,
|
|
165
|
+
label_names=label_names,
|
|
166
|
+
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
167
|
+
if_exists="open",
|
|
168
|
+
)
|
|
169
|
+
opened_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
170
|
+
readonly_memoryset.name,
|
|
171
|
+
hf_dataset,
|
|
172
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
173
|
+
if_exists="open",
|
|
174
|
+
)
|
|
175
|
+
assert opened_memoryset is not None
|
|
176
|
+
assert opened_memoryset.name == readonly_memoryset.name
|
|
177
|
+
assert opened_memoryset.length == len(hf_dataset)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_if_exists_error_no_datasource_creation(
|
|
181
|
+
readonly_memoryset: LabeledMemoryset,
|
|
182
|
+
):
|
|
183
|
+
memoryset_name = readonly_memoryset.name
|
|
184
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
185
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
186
|
+
assert not Datasource.exists(datasource_name)
|
|
187
|
+
with pytest.raises(ValueError):
|
|
188
|
+
LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
|
|
189
|
+
assert not Datasource.exists(datasource_name)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def test_if_exists_open_reuses_existing_datasource(
|
|
193
|
+
readonly_memoryset: LabeledMemoryset,
|
|
194
|
+
):
|
|
195
|
+
memoryset_name = readonly_memoryset.name
|
|
196
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
197
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
198
|
+
assert not Datasource.exists(datasource_name)
|
|
199
|
+
reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
|
|
200
|
+
assert reopened.id == readonly_memoryset.id
|
|
201
|
+
assert not Datasource.exists(datasource_name)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_create_memoryset_string_label():
|
|
205
|
+
assert not LabeledMemoryset.exists("test_string_label")
|
|
206
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
207
|
+
"test_string_label",
|
|
208
|
+
Dataset.from_dict({"value": ["terrible", "great"], "label": ["negative", "positive"]}),
|
|
209
|
+
)
|
|
210
|
+
assert memoryset is not None
|
|
211
|
+
assert memoryset.length == 2
|
|
212
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
213
|
+
assert memoryset[0].label == 0
|
|
214
|
+
assert memoryset[1].label == 1
|
|
215
|
+
assert memoryset[0].label_name == "negative"
|
|
216
|
+
assert memoryset[1].label_name == "positive"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def test_create_memoryset_integer_label():
|
|
220
|
+
assert not LabeledMemoryset.exists("test_integer_label")
|
|
221
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
222
|
+
"test_integer_label",
|
|
223
|
+
Dataset.from_dict({"value": ["terrible", "great"], "label": [0, 1]}),
|
|
224
|
+
label_names=["negative", "positive"],
|
|
225
|
+
)
|
|
226
|
+
assert memoryset is not None
|
|
227
|
+
assert memoryset.length == 2
|
|
228
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
229
|
+
assert memoryset[0].label == 0
|
|
230
|
+
assert memoryset[1].label == 1
|
|
231
|
+
assert memoryset[0].label_name == "negative"
|
|
232
|
+
assert memoryset[1].label_name == "positive"
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def test_create_memoryset_null_labels():
|
|
236
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
237
|
+
"test_null_labels",
|
|
238
|
+
Dataset.from_dict({"value": ["terrible", "great"]}),
|
|
239
|
+
label_names=["negative", "positive"],
|
|
240
|
+
label_column=None,
|
|
241
|
+
)
|
|
242
|
+
assert memoryset is not None
|
|
243
|
+
assert memoryset.length == 2
|
|
244
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
245
|
+
assert memoryset[0].label is None
|
|
246
|
+
assert memoryset[1].label is None
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
250
|
+
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
251
|
+
assert fetched_memoryset is not None
|
|
252
|
+
assert fetched_memoryset.name == readonly_memoryset.name
|
|
253
|
+
assert fetched_memoryset.length == len(hf_dataset)
|
|
254
|
+
assert fetched_memoryset.index_type == "IVF_FLAT"
|
|
255
|
+
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
259
|
+
with unauthenticated_client.use():
|
|
260
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
261
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_open_memoryset_not_found():
|
|
265
|
+
with pytest.raises(LookupError):
|
|
266
|
+
LabeledMemoryset.open(str(uuid4()))
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def test_open_memoryset_invalid_input():
|
|
270
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
271
|
+
LabeledMemoryset.open("not valid id")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
275
|
+
with unauthorized_client.use():
|
|
276
|
+
with pytest.raises(LookupError):
|
|
277
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
281
|
+
memorysets = LabeledMemoryset.all()
|
|
282
|
+
assert len(memorysets) > 0
|
|
283
|
+
assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def test_all_memorysets_hidden(
|
|
287
|
+
readonly_memoryset: LabeledMemoryset,
|
|
288
|
+
):
|
|
289
|
+
# Create a hidden memoryset
|
|
290
|
+
hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
|
|
291
|
+
hidden_memoryset.set(hidden=True)
|
|
292
|
+
|
|
293
|
+
# Test that show_hidden=False excludes hidden memorysets
|
|
294
|
+
visible_memorysets = LabeledMemoryset.all(show_hidden=False)
|
|
295
|
+
assert len(visible_memorysets) > 0
|
|
296
|
+
assert readonly_memoryset in visible_memorysets
|
|
297
|
+
assert hidden_memoryset not in visible_memorysets
|
|
298
|
+
|
|
299
|
+
# Test that show_hidden=True includes hidden memorysets
|
|
300
|
+
all_memorysets = LabeledMemoryset.all(show_hidden=True)
|
|
301
|
+
assert len(all_memorysets) == len(visible_memorysets) + 1
|
|
302
|
+
assert readonly_memoryset in all_memorysets
|
|
303
|
+
assert hidden_memoryset in all_memorysets
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def test_all_memorysets_unauthenticated(unauthenticated_client):
|
|
307
|
+
with unauthenticated_client.use():
|
|
308
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
309
|
+
LabeledMemoryset.all()
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
|
|
313
|
+
with unauthorized_client.use():
|
|
314
|
+
assert readonly_memoryset not in LabeledMemoryset.all()
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
318
|
+
with unauthenticated_client.use():
|
|
319
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
320
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def test_drop_memoryset_not_found():
|
|
324
|
+
with pytest.raises(LookupError):
|
|
325
|
+
LabeledMemoryset.drop(str(uuid4()))
|
|
326
|
+
# ignores error if specified
|
|
327
|
+
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
331
|
+
with unauthorized_client.use():
|
|
332
|
+
with pytest.raises(LookupError):
|
|
333
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
337
|
+
original_label_names = writable_memoryset.label_names
|
|
338
|
+
writable_memoryset.set(description="New description")
|
|
339
|
+
assert writable_memoryset.description == "New description"
|
|
340
|
+
|
|
341
|
+
writable_memoryset.set(description=None)
|
|
342
|
+
assert writable_memoryset.description is None
|
|
343
|
+
|
|
344
|
+
writable_memoryset.set(name="New_name")
|
|
345
|
+
assert writable_memoryset.name == "New_name"
|
|
346
|
+
|
|
347
|
+
writable_memoryset.set(name="test_writable_memoryset")
|
|
348
|
+
assert writable_memoryset.name == "test_writable_memoryset"
|
|
349
|
+
|
|
350
|
+
assert writable_memoryset.label_names == original_label_names
|
|
351
|
+
|
|
352
|
+
writable_memoryset.set(label_names=["New label 1", "New label 2"])
|
|
353
|
+
assert writable_memoryset.label_names == ["New label 1", "New label 2"]
|
|
354
|
+
|
|
355
|
+
writable_memoryset.set(hidden=True)
|
|
356
|
+
assert writable_memoryset.hidden is True
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
360
|
+
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
361
|
+
assert len(memory_lookups) == 2
|
|
362
|
+
assert len(memory_lookups[0]) == 1
|
|
363
|
+
assert len(memory_lookups[1]) == 1
|
|
364
|
+
assert memory_lookups[0][0].label == 0
|
|
365
|
+
assert memory_lookups[1][0].label == 1
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def test_search_count(readonly_memoryset: LabeledMemoryset):
|
|
369
|
+
memory_lookups = readonly_memoryset.search("i love soup", count=3)
|
|
370
|
+
assert len(memory_lookups) == 3
|
|
371
|
+
assert memory_lookups[0].label == 0
|
|
372
|
+
assert memory_lookups[1].label == 0
|
|
373
|
+
assert memory_lookups[2].label == 0
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def test_search_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
377
|
+
# Search within a specific partition - use "soup" which appears in both p1 and p2
|
|
378
|
+
# Use exclude_global to ensure we only get results from the specified partition
|
|
379
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
380
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
381
|
+
)
|
|
382
|
+
assert len(memory_lookups) > 0
|
|
383
|
+
# All results should be from partition p1 when partition_id is specified
|
|
384
|
+
assert all(
|
|
385
|
+
memory.partition_id == "p1" for memory in memory_lookups
|
|
386
|
+
), f"Expected all results from partition p1, but got: {[m.partition_id for m in memory_lookups]}"
|
|
387
|
+
|
|
388
|
+
# Search in a different partition - use "cats" which appears in both p1 and p2
|
|
389
|
+
memory_lookups_p2 = readonly_partitioned_memoryset.search(
|
|
390
|
+
"cats", partition_id="p2", partition_filter_mode="exclude_global", count=5
|
|
391
|
+
)
|
|
392
|
+
assert len(memory_lookups_p2) > 0
|
|
393
|
+
# All results should be from partition p2 when partition_id is specified
|
|
394
|
+
assert all(
|
|
395
|
+
memory.partition_id == "p2" for memory in memory_lookups_p2
|
|
396
|
+
), f"Expected all results from partition p2, but got: {[m.partition_id for m in memory_lookups_p2]}"
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def test_search_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
400
|
+
# Search excluding global memories - need to specify a partition_id when using exclude_global
|
|
401
|
+
# This tests that exclude_global works with a specific partition
|
|
402
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
403
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
404
|
+
)
|
|
405
|
+
assert len(memory_lookups) > 0
|
|
406
|
+
# All results should have a partition_id (not None) and be from p1
|
|
407
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def test_search_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
411
|
+
# Search only in global memories (partition_id=None in the data)
|
|
412
|
+
# Use a query that matches global memories and a reasonable count
|
|
413
|
+
memory_lookups = readonly_partitioned_memoryset.search("beach", partition_filter_mode="only_global", count=3)
|
|
414
|
+
# Should get at least some results (may be fewer than requested if not enough global memories match)
|
|
415
|
+
assert len(memory_lookups) > 0
|
|
416
|
+
# All results should be global (partition_id is None)
|
|
417
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
418
|
+
# When using only_global, all results should be global (either None)
|
|
419
|
+
assert all(
|
|
420
|
+
memory.partition_id is None for memory in memory_lookups
|
|
421
|
+
), f"Expected all results to be global (partition_id=None), but got partition_ids: {partition_ids}"
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def test_search_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
425
|
+
# Search including global memories (default behavior)
|
|
426
|
+
# Use a reasonable count that won't exceed available memories
|
|
427
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
428
|
+
"i love soup", partition_filter_mode="include_global", count=5
|
|
429
|
+
)
|
|
430
|
+
assert len(memory_lookups) > 0
|
|
431
|
+
# Results can include both partitioned and global memories
|
|
432
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
433
|
+
# Should have at least one partition or global memory
|
|
434
|
+
assert len(partition_ids) > 0
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def test_search_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
438
|
+
# Search ignoring partition filtering entirely
|
|
439
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
440
|
+
"i love soup", partition_filter_mode="ignore_partitions", count=10
|
|
441
|
+
)
|
|
442
|
+
assert len(memory_lookups) > 0
|
|
443
|
+
# Results can come from any partition or global
|
|
444
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
445
|
+
# Should have results from multiple partitions/global
|
|
446
|
+
assert len(partition_ids) >= 1
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def test_search_multiple_queries_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
450
|
+
# Search multiple queries within a specific partition
|
|
451
|
+
memory_lookups = readonly_partitioned_memoryset.search(["i love soup", "cats are cute"], partition_id="p1", count=3)
|
|
452
|
+
assert len(memory_lookups) == 2
|
|
453
|
+
assert len(memory_lookups[0]) > 0
|
|
454
|
+
assert len(memory_lookups[1]) > 0
|
|
455
|
+
# All results should be from partition p1
|
|
456
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups[0])
|
|
457
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups[1])
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def test_search_with_partition_id_and_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
461
|
+
# When partition_id is specified, partition_filter_mode should still work
|
|
462
|
+
# Search in p1 with exclude_global (should only return p1 results)
|
|
463
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
464
|
+
"i love soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
465
|
+
)
|
|
466
|
+
assert len(memory_lookups) > 0
|
|
467
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
471
|
+
memory = readonly_memoryset[0]
|
|
472
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
473
|
+
assert memory.label == hf_dataset[0]["label"]
|
|
474
|
+
assert memory.label_name == label_names[hf_dataset[0]["label"]]
|
|
475
|
+
assert memory.source_id == hf_dataset[0]["source_id"]
|
|
476
|
+
assert memory.score == hf_dataset[0]["score"]
|
|
477
|
+
assert memory.key == hf_dataset[0]["key"]
|
|
478
|
+
last_memory = readonly_memoryset[-1]
|
|
479
|
+
assert last_memory.value == hf_dataset[-1]["value"]
|
|
480
|
+
assert last_memory.label == hf_dataset[-1]["label"]
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
484
|
+
memories = readonly_memoryset[1:3]
|
|
485
|
+
assert len(memories) == 2
|
|
486
|
+
assert memories[0].value == hf_dataset["value"][1]
|
|
487
|
+
assert memories[1].value == hf_dataset["value"][2]
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
491
|
+
memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
|
|
492
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
493
|
+
assert memory == readonly_memoryset[memory.memory_id]
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
497
|
+
memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
|
|
498
|
+
assert len(memories) == 2
|
|
499
|
+
assert memories[0].value == hf_dataset[0]["value"]
|
|
500
|
+
assert memories[1].value == hf_dataset[1]["value"]
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
504
|
+
memories = readonly_memoryset.query(filters=[("label", "==", 1)])
|
|
505
|
+
assert len(memories) == 8
|
|
506
|
+
assert all(memory.label == 1 for memory in memories)
|
|
507
|
+
assert len(readonly_memoryset.query(limit=2)) == 2
|
|
508
|
+
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
|
|
512
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
513
|
+
feedback_name = f"correct_{random.randint(0, 1000000)}"
|
|
514
|
+
prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
|
|
515
|
+
memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
|
|
516
|
+
|
|
517
|
+
# Get the memory_ids that were actually used in the prediction
|
|
518
|
+
used_memory_ids = {memory.memory_id for memory in prediction.memory_lookups}
|
|
519
|
+
|
|
520
|
+
assert len(memories) == 8
|
|
521
|
+
assert all(memory.label == 0 for memory in memories)
|
|
522
|
+
for memory in memories:
|
|
523
|
+
assert memory.feedback_metrics is not None
|
|
524
|
+
if memory.memory_id in used_memory_ids:
|
|
525
|
+
assert feedback_name in memory.feedback_metrics
|
|
526
|
+
assert memory.feedback_metrics[feedback_name]["avg"] == 1.0
|
|
527
|
+
assert memory.feedback_metrics[feedback_name]["count"] == 1
|
|
528
|
+
else:
|
|
529
|
+
assert feedback_name not in memory.feedback_metrics or memory.feedback_metrics[feedback_name]["count"] == 0
|
|
530
|
+
assert isinstance(memory.lookup_count, int)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
|
|
534
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
535
|
+
prediction.record_feedback(category="accurate", value=prediction.label == 0)
|
|
536
|
+
memories = prediction.memoryset.query(
|
|
537
|
+
filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
|
|
538
|
+
)
|
|
539
|
+
assert len(memories) == 3
|
|
540
|
+
assert all(memory.label == 0 for memory in memories)
|
|
541
|
+
for memory in memories:
|
|
542
|
+
assert memory.feedback_metrics is not None
|
|
543
|
+
assert memory.feedback_metrics["accurate"] is not None
|
|
544
|
+
assert memory.feedback_metrics["accurate"]["avg"] == 1.0
|
|
545
|
+
assert memory.feedback_metrics["accurate"]["count"] == 1
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
|
|
549
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
550
|
+
prediction.record_feedback(category="positive", value=1.0)
|
|
551
|
+
prediction2 = classification_model.predict("Do you like cats?")
|
|
552
|
+
prediction2.record_feedback(category="positive", value=-1.0)
|
|
553
|
+
|
|
554
|
+
memories = prediction.memoryset.query(
|
|
555
|
+
filters=[("feedback_metrics.positive.avg", ">=", -1.0)],
|
|
556
|
+
sort=[("feedback_metrics.positive.avg", "desc")],
|
|
557
|
+
with_feedback_metrics=True,
|
|
558
|
+
)
|
|
559
|
+
assert (
|
|
560
|
+
len(memories) == 6
|
|
561
|
+
) # there are only 6 out of 16 memories that have a positive feedback metric. Look at SAMPLE_DATA in conftest.py
|
|
562
|
+
assert memories[0].feedback_metrics["positive"]["avg"] == 1.0
|
|
563
|
+
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def test_query_memoryset_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
567
|
+
# Query with partition_id and include_global (default) - includes both p1 and global memories
|
|
568
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1")
|
|
569
|
+
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
570
|
+
# Results should include both p1 and global memories
|
|
571
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
572
|
+
assert "p1" in partition_ids
|
|
573
|
+
assert None in partition_ids
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def test_query_memoryset_with_partition_id_and_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
577
|
+
# Query with partition_id and exclude_global mode - only returns p1 memories
|
|
578
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
579
|
+
assert len(memories) == 8 # Only 8 p1 memories (no global)
|
|
580
|
+
# All results should be from partition p1 (no global memories)
|
|
581
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def test_query_memoryset_with_partition_id_and_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
585
|
+
# Query with partition_id and include_global mode (default) - includes both p1 and global
|
|
586
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="include_global")
|
|
587
|
+
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
588
|
+
# Results should include both p1 and global memories
|
|
589
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
590
|
+
assert "p1" in partition_ids
|
|
591
|
+
assert None in partition_ids
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def test_query_memoryset_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
595
|
+
# Query excluding global memories requires a partition_id
|
|
596
|
+
# Test with a specific partition_id
|
|
597
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
598
|
+
assert len(memories) == 8 # Only p1 memories
|
|
599
|
+
# All results should have a partition_id (not global)
|
|
600
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def test_query_memoryset_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
604
|
+
# Query only in global memories
|
|
605
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="only_global")
|
|
606
|
+
assert len(memories) == 7 # There are 7 global memories in SAMPLE_DATA
|
|
607
|
+
# All results should be global (partition_id is None)
|
|
608
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def test_query_memoryset_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
612
|
+
# Query including global memories - when no partition_id is specified,
|
|
613
|
+
# include_global seems to only return global memories
|
|
614
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="include_global")
|
|
615
|
+
# Based on actual behavior, this returns only global memories
|
|
616
|
+
assert len(memories) == 7
|
|
617
|
+
# All results should be global
|
|
618
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def test_query_memoryset_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
622
|
+
# Query ignoring partition filtering entirely - returns all memories
|
|
623
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="ignore_partitions", limit=100)
|
|
624
|
+
assert len(memories) == 22 # All 22 memories
|
|
625
|
+
# Results can come from any partition or global
|
|
626
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
627
|
+
# Should have results from multiple partitions/global
|
|
628
|
+
assert len(partition_ids) >= 1
|
|
629
|
+
# Verify we have p1, p2, and global
|
|
630
|
+
assert "p1" in partition_ids
|
|
631
|
+
assert "p2" in partition_ids
|
|
632
|
+
assert None in partition_ids
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def test_query_memoryset_with_filters_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
636
|
+
# Query with filters and partition_id
|
|
637
|
+
memories = readonly_partitioned_memoryset.query(filters=[("label", "==", 0)], partition_id="p1")
|
|
638
|
+
assert len(memories) > 0
|
|
639
|
+
# All results should match the filter and be from partition p1
|
|
640
|
+
assert all(memory.label == 0 for memory in memories)
|
|
641
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def test_query_memoryset_with_filters_and_partition_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
645
|
+
# Query with filters and partition_filter_mode - exclude_global requires partition_id
|
|
646
|
+
memories = readonly_partitioned_memoryset.query(
|
|
647
|
+
filters=[("label", "==", 1)], partition_id="p1", partition_filter_mode="exclude_global"
|
|
648
|
+
)
|
|
649
|
+
assert len(memories) > 0
|
|
650
|
+
# All results should match the filter and be from p1 (not global)
|
|
651
|
+
assert all(memory.label == 1 for memory in memories)
|
|
652
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def test_query_memoryset_with_limit_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
656
|
+
# Query with limit and partition_id
|
|
657
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p2", limit=3)
|
|
658
|
+
assert len(memories) == 3
|
|
659
|
+
# All results should be from partition p2
|
|
660
|
+
assert all(memory.partition_id == "p2" for memory in memories)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def test_query_memoryset_with_offset_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
664
|
+
# Query with offset and partition_id - use exclude_global to get only p1 memories
|
|
665
|
+
memories_page1 = readonly_partitioned_memoryset.query(
|
|
666
|
+
partition_id="p1", partition_filter_mode="exclude_global", limit=5
|
|
667
|
+
)
|
|
668
|
+
memories_page2 = readonly_partitioned_memoryset.query(
|
|
669
|
+
partition_id="p1", partition_filter_mode="exclude_global", offset=5, limit=5
|
|
670
|
+
)
|
|
671
|
+
assert len(memories_page1) == 5
|
|
672
|
+
assert len(memories_page2) == 3 # Only 3 remaining p1 memories (8 total - 5 = 3)
|
|
673
|
+
# All results should be from partition p1
|
|
674
|
+
assert all(memory.partition_id == "p1" for memory in memories_page1)
|
|
675
|
+
assert all(memory.partition_id == "p1" for memory in memories_page2)
|
|
676
|
+
# Results should be different (pagination works)
|
|
677
|
+
memory_ids_page1 = {memory.memory_id for memory in memories_page1}
|
|
678
|
+
memory_ids_page2 = {memory.memory_id for memory in memories_page2}
|
|
679
|
+
assert memory_ids_page1.isdisjoint(memory_ids_page2)
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def test_query_memoryset_with_partition_id_p2(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
683
|
+
# Query a different partition to verify it works
|
|
684
|
+
# With include_global (default), it includes both p2 and global memories
|
|
685
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p2")
|
|
686
|
+
assert len(memories) == 14 # 7 p2 + 7 global = 14
|
|
687
|
+
# Results should include both p2 and global memories
|
|
688
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
689
|
+
assert "p2" in partition_ids
|
|
690
|
+
assert None in partition_ids
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def test_query_memoryset_with_metadata_filter_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
694
|
+
# Query with metadata filter and partition_id
|
|
695
|
+
memories = readonly_partitioned_memoryset.query(filters=[("metadata.key", "==", "g1")], partition_id="p1")
|
|
696
|
+
assert len(memories) > 0
|
|
697
|
+
# All results should match the metadata filter and be from partition p1
|
|
698
|
+
assert all(memory.metadata.get("key") == "g1" for memory in memories)
|
|
699
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def test_query_memoryset_with_partition_filter_mode_only_global_and_filters(
|
|
703
|
+
readonly_partitioned_memoryset: LabeledMemoryset,
|
|
704
|
+
):
|
|
705
|
+
# Query only global memories with filters
|
|
706
|
+
memories = readonly_partitioned_memoryset.query(
|
|
707
|
+
filters=[("metadata.key", "==", "g3")], partition_filter_mode="only_global"
|
|
708
|
+
)
|
|
709
|
+
assert len(memories) > 0
|
|
710
|
+
# All results should match the filter and be global
|
|
711
|
+
assert all(memory.metadata.get("key") == "g3" for memory in memories)
|
|
712
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
|
|
716
|
+
"""Test that LabeledMemory.predictions() only returns classification predictions."""
|
|
717
|
+
# Given: A classification model with memories
|
|
718
|
+
memories = classification_model.memoryset.query(limit=1)
|
|
719
|
+
assert len(memories) > 0
|
|
720
|
+
memory = memories[0]
|
|
721
|
+
|
|
722
|
+
# When: I call the predictions method
|
|
723
|
+
predictions = memory.predictions()
|
|
724
|
+
|
|
725
|
+
# Then: It should return a list of ClassificationPrediction objects
|
|
726
|
+
assert isinstance(predictions, list)
|
|
727
|
+
for prediction in predictions:
|
|
728
|
+
assert prediction.__class__.__name__ == "ClassificationPrediction"
|
|
729
|
+
assert hasattr(prediction, "label")
|
|
730
|
+
assert not hasattr(prediction, "score") or prediction.score is None
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def test_scored_memory_predictions_property(regression_model: RegressionModel):
|
|
734
|
+
"""Test that ScoredMemory.predictions() only returns regression predictions."""
|
|
735
|
+
# Given: A regression model with memories
|
|
736
|
+
memories = regression_model.memoryset.query(limit=1)
|
|
737
|
+
assert len(memories) > 0
|
|
738
|
+
memory = memories[0]
|
|
739
|
+
|
|
740
|
+
# When: I call the predictions method
|
|
741
|
+
predictions = memory.predictions()
|
|
742
|
+
|
|
743
|
+
# Then: It should return a list of RegressionPrediction objects
|
|
744
|
+
assert isinstance(predictions, list)
|
|
745
|
+
for prediction in predictions:
|
|
746
|
+
assert prediction.__class__.__name__ == "RegressionPrediction"
|
|
747
|
+
assert hasattr(prediction, "score")
|
|
748
|
+
assert not hasattr(prediction, "label") or prediction.label is None
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def test_memory_feedback_property(classification_model: ClassificationModel):
|
|
752
|
+
"""Test that memory.feedback() returns feedback from relevant predictions."""
|
|
753
|
+
# Given: A prediction with recorded feedback
|
|
754
|
+
prediction = classification_model.predict("Test feedback")
|
|
755
|
+
feedback_category = f"test_feedback_{random.randint(0, 1000000)}"
|
|
756
|
+
prediction.record_feedback(category=feedback_category, value=True)
|
|
757
|
+
|
|
758
|
+
# And: A memory that was used in the prediction
|
|
759
|
+
memory_lookups = prediction.memory_lookups
|
|
760
|
+
assert len(memory_lookups) > 0
|
|
761
|
+
memory = memory_lookups[0]
|
|
762
|
+
|
|
763
|
+
# When: I access the feedback property
|
|
764
|
+
feedback = memory.feedback()
|
|
765
|
+
|
|
766
|
+
# Then: It should return feedback aggregated by category as a dict
|
|
767
|
+
assert isinstance(feedback, dict)
|
|
768
|
+
assert feedback_category in feedback
|
|
769
|
+
# Feedback values are lists (you may want to look at mean on the raw data)
|
|
770
|
+
assert isinstance(feedback[feedback_category], list)
|
|
771
|
+
assert len(feedback[feedback_category]) > 0
|
|
772
|
+
# For binary feedback, values should be booleans
|
|
773
|
+
assert isinstance(feedback[feedback_category][0], bool)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def test_memory_predictions_method_parameters(classification_model: ClassificationModel):
|
|
777
|
+
"""Test that memory.predictions() method supports pagination, sorting, and filtering."""
|
|
778
|
+
# Given: A classification model with memories
|
|
779
|
+
memories = classification_model.memoryset.query(limit=1)
|
|
780
|
+
assert len(memories) > 0
|
|
781
|
+
memory = memories[0]
|
|
782
|
+
|
|
783
|
+
# When: I call predictions with limit parameter
|
|
784
|
+
predictions_limited = memory.predictions(limit=2)
|
|
785
|
+
|
|
786
|
+
# Then: It should respect the limit
|
|
787
|
+
assert isinstance(predictions_limited, list)
|
|
788
|
+
assert len(predictions_limited) <= 2
|
|
789
|
+
|
|
790
|
+
# When: I call predictions with offset parameter
|
|
791
|
+
all_predictions = memory.predictions(limit=100)
|
|
792
|
+
if len(all_predictions) > 1:
|
|
793
|
+
predictions_offset = memory.predictions(limit=1, offset=1)
|
|
794
|
+
# Then: offset should skip the first prediction
|
|
795
|
+
assert predictions_offset[0].prediction_id != all_predictions[0].prediction_id
|
|
796
|
+
|
|
797
|
+
# When: I call predictions with sort parameter
|
|
798
|
+
predictions_sorted = memory.predictions(limit=10, sort=[("timestamp", "desc")])
|
|
799
|
+
# Then: It should return predictions (sorting verified by API)
|
|
800
|
+
assert isinstance(predictions_sorted, list)
|
|
801
|
+
|
|
802
|
+
# When: I call predictions with expected_label_match parameter
|
|
803
|
+
correct_predictions = memory.predictions(expected_label_match=True)
|
|
804
|
+
incorrect_predictions = memory.predictions(expected_label_match=False)
|
|
805
|
+
# Then: Both should return lists (correctness verified by API filtering)
|
|
806
|
+
assert isinstance(correct_predictions, list)
|
|
807
|
+
assert isinstance(incorrect_predictions, list)
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def test_memory_predictions_expected_label_filter(classification_model: ClassificationModel):
|
|
811
|
+
"""Test that memory.predictions(expected_label_match=...) filters predictions by correctness."""
|
|
812
|
+
# Given: Make an initial prediction to learn the model's label for a known input
|
|
813
|
+
baseline_prediction = classification_model.predict("Filter test input", save_telemetry="sync")
|
|
814
|
+
original_label = baseline_prediction.label
|
|
815
|
+
alternate_label = 0 if original_label else 1
|
|
816
|
+
|
|
817
|
+
# When: Make a second prediction with an intentionally incorrect expected label
|
|
818
|
+
mismatched_prediction = classification_model.predict(
|
|
819
|
+
"Filter test input",
|
|
820
|
+
expected_labels=alternate_label,
|
|
821
|
+
save_telemetry="sync",
|
|
822
|
+
)
|
|
823
|
+
mismatched_memory = mismatched_prediction.memory_lookups[0]
|
|
824
|
+
|
|
825
|
+
# Then: The prediction should show up when filtering for incorrect predictions
|
|
826
|
+
incorrect_predictions = mismatched_memory.predictions(expected_label_match=False)
|
|
827
|
+
assert any(pred.prediction_id == mismatched_prediction.prediction_id for pred in incorrect_predictions)
|
|
828
|
+
|
|
829
|
+
# Produce a correct prediction (predicted label matches expected label)
|
|
830
|
+
correct_prediction = classification_model.predict(
|
|
831
|
+
"Filter test input",
|
|
832
|
+
expected_labels=original_label,
|
|
833
|
+
save_telemetry="sync",
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Ensure we are inspecting a memory used by both correct and incorrect predictions
|
|
837
|
+
correct_lookup_ids = {lookup.memory_id for lookup in correct_prediction.memory_lookups}
|
|
838
|
+
if mismatched_memory.memory_id not in correct_lookup_ids:
|
|
839
|
+
shared_lookup = next(
|
|
840
|
+
(lookup for lookup in mismatched_prediction.memory_lookups if lookup.memory_id in correct_lookup_ids),
|
|
841
|
+
None,
|
|
842
|
+
)
|
|
843
|
+
assert shared_lookup is not None, "No shared memory lookup between correct and incorrect predictions"
|
|
844
|
+
mismatched_memory = shared_lookup
|
|
845
|
+
|
|
846
|
+
# And: The correct prediction should appear when filtering for correct predictions
|
|
847
|
+
correct_predictions = mismatched_memory.predictions(expected_label_match=True)
|
|
848
|
+
assert any(pred.prediction_id == correct_prediction.prediction_id for pred in correct_predictions)
|
|
849
|
+
assert all(pred.prediction_id != mismatched_prediction.prediction_id for pred in correct_predictions)
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
853
|
+
writable_memoryset.refresh()
|
|
854
|
+
prev_length = writable_memoryset.length
|
|
855
|
+
writable_memoryset.insert(
|
|
856
|
+
[
|
|
857
|
+
dict(value="tomato soup is my favorite", label=0),
|
|
858
|
+
dict(value="cats are fun to play with", label=1),
|
|
859
|
+
],
|
|
860
|
+
batch_size=1,
|
|
861
|
+
)
|
|
862
|
+
writable_memoryset.refresh()
|
|
863
|
+
assert writable_memoryset.length == prev_length + 2
|
|
864
|
+
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
865
|
+
writable_memoryset.refresh()
|
|
866
|
+
assert writable_memoryset.length == prev_length + 3
|
|
867
|
+
last_memory = writable_memoryset[-1]
|
|
868
|
+
assert last_memory.value == "tomato soup is my favorite"
|
|
869
|
+
assert last_memory.label == 0
|
|
870
|
+
assert last_memory.metadata
|
|
871
|
+
assert last_memory.metadata["key"] == "test"
|
|
872
|
+
assert last_memory.source_id == "test"
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
876
|
+
@skip_in_ci("CI environment may not have session consistency guarantees")
|
|
877
|
+
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
878
|
+
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
879
|
+
|
|
880
|
+
# test updating a single memory
|
|
881
|
+
memory_id = writable_memoryset[0].memory_id
|
|
882
|
+
updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
883
|
+
assert updated_memory.value == "i love soup so much"
|
|
884
|
+
assert updated_memory.label == hf_dataset[0]["label"]
|
|
885
|
+
writable_memoryset.refresh() # Refresh to ensure consistency after update
|
|
886
|
+
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
887
|
+
|
|
888
|
+
# test updating a memory instance
|
|
889
|
+
memory = writable_memoryset[0]
|
|
890
|
+
updated_memory = memory.update(value="i love soup even more")
|
|
891
|
+
assert updated_memory is memory
|
|
892
|
+
assert memory.value == "i love soup even more"
|
|
893
|
+
assert memory.label == hf_dataset[0]["label"]
|
|
894
|
+
|
|
895
|
+
# test updating multiple memories
|
|
896
|
+
memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
|
|
897
|
+
updated_memories = writable_memoryset.update(
|
|
898
|
+
[
|
|
899
|
+
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
900
|
+
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
901
|
+
],
|
|
902
|
+
batch_size=1,
|
|
903
|
+
)
|
|
904
|
+
assert updated_memories[0].value == "i love soup so much"
|
|
905
|
+
assert updated_memories[1].value == "cats are so cute"
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
def test_delete_memories(writable_memoryset: LabeledMemoryset):
|
|
909
|
+
# We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
910
|
+
|
|
911
|
+
# test deleting a single memory
|
|
912
|
+
prev_length = writable_memoryset.length
|
|
913
|
+
memory_id = writable_memoryset[0].memory_id
|
|
914
|
+
writable_memoryset.delete(memory_id)
|
|
915
|
+
with pytest.raises(LookupError):
|
|
916
|
+
writable_memoryset.get(memory_id)
|
|
917
|
+
assert writable_memoryset.length == prev_length - 1
|
|
918
|
+
|
|
919
|
+
# test deleting multiple memories
|
|
920
|
+
prev_length = writable_memoryset.length
|
|
921
|
+
writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id], batch_size=1)
|
|
922
|
+
assert writable_memoryset.length == prev_length - 2
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
926
|
+
cloned_memoryset = readonly_memoryset.clone(
|
|
927
|
+
"test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
|
|
928
|
+
)
|
|
929
|
+
assert cloned_memoryset is not None
|
|
930
|
+
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
931
|
+
assert cloned_memoryset.length == readonly_memoryset.length
|
|
932
|
+
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
933
|
+
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def test_clone_empty_memoryset():
|
|
937
|
+
name = f"test_empty_to_clone_{uuid4()}"
|
|
938
|
+
cloned_name = f"test_empty_cloned_{uuid4()}"
|
|
939
|
+
label_names = ["negative", "positive"]
|
|
940
|
+
try:
|
|
941
|
+
# Create an empty memoryset
|
|
942
|
+
empty_memoryset = LabeledMemoryset.create(name, label_names=label_names, description="empty memoryset to clone")
|
|
943
|
+
assert empty_memoryset is not None
|
|
944
|
+
assert empty_memoryset.name == name
|
|
945
|
+
assert empty_memoryset.length == 0
|
|
946
|
+
assert empty_memoryset.insertion_status is None # Empty memorysets have None status
|
|
947
|
+
|
|
948
|
+
# Clone the empty memoryset
|
|
949
|
+
cloned_memoryset = empty_memoryset.clone(cloned_name, embedding_model=PretrainedEmbeddingModel.DISTILBERT)
|
|
950
|
+
assert cloned_memoryset is not None
|
|
951
|
+
assert cloned_memoryset.name == cloned_name
|
|
952
|
+
assert cloned_memoryset.length == 0 # Clone should also be empty
|
|
953
|
+
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
954
|
+
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
955
|
+
assert cloned_memoryset.label_names == label_names
|
|
956
|
+
finally:
|
|
957
|
+
LabeledMemoryset.drop(name, if_not_exists="ignore")
|
|
958
|
+
LabeledMemoryset.drop(cloned_name, if_not_exists="ignore")
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
@pytest.fixture(scope="function")
|
|
962
|
+
async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
963
|
+
writable_memoryset.insert(
|
|
964
|
+
[
|
|
965
|
+
dict(value="raspberry soup Is my favorite", label=0),
|
|
966
|
+
dict(value="Raspberry soup is MY favorite", label=0),
|
|
967
|
+
dict(value="rAspberry soup is my favorite", label=0),
|
|
968
|
+
dict(value="raSpberry SOuP is my favorite", label=0),
|
|
969
|
+
dict(value="rasPberry SOuP is my favorite", label=0),
|
|
970
|
+
dict(value="bunny rabbit Is not my mom", label=1),
|
|
971
|
+
dict(value="bunny rabbit is not MY mom", label=1),
|
|
972
|
+
dict(value="bunny rabbit Is not my moM", label=1),
|
|
973
|
+
dict(value="bunny rabbit is not my mom", label=1),
|
|
974
|
+
dict(value="bunny rabbit is not my mom", label=1),
|
|
975
|
+
dict(value="bunny rabbit is not My mom", label=1),
|
|
976
|
+
]
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
|
|
980
|
+
response = writable_memoryset.get_potential_duplicate_groups()
|
|
981
|
+
assert isinstance(response, list)
|
|
982
|
+
assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
986
|
+
# Insert a memory to test cascading edits
|
|
987
|
+
SOUP = 0
|
|
988
|
+
CATS = 1
|
|
989
|
+
query_text = "i love soup" # from SAMPLE_DATA in conftest.py
|
|
990
|
+
mislabeled_soup_text = "soup is comfort in a bowl"
|
|
991
|
+
writable_memoryset.insert(
|
|
992
|
+
[
|
|
993
|
+
dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
|
|
994
|
+
]
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
# Fetch the memory to update
|
|
998
|
+
memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
|
|
999
|
+
|
|
1000
|
+
# Update the label and get cascading edit suggestions
|
|
1001
|
+
suggestions = writable_memoryset.get_cascading_edits_suggestions(
|
|
1002
|
+
memory=memory,
|
|
1003
|
+
old_label=CATS,
|
|
1004
|
+
new_label=SOUP,
|
|
1005
|
+
max_neighbors=10,
|
|
1006
|
+
max_validation_neighbors=5,
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
# Validate the suggestions
|
|
1010
|
+
assert len(suggestions) == 1
|
|
1011
|
+
assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
|
|
1015
|
+
"""Test that analyze() raises ValueError for invalid analysis names"""
|
|
1016
|
+
memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
1017
|
+
|
|
1018
|
+
# Test with string input
|
|
1019
|
+
with pytest.raises(ValueError) as excinfo:
|
|
1020
|
+
memoryset.analyze("invalid_name")
|
|
1021
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
1022
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
1023
|
+
|
|
1024
|
+
# Test with dict input
|
|
1025
|
+
with pytest.raises(ValueError) as excinfo:
|
|
1026
|
+
memoryset.analyze({"name": "invalid_name"})
|
|
1027
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
1028
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
1029
|
+
|
|
1030
|
+
# Test with multiple analyses where one is invalid
|
|
1031
|
+
with pytest.raises(ValueError) as excinfo:
|
|
1032
|
+
memoryset.analyze("duplicate", "invalid_name")
|
|
1033
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
1034
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
1035
|
+
|
|
1036
|
+
# Test with valid analysis names
|
|
1037
|
+
result = memoryset.analyze("duplicate", "cluster")
|
|
1038
|
+
assert isinstance(result, dict)
|
|
1039
|
+
assert "duplicate" in result
|
|
1040
|
+
assert "cluster" in result
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
1044
|
+
# NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
|
|
1045
|
+
# Otherwise, it would be recreated on the next test run if it were dropped earlier, and
|
|
1046
|
+
# that's expensive.
|
|
1047
|
+
assert LabeledMemoryset.exists(writable_memoryset.name)
|
|
1048
|
+
LabeledMemoryset.drop(writable_memoryset.name)
|
|
1049
|
+
assert not LabeledMemoryset.exists(writable_memoryset.name)
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
1053
|
+
assert scored_memoryset.length == 22
|
|
1054
|
+
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
1055
|
+
assert scored_memoryset[0].value == "i love soup"
|
|
1056
|
+
assert scored_memoryset[0].score is not None
|
|
1057
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "label": 0, "partition_id": "p1"}
|
|
1058
|
+
assert scored_memoryset[0].source_id == "s1"
|
|
1059
|
+
lookup = scored_memoryset.search("i love soup", count=1)
|
|
1060
|
+
assert len(lookup) == 1
|
|
1061
|
+
assert lookup[0].score is not None
|
|
1062
|
+
assert lookup[0].score < 0.11
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
1066
|
+
def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
1067
|
+
# we are only updating an inconsequential metadata field so that we don't affect other tests
|
|
1068
|
+
memory = scored_memoryset[0]
|
|
1069
|
+
assert memory.label == 0
|
|
1070
|
+
scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
|
|
1071
|
+
assert scored_memoryset[0].label == 3
|
|
1072
|
+
memory.update(label=4)
|
|
1073
|
+
assert scored_memoryset[0].label == 4
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
@pytest.mark.asyncio
|
|
1077
|
+
async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
|
|
1078
|
+
"""Test async insertion of a single memory"""
|
|
1079
|
+
await writable_memoryset.arefresh()
|
|
1080
|
+
prev_length = writable_memoryset.length
|
|
1081
|
+
|
|
1082
|
+
await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
|
|
1083
|
+
|
|
1084
|
+
await writable_memoryset.arefresh()
|
|
1085
|
+
assert writable_memoryset.length == prev_length + 1
|
|
1086
|
+
last_memory = writable_memoryset[-1]
|
|
1087
|
+
assert last_memory.value == "async tomato soup is my favorite"
|
|
1088
|
+
assert last_memory.label == 0
|
|
1089
|
+
assert last_memory.metadata["key"] == "async_test"
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
@pytest.mark.asyncio
|
|
1093
|
+
async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
|
|
1094
|
+
"""Test async insertion of multiple memories"""
|
|
1095
|
+
await writable_memoryset.arefresh()
|
|
1096
|
+
prev_length = writable_memoryset.length
|
|
1097
|
+
|
|
1098
|
+
await writable_memoryset.ainsert(
|
|
1099
|
+
[
|
|
1100
|
+
dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
|
|
1101
|
+
dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
|
|
1102
|
+
]
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
await writable_memoryset.arefresh()
|
|
1106
|
+
assert writable_memoryset.length == prev_length + 2
|
|
1107
|
+
|
|
1108
|
+
# Check the inserted memories
|
|
1109
|
+
last_two_memories = writable_memoryset[-2:]
|
|
1110
|
+
values = [memory.value for memory in last_two_memories]
|
|
1111
|
+
labels = [memory.label for memory in last_two_memories]
|
|
1112
|
+
keys = [memory.metadata.get("key") for memory in last_two_memories]
|
|
1113
|
+
|
|
1114
|
+
assert "async batch soup is delicious" in values
|
|
1115
|
+
assert "async batch cats are adorable" in values
|
|
1116
|
+
assert 0 in labels
|
|
1117
|
+
assert 1 in labels
|
|
1118
|
+
assert "batch_test_1" in keys
|
|
1119
|
+
assert "batch_test_2" in keys
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
@pytest.mark.asyncio
|
|
1123
|
+
async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
|
|
1124
|
+
"""Test async insertion with source_id and metadata"""
|
|
1125
|
+
await writable_memoryset.arefresh()
|
|
1126
|
+
prev_length = writable_memoryset.length
|
|
1127
|
+
|
|
1128
|
+
await writable_memoryset.ainsert(
|
|
1129
|
+
dict(
|
|
1130
|
+
value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
|
|
1131
|
+
)
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
await writable_memoryset.arefresh()
|
|
1135
|
+
assert writable_memoryset.length == prev_length + 1
|
|
1136
|
+
last_memory = writable_memoryset[-1]
|
|
1137
|
+
assert last_memory.value == "async soup with source id"
|
|
1138
|
+
assert last_memory.label == 0
|
|
1139
|
+
assert last_memory.source_id == "async_source_123"
|
|
1140
|
+
assert last_memory.metadata["custom_field"] == "async_custom_value"
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
@pytest.mark.asyncio
|
|
1144
|
+
async def test_insert_memories_async_unauthenticated(
|
|
1145
|
+
unauthenticated_async_client, writable_memoryset: LabeledMemoryset
|
|
1146
|
+
):
|
|
1147
|
+
"""Test async insertion with invalid authentication"""
|
|
1148
|
+
with unauthenticated_async_client.use():
|
|
1149
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
1150
|
+
await writable_memoryset.ainsert(dict(value="this should fail", label=0))
|