orca-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +19 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +193 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +159 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +194 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +63 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +499 -0
- orca_sdk/classification_model_test.py +266 -0
- orca_sdk/conftest.py +117 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +95 -0
- orca_sdk/embedding_model.py +336 -0
- orca_sdk/embedding_model_test.py +173 -0
- orca_sdk/labeled_memoryset.py +1154 -0
- orca_sdk/labeled_memoryset_test.py +271 -0
- orca_sdk/orca_credentials.py +75 -0
- orca_sdk/orca_credentials_test.py +37 -0
- orca_sdk/telemetry.py +386 -0
- orca_sdk/telemetry_test.py +100 -0
- orca_sdk-0.1.0.dist-info/METADATA +39 -0
- orca_sdk-0.1.0.dist-info/RECORD +175 -0
- orca_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
from ._generated_api_client.api import (
|
|
10
|
+
create_evaluation,
|
|
11
|
+
create_model,
|
|
12
|
+
delete_model,
|
|
13
|
+
get_evaluation,
|
|
14
|
+
get_model,
|
|
15
|
+
list_models,
|
|
16
|
+
list_predictions,
|
|
17
|
+
predict_gpu,
|
|
18
|
+
record_prediction_feedback,
|
|
19
|
+
)
|
|
20
|
+
from ._generated_api_client.models import (
|
|
21
|
+
ClassificationEvaluationResult,
|
|
22
|
+
CreateRACModelRequest,
|
|
23
|
+
EvaluationRequest,
|
|
24
|
+
ListPredictionsRequest,
|
|
25
|
+
)
|
|
26
|
+
from ._generated_api_client.models import (
|
|
27
|
+
ListPredictionsRequestSortItemItemType0 as PredictionSortColumns,
|
|
28
|
+
)
|
|
29
|
+
from ._generated_api_client.models import (
|
|
30
|
+
ListPredictionsRequestSortItemItemType1 as PredictionSortDirection,
|
|
31
|
+
)
|
|
32
|
+
from ._generated_api_client.models import RACHeadType, RACModelMetadata
|
|
33
|
+
from ._generated_api_client.models.prediction_request import PredictionRequest
|
|
34
|
+
from ._utils.common import CreateMode, DropMode
|
|
35
|
+
from ._utils.task import wait_for_task
|
|
36
|
+
from .datasource import Datasource
|
|
37
|
+
from .labeled_memoryset import LabeledMemoryset
|
|
38
|
+
from .telemetry import LabelPrediction, _parse_feedback
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ClassificationModel:
|
|
42
|
+
"""
|
|
43
|
+
A handle to a classification model in OrcaCloud
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
id: Unique identifier for the model
|
|
47
|
+
name: Unique name of the model
|
|
48
|
+
memoryset: Memoryset that the model uses
|
|
49
|
+
head_type: Classification head type of the model
|
|
50
|
+
num_classes: Number of distinct classes the model can predict
|
|
51
|
+
memory_lookup_count: Number of memories the model uses for each prediction
|
|
52
|
+
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
53
|
+
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
54
|
+
created_at: When the model was created
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
id: str
|
|
58
|
+
name: str
|
|
59
|
+
memoryset: LabeledMemoryset
|
|
60
|
+
head_type: RACHeadType
|
|
61
|
+
num_classes: int
|
|
62
|
+
memory_lookup_count: int
|
|
63
|
+
weigh_memories: bool | None
|
|
64
|
+
min_memory_weight: float | None
|
|
65
|
+
version: int
|
|
66
|
+
created_at: datetime
|
|
67
|
+
|
|
68
|
+
def __init__(self, metadata: RACModelMetadata):
|
|
69
|
+
# for internal use only, do not document
|
|
70
|
+
self.id = metadata.id
|
|
71
|
+
self.name = metadata.name
|
|
72
|
+
self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
|
|
73
|
+
self.head_type = metadata.head_type
|
|
74
|
+
self.num_classes = metadata.num_classes
|
|
75
|
+
self.memory_lookup_count = metadata.memory_lookup_count
|
|
76
|
+
self.weigh_memories = metadata.weigh_memories
|
|
77
|
+
self.min_memory_weight = metadata.min_memory_weight
|
|
78
|
+
self.version = metadata.version
|
|
79
|
+
self.created_at = metadata.created_at
|
|
80
|
+
|
|
81
|
+
self._memoryset_override_id: str | None = None
|
|
82
|
+
self._last_prediction: LabelPrediction | None = None
|
|
83
|
+
self._last_prediction_was_batch: bool = False
|
|
84
|
+
|
|
85
|
+
def __eq__(self, other) -> bool:
|
|
86
|
+
return isinstance(other, ClassificationModel) and self.id == other.id
|
|
87
|
+
|
|
88
|
+
def __repr__(self):
|
|
89
|
+
return (
|
|
90
|
+
"ClassificationModel({\n"
|
|
91
|
+
f" name: '{self.name}',\n"
|
|
92
|
+
f" head_type: {self.head_type},\n"
|
|
93
|
+
f" num_classes: {self.num_classes},\n"
|
|
94
|
+
f" memory_lookup_count: {self.memory_lookup_count},\n"
|
|
95
|
+
f" memoryset: LabeledMemoryset.open('{self.memoryset.name}'),\n"
|
|
96
|
+
"})"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def last_prediction(self) -> LabelPrediction:
|
|
101
|
+
"""
|
|
102
|
+
Last prediction made by the model
|
|
103
|
+
|
|
104
|
+
Note:
|
|
105
|
+
If the last prediction was part of a batch prediction, the last prediction from the
|
|
106
|
+
batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
|
|
107
|
+
is raised.
|
|
108
|
+
"""
|
|
109
|
+
if self._last_prediction_was_batch:
|
|
110
|
+
logging.warning(
|
|
111
|
+
"Last prediction was part of a batch prediction, returning the last prediction from the batch"
|
|
112
|
+
)
|
|
113
|
+
if self._last_prediction is None:
|
|
114
|
+
raise LookupError("No prediction has been made yet")
|
|
115
|
+
return self._last_prediction
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def create(
|
|
119
|
+
cls,
|
|
120
|
+
name: str,
|
|
121
|
+
memoryset: LabeledMemoryset,
|
|
122
|
+
head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
|
|
123
|
+
*,
|
|
124
|
+
num_classes: int | None = None,
|
|
125
|
+
memory_lookup_count: int | None = None,
|
|
126
|
+
weigh_memories: bool = True,
|
|
127
|
+
min_memory_weight: float | None = None,
|
|
128
|
+
if_exists: CreateMode = "error",
|
|
129
|
+
) -> ClassificationModel:
|
|
130
|
+
"""
|
|
131
|
+
Create a new classification model
|
|
132
|
+
|
|
133
|
+
Params:
|
|
134
|
+
name: Name for the new model (must be unique)
|
|
135
|
+
memoryset: Memoryset to attach the model to
|
|
136
|
+
head_type: Type of model head to use
|
|
137
|
+
num_classes: Number of classes this model can predict, will be inferred from memoryset if not specified
|
|
138
|
+
memory_lookup_count: Number of memories to lookup for each prediction,
|
|
139
|
+
by default the system uses a simple heuristic to choose a number of memories that works well in most cases
|
|
140
|
+
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
141
|
+
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
142
|
+
if_exists: What to do if a model with the same name already exists, defaults to
|
|
143
|
+
`"error"`. Other option is `"open"` to open the existing model.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Handle to the new model in the OrcaCloud
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
ValueError: If the model already exists and if_exists is `"error"` or if it is
|
|
150
|
+
`"open"` and the existing model has different attributes.
|
|
151
|
+
|
|
152
|
+
Examples:
|
|
153
|
+
Create a new model using default options:
|
|
154
|
+
>>> model = ClassificationModel.create(
|
|
155
|
+
... "my_model",
|
|
156
|
+
... LabeledMemoryset.open("my_memoryset"),
|
|
157
|
+
... )
|
|
158
|
+
|
|
159
|
+
Create a new model with non-default model head and options:
|
|
160
|
+
>>> model = ClassificationModel.create(
|
|
161
|
+
... name="my_model",
|
|
162
|
+
... memoryset=LabeledMemoryset.open("my_memoryset"),
|
|
163
|
+
... head_type=RACHeadType.MMOE,
|
|
164
|
+
... num_classes=5,
|
|
165
|
+
... memory_lookup_count=20,
|
|
166
|
+
... )
|
|
167
|
+
"""
|
|
168
|
+
if cls.exists(name):
|
|
169
|
+
if if_exists == "error":
|
|
170
|
+
raise ValueError(f"Model with name {name} already exists")
|
|
171
|
+
elif if_exists == "open":
|
|
172
|
+
existing = cls.open(name)
|
|
173
|
+
for attribute in {"head_type", "memory_lookup_count", "num_classes", "min_memory_weight"}:
|
|
174
|
+
local_attribute = locals()[attribute]
|
|
175
|
+
existing_attribute = getattr(existing, attribute)
|
|
176
|
+
if local_attribute is not None and local_attribute != existing_attribute:
|
|
177
|
+
raise ValueError(f"Model with name {name} already exists with different {attribute}")
|
|
178
|
+
|
|
179
|
+
# special case for memoryset
|
|
180
|
+
if existing.memoryset.id != memoryset.id:
|
|
181
|
+
raise ValueError(f"Model with name {name} already exists with different memoryset")
|
|
182
|
+
|
|
183
|
+
return existing
|
|
184
|
+
|
|
185
|
+
metadata = create_model(
|
|
186
|
+
body=CreateRACModelRequest(
|
|
187
|
+
name=name,
|
|
188
|
+
memoryset_id=memoryset.id,
|
|
189
|
+
head_type=RACHeadType(head_type),
|
|
190
|
+
memory_lookup_count=memory_lookup_count,
|
|
191
|
+
num_classes=num_classes,
|
|
192
|
+
weigh_memories=weigh_memories,
|
|
193
|
+
min_memory_weight=min_memory_weight,
|
|
194
|
+
),
|
|
195
|
+
)
|
|
196
|
+
return cls(metadata)
|
|
197
|
+
|
|
198
|
+
@classmethod
|
|
199
|
+
def open(cls, name: str) -> ClassificationModel:
|
|
200
|
+
"""
|
|
201
|
+
Get a handle to a classification model in the OrcaCloud
|
|
202
|
+
|
|
203
|
+
Params:
|
|
204
|
+
name: Name or unique identifier of the classification model
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Handle to the existing classification model in the OrcaCloud
|
|
208
|
+
|
|
209
|
+
Raises:
|
|
210
|
+
LookupError: If the classification model does not exist
|
|
211
|
+
"""
|
|
212
|
+
return cls(get_model(name))
|
|
213
|
+
|
|
214
|
+
@classmethod
|
|
215
|
+
def exists(cls, name_or_id: str) -> bool:
|
|
216
|
+
"""
|
|
217
|
+
Check if a classification model exists in the OrcaCloud
|
|
218
|
+
|
|
219
|
+
Params:
|
|
220
|
+
name_or_id: Name or id of the classification model
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
`True` if the classification model exists, `False` otherwise
|
|
224
|
+
"""
|
|
225
|
+
try:
|
|
226
|
+
cls.open(name_or_id)
|
|
227
|
+
return True
|
|
228
|
+
except LookupError:
|
|
229
|
+
return False
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def all(cls) -> list[ClassificationModel]:
|
|
233
|
+
"""
|
|
234
|
+
Get a list of handles to all classification models in the OrcaCloud
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List of handles to all classification models in the OrcaCloud
|
|
238
|
+
"""
|
|
239
|
+
return [cls(metadata) for metadata in list_models()]
|
|
240
|
+
|
|
241
|
+
@classmethod
|
|
242
|
+
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
243
|
+
"""
|
|
244
|
+
Delete a classification model from the OrcaCloud
|
|
245
|
+
|
|
246
|
+
Warning:
|
|
247
|
+
This will delete the model and all associated data, including predictions, evaluations, and feedback.
|
|
248
|
+
|
|
249
|
+
Params:
|
|
250
|
+
name_or_id: Name or id of the classification model
|
|
251
|
+
if_not_exists: What to do if the classification model does not exist, defaults to `"error"`.
|
|
252
|
+
Other option is `"ignore"` to do nothing if the classification model does not exist.
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
256
|
+
"""
|
|
257
|
+
try:
|
|
258
|
+
delete_model(name_or_id)
|
|
259
|
+
logging.info(f"Deleted model {name_or_id}")
|
|
260
|
+
except LookupError:
|
|
261
|
+
if if_not_exists == "error":
|
|
262
|
+
raise
|
|
263
|
+
|
|
264
|
+
@overload
|
|
265
|
+
def predict(
|
|
266
|
+
self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
|
|
267
|
+
) -> list[LabelPrediction]:
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
@overload
|
|
271
|
+
def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
def predict(
|
|
275
|
+
self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
|
|
276
|
+
) -> list[LabelPrediction] | LabelPrediction:
|
|
277
|
+
"""
|
|
278
|
+
Predict label(s) for the given input value(s) grounded in similar memories
|
|
279
|
+
|
|
280
|
+
Params:
|
|
281
|
+
value: Value(s) to get predict the labels of
|
|
282
|
+
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
283
|
+
tags: Tags to add to the prediction(s)
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Label prediction or list of label predictions
|
|
287
|
+
|
|
288
|
+
Examples:
|
|
289
|
+
Predict the label for a single value:
|
|
290
|
+
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
291
|
+
LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy' })
|
|
292
|
+
|
|
293
|
+
Predict the labels for a list of values:
|
|
294
|
+
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
295
|
+
[
|
|
296
|
+
LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
|
|
297
|
+
LabelPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
|
|
298
|
+
]
|
|
299
|
+
"""
|
|
300
|
+
response = predict_gpu(
|
|
301
|
+
self.id,
|
|
302
|
+
body=PredictionRequest(
|
|
303
|
+
input_values=value if isinstance(value, list) else [value],
|
|
304
|
+
memoryset_override_id=self._memoryset_override_id,
|
|
305
|
+
expected_labels=(
|
|
306
|
+
expected_labels
|
|
307
|
+
if isinstance(expected_labels, list)
|
|
308
|
+
else [expected_labels]
|
|
309
|
+
if expected_labels is not None
|
|
310
|
+
else None
|
|
311
|
+
),
|
|
312
|
+
tags=list(tags),
|
|
313
|
+
),
|
|
314
|
+
)
|
|
315
|
+
predictions = [
|
|
316
|
+
LabelPrediction(
|
|
317
|
+
prediction_id=prediction.prediction_id,
|
|
318
|
+
label=prediction.label,
|
|
319
|
+
label_name=prediction.label_name,
|
|
320
|
+
confidence=prediction.confidence,
|
|
321
|
+
memoryset=self.memoryset,
|
|
322
|
+
model=self,
|
|
323
|
+
)
|
|
324
|
+
for prediction in response
|
|
325
|
+
]
|
|
326
|
+
self._last_prediction_was_batch = isinstance(value, list)
|
|
327
|
+
self._last_prediction = predictions[-1]
|
|
328
|
+
return predictions if isinstance(value, list) else predictions[0]
|
|
329
|
+
|
|
330
|
+
def predictions(
|
|
331
|
+
self,
|
|
332
|
+
limit: int = 100,
|
|
333
|
+
offset: int = 0,
|
|
334
|
+
tag: str | None = None,
|
|
335
|
+
sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
|
|
336
|
+
) -> list[LabelPrediction]:
|
|
337
|
+
"""
|
|
338
|
+
Get a list of predictions made by this model
|
|
339
|
+
|
|
340
|
+
Params:
|
|
341
|
+
limit: Optional maximum number of predictions to return
|
|
342
|
+
offset: Optional offset of the first prediction to return
|
|
343
|
+
tag: Optional tag to filter predictions by
|
|
344
|
+
sort: Optional list of columns and directions to sort the predictions by.
|
|
345
|
+
Predictions can be sorted by `timestamp` or `confidence`.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
List of label predictions
|
|
349
|
+
|
|
350
|
+
Examples:
|
|
351
|
+
Get the last 3 predictions:
|
|
352
|
+
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
353
|
+
[
|
|
354
|
+
LabeledPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
|
|
355
|
+
LabeledPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
|
|
356
|
+
LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am ecstatic'}),
|
|
357
|
+
]
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
Get second most confident prediction:
|
|
361
|
+
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
362
|
+
[LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am having a good day'})]
|
|
363
|
+
"""
|
|
364
|
+
predictions = list_predictions(
|
|
365
|
+
body=ListPredictionsRequest(
|
|
366
|
+
model_id=self.id,
|
|
367
|
+
limit=limit,
|
|
368
|
+
offset=offset,
|
|
369
|
+
sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
|
|
370
|
+
tag=tag,
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
return [
|
|
374
|
+
LabelPrediction(
|
|
375
|
+
prediction_id=prediction.prediction_id,
|
|
376
|
+
label=prediction.label,
|
|
377
|
+
label_name=prediction.label_name,
|
|
378
|
+
confidence=prediction.confidence,
|
|
379
|
+
memoryset=self.memoryset,
|
|
380
|
+
model=self,
|
|
381
|
+
telemetry=prediction,
|
|
382
|
+
)
|
|
383
|
+
for prediction in predictions
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
def evaluate(
|
|
387
|
+
self,
|
|
388
|
+
datasource: Datasource,
|
|
389
|
+
value_column: str = "value",
|
|
390
|
+
label_column: str = "label",
|
|
391
|
+
record_predictions: bool = False,
|
|
392
|
+
tags: set[str] | None = None,
|
|
393
|
+
) -> dict[str, float]:
|
|
394
|
+
"""
|
|
395
|
+
Evaluate the classification model on a given datasource
|
|
396
|
+
|
|
397
|
+
Params:
|
|
398
|
+
datasource: Datasource to evaluate the model on
|
|
399
|
+
value_column: Name of the column that contains the input values to the model
|
|
400
|
+
label_column: Name of the column containing the expected labels
|
|
401
|
+
record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
|
|
402
|
+
tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Dictionary with evaluation metrics
|
|
406
|
+
|
|
407
|
+
Examples:
|
|
408
|
+
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
409
|
+
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35 }
|
|
410
|
+
"""
|
|
411
|
+
response = create_evaluation(
|
|
412
|
+
self.id,
|
|
413
|
+
body=EvaluationRequest(
|
|
414
|
+
datasource_id=datasource.id,
|
|
415
|
+
datasource_label_column=label_column,
|
|
416
|
+
datasource_value_column=value_column,
|
|
417
|
+
memoryset_override_id=self._memoryset_override_id,
|
|
418
|
+
record_telemetry=record_predictions,
|
|
419
|
+
telemetry_tags=list(tags) if tags else None,
|
|
420
|
+
),
|
|
421
|
+
)
|
|
422
|
+
wait_for_task(response.task_id, description="Running evaluation")
|
|
423
|
+
response = get_evaluation(self.id, UUID(response.task_id))
|
|
424
|
+
assert response.result is not None
|
|
425
|
+
return response.result.to_dict()
|
|
426
|
+
|
|
427
|
+
def finetune(self, datasource: Datasource):
|
|
428
|
+
# do not document until implemented
|
|
429
|
+
raise NotImplementedError("Finetuning is not supported yet")
|
|
430
|
+
|
|
431
|
+
@contextmanager
|
|
432
|
+
def use_memoryset(self, memoryset_override: LabeledMemoryset) -> Generator[None, None, None]:
|
|
433
|
+
"""
|
|
434
|
+
Temporarily override the memoryset used by the model for predictions
|
|
435
|
+
|
|
436
|
+
Params:
|
|
437
|
+
memoryset_override: Memoryset to override the default memoryset with
|
|
438
|
+
|
|
439
|
+
Examples:
|
|
440
|
+
>>> with model.use_memoryset(LabeledMemoryset.open("my_other_memoryset")):
|
|
441
|
+
... predictions = model.predict("I am happy")
|
|
442
|
+
"""
|
|
443
|
+
self._memoryset_override_id = memoryset_override.id
|
|
444
|
+
yield
|
|
445
|
+
self._memoryset_override_id = None
|
|
446
|
+
|
|
447
|
+
@overload
|
|
448
|
+
def record_feedback(self, feedback: dict[str, Any]) -> None:
|
|
449
|
+
pass
|
|
450
|
+
|
|
451
|
+
@overload
|
|
452
|
+
def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
|
|
453
|
+
pass
|
|
454
|
+
|
|
455
|
+
def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
|
|
456
|
+
"""
|
|
457
|
+
Record feedback for a list of predictions.
|
|
458
|
+
|
|
459
|
+
We support recording feedback in several categories for each prediction. A
|
|
460
|
+
[`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
|
|
461
|
+
the first time feedback with a new name is recorded. Categories are global across models.
|
|
462
|
+
The value type of the category is inferred from the first recorded value. Subsequent
|
|
463
|
+
feedback for the same category must be of the same type.
|
|
464
|
+
|
|
465
|
+
Params:
|
|
466
|
+
feedback: Feedback to record, this should be dictionaries with the following keys:
|
|
467
|
+
|
|
468
|
+
- `category`: Name of the category under which to record the feedback.
|
|
469
|
+
- `value`: Feedback value to record, should be `True` for positive feedback and
|
|
470
|
+
`False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
|
|
471
|
+
where negative values indicate negative feedback and positive values indicate
|
|
472
|
+
positive feedback.
|
|
473
|
+
- `comment`: Optional comment to record with the feedback.
|
|
474
|
+
|
|
475
|
+
Examples:
|
|
476
|
+
Record whether predictions were correct or incorrect:
|
|
477
|
+
>>> model.record_feedback({
|
|
478
|
+
... "prediction": p.prediction_id,
|
|
479
|
+
... "category": "correct",
|
|
480
|
+
... "value": p.label == p.expected_label,
|
|
481
|
+
... } for p in predictions)
|
|
482
|
+
|
|
483
|
+
Record star rating as normalized continuous score between `-1.0` and `+1.0`:
|
|
484
|
+
>>> model.record_feedback({
|
|
485
|
+
... "prediction": "123e4567-e89b-12d3-a456-426614174000",
|
|
486
|
+
... "category": "rating",
|
|
487
|
+
... "value": -0.5,
|
|
488
|
+
... "comment": "2 stars"
|
|
489
|
+
... })
|
|
490
|
+
|
|
491
|
+
Raises:
|
|
492
|
+
ValueError: If the value does not match previous value types for the category, or is a
|
|
493
|
+
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
494
|
+
"""
|
|
495
|
+
record_prediction_feedback(
|
|
496
|
+
body=[
|
|
497
|
+
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
498
|
+
],
|
|
499
|
+
)
|