orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +31 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +601 -129
- orca_sdk/classification_model_test.py +415 -117
- orca_sdk/client.py +3787 -0
- orca_sdk/conftest.py +184 -38
- orca_sdk/credentials.py +162 -20
- orca_sdk/credentials_test.py +100 -16
- orca_sdk/datasource.py +268 -68
- orca_sdk/datasource_test.py +266 -18
- orca_sdk/embedding_model.py +434 -82
- orca_sdk/embedding_model_test.py +66 -33
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1690 -324
- orca_sdk/memoryset_test.py +456 -119
- orca_sdk/regression_model.py +694 -0
- orca_sdk/regression_model_test.py +378 -0
- orca_sdk/telemetry.py +460 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/classification_model.py
CHANGED
|
@@ -3,39 +3,67 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from contextlib import contextmanager
|
|
5
5
|
from datetime import datetime
|
|
6
|
-
from typing import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
get_evaluation,
|
|
14
|
-
get_model,
|
|
15
|
-
list_models,
|
|
16
|
-
list_predictions,
|
|
17
|
-
predict_gpu,
|
|
18
|
-
record_prediction_feedback,
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
Generator,
|
|
9
|
+
Iterable,
|
|
10
|
+
Literal,
|
|
11
|
+
cast,
|
|
12
|
+
overload,
|
|
19
13
|
)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
14
|
+
|
|
15
|
+
from datasets import Dataset
|
|
16
|
+
|
|
17
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
18
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
19
|
+
from .async_client import OrcaAsyncClient
|
|
20
|
+
from .client import (
|
|
21
|
+
BootstrapClassificationModelMeta,
|
|
22
|
+
BootstrapClassificationModelResult,
|
|
23
|
+
ClassificationModelMetadata,
|
|
24
|
+
OrcaClient,
|
|
25
|
+
PredictiveModelUpdate,
|
|
26
|
+
RACHeadType,
|
|
25
27
|
)
|
|
26
|
-
from .
|
|
27
|
-
|
|
28
|
+
from .datasource import Datasource
|
|
29
|
+
from .job import Job
|
|
30
|
+
from .memoryset import (
|
|
31
|
+
FilterItem,
|
|
32
|
+
FilterItemTuple,
|
|
33
|
+
LabeledMemoryset,
|
|
34
|
+
_is_metric_column,
|
|
35
|
+
_parse_filter_item_from_tuple,
|
|
28
36
|
)
|
|
29
|
-
from .
|
|
30
|
-
|
|
37
|
+
from .telemetry import (
|
|
38
|
+
ClassificationPrediction,
|
|
39
|
+
TelemetryMode,
|
|
40
|
+
_get_telemetry_config,
|
|
41
|
+
_parse_feedback,
|
|
31
42
|
)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BootstrappedClassificationModel:
|
|
46
|
+
|
|
47
|
+
datasource: Datasource | None
|
|
48
|
+
memoryset: LabeledMemoryset | None
|
|
49
|
+
classification_model: ClassificationModel | None
|
|
50
|
+
agent_output: BootstrapClassificationModelResult | None
|
|
51
|
+
|
|
52
|
+
def __init__(self, metadata: BootstrapClassificationModelMeta):
|
|
53
|
+
self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
|
|
54
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
|
|
55
|
+
self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
|
|
56
|
+
self.agent_output = metadata["agent_output"]
|
|
57
|
+
|
|
58
|
+
def __repr__(self):
|
|
59
|
+
return (
|
|
60
|
+
"BootstrappedClassificationModel({\n"
|
|
61
|
+
f" datasource: {self.datasource},\n"
|
|
62
|
+
f" memoryset: {self.memoryset},\n"
|
|
63
|
+
f" classification_model: {self.classification_model},\n"
|
|
64
|
+
f" agent_output: {self.agent_output},\n"
|
|
65
|
+
"})"
|
|
66
|
+
)
|
|
39
67
|
|
|
40
68
|
|
|
41
69
|
class ClassificationModel:
|
|
@@ -45,17 +73,20 @@ class ClassificationModel:
|
|
|
45
73
|
Attributes:
|
|
46
74
|
id: Unique identifier for the model
|
|
47
75
|
name: Unique name of the model
|
|
76
|
+
description: Optional description of the model
|
|
48
77
|
memoryset: Memoryset that the model uses
|
|
49
78
|
head_type: Classification head type of the model
|
|
50
79
|
num_classes: Number of distinct classes the model can predict
|
|
51
80
|
memory_lookup_count: Number of memories the model uses for each prediction
|
|
52
81
|
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
53
82
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
83
|
+
locked: Whether the model is locked to prevent accidental deletion
|
|
54
84
|
created_at: When the model was created
|
|
55
85
|
"""
|
|
56
86
|
|
|
57
87
|
id: str
|
|
58
88
|
name: str
|
|
89
|
+
description: str | None
|
|
59
90
|
memoryset: LabeledMemoryset
|
|
60
91
|
head_type: RACHeadType
|
|
61
92
|
num_classes: int
|
|
@@ -63,23 +94,26 @@ class ClassificationModel:
|
|
|
63
94
|
weigh_memories: bool | None
|
|
64
95
|
min_memory_weight: float | None
|
|
65
96
|
version: int
|
|
97
|
+
locked: bool
|
|
66
98
|
created_at: datetime
|
|
67
99
|
|
|
68
|
-
def __init__(self, metadata:
|
|
100
|
+
def __init__(self, metadata: ClassificationModelMetadata):
|
|
69
101
|
# for internal use only, do not document
|
|
70
|
-
self.id = metadata
|
|
71
|
-
self.name = metadata
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
self.
|
|
79
|
-
self.
|
|
102
|
+
self.id = metadata["id"]
|
|
103
|
+
self.name = metadata["name"]
|
|
104
|
+
self.description = metadata["description"]
|
|
105
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
|
|
106
|
+
self.head_type = metadata["head_type"]
|
|
107
|
+
self.num_classes = metadata["num_classes"]
|
|
108
|
+
self.memory_lookup_count = metadata["memory_lookup_count"]
|
|
109
|
+
self.weigh_memories = metadata["weigh_memories"]
|
|
110
|
+
self.min_memory_weight = metadata["min_memory_weight"]
|
|
111
|
+
self.version = metadata["version"]
|
|
112
|
+
self.locked = metadata["locked"]
|
|
113
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
80
114
|
|
|
81
115
|
self._memoryset_override_id: str | None = None
|
|
82
|
-
self._last_prediction:
|
|
116
|
+
self._last_prediction: ClassificationPrediction | None = None
|
|
83
117
|
self._last_prediction_was_batch: bool = False
|
|
84
118
|
|
|
85
119
|
def __eq__(self, other) -> bool:
|
|
@@ -97,7 +131,7 @@ class ClassificationModel:
|
|
|
97
131
|
)
|
|
98
132
|
|
|
99
133
|
@property
|
|
100
|
-
def last_prediction(self) ->
|
|
134
|
+
def last_prediction(self) -> ClassificationPrediction:
|
|
101
135
|
"""
|
|
102
136
|
Last prediction made by the model
|
|
103
137
|
|
|
@@ -119,8 +153,9 @@ class ClassificationModel:
|
|
|
119
153
|
cls,
|
|
120
154
|
name: str,
|
|
121
155
|
memoryset: LabeledMemoryset,
|
|
122
|
-
head_type:
|
|
156
|
+
head_type: RACHeadType = "KNN",
|
|
123
157
|
*,
|
|
158
|
+
description: str | None = None,
|
|
124
159
|
num_classes: int | None = None,
|
|
125
160
|
memory_lookup_count: int | None = None,
|
|
126
161
|
weigh_memories: bool = True,
|
|
@@ -141,6 +176,8 @@ class ClassificationModel:
|
|
|
141
176
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
142
177
|
if_exists: What to do if a model with the same name already exists, defaults to
|
|
143
178
|
`"error"`. Other option is `"open"` to open the existing model.
|
|
179
|
+
description: Optional description for the model, this will be used in agentic flows,
|
|
180
|
+
so make sure it is concise and describes the purpose of your model.
|
|
144
181
|
|
|
145
182
|
Returns:
|
|
146
183
|
Handle to the new model in the OrcaCloud
|
|
@@ -182,16 +219,19 @@ class ClassificationModel:
|
|
|
182
219
|
|
|
183
220
|
return existing
|
|
184
221
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
222
|
+
client = OrcaClient._resolve_client()
|
|
223
|
+
metadata = client.POST(
|
|
224
|
+
"/classification_model",
|
|
225
|
+
json={
|
|
226
|
+
"name": name,
|
|
227
|
+
"memoryset_name_or_id": memoryset.id,
|
|
228
|
+
"head_type": head_type,
|
|
229
|
+
"memory_lookup_count": memory_lookup_count,
|
|
230
|
+
"num_classes": num_classes,
|
|
231
|
+
"weigh_memories": weigh_memories,
|
|
232
|
+
"min_memory_weight": min_memory_weight,
|
|
233
|
+
"description": description,
|
|
234
|
+
},
|
|
195
235
|
)
|
|
196
236
|
return cls(metadata)
|
|
197
237
|
|
|
@@ -209,7 +249,8 @@ class ClassificationModel:
|
|
|
209
249
|
Raises:
|
|
210
250
|
LookupError: If the classification model does not exist
|
|
211
251
|
"""
|
|
212
|
-
|
|
252
|
+
client = OrcaClient._resolve_client()
|
|
253
|
+
return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
|
|
213
254
|
|
|
214
255
|
@classmethod
|
|
215
256
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -236,7 +277,8 @@ class ClassificationModel:
|
|
|
236
277
|
Returns:
|
|
237
278
|
List of handles to all classification models in the OrcaCloud
|
|
238
279
|
"""
|
|
239
|
-
|
|
280
|
+
client = OrcaClient._resolve_client()
|
|
281
|
+
return [cls(metadata) for metadata in client.GET("/classification_model")]
|
|
240
282
|
|
|
241
283
|
@classmethod
|
|
242
284
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -255,73 +297,334 @@ class ClassificationModel:
|
|
|
255
297
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
256
298
|
"""
|
|
257
299
|
try:
|
|
258
|
-
|
|
300
|
+
client = OrcaClient._resolve_client()
|
|
301
|
+
client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
259
302
|
logging.info(f"Deleted model {name_or_id}")
|
|
260
303
|
except LookupError:
|
|
261
304
|
if if_not_exists == "error":
|
|
262
305
|
raise
|
|
263
306
|
|
|
307
|
+
def refresh(self):
|
|
308
|
+
"""Refresh the model data from the OrcaCloud"""
|
|
309
|
+
self.__dict__.update(self.open(self.name).__dict__)
|
|
310
|
+
|
|
311
|
+
def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
|
|
312
|
+
"""
|
|
313
|
+
Update editable attributes of the model.
|
|
314
|
+
|
|
315
|
+
Note:
|
|
316
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
317
|
+
|
|
318
|
+
Params:
|
|
319
|
+
description: Value to set for the description
|
|
320
|
+
locked: Value to set for the locked status
|
|
321
|
+
|
|
322
|
+
Examples:
|
|
323
|
+
Update the description:
|
|
324
|
+
>>> model.set(description="New description")
|
|
325
|
+
|
|
326
|
+
Remove description:
|
|
327
|
+
>>> model.set(description=None)
|
|
328
|
+
|
|
329
|
+
Lock the model:
|
|
330
|
+
>>> model.set(locked=True)
|
|
331
|
+
"""
|
|
332
|
+
update: PredictiveModelUpdate = {}
|
|
333
|
+
if description is not UNSET:
|
|
334
|
+
update["description"] = description
|
|
335
|
+
if locked is not UNSET:
|
|
336
|
+
update["locked"] = locked
|
|
337
|
+
client = OrcaClient._resolve_client()
|
|
338
|
+
client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
339
|
+
self.refresh()
|
|
340
|
+
|
|
341
|
+
def lock(self) -> None:
|
|
342
|
+
"""Lock the model to prevent accidental deletion"""
|
|
343
|
+
self.set(locked=True)
|
|
344
|
+
|
|
345
|
+
def unlock(self) -> None:
|
|
346
|
+
"""Unlock the model to allow deletion"""
|
|
347
|
+
self.set(locked=False)
|
|
348
|
+
|
|
264
349
|
@overload
|
|
265
350
|
def predict(
|
|
266
|
-
self,
|
|
267
|
-
|
|
351
|
+
self,
|
|
352
|
+
value: list[str],
|
|
353
|
+
expected_labels: list[int] | None = None,
|
|
354
|
+
filters: list[FilterItemTuple] = [],
|
|
355
|
+
tags: set[str] | None = None,
|
|
356
|
+
save_telemetry: TelemetryMode = "on",
|
|
357
|
+
prompt: str | None = None,
|
|
358
|
+
use_lookup_cache: bool = True,
|
|
359
|
+
timeout_seconds: int = 10,
|
|
360
|
+
) -> list[ClassificationPrediction]:
|
|
268
361
|
pass
|
|
269
362
|
|
|
270
363
|
@overload
|
|
271
|
-
def predict(
|
|
364
|
+
def predict(
|
|
365
|
+
self,
|
|
366
|
+
value: str,
|
|
367
|
+
expected_labels: int | None = None,
|
|
368
|
+
filters: list[FilterItemTuple] = [],
|
|
369
|
+
tags: set[str] | None = None,
|
|
370
|
+
save_telemetry: TelemetryMode = "on",
|
|
371
|
+
prompt: str | None = None,
|
|
372
|
+
use_lookup_cache: bool = True,
|
|
373
|
+
timeout_seconds: int = 10,
|
|
374
|
+
) -> ClassificationPrediction:
|
|
272
375
|
pass
|
|
273
376
|
|
|
274
377
|
def predict(
|
|
275
|
-
self,
|
|
276
|
-
|
|
378
|
+
self,
|
|
379
|
+
value: list[str] | str,
|
|
380
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
381
|
+
filters: list[FilterItemTuple] = [],
|
|
382
|
+
tags: set[str] | None = None,
|
|
383
|
+
save_telemetry: TelemetryMode = "on",
|
|
384
|
+
prompt: str | None = None,
|
|
385
|
+
use_lookup_cache: bool = True,
|
|
386
|
+
timeout_seconds: int = 10,
|
|
387
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
277
388
|
"""
|
|
278
389
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
279
390
|
|
|
280
391
|
Params:
|
|
281
392
|
value: Value(s) to get predict the labels of
|
|
282
393
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
394
|
+
filters: Optional filters to apply during memory lookup
|
|
283
395
|
tags: Tags to add to the prediction(s)
|
|
396
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
397
|
+
* `"off"`: Do not save telemetry
|
|
398
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
399
|
+
environment variable is set.
|
|
400
|
+
* `"sync"`: Save telemetry synchronously
|
|
401
|
+
* `"async"`: Save telemetry asynchronously
|
|
402
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
403
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
404
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
284
405
|
|
|
285
406
|
Returns:
|
|
286
407
|
Label prediction or list of label predictions
|
|
287
408
|
|
|
409
|
+
Raises:
|
|
410
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
411
|
+
TimeoutError: If the request times out after the specified duration
|
|
412
|
+
|
|
288
413
|
Examples:
|
|
289
414
|
Predict the label for a single value:
|
|
290
415
|
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
291
|
-
|
|
416
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
292
417
|
|
|
293
418
|
Predict the labels for a list of values:
|
|
294
419
|
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
295
420
|
[
|
|
296
|
-
|
|
297
|
-
|
|
421
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
422
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
298
423
|
]
|
|
424
|
+
|
|
425
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
426
|
+
>>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
427
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
299
428
|
"""
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
429
|
+
|
|
430
|
+
if timeout_seconds <= 0:
|
|
431
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
432
|
+
|
|
433
|
+
parsed_filters = [
|
|
434
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
438
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
439
|
+
|
|
440
|
+
if isinstance(expected_labels, int):
|
|
441
|
+
expected_labels = [expected_labels]
|
|
442
|
+
elif isinstance(expected_labels, str):
|
|
443
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
444
|
+
elif isinstance(expected_labels, list):
|
|
445
|
+
expected_labels = [
|
|
446
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
447
|
+
for label in expected_labels
|
|
448
|
+
]
|
|
449
|
+
|
|
450
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
451
|
+
client = OrcaClient._resolve_client()
|
|
452
|
+
response = client.POST(
|
|
453
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
454
|
+
params={"name_or_id": self.id},
|
|
455
|
+
json={
|
|
456
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
457
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
458
|
+
"expected_labels": expected_labels,
|
|
459
|
+
"tags": list(tags or set()),
|
|
460
|
+
"save_telemetry": telemetry_on,
|
|
461
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
462
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
463
|
+
"prompt": prompt,
|
|
464
|
+
"use_lookup_cache": use_lookup_cache,
|
|
465
|
+
},
|
|
466
|
+
timeout=timeout_seconds,
|
|
314
467
|
)
|
|
468
|
+
|
|
469
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
470
|
+
raise RuntimeError("Failed to save prediction to database.")
|
|
471
|
+
|
|
315
472
|
predictions = [
|
|
316
|
-
|
|
317
|
-
prediction_id=prediction
|
|
318
|
-
label=prediction
|
|
319
|
-
label_name=prediction
|
|
320
|
-
|
|
473
|
+
ClassificationPrediction(
|
|
474
|
+
prediction_id=prediction["prediction_id"],
|
|
475
|
+
label=prediction["label"],
|
|
476
|
+
label_name=prediction["label_name"],
|
|
477
|
+
score=None,
|
|
478
|
+
confidence=prediction["confidence"],
|
|
479
|
+
anomaly_score=prediction["anomaly_score"],
|
|
321
480
|
memoryset=self.memoryset,
|
|
322
481
|
model=self,
|
|
482
|
+
logits=prediction["logits"],
|
|
483
|
+
input_value=input_value,
|
|
323
484
|
)
|
|
324
|
-
for prediction in response
|
|
485
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
486
|
+
]
|
|
487
|
+
self._last_prediction_was_batch = isinstance(value, list)
|
|
488
|
+
self._last_prediction = predictions[-1]
|
|
489
|
+
return predictions if isinstance(value, list) else predictions[0]
|
|
490
|
+
|
|
491
|
+
@overload
|
|
492
|
+
async def apredict(
|
|
493
|
+
self,
|
|
494
|
+
value: list[str],
|
|
495
|
+
expected_labels: list[int] | None = None,
|
|
496
|
+
filters: list[FilterItemTuple] = [],
|
|
497
|
+
tags: set[str] | None = None,
|
|
498
|
+
save_telemetry: TelemetryMode = "on",
|
|
499
|
+
prompt: str | None = None,
|
|
500
|
+
use_lookup_cache: bool = True,
|
|
501
|
+
timeout_seconds: int = 10,
|
|
502
|
+
) -> list[ClassificationPrediction]:
|
|
503
|
+
pass
|
|
504
|
+
|
|
505
|
+
@overload
|
|
506
|
+
async def apredict(
|
|
507
|
+
self,
|
|
508
|
+
value: str,
|
|
509
|
+
expected_labels: int | None = None,
|
|
510
|
+
filters: list[FilterItemTuple] = [],
|
|
511
|
+
tags: set[str] | None = None,
|
|
512
|
+
save_telemetry: TelemetryMode = "on",
|
|
513
|
+
prompt: str | None = None,
|
|
514
|
+
use_lookup_cache: bool = True,
|
|
515
|
+
timeout_seconds: int = 10,
|
|
516
|
+
) -> ClassificationPrediction:
|
|
517
|
+
pass
|
|
518
|
+
|
|
519
|
+
async def apredict(
|
|
520
|
+
self,
|
|
521
|
+
value: list[str] | str,
|
|
522
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
523
|
+
filters: list[FilterItemTuple] = [],
|
|
524
|
+
tags: set[str] | None = None,
|
|
525
|
+
save_telemetry: TelemetryMode = "on",
|
|
526
|
+
prompt: str | None = None,
|
|
527
|
+
use_lookup_cache: bool = True,
|
|
528
|
+
timeout_seconds: int = 10,
|
|
529
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
530
|
+
"""
|
|
531
|
+
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
532
|
+
|
|
533
|
+
Params:
|
|
534
|
+
value: Value(s) to get predict the labels of
|
|
535
|
+
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
536
|
+
filters: Optional filters to apply during memory lookup
|
|
537
|
+
tags: Tags to add to the prediction(s)
|
|
538
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
539
|
+
* `"off"`: Do not save telemetry
|
|
540
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
541
|
+
environment variable is set.
|
|
542
|
+
* `"sync"`: Save telemetry synchronously
|
|
543
|
+
* `"async"`: Save telemetry asynchronously
|
|
544
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
545
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
546
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
Label prediction or list of label predictions.
|
|
550
|
+
|
|
551
|
+
Raises:
|
|
552
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
553
|
+
TimeoutError: If the request times out after the specified duration
|
|
554
|
+
|
|
555
|
+
Examples:
|
|
556
|
+
Predict the label for a single value:
|
|
557
|
+
>>> prediction = await model.apredict("I am happy", tags={"test"})
|
|
558
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
559
|
+
|
|
560
|
+
Predict the labels for a list of values:
|
|
561
|
+
>>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
562
|
+
[
|
|
563
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
564
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
565
|
+
]
|
|
566
|
+
|
|
567
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
568
|
+
>>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
569
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
if timeout_seconds <= 0:
|
|
573
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
574
|
+
|
|
575
|
+
parsed_filters = [
|
|
576
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
577
|
+
]
|
|
578
|
+
|
|
579
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
580
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
581
|
+
|
|
582
|
+
if isinstance(expected_labels, int):
|
|
583
|
+
expected_labels = [expected_labels]
|
|
584
|
+
elif isinstance(expected_labels, str):
|
|
585
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
586
|
+
elif isinstance(expected_labels, list):
|
|
587
|
+
expected_labels = [
|
|
588
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
589
|
+
for label in expected_labels
|
|
590
|
+
]
|
|
591
|
+
|
|
592
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
593
|
+
client = OrcaAsyncClient._resolve_client()
|
|
594
|
+
response = await client.POST(
|
|
595
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
596
|
+
params={"name_or_id": self.id},
|
|
597
|
+
json={
|
|
598
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
599
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
600
|
+
"expected_labels": expected_labels,
|
|
601
|
+
"tags": list(tags or set()),
|
|
602
|
+
"save_telemetry": telemetry_on,
|
|
603
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
604
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
605
|
+
"prompt": prompt,
|
|
606
|
+
"use_lookup_cache": use_lookup_cache,
|
|
607
|
+
},
|
|
608
|
+
timeout=timeout_seconds,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
612
|
+
raise RuntimeError("Failed to save prediction to database.")
|
|
613
|
+
|
|
614
|
+
predictions = [
|
|
615
|
+
ClassificationPrediction(
|
|
616
|
+
prediction_id=prediction["prediction_id"],
|
|
617
|
+
label=prediction["label"],
|
|
618
|
+
label_name=prediction["label_name"],
|
|
619
|
+
score=None,
|
|
620
|
+
confidence=prediction["confidence"],
|
|
621
|
+
anomaly_score=prediction["anomaly_score"],
|
|
622
|
+
memoryset=self.memoryset,
|
|
623
|
+
model=self,
|
|
624
|
+
logits=prediction["logits"],
|
|
625
|
+
input_value=input_value,
|
|
626
|
+
)
|
|
627
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
325
628
|
]
|
|
326
629
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
327
630
|
self._last_prediction = predictions[-1]
|
|
@@ -332,8 +635,9 @@ class ClassificationModel:
|
|
|
332
635
|
limit: int = 100,
|
|
333
636
|
offset: int = 0,
|
|
334
637
|
tag: str | None = None,
|
|
335
|
-
sort: list[tuple[
|
|
336
|
-
|
|
638
|
+
sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
|
|
639
|
+
expected_label_match: bool | None = None,
|
|
640
|
+
) -> list[ClassificationPrediction]:
|
|
337
641
|
"""
|
|
338
642
|
Get a list of predictions made by this model
|
|
339
643
|
|
|
@@ -343,6 +647,8 @@ class ClassificationModel:
|
|
|
343
647
|
tag: Optional tag to filter predictions by
|
|
344
648
|
sort: Optional list of columns and directions to sort the predictions by.
|
|
345
649
|
Predictions can be sorted by `timestamp` or `confidence`.
|
|
650
|
+
expected_label_match: Optional filter to only include predictions where the expected
|
|
651
|
+
label does (`True`) or doesn't (`False`) match the predicted label
|
|
346
652
|
|
|
347
653
|
Returns:
|
|
348
654
|
List of label predictions
|
|
@@ -351,78 +657,212 @@ class ClassificationModel:
|
|
|
351
657
|
Get the last 3 predictions:
|
|
352
658
|
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
353
659
|
[
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
660
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
661
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
662
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
|
|
357
663
|
]
|
|
358
664
|
|
|
359
665
|
|
|
360
666
|
Get second most confident prediction:
|
|
361
667
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
362
|
-
[
|
|
668
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
|
|
669
|
+
|
|
670
|
+
Get predictions where the expected label doesn't match the predicted label:
|
|
671
|
+
>>> predictions = model.predictions(expected_label_match=False)
|
|
672
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
363
673
|
"""
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
674
|
+
client = OrcaClient._resolve_client()
|
|
675
|
+
predictions = client.POST(
|
|
676
|
+
"/telemetry/prediction",
|
|
677
|
+
json={
|
|
678
|
+
"model_id": self.id,
|
|
679
|
+
"limit": limit,
|
|
680
|
+
"offset": offset,
|
|
681
|
+
"sort": [list(sort_item) for sort_item in sort],
|
|
682
|
+
"tag": tag,
|
|
683
|
+
"expected_label_match": expected_label_match,
|
|
684
|
+
},
|
|
372
685
|
)
|
|
373
686
|
return [
|
|
374
|
-
|
|
375
|
-
prediction_id=prediction
|
|
376
|
-
label=prediction
|
|
377
|
-
label_name=prediction
|
|
378
|
-
|
|
687
|
+
ClassificationPrediction(
|
|
688
|
+
prediction_id=prediction["prediction_id"],
|
|
689
|
+
label=prediction["label"],
|
|
690
|
+
label_name=prediction["label_name"],
|
|
691
|
+
score=None,
|
|
692
|
+
confidence=prediction["confidence"],
|
|
693
|
+
anomaly_score=prediction["anomaly_score"],
|
|
379
694
|
memoryset=self.memoryset,
|
|
380
695
|
model=self,
|
|
381
696
|
telemetry=prediction,
|
|
382
697
|
)
|
|
383
698
|
for prediction in predictions
|
|
699
|
+
if "label" in prediction
|
|
384
700
|
]
|
|
385
701
|
|
|
386
|
-
def
|
|
702
|
+
def _evaluate_datasource(
|
|
387
703
|
self,
|
|
388
704
|
datasource: Datasource,
|
|
705
|
+
value_column: str,
|
|
706
|
+
label_column: str,
|
|
707
|
+
record_predictions: bool,
|
|
708
|
+
tags: set[str] | None,
|
|
709
|
+
background: bool = False,
|
|
710
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
711
|
+
client = OrcaClient._resolve_client()
|
|
712
|
+
response = client.POST(
|
|
713
|
+
"/classification_model/{model_name_or_id}/evaluation",
|
|
714
|
+
params={"model_name_or_id": self.id},
|
|
715
|
+
json={
|
|
716
|
+
"datasource_name_or_id": datasource.id,
|
|
717
|
+
"datasource_label_column": label_column,
|
|
718
|
+
"datasource_value_column": value_column,
|
|
719
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
720
|
+
"record_telemetry": record_predictions,
|
|
721
|
+
"telemetry_tags": list(tags) if tags else None,
|
|
722
|
+
},
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def get_value():
|
|
726
|
+
client = OrcaClient._resolve_client()
|
|
727
|
+
res = client.GET(
|
|
728
|
+
"/classification_model/{model_name_or_id}/evaluation/{task_id}",
|
|
729
|
+
params={"model_name_or_id": self.id, "task_id": response["task_id"]},
|
|
730
|
+
)
|
|
731
|
+
assert res["result"] is not None
|
|
732
|
+
return ClassificationMetrics(
|
|
733
|
+
coverage=res["result"].get("coverage"),
|
|
734
|
+
f1_score=res["result"].get("f1_score"),
|
|
735
|
+
accuracy=res["result"].get("accuracy"),
|
|
736
|
+
loss=res["result"].get("loss"),
|
|
737
|
+
anomaly_score_mean=res["result"].get("anomaly_score_mean"),
|
|
738
|
+
anomaly_score_median=res["result"].get("anomaly_score_median"),
|
|
739
|
+
anomaly_score_variance=res["result"].get("anomaly_score_variance"),
|
|
740
|
+
roc_auc=res["result"].get("roc_auc"),
|
|
741
|
+
pr_auc=res["result"].get("pr_auc"),
|
|
742
|
+
pr_curve=res["result"].get("pr_curve"),
|
|
743
|
+
roc_curve=res["result"].get("roc_curve"),
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
job = Job(response["task_id"], get_value)
|
|
747
|
+
return job if background else job.result()
|
|
748
|
+
|
|
749
|
+
def _evaluate_dataset(
|
|
750
|
+
self,
|
|
751
|
+
dataset: Dataset,
|
|
752
|
+
value_column: str,
|
|
753
|
+
label_column: str,
|
|
754
|
+
record_predictions: bool,
|
|
755
|
+
tags: set[str],
|
|
756
|
+
batch_size: int,
|
|
757
|
+
) -> ClassificationMetrics:
|
|
758
|
+
if len(dataset) == 0:
|
|
759
|
+
raise ValueError("Evaluation dataset cannot be empty")
|
|
760
|
+
|
|
761
|
+
if any(x is None for x in dataset[label_column]):
|
|
762
|
+
raise ValueError("Evaluation dataset cannot contain None values in the label column")
|
|
763
|
+
|
|
764
|
+
predictions = [
|
|
765
|
+
prediction
|
|
766
|
+
for i in range(0, len(dataset), batch_size)
|
|
767
|
+
for prediction in self.predict(
|
|
768
|
+
dataset[i : i + batch_size][value_column],
|
|
769
|
+
expected_labels=dataset[i : i + batch_size][label_column],
|
|
770
|
+
tags=tags,
|
|
771
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
772
|
+
)
|
|
773
|
+
]
|
|
774
|
+
|
|
775
|
+
return calculate_classification_metrics(
|
|
776
|
+
expected_labels=dataset[label_column],
|
|
777
|
+
logits=[p.logits for p in predictions],
|
|
778
|
+
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
779
|
+
include_curves=True,
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
@overload
|
|
783
|
+
def evaluate(
|
|
784
|
+
self,
|
|
785
|
+
data: Datasource | Dataset,
|
|
786
|
+
*,
|
|
389
787
|
value_column: str = "value",
|
|
390
788
|
label_column: str = "label",
|
|
391
789
|
record_predictions: bool = False,
|
|
392
|
-
tags: set[str]
|
|
393
|
-
|
|
790
|
+
tags: set[str] = {"evaluation"},
|
|
791
|
+
batch_size: int = 100,
|
|
792
|
+
background: Literal[True],
|
|
793
|
+
) -> Job[ClassificationMetrics]:
|
|
794
|
+
pass
|
|
795
|
+
|
|
796
|
+
@overload
|
|
797
|
+
def evaluate(
|
|
798
|
+
self,
|
|
799
|
+
data: Datasource | Dataset,
|
|
800
|
+
*,
|
|
801
|
+
value_column: str = "value",
|
|
802
|
+
label_column: str = "label",
|
|
803
|
+
record_predictions: bool = False,
|
|
804
|
+
tags: set[str] = {"evaluation"},
|
|
805
|
+
batch_size: int = 100,
|
|
806
|
+
background: Literal[False] = False,
|
|
807
|
+
) -> ClassificationMetrics:
|
|
808
|
+
pass
|
|
809
|
+
|
|
810
|
+
def evaluate(
|
|
811
|
+
self,
|
|
812
|
+
data: Datasource | Dataset,
|
|
813
|
+
*,
|
|
814
|
+
value_column: str = "value",
|
|
815
|
+
label_column: str = "label",
|
|
816
|
+
record_predictions: bool = False,
|
|
817
|
+
tags: set[str] = {"evaluation"},
|
|
818
|
+
batch_size: int = 100,
|
|
819
|
+
background: bool = False,
|
|
820
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
394
821
|
"""
|
|
395
|
-
Evaluate the classification model on a given datasource
|
|
822
|
+
Evaluate the classification model on a given dataset or datasource
|
|
396
823
|
|
|
397
824
|
Params:
|
|
398
|
-
|
|
825
|
+
data: Dataset or Datasource to evaluate the model on
|
|
399
826
|
value_column: Name of the column that contains the input values to the model
|
|
400
827
|
label_column: Name of the column containing the expected labels
|
|
401
|
-
record_predictions: Whether to record [`
|
|
402
|
-
tags: Optional tags to add to the recorded [`
|
|
828
|
+
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
829
|
+
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
830
|
+
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
831
|
+
background: Whether to run the operation in the background and return a job handle
|
|
403
832
|
|
|
404
833
|
Returns:
|
|
405
|
-
|
|
834
|
+
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
406
835
|
|
|
407
836
|
Examples:
|
|
408
837
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
409
|
-
{
|
|
838
|
+
ClassificationMetrics({
|
|
839
|
+
accuracy: 0.8500,
|
|
840
|
+
f1_score: 0.8500,
|
|
841
|
+
roc_auc: 0.8500,
|
|
842
|
+
pr_auc: 0.8500,
|
|
843
|
+
anomaly_score: 0.3500 ± 0.0500,
|
|
844
|
+
})
|
|
410
845
|
"""
|
|
411
|
-
|
|
412
|
-
self.
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
846
|
+
if isinstance(data, Datasource):
|
|
847
|
+
return self._evaluate_datasource(
|
|
848
|
+
datasource=data,
|
|
849
|
+
value_column=value_column,
|
|
850
|
+
label_column=label_column,
|
|
851
|
+
record_predictions=record_predictions,
|
|
852
|
+
tags=tags,
|
|
853
|
+
background=background,
|
|
854
|
+
)
|
|
855
|
+
elif isinstance(data, Dataset):
|
|
856
|
+
return self._evaluate_dataset(
|
|
857
|
+
dataset=data,
|
|
858
|
+
value_column=value_column,
|
|
859
|
+
label_column=label_column,
|
|
860
|
+
record_predictions=record_predictions,
|
|
861
|
+
tags=tags,
|
|
862
|
+
batch_size=batch_size,
|
|
863
|
+
)
|
|
864
|
+
else:
|
|
865
|
+
raise ValueError(f"Invalid data type: {type(data)}")
|
|
426
866
|
|
|
427
867
|
def finetune(self, datasource: Datasource):
|
|
428
868
|
# do not document until implemented
|
|
@@ -492,8 +932,40 @@ class ClassificationModel:
|
|
|
492
932
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
493
933
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
494
934
|
"""
|
|
495
|
-
|
|
496
|
-
|
|
935
|
+
client = OrcaClient._resolve_client()
|
|
936
|
+
client.PUT(
|
|
937
|
+
"/telemetry/prediction/feedback",
|
|
938
|
+
json=[
|
|
497
939
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
498
940
|
],
|
|
499
941
|
)
|
|
942
|
+
|
|
943
|
+
@staticmethod
|
|
944
|
+
def bootstrap_model(
|
|
945
|
+
model_description: str,
|
|
946
|
+
label_names: list[str],
|
|
947
|
+
initial_examples: list[tuple[str, str]],
|
|
948
|
+
num_examples_per_label: int,
|
|
949
|
+
background: bool = False,
|
|
950
|
+
) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
|
|
951
|
+
client = OrcaClient._resolve_client()
|
|
952
|
+
response = client.POST(
|
|
953
|
+
"/agents/bootstrap_classification_model",
|
|
954
|
+
json={
|
|
955
|
+
"model_description": model_description,
|
|
956
|
+
"label_names": label_names,
|
|
957
|
+
"initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
|
|
958
|
+
"num_examples_per_label": num_examples_per_label,
|
|
959
|
+
},
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
def get_result() -> BootstrappedClassificationModel:
|
|
963
|
+
client = OrcaClient._resolve_client()
|
|
964
|
+
res = client.GET(
|
|
965
|
+
"/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
|
|
966
|
+
)
|
|
967
|
+
assert res["result"] is not None
|
|
968
|
+
return BootstrappedClassificationModel(res["result"])
|
|
969
|
+
|
|
970
|
+
job = Job(response["task_id"], get_result)
|
|
971
|
+
return job if background else job.result()
|