orca-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +19 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +193 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +159 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +194 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +63 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +499 -0
- orca_sdk/classification_model_test.py +266 -0
- orca_sdk/conftest.py +117 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +95 -0
- orca_sdk/embedding_model.py +336 -0
- orca_sdk/embedding_model_test.py +173 -0
- orca_sdk/labeled_memoryset.py +1154 -0
- orca_sdk/labeled_memoryset_test.py +271 -0
- orca_sdk/orca_credentials.py +75 -0
- orca_sdk/orca_credentials_test.py +37 -0
- orca_sdk/telemetry.py +386 -0
- orca_sdk/telemetry_test.py +100 -0
- orca_sdk-0.1.0.dist-info/METADATA +39 -0
- orca_sdk-0.1.0.dist-info/RECORD +175 -0
- orca_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets.arrow_dataset import Dataset
|
|
5
|
+
|
|
6
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
7
|
+
from .labeled_memoryset import LabeledMemoryset, TaskStatus
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_create_memoryset(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
11
|
+
assert memoryset is not None
|
|
12
|
+
assert memoryset.name == "test_memoryset"
|
|
13
|
+
assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
14
|
+
assert memoryset.label_names == label_names
|
|
15
|
+
assert memoryset.insertion_status == TaskStatus.COMPLETED
|
|
16
|
+
assert isinstance(memoryset.length, int)
|
|
17
|
+
assert memoryset.length == len(hf_dataset)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_create_memoryset_unauthenticated(unauthenticated, datasource):
|
|
21
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
22
|
+
LabeledMemoryset.create("test_memoryset", datasource)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_create_memoryset_invalid_input(datasource):
|
|
26
|
+
# invalid name
|
|
27
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
28
|
+
LabeledMemoryset.create("test memoryset", datasource)
|
|
29
|
+
# invalid datasource
|
|
30
|
+
datasource.id = str(uuid4())
|
|
31
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
32
|
+
LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_create_memoryset_already_exists_error(hf_dataset, label_names, memoryset):
|
|
36
|
+
with pytest.raises(ValueError):
|
|
37
|
+
LabeledMemoryset.from_hf_dataset("test_memoryset", hf_dataset, label_names=label_names, value_column="text")
|
|
38
|
+
with pytest.raises(ValueError):
|
|
39
|
+
LabeledMemoryset.from_hf_dataset(
|
|
40
|
+
"test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_create_memoryset_already_exists_open(hf_dataset, label_names, memoryset):
|
|
45
|
+
# invalid label names
|
|
46
|
+
with pytest.raises(ValueError):
|
|
47
|
+
LabeledMemoryset.from_hf_dataset(
|
|
48
|
+
memoryset.name,
|
|
49
|
+
hf_dataset,
|
|
50
|
+
label_names=["turtles", "frogs"],
|
|
51
|
+
value_column="text",
|
|
52
|
+
if_exists="open",
|
|
53
|
+
)
|
|
54
|
+
# different embedding model
|
|
55
|
+
with pytest.raises(ValueError):
|
|
56
|
+
LabeledMemoryset.from_hf_dataset(
|
|
57
|
+
memoryset.name,
|
|
58
|
+
hf_dataset,
|
|
59
|
+
label_names=label_names,
|
|
60
|
+
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
61
|
+
if_exists="open",
|
|
62
|
+
)
|
|
63
|
+
opened_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
64
|
+
memoryset.name,
|
|
65
|
+
hf_dataset,
|
|
66
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
67
|
+
if_exists="open",
|
|
68
|
+
)
|
|
69
|
+
assert opened_memoryset is not None
|
|
70
|
+
assert opened_memoryset.name == memoryset.name
|
|
71
|
+
assert opened_memoryset.length == len(hf_dataset)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_open_memoryset(memoryset, hf_dataset):
|
|
75
|
+
fetched_memoryset = LabeledMemoryset.open(memoryset.name)
|
|
76
|
+
assert fetched_memoryset is not None
|
|
77
|
+
assert fetched_memoryset.name == memoryset.name
|
|
78
|
+
assert fetched_memoryset.length == len(hf_dataset)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_open_memoryset_unauthenticated(unauthenticated, memoryset):
|
|
82
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
83
|
+
LabeledMemoryset.open(memoryset.name)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def test_open_memoryset_not_found():
|
|
87
|
+
with pytest.raises(LookupError):
|
|
88
|
+
LabeledMemoryset.open(str(uuid4()))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def test_open_memoryset_invalid_input():
|
|
92
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
93
|
+
LabeledMemoryset.open("not valid id")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_open_memoryset_unauthorized(unauthorized, memoryset):
|
|
97
|
+
with pytest.raises(LookupError):
|
|
98
|
+
LabeledMemoryset.open(memoryset.name)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_all_memorysets(memoryset):
|
|
102
|
+
memorysets = LabeledMemoryset.all()
|
|
103
|
+
assert len(memorysets) > 0
|
|
104
|
+
assert any(memoryset.name == memoryset.name for memoryset in memorysets)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_all_memorysets_unauthenticated(unauthenticated):
|
|
108
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
109
|
+
LabeledMemoryset.all()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_all_memorysets_unauthorized(unauthorized, memoryset):
|
|
113
|
+
assert memoryset not in LabeledMemoryset.all()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@pytest.mark.flaky
|
|
117
|
+
def test_drop_memoryset(hf_dataset):
|
|
118
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
119
|
+
"test_memoryset_delete",
|
|
120
|
+
hf_dataset.select(range(1)),
|
|
121
|
+
value_column="text",
|
|
122
|
+
)
|
|
123
|
+
assert LabeledMemoryset.exists(memoryset.name)
|
|
124
|
+
LabeledMemoryset.drop(memoryset.name)
|
|
125
|
+
assert not LabeledMemoryset.exists(memoryset.name)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
|
|
129
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
130
|
+
LabeledMemoryset.drop(memoryset.name)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_drop_memoryset_not_found(memoryset):
|
|
134
|
+
with pytest.raises(LookupError):
|
|
135
|
+
LabeledMemoryset.drop(str(uuid4()))
|
|
136
|
+
# ignores error if specified
|
|
137
|
+
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_drop_memoryset_unauthorized(unauthorized, memoryset):
|
|
141
|
+
with pytest.raises(LookupError):
|
|
142
|
+
LabeledMemoryset.drop(memoryset.name)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_search(memoryset: LabeledMemoryset):
|
|
146
|
+
memory_lookups = memoryset.search(["i love soup", "cats are cute"])
|
|
147
|
+
assert len(memory_lookups) == 2
|
|
148
|
+
assert len(memory_lookups[0]) == 1
|
|
149
|
+
assert len(memory_lookups[1]) == 1
|
|
150
|
+
assert memory_lookups[0][0].label == 0
|
|
151
|
+
assert memory_lookups[1][0].label == 1
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_search_count(memoryset: LabeledMemoryset):
|
|
155
|
+
memory_lookups = memoryset.search("i love soup", count=3)
|
|
156
|
+
assert len(memory_lookups) == 3
|
|
157
|
+
assert memory_lookups[0].label == 0
|
|
158
|
+
assert memory_lookups[1].label == 0
|
|
159
|
+
assert memory_lookups[2].label == 0
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_get_memory_at_index(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
163
|
+
memory = memoryset[0]
|
|
164
|
+
assert memory.value == hf_dataset[0]["text"]
|
|
165
|
+
assert memory.label == hf_dataset[0]["label"]
|
|
166
|
+
assert memory.label_name == label_names[hf_dataset[0]["label"]]
|
|
167
|
+
assert memory.source_id == hf_dataset[0]["source_id"]
|
|
168
|
+
assert memory.score == hf_dataset[0]["score"]
|
|
169
|
+
assert memory.key == hf_dataset[0]["key"]
|
|
170
|
+
last_memory = memoryset[-1]
|
|
171
|
+
assert last_memory.value == hf_dataset[-1]["text"]
|
|
172
|
+
assert last_memory.label == hf_dataset[-1]["label"]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def test_get_range_of_memories(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
176
|
+
memories = memoryset[1:3]
|
|
177
|
+
assert len(memories) == 2
|
|
178
|
+
assert memories[0].value == hf_dataset["text"][1]
|
|
179
|
+
assert memories[1].value == hf_dataset["text"][2]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def test_get_memory_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
183
|
+
memory = memoryset.get(memoryset[0].memory_id)
|
|
184
|
+
assert memory.value == hf_dataset[0]["text"]
|
|
185
|
+
assert memory == memoryset[memory.memory_id]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_get_memories_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
189
|
+
memories = memoryset.get([memoryset[0].memory_id, memoryset[1].memory_id])
|
|
190
|
+
assert len(memories) == 2
|
|
191
|
+
assert memories[0].value == hf_dataset[0]["text"]
|
|
192
|
+
assert memories[1].value == hf_dataset[1]["text"]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_query_memoryset(memoryset: LabeledMemoryset):
|
|
196
|
+
memories = memoryset.query(filters=[("label", "==", 1)])
|
|
197
|
+
assert len(memories) == 3
|
|
198
|
+
assert all(memory.label == 1 for memory in memories)
|
|
199
|
+
assert len(memoryset.query(limit=2)) == 2
|
|
200
|
+
assert len(memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_insert_memories(memoryset: LabeledMemoryset):
|
|
204
|
+
prev_length = memoryset.length
|
|
205
|
+
memoryset.insert(
|
|
206
|
+
[
|
|
207
|
+
dict(value="tomato soup is my favorite", label=0),
|
|
208
|
+
dict(value="cats are fun to play with", label=1),
|
|
209
|
+
]
|
|
210
|
+
)
|
|
211
|
+
assert memoryset.length == prev_length + 2
|
|
212
|
+
memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
213
|
+
assert memoryset.length == prev_length + 3
|
|
214
|
+
last_memory = memoryset[-1]
|
|
215
|
+
assert last_memory.value == "tomato soup is my favorite"
|
|
216
|
+
assert last_memory.label == 0
|
|
217
|
+
assert last_memory.metadata
|
|
218
|
+
assert last_memory.metadata["key"] == "test"
|
|
219
|
+
assert last_memory.source_id == "test"
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_update_memory(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
223
|
+
memory_id = memoryset[0].memory_id
|
|
224
|
+
updated_memory = memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
225
|
+
assert updated_memory.value == "i love soup so much"
|
|
226
|
+
assert updated_memory.label == hf_dataset[0]["label"]
|
|
227
|
+
assert memoryset.get(memory_id).value == "i love soup so much"
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def test_update_memory_instance(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
231
|
+
memory = memoryset[0]
|
|
232
|
+
updated_memory = memory.update(value="i love soup even more")
|
|
233
|
+
assert updated_memory is memory
|
|
234
|
+
assert memory.value == "i love soup even more"
|
|
235
|
+
assert memory.label == hf_dataset[0]["label"]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_update_memories(memoryset: LabeledMemoryset):
|
|
239
|
+
memory_ids = [memory.memory_id for memory in memoryset[:2]]
|
|
240
|
+
updated_memories = memoryset.update(
|
|
241
|
+
[
|
|
242
|
+
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
243
|
+
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
244
|
+
]
|
|
245
|
+
)
|
|
246
|
+
assert updated_memories[0].value == "i love soup so much"
|
|
247
|
+
assert updated_memories[1].value == "cats are so cute"
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def test_delete_memory(memoryset: LabeledMemoryset):
|
|
251
|
+
prev_length = memoryset.length
|
|
252
|
+
memory_id = memoryset[0].memory_id
|
|
253
|
+
memoryset.delete(memory_id)
|
|
254
|
+
with pytest.raises(LookupError):
|
|
255
|
+
memoryset.get(memory_id)
|
|
256
|
+
assert memoryset.length == prev_length - 1
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def test_delete_memories(memoryset: LabeledMemoryset):
|
|
260
|
+
prev_length = memoryset.length
|
|
261
|
+
memoryset.delete([memoryset[0].memory_id, memoryset[1].memory_id])
|
|
262
|
+
assert memoryset.length == prev_length - 2
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def test_clone_memoryset(memoryset: LabeledMemoryset):
|
|
266
|
+
cloned_memoryset = memoryset.clone("test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT)
|
|
267
|
+
assert cloned_memoryset is not None
|
|
268
|
+
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
269
|
+
assert cloned_memoryset.length == memoryset.length
|
|
270
|
+
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
271
|
+
assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
from ._generated_api_client.api import check_authentication, list_api_keys
|
|
5
|
+
from ._generated_api_client.client import get_base_url, get_headers, set_headers
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ApiKeyInfo(NamedTuple):
|
|
9
|
+
"""
|
|
10
|
+
Named tuple containing information about an API key
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
name: Unique name of the API key
|
|
14
|
+
created_at: When the API key was created
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
name: str
|
|
18
|
+
created_at: datetime
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OrcaCredentials:
|
|
22
|
+
"""
|
|
23
|
+
Class for managing Orca API credentials
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def get_api_url() -> str:
|
|
28
|
+
"""
|
|
29
|
+
Get the Orca API base URL that is currently being used
|
|
30
|
+
"""
|
|
31
|
+
return get_base_url()
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def list_api_keys() -> list[ApiKeyInfo]:
|
|
35
|
+
"""
|
|
36
|
+
List all API keys that have been created for your org
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of named tuples, with the name and creation date time of the API key
|
|
40
|
+
"""
|
|
41
|
+
return [ApiKeyInfo(name=api_key.name, created_at=api_key.created_at) for api_key in list_api_keys()]
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def is_authenticated() -> bool:
|
|
45
|
+
"""
|
|
46
|
+
Check if you are authenticated to interact with the Orca API
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
True if you are authenticated, False otherwise
|
|
50
|
+
"""
|
|
51
|
+
try:
|
|
52
|
+
return check_authentication()
|
|
53
|
+
except ValueError as e:
|
|
54
|
+
if "Invalid API key" in str(e):
|
|
55
|
+
return False
|
|
56
|
+
raise e
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def set_api_key(api_key: str, check_validity: bool = True):
|
|
60
|
+
"""
|
|
61
|
+
Set the API key to use for authenticating with the Orca API
|
|
62
|
+
|
|
63
|
+
Note:
|
|
64
|
+
The API key can also be provided by setting the `ORCA_API_KEY` environment variable
|
|
65
|
+
|
|
66
|
+
Params:
|
|
67
|
+
api_key: The API key to set
|
|
68
|
+
check_validity: Whether to check if the API key is valid and raise an error otherwise
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValueError: if the API key is invalid and `check_validity` is True
|
|
72
|
+
"""
|
|
73
|
+
set_headers(get_headers() | {"Api-Key": api_key})
|
|
74
|
+
if check_validity:
|
|
75
|
+
check_authentication()
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from .orca_credentials import OrcaCredentials
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_list_api_keys():
|
|
9
|
+
api_keys = OrcaCredentials.list_api_keys()
|
|
10
|
+
assert len(api_keys) >= 1
|
|
11
|
+
assert "orca_sdk_test" in [api_key.name for api_key in api_keys]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_list_api_keys_unauthenticated(unauthenticated):
|
|
15
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
16
|
+
OrcaCredentials.list_api_keys()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_is_authenticated():
|
|
20
|
+
assert OrcaCredentials.is_authenticated()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_is_authenticated_false(unauthenticated):
|
|
24
|
+
assert not OrcaCredentials.is_authenticated()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_set_api_key(api_key, unauthenticated):
|
|
28
|
+
assert not OrcaCredentials.is_authenticated()
|
|
29
|
+
OrcaCredentials.set_api_key(api_key)
|
|
30
|
+
assert OrcaCredentials.is_authenticated()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_set_invalid_api_key(api_key):
|
|
34
|
+
assert OrcaCredentials.is_authenticated()
|
|
35
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
36
|
+
OrcaCredentials.set_api_key(str(uuid4()))
|
|
37
|
+
assert not OrcaCredentials.is_authenticated()
|