orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
@@ -0,0 +1,1150 @@
1
+ import random
2
+ from uuid import uuid4
3
+
4
+ import pytest
5
+ from datasets.arrow_dataset import Dataset
6
+
7
+ from .classification_model import ClassificationModel
8
+ from .conftest import skip_in_ci, skip_in_prod
9
+ from .datasource import Datasource
10
+ from .embedding_model import PretrainedEmbeddingModel
11
+ from .memoryset import (
12
+ LabeledMemory,
13
+ LabeledMemoryset,
14
+ ScoredMemory,
15
+ ScoredMemoryset,
16
+ Status,
17
+ )
18
+ from .regression_model import RegressionModel
19
+
20
+ """
21
+ Test Performance Note:
22
+
23
+ Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
24
+
25
+ - Two fixtures are used to manage memorysets:
26
+ - `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
27
+ It should only be used in nullipotent tests.
28
+ - `writable_memoryset` is a function-scoped, regenerating fixture.
29
+ It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
30
+
31
+ - To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
32
+ For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
33
+ rather than separate `test_delete_single` and `test_delete_multiple` tests.
34
+ """
35
+
36
+
37
+ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
38
+ assert readonly_memoryset is not None
39
+ assert readonly_memoryset.name == "test_readonly_memoryset"
40
+ assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
41
+ assert readonly_memoryset.label_names == label_names
42
+ assert readonly_memoryset.insertion_status == Status.COMPLETED
43
+ assert isinstance(readonly_memoryset.length, int)
44
+ assert readonly_memoryset.length == len(hf_dataset)
45
+ assert readonly_memoryset.index_type == "IVF_FLAT"
46
+ assert readonly_memoryset.index_params == {"n_lists": 100}
47
+
48
+
49
+ def test_create_empty_labeled_memoryset():
50
+ name = f"test_empty_labeled_{uuid4()}"
51
+ label_names = ["negative", "positive"]
52
+ try:
53
+ memoryset = LabeledMemoryset.create(name, label_names=label_names, description="empty labeled test")
54
+ assert memoryset is not None
55
+ assert memoryset.name == name
56
+ assert memoryset.length == 0
57
+ assert memoryset.label_names == label_names
58
+ assert memoryset.insertion_status is None
59
+
60
+ # inserting should work on an empty memoryset
61
+ memoryset.insert(dict(value="i love soup", label=1, key="k1"))
62
+ memoryset.refresh()
63
+ assert memoryset.length == 1
64
+ m = memoryset[0]
65
+ assert isinstance(m, LabeledMemory)
66
+ assert m.value == "i love soup"
67
+ assert m.label == 1
68
+ assert m.label_name == "positive"
69
+ assert m.metadata.get("key") == "k1"
70
+
71
+ # if_exists="open" should re-open the same memoryset
72
+ reopened = LabeledMemoryset.create(name, label_names=label_names, if_exists="open")
73
+ assert reopened.id == memoryset.id
74
+ assert len(reopened) == 1
75
+
76
+ # if_exists="open" should raise if label_names mismatch
77
+ with pytest.raises(ValueError, match=r"label names|requested"):
78
+ LabeledMemoryset.create(name, label_names=["turtles", "frogs"], if_exists="open")
79
+
80
+ # if_exists="open" should raise if embedding_model mismatch
81
+ with pytest.raises(ValueError, match=r"embedding_model|requested"):
82
+ LabeledMemoryset.create(
83
+ name,
84
+ label_names=label_names,
85
+ embedding_model=PretrainedEmbeddingModel.DISTILBERT,
86
+ if_exists="open",
87
+ )
88
+
89
+ # if_exists="error" should raise when it already exists
90
+ with pytest.raises(ValueError, match="already exists"):
91
+ LabeledMemoryset.create(name, label_names=label_names, if_exists="error")
92
+ finally:
93
+ LabeledMemoryset.drop(name, if_not_exists="ignore")
94
+
95
+
96
+ def test_create_empty_scored_memoryset():
97
+ name = f"test_empty_scored_{uuid4()}"
98
+ try:
99
+ memoryset = ScoredMemoryset.create(name, description="empty scored test")
100
+ assert memoryset is not None
101
+ assert memoryset.name == name
102
+ assert memoryset.length == 0
103
+ assert memoryset.insertion_status is None
104
+
105
+ # inserting should work on an empty memoryset
106
+ memoryset.insert(dict(value="i love soup", score=0.25, key="k1", label=0))
107
+ memoryset.refresh()
108
+ assert memoryset.length == 1
109
+ m = memoryset[0]
110
+ assert isinstance(m, ScoredMemory)
111
+ assert m.value == "i love soup"
112
+ assert m.score == 0.25
113
+ assert m.metadata.get("key") == "k1"
114
+ assert m.metadata.get("label") == 0
115
+
116
+ # if_exists="open" should re-open the same memoryset
117
+ reopened = ScoredMemoryset.create(name, if_exists="open")
118
+ assert reopened.id == memoryset.id
119
+
120
+ # if_exists="open" should raise if embedding_model mismatch
121
+ with pytest.raises(ValueError, match=r"embedding_model|requested"):
122
+ ScoredMemoryset.create(name, embedding_model=PretrainedEmbeddingModel.DISTILBERT, if_exists="open")
123
+
124
+ # if_exists="error" should raise when it already exists
125
+ with pytest.raises(ValueError, match="already exists"):
126
+ ScoredMemoryset.create(name, if_exists="error")
127
+ finally:
128
+ ScoredMemoryset.drop(name, if_not_exists="ignore")
129
+
130
+
131
+ def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
132
+ with unauthenticated_client.use():
133
+ with pytest.raises(ValueError, match="Invalid API key"):
134
+ LabeledMemoryset.create("test_memoryset", datasource=datasource)
135
+
136
+
137
+ def test_create_memoryset_invalid_input(datasource):
138
+ # invalid name
139
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
140
+ LabeledMemoryset.create("test memoryset", datasource=datasource)
141
+
142
+
143
+ def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
144
+ memoryset_name = readonly_memoryset.name
145
+ with pytest.raises(ValueError):
146
+ LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
147
+ with pytest.raises(ValueError):
148
+ LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
149
+
150
+
151
+ def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
152
+ # invalid label names
153
+ with pytest.raises(ValueError):
154
+ LabeledMemoryset.from_hf_dataset(
155
+ readonly_memoryset.name,
156
+ hf_dataset,
157
+ label_names=["turtles", "frogs"],
158
+ if_exists="open",
159
+ )
160
+ # different embedding model
161
+ with pytest.raises(ValueError):
162
+ LabeledMemoryset.from_hf_dataset(
163
+ readonly_memoryset.name,
164
+ hf_dataset,
165
+ label_names=label_names,
166
+ embedding_model=PretrainedEmbeddingModel.DISTILBERT,
167
+ if_exists="open",
168
+ )
169
+ opened_memoryset = LabeledMemoryset.from_hf_dataset(
170
+ readonly_memoryset.name,
171
+ hf_dataset,
172
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
173
+ if_exists="open",
174
+ )
175
+ assert opened_memoryset is not None
176
+ assert opened_memoryset.name == readonly_memoryset.name
177
+ assert opened_memoryset.length == len(hf_dataset)
178
+
179
+
180
+ def test_if_exists_error_no_datasource_creation(
181
+ readonly_memoryset: LabeledMemoryset,
182
+ ):
183
+ memoryset_name = readonly_memoryset.name
184
+ datasource_name = f"{memoryset_name}_datasource"
185
+ Datasource.drop(datasource_name, if_not_exists="ignore")
186
+ assert not Datasource.exists(datasource_name)
187
+ with pytest.raises(ValueError):
188
+ LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
189
+ assert not Datasource.exists(datasource_name)
190
+
191
+
192
+ def test_if_exists_open_reuses_existing_datasource(
193
+ readonly_memoryset: LabeledMemoryset,
194
+ ):
195
+ memoryset_name = readonly_memoryset.name
196
+ datasource_name = f"{memoryset_name}_datasource"
197
+ Datasource.drop(datasource_name, if_not_exists="ignore")
198
+ assert not Datasource.exists(datasource_name)
199
+ reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
200
+ assert reopened.id == readonly_memoryset.id
201
+ assert not Datasource.exists(datasource_name)
202
+
203
+
204
+ def test_create_memoryset_string_label():
205
+ assert not LabeledMemoryset.exists("test_string_label")
206
+ memoryset = LabeledMemoryset.from_hf_dataset(
207
+ "test_string_label",
208
+ Dataset.from_dict({"value": ["terrible", "great"], "label": ["negative", "positive"]}),
209
+ )
210
+ assert memoryset is not None
211
+ assert memoryset.length == 2
212
+ assert memoryset.label_names == ["negative", "positive"]
213
+ assert memoryset[0].label == 0
214
+ assert memoryset[1].label == 1
215
+ assert memoryset[0].label_name == "negative"
216
+ assert memoryset[1].label_name == "positive"
217
+
218
+
219
+ def test_create_memoryset_integer_label():
220
+ assert not LabeledMemoryset.exists("test_integer_label")
221
+ memoryset = LabeledMemoryset.from_hf_dataset(
222
+ "test_integer_label",
223
+ Dataset.from_dict({"value": ["terrible", "great"], "label": [0, 1]}),
224
+ label_names=["negative", "positive"],
225
+ )
226
+ assert memoryset is not None
227
+ assert memoryset.length == 2
228
+ assert memoryset.label_names == ["negative", "positive"]
229
+ assert memoryset[0].label == 0
230
+ assert memoryset[1].label == 1
231
+ assert memoryset[0].label_name == "negative"
232
+ assert memoryset[1].label_name == "positive"
233
+
234
+
235
+ def test_create_memoryset_null_labels():
236
+ memoryset = LabeledMemoryset.from_hf_dataset(
237
+ "test_null_labels",
238
+ Dataset.from_dict({"value": ["terrible", "great"]}),
239
+ label_names=["negative", "positive"],
240
+ label_column=None,
241
+ )
242
+ assert memoryset is not None
243
+ assert memoryset.length == 2
244
+ assert memoryset.label_names == ["negative", "positive"]
245
+ assert memoryset[0].label is None
246
+ assert memoryset[1].label is None
247
+
248
+
249
+ def test_open_memoryset(readonly_memoryset, hf_dataset):
250
+ fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
251
+ assert fetched_memoryset is not None
252
+ assert fetched_memoryset.name == readonly_memoryset.name
253
+ assert fetched_memoryset.length == len(hf_dataset)
254
+ assert fetched_memoryset.index_type == "IVF_FLAT"
255
+ assert fetched_memoryset.index_params == {"n_lists": 100}
256
+
257
+
258
+ def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
259
+ with unauthenticated_client.use():
260
+ with pytest.raises(ValueError, match="Invalid API key"):
261
+ LabeledMemoryset.open(readonly_memoryset.name)
262
+
263
+
264
+ def test_open_memoryset_not_found():
265
+ with pytest.raises(LookupError):
266
+ LabeledMemoryset.open(str(uuid4()))
267
+
268
+
269
+ def test_open_memoryset_invalid_input():
270
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
271
+ LabeledMemoryset.open("not valid id")
272
+
273
+
274
+ def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
275
+ with unauthorized_client.use():
276
+ with pytest.raises(LookupError):
277
+ LabeledMemoryset.open(readonly_memoryset.name)
278
+
279
+
280
+ def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
281
+ memorysets = LabeledMemoryset.all()
282
+ assert len(memorysets) > 0
283
+ assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
284
+
285
+
286
+ def test_all_memorysets_hidden(
287
+ readonly_memoryset: LabeledMemoryset,
288
+ ):
289
+ # Create a hidden memoryset
290
+ hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
291
+ hidden_memoryset.set(hidden=True)
292
+
293
+ # Test that show_hidden=False excludes hidden memorysets
294
+ visible_memorysets = LabeledMemoryset.all(show_hidden=False)
295
+ assert len(visible_memorysets) > 0
296
+ assert readonly_memoryset in visible_memorysets
297
+ assert hidden_memoryset not in visible_memorysets
298
+
299
+ # Test that show_hidden=True includes hidden memorysets
300
+ all_memorysets = LabeledMemoryset.all(show_hidden=True)
301
+ assert len(all_memorysets) == len(visible_memorysets) + 1
302
+ assert readonly_memoryset in all_memorysets
303
+ assert hidden_memoryset in all_memorysets
304
+
305
+
306
+ def test_all_memorysets_unauthenticated(unauthenticated_client):
307
+ with unauthenticated_client.use():
308
+ with pytest.raises(ValueError, match="Invalid API key"):
309
+ LabeledMemoryset.all()
310
+
311
+
312
+ def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
313
+ with unauthorized_client.use():
314
+ assert readonly_memoryset not in LabeledMemoryset.all()
315
+
316
+
317
+ def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
318
+ with unauthenticated_client.use():
319
+ with pytest.raises(ValueError, match="Invalid API key"):
320
+ LabeledMemoryset.drop(readonly_memoryset.name)
321
+
322
+
323
+ def test_drop_memoryset_not_found():
324
+ with pytest.raises(LookupError):
325
+ LabeledMemoryset.drop(str(uuid4()))
326
+ # ignores error if specified
327
+ LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
328
+
329
+
330
+ def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
331
+ with unauthorized_client.use():
332
+ with pytest.raises(LookupError):
333
+ LabeledMemoryset.drop(readonly_memoryset.name)
334
+
335
+
336
+ def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
337
+ original_label_names = writable_memoryset.label_names
338
+ writable_memoryset.set(description="New description")
339
+ assert writable_memoryset.description == "New description"
340
+
341
+ writable_memoryset.set(description=None)
342
+ assert writable_memoryset.description is None
343
+
344
+ writable_memoryset.set(name="New_name")
345
+ assert writable_memoryset.name == "New_name"
346
+
347
+ writable_memoryset.set(name="test_writable_memoryset")
348
+ assert writable_memoryset.name == "test_writable_memoryset"
349
+
350
+ assert writable_memoryset.label_names == original_label_names
351
+
352
+ writable_memoryset.set(label_names=["New label 1", "New label 2"])
353
+ assert writable_memoryset.label_names == ["New label 1", "New label 2"]
354
+
355
+ writable_memoryset.set(hidden=True)
356
+ assert writable_memoryset.hidden is True
357
+
358
+
359
+ def test_search(readonly_memoryset: LabeledMemoryset):
360
+ memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
361
+ assert len(memory_lookups) == 2
362
+ assert len(memory_lookups[0]) == 1
363
+ assert len(memory_lookups[1]) == 1
364
+ assert memory_lookups[0][0].label == 0
365
+ assert memory_lookups[1][0].label == 1
366
+
367
+
368
+ def test_search_count(readonly_memoryset: LabeledMemoryset):
369
+ memory_lookups = readonly_memoryset.search("i love soup", count=3)
370
+ assert len(memory_lookups) == 3
371
+ assert memory_lookups[0].label == 0
372
+ assert memory_lookups[1].label == 0
373
+ assert memory_lookups[2].label == 0
374
+
375
+
376
+ def test_search_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
377
+ # Search within a specific partition - use "soup" which appears in both p1 and p2
378
+ # Use exclude_global to ensure we only get results from the specified partition
379
+ memory_lookups = readonly_partitioned_memoryset.search(
380
+ "soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
381
+ )
382
+ assert len(memory_lookups) > 0
383
+ # All results should be from partition p1 when partition_id is specified
384
+ assert all(
385
+ memory.partition_id == "p1" for memory in memory_lookups
386
+ ), f"Expected all results from partition p1, but got: {[m.partition_id for m in memory_lookups]}"
387
+
388
+ # Search in a different partition - use "cats" which appears in both p1 and p2
389
+ memory_lookups_p2 = readonly_partitioned_memoryset.search(
390
+ "cats", partition_id="p2", partition_filter_mode="exclude_global", count=5
391
+ )
392
+ assert len(memory_lookups_p2) > 0
393
+ # All results should be from partition p2 when partition_id is specified
394
+ assert all(
395
+ memory.partition_id == "p2" for memory in memory_lookups_p2
396
+ ), f"Expected all results from partition p2, but got: {[m.partition_id for m in memory_lookups_p2]}"
397
+
398
+
399
+ def test_search_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
400
+ # Search excluding global memories - need to specify a partition_id when using exclude_global
401
+ # This tests that exclude_global works with a specific partition
402
+ memory_lookups = readonly_partitioned_memoryset.search(
403
+ "soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
404
+ )
405
+ assert len(memory_lookups) > 0
406
+ # All results should have a partition_id (not None) and be from p1
407
+ assert all(memory.partition_id == "p1" for memory in memory_lookups)
408
+
409
+
410
+ def test_search_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
411
+ # Search only in global memories (partition_id=None in the data)
412
+ # Use a query that matches global memories and a reasonable count
413
+ memory_lookups = readonly_partitioned_memoryset.search("beach", partition_filter_mode="only_global", count=3)
414
+ # Should get at least some results (may be fewer than requested if not enough global memories match)
415
+ assert len(memory_lookups) > 0
416
+ # All results should be global (partition_id is None)
417
+ partition_ids = {memory.partition_id for memory in memory_lookups}
418
+ # When using only_global, all results should be global (either None)
419
+ assert all(
420
+ memory.partition_id is None for memory in memory_lookups
421
+ ), f"Expected all results to be global (partition_id=None), but got partition_ids: {partition_ids}"
422
+
423
+
424
+ def test_search_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
425
+ # Search including global memories (default behavior)
426
+ # Use a reasonable count that won't exceed available memories
427
+ memory_lookups = readonly_partitioned_memoryset.search(
428
+ "i love soup", partition_filter_mode="include_global", count=5
429
+ )
430
+ assert len(memory_lookups) > 0
431
+ # Results can include both partitioned and global memories
432
+ partition_ids = {memory.partition_id for memory in memory_lookups}
433
+ # Should have at least one partition or global memory
434
+ assert len(partition_ids) > 0
435
+
436
+
437
+ def test_search_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
438
+ # Search ignoring partition filtering entirely
439
+ memory_lookups = readonly_partitioned_memoryset.search(
440
+ "i love soup", partition_filter_mode="ignore_partitions", count=10
441
+ )
442
+ assert len(memory_lookups) > 0
443
+ # Results can come from any partition or global
444
+ partition_ids = {memory.partition_id for memory in memory_lookups}
445
+ # Should have results from multiple partitions/global
446
+ assert len(partition_ids) >= 1
447
+
448
+
449
+ def test_search_multiple_queries_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
450
+ # Search multiple queries within a specific partition
451
+ memory_lookups = readonly_partitioned_memoryset.search(["i love soup", "cats are cute"], partition_id="p1", count=3)
452
+ assert len(memory_lookups) == 2
453
+ assert len(memory_lookups[0]) > 0
454
+ assert len(memory_lookups[1]) > 0
455
+ # All results should be from partition p1
456
+ assert all(memory.partition_id == "p1" for memory in memory_lookups[0])
457
+ assert all(memory.partition_id == "p1" for memory in memory_lookups[1])
458
+
459
+
460
+ def test_search_with_partition_id_and_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
461
+ # When partition_id is specified, partition_filter_mode should still work
462
+ # Search in p1 with exclude_global (should only return p1 results)
463
+ memory_lookups = readonly_partitioned_memoryset.search(
464
+ "i love soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
465
+ )
466
+ assert len(memory_lookups) > 0
467
+ assert all(memory.partition_id == "p1" for memory in memory_lookups)
468
+
469
+
470
+ def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
471
+ memory = readonly_memoryset[0]
472
+ assert memory.value == hf_dataset[0]["value"]
473
+ assert memory.label == hf_dataset[0]["label"]
474
+ assert memory.label_name == label_names[hf_dataset[0]["label"]]
475
+ assert memory.source_id == hf_dataset[0]["source_id"]
476
+ assert memory.score == hf_dataset[0]["score"]
477
+ assert memory.key == hf_dataset[0]["key"]
478
+ last_memory = readonly_memoryset[-1]
479
+ assert last_memory.value == hf_dataset[-1]["value"]
480
+ assert last_memory.label == hf_dataset[-1]["label"]
481
+
482
+
483
+ def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
484
+ memories = readonly_memoryset[1:3]
485
+ assert len(memories) == 2
486
+ assert memories[0].value == hf_dataset["value"][1]
487
+ assert memories[1].value == hf_dataset["value"][2]
488
+
489
+
490
+ def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
491
+ memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
492
+ assert memory.value == hf_dataset[0]["value"]
493
+ assert memory == readonly_memoryset[memory.memory_id]
494
+
495
+
496
+ def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
497
+ memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
498
+ assert len(memories) == 2
499
+ assert memories[0].value == hf_dataset[0]["value"]
500
+ assert memories[1].value == hf_dataset[1]["value"]
501
+
502
+
503
+ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
504
+ memories = readonly_memoryset.query(filters=[("label", "==", 1)])
505
+ assert len(memories) == 8
506
+ assert all(memory.label == 1 for memory in memories)
507
+ assert len(readonly_memoryset.query(limit=2)) == 2
508
+ assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
509
+
510
+
511
+ def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
512
+ prediction = classification_model.predict("Do you love soup?")
513
+ feedback_name = f"correct_{random.randint(0, 1000000)}"
514
+ prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
515
+ memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
516
+
517
+ # Get the memory_ids that were actually used in the prediction
518
+ used_memory_ids = {memory.memory_id for memory in prediction.memory_lookups}
519
+
520
+ assert len(memories) == 8
521
+ assert all(memory.label == 0 for memory in memories)
522
+ for memory in memories:
523
+ assert memory.feedback_metrics is not None
524
+ if memory.memory_id in used_memory_ids:
525
+ assert feedback_name in memory.feedback_metrics
526
+ assert memory.feedback_metrics[feedback_name]["avg"] == 1.0
527
+ assert memory.feedback_metrics[feedback_name]["count"] == 1
528
+ else:
529
+ assert feedback_name not in memory.feedback_metrics or memory.feedback_metrics[feedback_name]["count"] == 0
530
+ assert isinstance(memory.lookup_count, int)
531
+
532
+
533
+ def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
534
+ prediction = classification_model.predict("Do you love soup?")
535
+ prediction.record_feedback(category="accurate", value=prediction.label == 0)
536
+ memories = prediction.memoryset.query(
537
+ filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
538
+ )
539
+ assert len(memories) == 3
540
+ assert all(memory.label == 0 for memory in memories)
541
+ for memory in memories:
542
+ assert memory.feedback_metrics is not None
543
+ assert memory.feedback_metrics["accurate"] is not None
544
+ assert memory.feedback_metrics["accurate"]["avg"] == 1.0
545
+ assert memory.feedback_metrics["accurate"]["count"] == 1
546
+
547
+
548
+ def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
549
+ prediction = classification_model.predict("Do you love soup?")
550
+ prediction.record_feedback(category="positive", value=1.0)
551
+ prediction2 = classification_model.predict("Do you like cats?")
552
+ prediction2.record_feedback(category="positive", value=-1.0)
553
+
554
+ memories = prediction.memoryset.query(
555
+ filters=[("feedback_metrics.positive.avg", ">=", -1.0)],
556
+ sort=[("feedback_metrics.positive.avg", "desc")],
557
+ with_feedback_metrics=True,
558
+ )
559
+ assert (
560
+ len(memories) == 6
561
+ ) # there are only 6 out of 16 memories that have a positive feedback metric. Look at SAMPLE_DATA in conftest.py
562
+ assert memories[0].feedback_metrics["positive"]["avg"] == 1.0
563
+ assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
564
+
565
+
566
+ def test_query_memoryset_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
567
+ # Query with partition_id and include_global (default) - includes both p1 and global memories
568
+ memories = readonly_partitioned_memoryset.query(partition_id="p1")
569
+ assert len(memories) == 15 # 8 p1 + 7 global = 15
570
+ # Results should include both p1 and global memories
571
+ partition_ids = {memory.partition_id for memory in memories}
572
+ assert "p1" in partition_ids
573
+ assert None in partition_ids
574
+
575
+
576
+ def test_query_memoryset_with_partition_id_and_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
577
+ # Query with partition_id and exclude_global mode - only returns p1 memories
578
+ memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
579
+ assert len(memories) == 8 # Only 8 p1 memories (no global)
580
+ # All results should be from partition p1 (no global memories)
581
+ assert all(memory.partition_id == "p1" for memory in memories)
582
+
583
+
584
+ def test_query_memoryset_with_partition_id_and_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
585
+ # Query with partition_id and include_global mode (default) - includes both p1 and global
586
+ memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="include_global")
587
+ assert len(memories) == 15 # 8 p1 + 7 global = 15
588
+ # Results should include both p1 and global memories
589
+ partition_ids = {memory.partition_id for memory in memories}
590
+ assert "p1" in partition_ids
591
+ assert None in partition_ids
592
+
593
+
594
+ def test_query_memoryset_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
595
+ # Query excluding global memories requires a partition_id
596
+ # Test with a specific partition_id
597
+ memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
598
+ assert len(memories) == 8 # Only p1 memories
599
+ # All results should have a partition_id (not global)
600
+ assert all(memory.partition_id == "p1" for memory in memories)
601
+
602
+
603
+ def test_query_memoryset_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
604
+ # Query only in global memories
605
+ memories = readonly_partitioned_memoryset.query(partition_filter_mode="only_global")
606
+ assert len(memories) == 7 # There are 7 global memories in SAMPLE_DATA
607
+ # All results should be global (partition_id is None)
608
+ assert all(memory.partition_id is None for memory in memories)
609
+
610
+
611
+ def test_query_memoryset_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
612
+ # Query including global memories - when no partition_id is specified,
613
+ # include_global seems to only return global memories
614
+ memories = readonly_partitioned_memoryset.query(partition_filter_mode="include_global")
615
+ # Based on actual behavior, this returns only global memories
616
+ assert len(memories) == 7
617
+ # All results should be global
618
+ assert all(memory.partition_id is None for memory in memories)
619
+
620
+
621
+ def test_query_memoryset_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
622
+ # Query ignoring partition filtering entirely - returns all memories
623
+ memories = readonly_partitioned_memoryset.query(partition_filter_mode="ignore_partitions", limit=100)
624
+ assert len(memories) == 22 # All 22 memories
625
+ # Results can come from any partition or global
626
+ partition_ids = {memory.partition_id for memory in memories}
627
+ # Should have results from multiple partitions/global
628
+ assert len(partition_ids) >= 1
629
+ # Verify we have p1, p2, and global
630
+ assert "p1" in partition_ids
631
+ assert "p2" in partition_ids
632
+ assert None in partition_ids
633
+
634
+
635
+ def test_query_memoryset_with_filters_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
636
+ # Query with filters and partition_id
637
+ memories = readonly_partitioned_memoryset.query(filters=[("label", "==", 0)], partition_id="p1")
638
+ assert len(memories) > 0
639
+ # All results should match the filter and be from partition p1
640
+ assert all(memory.label == 0 for memory in memories)
641
+ assert all(memory.partition_id == "p1" for memory in memories)
642
+
643
+
644
+ def test_query_memoryset_with_filters_and_partition_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
645
+ # Query with filters and partition_filter_mode - exclude_global requires partition_id
646
+ memories = readonly_partitioned_memoryset.query(
647
+ filters=[("label", "==", 1)], partition_id="p1", partition_filter_mode="exclude_global"
648
+ )
649
+ assert len(memories) > 0
650
+ # All results should match the filter and be from p1 (not global)
651
+ assert all(memory.label == 1 for memory in memories)
652
+ assert all(memory.partition_id == "p1" for memory in memories)
653
+
654
+
655
+ def test_query_memoryset_with_limit_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
656
+ # Query with limit and partition_id
657
+ memories = readonly_partitioned_memoryset.query(partition_id="p2", limit=3)
658
+ assert len(memories) == 3
659
+ # All results should be from partition p2
660
+ assert all(memory.partition_id == "p2" for memory in memories)
661
+
662
+
663
+ def test_query_memoryset_with_offset_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
664
+ # Query with offset and partition_id - use exclude_global to get only p1 memories
665
+ memories_page1 = readonly_partitioned_memoryset.query(
666
+ partition_id="p1", partition_filter_mode="exclude_global", limit=5
667
+ )
668
+ memories_page2 = readonly_partitioned_memoryset.query(
669
+ partition_id="p1", partition_filter_mode="exclude_global", offset=5, limit=5
670
+ )
671
+ assert len(memories_page1) == 5
672
+ assert len(memories_page2) == 3 # Only 3 remaining p1 memories (8 total - 5 = 3)
673
+ # All results should be from partition p1
674
+ assert all(memory.partition_id == "p1" for memory in memories_page1)
675
+ assert all(memory.partition_id == "p1" for memory in memories_page2)
676
+ # Results should be different (pagination works)
677
+ memory_ids_page1 = {memory.memory_id for memory in memories_page1}
678
+ memory_ids_page2 = {memory.memory_id for memory in memories_page2}
679
+ assert memory_ids_page1.isdisjoint(memory_ids_page2)
680
+
681
+
682
+ def test_query_memoryset_with_partition_id_p2(readonly_partitioned_memoryset: LabeledMemoryset):
683
+ # Query a different partition to verify it works
684
+ # With include_global (default), it includes both p2 and global memories
685
+ memories = readonly_partitioned_memoryset.query(partition_id="p2")
686
+ assert len(memories) == 14 # 7 p2 + 7 global = 14
687
+ # Results should include both p2 and global memories
688
+ partition_ids = {memory.partition_id for memory in memories}
689
+ assert "p2" in partition_ids
690
+ assert None in partition_ids
691
+
692
+
693
+ def test_query_memoryset_with_metadata_filter_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
694
+ # Query with metadata filter and partition_id
695
+ memories = readonly_partitioned_memoryset.query(filters=[("metadata.key", "==", "g1")], partition_id="p1")
696
+ assert len(memories) > 0
697
+ # All results should match the metadata filter and be from partition p1
698
+ assert all(memory.metadata.get("key") == "g1" for memory in memories)
699
+ assert all(memory.partition_id == "p1" for memory in memories)
700
+
701
+
702
+ def test_query_memoryset_with_partition_filter_mode_only_global_and_filters(
703
+ readonly_partitioned_memoryset: LabeledMemoryset,
704
+ ):
705
+ # Query only global memories with filters
706
+ memories = readonly_partitioned_memoryset.query(
707
+ filters=[("metadata.key", "==", "g3")], partition_filter_mode="only_global"
708
+ )
709
+ assert len(memories) > 0
710
+ # All results should match the filter and be global
711
+ assert all(memory.metadata.get("key") == "g3" for memory in memories)
712
+ assert all(memory.partition_id is None for memory in memories)
713
+
714
+
715
+ def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
716
+ """Test that LabeledMemory.predictions() only returns classification predictions."""
717
+ # Given: A classification model with memories
718
+ memories = classification_model.memoryset.query(limit=1)
719
+ assert len(memories) > 0
720
+ memory = memories[0]
721
+
722
+ # When: I call the predictions method
723
+ predictions = memory.predictions()
724
+
725
+ # Then: It should return a list of ClassificationPrediction objects
726
+ assert isinstance(predictions, list)
727
+ for prediction in predictions:
728
+ assert prediction.__class__.__name__ == "ClassificationPrediction"
729
+ assert hasattr(prediction, "label")
730
+ assert not hasattr(prediction, "score") or prediction.score is None
731
+
732
+
733
+ def test_scored_memory_predictions_property(regression_model: RegressionModel):
734
+ """Test that ScoredMemory.predictions() only returns regression predictions."""
735
+ # Given: A regression model with memories
736
+ memories = regression_model.memoryset.query(limit=1)
737
+ assert len(memories) > 0
738
+ memory = memories[0]
739
+
740
+ # When: I call the predictions method
741
+ predictions = memory.predictions()
742
+
743
+ # Then: It should return a list of RegressionPrediction objects
744
+ assert isinstance(predictions, list)
745
+ for prediction in predictions:
746
+ assert prediction.__class__.__name__ == "RegressionPrediction"
747
+ assert hasattr(prediction, "score")
748
+ assert not hasattr(prediction, "label") or prediction.label is None
749
+
750
+
751
+ def test_memory_feedback_property(classification_model: ClassificationModel):
752
+ """Test that memory.feedback() returns feedback from relevant predictions."""
753
+ # Given: A prediction with recorded feedback
754
+ prediction = classification_model.predict("Test feedback")
755
+ feedback_category = f"test_feedback_{random.randint(0, 1000000)}"
756
+ prediction.record_feedback(category=feedback_category, value=True)
757
+
758
+ # And: A memory that was used in the prediction
759
+ memory_lookups = prediction.memory_lookups
760
+ assert len(memory_lookups) > 0
761
+ memory = memory_lookups[0]
762
+
763
+ # When: I access the feedback property
764
+ feedback = memory.feedback()
765
+
766
+ # Then: It should return feedback aggregated by category as a dict
767
+ assert isinstance(feedback, dict)
768
+ assert feedback_category in feedback
769
+ # Feedback values are lists (you may want to look at mean on the raw data)
770
+ assert isinstance(feedback[feedback_category], list)
771
+ assert len(feedback[feedback_category]) > 0
772
+ # For binary feedback, values should be booleans
773
+ assert isinstance(feedback[feedback_category][0], bool)
774
+
775
+
776
+ def test_memory_predictions_method_parameters(classification_model: ClassificationModel):
777
+ """Test that memory.predictions() method supports pagination, sorting, and filtering."""
778
+ # Given: A classification model with memories
779
+ memories = classification_model.memoryset.query(limit=1)
780
+ assert len(memories) > 0
781
+ memory = memories[0]
782
+
783
+ # When: I call predictions with limit parameter
784
+ predictions_limited = memory.predictions(limit=2)
785
+
786
+ # Then: It should respect the limit
787
+ assert isinstance(predictions_limited, list)
788
+ assert len(predictions_limited) <= 2
789
+
790
+ # When: I call predictions with offset parameter
791
+ all_predictions = memory.predictions(limit=100)
792
+ if len(all_predictions) > 1:
793
+ predictions_offset = memory.predictions(limit=1, offset=1)
794
+ # Then: offset should skip the first prediction
795
+ assert predictions_offset[0].prediction_id != all_predictions[0].prediction_id
796
+
797
+ # When: I call predictions with sort parameter
798
+ predictions_sorted = memory.predictions(limit=10, sort=[("timestamp", "desc")])
799
+ # Then: It should return predictions (sorting verified by API)
800
+ assert isinstance(predictions_sorted, list)
801
+
802
+ # When: I call predictions with expected_label_match parameter
803
+ correct_predictions = memory.predictions(expected_label_match=True)
804
+ incorrect_predictions = memory.predictions(expected_label_match=False)
805
+ # Then: Both should return lists (correctness verified by API filtering)
806
+ assert isinstance(correct_predictions, list)
807
+ assert isinstance(incorrect_predictions, list)
808
+
809
+
810
+ def test_memory_predictions_expected_label_filter(classification_model: ClassificationModel):
811
+ """Test that memory.predictions(expected_label_match=...) filters predictions by correctness."""
812
+ # Given: Make an initial prediction to learn the model's label for a known input
813
+ baseline_prediction = classification_model.predict("Filter test input", save_telemetry="sync")
814
+ original_label = baseline_prediction.label
815
+ alternate_label = 0 if original_label else 1
816
+
817
+ # When: Make a second prediction with an intentionally incorrect expected label
818
+ mismatched_prediction = classification_model.predict(
819
+ "Filter test input",
820
+ expected_labels=alternate_label,
821
+ save_telemetry="sync",
822
+ )
823
+ mismatched_memory = mismatched_prediction.memory_lookups[0]
824
+
825
+ # Then: The prediction should show up when filtering for incorrect predictions
826
+ incorrect_predictions = mismatched_memory.predictions(expected_label_match=False)
827
+ assert any(pred.prediction_id == mismatched_prediction.prediction_id for pred in incorrect_predictions)
828
+
829
+ # Produce a correct prediction (predicted label matches expected label)
830
+ correct_prediction = classification_model.predict(
831
+ "Filter test input",
832
+ expected_labels=original_label,
833
+ save_telemetry="sync",
834
+ )
835
+
836
+ # Ensure we are inspecting a memory used by both correct and incorrect predictions
837
+ correct_lookup_ids = {lookup.memory_id for lookup in correct_prediction.memory_lookups}
838
+ if mismatched_memory.memory_id not in correct_lookup_ids:
839
+ shared_lookup = next(
840
+ (lookup for lookup in mismatched_prediction.memory_lookups if lookup.memory_id in correct_lookup_ids),
841
+ None,
842
+ )
843
+ assert shared_lookup is not None, "No shared memory lookup between correct and incorrect predictions"
844
+ mismatched_memory = shared_lookup
845
+
846
+ # And: The correct prediction should appear when filtering for correct predictions
847
+ correct_predictions = mismatched_memory.predictions(expected_label_match=True)
848
+ assert any(pred.prediction_id == correct_prediction.prediction_id for pred in correct_predictions)
849
+ assert all(pred.prediction_id != mismatched_prediction.prediction_id for pred in correct_predictions)
850
+
851
+
852
+ def test_insert_memories(writable_memoryset: LabeledMemoryset):
853
+ writable_memoryset.refresh()
854
+ prev_length = writable_memoryset.length
855
+ writable_memoryset.insert(
856
+ [
857
+ dict(value="tomato soup is my favorite", label=0),
858
+ dict(value="cats are fun to play with", label=1),
859
+ ],
860
+ batch_size=1,
861
+ )
862
+ writable_memoryset.refresh()
863
+ assert writable_memoryset.length == prev_length + 2
864
+ writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
865
+ writable_memoryset.refresh()
866
+ assert writable_memoryset.length == prev_length + 3
867
+ last_memory = writable_memoryset[-1]
868
+ assert last_memory.value == "tomato soup is my favorite"
869
+ assert last_memory.label == 0
870
+ assert last_memory.metadata
871
+ assert last_memory.metadata["key"] == "test"
872
+ assert last_memory.source_id == "test"
873
+
874
+
875
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
876
+ @skip_in_ci("CI environment may not have session consistency guarantees")
877
+ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
878
+ # We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
879
+
880
+ # test updating a single memory
881
+ memory_id = writable_memoryset[0].memory_id
882
+ updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
883
+ assert updated_memory.value == "i love soup so much"
884
+ assert updated_memory.label == hf_dataset[0]["label"]
885
+ writable_memoryset.refresh() # Refresh to ensure consistency after update
886
+ assert writable_memoryset.get(memory_id).value == "i love soup so much"
887
+
888
+ # test updating a memory instance
889
+ memory = writable_memoryset[0]
890
+ updated_memory = memory.update(value="i love soup even more")
891
+ assert updated_memory is memory
892
+ assert memory.value == "i love soup even more"
893
+ assert memory.label == hf_dataset[0]["label"]
894
+
895
+ # test updating multiple memories
896
+ memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
897
+ updated_memories = writable_memoryset.update(
898
+ [
899
+ dict(memory_id=memory_ids[0], value="i love soup so much"),
900
+ dict(memory_id=memory_ids[1], value="cats are so cute"),
901
+ ],
902
+ batch_size=1,
903
+ )
904
+ assert updated_memories[0].value == "i love soup so much"
905
+ assert updated_memories[1].value == "cats are so cute"
906
+
907
+
908
+ def test_delete_memories(writable_memoryset: LabeledMemoryset):
909
+ # We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
910
+
911
+ # test deleting a single memory
912
+ prev_length = writable_memoryset.length
913
+ memory_id = writable_memoryset[0].memory_id
914
+ writable_memoryset.delete(memory_id)
915
+ with pytest.raises(LookupError):
916
+ writable_memoryset.get(memory_id)
917
+ assert writable_memoryset.length == prev_length - 1
918
+
919
+ # test deleting multiple memories
920
+ prev_length = writable_memoryset.length
921
+ writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id], batch_size=1)
922
+ assert writable_memoryset.length == prev_length - 2
923
+
924
+
925
+ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
926
+ cloned_memoryset = readonly_memoryset.clone(
927
+ "test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
928
+ )
929
+ assert cloned_memoryset is not None
930
+ assert cloned_memoryset.name == "test_cloned_memoryset"
931
+ assert cloned_memoryset.length == readonly_memoryset.length
932
+ assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
933
+ assert cloned_memoryset.insertion_status == Status.COMPLETED
934
+
935
+
936
+ def test_clone_empty_memoryset():
937
+ name = f"test_empty_to_clone_{uuid4()}"
938
+ cloned_name = f"test_empty_cloned_{uuid4()}"
939
+ label_names = ["negative", "positive"]
940
+ try:
941
+ # Create an empty memoryset
942
+ empty_memoryset = LabeledMemoryset.create(name, label_names=label_names, description="empty memoryset to clone")
943
+ assert empty_memoryset is not None
944
+ assert empty_memoryset.name == name
945
+ assert empty_memoryset.length == 0
946
+ assert empty_memoryset.insertion_status is None # Empty memorysets have None status
947
+
948
+ # Clone the empty memoryset
949
+ cloned_memoryset = empty_memoryset.clone(cloned_name, embedding_model=PretrainedEmbeddingModel.DISTILBERT)
950
+ assert cloned_memoryset is not None
951
+ assert cloned_memoryset.name == cloned_name
952
+ assert cloned_memoryset.length == 0 # Clone should also be empty
953
+ assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
954
+ assert cloned_memoryset.insertion_status == Status.COMPLETED
955
+ assert cloned_memoryset.label_names == label_names
956
+ finally:
957
+ LabeledMemoryset.drop(name, if_not_exists="ignore")
958
+ LabeledMemoryset.drop(cloned_name, if_not_exists="ignore")
959
+
960
+
961
+ @pytest.fixture(scope="function")
962
+ async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
963
+ writable_memoryset.insert(
964
+ [
965
+ dict(value="raspberry soup Is my favorite", label=0),
966
+ dict(value="Raspberry soup is MY favorite", label=0),
967
+ dict(value="rAspberry soup is my favorite", label=0),
968
+ dict(value="raSpberry SOuP is my favorite", label=0),
969
+ dict(value="rasPberry SOuP is my favorite", label=0),
970
+ dict(value="bunny rabbit Is not my mom", label=1),
971
+ dict(value="bunny rabbit is not MY mom", label=1),
972
+ dict(value="bunny rabbit Is not my moM", label=1),
973
+ dict(value="bunny rabbit is not my mom", label=1),
974
+ dict(value="bunny rabbit is not my mom", label=1),
975
+ dict(value="bunny rabbit is not My mom", label=1),
976
+ ]
977
+ )
978
+
979
+ writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
980
+ response = writable_memoryset.get_potential_duplicate_groups()
981
+ assert isinstance(response, list)
982
+ assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
983
+
984
+
985
+ def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
986
+ # Insert a memory to test cascading edits
987
+ SOUP = 0
988
+ CATS = 1
989
+ query_text = "i love soup" # from SAMPLE_DATA in conftest.py
990
+ mislabeled_soup_text = "soup is comfort in a bowl"
991
+ writable_memoryset.insert(
992
+ [
993
+ dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
994
+ ]
995
+ )
996
+
997
+ # Fetch the memory to update
998
+ memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
999
+
1000
+ # Update the label and get cascading edit suggestions
1001
+ suggestions = writable_memoryset.get_cascading_edits_suggestions(
1002
+ memory=memory,
1003
+ old_label=CATS,
1004
+ new_label=SOUP,
1005
+ max_neighbors=10,
1006
+ max_validation_neighbors=5,
1007
+ )
1008
+
1009
+ # Validate the suggestions
1010
+ assert len(suggestions) == 1
1011
+ assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
1012
+
1013
+
1014
+ def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
1015
+ """Test that analyze() raises ValueError for invalid analysis names"""
1016
+ memoryset = LabeledMemoryset.open(readonly_memoryset.name)
1017
+
1018
+ # Test with string input
1019
+ with pytest.raises(ValueError) as excinfo:
1020
+ memoryset.analyze("invalid_name")
1021
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
1022
+ assert "Valid names are:" in str(excinfo.value)
1023
+
1024
+ # Test with dict input
1025
+ with pytest.raises(ValueError) as excinfo:
1026
+ memoryset.analyze({"name": "invalid_name"})
1027
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
1028
+ assert "Valid names are:" in str(excinfo.value)
1029
+
1030
+ # Test with multiple analyses where one is invalid
1031
+ with pytest.raises(ValueError) as excinfo:
1032
+ memoryset.analyze("duplicate", "invalid_name")
1033
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
1034
+ assert "Valid names are:" in str(excinfo.value)
1035
+
1036
+ # Test with valid analysis names
1037
+ result = memoryset.analyze("duplicate", "cluster")
1038
+ assert isinstance(result, dict)
1039
+ assert "duplicate" in result
1040
+ assert "cluster" in result
1041
+
1042
+
1043
+ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
1044
+ # NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
1045
+ # Otherwise, it would be recreated on the next test run if it were dropped earlier, and
1046
+ # that's expensive.
1047
+ assert LabeledMemoryset.exists(writable_memoryset.name)
1048
+ LabeledMemoryset.drop(writable_memoryset.name)
1049
+ assert not LabeledMemoryset.exists(writable_memoryset.name)
1050
+
1051
+
1052
+ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
1053
+ assert scored_memoryset.length == 22
1054
+ assert isinstance(scored_memoryset[0], ScoredMemory)
1055
+ assert scored_memoryset[0].value == "i love soup"
1056
+ assert scored_memoryset[0].score is not None
1057
+ assert scored_memoryset[0].metadata == {"key": "g1", "label": 0, "partition_id": "p1"}
1058
+ assert scored_memoryset[0].source_id == "s1"
1059
+ lookup = scored_memoryset.search("i love soup", count=1)
1060
+ assert len(lookup) == 1
1061
+ assert lookup[0].score is not None
1062
+ assert lookup[0].score < 0.11
1063
+
1064
+
1065
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
1066
+ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
1067
+ # we are only updating an inconsequential metadata field so that we don't affect other tests
1068
+ memory = scored_memoryset[0]
1069
+ assert memory.label == 0
1070
+ scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
1071
+ assert scored_memoryset[0].label == 3
1072
+ memory.update(label=4)
1073
+ assert scored_memoryset[0].label == 4
1074
+
1075
+
1076
+ @pytest.mark.asyncio
1077
+ async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
1078
+ """Test async insertion of a single memory"""
1079
+ await writable_memoryset.arefresh()
1080
+ prev_length = writable_memoryset.length
1081
+
1082
+ await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
1083
+
1084
+ await writable_memoryset.arefresh()
1085
+ assert writable_memoryset.length == prev_length + 1
1086
+ last_memory = writable_memoryset[-1]
1087
+ assert last_memory.value == "async tomato soup is my favorite"
1088
+ assert last_memory.label == 0
1089
+ assert last_memory.metadata["key"] == "async_test"
1090
+
1091
+
1092
+ @pytest.mark.asyncio
1093
+ async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
1094
+ """Test async insertion of multiple memories"""
1095
+ await writable_memoryset.arefresh()
1096
+ prev_length = writable_memoryset.length
1097
+
1098
+ await writable_memoryset.ainsert(
1099
+ [
1100
+ dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
1101
+ dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
1102
+ ]
1103
+ )
1104
+
1105
+ await writable_memoryset.arefresh()
1106
+ assert writable_memoryset.length == prev_length + 2
1107
+
1108
+ # Check the inserted memories
1109
+ last_two_memories = writable_memoryset[-2:]
1110
+ values = [memory.value for memory in last_two_memories]
1111
+ labels = [memory.label for memory in last_two_memories]
1112
+ keys = [memory.metadata.get("key") for memory in last_two_memories]
1113
+
1114
+ assert "async batch soup is delicious" in values
1115
+ assert "async batch cats are adorable" in values
1116
+ assert 0 in labels
1117
+ assert 1 in labels
1118
+ assert "batch_test_1" in keys
1119
+ assert "batch_test_2" in keys
1120
+
1121
+
1122
+ @pytest.mark.asyncio
1123
+ async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
1124
+ """Test async insertion with source_id and metadata"""
1125
+ await writable_memoryset.arefresh()
1126
+ prev_length = writable_memoryset.length
1127
+
1128
+ await writable_memoryset.ainsert(
1129
+ dict(
1130
+ value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
1131
+ )
1132
+ )
1133
+
1134
+ await writable_memoryset.arefresh()
1135
+ assert writable_memoryset.length == prev_length + 1
1136
+ last_memory = writable_memoryset[-1]
1137
+ assert last_memory.value == "async soup with source id"
1138
+ assert last_memory.label == 0
1139
+ assert last_memory.source_id == "async_source_123"
1140
+ assert last_memory.metadata["custom_field"] == "async_custom_value"
1141
+
1142
+
1143
+ @pytest.mark.asyncio
1144
+ async def test_insert_memories_async_unauthenticated(
1145
+ unauthenticated_async_client, writable_memoryset: LabeledMemoryset
1146
+ ):
1147
+ """Test async insertion with invalid authentication"""
1148
+ with unauthenticated_async_client.use():
1149
+ with pytest.raises(ValueError, match="Invalid API key"):
1150
+ await writable_memoryset.ainsert(dict(value="this should fail", label=0))