orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +84 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +90 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +184 -174
- orca_sdk/classification_model_test.py +178 -142
- orca_sdk/conftest.py +77 -26
- orca_sdk/datasource.py +34 -0
- orca_sdk/datasource_test.py +9 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +58 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +338 -0
- orca_sdk/telemetry.py +225 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import random
|
|
2
|
-
import time
|
|
3
|
-
from typing import Generator
|
|
4
3
|
from uuid import uuid4
|
|
5
4
|
|
|
6
5
|
import pytest
|
|
7
|
-
from datasets import ClassLabel, Features, Value
|
|
8
6
|
from datasets.arrow_dataset import Dataset
|
|
9
7
|
|
|
10
|
-
from orca_sdk.conftest import SAMPLE_DATA
|
|
11
|
-
|
|
12
|
-
from ._generated_api_client.models import CascadingEditSuggestion
|
|
13
8
|
from .classification_model import ClassificationModel
|
|
9
|
+
from .conftest import skip_in_prod
|
|
14
10
|
from .datasource import Datasource
|
|
15
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
16
|
-
from .memoryset import LabeledMemoryset,
|
|
12
|
+
from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
|
|
17
13
|
|
|
18
14
|
"""
|
|
19
15
|
Test Performance Note:
|
|
@@ -37,9 +33,11 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
|
|
|
37
33
|
assert readonly_memoryset.name == "test_readonly_memoryset"
|
|
38
34
|
assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
39
35
|
assert readonly_memoryset.label_names == label_names
|
|
40
|
-
assert readonly_memoryset.insertion_status ==
|
|
36
|
+
assert readonly_memoryset.insertion_status == Status.COMPLETED
|
|
41
37
|
assert isinstance(readonly_memoryset.length, int)
|
|
42
38
|
assert readonly_memoryset.length == len(hf_dataset)
|
|
39
|
+
assert readonly_memoryset.index_type == "IVF_FLAT"
|
|
40
|
+
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
43
41
|
|
|
44
42
|
|
|
45
43
|
def test_create_memoryset_unauthenticated(unauthenticated, datasource):
|
|
@@ -95,6 +93,8 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
|
95
93
|
assert fetched_memoryset is not None
|
|
96
94
|
assert fetched_memoryset.name == readonly_memoryset.name
|
|
97
95
|
assert fetched_memoryset.length == len(hf_dataset)
|
|
96
|
+
assert fetched_memoryset.index_type == "IVF_FLAT"
|
|
97
|
+
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
|
|
@@ -149,15 +149,25 @@ def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
|
|
|
149
149
|
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
def
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
writable_memoryset.update_metadata(description="New description")
|
|
152
|
+
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
153
|
+
original_label_names = writable_memoryset.label_names
|
|
154
|
+
writable_memoryset.set(description="New description")
|
|
156
155
|
assert writable_memoryset.description == "New description"
|
|
157
156
|
|
|
158
|
-
writable_memoryset.
|
|
157
|
+
writable_memoryset.set(description=None)
|
|
159
158
|
assert writable_memoryset.description is None
|
|
160
159
|
|
|
160
|
+
writable_memoryset.set(name="New_name")
|
|
161
|
+
assert writable_memoryset.name == "New_name"
|
|
162
|
+
|
|
163
|
+
writable_memoryset.set(name="test_writable_memoryset")
|
|
164
|
+
assert writable_memoryset.name == "test_writable_memoryset"
|
|
165
|
+
|
|
166
|
+
assert writable_memoryset.label_names == original_label_names
|
|
167
|
+
|
|
168
|
+
writable_memoryset.set(label_names=["New label 1", "New label 2"])
|
|
169
|
+
assert writable_memoryset.label_names == ["New label 1", "New label 2"]
|
|
170
|
+
|
|
161
171
|
|
|
162
172
|
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
163
173
|
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
@@ -214,11 +224,11 @@ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
214
224
|
assert len(memories) == 8
|
|
215
225
|
assert all(memory.label == 1 for memory in memories)
|
|
216
226
|
assert len(readonly_memoryset.query(limit=2)) == 2
|
|
217
|
-
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "
|
|
227
|
+
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
|
|
218
228
|
|
|
219
229
|
|
|
220
|
-
def test_query_memoryset_with_feedback_metrics(
|
|
221
|
-
prediction =
|
|
230
|
+
def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
|
|
231
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
222
232
|
feedback_name = f"correct_{random.randint(0, 1000000)}"
|
|
223
233
|
prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
|
|
224
234
|
memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
|
|
@@ -239,8 +249,8 @@ def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
|
|
|
239
249
|
assert isinstance(memory.lookup_count, int)
|
|
240
250
|
|
|
241
251
|
|
|
242
|
-
def test_query_memoryset_with_feedback_metrics_filter(
|
|
243
|
-
prediction =
|
|
252
|
+
def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
|
|
253
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
244
254
|
prediction.record_feedback(category="accurate", value=prediction.label == 0)
|
|
245
255
|
memories = prediction.memoryset.query(
|
|
246
256
|
filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
|
|
@@ -254,10 +264,10 @@ def test_query_memoryset_with_feedback_metrics_filter(model: ClassificationModel
|
|
|
254
264
|
assert memory.feedback_metrics["accurate"]["count"] == 1
|
|
255
265
|
|
|
256
266
|
|
|
257
|
-
def test_query_memoryset_with_feedback_metrics_sort(
|
|
258
|
-
prediction =
|
|
267
|
+
def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
|
|
268
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
259
269
|
prediction.record_feedback(category="positive", value=1.0)
|
|
260
|
-
prediction2 =
|
|
270
|
+
prediction2 = classification_model.predict("Do you like cats?")
|
|
261
271
|
prediction2.record_feedback(category="positive", value=-1.0)
|
|
262
272
|
|
|
263
273
|
memories = prediction.memoryset.query(
|
|
@@ -281,8 +291,10 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
281
291
|
dict(value="cats are fun to play with", label=1),
|
|
282
292
|
]
|
|
283
293
|
)
|
|
294
|
+
writable_memoryset.refresh()
|
|
284
295
|
assert writable_memoryset.length == prev_length + 2
|
|
285
296
|
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
297
|
+
writable_memoryset.refresh()
|
|
286
298
|
assert writable_memoryset.length == prev_length + 3
|
|
287
299
|
last_memory = writable_memoryset[-1]
|
|
288
300
|
assert last_memory.value == "tomato soup is my favorite"
|
|
@@ -292,6 +304,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
292
304
|
assert last_memory.source_id == "test"
|
|
293
305
|
|
|
294
306
|
|
|
307
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
295
308
|
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
296
309
|
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
297
310
|
|
|
@@ -300,6 +313,7 @@ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Datas
|
|
|
300
313
|
updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
301
314
|
assert updated_memory.value == "i love soup so much"
|
|
302
315
|
assert updated_memory.label == hf_dataset[0]["label"]
|
|
316
|
+
writable_memoryset.refresh() # Refresh to ensure consistency after update
|
|
303
317
|
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
304
318
|
|
|
305
319
|
# test updating a memory instance
|
|
@@ -346,7 +360,7 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
346
360
|
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
347
361
|
assert cloned_memoryset.length == readonly_memoryset.length
|
|
348
362
|
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
349
|
-
assert cloned_memoryset.insertion_status ==
|
|
363
|
+
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
350
364
|
|
|
351
365
|
|
|
352
366
|
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
@@ -361,7 +375,6 @@ def test_embedding_evaluation(eval_datasource: Datasource):
|
|
|
361
375
|
assert response["evaluation_results"][0] is not None
|
|
362
376
|
assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
|
|
363
377
|
assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
364
|
-
Datasource.drop("eval_datasource")
|
|
365
378
|
|
|
366
379
|
|
|
367
380
|
@pytest.fixture(scope="function")
|
|
@@ -453,3 +466,25 @@ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
|
453
466
|
assert LabeledMemoryset.exists(writable_memoryset.name)
|
|
454
467
|
LabeledMemoryset.drop(writable_memoryset.name)
|
|
455
468
|
assert not LabeledMemoryset.exists(writable_memoryset.name)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
472
|
+
assert scored_memoryset.length == 16
|
|
473
|
+
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
474
|
+
assert scored_memoryset[0].value == "i love soup"
|
|
475
|
+
assert scored_memoryset[0].score is not None
|
|
476
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
|
|
477
|
+
lookup = scored_memoryset.search("i love soup", count=1)
|
|
478
|
+
assert len(lookup) == 1
|
|
479
|
+
assert lookup[0].score < 0.11
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
483
|
+
def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
484
|
+
# we are only updating an inconsequential metadata field so that we don't affect other tests
|
|
485
|
+
memory = scored_memoryset[0]
|
|
486
|
+
assert memory.label == 0
|
|
487
|
+
scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
|
|
488
|
+
assert scored_memoryset[0].label == 3
|
|
489
|
+
memory.update(label=4)
|
|
490
|
+
assert scored_memoryset[0].label == 4
|