orca-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. orca_sdk/__init__.py +19 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +193 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
  22. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
  23. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
  24. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
  25. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  26. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
  27. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
  28. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
  30. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
  34. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  35. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
  36. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
  37. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
  38. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
  39. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
  40. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
  41. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
  42. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
  43. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
  45. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
  46. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
  47. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
  48. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
  49. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
  50. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
  51. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
  52. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  53. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
  54. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
  56. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
  58. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
  59. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
  60. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
  62. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
  63. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
  64. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
  65. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
  66. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
  67. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
  68. orca_sdk/_generated_api_client/client.py +216 -0
  69. orca_sdk/_generated_api_client/errors.py +38 -0
  70. orca_sdk/_generated_api_client/models/__init__.py +159 -0
  71. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
  72. orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
  73. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  74. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  75. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
  76. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  77. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  78. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  79. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  80. orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
  81. orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
  82. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
  83. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  84. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  85. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  86. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  87. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  88. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  89. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  90. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  91. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  92. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  93. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  94. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
  95. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  96. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  97. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  98. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  99. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  100. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  101. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  102. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  103. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  104. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  105. orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
  106. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
  107. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  108. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
  109. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  110. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  111. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  112. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  113. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
  114. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  115. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  116. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  117. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  118. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  119. orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
  120. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
  121. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
  122. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  123. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  124. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  125. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  126. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  127. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  128. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  129. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
  130. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  131. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  132. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  133. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  134. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  135. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  136. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
  137. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  138. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  139. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  140. orca_sdk/_generated_api_client/models/task.py +198 -0
  141. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  142. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  143. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  144. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  145. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  146. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  147. orca_sdk/_generated_api_client/py.typed +1 -0
  148. orca_sdk/_generated_api_client/types.py +56 -0
  149. orca_sdk/_utils/__init__.py +0 -0
  150. orca_sdk/_utils/analysis_ui.py +194 -0
  151. orca_sdk/_utils/analysis_ui_style.css +54 -0
  152. orca_sdk/_utils/auth.py +63 -0
  153. orca_sdk/_utils/auth_test.py +31 -0
  154. orca_sdk/_utils/common.py +37 -0
  155. orca_sdk/_utils/data_parsing.py +99 -0
  156. orca_sdk/_utils/data_parsing_test.py +244 -0
  157. orca_sdk/_utils/prediction_result_ui.css +18 -0
  158. orca_sdk/_utils/prediction_result_ui.py +64 -0
  159. orca_sdk/_utils/task.py +73 -0
  160. orca_sdk/classification_model.py +499 -0
  161. orca_sdk/classification_model_test.py +266 -0
  162. orca_sdk/conftest.py +117 -0
  163. orca_sdk/datasource.py +333 -0
  164. orca_sdk/datasource_test.py +95 -0
  165. orca_sdk/embedding_model.py +336 -0
  166. orca_sdk/embedding_model_test.py +173 -0
  167. orca_sdk/labeled_memoryset.py +1154 -0
  168. orca_sdk/labeled_memoryset_test.py +271 -0
  169. orca_sdk/orca_credentials.py +75 -0
  170. orca_sdk/orca_credentials_test.py +37 -0
  171. orca_sdk/telemetry.py +386 -0
  172. orca_sdk/telemetry_test.py +100 -0
  173. orca_sdk-0.1.0.dist-info/METADATA +39 -0
  174. orca_sdk-0.1.0.dist-info/RECORD +175 -0
  175. orca_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,271 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+ from datasets.arrow_dataset import Dataset
5
+
6
+ from .embedding_model import PretrainedEmbeddingModel
7
+ from .labeled_memoryset import LabeledMemoryset, TaskStatus
8
+
9
+
10
+ def test_create_memoryset(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
11
+ assert memoryset is not None
12
+ assert memoryset.name == "test_memoryset"
13
+ assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
14
+ assert memoryset.label_names == label_names
15
+ assert memoryset.insertion_status == TaskStatus.COMPLETED
16
+ assert isinstance(memoryset.length, int)
17
+ assert memoryset.length == len(hf_dataset)
18
+
19
+
20
+ def test_create_memoryset_unauthenticated(unauthenticated, datasource):
21
+ with pytest.raises(ValueError, match="Invalid API key"):
22
+ LabeledMemoryset.create("test_memoryset", datasource)
23
+
24
+
25
+ def test_create_memoryset_invalid_input(datasource):
26
+ # invalid name
27
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
28
+ LabeledMemoryset.create("test memoryset", datasource)
29
+ # invalid datasource
30
+ datasource.id = str(uuid4())
31
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
32
+ LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
33
+
34
+
35
+ def test_create_memoryset_already_exists_error(hf_dataset, label_names, memoryset):
36
+ with pytest.raises(ValueError):
37
+ LabeledMemoryset.from_hf_dataset("test_memoryset", hf_dataset, label_names=label_names, value_column="text")
38
+ with pytest.raises(ValueError):
39
+ LabeledMemoryset.from_hf_dataset(
40
+ "test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
41
+ )
42
+
43
+
44
+ def test_create_memoryset_already_exists_open(hf_dataset, label_names, memoryset):
45
+ # invalid label names
46
+ with pytest.raises(ValueError):
47
+ LabeledMemoryset.from_hf_dataset(
48
+ memoryset.name,
49
+ hf_dataset,
50
+ label_names=["turtles", "frogs"],
51
+ value_column="text",
52
+ if_exists="open",
53
+ )
54
+ # different embedding model
55
+ with pytest.raises(ValueError):
56
+ LabeledMemoryset.from_hf_dataset(
57
+ memoryset.name,
58
+ hf_dataset,
59
+ label_names=label_names,
60
+ embedding_model=PretrainedEmbeddingModel.DISTILBERT,
61
+ if_exists="open",
62
+ )
63
+ opened_memoryset = LabeledMemoryset.from_hf_dataset(
64
+ memoryset.name,
65
+ hf_dataset,
66
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
67
+ if_exists="open",
68
+ )
69
+ assert opened_memoryset is not None
70
+ assert opened_memoryset.name == memoryset.name
71
+ assert opened_memoryset.length == len(hf_dataset)
72
+
73
+
74
+ def test_open_memoryset(memoryset, hf_dataset):
75
+ fetched_memoryset = LabeledMemoryset.open(memoryset.name)
76
+ assert fetched_memoryset is not None
77
+ assert fetched_memoryset.name == memoryset.name
78
+ assert fetched_memoryset.length == len(hf_dataset)
79
+
80
+
81
+ def test_open_memoryset_unauthenticated(unauthenticated, memoryset):
82
+ with pytest.raises(ValueError, match="Invalid API key"):
83
+ LabeledMemoryset.open(memoryset.name)
84
+
85
+
86
+ def test_open_memoryset_not_found():
87
+ with pytest.raises(LookupError):
88
+ LabeledMemoryset.open(str(uuid4()))
89
+
90
+
91
+ def test_open_memoryset_invalid_input():
92
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
93
+ LabeledMemoryset.open("not valid id")
94
+
95
+
96
+ def test_open_memoryset_unauthorized(unauthorized, memoryset):
97
+ with pytest.raises(LookupError):
98
+ LabeledMemoryset.open(memoryset.name)
99
+
100
+
101
+ def test_all_memorysets(memoryset):
102
+ memorysets = LabeledMemoryset.all()
103
+ assert len(memorysets) > 0
104
+ assert any(memoryset.name == memoryset.name for memoryset in memorysets)
105
+
106
+
107
+ def test_all_memorysets_unauthenticated(unauthenticated):
108
+ with pytest.raises(ValueError, match="Invalid API key"):
109
+ LabeledMemoryset.all()
110
+
111
+
112
+ def test_all_memorysets_unauthorized(unauthorized, memoryset):
113
+ assert memoryset not in LabeledMemoryset.all()
114
+
115
+
116
+ @pytest.mark.flaky
117
+ def test_drop_memoryset(hf_dataset):
118
+ memoryset = LabeledMemoryset.from_hf_dataset(
119
+ "test_memoryset_delete",
120
+ hf_dataset.select(range(1)),
121
+ value_column="text",
122
+ )
123
+ assert LabeledMemoryset.exists(memoryset.name)
124
+ LabeledMemoryset.drop(memoryset.name)
125
+ assert not LabeledMemoryset.exists(memoryset.name)
126
+
127
+
128
+ def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
129
+ with pytest.raises(ValueError, match="Invalid API key"):
130
+ LabeledMemoryset.drop(memoryset.name)
131
+
132
+
133
+ def test_drop_memoryset_not_found(memoryset):
134
+ with pytest.raises(LookupError):
135
+ LabeledMemoryset.drop(str(uuid4()))
136
+ # ignores error if specified
137
+ LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
138
+
139
+
140
+ def test_drop_memoryset_unauthorized(unauthorized, memoryset):
141
+ with pytest.raises(LookupError):
142
+ LabeledMemoryset.drop(memoryset.name)
143
+
144
+
145
+ def test_search(memoryset: LabeledMemoryset):
146
+ memory_lookups = memoryset.search(["i love soup", "cats are cute"])
147
+ assert len(memory_lookups) == 2
148
+ assert len(memory_lookups[0]) == 1
149
+ assert len(memory_lookups[1]) == 1
150
+ assert memory_lookups[0][0].label == 0
151
+ assert memory_lookups[1][0].label == 1
152
+
153
+
154
+ def test_search_count(memoryset: LabeledMemoryset):
155
+ memory_lookups = memoryset.search("i love soup", count=3)
156
+ assert len(memory_lookups) == 3
157
+ assert memory_lookups[0].label == 0
158
+ assert memory_lookups[1].label == 0
159
+ assert memory_lookups[2].label == 0
160
+
161
+
162
+ def test_get_memory_at_index(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
163
+ memory = memoryset[0]
164
+ assert memory.value == hf_dataset[0]["text"]
165
+ assert memory.label == hf_dataset[0]["label"]
166
+ assert memory.label_name == label_names[hf_dataset[0]["label"]]
167
+ assert memory.source_id == hf_dataset[0]["source_id"]
168
+ assert memory.score == hf_dataset[0]["score"]
169
+ assert memory.key == hf_dataset[0]["key"]
170
+ last_memory = memoryset[-1]
171
+ assert last_memory.value == hf_dataset[-1]["text"]
172
+ assert last_memory.label == hf_dataset[-1]["label"]
173
+
174
+
175
+ def test_get_range_of_memories(memoryset: LabeledMemoryset, hf_dataset: Dataset):
176
+ memories = memoryset[1:3]
177
+ assert len(memories) == 2
178
+ assert memories[0].value == hf_dataset["text"][1]
179
+ assert memories[1].value == hf_dataset["text"][2]
180
+
181
+
182
+ def test_get_memory_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
183
+ memory = memoryset.get(memoryset[0].memory_id)
184
+ assert memory.value == hf_dataset[0]["text"]
185
+ assert memory == memoryset[memory.memory_id]
186
+
187
+
188
+ def test_get_memories_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
189
+ memories = memoryset.get([memoryset[0].memory_id, memoryset[1].memory_id])
190
+ assert len(memories) == 2
191
+ assert memories[0].value == hf_dataset[0]["text"]
192
+ assert memories[1].value == hf_dataset[1]["text"]
193
+
194
+
195
+ def test_query_memoryset(memoryset: LabeledMemoryset):
196
+ memories = memoryset.query(filters=[("label", "==", 1)])
197
+ assert len(memories) == 3
198
+ assert all(memory.label == 1 for memory in memories)
199
+ assert len(memoryset.query(limit=2)) == 2
200
+ assert len(memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
201
+
202
+
203
+ def test_insert_memories(memoryset: LabeledMemoryset):
204
+ prev_length = memoryset.length
205
+ memoryset.insert(
206
+ [
207
+ dict(value="tomato soup is my favorite", label=0),
208
+ dict(value="cats are fun to play with", label=1),
209
+ ]
210
+ )
211
+ assert memoryset.length == prev_length + 2
212
+ memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
213
+ assert memoryset.length == prev_length + 3
214
+ last_memory = memoryset[-1]
215
+ assert last_memory.value == "tomato soup is my favorite"
216
+ assert last_memory.label == 0
217
+ assert last_memory.metadata
218
+ assert last_memory.metadata["key"] == "test"
219
+ assert last_memory.source_id == "test"
220
+
221
+
222
+ def test_update_memory(memoryset: LabeledMemoryset, hf_dataset: Dataset):
223
+ memory_id = memoryset[0].memory_id
224
+ updated_memory = memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
225
+ assert updated_memory.value == "i love soup so much"
226
+ assert updated_memory.label == hf_dataset[0]["label"]
227
+ assert memoryset.get(memory_id).value == "i love soup so much"
228
+
229
+
230
+ def test_update_memory_instance(memoryset: LabeledMemoryset, hf_dataset: Dataset):
231
+ memory = memoryset[0]
232
+ updated_memory = memory.update(value="i love soup even more")
233
+ assert updated_memory is memory
234
+ assert memory.value == "i love soup even more"
235
+ assert memory.label == hf_dataset[0]["label"]
236
+
237
+
238
+ def test_update_memories(memoryset: LabeledMemoryset):
239
+ memory_ids = [memory.memory_id for memory in memoryset[:2]]
240
+ updated_memories = memoryset.update(
241
+ [
242
+ dict(memory_id=memory_ids[0], value="i love soup so much"),
243
+ dict(memory_id=memory_ids[1], value="cats are so cute"),
244
+ ]
245
+ )
246
+ assert updated_memories[0].value == "i love soup so much"
247
+ assert updated_memories[1].value == "cats are so cute"
248
+
249
+
250
+ def test_delete_memory(memoryset: LabeledMemoryset):
251
+ prev_length = memoryset.length
252
+ memory_id = memoryset[0].memory_id
253
+ memoryset.delete(memory_id)
254
+ with pytest.raises(LookupError):
255
+ memoryset.get(memory_id)
256
+ assert memoryset.length == prev_length - 1
257
+
258
+
259
+ def test_delete_memories(memoryset: LabeledMemoryset):
260
+ prev_length = memoryset.length
261
+ memoryset.delete([memoryset[0].memory_id, memoryset[1].memory_id])
262
+ assert memoryset.length == prev_length - 2
263
+
264
+
265
+ def test_clone_memoryset(memoryset: LabeledMemoryset):
266
+ cloned_memoryset = memoryset.clone("test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT)
267
+ assert cloned_memoryset is not None
268
+ assert cloned_memoryset.name == "test_cloned_memoryset"
269
+ assert cloned_memoryset.length == memoryset.length
270
+ assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
271
+ assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
@@ -0,0 +1,75 @@
1
+ from datetime import datetime
2
+ from typing import NamedTuple
3
+
4
+ from ._generated_api_client.api import check_authentication, list_api_keys
5
+ from ._generated_api_client.client import get_base_url, get_headers, set_headers
6
+
7
+
8
+ class ApiKeyInfo(NamedTuple):
9
+ """
10
+ Named tuple containing information about an API key
11
+
12
+ Attributes:
13
+ name: Unique name of the API key
14
+ created_at: When the API key was created
15
+ """
16
+
17
+ name: str
18
+ created_at: datetime
19
+
20
+
21
+ class OrcaCredentials:
22
+ """
23
+ Class for managing Orca API credentials
24
+ """
25
+
26
+ @staticmethod
27
+ def get_api_url() -> str:
28
+ """
29
+ Get the Orca API base URL that is currently being used
30
+ """
31
+ return get_base_url()
32
+
33
+ @staticmethod
34
+ def list_api_keys() -> list[ApiKeyInfo]:
35
+ """
36
+ List all API keys that have been created for your org
37
+
38
+ Returns:
39
+ A list of named tuples, with the name and creation date time of the API key
40
+ """
41
+ return [ApiKeyInfo(name=api_key.name, created_at=api_key.created_at) for api_key in list_api_keys()]
42
+
43
+ @staticmethod
44
+ def is_authenticated() -> bool:
45
+ """
46
+ Check if you are authenticated to interact with the Orca API
47
+
48
+ Returns:
49
+ True if you are authenticated, False otherwise
50
+ """
51
+ try:
52
+ return check_authentication()
53
+ except ValueError as e:
54
+ if "Invalid API key" in str(e):
55
+ return False
56
+ raise e
57
+
58
+ @staticmethod
59
+ def set_api_key(api_key: str, check_validity: bool = True):
60
+ """
61
+ Set the API key to use for authenticating with the Orca API
62
+
63
+ Note:
64
+ The API key can also be provided by setting the `ORCA_API_KEY` environment variable
65
+
66
+ Params:
67
+ api_key: The API key to set
68
+ check_validity: Whether to check if the API key is valid and raise an error otherwise
69
+
70
+ Raises:
71
+ ValueError: if the API key is invalid and `check_validity` is True
72
+ """
73
+ set_headers(get_headers() | {"Api-Key": api_key})
74
+ if check_validity:
75
+ check_authentication()
@@ -0,0 +1,37 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+
5
+ from .orca_credentials import OrcaCredentials
6
+
7
+
8
+ def test_list_api_keys():
9
+ api_keys = OrcaCredentials.list_api_keys()
10
+ assert len(api_keys) >= 1
11
+ assert "orca_sdk_test" in [api_key.name for api_key in api_keys]
12
+
13
+
14
+ def test_list_api_keys_unauthenticated(unauthenticated):
15
+ with pytest.raises(ValueError, match="Invalid API key"):
16
+ OrcaCredentials.list_api_keys()
17
+
18
+
19
+ def test_is_authenticated():
20
+ assert OrcaCredentials.is_authenticated()
21
+
22
+
23
+ def test_is_authenticated_false(unauthenticated):
24
+ assert not OrcaCredentials.is_authenticated()
25
+
26
+
27
+ def test_set_api_key(api_key, unauthenticated):
28
+ assert not OrcaCredentials.is_authenticated()
29
+ OrcaCredentials.set_api_key(api_key)
30
+ assert OrcaCredentials.is_authenticated()
31
+
32
+
33
+ def test_set_invalid_api_key(api_key):
34
+ assert OrcaCredentials.is_authenticated()
35
+ with pytest.raises(ValueError, match="Invalid API key"):
36
+ OrcaCredentials.set_api_key(str(uuid4()))
37
+ assert not OrcaCredentials.is_authenticated()