orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/classification_model.py
CHANGED
|
@@ -4,38 +4,58 @@ import logging
|
|
|
4
4
|
from contextlib import contextmanager
|
|
5
5
|
from datetime import datetime
|
|
6
6
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
)
|
|
20
|
-
from ._generated_api_client.models import (
|
|
21
|
-
ClassificationEvaluationResult,
|
|
22
|
-
CreateRACModelRequest,
|
|
23
|
-
EvaluationRequest,
|
|
24
|
-
ListPredictionsRequest,
|
|
7
|
+
|
|
8
|
+
from datasets import Dataset
|
|
9
|
+
|
|
10
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
11
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
12
|
+
from .client import (
|
|
13
|
+
BootstrapClassificationModelMeta,
|
|
14
|
+
BootstrapClassificationModelResult,
|
|
15
|
+
ClassificationModelMetadata,
|
|
16
|
+
PredictiveModelUpdate,
|
|
17
|
+
RACHeadType,
|
|
18
|
+
orca_api,
|
|
25
19
|
)
|
|
26
|
-
from .
|
|
27
|
-
|
|
20
|
+
from .datasource import Datasource
|
|
21
|
+
from .job import Job
|
|
22
|
+
from .memoryset import (
|
|
23
|
+
FilterItem,
|
|
24
|
+
FilterItemTuple,
|
|
25
|
+
LabeledMemoryset,
|
|
26
|
+
_is_metric_column,
|
|
27
|
+
_parse_filter_item_from_tuple,
|
|
28
28
|
)
|
|
29
|
-
from .
|
|
30
|
-
|
|
29
|
+
from .telemetry import (
|
|
30
|
+
ClassificationPrediction,
|
|
31
|
+
TelemetryMode,
|
|
32
|
+
_get_telemetry_config,
|
|
33
|
+
_parse_feedback,
|
|
31
34
|
)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BootstrappedClassificationModel:
|
|
38
|
+
|
|
39
|
+
datasource: Datasource | None
|
|
40
|
+
memoryset: LabeledMemoryset | None
|
|
41
|
+
classification_model: ClassificationModel | None
|
|
42
|
+
agent_output: BootstrapClassificationModelResult | None
|
|
43
|
+
|
|
44
|
+
def __init__(self, metadata: BootstrapClassificationModelMeta):
|
|
45
|
+
self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
|
|
46
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
|
|
47
|
+
self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
|
|
48
|
+
self.agent_output = metadata["agent_output"]
|
|
49
|
+
|
|
50
|
+
def __repr__(self):
|
|
51
|
+
return (
|
|
52
|
+
"BootstrappedClassificationModel({\n"
|
|
53
|
+
f" datasource: {self.datasource},\n"
|
|
54
|
+
f" memoryset: {self.memoryset},\n"
|
|
55
|
+
f" classification_model: {self.classification_model},\n"
|
|
56
|
+
f" agent_output: {self.agent_output},\n"
|
|
57
|
+
"})"
|
|
58
|
+
)
|
|
39
59
|
|
|
40
60
|
|
|
41
61
|
class ClassificationModel:
|
|
@@ -45,17 +65,20 @@ class ClassificationModel:
|
|
|
45
65
|
Attributes:
|
|
46
66
|
id: Unique identifier for the model
|
|
47
67
|
name: Unique name of the model
|
|
68
|
+
description: Optional description of the model
|
|
48
69
|
memoryset: Memoryset that the model uses
|
|
49
70
|
head_type: Classification head type of the model
|
|
50
71
|
num_classes: Number of distinct classes the model can predict
|
|
51
72
|
memory_lookup_count: Number of memories the model uses for each prediction
|
|
52
73
|
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
53
74
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
75
|
+
locked: Whether the model is locked to prevent accidental deletion
|
|
54
76
|
created_at: When the model was created
|
|
55
77
|
"""
|
|
56
78
|
|
|
57
79
|
id: str
|
|
58
80
|
name: str
|
|
81
|
+
description: str | None
|
|
59
82
|
memoryset: LabeledMemoryset
|
|
60
83
|
head_type: RACHeadType
|
|
61
84
|
num_classes: int
|
|
@@ -63,23 +86,26 @@ class ClassificationModel:
|
|
|
63
86
|
weigh_memories: bool | None
|
|
64
87
|
min_memory_weight: float | None
|
|
65
88
|
version: int
|
|
89
|
+
locked: bool
|
|
66
90
|
created_at: datetime
|
|
67
91
|
|
|
68
|
-
def __init__(self, metadata:
|
|
92
|
+
def __init__(self, metadata: ClassificationModelMetadata):
|
|
69
93
|
# for internal use only, do not document
|
|
70
|
-
self.id = metadata
|
|
71
|
-
self.name = metadata
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
self.
|
|
79
|
-
self.
|
|
94
|
+
self.id = metadata["id"]
|
|
95
|
+
self.name = metadata["name"]
|
|
96
|
+
self.description = metadata["description"]
|
|
97
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
|
|
98
|
+
self.head_type = metadata["head_type"]
|
|
99
|
+
self.num_classes = metadata["num_classes"]
|
|
100
|
+
self.memory_lookup_count = metadata["memory_lookup_count"]
|
|
101
|
+
self.weigh_memories = metadata["weigh_memories"]
|
|
102
|
+
self.min_memory_weight = metadata["min_memory_weight"]
|
|
103
|
+
self.version = metadata["version"]
|
|
104
|
+
self.locked = metadata["locked"]
|
|
105
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
80
106
|
|
|
81
107
|
self._memoryset_override_id: str | None = None
|
|
82
|
-
self._last_prediction:
|
|
108
|
+
self._last_prediction: ClassificationPrediction | None = None
|
|
83
109
|
self._last_prediction_was_batch: bool = False
|
|
84
110
|
|
|
85
111
|
def __eq__(self, other) -> bool:
|
|
@@ -97,7 +123,7 @@ class ClassificationModel:
|
|
|
97
123
|
)
|
|
98
124
|
|
|
99
125
|
@property
|
|
100
|
-
def last_prediction(self) ->
|
|
126
|
+
def last_prediction(self) -> ClassificationPrediction:
|
|
101
127
|
"""
|
|
102
128
|
Last prediction made by the model
|
|
103
129
|
|
|
@@ -119,8 +145,9 @@ class ClassificationModel:
|
|
|
119
145
|
cls,
|
|
120
146
|
name: str,
|
|
121
147
|
memoryset: LabeledMemoryset,
|
|
122
|
-
head_type:
|
|
148
|
+
head_type: RACHeadType = "KNN",
|
|
123
149
|
*,
|
|
150
|
+
description: str | None = None,
|
|
124
151
|
num_classes: int | None = None,
|
|
125
152
|
memory_lookup_count: int | None = None,
|
|
126
153
|
weigh_memories: bool = True,
|
|
@@ -141,6 +168,8 @@ class ClassificationModel:
|
|
|
141
168
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
142
169
|
if_exists: What to do if a model with the same name already exists, defaults to
|
|
143
170
|
`"error"`. Other option is `"open"` to open the existing model.
|
|
171
|
+
description: Optional description for the model, this will be used in agentic flows,
|
|
172
|
+
so make sure it is concise and describes the purpose of your model.
|
|
144
173
|
|
|
145
174
|
Returns:
|
|
146
175
|
Handle to the new model in the OrcaCloud
|
|
@@ -182,16 +211,18 @@ class ClassificationModel:
|
|
|
182
211
|
|
|
183
212
|
return existing
|
|
184
213
|
|
|
185
|
-
metadata =
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
214
|
+
metadata = orca_api.POST(
|
|
215
|
+
"/classification_model",
|
|
216
|
+
json={
|
|
217
|
+
"name": name,
|
|
218
|
+
"memoryset_name_or_id": memoryset.id,
|
|
219
|
+
"head_type": head_type,
|
|
220
|
+
"memory_lookup_count": memory_lookup_count,
|
|
221
|
+
"num_classes": num_classes,
|
|
222
|
+
"weigh_memories": weigh_memories,
|
|
223
|
+
"min_memory_weight": min_memory_weight,
|
|
224
|
+
"description": description,
|
|
225
|
+
},
|
|
195
226
|
)
|
|
196
227
|
return cls(metadata)
|
|
197
228
|
|
|
@@ -209,7 +240,7 @@ class ClassificationModel:
|
|
|
209
240
|
Raises:
|
|
210
241
|
LookupError: If the classification model does not exist
|
|
211
242
|
"""
|
|
212
|
-
return cls(
|
|
243
|
+
return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
|
|
213
244
|
|
|
214
245
|
@classmethod
|
|
215
246
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -236,7 +267,7 @@ class ClassificationModel:
|
|
|
236
267
|
Returns:
|
|
237
268
|
List of handles to all classification models in the OrcaCloud
|
|
238
269
|
"""
|
|
239
|
-
return [cls(metadata) for metadata in
|
|
270
|
+
return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
|
|
240
271
|
|
|
241
272
|
@classmethod
|
|
242
273
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -255,73 +286,189 @@ class ClassificationModel:
|
|
|
255
286
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
256
287
|
"""
|
|
257
288
|
try:
|
|
258
|
-
|
|
289
|
+
orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
259
290
|
logging.info(f"Deleted model {name_or_id}")
|
|
260
291
|
except LookupError:
|
|
261
292
|
if if_not_exists == "error":
|
|
262
293
|
raise
|
|
263
294
|
|
|
295
|
+
def refresh(self):
|
|
296
|
+
"""Refresh the model data from the OrcaCloud"""
|
|
297
|
+
self.__dict__.update(self.open(self.name).__dict__)
|
|
298
|
+
|
|
299
|
+
def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
|
|
300
|
+
"""
|
|
301
|
+
Update editable attributes of the model.
|
|
302
|
+
|
|
303
|
+
Note:
|
|
304
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
305
|
+
|
|
306
|
+
Params:
|
|
307
|
+
description: Value to set for the description
|
|
308
|
+
locked: Value to set for the locked status
|
|
309
|
+
|
|
310
|
+
Examples:
|
|
311
|
+
Update the description:
|
|
312
|
+
>>> model.set(description="New description")
|
|
313
|
+
|
|
314
|
+
Remove description:
|
|
315
|
+
>>> model.set(description=None)
|
|
316
|
+
|
|
317
|
+
Lock the model:
|
|
318
|
+
>>> model.set(locked=True)
|
|
319
|
+
"""
|
|
320
|
+
update: PredictiveModelUpdate = {}
|
|
321
|
+
if description is not UNSET:
|
|
322
|
+
update["description"] = description
|
|
323
|
+
if locked is not UNSET:
|
|
324
|
+
update["locked"] = locked
|
|
325
|
+
orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
326
|
+
self.refresh()
|
|
327
|
+
|
|
328
|
+
def lock(self) -> None:
|
|
329
|
+
"""Lock the model to prevent accidental deletion"""
|
|
330
|
+
self.set(locked=True)
|
|
331
|
+
|
|
332
|
+
def unlock(self) -> None:
|
|
333
|
+
"""Unlock the model to allow deletion"""
|
|
334
|
+
self.set(locked=False)
|
|
335
|
+
|
|
264
336
|
@overload
|
|
265
337
|
def predict(
|
|
266
|
-
self,
|
|
267
|
-
|
|
338
|
+
self,
|
|
339
|
+
value: list[str],
|
|
340
|
+
expected_labels: list[int] | None = None,
|
|
341
|
+
filters: list[FilterItemTuple] = [],
|
|
342
|
+
tags: set[str] | None = None,
|
|
343
|
+
save_telemetry: TelemetryMode = "on",
|
|
344
|
+
prompt: str | None = None,
|
|
345
|
+
use_lookup_cache: bool = True,
|
|
346
|
+
timeout_seconds: int = 10,
|
|
347
|
+
) -> list[ClassificationPrediction]:
|
|
268
348
|
pass
|
|
269
349
|
|
|
270
350
|
@overload
|
|
271
|
-
def predict(
|
|
351
|
+
def predict(
|
|
352
|
+
self,
|
|
353
|
+
value: str,
|
|
354
|
+
expected_labels: int | None = None,
|
|
355
|
+
filters: list[FilterItemTuple] = [],
|
|
356
|
+
tags: set[str] | None = None,
|
|
357
|
+
save_telemetry: TelemetryMode = "on",
|
|
358
|
+
prompt: str | None = None,
|
|
359
|
+
use_lookup_cache: bool = True,
|
|
360
|
+
timeout_seconds: int = 10,
|
|
361
|
+
) -> ClassificationPrediction:
|
|
272
362
|
pass
|
|
273
363
|
|
|
274
364
|
def predict(
|
|
275
|
-
self,
|
|
276
|
-
|
|
365
|
+
self,
|
|
366
|
+
value: list[str] | str,
|
|
367
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
368
|
+
filters: list[FilterItemTuple] = [],
|
|
369
|
+
tags: set[str] | None = None,
|
|
370
|
+
save_telemetry: TelemetryMode = "on",
|
|
371
|
+
prompt: str | None = None,
|
|
372
|
+
use_lookup_cache: bool = True,
|
|
373
|
+
timeout_seconds: int = 10,
|
|
374
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
277
375
|
"""
|
|
278
376
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
279
377
|
|
|
280
378
|
Params:
|
|
281
379
|
value: Value(s) to get predict the labels of
|
|
282
380
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
381
|
+
filters: Optional filters to apply during memory lookup
|
|
283
382
|
tags: Tags to add to the prediction(s)
|
|
383
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
384
|
+
* `"off"`: Do not save telemetry
|
|
385
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
386
|
+
environment variable is set.
|
|
387
|
+
* `"sync"`: Save telemetry synchronously
|
|
388
|
+
* `"async"`: Save telemetry asynchronously
|
|
389
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
390
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
391
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
284
392
|
|
|
285
393
|
Returns:
|
|
286
394
|
Label prediction or list of label predictions
|
|
287
395
|
|
|
396
|
+
Raises:
|
|
397
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
398
|
+
TimeoutError: If the request times out after the specified duration
|
|
399
|
+
|
|
288
400
|
Examples:
|
|
289
401
|
Predict the label for a single value:
|
|
290
402
|
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
291
|
-
|
|
403
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
292
404
|
|
|
293
405
|
Predict the labels for a list of values:
|
|
294
406
|
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
295
407
|
[
|
|
296
|
-
|
|
297
|
-
|
|
408
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
409
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
298
410
|
]
|
|
411
|
+
|
|
412
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
413
|
+
>>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
414
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
299
415
|
"""
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
416
|
+
|
|
417
|
+
if timeout_seconds <= 0:
|
|
418
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
419
|
+
|
|
420
|
+
parsed_filters = [
|
|
421
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
422
|
+
]
|
|
423
|
+
|
|
424
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
425
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
426
|
+
|
|
427
|
+
if isinstance(expected_labels, int):
|
|
428
|
+
expected_labels = [expected_labels]
|
|
429
|
+
elif isinstance(expected_labels, str):
|
|
430
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
431
|
+
elif isinstance(expected_labels, list):
|
|
432
|
+
expected_labels = [
|
|
433
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
434
|
+
for label in expected_labels
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
438
|
+
response = orca_api.POST(
|
|
439
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
440
|
+
params={"name_or_id": self.id},
|
|
441
|
+
json={
|
|
442
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
443
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
444
|
+
"expected_labels": expected_labels,
|
|
445
|
+
"tags": list(tags or set()),
|
|
446
|
+
"save_telemetry": telemetry_on,
|
|
447
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
448
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
449
|
+
"prompt": prompt,
|
|
450
|
+
"use_lookup_cache": use_lookup_cache,
|
|
451
|
+
},
|
|
452
|
+
timeout=timeout_seconds,
|
|
314
453
|
)
|
|
454
|
+
|
|
455
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
456
|
+
raise RuntimeError("Failed to save prediction to database.")
|
|
457
|
+
|
|
315
458
|
predictions = [
|
|
316
|
-
|
|
317
|
-
prediction_id=prediction
|
|
318
|
-
label=prediction
|
|
319
|
-
label_name=prediction
|
|
320
|
-
|
|
459
|
+
ClassificationPrediction(
|
|
460
|
+
prediction_id=prediction["prediction_id"],
|
|
461
|
+
label=prediction["label"],
|
|
462
|
+
label_name=prediction["label_name"],
|
|
463
|
+
score=None,
|
|
464
|
+
confidence=prediction["confidence"],
|
|
465
|
+
anomaly_score=prediction["anomaly_score"],
|
|
321
466
|
memoryset=self.memoryset,
|
|
322
467
|
model=self,
|
|
468
|
+
logits=prediction["logits"],
|
|
469
|
+
input_value=input_value,
|
|
323
470
|
)
|
|
324
|
-
for prediction in response
|
|
471
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
325
472
|
]
|
|
326
473
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
327
474
|
self._last_prediction = predictions[-1]
|
|
@@ -332,8 +479,9 @@ class ClassificationModel:
|
|
|
332
479
|
limit: int = 100,
|
|
333
480
|
offset: int = 0,
|
|
334
481
|
tag: str | None = None,
|
|
335
|
-
sort: list[tuple[
|
|
336
|
-
|
|
482
|
+
sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
|
|
483
|
+
expected_label_match: bool | None = None,
|
|
484
|
+
) -> list[ClassificationPrediction]:
|
|
337
485
|
"""
|
|
338
486
|
Get a list of predictions made by this model
|
|
339
487
|
|
|
@@ -343,6 +491,8 @@ class ClassificationModel:
|
|
|
343
491
|
tag: Optional tag to filter predictions by
|
|
344
492
|
sort: Optional list of columns and directions to sort the predictions by.
|
|
345
493
|
Predictions can be sorted by `timestamp` or `confidence`.
|
|
494
|
+
expected_label_match: Optional filter to only include predictions where the expected
|
|
495
|
+
label does (`True`) or doesn't (`False`) match the predicted label
|
|
346
496
|
|
|
347
497
|
Returns:
|
|
348
498
|
List of label predictions
|
|
@@ -351,78 +501,209 @@ class ClassificationModel:
|
|
|
351
501
|
Get the last 3 predictions:
|
|
352
502
|
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
353
503
|
[
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
504
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
505
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
506
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
|
|
357
507
|
]
|
|
358
508
|
|
|
359
509
|
|
|
360
510
|
Get second most confident prediction:
|
|
361
511
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
362
|
-
[
|
|
512
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
|
|
513
|
+
|
|
514
|
+
Get predictions where the expected label doesn't match the predicted label:
|
|
515
|
+
>>> predictions = model.predictions(expected_label_match=False)
|
|
516
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
363
517
|
"""
|
|
364
|
-
predictions =
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
518
|
+
predictions = orca_api.POST(
|
|
519
|
+
"/telemetry/prediction",
|
|
520
|
+
json={
|
|
521
|
+
"model_id": self.id,
|
|
522
|
+
"limit": limit,
|
|
523
|
+
"offset": offset,
|
|
524
|
+
"sort": [list(sort_item) for sort_item in sort],
|
|
525
|
+
"tag": tag,
|
|
526
|
+
"expected_label_match": expected_label_match,
|
|
527
|
+
},
|
|
372
528
|
)
|
|
373
529
|
return [
|
|
374
|
-
|
|
375
|
-
prediction_id=prediction
|
|
376
|
-
label=prediction
|
|
377
|
-
label_name=prediction
|
|
378
|
-
|
|
530
|
+
ClassificationPrediction(
|
|
531
|
+
prediction_id=prediction["prediction_id"],
|
|
532
|
+
label=prediction["label"],
|
|
533
|
+
label_name=prediction["label_name"],
|
|
534
|
+
score=None,
|
|
535
|
+
confidence=prediction["confidence"],
|
|
536
|
+
anomaly_score=prediction["anomaly_score"],
|
|
379
537
|
memoryset=self.memoryset,
|
|
380
538
|
model=self,
|
|
381
539
|
telemetry=prediction,
|
|
382
540
|
)
|
|
383
541
|
for prediction in predictions
|
|
542
|
+
if "label" in prediction
|
|
384
543
|
]
|
|
385
544
|
|
|
386
|
-
def
|
|
545
|
+
def _evaluate_datasource(
|
|
387
546
|
self,
|
|
388
547
|
datasource: Datasource,
|
|
548
|
+
value_column: str,
|
|
549
|
+
label_column: str,
|
|
550
|
+
record_predictions: bool,
|
|
551
|
+
tags: set[str] | None,
|
|
552
|
+
background: bool = False,
|
|
553
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
554
|
+
response = orca_api.POST(
|
|
555
|
+
"/classification_model/{model_name_or_id}/evaluation",
|
|
556
|
+
params={"model_name_or_id": self.id},
|
|
557
|
+
json={
|
|
558
|
+
"datasource_name_or_id": datasource.id,
|
|
559
|
+
"datasource_label_column": label_column,
|
|
560
|
+
"datasource_value_column": value_column,
|
|
561
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
562
|
+
"record_telemetry": record_predictions,
|
|
563
|
+
"telemetry_tags": list(tags) if tags else None,
|
|
564
|
+
},
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def get_value():
|
|
568
|
+
res = orca_api.GET(
|
|
569
|
+
"/classification_model/{model_name_or_id}/evaluation/{task_id}",
|
|
570
|
+
params={"model_name_or_id": self.id, "task_id": response["task_id"]},
|
|
571
|
+
)
|
|
572
|
+
assert res["result"] is not None
|
|
573
|
+
return ClassificationMetrics(
|
|
574
|
+
coverage=res["result"].get("coverage"),
|
|
575
|
+
f1_score=res["result"].get("f1_score"),
|
|
576
|
+
accuracy=res["result"].get("accuracy"),
|
|
577
|
+
loss=res["result"].get("loss"),
|
|
578
|
+
anomaly_score_mean=res["result"].get("anomaly_score_mean"),
|
|
579
|
+
anomaly_score_median=res["result"].get("anomaly_score_median"),
|
|
580
|
+
anomaly_score_variance=res["result"].get("anomaly_score_variance"),
|
|
581
|
+
roc_auc=res["result"].get("roc_auc"),
|
|
582
|
+
pr_auc=res["result"].get("pr_auc"),
|
|
583
|
+
pr_curve=res["result"].get("pr_curve"),
|
|
584
|
+
roc_curve=res["result"].get("roc_curve"),
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
job = Job(response["task_id"], get_value)
|
|
588
|
+
return job if background else job.result()
|
|
589
|
+
|
|
590
|
+
def _evaluate_dataset(
|
|
591
|
+
self,
|
|
592
|
+
dataset: Dataset,
|
|
593
|
+
value_column: str,
|
|
594
|
+
label_column: str,
|
|
595
|
+
record_predictions: bool,
|
|
596
|
+
tags: set[str],
|
|
597
|
+
batch_size: int,
|
|
598
|
+
) -> ClassificationMetrics:
|
|
599
|
+
if len(dataset) == 0:
|
|
600
|
+
raise ValueError("Evaluation dataset cannot be empty")
|
|
601
|
+
|
|
602
|
+
if any(x is None for x in dataset[label_column]):
|
|
603
|
+
raise ValueError("Evaluation dataset cannot contain None values in the label column")
|
|
604
|
+
|
|
605
|
+
predictions = [
|
|
606
|
+
prediction
|
|
607
|
+
for i in range(0, len(dataset), batch_size)
|
|
608
|
+
for prediction in self.predict(
|
|
609
|
+
dataset[i : i + batch_size][value_column],
|
|
610
|
+
expected_labels=dataset[i : i + batch_size][label_column],
|
|
611
|
+
tags=tags,
|
|
612
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
613
|
+
)
|
|
614
|
+
]
|
|
615
|
+
|
|
616
|
+
return calculate_classification_metrics(
|
|
617
|
+
expected_labels=dataset[label_column],
|
|
618
|
+
logits=[p.logits for p in predictions],
|
|
619
|
+
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
620
|
+
include_curves=True,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
@overload
|
|
624
|
+
def evaluate(
|
|
625
|
+
self,
|
|
626
|
+
data: Datasource | Dataset,
|
|
627
|
+
*,
|
|
389
628
|
value_column: str = "value",
|
|
390
629
|
label_column: str = "label",
|
|
391
630
|
record_predictions: bool = False,
|
|
392
|
-
tags: set[str]
|
|
393
|
-
|
|
631
|
+
tags: set[str] = {"evaluation"},
|
|
632
|
+
batch_size: int = 100,
|
|
633
|
+
background: Literal[True],
|
|
634
|
+
) -> Job[ClassificationMetrics]:
|
|
635
|
+
pass
|
|
636
|
+
|
|
637
|
+
@overload
|
|
638
|
+
def evaluate(
|
|
639
|
+
self,
|
|
640
|
+
data: Datasource | Dataset,
|
|
641
|
+
*,
|
|
642
|
+
value_column: str = "value",
|
|
643
|
+
label_column: str = "label",
|
|
644
|
+
record_predictions: bool = False,
|
|
645
|
+
tags: set[str] = {"evaluation"},
|
|
646
|
+
batch_size: int = 100,
|
|
647
|
+
background: Literal[False] = False,
|
|
648
|
+
) -> ClassificationMetrics:
|
|
649
|
+
pass
|
|
650
|
+
|
|
651
|
+
def evaluate(
|
|
652
|
+
self,
|
|
653
|
+
data: Datasource | Dataset,
|
|
654
|
+
*,
|
|
655
|
+
value_column: str = "value",
|
|
656
|
+
label_column: str = "label",
|
|
657
|
+
record_predictions: bool = False,
|
|
658
|
+
tags: set[str] = {"evaluation"},
|
|
659
|
+
batch_size: int = 100,
|
|
660
|
+
background: bool = False,
|
|
661
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
394
662
|
"""
|
|
395
|
-
Evaluate the classification model on a given datasource
|
|
663
|
+
Evaluate the classification model on a given dataset or datasource
|
|
396
664
|
|
|
397
665
|
Params:
|
|
398
|
-
|
|
666
|
+
data: Dataset or Datasource to evaluate the model on
|
|
399
667
|
value_column: Name of the column that contains the input values to the model
|
|
400
668
|
label_column: Name of the column containing the expected labels
|
|
401
|
-
record_predictions: Whether to record [`
|
|
402
|
-
tags: Optional tags to add to the recorded [`
|
|
669
|
+
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
670
|
+
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
671
|
+
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
672
|
+
background: Whether to run the operation in the background and return a job handle
|
|
403
673
|
|
|
404
674
|
Returns:
|
|
405
|
-
|
|
675
|
+
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
406
676
|
|
|
407
677
|
Examples:
|
|
408
678
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
409
|
-
{
|
|
679
|
+
ClassificationMetrics({
|
|
680
|
+
accuracy: 0.8500,
|
|
681
|
+
f1_score: 0.8500,
|
|
682
|
+
roc_auc: 0.8500,
|
|
683
|
+
pr_auc: 0.8500,
|
|
684
|
+
anomaly_score: 0.3500 ± 0.0500,
|
|
685
|
+
})
|
|
410
686
|
"""
|
|
411
|
-
|
|
412
|
-
self.
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
687
|
+
if isinstance(data, Datasource):
|
|
688
|
+
return self._evaluate_datasource(
|
|
689
|
+
datasource=data,
|
|
690
|
+
value_column=value_column,
|
|
691
|
+
label_column=label_column,
|
|
692
|
+
record_predictions=record_predictions,
|
|
693
|
+
tags=tags,
|
|
694
|
+
background=background,
|
|
695
|
+
)
|
|
696
|
+
elif isinstance(data, Dataset):
|
|
697
|
+
return self._evaluate_dataset(
|
|
698
|
+
dataset=data,
|
|
699
|
+
value_column=value_column,
|
|
700
|
+
label_column=label_column,
|
|
701
|
+
record_predictions=record_predictions,
|
|
702
|
+
tags=tags,
|
|
703
|
+
batch_size=batch_size,
|
|
704
|
+
)
|
|
705
|
+
else:
|
|
706
|
+
raise ValueError(f"Invalid data type: {type(data)}")
|
|
426
707
|
|
|
427
708
|
def finetune(self, datasource: Datasource):
|
|
428
709
|
# do not document until implemented
|
|
@@ -492,8 +773,37 @@ class ClassificationModel:
|
|
|
492
773
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
493
774
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
494
775
|
"""
|
|
495
|
-
|
|
496
|
-
|
|
776
|
+
orca_api.PUT(
|
|
777
|
+
"/telemetry/prediction/feedback",
|
|
778
|
+
json=[
|
|
497
779
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
498
780
|
],
|
|
499
781
|
)
|
|
782
|
+
|
|
783
|
+
@staticmethod
|
|
784
|
+
def bootstrap_model(
|
|
785
|
+
model_description: str,
|
|
786
|
+
label_names: list[str],
|
|
787
|
+
initial_examples: list[tuple[str, str]],
|
|
788
|
+
num_examples_per_label: int,
|
|
789
|
+
background: bool = False,
|
|
790
|
+
) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
|
|
791
|
+
response = orca_api.POST(
|
|
792
|
+
"/agents/bootstrap_classification_model",
|
|
793
|
+
json={
|
|
794
|
+
"model_description": model_description,
|
|
795
|
+
"label_names": label_names,
|
|
796
|
+
"initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
|
|
797
|
+
"num_examples_per_label": num_examples_per_label,
|
|
798
|
+
},
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
def get_result() -> BootstrappedClassificationModel:
|
|
802
|
+
res = orca_api.GET(
|
|
803
|
+
"/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
|
|
804
|
+
)
|
|
805
|
+
assert res["result"] is not None
|
|
806
|
+
return BootstrappedClassificationModel(res["result"])
|
|
807
|
+
|
|
808
|
+
job = Job(response["task_id"], get_result)
|
|
809
|
+
return job if background else job.result()
|