orca-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +19 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +193 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +159 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +194 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +63 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +499 -0
- orca_sdk/classification_model_test.py +266 -0
- orca_sdk/conftest.py +117 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +95 -0
- orca_sdk/embedding_model.py +336 -0
- orca_sdk/embedding_model_test.py +173 -0
- orca_sdk/labeled_memoryset.py +1154 -0
- orca_sdk/labeled_memoryset_test.py +271 -0
- orca_sdk/orca_credentials.py +75 -0
- orca_sdk/orca_credentials_test.py +37 -0
- orca_sdk/telemetry.py +386 -0
- orca_sdk/telemetry_test.py +100 -0
- orca_sdk-0.1.0.dist-info/METADATA +39 -0
- orca_sdk-0.1.0.dist-info/RECORD +175 -0
- orca_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets.arrow_dataset import Dataset
|
|
5
|
+
|
|
6
|
+
from .classification_model import ClassificationModel
|
|
7
|
+
from .datasource import Datasource
|
|
8
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
9
|
+
from .labeled_memoryset import LabeledMemoryset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
|
|
13
|
+
assert model is not None
|
|
14
|
+
assert model.name == "test_model"
|
|
15
|
+
assert model.memoryset == memoryset
|
|
16
|
+
assert model.num_classes == 2
|
|
17
|
+
assert model.memory_lookup_count == 3
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
|
|
21
|
+
with pytest.raises(ValueError):
|
|
22
|
+
ClassificationModel.create("test_model", memoryset)
|
|
23
|
+
with pytest.raises(ValueError):
|
|
24
|
+
ClassificationModel.create("test_model", memoryset, if_exists="error")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
|
|
28
|
+
with pytest.raises(ValueError):
|
|
29
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
|
|
30
|
+
|
|
31
|
+
with pytest.raises(ValueError):
|
|
32
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
|
|
33
|
+
|
|
34
|
+
with pytest.raises(ValueError):
|
|
35
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
|
|
36
|
+
|
|
37
|
+
with pytest.raises(ValueError):
|
|
38
|
+
ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
|
|
39
|
+
|
|
40
|
+
new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
|
|
41
|
+
assert new_model is not None
|
|
42
|
+
assert new_model.name == "test_model"
|
|
43
|
+
assert new_model.memoryset == memoryset
|
|
44
|
+
assert new_model.num_classes == 2
|
|
45
|
+
assert new_model.memory_lookup_count == 3
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
|
|
49
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
+
ClassificationModel.create("test_model", memoryset)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_get_model(model: ClassificationModel):
|
|
54
|
+
fetched_model = ClassificationModel.open(model.name)
|
|
55
|
+
assert fetched_model is not None
|
|
56
|
+
assert fetched_model.id == model.id
|
|
57
|
+
assert fetched_model.name == model.name
|
|
58
|
+
assert fetched_model.num_classes == 2
|
|
59
|
+
assert fetched_model.memory_lookup_count == 3
|
|
60
|
+
assert fetched_model == model
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_get_model_unauthenticated(unauthenticated):
|
|
64
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
65
|
+
ClassificationModel.open("test_model")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_get_model_invalid_input():
|
|
69
|
+
with pytest.raises(ValueError, match="Invalid input"):
|
|
70
|
+
ClassificationModel.open("not valid id")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_get_model_not_found():
|
|
74
|
+
with pytest.raises(LookupError):
|
|
75
|
+
ClassificationModel.open(str(uuid4()))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
|
|
79
|
+
with pytest.raises(LookupError):
|
|
80
|
+
ClassificationModel.open(model.name)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_list_models(model: ClassificationModel):
|
|
84
|
+
models = ClassificationModel.all()
|
|
85
|
+
assert len(models) > 0
|
|
86
|
+
assert any(model.name == model.name for model in models)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_list_models_unauthenticated(unauthenticated):
|
|
90
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
91
|
+
ClassificationModel.all()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
|
|
95
|
+
assert ClassificationModel.all() == []
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_delete_model(memoryset: LabeledMemoryset):
|
|
99
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
|
|
100
|
+
assert ClassificationModel.open("model_to_delete")
|
|
101
|
+
ClassificationModel.drop("model_to_delete")
|
|
102
|
+
with pytest.raises(LookupError):
|
|
103
|
+
ClassificationModel.open("model_to_delete")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
|
|
107
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
+
ClassificationModel.drop(model.name)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_delete_model_not_found():
|
|
112
|
+
with pytest.raises(LookupError):
|
|
113
|
+
ClassificationModel.drop(str(uuid4()))
|
|
114
|
+
# ignores error if specified
|
|
115
|
+
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
|
|
119
|
+
with pytest.raises(LookupError):
|
|
120
|
+
ClassificationModel.drop(model.name)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@pytest.mark.flaky
|
|
124
|
+
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
125
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
|
|
126
|
+
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
127
|
+
with pytest.raises(RuntimeError):
|
|
128
|
+
LabeledMemoryset.drop(memoryset.id)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_evaluate(model):
|
|
132
|
+
eval_datasource = Datasource.from_list(
|
|
133
|
+
"eval_datasource",
|
|
134
|
+
[
|
|
135
|
+
{"text": "chicken noodle soup is the best", "label": 1},
|
|
136
|
+
{"text": "cats are cute", "label": 0},
|
|
137
|
+
{"text": "soup is great for the winter", "label": 0},
|
|
138
|
+
{"text": "i love cats", "label": 1},
|
|
139
|
+
],
|
|
140
|
+
)
|
|
141
|
+
result = model.evaluate(eval_datasource, value_column="text")
|
|
142
|
+
assert result is not None
|
|
143
|
+
assert isinstance(result["accuracy"], float)
|
|
144
|
+
assert isinstance(result["f1_score"], float)
|
|
145
|
+
assert isinstance(result["loss"], float)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_evaluate_with_telemetry(model):
|
|
149
|
+
samples = [
|
|
150
|
+
{"text": "chicken noodle soup is the best", "label": 1},
|
|
151
|
+
{"text": "cats are cute", "label": 0},
|
|
152
|
+
]
|
|
153
|
+
eval_datasource = Datasource.from_list("eval_datasource_2", samples)
|
|
154
|
+
result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
|
|
155
|
+
assert result is not None
|
|
156
|
+
predictions = model.predictions(tag="test")
|
|
157
|
+
assert len(predictions) == 2
|
|
158
|
+
assert all(p.tags == {"test"} for p in predictions)
|
|
159
|
+
assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
163
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
164
|
+
assert len(predictions) == 2
|
|
165
|
+
assert predictions[0].label == 0
|
|
166
|
+
assert predictions[0].label_name == label_names[0]
|
|
167
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
168
|
+
assert predictions[1].label == 1
|
|
169
|
+
assert predictions[1].label_name == label_names[1]
|
|
170
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
|
|
174
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
175
|
+
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def test_predict_unauthorized(unauthorized, model: ClassificationModel):
|
|
179
|
+
with pytest.raises(LookupError):
|
|
180
|
+
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_predict_constraint_violation(memoryset: LabeledMemoryset):
|
|
184
|
+
model = ClassificationModel.create(
|
|
185
|
+
"test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
|
|
186
|
+
)
|
|
187
|
+
with pytest.raises(RuntimeError):
|
|
188
|
+
model.predict("test")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def test_record_prediction_feedback(model: ClassificationModel):
|
|
192
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
193
|
+
expected_labels = [0, 1]
|
|
194
|
+
model.record_feedback(
|
|
195
|
+
{
|
|
196
|
+
"prediction_id": p.prediction_id,
|
|
197
|
+
"category": "correct",
|
|
198
|
+
"value": p.label == expected_label,
|
|
199
|
+
}
|
|
200
|
+
for expected_label, p in zip(expected_labels, predictions)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_record_prediction_feedback_missing_category(model: ClassificationModel):
|
|
205
|
+
prediction = model.predict("Do you love soup?")
|
|
206
|
+
with pytest.raises(ValueError):
|
|
207
|
+
model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
|
|
211
|
+
prediction = model.predict("Do you love soup?")
|
|
212
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
213
|
+
model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
|
|
217
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
218
|
+
model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
|
|
222
|
+
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
223
|
+
"test_memoryset_inverted_labels",
|
|
224
|
+
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
225
|
+
value_column="text",
|
|
226
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
227
|
+
)
|
|
228
|
+
with model.use_memoryset(inverted_labeled_memoryset):
|
|
229
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
230
|
+
assert predictions[0].label == 1
|
|
231
|
+
assert predictions[1].label == 0
|
|
232
|
+
|
|
233
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
234
|
+
assert predictions[0].label == 0
|
|
235
|
+
assert predictions[1].label == 1
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_predict_with_expected_labels(model: ClassificationModel):
|
|
239
|
+
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
240
|
+
assert prediction.expected_label == 1
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
|
|
244
|
+
# invalid number of expected labels for batch prediction
|
|
245
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
246
|
+
model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
247
|
+
# invalid label value
|
|
248
|
+
with pytest.raises(ValueError):
|
|
249
|
+
model.predict("Do you love soup?", expected_labels=5)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def test_last_prediction_with_batch(model: ClassificationModel):
|
|
253
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
254
|
+
assert model.last_prediction is not None
|
|
255
|
+
assert model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
256
|
+
assert model.last_prediction.input_value == "Are cats cute?"
|
|
257
|
+
assert model._last_prediction_was_batch is True
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def test_last_prediction_with_single(model: ClassificationModel):
|
|
261
|
+
# Test that last_prediction is updated correctly with single prediction
|
|
262
|
+
prediction = model.predict("Do you love soup?")
|
|
263
|
+
assert model.last_prediction is not None
|
|
264
|
+
assert model.last_prediction.prediction_id == prediction.prediction_id
|
|
265
|
+
assert model.last_prediction.input_value == "Do you love soup?"
|
|
266
|
+
assert model._last_prediction_was_batch is False
|
orca_sdk/conftest.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Generator
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from datasets import ClassLabel, Dataset, Features, Value
|
|
8
|
+
|
|
9
|
+
from ._generated_api_client.client import set_headers
|
|
10
|
+
from ._utils.auth import _create_api_key, _delete_org
|
|
11
|
+
from .classification_model import ClassificationModel
|
|
12
|
+
from .datasource import Datasource
|
|
13
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
14
|
+
from .labeled_memoryset import LabeledMemoryset
|
|
15
|
+
from .orca_credentials import OrcaCredentials
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(level=logging.INFO)
|
|
18
|
+
|
|
19
|
+
os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _create_org_id():
|
|
23
|
+
# UUID start to identify test data (0xtest...)
|
|
24
|
+
return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture(scope="session")
|
|
28
|
+
def org_id():
|
|
29
|
+
return _create_org_id()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture(autouse=True, scope="session")
|
|
33
|
+
def api_key(org_id) -> Generator[str, None, None]:
|
|
34
|
+
api_key = _create_api_key(org_id=org_id, name="orca_sdk_test")
|
|
35
|
+
OrcaCredentials.set_api_key(api_key, check_validity=True)
|
|
36
|
+
yield api_key
|
|
37
|
+
_delete_org(org_id)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture(autouse=True)
|
|
41
|
+
def authenticated(api_key):
|
|
42
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pytest.fixture()
|
|
46
|
+
def unauthenticated(api_key):
|
|
47
|
+
OrcaCredentials.set_api_key(str(uuid4()), check_validity=False)
|
|
48
|
+
yield
|
|
49
|
+
# Need to reset the api key to the original api key so following tests don't fail
|
|
50
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.fixture()
|
|
54
|
+
def other_org_id():
|
|
55
|
+
return _create_org_id()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.fixture()
|
|
59
|
+
def unauthorized(api_key, other_org_id):
|
|
60
|
+
different_api_key = _create_api_key(org_id=other_org_id, name="orca_sdk_test_other_org")
|
|
61
|
+
OrcaCredentials.set_api_key(different_api_key, check_validity=False)
|
|
62
|
+
yield
|
|
63
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
64
|
+
_delete_org(other_org_id)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.fixture(scope="session")
|
|
68
|
+
def label_names():
|
|
69
|
+
return ["soup", "cats"]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
SAMPLE_DATA = [
|
|
73
|
+
{"text": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
|
|
74
|
+
{"text": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
|
|
75
|
+
{"text": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
|
|
76
|
+
{"text": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
|
|
77
|
+
{"text": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
|
|
78
|
+
{"text": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture(scope="session")
|
|
83
|
+
def hf_dataset(label_names):
|
|
84
|
+
return Dataset.from_list(
|
|
85
|
+
SAMPLE_DATA,
|
|
86
|
+
features=Features(
|
|
87
|
+
{
|
|
88
|
+
"text": Value("string"),
|
|
89
|
+
"label": ClassLabel(names=label_names),
|
|
90
|
+
"key": Value("string"),
|
|
91
|
+
"score": Value("float"),
|
|
92
|
+
"source_id": Value("string"),
|
|
93
|
+
}
|
|
94
|
+
),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.fixture(scope="session")
|
|
99
|
+
def datasource(hf_dataset) -> Datasource:
|
|
100
|
+
return Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.fixture(scope="session")
|
|
104
|
+
def memoryset(datasource) -> LabeledMemoryset:
|
|
105
|
+
return LabeledMemoryset.create(
|
|
106
|
+
"test_memoryset",
|
|
107
|
+
datasource=datasource,
|
|
108
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
109
|
+
value_column="text",
|
|
110
|
+
source_id_column="source_id",
|
|
111
|
+
max_seq_length_override=32,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@pytest.fixture(scope="session")
|
|
116
|
+
def model(memoryset) -> ClassificationModel:
|
|
117
|
+
return ClassificationModel.create("test_model", memoryset, num_classes=2, memory_lookup_count=3)
|