orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3942 -0
- orca_sdk/classification_model.py +218 -20
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +899 -712
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +64 -12
- orca_sdk/datasource_test.py +144 -18
- orca_sdk/embedding_model.py +54 -37
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +27 -21
- orca_sdk/memoryset.py +823 -205
- orca_sdk/memoryset_test.py +315 -33
- orca_sdk/regression_model.py +59 -15
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +76 -26
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/METADATA +1 -1
- orca_sdk-0.1.4.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -5,10 +5,11 @@ import pytest
|
|
|
5
5
|
from datasets.arrow_dataset import Dataset
|
|
6
6
|
|
|
7
7
|
from .classification_model import ClassificationModel
|
|
8
|
-
from .conftest import skip_in_prod
|
|
8
|
+
from .conftest import skip_in_ci, skip_in_prod
|
|
9
9
|
from .datasource import Datasource
|
|
10
10
|
from .embedding_model import PretrainedEmbeddingModel
|
|
11
11
|
from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
|
|
12
|
+
from .regression_model import RegressionModel
|
|
12
13
|
|
|
13
14
|
"""
|
|
14
15
|
Test Performance Note:
|
|
@@ -39,9 +40,10 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
|
|
|
39
40
|
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
40
41
|
|
|
41
42
|
|
|
42
|
-
def test_create_memoryset_unauthenticated(
|
|
43
|
-
with
|
|
44
|
-
|
|
43
|
+
def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
|
|
44
|
+
with unauthenticated_client.use():
|
|
45
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
46
|
+
LabeledMemoryset.create("test_memoryset", datasource)
|
|
45
47
|
|
|
46
48
|
|
|
47
49
|
def test_create_memoryset_invalid_input(datasource):
|
|
@@ -87,6 +89,75 @@ def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_
|
|
|
87
89
|
assert opened_memoryset.length == len(hf_dataset)
|
|
88
90
|
|
|
89
91
|
|
|
92
|
+
def test_if_exists_error_no_datasource_creation(
|
|
93
|
+
readonly_memoryset: LabeledMemoryset,
|
|
94
|
+
):
|
|
95
|
+
memoryset_name = readonly_memoryset.name
|
|
96
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
97
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
98
|
+
assert not Datasource.exists(datasource_name)
|
|
99
|
+
with pytest.raises(ValueError):
|
|
100
|
+
LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
|
|
101
|
+
assert not Datasource.exists(datasource_name)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_if_exists_open_reuses_existing_datasource(
|
|
105
|
+
readonly_memoryset: LabeledMemoryset,
|
|
106
|
+
):
|
|
107
|
+
memoryset_name = readonly_memoryset.name
|
|
108
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
109
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
110
|
+
assert not Datasource.exists(datasource_name)
|
|
111
|
+
reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
|
|
112
|
+
assert reopened.id == readonly_memoryset.id
|
|
113
|
+
assert not Datasource.exists(datasource_name)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def test_create_memoryset_string_label():
|
|
117
|
+
assert not LabeledMemoryset.exists("test_string_label")
|
|
118
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
119
|
+
"test_string_label",
|
|
120
|
+
Dataset.from_dict({"value": ["terrible", "great"], "label": ["negative", "positive"]}),
|
|
121
|
+
)
|
|
122
|
+
assert memoryset is not None
|
|
123
|
+
assert memoryset.length == 2
|
|
124
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
125
|
+
assert memoryset[0].label == 0
|
|
126
|
+
assert memoryset[1].label == 1
|
|
127
|
+
assert memoryset[0].label_name == "negative"
|
|
128
|
+
assert memoryset[1].label_name == "positive"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_create_memoryset_integer_label():
|
|
132
|
+
assert not LabeledMemoryset.exists("test_integer_label")
|
|
133
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
134
|
+
"test_integer_label",
|
|
135
|
+
Dataset.from_dict({"value": ["terrible", "great"], "label": [0, 1]}),
|
|
136
|
+
label_names=["negative", "positive"],
|
|
137
|
+
)
|
|
138
|
+
assert memoryset is not None
|
|
139
|
+
assert memoryset.length == 2
|
|
140
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
141
|
+
assert memoryset[0].label == 0
|
|
142
|
+
assert memoryset[1].label == 1
|
|
143
|
+
assert memoryset[0].label_name == "negative"
|
|
144
|
+
assert memoryset[1].label_name == "positive"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_create_memoryset_null_labels():
|
|
148
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
149
|
+
"test_null_labels",
|
|
150
|
+
Dataset.from_dict({"value": ["terrible", "great"]}),
|
|
151
|
+
label_names=["negative", "positive"],
|
|
152
|
+
label_column=None,
|
|
153
|
+
)
|
|
154
|
+
assert memoryset is not None
|
|
155
|
+
assert memoryset.length == 2
|
|
156
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
157
|
+
assert memoryset[0].label == None
|
|
158
|
+
assert memoryset[1].label == None
|
|
159
|
+
|
|
160
|
+
|
|
90
161
|
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
91
162
|
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
92
163
|
assert fetched_memoryset is not None
|
|
@@ -96,9 +167,10 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
|
96
167
|
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
97
168
|
|
|
98
169
|
|
|
99
|
-
def test_open_memoryset_unauthenticated(
|
|
100
|
-
with
|
|
101
|
-
|
|
170
|
+
def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
171
|
+
with unauthenticated_client.use():
|
|
172
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
173
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
102
174
|
|
|
103
175
|
|
|
104
176
|
def test_open_memoryset_not_found():
|
|
@@ -111,9 +183,10 @@ def test_open_memoryset_invalid_input():
|
|
|
111
183
|
LabeledMemoryset.open("not valid id")
|
|
112
184
|
|
|
113
185
|
|
|
114
|
-
def test_open_memoryset_unauthorized(
|
|
115
|
-
with
|
|
116
|
-
|
|
186
|
+
def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
187
|
+
with unauthorized_client.use():
|
|
188
|
+
with pytest.raises(LookupError):
|
|
189
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
117
190
|
|
|
118
191
|
|
|
119
192
|
def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
@@ -142,18 +215,21 @@ def test_all_memorysets_hidden(
|
|
|
142
215
|
assert hidden_memoryset in all_memorysets
|
|
143
216
|
|
|
144
217
|
|
|
145
|
-
def test_all_memorysets_unauthenticated(
|
|
146
|
-
with
|
|
147
|
-
|
|
218
|
+
def test_all_memorysets_unauthenticated(unauthenticated_client):
|
|
219
|
+
with unauthenticated_client.use():
|
|
220
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
221
|
+
LabeledMemoryset.all()
|
|
148
222
|
|
|
149
223
|
|
|
150
|
-
def test_all_memorysets_unauthorized(
|
|
151
|
-
|
|
224
|
+
def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
|
|
225
|
+
with unauthorized_client.use():
|
|
226
|
+
assert readonly_memoryset not in LabeledMemoryset.all()
|
|
152
227
|
|
|
153
228
|
|
|
154
|
-
def test_drop_memoryset_unauthenticated(
|
|
155
|
-
with
|
|
156
|
-
|
|
229
|
+
def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
230
|
+
with unauthenticated_client.use():
|
|
231
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
232
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
157
233
|
|
|
158
234
|
|
|
159
235
|
def test_drop_memoryset_not_found():
|
|
@@ -163,9 +239,10 @@ def test_drop_memoryset_not_found():
|
|
|
163
239
|
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
164
240
|
|
|
165
241
|
|
|
166
|
-
def test_drop_memoryset_unauthorized(
|
|
167
|
-
with
|
|
168
|
-
|
|
242
|
+
def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
243
|
+
with unauthorized_client.use():
|
|
244
|
+
with pytest.raises(LookupError):
|
|
245
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
169
246
|
|
|
170
247
|
|
|
171
248
|
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
@@ -304,6 +381,143 @@ def test_query_memoryset_with_feedback_metrics_sort(classification_model: Classi
|
|
|
304
381
|
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
305
382
|
|
|
306
383
|
|
|
384
|
+
def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
|
|
385
|
+
"""Test that LabeledMemory.predictions() only returns classification predictions."""
|
|
386
|
+
# Given: A classification model with memories
|
|
387
|
+
memories = classification_model.memoryset.query(limit=1)
|
|
388
|
+
assert len(memories) > 0
|
|
389
|
+
memory = memories[0]
|
|
390
|
+
|
|
391
|
+
# When: I call the predictions method
|
|
392
|
+
predictions = memory.predictions()
|
|
393
|
+
|
|
394
|
+
# Then: It should return a list of ClassificationPrediction objects
|
|
395
|
+
assert isinstance(predictions, list)
|
|
396
|
+
for prediction in predictions:
|
|
397
|
+
assert prediction.__class__.__name__ == "ClassificationPrediction"
|
|
398
|
+
assert hasattr(prediction, "label")
|
|
399
|
+
assert not hasattr(prediction, "score") or prediction.score is None
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def test_scored_memory_predictions_property(regression_model: RegressionModel):
|
|
403
|
+
"""Test that ScoredMemory.predictions() only returns regression predictions."""
|
|
404
|
+
# Given: A regression model with memories
|
|
405
|
+
memories = regression_model.memoryset.query(limit=1)
|
|
406
|
+
assert len(memories) > 0
|
|
407
|
+
memory = memories[0]
|
|
408
|
+
|
|
409
|
+
# When: I call the predictions method
|
|
410
|
+
predictions = memory.predictions()
|
|
411
|
+
|
|
412
|
+
# Then: It should return a list of RegressionPrediction objects
|
|
413
|
+
assert isinstance(predictions, list)
|
|
414
|
+
for prediction in predictions:
|
|
415
|
+
assert prediction.__class__.__name__ == "RegressionPrediction"
|
|
416
|
+
assert hasattr(prediction, "score")
|
|
417
|
+
assert not hasattr(prediction, "label") or prediction.label is None
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def test_memory_feedback_property(classification_model: ClassificationModel):
|
|
421
|
+
"""Test that memory.feedback() returns feedback from relevant predictions."""
|
|
422
|
+
# Given: A prediction with recorded feedback
|
|
423
|
+
prediction = classification_model.predict("Test feedback")
|
|
424
|
+
feedback_category = f"test_feedback_{random.randint(0, 1000000)}"
|
|
425
|
+
prediction.record_feedback(category=feedback_category, value=True)
|
|
426
|
+
|
|
427
|
+
# And: A memory that was used in the prediction
|
|
428
|
+
memory_lookups = prediction.memory_lookups
|
|
429
|
+
assert len(memory_lookups) > 0
|
|
430
|
+
memory = memory_lookups[0]
|
|
431
|
+
|
|
432
|
+
# When: I access the feedback property
|
|
433
|
+
feedback = memory.feedback()
|
|
434
|
+
|
|
435
|
+
# Then: It should return feedback aggregated by category as a dict
|
|
436
|
+
assert isinstance(feedback, dict)
|
|
437
|
+
assert feedback_category in feedback
|
|
438
|
+
# Feedback values are lists (you may want to look at mean on the raw data)
|
|
439
|
+
assert isinstance(feedback[feedback_category], list)
|
|
440
|
+
assert len(feedback[feedback_category]) > 0
|
|
441
|
+
# For binary feedback, values should be booleans
|
|
442
|
+
assert isinstance(feedback[feedback_category][0], bool)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def test_memory_predictions_method_parameters(classification_model: ClassificationModel):
|
|
446
|
+
"""Test that memory.predictions() method supports pagination, sorting, and filtering."""
|
|
447
|
+
# Given: A classification model with memories
|
|
448
|
+
memories = classification_model.memoryset.query(limit=1)
|
|
449
|
+
assert len(memories) > 0
|
|
450
|
+
memory = memories[0]
|
|
451
|
+
|
|
452
|
+
# When: I call predictions with limit parameter
|
|
453
|
+
predictions_limited = memory.predictions(limit=2)
|
|
454
|
+
|
|
455
|
+
# Then: It should respect the limit
|
|
456
|
+
assert isinstance(predictions_limited, list)
|
|
457
|
+
assert len(predictions_limited) <= 2
|
|
458
|
+
|
|
459
|
+
# When: I call predictions with offset parameter
|
|
460
|
+
all_predictions = memory.predictions(limit=100)
|
|
461
|
+
if len(all_predictions) > 1:
|
|
462
|
+
predictions_offset = memory.predictions(limit=1, offset=1)
|
|
463
|
+
# Then: offset should skip the first prediction
|
|
464
|
+
assert predictions_offset[0].prediction_id != all_predictions[0].prediction_id
|
|
465
|
+
|
|
466
|
+
# When: I call predictions with sort parameter
|
|
467
|
+
predictions_sorted = memory.predictions(limit=10, sort=[("timestamp", "desc")])
|
|
468
|
+
# Then: It should return predictions (sorting verified by API)
|
|
469
|
+
assert isinstance(predictions_sorted, list)
|
|
470
|
+
|
|
471
|
+
# When: I call predictions with expected_label_match parameter
|
|
472
|
+
correct_predictions = memory.predictions(expected_label_match=True)
|
|
473
|
+
incorrect_predictions = memory.predictions(expected_label_match=False)
|
|
474
|
+
# Then: Both should return lists (correctness verified by API filtering)
|
|
475
|
+
assert isinstance(correct_predictions, list)
|
|
476
|
+
assert isinstance(incorrect_predictions, list)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def test_memory_predictions_expected_label_filter(classification_model: ClassificationModel):
|
|
480
|
+
"""Test that memory.predictions(expected_label_match=...) filters predictions by correctness."""
|
|
481
|
+
# Given: Make an initial prediction to learn the model's label for a known input
|
|
482
|
+
baseline_prediction = classification_model.predict("Filter test input", save_telemetry="sync")
|
|
483
|
+
original_label = baseline_prediction.label
|
|
484
|
+
alternate_label = 0 if original_label else 1
|
|
485
|
+
|
|
486
|
+
# When: Make a second prediction with an intentionally incorrect expected label
|
|
487
|
+
mismatched_prediction = classification_model.predict(
|
|
488
|
+
"Filter test input",
|
|
489
|
+
expected_labels=alternate_label,
|
|
490
|
+
save_telemetry="sync",
|
|
491
|
+
)
|
|
492
|
+
mismatched_memory = mismatched_prediction.memory_lookups[0]
|
|
493
|
+
|
|
494
|
+
# Then: The prediction should show up when filtering for incorrect predictions
|
|
495
|
+
incorrect_predictions = mismatched_memory.predictions(expected_label_match=False)
|
|
496
|
+
assert any(pred.prediction_id == mismatched_prediction.prediction_id for pred in incorrect_predictions)
|
|
497
|
+
|
|
498
|
+
# Produce a correct prediction (predicted label matches expected label)
|
|
499
|
+
correct_prediction = classification_model.predict(
|
|
500
|
+
"Filter test input",
|
|
501
|
+
expected_labels=original_label,
|
|
502
|
+
save_telemetry="sync",
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Ensure we are inspecting a memory used by both correct and incorrect predictions
|
|
506
|
+
correct_lookup_ids = {lookup.memory_id for lookup in correct_prediction.memory_lookups}
|
|
507
|
+
if mismatched_memory.memory_id not in correct_lookup_ids:
|
|
508
|
+
shared_lookup = next(
|
|
509
|
+
(lookup for lookup in mismatched_prediction.memory_lookups if lookup.memory_id in correct_lookup_ids),
|
|
510
|
+
None,
|
|
511
|
+
)
|
|
512
|
+
assert shared_lookup is not None, "No shared memory lookup between correct and incorrect predictions"
|
|
513
|
+
mismatched_memory = shared_lookup
|
|
514
|
+
|
|
515
|
+
# And: The correct prediction should appear when filtering for correct predictions
|
|
516
|
+
correct_predictions = mismatched_memory.predictions(expected_label_match=True)
|
|
517
|
+
assert any(pred.prediction_id == correct_prediction.prediction_id for pred in correct_predictions)
|
|
518
|
+
assert all(pred.prediction_id != mismatched_prediction.prediction_id for pred in correct_predictions)
|
|
519
|
+
|
|
520
|
+
|
|
307
521
|
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
308
522
|
writable_memoryset.refresh()
|
|
309
523
|
prev_length = writable_memoryset.length
|
|
@@ -327,6 +541,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
327
541
|
|
|
328
542
|
|
|
329
543
|
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
544
|
+
@skip_in_ci("CI environment may not have session consistency guarantees")
|
|
330
545
|
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
331
546
|
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
332
547
|
|
|
@@ -385,17 +600,6 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
385
600
|
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
386
601
|
|
|
387
602
|
|
|
388
|
-
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
389
|
-
results = LabeledMemoryset.run_embedding_evaluation(
|
|
390
|
-
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=3
|
|
391
|
-
)
|
|
392
|
-
assert isinstance(results, list)
|
|
393
|
-
assert len(results) == 1
|
|
394
|
-
assert results[0] is not None
|
|
395
|
-
assert results[0]["embedding_model_name"] == "CDE_SMALL"
|
|
396
|
-
assert results[0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
397
|
-
|
|
398
|
-
|
|
399
603
|
@pytest.fixture(scope="function")
|
|
400
604
|
async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
401
605
|
writable_memoryset.insert(
|
|
@@ -492,7 +696,8 @@ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
|
492
696
|
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
493
697
|
assert scored_memoryset[0].value == "i love soup"
|
|
494
698
|
assert scored_memoryset[0].score is not None
|
|
495
|
-
assert scored_memoryset[0].metadata == {"key": "g1", "
|
|
699
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "label": 0}
|
|
700
|
+
assert scored_memoryset[0].source_id == "s1"
|
|
496
701
|
lookup = scored_memoryset.search("i love soup", count=1)
|
|
497
702
|
assert len(lookup) == 1
|
|
498
703
|
assert lookup[0].score is not None
|
|
@@ -508,3 +713,80 @@ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
|
508
713
|
assert scored_memoryset[0].label == 3
|
|
509
714
|
memory.update(label=4)
|
|
510
715
|
assert scored_memoryset[0].label == 4
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
@pytest.mark.asyncio
|
|
719
|
+
async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
|
|
720
|
+
"""Test async insertion of a single memory"""
|
|
721
|
+
await writable_memoryset.arefresh()
|
|
722
|
+
prev_length = writable_memoryset.length
|
|
723
|
+
|
|
724
|
+
await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
|
|
725
|
+
|
|
726
|
+
await writable_memoryset.arefresh()
|
|
727
|
+
assert writable_memoryset.length == prev_length + 1
|
|
728
|
+
last_memory = writable_memoryset[-1]
|
|
729
|
+
assert last_memory.value == "async tomato soup is my favorite"
|
|
730
|
+
assert last_memory.label == 0
|
|
731
|
+
assert last_memory.metadata["key"] == "async_test"
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@pytest.mark.asyncio
|
|
735
|
+
async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
|
|
736
|
+
"""Test async insertion of multiple memories"""
|
|
737
|
+
await writable_memoryset.arefresh()
|
|
738
|
+
prev_length = writable_memoryset.length
|
|
739
|
+
|
|
740
|
+
await writable_memoryset.ainsert(
|
|
741
|
+
[
|
|
742
|
+
dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
|
|
743
|
+
dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
|
|
744
|
+
]
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
await writable_memoryset.arefresh()
|
|
748
|
+
assert writable_memoryset.length == prev_length + 2
|
|
749
|
+
|
|
750
|
+
# Check the inserted memories
|
|
751
|
+
last_two_memories = writable_memoryset[-2:]
|
|
752
|
+
values = [memory.value for memory in last_two_memories]
|
|
753
|
+
labels = [memory.label for memory in last_two_memories]
|
|
754
|
+
keys = [memory.metadata.get("key") for memory in last_two_memories]
|
|
755
|
+
|
|
756
|
+
assert "async batch soup is delicious" in values
|
|
757
|
+
assert "async batch cats are adorable" in values
|
|
758
|
+
assert 0 in labels
|
|
759
|
+
assert 1 in labels
|
|
760
|
+
assert "batch_test_1" in keys
|
|
761
|
+
assert "batch_test_2" in keys
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
@pytest.mark.asyncio
|
|
765
|
+
async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
|
|
766
|
+
"""Test async insertion with source_id and metadata"""
|
|
767
|
+
await writable_memoryset.arefresh()
|
|
768
|
+
prev_length = writable_memoryset.length
|
|
769
|
+
|
|
770
|
+
await writable_memoryset.ainsert(
|
|
771
|
+
dict(
|
|
772
|
+
value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
|
|
773
|
+
)
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
await writable_memoryset.arefresh()
|
|
777
|
+
assert writable_memoryset.length == prev_length + 1
|
|
778
|
+
last_memory = writable_memoryset[-1]
|
|
779
|
+
assert last_memory.value == "async soup with source id"
|
|
780
|
+
assert last_memory.label == 0
|
|
781
|
+
assert last_memory.source_id == "async_source_123"
|
|
782
|
+
assert last_memory.metadata["custom_field"] == "async_custom_value"
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
@pytest.mark.asyncio
|
|
786
|
+
async def test_insert_memories_async_unauthenticated(
|
|
787
|
+
unauthenticated_async_client, writable_memoryset: LabeledMemoryset
|
|
788
|
+
):
|
|
789
|
+
"""Test async insertion with invalid authentication"""
|
|
790
|
+
with unauthenticated_async_client.use():
|
|
791
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
792
|
+
await writable_memoryset.ainsert(dict(value="this should fail", label=0))
|