orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/memoryset_test.py
CHANGED
|
@@ -1,20 +1,42 @@
|
|
|
1
|
+
import random
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
4
5
|
from datasets.arrow_dataset import Dataset
|
|
5
6
|
|
|
7
|
+
from .classification_model import ClassificationModel
|
|
8
|
+
from .conftest import skip_in_prod
|
|
9
|
+
from .datasource import Datasource
|
|
6
10
|
from .embedding_model import PretrainedEmbeddingModel
|
|
7
|
-
from .memoryset import LabeledMemoryset,
|
|
11
|
+
from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
|
|
8
12
|
|
|
13
|
+
"""
|
|
14
|
+
Test Performance Note:
|
|
9
15
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
|
|
17
|
+
|
|
18
|
+
- Two fixtures are used to manage memorysets:
|
|
19
|
+
- `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
|
|
20
|
+
It should only be used in nullipotent tests.
|
|
21
|
+
- `writable_memoryset` is a function-scoped, regenerating fixture.
|
|
22
|
+
It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
|
|
23
|
+
|
|
24
|
+
- To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
|
|
25
|
+
For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
|
|
26
|
+
rather than separate `test_delete_single` and `test_delete_multiple` tests.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
31
|
+
assert readonly_memoryset is not None
|
|
32
|
+
assert readonly_memoryset.name == "test_readonly_memoryset"
|
|
33
|
+
assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
34
|
+
assert readonly_memoryset.label_names == label_names
|
|
35
|
+
assert readonly_memoryset.insertion_status == Status.COMPLETED
|
|
36
|
+
assert isinstance(readonly_memoryset.length, int)
|
|
37
|
+
assert readonly_memoryset.length == len(hf_dataset)
|
|
38
|
+
assert readonly_memoryset.index_type == "IVF_FLAT"
|
|
39
|
+
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
18
40
|
|
|
19
41
|
|
|
20
42
|
def test_create_memoryset_unauthenticated(unauthenticated, datasource):
|
|
@@ -26,61 +48,57 @@ def test_create_memoryset_invalid_input(datasource):
|
|
|
26
48
|
# invalid name
|
|
27
49
|
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
28
50
|
LabeledMemoryset.create("test memoryset", datasource)
|
|
29
|
-
# invalid datasource
|
|
30
|
-
datasource.id = str(uuid4())
|
|
31
|
-
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
32
|
-
LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
|
|
33
51
|
|
|
34
52
|
|
|
35
|
-
def test_create_memoryset_already_exists_error(hf_dataset, label_names,
|
|
53
|
+
def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
|
|
54
|
+
memoryset_name = readonly_memoryset.name
|
|
36
55
|
with pytest.raises(ValueError):
|
|
37
|
-
LabeledMemoryset.from_hf_dataset(
|
|
56
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
|
|
38
57
|
with pytest.raises(ValueError):
|
|
39
|
-
LabeledMemoryset.from_hf_dataset(
|
|
40
|
-
"test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
|
|
41
|
-
)
|
|
58
|
+
LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
|
|
42
59
|
|
|
43
60
|
|
|
44
|
-
def test_create_memoryset_already_exists_open(hf_dataset, label_names,
|
|
61
|
+
def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
|
|
45
62
|
# invalid label names
|
|
46
63
|
with pytest.raises(ValueError):
|
|
47
64
|
LabeledMemoryset.from_hf_dataset(
|
|
48
|
-
|
|
65
|
+
readonly_memoryset.name,
|
|
49
66
|
hf_dataset,
|
|
50
67
|
label_names=["turtles", "frogs"],
|
|
51
|
-
value_column="text",
|
|
52
68
|
if_exists="open",
|
|
53
69
|
)
|
|
54
70
|
# different embedding model
|
|
55
71
|
with pytest.raises(ValueError):
|
|
56
72
|
LabeledMemoryset.from_hf_dataset(
|
|
57
|
-
|
|
73
|
+
readonly_memoryset.name,
|
|
58
74
|
hf_dataset,
|
|
59
75
|
label_names=label_names,
|
|
60
76
|
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
61
77
|
if_exists="open",
|
|
62
78
|
)
|
|
63
79
|
opened_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
64
|
-
|
|
80
|
+
readonly_memoryset.name,
|
|
65
81
|
hf_dataset,
|
|
66
82
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
67
83
|
if_exists="open",
|
|
68
84
|
)
|
|
69
85
|
assert opened_memoryset is not None
|
|
70
|
-
assert opened_memoryset.name ==
|
|
86
|
+
assert opened_memoryset.name == readonly_memoryset.name
|
|
71
87
|
assert opened_memoryset.length == len(hf_dataset)
|
|
72
88
|
|
|
73
89
|
|
|
74
|
-
def test_open_memoryset(
|
|
75
|
-
fetched_memoryset = LabeledMemoryset.open(
|
|
90
|
+
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
91
|
+
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
76
92
|
assert fetched_memoryset is not None
|
|
77
|
-
assert fetched_memoryset.name ==
|
|
93
|
+
assert fetched_memoryset.name == readonly_memoryset.name
|
|
78
94
|
assert fetched_memoryset.length == len(hf_dataset)
|
|
95
|
+
assert fetched_memoryset.index_type == "IVF_FLAT"
|
|
96
|
+
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
79
97
|
|
|
80
98
|
|
|
81
|
-
def test_open_memoryset_unauthenticated(unauthenticated,
|
|
99
|
+
def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
|
|
82
100
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
83
|
-
LabeledMemoryset.open(
|
|
101
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
84
102
|
|
|
85
103
|
|
|
86
104
|
def test_open_memoryset_not_found():
|
|
@@ -93,15 +111,35 @@ def test_open_memoryset_invalid_input():
|
|
|
93
111
|
LabeledMemoryset.open("not valid id")
|
|
94
112
|
|
|
95
113
|
|
|
96
|
-
def test_open_memoryset_unauthorized(unauthorized,
|
|
114
|
+
def test_open_memoryset_unauthorized(unauthorized, readonly_memoryset):
|
|
97
115
|
with pytest.raises(LookupError):
|
|
98
|
-
LabeledMemoryset.open(
|
|
116
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
99
117
|
|
|
100
118
|
|
|
101
|
-
def test_all_memorysets(
|
|
119
|
+
def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
102
120
|
memorysets = LabeledMemoryset.all()
|
|
103
121
|
assert len(memorysets) > 0
|
|
104
|
-
assert any(memoryset.name ==
|
|
122
|
+
assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_all_memorysets_hidden(
|
|
126
|
+
readonly_memoryset: LabeledMemoryset,
|
|
127
|
+
):
|
|
128
|
+
# Create a hidden memoryset
|
|
129
|
+
hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
|
|
130
|
+
hidden_memoryset.set(hidden=True)
|
|
131
|
+
|
|
132
|
+
# Test that show_hidden=False excludes hidden memorysets
|
|
133
|
+
visible_memorysets = LabeledMemoryset.all(show_hidden=False)
|
|
134
|
+
assert len(visible_memorysets) > 0
|
|
135
|
+
assert readonly_memoryset in visible_memorysets
|
|
136
|
+
assert hidden_memoryset not in visible_memorysets
|
|
137
|
+
|
|
138
|
+
# Test that show_hidden=True includes hidden memorysets
|
|
139
|
+
all_memorysets = LabeledMemoryset.all(show_hidden=True)
|
|
140
|
+
assert len(all_memorysets) == len(visible_memorysets) + 1
|
|
141
|
+
assert readonly_memoryset in all_memorysets
|
|
142
|
+
assert hidden_memoryset in all_memorysets
|
|
105
143
|
|
|
106
144
|
|
|
107
145
|
def test_all_memorysets_unauthenticated(unauthenticated):
|
|
@@ -109,41 +147,52 @@ def test_all_memorysets_unauthenticated(unauthenticated):
|
|
|
109
147
|
LabeledMemoryset.all()
|
|
110
148
|
|
|
111
149
|
|
|
112
|
-
def test_all_memorysets_unauthorized(unauthorized,
|
|
113
|
-
assert
|
|
114
|
-
|
|
150
|
+
def test_all_memorysets_unauthorized(unauthorized, readonly_memoryset):
|
|
151
|
+
assert readonly_memoryset not in LabeledMemoryset.all()
|
|
115
152
|
|
|
116
|
-
@pytest.mark.flaky
|
|
117
|
-
def test_drop_memoryset(hf_dataset):
|
|
118
|
-
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
119
|
-
"test_memoryset_delete",
|
|
120
|
-
hf_dataset.select(range(1)),
|
|
121
|
-
value_column="text",
|
|
122
|
-
)
|
|
123
|
-
assert LabeledMemoryset.exists(memoryset.name)
|
|
124
|
-
LabeledMemoryset.drop(memoryset.name)
|
|
125
|
-
assert not LabeledMemoryset.exists(memoryset.name)
|
|
126
153
|
|
|
127
|
-
|
|
128
|
-
def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
|
|
154
|
+
def test_drop_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
|
|
129
155
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
130
|
-
LabeledMemoryset.drop(
|
|
156
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
131
157
|
|
|
132
158
|
|
|
133
|
-
def test_drop_memoryset_not_found(
|
|
159
|
+
def test_drop_memoryset_not_found():
|
|
134
160
|
with pytest.raises(LookupError):
|
|
135
161
|
LabeledMemoryset.drop(str(uuid4()))
|
|
136
162
|
# ignores error if specified
|
|
137
163
|
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
138
164
|
|
|
139
165
|
|
|
140
|
-
def test_drop_memoryset_unauthorized(unauthorized,
|
|
166
|
+
def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
|
|
141
167
|
with pytest.raises(LookupError):
|
|
142
|
-
LabeledMemoryset.drop(
|
|
168
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
172
|
+
original_label_names = writable_memoryset.label_names
|
|
173
|
+
writable_memoryset.set(description="New description")
|
|
174
|
+
assert writable_memoryset.description == "New description"
|
|
175
|
+
|
|
176
|
+
writable_memoryset.set(description=None)
|
|
177
|
+
assert writable_memoryset.description is None
|
|
178
|
+
|
|
179
|
+
writable_memoryset.set(name="New_name")
|
|
180
|
+
assert writable_memoryset.name == "New_name"
|
|
181
|
+
|
|
182
|
+
writable_memoryset.set(name="test_writable_memoryset")
|
|
183
|
+
assert writable_memoryset.name == "test_writable_memoryset"
|
|
143
184
|
|
|
185
|
+
assert writable_memoryset.label_names == original_label_names
|
|
144
186
|
|
|
145
|
-
|
|
146
|
-
|
|
187
|
+
writable_memoryset.set(label_names=["New label 1", "New label 2"])
|
|
188
|
+
assert writable_memoryset.label_names == ["New label 1", "New label 2"]
|
|
189
|
+
|
|
190
|
+
writable_memoryset.set(hidden=True)
|
|
191
|
+
assert writable_memoryset.hidden is True
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
195
|
+
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
147
196
|
assert len(memory_lookups) == 2
|
|
148
197
|
assert len(memory_lookups[0]) == 1
|
|
149
198
|
assert len(memory_lookups[1]) == 1
|
|
@@ -151,67 +200,125 @@ def test_search(memoryset: LabeledMemoryset):
|
|
|
151
200
|
assert memory_lookups[1][0].label == 1
|
|
152
201
|
|
|
153
202
|
|
|
154
|
-
def test_search_count(
|
|
155
|
-
memory_lookups =
|
|
203
|
+
def test_search_count(readonly_memoryset: LabeledMemoryset):
|
|
204
|
+
memory_lookups = readonly_memoryset.search("i love soup", count=3)
|
|
156
205
|
assert len(memory_lookups) == 3
|
|
157
206
|
assert memory_lookups[0].label == 0
|
|
158
207
|
assert memory_lookups[1].label == 0
|
|
159
208
|
assert memory_lookups[2].label == 0
|
|
160
209
|
|
|
161
210
|
|
|
162
|
-
def test_get_memory_at_index(
|
|
163
|
-
memory =
|
|
164
|
-
assert memory.value == hf_dataset[0]["
|
|
211
|
+
def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
212
|
+
memory = readonly_memoryset[0]
|
|
213
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
165
214
|
assert memory.label == hf_dataset[0]["label"]
|
|
166
215
|
assert memory.label_name == label_names[hf_dataset[0]["label"]]
|
|
167
216
|
assert memory.source_id == hf_dataset[0]["source_id"]
|
|
168
217
|
assert memory.score == hf_dataset[0]["score"]
|
|
169
218
|
assert memory.key == hf_dataset[0]["key"]
|
|
170
|
-
last_memory =
|
|
171
|
-
assert last_memory.value == hf_dataset[-1]["
|
|
219
|
+
last_memory = readonly_memoryset[-1]
|
|
220
|
+
assert last_memory.value == hf_dataset[-1]["value"]
|
|
172
221
|
assert last_memory.label == hf_dataset[-1]["label"]
|
|
173
222
|
|
|
174
223
|
|
|
175
|
-
def test_get_range_of_memories(
|
|
176
|
-
memories =
|
|
224
|
+
def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
225
|
+
memories = readonly_memoryset[1:3]
|
|
177
226
|
assert len(memories) == 2
|
|
178
|
-
assert memories[0].value == hf_dataset["
|
|
179
|
-
assert memories[1].value == hf_dataset["
|
|
227
|
+
assert memories[0].value == hf_dataset["value"][1]
|
|
228
|
+
assert memories[1].value == hf_dataset["value"][2]
|
|
180
229
|
|
|
181
230
|
|
|
182
|
-
def test_get_memory_by_id(
|
|
183
|
-
memory =
|
|
184
|
-
assert memory.value == hf_dataset[0]["
|
|
185
|
-
assert memory ==
|
|
231
|
+
def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
232
|
+
memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
|
|
233
|
+
assert memory.value == hf_dataset[0]["value"]
|
|
234
|
+
assert memory == readonly_memoryset[memory.memory_id]
|
|
186
235
|
|
|
187
236
|
|
|
188
|
-
def test_get_memories_by_id(
|
|
189
|
-
memories =
|
|
237
|
+
def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
238
|
+
memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
|
|
190
239
|
assert len(memories) == 2
|
|
191
|
-
assert memories[0].value == hf_dataset[0]["
|
|
192
|
-
assert memories[1].value == hf_dataset[1]["
|
|
240
|
+
assert memories[0].value == hf_dataset[0]["value"]
|
|
241
|
+
assert memories[1].value == hf_dataset[1]["value"]
|
|
193
242
|
|
|
194
243
|
|
|
195
|
-
def test_query_memoryset(
|
|
196
|
-
memories =
|
|
197
|
-
assert len(memories) ==
|
|
244
|
+
def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
245
|
+
memories = readonly_memoryset.query(filters=[("label", "==", 1)])
|
|
246
|
+
assert len(memories) == 8
|
|
198
247
|
assert all(memory.label == 1 for memory in memories)
|
|
199
|
-
assert len(
|
|
200
|
-
assert len(
|
|
248
|
+
assert len(readonly_memoryset.query(limit=2)) == 2
|
|
249
|
+
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
|
|
253
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
254
|
+
feedback_name = f"correct_{random.randint(0, 1000000)}"
|
|
255
|
+
prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
|
|
256
|
+
memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
|
|
257
|
+
|
|
258
|
+
# Get the memory_ids that were actually used in the prediction
|
|
259
|
+
used_memory_ids = {memory.memory_id for memory in prediction.memory_lookups}
|
|
260
|
+
|
|
261
|
+
assert len(memories) == 8
|
|
262
|
+
assert all(memory.label == 0 for memory in memories)
|
|
263
|
+
for memory in memories:
|
|
264
|
+
assert memory.feedback_metrics is not None
|
|
265
|
+
if memory.memory_id in used_memory_ids:
|
|
266
|
+
assert feedback_name in memory.feedback_metrics
|
|
267
|
+
assert memory.feedback_metrics[feedback_name]["avg"] == 1.0
|
|
268
|
+
assert memory.feedback_metrics[feedback_name]["count"] == 1
|
|
269
|
+
else:
|
|
270
|
+
assert feedback_name not in memory.feedback_metrics or memory.feedback_metrics[feedback_name]["count"] == 0
|
|
271
|
+
assert isinstance(memory.lookup_count, int)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
|
|
275
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
276
|
+
prediction.record_feedback(category="accurate", value=prediction.label == 0)
|
|
277
|
+
memories = prediction.memoryset.query(
|
|
278
|
+
filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
|
|
279
|
+
)
|
|
280
|
+
assert len(memories) == 3
|
|
281
|
+
assert all(memory.label == 0 for memory in memories)
|
|
282
|
+
for memory in memories:
|
|
283
|
+
assert memory.feedback_metrics is not None
|
|
284
|
+
assert memory.feedback_metrics["accurate"] is not None
|
|
285
|
+
assert memory.feedback_metrics["accurate"]["avg"] == 1.0
|
|
286
|
+
assert memory.feedback_metrics["accurate"]["count"] == 1
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
|
|
290
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
291
|
+
prediction.record_feedback(category="positive", value=1.0)
|
|
292
|
+
prediction2 = classification_model.predict("Do you like cats?")
|
|
293
|
+
prediction2.record_feedback(category="positive", value=-1.0)
|
|
294
|
+
|
|
295
|
+
memories = prediction.memoryset.query(
|
|
296
|
+
filters=[("feedback_metrics.positive.avg", ">=", -1.0)],
|
|
297
|
+
sort=[("feedback_metrics.positive.avg", "desc")],
|
|
298
|
+
with_feedback_metrics=True,
|
|
299
|
+
)
|
|
300
|
+
assert (
|
|
301
|
+
len(memories) == 6
|
|
302
|
+
) # there are only 6 out of 16 memories that have a positive feedback metric. Look at SAMPLE_DATA in conftest.py
|
|
303
|
+
assert memories[0].feedback_metrics["positive"]["avg"] == 1.0
|
|
304
|
+
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
201
305
|
|
|
202
306
|
|
|
203
|
-
def test_insert_memories(
|
|
204
|
-
|
|
205
|
-
|
|
307
|
+
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
308
|
+
writable_memoryset.refresh()
|
|
309
|
+
prev_length = writable_memoryset.length
|
|
310
|
+
writable_memoryset.insert(
|
|
206
311
|
[
|
|
207
312
|
dict(value="tomato soup is my favorite", label=0),
|
|
208
313
|
dict(value="cats are fun to play with", label=1),
|
|
209
314
|
]
|
|
210
315
|
)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
316
|
+
writable_memoryset.refresh()
|
|
317
|
+
assert writable_memoryset.length == prev_length + 2
|
|
318
|
+
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
319
|
+
writable_memoryset.refresh()
|
|
320
|
+
assert writable_memoryset.length == prev_length + 3
|
|
321
|
+
last_memory = writable_memoryset[-1]
|
|
215
322
|
assert last_memory.value == "tomato soup is my favorite"
|
|
216
323
|
assert last_memory.label == 0
|
|
217
324
|
assert last_memory.metadata
|
|
@@ -219,25 +326,28 @@ def test_insert_memories(memoryset: LabeledMemoryset):
|
|
|
219
326
|
assert last_memory.source_id == "test"
|
|
220
327
|
|
|
221
328
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
329
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
330
|
+
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
331
|
+
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
332
|
+
|
|
333
|
+
# test updating a single memory
|
|
334
|
+
memory_id = writable_memoryset[0].memory_id
|
|
335
|
+
updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
225
336
|
assert updated_memory.value == "i love soup so much"
|
|
226
337
|
assert updated_memory.label == hf_dataset[0]["label"]
|
|
227
|
-
|
|
338
|
+
writable_memoryset.refresh() # Refresh to ensure consistency after update
|
|
339
|
+
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
228
340
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
memory = memoryset[0]
|
|
341
|
+
# test updating a memory instance
|
|
342
|
+
memory = writable_memoryset[0]
|
|
232
343
|
updated_memory = memory.update(value="i love soup even more")
|
|
233
344
|
assert updated_memory is memory
|
|
234
345
|
assert memory.value == "i love soup even more"
|
|
235
346
|
assert memory.label == hf_dataset[0]["label"]
|
|
236
347
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
updated_memories = memoryset.update(
|
|
348
|
+
# test updating multiple memories
|
|
349
|
+
memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
|
|
350
|
+
updated_memories = writable_memoryset.update(
|
|
241
351
|
[
|
|
242
352
|
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
243
353
|
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
@@ -247,25 +357,154 @@ def test_update_memories(memoryset: LabeledMemoryset):
|
|
|
247
357
|
assert updated_memories[1].value == "cats are so cute"
|
|
248
358
|
|
|
249
359
|
|
|
250
|
-
def
|
|
251
|
-
|
|
252
|
-
memory_id = memoryset[0].memory_id
|
|
253
|
-
memoryset.delete(memory_id)
|
|
254
|
-
with pytest.raises(LookupError):
|
|
255
|
-
memoryset.get(memory_id)
|
|
256
|
-
assert memoryset.length == prev_length - 1
|
|
360
|
+
def test_delete_memories(writable_memoryset: LabeledMemoryset):
|
|
361
|
+
# We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
257
362
|
|
|
363
|
+
# test deleting a single memory
|
|
364
|
+
prev_length = writable_memoryset.length
|
|
365
|
+
memory_id = writable_memoryset[0].memory_id
|
|
366
|
+
writable_memoryset.delete(memory_id)
|
|
367
|
+
with pytest.raises(LookupError):
|
|
368
|
+
writable_memoryset.get(memory_id)
|
|
369
|
+
assert writable_memoryset.length == prev_length - 1
|
|
258
370
|
|
|
259
|
-
|
|
260
|
-
prev_length =
|
|
261
|
-
|
|
262
|
-
assert
|
|
371
|
+
# test deleting multiple memories
|
|
372
|
+
prev_length = writable_memoryset.length
|
|
373
|
+
writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id])
|
|
374
|
+
assert writable_memoryset.length == prev_length - 2
|
|
263
375
|
|
|
264
376
|
|
|
265
|
-
def test_clone_memoryset(
|
|
266
|
-
cloned_memoryset =
|
|
377
|
+
def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
378
|
+
cloned_memoryset = readonly_memoryset.clone(
|
|
379
|
+
"test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
|
|
380
|
+
)
|
|
267
381
|
assert cloned_memoryset is not None
|
|
268
382
|
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
269
|
-
assert cloned_memoryset.length ==
|
|
383
|
+
assert cloned_memoryset.length == readonly_memoryset.length
|
|
270
384
|
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
271
|
-
assert cloned_memoryset.insertion_status ==
|
|
385
|
+
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
389
|
+
results = LabeledMemoryset.run_embedding_evaluation(
|
|
390
|
+
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=3
|
|
391
|
+
)
|
|
392
|
+
assert isinstance(results, list)
|
|
393
|
+
assert len(results) == 1
|
|
394
|
+
assert results[0] is not None
|
|
395
|
+
assert results[0]["embedding_model_name"] == "CDE_SMALL"
|
|
396
|
+
assert results[0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@pytest.fixture(scope="function")
|
|
400
|
+
async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
401
|
+
writable_memoryset.insert(
|
|
402
|
+
[
|
|
403
|
+
dict(value="raspberry soup Is my favorite", label=0),
|
|
404
|
+
dict(value="Raspberry soup is MY favorite", label=0),
|
|
405
|
+
dict(value="rAspberry soup is my favorite", label=0),
|
|
406
|
+
dict(value="raSpberry SOuP is my favorite", label=0),
|
|
407
|
+
dict(value="rasPberry SOuP is my favorite", label=0),
|
|
408
|
+
dict(value="bunny rabbit Is not my mom", label=1),
|
|
409
|
+
dict(value="bunny rabbit is not MY mom", label=1),
|
|
410
|
+
dict(value="bunny rabbit Is not my moM", label=1),
|
|
411
|
+
dict(value="bunny rabbit is not my mom", label=1),
|
|
412
|
+
dict(value="bunny rabbit is not my mom", label=1),
|
|
413
|
+
dict(value="bunny rabbit is not My mom", label=1),
|
|
414
|
+
]
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
|
|
418
|
+
response = writable_memoryset.get_potential_duplicate_groups()
|
|
419
|
+
assert isinstance(response, list)
|
|
420
|
+
assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
424
|
+
# Insert a memory to test cascading edits
|
|
425
|
+
SOUP = 0
|
|
426
|
+
CATS = 1
|
|
427
|
+
query_text = "i love soup" # from SAMPLE_DATA in conftest.py
|
|
428
|
+
mislabeled_soup_text = "soup is comfort in a bowl"
|
|
429
|
+
writable_memoryset.insert(
|
|
430
|
+
[
|
|
431
|
+
dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
|
|
432
|
+
]
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Fetch the memory to update
|
|
436
|
+
memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
|
|
437
|
+
|
|
438
|
+
# Update the label and get cascading edit suggestions
|
|
439
|
+
suggestions = writable_memoryset.get_cascading_edits_suggestions(
|
|
440
|
+
memory=memory,
|
|
441
|
+
old_label=CATS,
|
|
442
|
+
new_label=SOUP,
|
|
443
|
+
max_neighbors=10,
|
|
444
|
+
max_validation_neighbors=5,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Validate the suggestions
|
|
448
|
+
assert len(suggestions) == 1
|
|
449
|
+
assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
|
|
453
|
+
"""Test that analyze() raises ValueError for invalid analysis names"""
|
|
454
|
+
memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
455
|
+
|
|
456
|
+
# Test with string input
|
|
457
|
+
with pytest.raises(ValueError) as excinfo:
|
|
458
|
+
memoryset.analyze("invalid_name")
|
|
459
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
460
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
461
|
+
|
|
462
|
+
# Test with dict input
|
|
463
|
+
with pytest.raises(ValueError) as excinfo:
|
|
464
|
+
memoryset.analyze({"name": "invalid_name"})
|
|
465
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
466
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
467
|
+
|
|
468
|
+
# Test with multiple analyses where one is invalid
|
|
469
|
+
with pytest.raises(ValueError) as excinfo:
|
|
470
|
+
memoryset.analyze("duplicate", "invalid_name")
|
|
471
|
+
assert "Invalid analysis name: invalid_name" in str(excinfo.value)
|
|
472
|
+
assert "Valid names are:" in str(excinfo.value)
|
|
473
|
+
|
|
474
|
+
# Test with valid analysis names
|
|
475
|
+
result = memoryset.analyze("duplicate", "cluster")
|
|
476
|
+
assert isinstance(result, dict)
|
|
477
|
+
assert "duplicate" in result
|
|
478
|
+
assert "cluster" in result
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
482
|
+
# NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
|
|
483
|
+
# Otherwise, it would be recreated on the next test run if it were dropped earlier, and
|
|
484
|
+
# that's expensive.
|
|
485
|
+
assert LabeledMemoryset.exists(writable_memoryset.name)
|
|
486
|
+
LabeledMemoryset.drop(writable_memoryset.name)
|
|
487
|
+
assert not LabeledMemoryset.exists(writable_memoryset.name)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
491
|
+
assert scored_memoryset.length == 22
|
|
492
|
+
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
493
|
+
assert scored_memoryset[0].value == "i love soup"
|
|
494
|
+
assert scored_memoryset[0].score is not None
|
|
495
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
|
|
496
|
+
lookup = scored_memoryset.search("i love soup", count=1)
|
|
497
|
+
assert len(lookup) == 1
|
|
498
|
+
assert lookup[0].score is not None
|
|
499
|
+
assert lookup[0].score < 0.11
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
503
|
+
def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
504
|
+
# we are only updating an inconsequential metadata field so that we don't affect other tests
|
|
505
|
+
memory = scored_memoryset[0]
|
|
506
|
+
assert memory.label == 0
|
|
507
|
+
scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
|
|
508
|
+
assert scored_memoryset[0].label == 3
|
|
509
|
+
memory.update(label=4)
|
|
510
|
+
assert scored_memoryset[0].label == 4
|