orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,86 +1,130 @@
1
+ import random
1
2
  from uuid import uuid4
2
3
 
3
4
  import pytest
4
5
  from datasets.arrow_dataset import Dataset
5
6
 
7
+ from .classification_model import ClassificationModel
8
+ from .conftest import skip_in_ci, skip_in_prod
9
+ from .datasource import Datasource
6
10
  from .embedding_model import PretrainedEmbeddingModel
7
- from .memoryset import LabeledMemoryset, TaskStatus
11
+ from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
8
12
 
13
+ """
14
+ Test Performance Note:
9
15
 
10
- def test_create_memoryset(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
11
- assert memoryset is not None
12
- assert memoryset.name == "test_memoryset"
13
- assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
14
- assert memoryset.label_names == label_names
15
- assert memoryset.insertion_status == TaskStatus.COMPLETED
16
- assert isinstance(memoryset.length, int)
17
- assert memoryset.length == len(hf_dataset)
16
+ Creating new `LabeledMemoryset` objects is expensive, so this test file applies the following optimizations:
18
17
 
18
+ - Two fixtures are used to manage memorysets:
19
+ - `readonly_memoryset` is a session-scoped fixture shared across tests that do not modify state.
20
+ It should only be used in nullipotent tests.
21
+ - `writable_memoryset` is a function-scoped, regenerating fixture.
22
+ It can be used in tests that mutate or delete the memoryset, and will be reset before each test.
19
23
 
20
- def test_create_memoryset_unauthenticated(unauthenticated, datasource):
21
- with pytest.raises(ValueError, match="Invalid API key"):
22
- LabeledMemoryset.create("test_memoryset", datasource)
24
+ - To minimize fixture overhead, tests using `writable_memoryset` should combine related behaviors.
25
+ For example, prefer a single `test_delete` that covers both single and multiple deletion cases,
26
+ rather than separate `test_delete_single` and `test_delete_multiple` tests.
27
+ """
28
+
29
+
30
+ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
31
+ assert readonly_memoryset is not None
32
+ assert readonly_memoryset.name == "test_readonly_memoryset"
33
+ assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
34
+ assert readonly_memoryset.label_names == label_names
35
+ assert readonly_memoryset.insertion_status == Status.COMPLETED
36
+ assert isinstance(readonly_memoryset.length, int)
37
+ assert readonly_memoryset.length == len(hf_dataset)
38
+ assert readonly_memoryset.index_type == "IVF_FLAT"
39
+ assert readonly_memoryset.index_params == {"n_lists": 100}
40
+
41
+
42
+ def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
43
+ with unauthenticated_client.use():
44
+ with pytest.raises(ValueError, match="Invalid API key"):
45
+ LabeledMemoryset.create("test_memoryset", datasource)
23
46
 
24
47
 
25
48
  def test_create_memoryset_invalid_input(datasource):
26
49
  # invalid name
27
50
  with pytest.raises(ValueError, match=r"Invalid input:.*"):
28
51
  LabeledMemoryset.create("test memoryset", datasource)
29
- # invalid datasource
30
- datasource.id = str(uuid4())
31
- with pytest.raises(ValueError, match=r"Invalid input:.*"):
32
- LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
33
52
 
34
53
 
35
- def test_create_memoryset_already_exists_error(hf_dataset, label_names, memoryset):
54
+ def test_create_memoryset_already_exists_error(hf_dataset, label_names, readonly_memoryset):
55
+ memoryset_name = readonly_memoryset.name
36
56
  with pytest.raises(ValueError):
37
- LabeledMemoryset.from_hf_dataset("test_memoryset", hf_dataset, label_names=label_names, value_column="text")
57
+ LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names)
38
58
  with pytest.raises(ValueError):
39
- LabeledMemoryset.from_hf_dataset(
40
- "test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
41
- )
59
+ LabeledMemoryset.from_hf_dataset(memoryset_name, hf_dataset, label_names=label_names, if_exists="error")
42
60
 
43
61
 
44
- def test_create_memoryset_already_exists_open(hf_dataset, label_names, memoryset):
62
+ def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_memoryset):
45
63
  # invalid label names
46
64
  with pytest.raises(ValueError):
47
65
  LabeledMemoryset.from_hf_dataset(
48
- memoryset.name,
66
+ readonly_memoryset.name,
49
67
  hf_dataset,
50
68
  label_names=["turtles", "frogs"],
51
- value_column="text",
52
69
  if_exists="open",
53
70
  )
54
71
  # different embedding model
55
72
  with pytest.raises(ValueError):
56
73
  LabeledMemoryset.from_hf_dataset(
57
- memoryset.name,
74
+ readonly_memoryset.name,
58
75
  hf_dataset,
59
76
  label_names=label_names,
60
77
  embedding_model=PretrainedEmbeddingModel.DISTILBERT,
61
78
  if_exists="open",
62
79
  )
63
80
  opened_memoryset = LabeledMemoryset.from_hf_dataset(
64
- memoryset.name,
81
+ readonly_memoryset.name,
65
82
  hf_dataset,
66
83
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
67
84
  if_exists="open",
68
85
  )
69
86
  assert opened_memoryset is not None
70
- assert opened_memoryset.name == memoryset.name
87
+ assert opened_memoryset.name == readonly_memoryset.name
71
88
  assert opened_memoryset.length == len(hf_dataset)
72
89
 
73
90
 
74
- def test_open_memoryset(memoryset, hf_dataset):
75
- fetched_memoryset = LabeledMemoryset.open(memoryset.name)
91
+ def test_if_exists_error_no_datasource_creation(
92
+ readonly_memoryset: LabeledMemoryset,
93
+ ):
94
+ memoryset_name = readonly_memoryset.name
95
+ datasource_name = f"{memoryset_name}_datasource"
96
+ Datasource.drop(datasource_name, if_not_exists="ignore")
97
+ assert not Datasource.exists(datasource_name)
98
+ with pytest.raises(ValueError):
99
+ LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
100
+ assert not Datasource.exists(datasource_name)
101
+
102
+
103
+ def test_if_exists_open_reuses_existing_datasource(
104
+ readonly_memoryset: LabeledMemoryset,
105
+ ):
106
+ memoryset_name = readonly_memoryset.name
107
+ datasource_name = f"{memoryset_name}_datasource"
108
+ Datasource.drop(datasource_name, if_not_exists="ignore")
109
+ assert not Datasource.exists(datasource_name)
110
+ reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
111
+ assert reopened.id == readonly_memoryset.id
112
+ assert not Datasource.exists(datasource_name)
113
+
114
+
115
+ def test_open_memoryset(readonly_memoryset, hf_dataset):
116
+ fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
76
117
  assert fetched_memoryset is not None
77
- assert fetched_memoryset.name == memoryset.name
118
+ assert fetched_memoryset.name == readonly_memoryset.name
78
119
  assert fetched_memoryset.length == len(hf_dataset)
120
+ assert fetched_memoryset.index_type == "IVF_FLAT"
121
+ assert fetched_memoryset.index_params == {"n_lists": 100}
79
122
 
80
123
 
81
- def test_open_memoryset_unauthenticated(unauthenticated, memoryset):
82
- with pytest.raises(ValueError, match="Invalid API key"):
83
- LabeledMemoryset.open(memoryset.name)
124
+ def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
125
+ with unauthenticated_client.use():
126
+ with pytest.raises(ValueError, match="Invalid API key"):
127
+ LabeledMemoryset.open(readonly_memoryset.name)
84
128
 
85
129
 
86
130
  def test_open_memoryset_not_found():
@@ -93,57 +137,93 @@ def test_open_memoryset_invalid_input():
93
137
  LabeledMemoryset.open("not valid id")
94
138
 
95
139
 
96
- def test_open_memoryset_unauthorized(unauthorized, memoryset):
97
- with pytest.raises(LookupError):
98
- LabeledMemoryset.open(memoryset.name)
140
+ def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
141
+ with unauthorized_client.use():
142
+ with pytest.raises(LookupError):
143
+ LabeledMemoryset.open(readonly_memoryset.name)
99
144
 
100
145
 
101
- def test_all_memorysets(memoryset):
146
+ def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
102
147
  memorysets = LabeledMemoryset.all()
103
148
  assert len(memorysets) > 0
104
- assert any(memoryset.name == memoryset.name for memoryset in memorysets)
149
+ assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
105
150
 
106
151
 
107
- def test_all_memorysets_unauthenticated(unauthenticated):
108
- with pytest.raises(ValueError, match="Invalid API key"):
109
- LabeledMemoryset.all()
152
+ def test_all_memorysets_hidden(
153
+ readonly_memoryset: LabeledMemoryset,
154
+ ):
155
+ # Create a hidden memoryset
156
+ hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
157
+ hidden_memoryset.set(hidden=True)
110
158
 
159
+ # Test that show_hidden=False excludes hidden memorysets
160
+ visible_memorysets = LabeledMemoryset.all(show_hidden=False)
161
+ assert len(visible_memorysets) > 0
162
+ assert readonly_memoryset in visible_memorysets
163
+ assert hidden_memoryset not in visible_memorysets
111
164
 
112
- def test_all_memorysets_unauthorized(unauthorized, memoryset):
113
- assert memoryset not in LabeledMemoryset.all()
165
+ # Test that show_hidden=True includes hidden memorysets
166
+ all_memorysets = LabeledMemoryset.all(show_hidden=True)
167
+ assert len(all_memorysets) == len(visible_memorysets) + 1
168
+ assert readonly_memoryset in all_memorysets
169
+ assert hidden_memoryset in all_memorysets
114
170
 
115
171
 
116
- @pytest.mark.flaky
117
- def test_drop_memoryset(hf_dataset):
118
- memoryset = LabeledMemoryset.from_hf_dataset(
119
- "test_memoryset_delete",
120
- hf_dataset.select(range(1)),
121
- value_column="text",
122
- )
123
- assert LabeledMemoryset.exists(memoryset.name)
124
- LabeledMemoryset.drop(memoryset.name)
125
- assert not LabeledMemoryset.exists(memoryset.name)
172
+ def test_all_memorysets_unauthenticated(unauthenticated_client):
173
+ with unauthenticated_client.use():
174
+ with pytest.raises(ValueError, match="Invalid API key"):
175
+ LabeledMemoryset.all()
176
+
126
177
 
178
+ def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
179
+ with unauthorized_client.use():
180
+ assert readonly_memoryset not in LabeledMemoryset.all()
127
181
 
128
- def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
129
- with pytest.raises(ValueError, match="Invalid API key"):
130
- LabeledMemoryset.drop(memoryset.name)
131
182
 
183
+ def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
184
+ with unauthenticated_client.use():
185
+ with pytest.raises(ValueError, match="Invalid API key"):
186
+ LabeledMemoryset.drop(readonly_memoryset.name)
132
187
 
133
- def test_drop_memoryset_not_found(memoryset):
188
+
189
+ def test_drop_memoryset_not_found():
134
190
  with pytest.raises(LookupError):
135
191
  LabeledMemoryset.drop(str(uuid4()))
136
192
  # ignores error if specified
137
193
  LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
138
194
 
139
195
 
140
- def test_drop_memoryset_unauthorized(unauthorized, memoryset):
141
- with pytest.raises(LookupError):
142
- LabeledMemoryset.drop(memoryset.name)
196
+ def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
197
+ with unauthorized_client.use():
198
+ with pytest.raises(LookupError):
199
+ LabeledMemoryset.drop(readonly_memoryset.name)
143
200
 
144
201
 
145
- def test_search(memoryset: LabeledMemoryset):
146
- memory_lookups = memoryset.search(["i love soup", "cats are cute"])
202
+ def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
203
+ original_label_names = writable_memoryset.label_names
204
+ writable_memoryset.set(description="New description")
205
+ assert writable_memoryset.description == "New description"
206
+
207
+ writable_memoryset.set(description=None)
208
+ assert writable_memoryset.description is None
209
+
210
+ writable_memoryset.set(name="New_name")
211
+ assert writable_memoryset.name == "New_name"
212
+
213
+ writable_memoryset.set(name="test_writable_memoryset")
214
+ assert writable_memoryset.name == "test_writable_memoryset"
215
+
216
+ assert writable_memoryset.label_names == original_label_names
217
+
218
+ writable_memoryset.set(label_names=["New label 1", "New label 2"])
219
+ assert writable_memoryset.label_names == ["New label 1", "New label 2"]
220
+
221
+ writable_memoryset.set(hidden=True)
222
+ assert writable_memoryset.hidden is True
223
+
224
+
225
+ def test_search(readonly_memoryset: LabeledMemoryset):
226
+ memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
147
227
  assert len(memory_lookups) == 2
148
228
  assert len(memory_lookups[0]) == 1
149
229
  assert len(memory_lookups[1]) == 1
@@ -151,67 +231,125 @@ def test_search(memoryset: LabeledMemoryset):
151
231
  assert memory_lookups[1][0].label == 1
152
232
 
153
233
 
154
- def test_search_count(memoryset: LabeledMemoryset):
155
- memory_lookups = memoryset.search("i love soup", count=3)
234
+ def test_search_count(readonly_memoryset: LabeledMemoryset):
235
+ memory_lookups = readonly_memoryset.search("i love soup", count=3)
156
236
  assert len(memory_lookups) == 3
157
237
  assert memory_lookups[0].label == 0
158
238
  assert memory_lookups[1].label == 0
159
239
  assert memory_lookups[2].label == 0
160
240
 
161
241
 
162
- def test_get_memory_at_index(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
163
- memory = memoryset[0]
164
- assert memory.value == hf_dataset[0]["text"]
242
+ def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
243
+ memory = readonly_memoryset[0]
244
+ assert memory.value == hf_dataset[0]["value"]
165
245
  assert memory.label == hf_dataset[0]["label"]
166
246
  assert memory.label_name == label_names[hf_dataset[0]["label"]]
167
247
  assert memory.source_id == hf_dataset[0]["source_id"]
168
248
  assert memory.score == hf_dataset[0]["score"]
169
249
  assert memory.key == hf_dataset[0]["key"]
170
- last_memory = memoryset[-1]
171
- assert last_memory.value == hf_dataset[-1]["text"]
250
+ last_memory = readonly_memoryset[-1]
251
+ assert last_memory.value == hf_dataset[-1]["value"]
172
252
  assert last_memory.label == hf_dataset[-1]["label"]
173
253
 
174
254
 
175
- def test_get_range_of_memories(memoryset: LabeledMemoryset, hf_dataset: Dataset):
176
- memories = memoryset[1:3]
255
+ def test_get_range_of_memories(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
256
+ memories = readonly_memoryset[1:3]
177
257
  assert len(memories) == 2
178
- assert memories[0].value == hf_dataset["text"][1]
179
- assert memories[1].value == hf_dataset["text"][2]
258
+ assert memories[0].value == hf_dataset["value"][1]
259
+ assert memories[1].value == hf_dataset["value"][2]
180
260
 
181
261
 
182
- def test_get_memory_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
183
- memory = memoryset.get(memoryset[0].memory_id)
184
- assert memory.value == hf_dataset[0]["text"]
185
- assert memory == memoryset[memory.memory_id]
262
+ def test_get_memory_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
263
+ memory = readonly_memoryset.get(readonly_memoryset[0].memory_id)
264
+ assert memory.value == hf_dataset[0]["value"]
265
+ assert memory == readonly_memoryset[memory.memory_id]
186
266
 
187
267
 
188
- def test_get_memories_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
189
- memories = memoryset.get([memoryset[0].memory_id, memoryset[1].memory_id])
268
+ def test_get_memories_by_id(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset):
269
+ memories = readonly_memoryset.get([readonly_memoryset[0].memory_id, readonly_memoryset[1].memory_id])
190
270
  assert len(memories) == 2
191
- assert memories[0].value == hf_dataset[0]["text"]
192
- assert memories[1].value == hf_dataset[1]["text"]
271
+ assert memories[0].value == hf_dataset[0]["value"]
272
+ assert memories[1].value == hf_dataset[1]["value"]
193
273
 
194
274
 
195
- def test_query_memoryset(memoryset: LabeledMemoryset):
196
- memories = memoryset.query(filters=[("label", "==", 1)])
197
- assert len(memories) == 3
275
+ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
276
+ memories = readonly_memoryset.query(filters=[("label", "==", 1)])
277
+ assert len(memories) == 8
198
278
  assert all(memory.label == 1 for memory in memories)
199
- assert len(memoryset.query(limit=2)) == 2
200
- assert len(memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
279
+ assert len(readonly_memoryset.query(limit=2)) == 2
280
+ assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
281
+
282
+
283
+ def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
284
+ prediction = classification_model.predict("Do you love soup?")
285
+ feedback_name = f"correct_{random.randint(0, 1000000)}"
286
+ prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
287
+ memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
288
+
289
+ # Get the memory_ids that were actually used in the prediction
290
+ used_memory_ids = {memory.memory_id for memory in prediction.memory_lookups}
291
+
292
+ assert len(memories) == 8
293
+ assert all(memory.label == 0 for memory in memories)
294
+ for memory in memories:
295
+ assert memory.feedback_metrics is not None
296
+ if memory.memory_id in used_memory_ids:
297
+ assert feedback_name in memory.feedback_metrics
298
+ assert memory.feedback_metrics[feedback_name]["avg"] == 1.0
299
+ assert memory.feedback_metrics[feedback_name]["count"] == 1
300
+ else:
301
+ assert feedback_name not in memory.feedback_metrics or memory.feedback_metrics[feedback_name]["count"] == 0
302
+ assert isinstance(memory.lookup_count, int)
303
+
304
+
305
+ def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
306
+ prediction = classification_model.predict("Do you love soup?")
307
+ prediction.record_feedback(category="accurate", value=prediction.label == 0)
308
+ memories = prediction.memoryset.query(
309
+ filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
310
+ )
311
+ assert len(memories) == 3
312
+ assert all(memory.label == 0 for memory in memories)
313
+ for memory in memories:
314
+ assert memory.feedback_metrics is not None
315
+ assert memory.feedback_metrics["accurate"] is not None
316
+ assert memory.feedback_metrics["accurate"]["avg"] == 1.0
317
+ assert memory.feedback_metrics["accurate"]["count"] == 1
318
+
319
+
320
+ def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
321
+ prediction = classification_model.predict("Do you love soup?")
322
+ prediction.record_feedback(category="positive", value=1.0)
323
+ prediction2 = classification_model.predict("Do you like cats?")
324
+ prediction2.record_feedback(category="positive", value=-1.0)
325
+
326
+ memories = prediction.memoryset.query(
327
+ filters=[("feedback_metrics.positive.avg", ">=", -1.0)],
328
+ sort=[("feedback_metrics.positive.avg", "desc")],
329
+ with_feedback_metrics=True,
330
+ )
331
+ assert (
332
+ len(memories) == 6
333
+ ) # there are only 6 out of 16 memories that have a positive feedback metric. Look at SAMPLE_DATA in conftest.py
334
+ assert memories[0].feedback_metrics["positive"]["avg"] == 1.0
335
+ assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
201
336
 
202
337
 
203
- def test_insert_memories(memoryset: LabeledMemoryset):
204
- prev_length = memoryset.length
205
- memoryset.insert(
338
+ def test_insert_memories(writable_memoryset: LabeledMemoryset):
339
+ writable_memoryset.refresh()
340
+ prev_length = writable_memoryset.length
341
+ writable_memoryset.insert(
206
342
  [
207
343
  dict(value="tomato soup is my favorite", label=0),
208
344
  dict(value="cats are fun to play with", label=1),
209
345
  ]
210
346
  )
211
- assert memoryset.length == prev_length + 2
212
- memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
213
- assert memoryset.length == prev_length + 3
214
- last_memory = memoryset[-1]
347
+ writable_memoryset.refresh()
348
+ assert writable_memoryset.length == prev_length + 2
349
+ writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
350
+ writable_memoryset.refresh()
351
+ assert writable_memoryset.length == prev_length + 3
352
+ last_memory = writable_memoryset[-1]
215
353
  assert last_memory.value == "tomato soup is my favorite"
216
354
  assert last_memory.label == 0
217
355
  assert last_memory.metadata
@@ -219,25 +357,29 @@ def test_insert_memories(memoryset: LabeledMemoryset):
219
357
  assert last_memory.source_id == "test"
220
358
 
221
359
 
222
- def test_update_memory(memoryset: LabeledMemoryset, hf_dataset: Dataset):
223
- memory_id = memoryset[0].memory_id
224
- updated_memory = memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
360
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
361
+ @skip_in_ci("CI environment may not have session consistency guarantees")
362
+ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
363
+ # We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
364
+
365
+ # test updating a single memory
366
+ memory_id = writable_memoryset[0].memory_id
367
+ updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
225
368
  assert updated_memory.value == "i love soup so much"
226
369
  assert updated_memory.label == hf_dataset[0]["label"]
227
- assert memoryset.get(memory_id).value == "i love soup so much"
370
+ writable_memoryset.refresh() # Refresh to ensure consistency after update
371
+ assert writable_memoryset.get(memory_id).value == "i love soup so much"
228
372
 
229
-
230
- def test_update_memory_instance(memoryset: LabeledMemoryset, hf_dataset: Dataset):
231
- memory = memoryset[0]
373
+ # test updating a memory instance
374
+ memory = writable_memoryset[0]
232
375
  updated_memory = memory.update(value="i love soup even more")
233
376
  assert updated_memory is memory
234
377
  assert memory.value == "i love soup even more"
235
378
  assert memory.label == hf_dataset[0]["label"]
236
379
 
237
-
238
- def test_update_memories(memoryset: LabeledMemoryset):
239
- memory_ids = [memory.memory_id for memory in memoryset[:2]]
240
- updated_memories = memoryset.update(
380
+ # test updating multiple memories
381
+ memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
382
+ updated_memories = writable_memoryset.update(
241
383
  [
242
384
  dict(memory_id=memory_ids[0], value="i love soup so much"),
243
385
  dict(memory_id=memory_ids[1], value="cats are so cute"),
@@ -247,25 +389,220 @@ def test_update_memories(memoryset: LabeledMemoryset):
247
389
  assert updated_memories[1].value == "cats are so cute"
248
390
 
249
391
 
250
- def test_delete_memory(memoryset: LabeledMemoryset):
251
- prev_length = memoryset.length
252
- memory_id = memoryset[0].memory_id
253
- memoryset.delete(memory_id)
254
- with pytest.raises(LookupError):
255
- memoryset.get(memory_id)
256
- assert memoryset.length == prev_length - 1
392
+ def test_delete_memories(writable_memoryset: LabeledMemoryset):
393
+ # We've combined the delete tests into one to avoid multiple expensive requests for a writable_memoryset
257
394
 
395
+ # test deleting a single memory
396
+ prev_length = writable_memoryset.length
397
+ memory_id = writable_memoryset[0].memory_id
398
+ writable_memoryset.delete(memory_id)
399
+ with pytest.raises(LookupError):
400
+ writable_memoryset.get(memory_id)
401
+ assert writable_memoryset.length == prev_length - 1
258
402
 
259
- def test_delete_memories(memoryset: LabeledMemoryset):
260
- prev_length = memoryset.length
261
- memoryset.delete([memoryset[0].memory_id, memoryset[1].memory_id])
262
- assert memoryset.length == prev_length - 2
403
+ # test deleting multiple memories
404
+ prev_length = writable_memoryset.length
405
+ writable_memoryset.delete([writable_memoryset[0].memory_id, writable_memoryset[1].memory_id])
406
+ assert writable_memoryset.length == prev_length - 2
263
407
 
264
408
 
265
- def test_clone_memoryset(memoryset: LabeledMemoryset):
266
- cloned_memoryset = memoryset.clone("test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT)
409
+ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
410
+ cloned_memoryset = readonly_memoryset.clone(
411
+ "test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
412
+ )
267
413
  assert cloned_memoryset is not None
268
414
  assert cloned_memoryset.name == "test_cloned_memoryset"
269
- assert cloned_memoryset.length == memoryset.length
415
+ assert cloned_memoryset.length == readonly_memoryset.length
270
416
  assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
271
- assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
417
+ assert cloned_memoryset.insertion_status == Status.COMPLETED
418
+
419
+
420
+ @pytest.fixture(scope="function")
421
+ async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
422
+ writable_memoryset.insert(
423
+ [
424
+ dict(value="raspberry soup Is my favorite", label=0),
425
+ dict(value="Raspberry soup is MY favorite", label=0),
426
+ dict(value="rAspberry soup is my favorite", label=0),
427
+ dict(value="raSpberry SOuP is my favorite", label=0),
428
+ dict(value="rasPberry SOuP is my favorite", label=0),
429
+ dict(value="bunny rabbit Is not my mom", label=1),
430
+ dict(value="bunny rabbit is not MY mom", label=1),
431
+ dict(value="bunny rabbit Is not my moM", label=1),
432
+ dict(value="bunny rabbit is not my mom", label=1),
433
+ dict(value="bunny rabbit is not my mom", label=1),
434
+ dict(value="bunny rabbit is not My mom", label=1),
435
+ ]
436
+ )
437
+
438
+ writable_memoryset.analyze({"name": "duplicate", "possible_duplicate_threshold": 0.97})
439
+ response = writable_memoryset.get_potential_duplicate_groups()
440
+ assert isinstance(response, list)
441
+ assert sorted([len(res) for res in response]) == [5, 6] # 5 favorite, 6 mom
442
+
443
+
444
+ def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
445
+ # Insert a memory to test cascading edits
446
+ SOUP = 0
447
+ CATS = 1
448
+ query_text = "i love soup" # from SAMPLE_DATA in conftest.py
449
+ mislabeled_soup_text = "soup is comfort in a bowl"
450
+ writable_memoryset.insert(
451
+ [
452
+ dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
453
+ ]
454
+ )
455
+
456
+ # Fetch the memory to update
457
+ memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
458
+
459
+ # Update the label and get cascading edit suggestions
460
+ suggestions = writable_memoryset.get_cascading_edits_suggestions(
461
+ memory=memory,
462
+ old_label=CATS,
463
+ new_label=SOUP,
464
+ max_neighbors=10,
465
+ max_validation_neighbors=5,
466
+ )
467
+
468
+ # Validate the suggestions
469
+ assert len(suggestions) == 1
470
+ assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
471
+
472
+
473
+ def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
474
+ """Test that analyze() raises ValueError for invalid analysis names"""
475
+ memoryset = LabeledMemoryset.open(readonly_memoryset.name)
476
+
477
+ # Test with string input
478
+ with pytest.raises(ValueError) as excinfo:
479
+ memoryset.analyze("invalid_name")
480
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
481
+ assert "Valid names are:" in str(excinfo.value)
482
+
483
+ # Test with dict input
484
+ with pytest.raises(ValueError) as excinfo:
485
+ memoryset.analyze({"name": "invalid_name"})
486
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
487
+ assert "Valid names are:" in str(excinfo.value)
488
+
489
+ # Test with multiple analyses where one is invalid
490
+ with pytest.raises(ValueError) as excinfo:
491
+ memoryset.analyze("duplicate", "invalid_name")
492
+ assert "Invalid analysis name: invalid_name" in str(excinfo.value)
493
+ assert "Valid names are:" in str(excinfo.value)
494
+
495
+ # Test with valid analysis names
496
+ result = memoryset.analyze("duplicate", "cluster")
497
+ assert isinstance(result, dict)
498
+ assert "duplicate" in result
499
+ assert "cluster" in result
500
+
501
+
502
+ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
503
+ # NOTE: Keep this test at the end to ensure the memoryset is dropped after all tests.
504
+ # Otherwise, it would be recreated on the next test run if it were dropped earlier, and
505
+ # that's expensive.
506
+ assert LabeledMemoryset.exists(writable_memoryset.name)
507
+ LabeledMemoryset.drop(writable_memoryset.name)
508
+ assert not LabeledMemoryset.exists(writable_memoryset.name)
509
+
510
+
511
+ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
512
+ assert scored_memoryset.length == 22
513
+ assert isinstance(scored_memoryset[0], ScoredMemory)
514
+ assert scored_memoryset[0].value == "i love soup"
515
+ assert scored_memoryset[0].score is not None
516
+ assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
517
+ lookup = scored_memoryset.search("i love soup", count=1)
518
+ assert len(lookup) == 1
519
+ assert lookup[0].score is not None
520
+ assert lookup[0].score < 0.11
521
+
522
+
523
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
524
+ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
525
+ # we are only updating an inconsequential metadata field so that we don't affect other tests
526
+ memory = scored_memoryset[0]
527
+ assert memory.label == 0
528
+ scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
529
+ assert scored_memoryset[0].label == 3
530
+ memory.update(label=4)
531
+ assert scored_memoryset[0].label == 4
532
+
533
+
534
+ @pytest.mark.asyncio
535
+ async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
536
+ """Test async insertion of a single memory"""
537
+ await writable_memoryset.arefresh()
538
+ prev_length = writable_memoryset.length
539
+
540
+ await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
541
+
542
+ await writable_memoryset.arefresh()
543
+ assert writable_memoryset.length == prev_length + 1
544
+ last_memory = writable_memoryset[-1]
545
+ assert last_memory.value == "async tomato soup is my favorite"
546
+ assert last_memory.label == 0
547
+ assert last_memory.metadata["key"] == "async_test"
548
+
549
+
550
+ @pytest.mark.asyncio
551
+ async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
552
+ """Test async insertion of multiple memories"""
553
+ await writable_memoryset.arefresh()
554
+ prev_length = writable_memoryset.length
555
+
556
+ await writable_memoryset.ainsert(
557
+ [
558
+ dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
559
+ dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
560
+ ]
561
+ )
562
+
563
+ await writable_memoryset.arefresh()
564
+ assert writable_memoryset.length == prev_length + 2
565
+
566
+ # Check the inserted memories
567
+ last_two_memories = writable_memoryset[-2:]
568
+ values = [memory.value for memory in last_two_memories]
569
+ labels = [memory.label for memory in last_two_memories]
570
+ keys = [memory.metadata.get("key") for memory in last_two_memories]
571
+
572
+ assert "async batch soup is delicious" in values
573
+ assert "async batch cats are adorable" in values
574
+ assert 0 in labels
575
+ assert 1 in labels
576
+ assert "batch_test_1" in keys
577
+ assert "batch_test_2" in keys
578
+
579
+
580
+ @pytest.mark.asyncio
581
+ async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
582
+ """Test async insertion with source_id and metadata"""
583
+ await writable_memoryset.arefresh()
584
+ prev_length = writable_memoryset.length
585
+
586
+ await writable_memoryset.ainsert(
587
+ dict(
588
+ value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
589
+ )
590
+ )
591
+
592
+ await writable_memoryset.arefresh()
593
+ assert writable_memoryset.length == prev_length + 1
594
+ last_memory = writable_memoryset[-1]
595
+ assert last_memory.value == "async soup with source id"
596
+ assert last_memory.label == 0
597
+ assert last_memory.source_id == "async_source_123"
598
+ assert last_memory.metadata["custom_field"] == "async_custom_value"
599
+
600
+
601
+ @pytest.mark.asyncio
602
+ async def test_insert_memories_async_unauthenticated(
603
+ unauthenticated_async_client, writable_memoryset: LabeledMemoryset
604
+ ):
605
+ """Test async insertion with invalid authentication"""
606
+ with unauthenticated_async_client.use():
607
+ with pytest.raises(ValueError, match="Invalid API key"):
608
+ await writable_memoryset.ainsert(dict(value="this should fail", label=0))