orca-sdk 0.0.78__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +24 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +205 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +130 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +172 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +132 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +129 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +185 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +172 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +172 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +163 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +129 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +192 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +169 -0
- orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +185 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +171 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +129 -0
- orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +237 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +120 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +120 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +191 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +129 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +185 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +170 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +158 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +171 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +190 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +171 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +186 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +262 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +129 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +195 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +190 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +189 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +194 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +163 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +129 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +156 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +158 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +245 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +164 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +158 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +159 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +129 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +177 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +173 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +183 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +179 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +116 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +137 -0
- orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +9 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +147 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +120 -0
- orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +9 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +145 -0
- orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +9 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +279 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +179 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +148 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +86 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/embedding_model_result.py +114 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +20 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +115 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +246 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +237 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +257 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memory_metrics.py +156 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +21 -0
- orca_sdk/_generated_api_client/models/precision_recall_curve.py +94 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +10 -0
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +12 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +94 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +192 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +68 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +508 -0
- orca_sdk/classification_model_test.py +272 -0
- orca_sdk/conftest.py +116 -0
- orca_sdk/credentials.py +126 -0
- orca_sdk/credentials_test.py +37 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +96 -0
- orca_sdk/embedding_model.py +347 -0
- orca_sdk/embedding_model_test.py +176 -0
- orca_sdk/memoryset.py +1209 -0
- orca_sdk/memoryset_test.py +287 -0
- orca_sdk/telemetry.py +398 -0
- orca_sdk/telemetry_test.py +109 -0
- orca_sdk-0.0.78.dist-info/METADATA +79 -0
- orca_sdk-0.0.78.dist-info/RECORD +188 -0
- orca_sdk-0.0.78.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets.arrow_dataset import Dataset
|
|
5
|
+
|
|
6
|
+
from .datasource import Datasource
|
|
7
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
8
|
+
from .memoryset import LabeledMemoryset, TaskStatus
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_create_memoryset(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
12
|
+
assert memoryset is not None
|
|
13
|
+
assert memoryset.name == "test_memoryset"
|
|
14
|
+
assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
15
|
+
assert memoryset.label_names == label_names
|
|
16
|
+
assert memoryset.insertion_status == TaskStatus.COMPLETED
|
|
17
|
+
assert isinstance(memoryset.length, int)
|
|
18
|
+
assert memoryset.length == len(hf_dataset)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_create_memoryset_unauthenticated(unauthenticated, datasource):
|
|
22
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
23
|
+
LabeledMemoryset.create("test_memoryset", datasource)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_create_memoryset_invalid_input(datasource):
|
|
27
|
+
# invalid name
|
|
28
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
29
|
+
LabeledMemoryset.create("test memoryset", datasource)
|
|
30
|
+
# invalid datasource
|
|
31
|
+
datasource.id = str(uuid4())
|
|
32
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
33
|
+
LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_create_memoryset_already_exists_error(hf_dataset, label_names, memoryset):
|
|
37
|
+
with pytest.raises(ValueError):
|
|
38
|
+
LabeledMemoryset.from_hf_dataset("test_memoryset", hf_dataset, label_names=label_names, value_column="text")
|
|
39
|
+
with pytest.raises(ValueError):
|
|
40
|
+
LabeledMemoryset.from_hf_dataset(
|
|
41
|
+
"test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_create_memoryset_already_exists_open(hf_dataset, label_names, memoryset):
|
|
46
|
+
# invalid label names
|
|
47
|
+
with pytest.raises(ValueError):
|
|
48
|
+
LabeledMemoryset.from_hf_dataset(
|
|
49
|
+
memoryset.name,
|
|
50
|
+
hf_dataset,
|
|
51
|
+
label_names=["turtles", "frogs"],
|
|
52
|
+
value_column="text",
|
|
53
|
+
if_exists="open",
|
|
54
|
+
)
|
|
55
|
+
# different embedding model
|
|
56
|
+
with pytest.raises(ValueError):
|
|
57
|
+
LabeledMemoryset.from_hf_dataset(
|
|
58
|
+
memoryset.name,
|
|
59
|
+
hf_dataset,
|
|
60
|
+
label_names=label_names,
|
|
61
|
+
embedding_model=PretrainedEmbeddingModel.DISTILBERT,
|
|
62
|
+
if_exists="open",
|
|
63
|
+
)
|
|
64
|
+
opened_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
65
|
+
memoryset.name,
|
|
66
|
+
hf_dataset,
|
|
67
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
68
|
+
if_exists="open",
|
|
69
|
+
)
|
|
70
|
+
assert opened_memoryset is not None
|
|
71
|
+
assert opened_memoryset.name == memoryset.name
|
|
72
|
+
assert opened_memoryset.length == len(hf_dataset)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_open_memoryset(memoryset, hf_dataset):
|
|
76
|
+
fetched_memoryset = LabeledMemoryset.open(memoryset.name)
|
|
77
|
+
assert fetched_memoryset is not None
|
|
78
|
+
assert fetched_memoryset.name == memoryset.name
|
|
79
|
+
assert fetched_memoryset.length == len(hf_dataset)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def test_open_memoryset_unauthenticated(unauthenticated, memoryset):
|
|
83
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
84
|
+
LabeledMemoryset.open(memoryset.name)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_open_memoryset_not_found():
|
|
88
|
+
with pytest.raises(LookupError):
|
|
89
|
+
LabeledMemoryset.open(str(uuid4()))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_open_memoryset_invalid_input():
|
|
93
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
94
|
+
LabeledMemoryset.open("not valid id")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def test_open_memoryset_unauthorized(unauthorized, memoryset):
|
|
98
|
+
with pytest.raises(LookupError):
|
|
99
|
+
LabeledMemoryset.open(memoryset.name)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_all_memorysets(memoryset):
|
|
103
|
+
memorysets = LabeledMemoryset.all()
|
|
104
|
+
assert len(memorysets) > 0
|
|
105
|
+
assert any(memoryset.name == memoryset.name for memoryset in memorysets)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_all_memorysets_unauthenticated(unauthenticated):
|
|
109
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
110
|
+
LabeledMemoryset.all()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_all_memorysets_unauthorized(unauthorized, memoryset):
|
|
114
|
+
assert memoryset not in LabeledMemoryset.all()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_drop_memoryset(hf_dataset):
|
|
118
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
119
|
+
"test_memoryset_delete",
|
|
120
|
+
hf_dataset.select(range(1)),
|
|
121
|
+
value_column="text",
|
|
122
|
+
)
|
|
123
|
+
assert LabeledMemoryset.exists(memoryset.name)
|
|
124
|
+
LabeledMemoryset.drop(memoryset.name)
|
|
125
|
+
assert not LabeledMemoryset.exists(memoryset.name)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
|
|
129
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
130
|
+
LabeledMemoryset.drop(memoryset.name)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_drop_memoryset_not_found(memoryset):
|
|
134
|
+
with pytest.raises(LookupError):
|
|
135
|
+
LabeledMemoryset.drop(str(uuid4()))
|
|
136
|
+
# ignores error if specified
|
|
137
|
+
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_drop_memoryset_unauthorized(unauthorized, memoryset):
|
|
141
|
+
with pytest.raises(LookupError):
|
|
142
|
+
LabeledMemoryset.drop(memoryset.name)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_search(memoryset: LabeledMemoryset):
|
|
146
|
+
memory_lookups = memoryset.search(["i love soup", "cats are cute"])
|
|
147
|
+
assert len(memory_lookups) == 2
|
|
148
|
+
assert len(memory_lookups[0]) == 1
|
|
149
|
+
assert len(memory_lookups[1]) == 1
|
|
150
|
+
assert memory_lookups[0][0].label == 0
|
|
151
|
+
assert memory_lookups[1][0].label == 1
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_search_count(memoryset: LabeledMemoryset):
|
|
155
|
+
memory_lookups = memoryset.search("i love soup", count=3)
|
|
156
|
+
assert len(memory_lookups) == 3
|
|
157
|
+
assert memory_lookups[0].label == 0
|
|
158
|
+
assert memory_lookups[1].label == 0
|
|
159
|
+
assert memory_lookups[2].label == 0
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_get_memory_at_index(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
163
|
+
memory = memoryset[0]
|
|
164
|
+
assert memory.value == hf_dataset[0]["text"]
|
|
165
|
+
assert memory.label == hf_dataset[0]["label"]
|
|
166
|
+
assert memory.label_name == label_names[hf_dataset[0]["label"]]
|
|
167
|
+
assert memory.source_id == hf_dataset[0]["source_id"]
|
|
168
|
+
assert memory.score == hf_dataset[0]["score"]
|
|
169
|
+
assert memory.key == hf_dataset[0]["key"]
|
|
170
|
+
last_memory = memoryset[-1]
|
|
171
|
+
assert last_memory.value == hf_dataset[-1]["text"]
|
|
172
|
+
assert last_memory.label == hf_dataset[-1]["label"]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def test_get_range_of_memories(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
176
|
+
memories = memoryset[1:3]
|
|
177
|
+
assert len(memories) == 2
|
|
178
|
+
assert memories[0].value == hf_dataset["text"][1]
|
|
179
|
+
assert memories[1].value == hf_dataset["text"][2]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def test_get_memory_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
183
|
+
memory = memoryset.get(memoryset[0].memory_id)
|
|
184
|
+
assert memory.value == hf_dataset[0]["text"]
|
|
185
|
+
assert memory == memoryset[memory.memory_id]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_get_memories_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
189
|
+
memories = memoryset.get([memoryset[0].memory_id, memoryset[1].memory_id])
|
|
190
|
+
assert len(memories) == 2
|
|
191
|
+
assert memories[0].value == hf_dataset[0]["text"]
|
|
192
|
+
assert memories[1].value == hf_dataset[1]["text"]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_query_memoryset(memoryset: LabeledMemoryset):
|
|
196
|
+
memories = memoryset.query(filters=[("label", "==", 1)])
|
|
197
|
+
assert len(memories) == 3
|
|
198
|
+
assert all(memory.label == 1 for memory in memories)
|
|
199
|
+
assert len(memoryset.query(limit=2)) == 2
|
|
200
|
+
assert len(memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_insert_memories(memoryset: LabeledMemoryset):
|
|
204
|
+
prev_length = memoryset.length
|
|
205
|
+
memoryset.insert(
|
|
206
|
+
[
|
|
207
|
+
dict(value="tomato soup is my favorite", label=0),
|
|
208
|
+
dict(value="cats are fun to play with", label=1),
|
|
209
|
+
]
|
|
210
|
+
)
|
|
211
|
+
assert memoryset.length == prev_length + 2
|
|
212
|
+
memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
213
|
+
assert memoryset.length == prev_length + 3
|
|
214
|
+
last_memory = memoryset[-1]
|
|
215
|
+
assert last_memory.value == "tomato soup is my favorite"
|
|
216
|
+
assert last_memory.label == 0
|
|
217
|
+
assert last_memory.metadata
|
|
218
|
+
assert last_memory.metadata["key"] == "test"
|
|
219
|
+
assert last_memory.source_id == "test"
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def test_update_memory(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
223
|
+
memory_id = memoryset[0].memory_id
|
|
224
|
+
updated_memory = memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
225
|
+
assert updated_memory.value == "i love soup so much"
|
|
226
|
+
assert updated_memory.label == hf_dataset[0]["label"]
|
|
227
|
+
assert memoryset.get(memory_id).value == "i love soup so much"
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def test_update_memory_instance(memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
231
|
+
memory = memoryset[0]
|
|
232
|
+
updated_memory = memory.update(value="i love soup even more")
|
|
233
|
+
assert updated_memory is memory
|
|
234
|
+
assert memory.value == "i love soup even more"
|
|
235
|
+
assert memory.label == hf_dataset[0]["label"]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_update_memories(memoryset: LabeledMemoryset):
|
|
239
|
+
memory_ids = [memory.memory_id for memory in memoryset[:2]]
|
|
240
|
+
updated_memories = memoryset.update(
|
|
241
|
+
[
|
|
242
|
+
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
243
|
+
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
244
|
+
]
|
|
245
|
+
)
|
|
246
|
+
assert updated_memories[0].value == "i love soup so much"
|
|
247
|
+
assert updated_memories[1].value == "cats are so cute"
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def test_delete_memory(memoryset: LabeledMemoryset):
|
|
251
|
+
prev_length = memoryset.length
|
|
252
|
+
memory_id = memoryset[0].memory_id
|
|
253
|
+
memoryset.delete(memory_id)
|
|
254
|
+
with pytest.raises(LookupError):
|
|
255
|
+
memoryset.get(memory_id)
|
|
256
|
+
assert memoryset.length == prev_length - 1
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def test_delete_memories(memoryset: LabeledMemoryset):
|
|
260
|
+
prev_length = memoryset.length
|
|
261
|
+
memoryset.delete([memoryset[0].memory_id, memoryset[1].memory_id])
|
|
262
|
+
assert memoryset.length == prev_length - 2
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def test_clone_memoryset(memoryset: LabeledMemoryset):
|
|
266
|
+
cloned_memoryset = memoryset.clone("test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT)
|
|
267
|
+
assert cloned_memoryset is not None
|
|
268
|
+
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
269
|
+
assert cloned_memoryset.length == memoryset.length
|
|
270
|
+
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
271
|
+
assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def test_embedding_evaluation(hf_dataset):
|
|
275
|
+
datasource = Datasource.from_hf_dataset("eval_datasource", hf_dataset, if_exists="open")
|
|
276
|
+
response = LabeledMemoryset.run_embedding_evaluation(
|
|
277
|
+
datasource, embedding_models=["CDE_SMALL"], neighbor_count=2, value_column="text"
|
|
278
|
+
)
|
|
279
|
+
assert response is not None
|
|
280
|
+
assert isinstance(response, dict)
|
|
281
|
+
assert response is not None
|
|
282
|
+
assert isinstance(response["evaluation_results"], list)
|
|
283
|
+
assert len(response["evaluation_results"]) == 1
|
|
284
|
+
assert response["evaluation_results"][0] is not None
|
|
285
|
+
assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
|
|
286
|
+
assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
287
|
+
Datasource.drop("eval_datasource")
|
orca_sdk/telemetry.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Iterable, overload
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from orca_sdk._utils.common import UNSET
|
|
9
|
+
|
|
10
|
+
from ._generated_api_client.api import (
|
|
11
|
+
drop_feedback_category_with_data,
|
|
12
|
+
get_prediction,
|
|
13
|
+
list_feedback_categories,
|
|
14
|
+
list_predictions,
|
|
15
|
+
record_prediction_feedback,
|
|
16
|
+
update_prediction,
|
|
17
|
+
)
|
|
18
|
+
from ._generated_api_client.models import (
|
|
19
|
+
FeedbackType,
|
|
20
|
+
LabelPredictionWithMemoriesAndFeedback,
|
|
21
|
+
ListPredictionsRequest,
|
|
22
|
+
PredictionFeedbackCategory,
|
|
23
|
+
PredictionFeedbackRequest,
|
|
24
|
+
UpdatePredictionRequest,
|
|
25
|
+
)
|
|
26
|
+
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
27
|
+
from ._utils.prediction_result_ui import inspect_prediction_result
|
|
28
|
+
from .memoryset import LabeledMemoryLookup, LabeledMemoryset
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from .classification_model import ClassificationModel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
|
|
35
|
+
category = feedback.get("category", None)
|
|
36
|
+
if category is None:
|
|
37
|
+
raise ValueError("`category` must be specified")
|
|
38
|
+
prediction_id = feedback.get("prediction_id", None)
|
|
39
|
+
if prediction_id is None:
|
|
40
|
+
raise ValueError("`prediction_id` must be specified")
|
|
41
|
+
return PredictionFeedbackRequest(
|
|
42
|
+
prediction_id=prediction_id,
|
|
43
|
+
category_name=category,
|
|
44
|
+
value=feedback.get("value", CLIENT_UNSET),
|
|
45
|
+
comment=feedback.get("comment", CLIENT_UNSET),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class FeedbackCategory:
|
|
50
|
+
"""
|
|
51
|
+
A category of feedback for predictions.
|
|
52
|
+
|
|
53
|
+
Categories are created automatically, the first time feedback with a new name is recorded.
|
|
54
|
+
The value type of the category is inferred from the first recorded value. Subsequent feedback
|
|
55
|
+
for the same category must be of the same type. Categories are not model specific.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
id: Unique identifier for the category.
|
|
59
|
+
name: Name of the category.
|
|
60
|
+
value_type: Type that values for this category must have.
|
|
61
|
+
created_at: When the category was created.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
id: str
|
|
65
|
+
name: str
|
|
66
|
+
value_type: type[bool] | type[float]
|
|
67
|
+
created_at: datetime
|
|
68
|
+
|
|
69
|
+
def __init__(self, category: PredictionFeedbackCategory):
|
|
70
|
+
# for internal use only, do not document
|
|
71
|
+
self.id = category.id
|
|
72
|
+
self.name = category.name
|
|
73
|
+
self.value_type = bool if category.type == FeedbackType.BINARY else float
|
|
74
|
+
self.created_at = category.created_at
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def all(cls) -> list[FeedbackCategory]:
|
|
78
|
+
"""
|
|
79
|
+
Get a list of all existing feedback categories.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List with information about all existing feedback categories.
|
|
83
|
+
"""
|
|
84
|
+
return [FeedbackCategory(category) for category in list_feedback_categories()]
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def drop(cls, name: str) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Drop all feedback for this category and drop the category itself, allowing it to be
|
|
90
|
+
recreated with a different value type.
|
|
91
|
+
|
|
92
|
+
Warning:
|
|
93
|
+
This will delete all feedback in this category across all models.
|
|
94
|
+
|
|
95
|
+
Params:
|
|
96
|
+
name: Name of the category to drop.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
LookupError: If the category is not found.
|
|
100
|
+
"""
|
|
101
|
+
drop_feedback_category_with_data(name)
|
|
102
|
+
logging.info(f"Deleted feedback category {name} with all associated feedback")
|
|
103
|
+
|
|
104
|
+
def __repr__(self):
|
|
105
|
+
return "FeedbackCategory({" + f"name: {self.name}, " + f"value_type: {self.value_type}" + "})"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LabelPrediction:
|
|
109
|
+
"""
|
|
110
|
+
A prediction made by a model
|
|
111
|
+
|
|
112
|
+
Attributes:
|
|
113
|
+
prediction_id: Unique identifier for the prediction
|
|
114
|
+
label: Predicted label for the input value
|
|
115
|
+
label_name: Name of the predicted label
|
|
116
|
+
confidence: Confidence of the prediction
|
|
117
|
+
anomaly_score: The score for how anomalous the input is relative to the memories
|
|
118
|
+
memory_lookups: List of memories used to ground the prediction
|
|
119
|
+
input_value: Input value that this prediction was for
|
|
120
|
+
model: Model that was used to make the prediction
|
|
121
|
+
memoryset: Memoryset that was used to lookup memories to ground the prediction
|
|
122
|
+
expected_label: Optional expected label that was set for the prediction
|
|
123
|
+
tags: tags that were set for the prediction
|
|
124
|
+
feedback: Feedback recorded, mapping from category name to value
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
prediction_id: str
|
|
128
|
+
label: int
|
|
129
|
+
label_name: str | None
|
|
130
|
+
confidence: float
|
|
131
|
+
anomaly_score: float | None
|
|
132
|
+
memoryset: LabeledMemoryset
|
|
133
|
+
model: ClassificationModel
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
prediction_id: str,
|
|
138
|
+
*,
|
|
139
|
+
label: int,
|
|
140
|
+
label_name: str | None,
|
|
141
|
+
confidence: float,
|
|
142
|
+
anomaly_score: float | None,
|
|
143
|
+
memoryset: LabeledMemoryset | str,
|
|
144
|
+
model: ClassificationModel | str,
|
|
145
|
+
telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
|
|
146
|
+
):
|
|
147
|
+
# for internal use only, do not document
|
|
148
|
+
from .classification_model import ClassificationModel
|
|
149
|
+
|
|
150
|
+
self.prediction_id = prediction_id
|
|
151
|
+
self.label = label
|
|
152
|
+
self.label_name = label_name
|
|
153
|
+
self.confidence = confidence
|
|
154
|
+
self.anomaly_score = anomaly_score
|
|
155
|
+
self.memoryset = LabeledMemoryset.open(memoryset) if isinstance(memoryset, str) else memoryset
|
|
156
|
+
self.model = ClassificationModel.open(model) if isinstance(model, str) else model
|
|
157
|
+
self.__telemetry = telemetry if telemetry else None
|
|
158
|
+
|
|
159
|
+
def __repr__(self):
|
|
160
|
+
return (
|
|
161
|
+
"LabelPrediction({"
|
|
162
|
+
+ f"label: <{self.label_name}: {self.label}>, "
|
|
163
|
+
+ f"confidence: {self.confidence:.2f}, "
|
|
164
|
+
+ f"anomaly_score: {self.anomaly_score:.2f}, "
|
|
165
|
+
if self.anomaly_score is not None
|
|
166
|
+
else ""
|
|
167
|
+
+ f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
|
|
168
|
+
+ "})"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback:
|
|
173
|
+
# for internal use only, do not document
|
|
174
|
+
if self.__telemetry is None:
|
|
175
|
+
self.__telemetry = get_prediction(prediction_id=UUID(self.prediction_id))
|
|
176
|
+
return self.__telemetry
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def memory_lookups(self) -> list[LabeledMemoryLookup]:
|
|
180
|
+
return [LabeledMemoryLookup(self.memoryset.id, lookup) for lookup in self._telemetry.memories]
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def input_value(self) -> str | None:
|
|
184
|
+
return self._telemetry.input_value
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def feedback(self) -> dict[str, bool | float]:
|
|
188
|
+
return {
|
|
189
|
+
f.category_name: (
|
|
190
|
+
f.value if f.category_type == FeedbackType.CONTINUOUS else True if f.value == 1 else False
|
|
191
|
+
)
|
|
192
|
+
for f in self._telemetry.feedbacks
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def expected_label(self) -> int | None:
|
|
197
|
+
return self._telemetry.expected_label
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def tags(self) -> set[str]:
|
|
201
|
+
return set(self._telemetry.tags)
|
|
202
|
+
|
|
203
|
+
@overload
|
|
204
|
+
@classmethod
|
|
205
|
+
def get(cls, prediction_id: str) -> LabelPrediction: # type: ignore -- this takes precedence
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
@overload
|
|
209
|
+
@classmethod
|
|
210
|
+
def get(cls, prediction_id: Iterable[str]) -> list[LabelPrediction]:
|
|
211
|
+
pass
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def get(cls, prediction_id: str | Iterable[str]) -> LabelPrediction | list[LabelPrediction]:
|
|
215
|
+
"""
|
|
216
|
+
Fetch a prediction or predictions
|
|
217
|
+
|
|
218
|
+
Params:
|
|
219
|
+
prediction_id: Unique identifier of the prediction or predictions to fetch
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Prediction or list of predictions
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
LookupError: If no prediction with the given id is found
|
|
226
|
+
|
|
227
|
+
Examples:
|
|
228
|
+
Fetch a single prediction:
|
|
229
|
+
>>> LabelPrediction.get("0195019a-5bc7-7afb-b902-5945ee1fb766")
|
|
230
|
+
LabelPrediction({
|
|
231
|
+
label: <positive: 1>,
|
|
232
|
+
confidence: 0.95,
|
|
233
|
+
anomaly_score: 0.1,
|
|
234
|
+
input_value: "I am happy",
|
|
235
|
+
memoryset: "my_memoryset",
|
|
236
|
+
model: "my_model"
|
|
237
|
+
})
|
|
238
|
+
|
|
239
|
+
Fetch multiple predictions:
|
|
240
|
+
>>> LabelPrediction.get([
|
|
241
|
+
... "0195019a-5bc7-7afb-b902-5945ee1fb766",
|
|
242
|
+
... "019501a1-ea08-76b2-9f62-95e4800b4841",
|
|
243
|
+
... ])
|
|
244
|
+
[
|
|
245
|
+
LabelPrediction({
|
|
246
|
+
label: <positive: 1>,
|
|
247
|
+
confidence: 0.95,
|
|
248
|
+
anomaly_score: 0.1,
|
|
249
|
+
input_value: "I am happy",
|
|
250
|
+
memoryset: "my_memoryset",
|
|
251
|
+
model: "my_model"
|
|
252
|
+
}),
|
|
253
|
+
LabelPrediction({
|
|
254
|
+
label: <negative: 0>,
|
|
255
|
+
confidence: 0.05,
|
|
256
|
+
anomaly_score: 0.2,
|
|
257
|
+
input_value: "I am sad",
|
|
258
|
+
memoryset: "my_memoryset", model: "my_model"
|
|
259
|
+
}),
|
|
260
|
+
]
|
|
261
|
+
"""
|
|
262
|
+
if isinstance(prediction_id, str):
|
|
263
|
+
prediction = get_prediction(prediction_id=UUID(prediction_id))
|
|
264
|
+
return cls(
|
|
265
|
+
prediction_id=prediction.prediction_id,
|
|
266
|
+
label=prediction.label,
|
|
267
|
+
label_name=prediction.label_name,
|
|
268
|
+
confidence=prediction.confidence,
|
|
269
|
+
anomaly_score=prediction.anomaly_score,
|
|
270
|
+
memoryset=prediction.memoryset_id,
|
|
271
|
+
model=prediction.model_id,
|
|
272
|
+
telemetry=prediction,
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
return [
|
|
276
|
+
cls(
|
|
277
|
+
prediction_id=prediction.prediction_id,
|
|
278
|
+
label=prediction.label,
|
|
279
|
+
label_name=prediction.label_name,
|
|
280
|
+
confidence=prediction.confidence,
|
|
281
|
+
anomaly_score=prediction.anomaly_score,
|
|
282
|
+
memoryset=prediction.memoryset_id,
|
|
283
|
+
model=prediction.model_id,
|
|
284
|
+
telemetry=prediction,
|
|
285
|
+
)
|
|
286
|
+
for prediction in list_predictions(body=ListPredictionsRequest(prediction_ids=list(prediction_id)))
|
|
287
|
+
]
|
|
288
|
+
|
|
289
|
+
def refresh(self):
|
|
290
|
+
"""Refresh the prediction data from the OrcaCloud"""
|
|
291
|
+
self.__dict__.update(LabelPrediction.get(self.prediction_id).__dict__)
|
|
292
|
+
|
|
293
|
+
def inspect(self):
|
|
294
|
+
"""Open a UI to inspect the memories used by this prediction"""
|
|
295
|
+
inspect_prediction_result(self)
|
|
296
|
+
|
|
297
|
+
def update(self, *, expected_label: int | None = UNSET, tags: set[str] | None = UNSET) -> None:
|
|
298
|
+
"""
|
|
299
|
+
Update editable prediction properties.
|
|
300
|
+
|
|
301
|
+
Params:
|
|
302
|
+
expected_label: Value to set for the expected label, defaults to `[UNSET]` if not provided.
|
|
303
|
+
tags: Value to replace existing tags with, defaults to `[UNSET]` if not provided.
|
|
304
|
+
|
|
305
|
+
Examples:
|
|
306
|
+
Update the expected label:
|
|
307
|
+
>>> prediction.update(expected_label=1)
|
|
308
|
+
|
|
309
|
+
Add a new tag:
|
|
310
|
+
>>> prediction.update(tags=prediction.tags | {"new_tag"})
|
|
311
|
+
|
|
312
|
+
Remove expected label and tags:
|
|
313
|
+
>>> prediction.update(expected_label=None, tags=None)
|
|
314
|
+
"""
|
|
315
|
+
update_prediction(
|
|
316
|
+
prediction_id=self.prediction_id,
|
|
317
|
+
body=UpdatePredictionRequest(
|
|
318
|
+
expected_label=expected_label if expected_label is not UNSET else CLIENT_UNSET,
|
|
319
|
+
tags=[] if tags is None else list(tags) if tags is not UNSET else CLIENT_UNSET,
|
|
320
|
+
),
|
|
321
|
+
)
|
|
322
|
+
self.refresh()
|
|
323
|
+
|
|
324
|
+
def add_tag(self, tag: str) -> None:
|
|
325
|
+
"""
|
|
326
|
+
Add a tag to the prediction
|
|
327
|
+
|
|
328
|
+
Params:
|
|
329
|
+
tag: Tag to add to the prediction
|
|
330
|
+
"""
|
|
331
|
+
self.update(tags=self.tags | {tag})
|
|
332
|
+
|
|
333
|
+
def remove_tag(self, tag: str) -> None:
|
|
334
|
+
"""
|
|
335
|
+
Remove a tag from the prediction
|
|
336
|
+
|
|
337
|
+
Params:
|
|
338
|
+
tag: Tag to remove from the prediction
|
|
339
|
+
"""
|
|
340
|
+
self.update(tags=self.tags - {tag})
|
|
341
|
+
|
|
342
|
+
def record_feedback(
|
|
343
|
+
self,
|
|
344
|
+
category: str,
|
|
345
|
+
value: bool | float,
|
|
346
|
+
*,
|
|
347
|
+
comment: str | None = None,
|
|
348
|
+
):
|
|
349
|
+
"""
|
|
350
|
+
Record feedback for the prediction.
|
|
351
|
+
|
|
352
|
+
We support recording feedback in several categories for each prediction. A
|
|
353
|
+
[`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
|
|
354
|
+
the first time feedback with a new name is recorded. Categories are global across models.
|
|
355
|
+
The value type of the category is inferred from the first recorded value. Subsequent
|
|
356
|
+
feedback for the same category must be of the same type.
|
|
357
|
+
|
|
358
|
+
Params:
|
|
359
|
+
category: Name of the category under which to record the feedback.
|
|
360
|
+
value: Feedback value to record, should be `True` for positive feedback and `False` for
|
|
361
|
+
negative feedback or a [`float`][float] between `-1.0` and `+1.0` where negative
|
|
362
|
+
values indicate negative feedback and positive values indicate positive feedback.
|
|
363
|
+
comment: Optional comment to record with the feedback.
|
|
364
|
+
|
|
365
|
+
Examples:
|
|
366
|
+
Record whether a suggestion was accepted or rejected:
|
|
367
|
+
>>> prediction.record_feedback("accepted", True)
|
|
368
|
+
|
|
369
|
+
Record star rating as normalized continuous score between `-1.0` and `+1.0`:
|
|
370
|
+
>>> prediction.record_feedback("rating", -0.5, comment="2 stars")
|
|
371
|
+
|
|
372
|
+
Raises:
|
|
373
|
+
ValueError: If the value does not match previous value types for the category, or is a
|
|
374
|
+
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
375
|
+
"""
|
|
376
|
+
record_prediction_feedback(
|
|
377
|
+
body=[
|
|
378
|
+
_parse_feedback(
|
|
379
|
+
{"prediction_id": self.prediction_id, "category": category, "value": value, "comment": comment}
|
|
380
|
+
)
|
|
381
|
+
]
|
|
382
|
+
)
|
|
383
|
+
self.refresh()
|
|
384
|
+
|
|
385
|
+
def delete_feedback(self, category: str) -> None:
|
|
386
|
+
"""
|
|
387
|
+
Delete prediction feedback for a specific category.
|
|
388
|
+
|
|
389
|
+
Params:
|
|
390
|
+
category: Name of the category of the feedback to delete.
|
|
391
|
+
|
|
392
|
+
Raises:
|
|
393
|
+
ValueError: If the category is not found.
|
|
394
|
+
"""
|
|
395
|
+
record_prediction_feedback(
|
|
396
|
+
body=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)]
|
|
397
|
+
)
|
|
398
|
+
self.refresh()
|