orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +31 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +601 -129
- orca_sdk/classification_model_test.py +415 -117
- orca_sdk/client.py +3787 -0
- orca_sdk/conftest.py +184 -38
- orca_sdk/credentials.py +162 -20
- orca_sdk/credentials_test.py +100 -16
- orca_sdk/datasource.py +268 -68
- orca_sdk/datasource_test.py +266 -18
- orca_sdk/embedding_model.py +434 -82
- orca_sdk/embedding_model_test.py +66 -33
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1690 -324
- orca_sdk/memoryset_test.py +456 -119
- orca_sdk/regression_model.py +694 -0
- orca_sdk/regression_model_test.py +378 -0
- orca_sdk/telemetry.py +460 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/memoryset_test.py
CHANGED
|
@@ -1,86 +1,130 @@
|
|
|
1
|
+
import random
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
4
5
|
from datasets.arrow_dataset import Dataset
|
|
5
6
|
|
|
7
|
+
from .classification_model import ClassificationModel
|
|
8
|
+
from .conftest import skip_in_ci, skip_in_prod
|
|
9
|
+
from .datasource import Datasource
|
|
6
10
|
from .embedding_model import PretrainedEmbeddingModel
|
|
7
|
-
from .memoryset import LabeledMemoryset,
|
|
11
|
+
from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
|
|
8
12
|
|
|
13
|
+
"""
|
|
14
|
+
Test Performance Note:
|
|
9
15
|
|
|
10
|
-
|
|
11
|
-
assert memoryset is not None
|
|
12
|
-
assert memoryset.name == "test_memoryset"
|
|
13
|
-
assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
14
|
-
assert memoryset.label_names == label_names
|
|
15
|
-
assert memoryset.insertion_status == TaskStatus.COMPLETED
|
|
16
|
-
assert isinstance(memoryset.length, int)
|
|
17
|
-
assert memoryset.length == len(hf_dataset)
|
|
16
|
+
Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
|
|
18
17
|
|
|
18
|
+
- Two fixtures are used to manage memorysets:
|
|
19
|
+
- `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
|
|
20
|
+
It should only be used in nullipotent tests.
|
|
21
|
+
- `writable_memoryset` is a function-scoped, regenerating fixture.
|
|
22
|
+
It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
|
|
19
23
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
24
|
+
- To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
|
|
25
|
+
For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
|
|
26
|
+
rather than separate `test_delete_single` and `test_delete_multiple` tests.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
31
|
+
assert readonly_memoryset is not None
|
|
32
|
+
assert readonly_memoryset.name == "test_readonly_memoryset"
|
|
33
|
+
assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
34
|
+
assert readonly_memoryset.label_names == label_names
|
|
35
|
+
assert readonly_memoryset.insertion_status == Status.COMPLETED
|
|
36
|
+
assert isinstance(readonly_memoryset.length, int)
|
|
37
|
+
assert readonly_memoryset.length == len(hf_dataset)
|
|
38
|
+
assert readonly_memoryset.index_type == "IVF_FLAT"
|
|
39
|
+
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
|
|
43
|
+
with unauthenticated_client.use():
|
|
44
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
45
|
+
LabeledMemoryset.create("test_memoryset", datasource)
|
|
23
46
|
|
|
24
47
|
|
|
25
48
|
def test_create_memoryset_invalid_input(datasource):
|
|
26
49
|
# invalid name
|
|
27
50
|
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
28
51
|
LabeledMemoryset.create("test memoryset", datasource)
|
|
29
|
-
# invalid datasource
|
|
30
|
-
datasource.id = str(uuid4())
|
|
31
|
-
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
32
|
-
LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
|
|
33
52
|
|
|
34
53
|
|
|
35
|
-
def test_create_memoryset_already_exists_error(hf_dataset, label_names,
|
|
54
|
+
def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
|
|
55
|
+
memoryset_name = readonly_memoryset.name
|
|
36
56
|
with pytest.raises(ValueError):
|
|
37
|
-
LabeledMemoryset.from_hf_dataset(
|
|
57
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
|
|
38
58
|
with pytest.raises(ValueError):
|
|
39
|
-
LabeledMemoryset.from_hf_dataset(
|
|
40
|
-
"test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
|
|
41
|
-
)
|
|
59
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
|
|
42
60
|
|
|
43
61
|
|
|
44
|
-
def test_create_memoryset_already_exists_open(hf_dataset, label_names,
|
|
62
|
+
def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
|
|
45
63
|
# invalid label names
|
|
46
64
|
with pytest.raises(ValueError):
|
|
47
65
|
LabeledMemoryset.from_hf_dataset(
|
|
48
|
-
|
|
66
|
+
readonly_memoryset.name,
|
|
49
67
|
hf_dataset,
|
|
50
68
|
label_names=["turtles", "frogs"],
|
|
51
|
-
value_column="text",
|
|
52
69
|
if_exists="open",
|
|
53
70
|
)
|
|
54
71
|
# different embedding model
|
|
55
72
|
with pytest.raises(ValueError):
|
|
56
73
|
LabeledMemoryset.from_hf_dataset(
|
|
57
|
-
|
|
74
|
+
readonly_memoryset.name,
|
|
58
75
|
hf_dataset,
|
|
59
76
|
label_names=label_names,
|
|
60
77
|
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
61
78
|
if_exists="open",
|
|
62
79
|
)
|
|
63
80
|
opened_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
64
|
-
|
|
81
|
+
readonly_memoryset.name,
|
|
65
82
|
hf_dataset,
|
|
66
83
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
67
84
|
if_exists="open",
|
|
68
85
|
)
|
|
69
86
|
assert opened_memoryset is not None
|
|
70
|
-
assert opened_memoryset.name ==
|
|
87
|
+
assert opened_memoryset.name == readonly_memoryset.name
|
|
71
88
|
assert opened_memoryset.length == len(hf_dataset)
|
|
72
89
|
|
|
73
90
|
|
|
74
|
-
def
|
|
75
|
-
|
|
91
|
+
def test_if_exists_error_no_datasource_creation(
|
|
92
|
+
readonly_memoryset: LabeledMemoryset,
|
|
93
|
+
):
|
|
94
|
+
memoryset_name = readonly_memoryset.name
|
|
95
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
96
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
97
|
+
assert not Datasource.exists(datasource_name)
|
|
98
|
+
with pytest.raises(ValueError):
|
|
99
|
+
LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
|
|
100
|
+
assert not Datasource.exists(datasource_name)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_if_exists_open_reuses_existing_datasource(
|
|
104
|
+
readonly_memoryset: LabeledMemoryset,
|
|
105
|
+
):
|
|
106
|
+
memoryset_name = readonly_memoryset.name
|
|
107
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
108
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
109
|
+
assert not Datasource.exists(datasource_name)
|
|
110
|
+
reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
|
|
111
|
+
assert reopened.id == readonly_memoryset.id
|
|
112
|
+
assert not Datasource.exists(datasource_name)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
116
|
+
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
76
117
|
assert fetched_memoryset is not None
|
|
77
|
-
assert fetched_memoryset.name ==
|
|
118
|
+
assert fetched_memoryset.name == readonly_memoryset.name
|
|
78
119
|
assert fetched_memoryset.length == len(hf_dataset)
|
|
120
|
+
assert fetched_memoryset.index_type == "IVF_FLAT"
|
|
121
|
+
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
79
122
|
|
|
80
123
|
|
|
81
|
-
def test_open_memoryset_unauthenticated(
|
|
82
|
-
with
|
|
83
|
-
|
|
124
|
+
def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
125
|
+
with unauthenticated_client.use():
|
|
126
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
127
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
84
128
|
|
|
85
129
|
|
|
86
130
|
def test_open_memoryset_not_found():
|
|
@@ -93,57 +137,93 @@ def test_open_memoryset_invalid_input():
|
|
|
93
137
|
LabeledMemoryset.open("not valid id")
|
|
94
138
|
|
|
95
139
|
|
|
96
|
-
def test_open_memoryset_unauthorized(
|
|
97
|
-
with
|
|
98
|
-
|
|
140
|
+
def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
141
|
+
with unauthorized_client.use():
|
|
142
|
+
with pytest.raises(LookupError):
|
|
143
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
99
144
|
|
|
100
145
|
|
|
101
|
-
def test_all_memorysets(
|
|
146
|
+
def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
102
147
|
memorysets = LabeledMemoryset.all()
|
|
103
148
|
assert len(memorysets) > 0
|
|
104
|
-
assert any(memoryset.name ==
|
|
149
|
+
assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
|
|
105
150
|
|
|
106
151
|
|
|
107
|
-
def
|
|
108
|
-
|
|
109
|
-
|
|
152
|
+
def test_all_memorysets_hidden(
|
|
153
|
+
readonly_memoryset: LabeledMemoryset,
|
|
154
|
+
):
|
|
155
|
+
# Create a hidden memoryset
|
|
156
|
+
hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
|
|
157
|
+
hidden_memoryset.set(hidden=True)
|
|
110
158
|
|
|
159
|
+
# Test that show_hidden=False excludes hidden memorysets
|
|
160
|
+
visible_memorysets = LabeledMemoryset.all(show_hidden=False)
|
|
161
|
+
assert len(visible_memorysets) > 0
|
|
162
|
+
assert readonly_memoryset in visible_memorysets
|
|
163
|
+
assert hidden_memoryset not in visible_memorysets
|
|
111
164
|
|
|
112
|
-
|
|
113
|
-
|
|
165
|
+
# Test that show_hidden=True includes hidden memorysets
|
|
166
|
+
all_memorysets = LabeledMemoryset.all(show_hidden=True)
|
|
167
|
+
assert len(all_memorysets) == len(visible_memorysets) + 1
|
|
168
|
+
assert readonly_memoryset in all_memorysets
|
|
169
|
+
assert hidden_memoryset in all_memorysets
|
|
114
170
|
|
|
115
171
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
value_column="text",
|
|
122
|
-
)
|
|
123
|
-
assert LabeledMemoryset.exists(memoryset.name)
|
|
124
|
-
LabeledMemoryset.drop(memoryset.name)
|
|
125
|
-
assert not LabeledMemoryset.exists(memoryset.name)
|
|
172
|
+
def test_all_memorysets_unauthenticated(unauthenticated_client):
|
|
173
|
+
with unauthenticated_client.use():
|
|
174
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
175
|
+
LabeledMemoryset.all()
|
|
176
|
+
|
|
126
177
|
|
|
178
|
+
def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
|
|
179
|
+
with unauthorized_client.use():
|
|
180
|
+
assert readonly_memoryset not in LabeledMemoryset.all()
|
|
127
181
|
|
|
128
|
-
def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
|
|
129
|
-
with pytest.raises(ValueError, match="Invalid API key"):
|
|
130
|
-
LabeledMemoryset.drop(memoryset.name)
|
|
131
182
|
|
|
183
|
+
def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
184
|
+
with unauthenticated_client.use():
|
|
185
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
186
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
132
187
|
|
|
133
|
-
|
|
188
|
+
|
|
189
|
+
def test_drop_memoryset_not_found():
|
|
134
190
|
with pytest.raises(LookupError):
|
|
135
191
|
LabeledMemoryset.drop(str(uuid4()))
|
|
136
192
|
# ignores error if specified
|
|
137
193
|
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
138
194
|
|
|
139
195
|
|
|
140
|
-
def test_drop_memoryset_unauthorized(
|
|
141
|
-
with
|
|
142
|
-
|
|
196
|
+
def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
197
|
+
with unauthorized_client.use():
|
|
198
|
+
with pytest.raises(LookupError):
|
|
199
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
143
200
|
|
|
144
201
|
|
|
145
|
-
def
|
|
146
|
-
|
|
202
|
+
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
203
|
+
original_label_names = writable_memoryset.label_names
|
|
204
|
+
writable_memoryset.set(description="New description")
|
|
205
|
+
assert writable_memoryset.description == "New description"
|
|
206
|
+
|
|
207
|
+
writable_memoryset.set(description=None)
|
|
208
|
+
assert writable_memoryset.description is None
|
|
209
|
+
|
|
210
|
+
writable_memoryset.set(name="New_name")
|
|
211
|
+
assert writable_memoryset.name == "New_name"
|
|
212
|
+
|
|
213
|
+
writable_memoryset.set(name="test_writable_memoryset")
|
|
214
|
+
assert writable_memoryset.name == "test_writable_memoryset"
|
|
215
|
+
|
|
216
|
+
assert writable_memoryset.label_names == original_label_names
|
|
217
|
+
|
|
218
|
+
writable_memoryset.set(label_names=["New label 1", "New label 2"])
|
|
219
|
+
assert writable_memoryset.label_names == ["New label 1", "New label 2"]
|
|
220
|
+
|
|
221
|
+
writable_memoryset.set(hidden=True)
|
|
222
|
+
assert writable_memoryset.hidden is True
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
226
|
+
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
147
227
|
assert len(memory_lookups) == 2
|
|
148
228
|
assert len(memory_lookups[0]) == 1
|
|
149
229
|
assert len(memory_lookups[1]) == 1
|
|
@@ -151,67 +231,125 @@ def test_search(memoryset: LabeledMemoryset):
|
|
|
151
231
|
assert memory_lookups[1][0].label == 1
|
|
152
232
|
|
|
153
233
|
|
|
154
|
-
def test_search_count(
|
|
155
|
-
memory_lookups =
|
|
234
|
+
def test_search_count(readonly_memoryset: LabeledMemoryset):
|
|
235
|
+
memory_lookups = readonly_memoryset.search("i love soup", count=3)
|
|
156
236
|
assert len(memory_lookups) == 3
|
|
157
237
|
assert memory_lookups[0].label == 0
|
|
158
238
|
assert memory_lookups[1].label == 0
|
|
159
239
|
assert memory_lookups[2].label == 0
|
|
160
240
|
|
|
161
241
|
|
|
162
|
-
def test_get_memory_at_index(
|
|
163
|
-
memory =
|
|
164
|
-
assert memory.value == hf_dataset[0]["
|
|
242
|
+
def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
243
|
+
memory = readonly_memoryset[0]
|
|
244
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
165
245
|
assert memory.label == hf_dataset[0]["label"]
|
|
166
246
|
assert memory.label_name == label_names[hf_dataset[0]["label"]]
|
|
167
247
|
assert memory.source_id == hf_dataset[0]["source_id"]
|
|
168
248
|
assert memory.score == hf_dataset[0]["score"]
|
|
169
249
|
assert memory.key == hf_dataset[0]["key"]
|
|
170
|
-
last_memory =
|
|
171
|
-
assert last_memory.value == hf_dataset[-1]["
|
|
250
|
+
last_memory = readonly_memoryset[-1]
|
|
251
|
+
assert last_memory.value == hf_dataset[-1]["value"]
|
|
172
252
|
assert last_memory.label == hf_dataset[-1]["label"]
|
|
173
253
|
|
|
174
254
|
|
|
175
|
-
def test_get_range_of_memories(
|
|
176
|
-
memories =
|
|
255
|
+
def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
256
|
+
memories = readonly_memoryset[1:3]
|
|
177
257
|
assert len(memories) == 2
|
|
178
|
-
assert memories[0].value == hf_dataset["
|
|
179
|
-
assert memories[1].value == hf_dataset["
|
|
258
|
+
assert memories[0].value == hf_dataset["value"][1]
|
|
259
|
+
assert memories[1].value == hf_dataset["value"][2]
|
|
180
260
|
|
|
181
261
|
|
|
182
|
-
def test_get_memory_by_id(
|
|
183
|
-
memory =
|
|
184
|
-
assert memory.value == hf_dataset[0]["
|
|
185
|
-
assert memory ==
|
|
262
|
+
def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
263
|
+
memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
|
|
264
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
265
|
+
assert memory == readonly_memoryset[memory.memory_id]
|
|
186
266
|
|
|
187
267
|
|
|
188
|
-
def test_get_memories_by_id(
|
|
189
|
-
memories =
|
|
268
|
+
def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
269
|
+
memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
|
|
190
270
|
assert len(memories) == 2
|
|
191
|
-
assert memories[0].value == hf_dataset[0]["
|
|
192
|
-
assert memories[1].value == hf_dataset[1]["
|
|
271
|
+
assert memories[0].value == hf_dataset[0]["value"]
|
|
272
|
+
assert memories[1].value == hf_dataset[1]["value"]
|
|
193
273
|
|
|
194
274
|
|
|
195
|
-
def test_query_memoryset(
|
|
196
|
-
memories =
|
|
197
|
-
assert len(memories) ==
|
|
275
|
+
def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
276
|
+
memories = readonly_memoryset.query(filters=[("label", "==", 1)])
|
|
277
|
+
assert len(memories) == 8
|
|
198
278
|
assert all(memory.label == 1 for memory in memories)
|
|
199
|
-
assert len(
|
|
200
|
-
assert len(
|
|
279
|
+
assert len(readonly_memoryset.query(limit=2)) == 2
|
|
280
|
+
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
|
|
284
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
285
|
+
feedback_name = f"correct_{random.randint(0, 1000000)}"
|
|
286
|
+
prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
|
|
287
|
+
memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
|
|
288
|
+
|
|
289
|
+
# Get the memory_ids that were actually used in the prediction
|
|
290
|
+
used_memory_ids = {memory.memory_id for memory in prediction.memory_lookups}
|
|
291
|
+
|
|
292
|
+
assert len(memories) == 8
|
|
293
|
+
assert all(memory.label == 0 for memory in memories)
|
|
294
|
+
for memory in memories:
|
|
295
|
+
assert memory.feedback_metrics is not None
|
|
296
|
+
if memory.memory_id in used_memory_ids:
|
|
297
|
+
assert feedback_name in memory.feedback_metrics
|
|
298
|
+
assert memory.feedback_metrics[feedback_name]["avg"] == 1.0
|
|
299
|
+
assert memory.feedback_metrics[feedback_name]["count"] == 1
|
|
300
|
+
else:
|
|
301
|
+
assert feedback_name not in memory.feedback_metrics or memory.feedback_metrics[feedback_name]["count"] == 0
|
|
302
|
+
assert isinstance(memory.lookup_count, int)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
|
|
306
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
307
|
+
prediction.record_feedback(category="accurate", value=prediction.label == 0)
|
|
308
|
+
memories = prediction.memoryset.query(
|
|
309
|
+
filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
|
|
310
|
+
)
|
|
311
|
+
assert len(memories) == 3
|
|
312
|
+
assert all(memory.label == 0 for memory in memories)
|
|
313
|
+
for memory in memories:
|
|
314
|
+
assert memory.feedback_metrics is not None
|
|
315
|
+
assert memory.feedback_metrics["accurate"] is not None
|
|
316
|
+
assert memory.feedback_metrics["accurate"]["avg"] == 1.0
|
|
317
|
+
assert memory.feedback_metrics["accurate"]["count"] == 1
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
|
|
321
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
322
|
+
prediction.record_feedback(category="positive", value=1.0)
|
|
323
|
+
prediction2 = classification_model.predict("Do you like cats?")
|
|
324
|
+
prediction2.record_feedback(category="positive", value=-1.0)
|
|
325
|
+
|
|
326
|
+
memories = prediction.memoryset.query(
|
|
327
|
+
filters=[("feedback_metrics.positive.avg", ">=", -1.0)],
|
|
328
|
+
sort=[("feedback_metrics.positive.avg", "desc")],
|
|
329
|
+
with_feedback_metrics=True,
|
|
330
|
+
)
|
|
331
|
+
assert (
|
|
332
|
+
len(memories) == 6
|
|
333
|
+
) # there are only 6 out of 16 memories that have a positive feedback metric. Look at SAMPLE_DATA in conftest.py
|
|
334
|
+
assert memories[0].feedback_metrics["positive"]["avg"] == 1.0
|
|
335
|
+
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
201
336
|
|
|
202
337
|
|
|
203
|
-
def test_insert_memories(
|
|
204
|
-
|
|
205
|
-
|
|
338
|
+
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
339
|
+
writable_memoryset.refresh()
|
|
340
|
+
prev_length = writable_memoryset.length
|
|
341
|
+
writable_memoryset.insert(
|
|
206
342
|
[
|
|
207
343
|
dict(value="tomato soup is my favorite", label=0),
|
|
208
344
|
dict(value="cats are fun to play with", label=1),
|
|
209
345
|
]
|
|
210
346
|
)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
347
|
+
writable_memoryset.refresh()
|
|
348
|
+
assert writable_memoryset.length == prev_length + 2
|
|
349
|
+
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
350
|
+
writable_memoryset.refresh()
|
|
351
|
+
assert writable_memoryset.length == prev_length + 3
|
|
352
|
+
last_memory = writable_memoryset[-1]
|
|
215
353
|
assert last_memory.value == "tomato soup is my favorite"
|
|
216
354
|
assert last_memory.label == 0
|
|
217
355
|
assert last_memory.metadata
|
|
@@ -219,25 +357,29 @@ def test_insert_memories(memoryset: LabeledMemoryset):
|
|
|
219
357
|
assert last_memory.source_id == "test"
|
|
220
358
|
|
|
221
359
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
360
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
361
|
+
@skip_in_ci("CI environment may not have session consistency guarantees")
|
|
362
|
+
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
363
|
+
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
364
|
+
|
|
365
|
+
# test updating a single memory
|
|
366
|
+
memory_id = writable_memoryset[0].memory_id
|
|
367
|
+
updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
225
368
|
assert updated_memory.value == "i love soup so much"
|
|
226
369
|
assert updated_memory.label == hf_dataset[0]["label"]
|
|
227
|
-
|
|
370
|
+
writable_memoryset.refresh() # Refresh to ensure consistency after update
|
|
371
|
+
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
228
372
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
memory = memoryset[0]
|
|
373
|
+
# test updating a memory instance
|
|
374
|
+
memory = writable_memoryset[0]
|
|
232
375
|
updated_memory = memory.update(value="i love soup even more")
|
|
233
376
|
assert updated_memory is memory
|
|
234
377
|
assert memory.value == "i love soup even more"
|
|
235
378
|
assert memory.label == hf_dataset[0]["label"]
|
|
236
379
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
updated_memories = memoryset.update(
|
|
380
|
+
# test updating multiple memories
|
|
381
|
+
memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
|
|
382
|
+
updated_memories = writable_memoryset.update(
|
|
241
383
|
[
|
|
242
384
|
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
243
385
|
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
@@ -247,25 +389,220 @@ def test_update_memories(memoryset: LabeledMemoryset):
|
|
|
247
389
|
assert updated_memories[1].value == "cats are so cute"
|
|
248
390
|
|
|
249
391
|
|
|
250
|
-
def
|
|
251
|
-
|
|
252
|
-
memory_id = memoryset[0].memory_id
|
|
253
|
-
memoryset.delete(memory_id)
|
|
254
|
-
with pytest.raises(LookupError):
|
|
255
|
-
memoryset.get(memory_id)
|
|
256
|
-
assert memoryset.length == prev_length - 1
|
|
392
|
+
def test_delete_memories(writable_memoryset: LabeledMemoryset):
|
|
393
|
+
# We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
257
394
|
|
|
395
|
+
# test deleting a single memory
|
|
396
|
+
prev_length = writable_memoryset.length
|
|
397
|
+
memory_id = writable_memoryset[0].memory_id
|
|
398
|
+
writable_memoryset.delete(memory_id)
|
|
399
|
+
with pytest.raises(LookupError):
|
|
400
|
+
writable_memoryset.get(memory_id)
|
|
401
|
+
assert writable_memoryset.length == prev_length - 1
|
|
258
402
|
|
|
259
|
-
|
|
260
|
-
prev_length =
|
|
261
|
-
|
|
262
|
-
assert
|
|
403
|
+
# test deleting multiple memories
|
|
404
|
+
prev_length = writable_memoryset.length
|
|
405
|
+
writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id])
|
|
406
|
+
assert writable_memoryset.length == prev_length - 2
|
|
263
407
|
|
|
264
408
|
|
|
265
|
-
def test_clone_memoryset(
|
|
266
|
-
cloned_memoryset =
|
|
409
|
+
def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
410
|
+
cloned_memoryset = readonly_memoryset.clone(
|
|
411
|
+
"test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
|
|
412
|
+
)
|
|
267
413
|
assert cloned_memoryset is not None
|
|
268
414
|
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
269
|
-
assert cloned_memoryset.length ==
|
|
415
|
+
assert cloned_memoryset.length == readonly_memoryset.length
|
|
270
416
|
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
271
|
-
assert cloned_memoryset.insertion_status ==
|
|
417
|
+
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@pytest.fixture(scope="function")
|
|
421
|
+
async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
422
|
+
writable_memoryset.insert(
|
|
423
|
+
[
|
|
424
|
+
dict(value="raspberry soup Is my favorite", label=0),
|
|
425
|
+
dict(value="Raspberry soup is MY favorite", label=0),
|
|
426
|
+
dict(value="rAspberry soup is my favorite", label=0),
|
|
427
|
+
dict(value="raSpberry SOuP is my favorite", label=0),
|
|
428
|
+
dict(value="rasPberry SOuP is my favorite", label=0),
|
|
429
|
+
dict(value="bunny rabbit Is not my mom", label=1),
|
|
430
|
+
dict(value="bunny rabbit is not MY mom", label=1),
|
|
431
|
+
dict(value="bunny rabbit Is not my moM", label=1),
|
|
432
|
+
dict(value="bunny rabbit is not my mom", label=1),
|
|
433
|
+
dict(value="bunny rabbit is not my mom", label=1),
|
|
434
|
+
dict(value="bunny rabbit is not My mom", label=1),
|
|
435
|
+
]
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
|
|
439
|
+
response = writable_memoryset.get_potential_duplicate_groups()
|
|
440
|
+
assert isinstance(response, list)
|
|
441
|
+
assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
445
|
+
# Insert a memory to test cascading edits
|
|
446
|
+
SOUP = 0
|
|
447
|
+
CATS = 1
|
|
448
|
+
query_text = "i love soup" # from SAMPLE_DATA in conftest.py
|
|
449
|
+
mislabeled_soup_text = "soup is comfort in a bowl"
|
|
450
|
+
writable_memoryset.insert(
|
|
451
|
+
[
|
|
452
|
+
dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
|
|
453
|
+
]
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Fetch the memory to update
|
|
457
|
+
memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
|
|
458
|
+
|
|
459
|
+
# Update the label and get cascading edit suggestions
|
|
460
|
+
suggestions = writable_memoryset.get_cascading_edits_suggestions(
|
|
461
|
+
memory=memory,
|
|
462
|
+
old_label=CATS,
|
|
463
|
+
new_label=SOUP,
|
|
464
|
+
max_neighbors=10,
|
|
465
|
+
max_validation_neighbors=5,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Validate the suggestions
|
|
469
|
+
assert len(suggestions) == 1
|
|
470
|
+
assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
|
|
474
|
+
"""Test that analyze() raises ValueError for invalid analysis names"""
|
|
475
|
+
memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
476
|
+
|
|
477
|
+
# Test with string input
|
|
478
|
+
with pytest.raises(ValueError) as excinfo:
|
|
479
|
+
memoryset.analyze("invalid_name")
|
|
480
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
481
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
482
|
+
|
|
483
|
+
# Test with dict input
|
|
484
|
+
with pytest.raises(ValueError) as excinfo:
|
|
485
|
+
memoryset.analyze({"name": "invalid_name"})
|
|
486
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
487
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
488
|
+
|
|
489
|
+
# Test with multiple analyses where one is invalid
|
|
490
|
+
with pytest.raises(ValueError) as excinfo:
|
|
491
|
+
memoryset.analyze("duplicate", "invalid_name")
|
|
492
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
493
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
494
|
+
|
|
495
|
+
# Test with valid analysis names
|
|
496
|
+
result = memoryset.analyze("duplicate", "cluster")
|
|
497
|
+
assert isinstance(result, dict)
|
|
498
|
+
assert "duplicate" in result
|
|
499
|
+
assert "cluster" in result
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
503
|
+
# NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
|
|
504
|
+
# Otherwise, it would be recreated on the next test run if it were dropped earlier, and
|
|
505
|
+
# that's expensive.
|
|
506
|
+
assert LabeledMemoryset.exists(writable_memoryset.name)
|
|
507
|
+
LabeledMemoryset.drop(writable_memoryset.name)
|
|
508
|
+
assert not LabeledMemoryset.exists(writable_memoryset.name)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
512
|
+
assert scored_memoryset.length == 22
|
|
513
|
+
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
514
|
+
assert scored_memoryset[0].value == "i love soup"
|
|
515
|
+
assert scored_memoryset[0].score is not None
|
|
516
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
|
|
517
|
+
lookup = scored_memoryset.search("i love soup", count=1)
|
|
518
|
+
assert len(lookup) == 1
|
|
519
|
+
assert lookup[0].score is not None
|
|
520
|
+
assert lookup[0].score < 0.11
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
524
|
+
def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
525
|
+
# we are only updating an inconsequential metadata field so that we don't affect other tests
|
|
526
|
+
memory = scored_memoryset[0]
|
|
527
|
+
assert memory.label == 0
|
|
528
|
+
scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
|
|
529
|
+
assert scored_memoryset[0].label == 3
|
|
530
|
+
memory.update(label=4)
|
|
531
|
+
assert scored_memoryset[0].label == 4
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@pytest.mark.asyncio
|
|
535
|
+
async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
|
|
536
|
+
"""Test async insertion of a single memory"""
|
|
537
|
+
await writable_memoryset.arefresh()
|
|
538
|
+
prev_length = writable_memoryset.length
|
|
539
|
+
|
|
540
|
+
await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
|
|
541
|
+
|
|
542
|
+
await writable_memoryset.arefresh()
|
|
543
|
+
assert writable_memoryset.length == prev_length + 1
|
|
544
|
+
last_memory = writable_memoryset[-1]
|
|
545
|
+
assert last_memory.value == "async tomato soup is my favorite"
|
|
546
|
+
assert last_memory.label == 0
|
|
547
|
+
assert last_memory.metadata["key"] == "async_test"
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@pytest.mark.asyncio
|
|
551
|
+
async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
|
|
552
|
+
"""Test async insertion of multiple memories"""
|
|
553
|
+
await writable_memoryset.arefresh()
|
|
554
|
+
prev_length = writable_memoryset.length
|
|
555
|
+
|
|
556
|
+
await writable_memoryset.ainsert(
|
|
557
|
+
[
|
|
558
|
+
dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
|
|
559
|
+
dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
|
|
560
|
+
]
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
await writable_memoryset.arefresh()
|
|
564
|
+
assert writable_memoryset.length == prev_length + 2
|
|
565
|
+
|
|
566
|
+
# Check the inserted memories
|
|
567
|
+
last_two_memories = writable_memoryset[-2:]
|
|
568
|
+
values = [memory.value for memory in last_two_memories]
|
|
569
|
+
labels = [memory.label for memory in last_two_memories]
|
|
570
|
+
keys = [memory.metadata.get("key") for memory in last_two_memories]
|
|
571
|
+
|
|
572
|
+
assert "async batch soup is delicious" in values
|
|
573
|
+
assert "async batch cats are adorable" in values
|
|
574
|
+
assert 0 in labels
|
|
575
|
+
assert 1 in labels
|
|
576
|
+
assert "batch_test_1" in keys
|
|
577
|
+
assert "batch_test_2" in keys
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@pytest.mark.asyncio
|
|
581
|
+
async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
|
|
582
|
+
"""Test async insertion with source_id and metadata"""
|
|
583
|
+
await writable_memoryset.arefresh()
|
|
584
|
+
prev_length = writable_memoryset.length
|
|
585
|
+
|
|
586
|
+
await writable_memoryset.ainsert(
|
|
587
|
+
dict(
|
|
588
|
+
value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
|
|
589
|
+
)
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
await writable_memoryset.arefresh()
|
|
593
|
+
assert writable_memoryset.length == prev_length + 1
|
|
594
|
+
last_memory = writable_memoryset[-1]
|
|
595
|
+
assert last_memory.value == "async soup with source id"
|
|
596
|
+
assert last_memory.label == 0
|
|
597
|
+
assert last_memory.source_id == "async_source_123"
|
|
598
|
+
assert last_memory.metadata["custom_field"] == "async_custom_value"
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
@pytest.mark.asyncio
|
|
602
|
+
async def test_insert_memories_async_unauthenticated(
|
|
603
|
+
unauthenticated_async_client, writable_memoryset: LabeledMemoryset
|
|
604
|
+
):
|
|
605
|
+
"""Test async insertion with invalid authentication"""
|
|
606
|
+
with unauthenticated_async_client.use():
|
|
607
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
608
|
+
await writable_memoryset.ainsert(dict(value="this should fail", label=0))
|