orca-sdk 0.0.78__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +24 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +205 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +130 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +172 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +132 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +129 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +185 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +172 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +172 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +163 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +129 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +192 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +169 -0
- orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +185 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +171 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +129 -0
- orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +237 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +120 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +120 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +191 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +129 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +185 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +170 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +171 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +190 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +171 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +186 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +262 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +129 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +195 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +190 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +189 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +194 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +163 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +129 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +156 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +158 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +245 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +164 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +159 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +129 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +177 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +173 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +183 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +179 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +116 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +137 -0
- orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +9 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +147 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +120 -0
- orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +9 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +145 -0
- orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +9 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +279 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +179 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +148 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +86 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/embedding_model_result.py +114 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +20 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +115 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +246 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +237 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +257 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memory_metrics.py +156 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +21 -0
- orca_sdk/_generated_api_client/models/precision_recall_curve.py +94 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +10 -0
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +12 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +94 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +192 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +68 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +508 -0
- orca_sdk/classification_model_test.py +272 -0
- orca_sdk/conftest.py +116 -0
- orca_sdk/credentials.py +126 -0
- orca_sdk/credentials_test.py +37 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +96 -0
- orca_sdk/embedding_model.py +347 -0
- orca_sdk/embedding_model_test.py +176 -0
- orca_sdk/memoryset.py +1209 -0
- orca_sdk/memoryset_test.py +287 -0
- orca_sdk/telemetry.py +398 -0
- orca_sdk/telemetry_test.py +109 -0
- orca_sdk-0.0.78.dist-info/METADATA +79 -0
- orca_sdk-0.0.78.dist-info/RECORD +188 -0
- orca_sdk-0.0.78.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets.arrow_dataset import Dataset
|
|
5
|
+
|
|
6
|
+
from .classification_model import ClassificationModel
|
|
7
|
+
from .datasource import Datasource
|
|
8
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
9
|
+
from .memoryset import LabeledMemoryset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
|
|
13
|
+
assert model is not None
|
|
14
|
+
assert model.name == "test_model"
|
|
15
|
+
assert model.memoryset == memoryset
|
|
16
|
+
assert model.num_classes == 2
|
|
17
|
+
assert model.memory_lookup_count == 3
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
|
|
21
|
+
with pytest.raises(ValueError):
|
|
22
|
+
ClassificationModel.create("test_model", memoryset)
|
|
23
|
+
with pytest.raises(ValueError):
|
|
24
|
+
ClassificationModel.create("test_model", memoryset, if_exists="error")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
|
|
28
|
+
with pytest.raises(ValueError):
|
|
29
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
|
|
30
|
+
|
|
31
|
+
with pytest.raises(ValueError):
|
|
32
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
|
|
33
|
+
|
|
34
|
+
with pytest.raises(ValueError):
|
|
35
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
|
|
36
|
+
|
|
37
|
+
with pytest.raises(ValueError):
|
|
38
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
|
|
39
|
+
|
|
40
|
+
new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
|
|
41
|
+
assert new_model is not None
|
|
42
|
+
assert new_model.name == "test_model"
|
|
43
|
+
assert new_model.memoryset == memoryset
|
|
44
|
+
assert new_model.num_classes == 2
|
|
45
|
+
assert new_model.memory_lookup_count == 3
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
|
|
49
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
+
ClassificationModel.create("test_model", memoryset)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_get_model(model: ClassificationModel):
|
|
54
|
+
fetched_model = ClassificationModel.open(model.name)
|
|
55
|
+
assert fetched_model is not None
|
|
56
|
+
assert fetched_model.id == model.id
|
|
57
|
+
assert fetched_model.name == model.name
|
|
58
|
+
assert fetched_model.num_classes == 2
|
|
59
|
+
assert fetched_model.memory_lookup_count == 3
|
|
60
|
+
assert fetched_model == model
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_get_model_unauthenticated(unauthenticated):
|
|
64
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
65
|
+
ClassificationModel.open("test_model")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_get_model_invalid_input():
|
|
69
|
+
with pytest.raises(ValueError, match="Invalid input"):
|
|
70
|
+
ClassificationModel.open("not valid id")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_get_model_not_found():
|
|
74
|
+
with pytest.raises(LookupError):
|
|
75
|
+
ClassificationModel.open(str(uuid4()))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
|
|
79
|
+
with pytest.raises(LookupError):
|
|
80
|
+
ClassificationModel.open(model.name)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_list_models(model: ClassificationModel):
|
|
84
|
+
models = ClassificationModel.all()
|
|
85
|
+
assert len(models) > 0
|
|
86
|
+
assert any(model.name == model.name for model in models)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_list_models_unauthenticated(unauthenticated):
|
|
90
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
91
|
+
ClassificationModel.all()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
|
|
95
|
+
assert ClassificationModel.all() == []
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_delete_model(memoryset: LabeledMemoryset):
|
|
99
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
|
|
100
|
+
assert ClassificationModel.open("model_to_delete")
|
|
101
|
+
ClassificationModel.drop("model_to_delete")
|
|
102
|
+
with pytest.raises(LookupError):
|
|
103
|
+
ClassificationModel.open("model_to_delete")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
|
|
107
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
+
ClassificationModel.drop(model.name)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_delete_model_not_found():
|
|
112
|
+
with pytest.raises(LookupError):
|
|
113
|
+
ClassificationModel.drop(str(uuid4()))
|
|
114
|
+
# ignores error if specified
|
|
115
|
+
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
|
|
119
|
+
with pytest.raises(LookupError):
|
|
120
|
+
ClassificationModel.drop(model.name)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
124
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
|
|
125
|
+
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
126
|
+
with pytest.raises(RuntimeError):
|
|
127
|
+
LabeledMemoryset.drop(memoryset.id)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_evaluate(model):
|
|
131
|
+
eval_datasource = Datasource.from_list(
|
|
132
|
+
"eval_datasource",
|
|
133
|
+
[
|
|
134
|
+
{"text": "chicken noodle soup is the best", "label": 1},
|
|
135
|
+
{"text": "cats are cute", "label": 0},
|
|
136
|
+
{"text": "soup is great for the winter", "label": 0},
|
|
137
|
+
{"text": "i love cats", "label": 1},
|
|
138
|
+
],
|
|
139
|
+
)
|
|
140
|
+
result = model.evaluate(eval_datasource, value_column="text")
|
|
141
|
+
assert result is not None
|
|
142
|
+
assert isinstance(result, dict)
|
|
143
|
+
assert isinstance(result["accuracy"], float)
|
|
144
|
+
assert isinstance(result["f1_score"], float)
|
|
145
|
+
assert isinstance(result["loss"], float)
|
|
146
|
+
assert len(result["precision_recall_curve"]["thresholds"]) == 4
|
|
147
|
+
assert len(result["precision_recall_curve"]["precisions"]) == 4
|
|
148
|
+
assert len(result["precision_recall_curve"]["recalls"]) == 4
|
|
149
|
+
assert len(result["roc_curve"]["thresholds"]) == 4
|
|
150
|
+
assert len(result["roc_curve"]["false_positive_rates"]) == 4
|
|
151
|
+
assert len(result["roc_curve"]["true_positive_rates"]) == 4
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_evaluate_with_telemetry(model):
|
|
155
|
+
samples = [
|
|
156
|
+
{"text": "chicken noodle soup is the best", "label": 1},
|
|
157
|
+
{"text": "cats are cute", "label": 0},
|
|
158
|
+
]
|
|
159
|
+
eval_datasource = Datasource.from_list("eval_datasource_2", samples)
|
|
160
|
+
result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
|
|
161
|
+
assert result is not None
|
|
162
|
+
predictions = model.predictions(tag="test")
|
|
163
|
+
assert len(predictions) == 2
|
|
164
|
+
assert all(p.tags == {"test"} for p in predictions)
|
|
165
|
+
assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
169
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
170
|
+
assert len(predictions) == 2
|
|
171
|
+
assert predictions[0].label == 0
|
|
172
|
+
assert predictions[0].label_name == label_names[0]
|
|
173
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
174
|
+
assert predictions[1].label == 1
|
|
175
|
+
assert predictions[1].label_name == label_names[1]
|
|
176
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
|
|
180
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
181
|
+
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def test_predict_unauthorized(unauthorized, model: ClassificationModel):
|
|
185
|
+
with pytest.raises(LookupError):
|
|
186
|
+
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_predict_constraint_violation(memoryset: LabeledMemoryset):
|
|
190
|
+
model = ClassificationModel.create(
|
|
191
|
+
"test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
|
|
192
|
+
)
|
|
193
|
+
with pytest.raises(RuntimeError):
|
|
194
|
+
model.predict("test")
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_record_prediction_feedback(model: ClassificationModel):
|
|
198
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
199
|
+
expected_labels = [0, 1]
|
|
200
|
+
model.record_feedback(
|
|
201
|
+
{
|
|
202
|
+
"prediction_id": p.prediction_id,
|
|
203
|
+
"category": "correct",
|
|
204
|
+
"value": p.label == expected_label,
|
|
205
|
+
}
|
|
206
|
+
for expected_label, p in zip(expected_labels, predictions)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def test_record_prediction_feedback_missing_category(model: ClassificationModel):
|
|
211
|
+
prediction = model.predict("Do you love soup?")
|
|
212
|
+
with pytest.raises(ValueError):
|
|
213
|
+
model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
|
|
217
|
+
prediction = model.predict("Do you love soup?")
|
|
218
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
219
|
+
model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
|
|
223
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
224
|
+
model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
|
|
228
|
+
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
229
|
+
"test_memoryset_inverted_labels",
|
|
230
|
+
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
231
|
+
value_column="text",
|
|
232
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
233
|
+
)
|
|
234
|
+
with model.use_memoryset(inverted_labeled_memoryset):
|
|
235
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
236
|
+
assert predictions[0].label == 1
|
|
237
|
+
assert predictions[1].label == 0
|
|
238
|
+
|
|
239
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
240
|
+
assert predictions[0].label == 0
|
|
241
|
+
assert predictions[1].label == 1
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def test_predict_with_expected_labels(model: ClassificationModel):
|
|
245
|
+
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
246
|
+
assert prediction.expected_label == 1
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
|
|
250
|
+
# invalid number of expected labels for batch prediction
|
|
251
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
252
|
+
model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
253
|
+
# invalid label value
|
|
254
|
+
with pytest.raises(ValueError):
|
|
255
|
+
model.predict("Do you love soup?", expected_labels=5)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def test_last_prediction_with_batch(model: ClassificationModel):
|
|
259
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
260
|
+
assert model.last_prediction is not None
|
|
261
|
+
assert model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
262
|
+
assert model.last_prediction.input_value == "Are cats cute?"
|
|
263
|
+
assert model._last_prediction_was_batch is True
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def test_last_prediction_with_single(model: ClassificationModel):
|
|
267
|
+
# Test that last_prediction is updated correctly with single prediction
|
|
268
|
+
prediction = model.predict("Do you love soup?")
|
|
269
|
+
assert model.last_prediction is not None
|
|
270
|
+
assert model.last_prediction.prediction_id == prediction.prediction_id
|
|
271
|
+
assert model.last_prediction.input_value == "Do you love soup?"
|
|
272
|
+
assert model._last_prediction_was_batch is False
|
orca_sdk/conftest.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Generator
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from datasets import ClassLabel, Dataset, Features, Value
|
|
8
|
+
|
|
9
|
+
from ._utils.auth import _create_api_key, _delete_org
|
|
10
|
+
from .classification_model import ClassificationModel
|
|
11
|
+
from .credentials import OrcaCredentials
|
|
12
|
+
from .datasource import Datasource
|
|
13
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
14
|
+
from .memoryset import LabeledMemoryset
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
|
|
18
|
+
os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _create_org_id():
|
|
22
|
+
# UUID start to identify test data (0xtest...)
|
|
23
|
+
return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture(scope="session")
|
|
27
|
+
def org_id():
|
|
28
|
+
return _create_org_id()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@pytest.fixture(autouse=True, scope="session")
|
|
32
|
+
def api_key(org_id) -> Generator[str, None, None]:
|
|
33
|
+
api_key = _create_api_key(org_id=org_id, name="orca_sdk_test")
|
|
34
|
+
OrcaCredentials.set_api_key(api_key, check_validity=True)
|
|
35
|
+
yield api_key
|
|
36
|
+
_delete_org(org_id)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture(autouse=True)
|
|
40
|
+
def authenticated(api_key):
|
|
41
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture()
|
|
45
|
+
def unauthenticated(api_key):
|
|
46
|
+
OrcaCredentials.set_api_key(str(uuid4()), check_validity=False)
|
|
47
|
+
yield
|
|
48
|
+
# Need to reset the api key to the original api key so following tests don't fail
|
|
49
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.fixture()
|
|
53
|
+
def other_org_id():
|
|
54
|
+
return _create_org_id()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.fixture()
|
|
58
|
+
def unauthorized(api_key, other_org_id):
|
|
59
|
+
different_api_key = _create_api_key(org_id=other_org_id, name="orca_sdk_test_other_org")
|
|
60
|
+
OrcaCredentials.set_api_key(different_api_key, check_validity=False)
|
|
61
|
+
yield
|
|
62
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
63
|
+
_delete_org(other_org_id)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.fixture(scope="session")
|
|
67
|
+
def label_names():
|
|
68
|
+
return ["soup", "cats"]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
SAMPLE_DATA = [
|
|
72
|
+
{"text": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
|
|
73
|
+
{"text": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
|
|
74
|
+
{"text": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
|
|
75
|
+
{"text": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
|
|
76
|
+
{"text": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
|
|
77
|
+
{"text": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.fixture(scope="session")
|
|
82
|
+
def hf_dataset(label_names):
|
|
83
|
+
return Dataset.from_list(
|
|
84
|
+
SAMPLE_DATA,
|
|
85
|
+
features=Features(
|
|
86
|
+
{
|
|
87
|
+
"text": Value("string"),
|
|
88
|
+
"label": ClassLabel(names=label_names),
|
|
89
|
+
"key": Value("string"),
|
|
90
|
+
"score": Value("float"),
|
|
91
|
+
"source_id": Value("string"),
|
|
92
|
+
}
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.fixture(scope="session")
|
|
98
|
+
def datasource(hf_dataset) -> Datasource:
|
|
99
|
+
return Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@pytest.fixture(scope="session")
|
|
103
|
+
def memoryset(datasource) -> LabeledMemoryset:
|
|
104
|
+
return LabeledMemoryset.create(
|
|
105
|
+
"test_memoryset",
|
|
106
|
+
datasource=datasource,
|
|
107
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
108
|
+
value_column="text",
|
|
109
|
+
source_id_column="source_id",
|
|
110
|
+
max_seq_length_override=32,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.fixture(scope="session")
|
|
115
|
+
def model(memoryset) -> ClassificationModel:
|
|
116
|
+
return ClassificationModel.create("test_model", memoryset, num_classes=2, memory_lookup_count=3)
|
orca_sdk/credentials.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Literal, NamedTuple
|
|
3
|
+
|
|
4
|
+
from ._generated_api_client.api import (
|
|
5
|
+
check_authentication,
|
|
6
|
+
create_api_key,
|
|
7
|
+
delete_api_key,
|
|
8
|
+
list_api_keys,
|
|
9
|
+
)
|
|
10
|
+
from ._generated_api_client.client import get_base_url, get_headers, set_headers
|
|
11
|
+
from ._generated_api_client.models import (
|
|
12
|
+
CreateApiKeyRequest,
|
|
13
|
+
CreateApiKeyRequestScopeItem,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
Scope = Literal["ADMINISTER", "PREDICT"]
|
|
17
|
+
"""
|
|
18
|
+
The scopes of an API key.
|
|
19
|
+
|
|
20
|
+
- `ADMINISTER`: Can do anything, including creating and deleting organizations, models, and API keys.
|
|
21
|
+
- `PREDICT`: Can only call model.predict and perform CRUD operations on predictions.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ApiKeyInfo(NamedTuple):
|
|
26
|
+
"""
|
|
27
|
+
Named tuple containing information about an API key
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
name: Unique name of the API key
|
|
31
|
+
created_at: When the API key was created
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
name: str
|
|
35
|
+
created_at: datetime
|
|
36
|
+
scopes: set[Scope]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OrcaCredentials:
|
|
40
|
+
"""
|
|
41
|
+
Class for managing Orca API credentials
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_api_url() -> str:
|
|
46
|
+
"""
|
|
47
|
+
Get the Orca API base URL that is currently being used
|
|
48
|
+
"""
|
|
49
|
+
return get_base_url()
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def list_api_keys() -> list[ApiKeyInfo]:
|
|
53
|
+
"""
|
|
54
|
+
List all API keys that have been created for your org
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A list of named tuples, with the name and creation date time of the API key
|
|
58
|
+
"""
|
|
59
|
+
return [
|
|
60
|
+
ApiKeyInfo(name=api_key.name, created_at=api_key.created_at, scopes=set(s.value for s in api_key.scope))
|
|
61
|
+
for api_key in list_api_keys()
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def is_authenticated() -> bool:
|
|
66
|
+
"""
|
|
67
|
+
Check if you are authenticated to interact with the Orca API
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
True if you are authenticated, False otherwise
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
return check_authentication()
|
|
74
|
+
except ValueError as e:
|
|
75
|
+
if "Invalid API key" in str(e):
|
|
76
|
+
return False
|
|
77
|
+
raise e
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def create_api_key(name: str, scopes: set[Scope] = {"ADMINISTER"}) -> str:
|
|
81
|
+
"""
|
|
82
|
+
Create a new API key with the given name and scopes
|
|
83
|
+
|
|
84
|
+
Params:
|
|
85
|
+
name: The name of the API key
|
|
86
|
+
scopes: The scopes of the API key
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The secret value of the API key. Make sure to save this value as it will not be shown again.
|
|
90
|
+
"""
|
|
91
|
+
res = create_api_key(
|
|
92
|
+
body=CreateApiKeyRequest(name=name, scope=[CreateApiKeyRequestScopeItem(scope) for scope in scopes])
|
|
93
|
+
)
|
|
94
|
+
return res.api_key
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def revoke_api_key(name: str) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Delete an API key
|
|
100
|
+
|
|
101
|
+
Params:
|
|
102
|
+
name: The name of the API key to delete
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ValueError: if the API key is not found
|
|
106
|
+
"""
|
|
107
|
+
delete_api_key(name_or_id=name)
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def set_api_key(api_key: str, check_validity: bool = True):
|
|
111
|
+
"""
|
|
112
|
+
Set the API key to use for authenticating with the Orca API
|
|
113
|
+
|
|
114
|
+
Note:
|
|
115
|
+
The API key can also be provided by setting the `ORCA_API_KEY` environment variable
|
|
116
|
+
|
|
117
|
+
Params:
|
|
118
|
+
api_key: The API key to set
|
|
119
|
+
check_validity: Whether to check if the API key is valid and raise an error otherwise
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: if the API key is invalid and `check_validity` is True
|
|
123
|
+
"""
|
|
124
|
+
set_headers(get_headers() | {"Api-Key": api_key})
|
|
125
|
+
if check_validity:
|
|
126
|
+
check_authentication()
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from .credentials import OrcaCredentials
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_list_api_keys():
|
|
9
|
+
api_keys = OrcaCredentials.list_api_keys()
|
|
10
|
+
assert len(api_keys) >= 1
|
|
11
|
+
assert "orca_sdk_test" in [api_key.name for api_key in api_keys]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_list_api_keys_unauthenticated(unauthenticated):
|
|
15
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
16
|
+
OrcaCredentials.list_api_keys()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_is_authenticated():
|
|
20
|
+
assert OrcaCredentials.is_authenticated()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_is_authenticated_false(unauthenticated):
|
|
24
|
+
assert not OrcaCredentials.is_authenticated()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_set_api_key(api_key, unauthenticated):
|
|
28
|
+
assert not OrcaCredentials.is_authenticated()
|
|
29
|
+
OrcaCredentials.set_api_key(api_key)
|
|
30
|
+
assert OrcaCredentials.is_authenticated()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_set_invalid_api_key(api_key):
|
|
34
|
+
assert OrcaCredentials.is_authenticated()
|
|
35
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
36
|
+
OrcaCredentials.set_api_key(str(uuid4()))
|
|
37
|
+
assert not OrcaCredentials.is_authenticated()
|