orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,20 +1,42 @@
1
+ import random
1
2
  from uuid import uuid4
2
3
 
3
4
  import pytest
4
5
  from datasets.arrow_dataset import Dataset
5
6
 
7
+ from .classification_model import ClassificationModel
8
+ from .conftest import skip_in_prod
9
+ from .datasource import Datasource
6
10
  from .embedding_model import PretrainedEmbeddingModel
7
- from .memoryset import LabeledMemoryset, TaskStatus
11
+ from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
8
12
 
13
+ """
14
+ Test Performance Note:
9
15
 
10
- def test_create_memoryset(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
11
- assert memoryset is not None
12
- assert memoryset.name == "test_memoryset"
13
- assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
14
- assert memoryset.label_names == label_names
15
- assert memoryset.insertion_status == TaskStatus.COMPLETED
16
- assert isinstance(memoryset.length, int)
17
- assert memoryset.length == len(hf_dataset)
16
+ Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
17
+
18
+ - Two fixtures are used to manage memorysets:
19
+ - `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
20
+ It should only be used in nullipotent tests.
21
+ - `writable_memoryset` is a function-scoped, regenerating fixture.
22
+ It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
23
+
24
+ - To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
25
+ For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
26
+ rather than separate `test_delete_single` and `test_delete_multiple` tests.
27
+ """
28
+
29
+
30
+ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
31
+ assert readonly_memoryset is not None
32
+ assert readonly_memoryset.name == "test_readonly_memoryset"
33
+ assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
34
+ assert readonly_memoryset.label_names == label_names
35
+ assert readonly_memoryset.insertion_status == Status.COMPLETED
36
+ assert isinstance(readonly_memoryset.length, int)
37
+ assert readonly_memoryset.length == len(hf_dataset)
38
+ assert readonly_memoryset.index_type == "IVF_FLAT"
39
+ assert readonly_memoryset.index_params == {"n_lists": 100}
18
40
 
19
41
 
20
42
  def test_create_memoryset_unauthenticated(unauthenticated, datasource):
@@ -26,61 +48,57 @@ def test_create_memoryset_invalid_input(datasource):
26
48
  # invalid name
27
49
  with pytest.raises(ValueError, match=r"Invalid input:.*"):
28
50
  LabeledMemoryset.create("test memoryset", datasource)
29
- # invalid datasource
30
- datasource.id = str(uuid4())
31
- with pytest.raises(ValueError, match=r"Invalid input:.*"):
32
- LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
33
51
 
34
52
 
35
- def test_create_memoryset_already_exists_error(hf_dataset, label_names, memoryset):
53
+ def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
54
+ memoryset_name = readonly_memoryset.name
36
55
  with pytest.raises(ValueError):
37
- LabeledMemoryset.from_hf_dataset("test_memoryset", hf_dataset, label_names=label_names, value_column="text")
56
+ LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
38
57
  with pytest.raises(ValueError):
39
- LabeledMemoryset.from_hf_dataset(
40
- "test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
41
- )
58
+ LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
42
59
 
43
60
 
44
- def test_create_memoryset_already_exists_open(hf_dataset, label_names, memoryset):
61
+ def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
45
62
  # invalid label names
46
63
  with pytest.raises(ValueError):
47
64
  LabeledMemoryset.from_hf_dataset(
48
- memoryset.name,
65
+ readonly_memoryset.name,
49
66
  hf_dataset,
50
67
  label_names=["turtles", "frogs"],
51
- value_column="text",
52
68
  if_exists="open",
53
69
  )
54
70
  # different embedding model
55
71
  with pytest.raises(ValueError):
56
72
  LabeledMemoryset.from_hf_dataset(
57
- memoryset.name,
73
+ readonly_memoryset.name,
58
74
  hf_dataset,
59
75
  label_names=label_names,
60
76
  embedding_model=PretrainedEmbeddingModel.DISTILBERT,
61
77
  if_exists="open",
62
78
  )
63
79
  opened_memoryset = LabeledMemoryset.from_hf_dataset(
64
- memoryset.name,
80
+ readonly_memoryset.name,
65
81
  hf_dataset,
66
82
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
67
83
  if_exists="open",
68
84
  )
69
85
  assert opened_memoryset is not None
70
- assert opened_memoryset.name == memoryset.name
86
+ assert opened_memoryset.name == readonly_memoryset.name
71
87
  assert opened_memoryset.length == len(hf_dataset)
72
88
 
73
89
 
74
- def test_open_memoryset(memoryset, hf_dataset):
75
- fetched_memoryset = LabeledMemoryset.open(memoryset.name)
90
+ def test_open_memoryset(readonly_memoryset, hf_dataset):
91
+ fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
76
92
  assert fetched_memoryset is not None
77
- assert fetched_memoryset.name == memoryset.name
93
+ assert fetched_memoryset.name == readonly_memoryset.name
78
94
  assert fetched_memoryset.length == len(hf_dataset)
95
+ assert fetched_memoryset.index_type == "IVF_FLAT"
96
+ assert fetched_memoryset.index_params == {"n_lists": 100}
79
97
 
80
98
 
81
- def test_open_memoryset_unauthenticated(unauthenticated, memoryset):
99
+ def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
82
100
  with pytest.raises(ValueError, match="Invalid API key"):
83
- LabeledMemoryset.open(memoryset.name)
101
+ LabeledMemoryset.open(readonly_memoryset.name)
84
102
 
85
103
 
86
104
  def test_open_memoryset_not_found():
@@ -93,15 +111,35 @@ def test_open_memoryset_invalid_input():
93
111
  LabeledMemoryset.open("not valid id")
94
112
 
95
113
 
96
- def test_open_memoryset_unauthorized(unauthorized, memoryset):
114
+ def test_open_memoryset_unauthorized(unauthorized, readonly_memoryset):
97
115
  with pytest.raises(LookupError):
98
- LabeledMemoryset.open(memoryset.name)
116
+ LabeledMemoryset.open(readonly_memoryset.name)
99
117
 
100
118
 
101
- def test_all_memorysets(memoryset):
119
+ def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
102
120
  memorysets = LabeledMemoryset.all()
103
121
  assert len(memorysets) > 0
104
- assert any(memoryset.name == memoryset.name for memoryset in memorysets)
122
+ assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
123
+
124
+
125
+ def test_all_memorysets_hidden(
126
+ readonly_memoryset: LabeledMemoryset,
127
+ ):
128
+ # Create a hidden memoryset
129
+ hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
130
+ hidden_memoryset.set(hidden=True)
131
+
132
+ # Test that show_hidden=False excludes hidden memorysets
133
+ visible_memorysets = LabeledMemoryset.all(show_hidden=False)
134
+ assert len(visible_memorysets) > 0
135
+ assert readonly_memoryset in visible_memorysets
136
+ assert hidden_memoryset not in visible_memorysets
137
+
138
+ # Test that show_hidden=True includes hidden memorysets
139
+ all_memorysets = LabeledMemoryset.all(show_hidden=True)
140
+ assert len(all_memorysets) == len(visible_memorysets) + 1
141
+ assert readonly_memoryset in all_memorysets
142
+ assert hidden_memoryset in all_memorysets
105
143
 
106
144
 
107
145
  def test_all_memorysets_unauthenticated(unauthenticated):
@@ -109,41 +147,52 @@ def test_all_memorysets_unauthenticated(unauthenticated):
109
147
  LabeledMemoryset.all()
110
148
 
111
149
 
112
- def test_all_memorysets_unauthorized(unauthorized, memoryset):
113
- assert memoryset not in LabeledMemoryset.all()
114
-
150
+ def test_all_memorysets_unauthorized(unauthorized, readonly_memoryset):
151
+ assert readonly_memoryset not in LabeledMemoryset.all()
115
152
 
116
- @pytest.mark.flaky
117
- def test_drop_memoryset(hf_dataset):
118
- memoryset = LabeledMemoryset.from_hf_dataset(
119
- "test_memoryset_delete",
120
- hf_dataset.select(range(1)),
121
- value_column="text",
122
- )
123
- assert LabeledMemoryset.exists(memoryset.name)
124
- LabeledMemoryset.drop(memoryset.name)
125
- assert not LabeledMemoryset.exists(memoryset.name)
126
153
 
127
-
128
- def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
154
+ def test_drop_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
129
155
  with pytest.raises(ValueError, match="Invalid API key"):
130
- LabeledMemoryset.drop(memoryset.name)
156
+ LabeledMemoryset.drop(readonly_memoryset.name)
131
157
 
132
158
 
133
- def test_drop_memoryset_not_found(memoryset):
159
+ def test_drop_memoryset_not_found():
134
160
  with pytest.raises(LookupError):
135
161
  LabeledMemoryset.drop(str(uuid4()))
136
162
  # ignores error if specified
137
163
  LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
138
164
 
139
165
 
140
- def test_drop_memoryset_unauthorized(unauthorized, memoryset):
166
+ def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
141
167
  with pytest.raises(LookupError):
142
- LabeledMemoryset.drop(memoryset.name)
168
+ LabeledMemoryset.drop(readonly_memoryset.name)
169
+
170
+
171
+ def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
172
+ original_label_names = writable_memoryset.label_names
173
+ writable_memoryset.set(description="New description")
174
+ assert writable_memoryset.description == "New description"
175
+
176
+ writable_memoryset.set(description=None)
177
+ assert writable_memoryset.description is None
178
+
179
+ writable_memoryset.set(name="New_name")
180
+ assert writable_memoryset.name == "New_name"
181
+
182
+ writable_memoryset.set(name="test_writable_memoryset")
183
+ assert writable_memoryset.name == "test_writable_memoryset"
143
184
 
185
+ assert writable_memoryset.label_names == original_label_names
144
186
 
145
- def test_search(memoryset: LabeledMemoryset):
146
- memory_lookups = memoryset.search(["i love soup", "cats are cute"])
187
+ writable_memoryset.set(label_names=["New label 1", "New label 2"])
188
+ assert writable_memoryset.label_names == ["New label 1", "New label 2"]
189
+
190
+ writable_memoryset.set(hidden=True)
191
+ assert writable_memoryset.hidden is True
192
+
193
+
194
+ def test_search(readonly_memoryset: LabeledMemoryset):
195
+ memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
147
196
  assert len(memory_lookups) == 2
148
197
  assert len(memory_lookups[0]) == 1
149
198
  assert len(memory_lookups[1]) == 1
@@ -151,67 +200,125 @@ def test_search(memoryset: LabeledMemoryset):
151
200
  assert memory_lookups[1][0].label == 1
152
201
 
153
202
 
154
- def test_search_count(memoryset: LabeledMemoryset):
155
- memory_lookups = memoryset.search("i love soup", count=3)
203
+ def test_search_count(readonly_memoryset: LabeledMemoryset):
204
+ memory_lookups = readonly_memoryset.search("i love soup", count=3)
156
205
  assert len(memory_lookups) == 3
157
206
  assert memory_lookups[0].label == 0
158
207
  assert memory_lookups[1].label == 0
159
208
  assert memory_lookups[2].label == 0
160
209
 
161
210
 
162
- def test_get_memory_at_index(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
163
- memory = memoryset[0]
164
- assert memory.value == hf_dataset[0]["text"]
211
+ def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
212
+ memory = readonly_memoryset[0]
213
+ assert memory.value == hf_dataset[0]["value"]
165
214
  assert memory.label == hf_dataset[0]["label"]
166
215
  assert memory.label_name == label_names[hf_dataset[0]["label"]]
167
216
  assert memory.source_id == hf_dataset[0]["source_id"]
168
217
  assert memory.score == hf_dataset[0]["score"]
169
218
  assert memory.key == hf_dataset[0]["key"]
170
- last_memory = memoryset[-1]
171
- assert last_memory.value == hf_dataset[-1]["text"]
219
+ last_memory = readonly_memoryset[-1]
220
+ assert last_memory.value == hf_dataset[-1]["value"]
172
221
  assert last_memory.label == hf_dataset[-1]["label"]
173
222
 
174
223
 
175
- def test_get_range_of_memories(memoryset: LabeledMemoryset, hf_dataset: Dataset):
176
- memories = memoryset[1:3]
224
+ def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
225
+ memories = readonly_memoryset[1:3]
177
226
  assert len(memories) == 2
178
- assert memories[0].value == hf_dataset["text"][1]
179
- assert memories[1].value == hf_dataset["text"][2]
227
+ assert memories[0].value == hf_dataset["value"][1]
228
+ assert memories[1].value == hf_dataset["value"][2]
180
229
 
181
230
 
182
- def test_get_memory_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
183
- memory = memoryset.get(memoryset[0].memory_id)
184
- assert memory.value == hf_dataset[0]["text"]
185
- assert memory == memoryset[memory.memory_id]
231
+ def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
232
+ memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
233
+ assert memory.value == hf_dataset[0]["value"]
234
+ assert memory == readonly_memoryset[memory.memory_id]
186
235
 
187
236
 
188
- def test_get_memories_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
189
- memories = memoryset.get([memoryset[0].memory_id, memoryset[1].memory_id])
237
+ def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
238
+ memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
190
239
  assert len(memories) == 2
191
- assert memories[0].value == hf_dataset[0]["text"]
192
- assert memories[1].value == hf_dataset[1]["text"]
240
+ assert memories[0].value == hf_dataset[0]["value"]
241
+ assert memories[1].value == hf_dataset[1]["value"]
193
242
 
194
243
 
195
- def test_query_memoryset(memoryset: LabeledMemoryset):
196
- memories = memoryset.query(filters=[("label", "==", 1)])
197
- assert len(memories) == 3
244
+ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
245
+ memories = readonly_memoryset.query(filters=[("label", "==", 1)])
246
+ assert len(memories) == 8
198
247
  assert all(memory.label == 1 for memory in memories)
199
- assert len(memoryset.query(limit=2)) == 2
200
- assert len(memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
248
+ assert len(readonly_memoryset.query(limit=2)) == 2
249
+ assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
250
+
251
+
252
+ def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
253
+ prediction = classification_model.predict("Do you love soup?")
254
+ feedback_name = f"correct_{random.randint(0, 1000000)}"
255
+ prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
256
+ memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
257
+
258
+ # Get the memory_ids that were actually used in the prediction
259
+ used_memory_ids = {memory.memory_id for memory in prediction.memory_lookups}
260
+
261
+ assert len(memories) == 8
262
+ assert all(memory.label == 0 for memory in memories)
263
+ for memory in memories:
264
+ assert memory.feedback_metrics is not None
265
+ if memory.memory_id in used_memory_ids:
266
+ assert feedback_name in memory.feedback_metrics
267
+ assert memory.feedback_metrics[feedback_name]["avg"] == 1.0
268
+ assert memory.feedback_metrics[feedback_name]["count"] == 1
269
+ else:
270
+ assert feedback_name not in memory.feedback_metrics or memory.feedback_metrics[feedback_name]["count"] == 0
271
+ assert isinstance(memory.lookup_count, int)
272
+
273
+
274
+ def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
275
+ prediction = classification_model.predict("Do you love soup?")
276
+ prediction.record_feedback(category="accurate", value=prediction.label == 0)
277
+ memories = prediction.memoryset.query(
278
+ filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
279
+ )
280
+ assert len(memories) == 3
281
+ assert all(memory.label == 0 for memory in memories)
282
+ for memory in memories:
283
+ assert memory.feedback_metrics is not None
284
+ assert memory.feedback_metrics["accurate"] is not None
285
+ assert memory.feedback_metrics["accurate"]["avg"] == 1.0
286
+ assert memory.feedback_metrics["accurate"]["count"] == 1
287
+
288
+
289
+ def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
290
+ prediction = classification_model.predict("Do you love soup?")
291
+ prediction.record_feedback(category="positive", value=1.0)
292
+ prediction2 = classification_model.predict("Do you like cats?")
293
+ prediction2.record_feedback(category="positive", value=-1.0)
294
+
295
+ memories = prediction.memoryset.query(
296
+ filters=[("feedback_metrics.positive.avg", ">=", -1.0)],
297
+ sort=[("feedback_metrics.positive.avg", "desc")],
298
+ with_feedback_metrics=True,
299
+ )
300
+ assert (
301
+ len(memories) == 6
302
+ ) # there are only 6 out of 16 memories that have a positive feedback metric. Look at SAMPLE_DATA in conftest.py
303
+ assert memories[0].feedback_metrics["positive"]["avg"] == 1.0
304
+ assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
201
305
 
202
306
 
203
- def test_insert_memories(memoryset: LabeledMemoryset):
204
- prev_length = memoryset.length
205
- memoryset.insert(
307
+ def test_insert_memories(writable_memoryset: LabeledMemoryset):
308
+ writable_memoryset.refresh()
309
+ prev_length = writable_memoryset.length
310
+ writable_memoryset.insert(
206
311
  [
207
312
  dict(value="tomato soup is my favorite", label=0),
208
313
  dict(value="cats are fun to play with", label=1),
209
314
  ]
210
315
  )
211
- assert memoryset.length == prev_length + 2
212
- memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
213
- assert memoryset.length == prev_length + 3
214
- last_memory = memoryset[-1]
316
+ writable_memoryset.refresh()
317
+ assert writable_memoryset.length == prev_length + 2
318
+ writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
319
+ writable_memoryset.refresh()
320
+ assert writable_memoryset.length == prev_length + 3
321
+ last_memory = writable_memoryset[-1]
215
322
  assert last_memory.value == "tomato soup is my favorite"
216
323
  assert last_memory.label == 0
217
324
  assert last_memory.metadata
@@ -219,25 +326,28 @@ def test_insert_memories(memoryset: LabeledMemoryset):
219
326
  assert last_memory.source_id == "test"
220
327
 
221
328
 
222
- def test_update_memory(memoryset: LabeledMemoryset, hf_dataset: Dataset):
223
- memory_id = memoryset[0].memory_id
224
- updated_memory = memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
329
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
330
+ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
331
+ # We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
332
+
333
+ # test updating a single memory
334
+ memory_id = writable_memoryset[0].memory_id
335
+ updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
225
336
  assert updated_memory.value == "i love soup so much"
226
337
  assert updated_memory.label == hf_dataset[0]["label"]
227
- assert memoryset.get(memory_id).value == "i love soup so much"
338
+ writable_memoryset.refresh() # Refresh to ensure consistency after update
339
+ assert writable_memoryset.get(memory_id).value == "i love soup so much"
228
340
 
229
-
230
- def test_update_memory_instance(memoryset: LabeledMemoryset, hf_dataset: Dataset):
231
- memory = memoryset[0]
341
+ # test updating a memory instance
342
+ memory = writable_memoryset[0]
232
343
  updated_memory = memory.update(value="i love soup even more")
233
344
  assert updated_memory is memory
234
345
  assert memory.value == "i love soup even more"
235
346
  assert memory.label == hf_dataset[0]["label"]
236
347
 
237
-
238
- def test_update_memories(memoryset: LabeledMemoryset):
239
- memory_ids = [memory.memory_id for memory in memoryset[:2]]
240
- updated_memories = memoryset.update(
348
+ # test updating multiple memories
349
+ memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
350
+ updated_memories = writable_memoryset.update(
241
351
  [
242
352
  dict(memory_id=memory_ids[0], value="i love soup so much"),
243
353
  dict(memory_id=memory_ids[1], value="cats are so cute"),
@@ -247,25 +357,154 @@ def test_update_memories(memoryset: LabeledMemoryset):
247
357
  assert updated_memories[1].value == "cats are so cute"
248
358
 
249
359
 
250
- def test_delete_memory(memoryset: LabeledMemoryset):
251
- prev_length = memoryset.length
252
- memory_id = memoryset[0].memory_id
253
- memoryset.delete(memory_id)
254
- with pytest.raises(LookupError):
255
- memoryset.get(memory_id)
256
- assert memoryset.length == prev_length - 1
360
+ def test_delete_memories(writable_memoryset: LabeledMemoryset):
361
+ # We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
257
362
 
363
+ # test deleting a single memory
364
+ prev_length = writable_memoryset.length
365
+ memory_id = writable_memoryset[0].memory_id
366
+ writable_memoryset.delete(memory_id)
367
+ with pytest.raises(LookupError):
368
+ writable_memoryset.get(memory_id)
369
+ assert writable_memoryset.length == prev_length - 1
258
370
 
259
- def test_delete_memories(memoryset: LabeledMemoryset):
260
- prev_length = memoryset.length
261
- memoryset.delete([memoryset[0].memory_id, memoryset[1].memory_id])
262
- assert memoryset.length == prev_length - 2
371
+ # test deleting multiple memories
372
+ prev_length = writable_memoryset.length
373
+ writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id])
374
+ assert writable_memoryset.length == prev_length - 2
263
375
 
264
376
 
265
- def test_clone_memoryset(memoryset: LabeledMemoryset):
266
- cloned_memoryset = memoryset.clone("test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT)
377
+ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
378
+ cloned_memoryset = readonly_memoryset.clone(
379
+ "test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
380
+ )
267
381
  assert cloned_memoryset is not None
268
382
  assert cloned_memoryset.name == "test_cloned_memoryset"
269
- assert cloned_memoryset.length == memoryset.length
383
+ assert cloned_memoryset.length == readonly_memoryset.length
270
384
  assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
271
- assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
385
+ assert cloned_memoryset.insertion_status == Status.COMPLETED
386
+
387
+
388
+ def test_embedding_evaluation(eval_datasource: Datasource):
389
+ results = LabeledMemoryset.run_embedding_evaluation(
390
+ eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=3
391
+ )
392
+ assert isinstance(results, list)
393
+ assert len(results) == 1
394
+ assert results[0] is not None
395
+ assert results[0]["embedding_model_name"] == "CDE_SMALL"
396
+ assert results[0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
397
+
398
+
399
+ @pytest.fixture(scope="function")
400
+ async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
401
+ writable_memoryset.insert(
402
+ [
403
+ dict(value="raspberry soup Is my favorite", label=0),
404
+ dict(value="Raspberry soup is MY favorite", label=0),
405
+ dict(value="rAspberry soup is my favorite", label=0),
406
+ dict(value="raSpberry SOuP is my favorite", label=0),
407
+ dict(value="rasPberry SOuP is my favorite", label=0),
408
+ dict(value="bunny rabbit Is not my mom", label=1),
409
+ dict(value="bunny rabbit is not MY mom", label=1),
410
+ dict(value="bunny rabbit Is not my moM", label=1),
411
+ dict(value="bunny rabbit is not my mom", label=1),
412
+ dict(value="bunny rabbit is not my mom", label=1),
413
+ dict(value="bunny rabbit is not My mom", label=1),
414
+ ]
415
+ )
416
+
417
+ writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
418
+ response = writable_memoryset.get_potential_duplicate_groups()
419
+ assert isinstance(response, list)
420
+ assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
421
+
422
+
423
+ def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
424
+ # Insert a memory to test cascading edits
425
+ SOUP = 0
426
+ CATS = 1
427
+ query_text = "i love soup" # from SAMPLE_DATA in conftest.py
428
+ mislabeled_soup_text = "soup is comfort in a bowl"
429
+ writable_memoryset.insert(
430
+ [
431
+ dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
432
+ ]
433
+ )
434
+
435
+ # Fetch the memory to update
436
+ memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
437
+
438
+ # Update the label and get cascading edit suggestions
439
+ suggestions = writable_memoryset.get_cascading_edits_suggestions(
440
+ memory=memory,
441
+ old_label=CATS,
442
+ new_label=SOUP,
443
+ max_neighbors=10,
444
+ max_validation_neighbors=5,
445
+ )
446
+
447
+ # Validate the suggestions
448
+ assert len(suggestions) == 1
449
+ assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
450
+
451
+
452
+ def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
453
+ """Test that analyze() raises ValueError for invalid analysis names"""
454
+ memoryset = LabeledMemoryset.open(readonly_memoryset.name)
455
+
456
+ # Test with string input
457
+ with pytest.raises(ValueError) as excinfo:
458
+ memoryset.analyze("invalid_name")
459
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
460
+ assert "Valid names are:" in str(excinfo.value)
461
+
462
+ # Test with dict input
463
+ with pytest.raises(ValueError) as excinfo:
464
+ memoryset.analyze({"name": "invalid_name"})
465
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
466
+ assert "Valid names are:" in str(excinfo.value)
467
+
468
+ # Test with multiple analyses where one is invalid
469
+ with pytest.raises(ValueError) as excinfo:
470
+ memoryset.analyze("duplicate", "invalid_name")
471
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
472
+ assert "Valid names are:" in str(excinfo.value)
473
+
474
+ # Test with valid analysis names
475
+ result = memoryset.analyze("duplicate", "cluster")
476
+ assert isinstance(result, dict)
477
+ assert "duplicate" in result
478
+ assert "cluster" in result
479
+
480
+
481
+ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
482
+ # NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
483
+ # Otherwise, it would be recreated on the next test run if it were dropped earlier, and
484
+ # that's expensive.
485
+ assert LabeledMemoryset.exists(writable_memoryset.name)
486
+ LabeledMemoryset.drop(writable_memoryset.name)
487
+ assert not LabeledMemoryset.exists(writable_memoryset.name)
488
+
489
+
490
+ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
491
+ assert scored_memoryset.length == 22
492
+ assert isinstance(scored_memoryset[0], ScoredMemory)
493
+ assert scored_memoryset[0].value == "i love soup"
494
+ assert scored_memoryset[0].score is not None
495
+ assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
496
+ lookup = scored_memoryset.search("i love soup", count=1)
497
+ assert len(lookup) == 1
498
+ assert lookup[0].score is not None
499
+ assert lookup[0].score < 0.11
500
+
501
+
502
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
503
+ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
504
+ # we are only updating an inconsequential metadata field so that we don't affect other tests
505
+ memory = scored_memoryset[0]
506
+ assert memory.label == 0
507
+ scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
508
+ assert scored_memoryset[0].label == 3
509
+ memory.update(label=4)
510
+ assert scored_memoryset[0].label == 4