orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +84 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +90 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +184 -174
- orca_sdk/classification_model_test.py +178 -142
- orca_sdk/conftest.py +77 -26
- orca_sdk/datasource.py +34 -0
- orca_sdk/datasource_test.py +9 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +58 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +338 -0
- orca_sdk/telemetry.py +225 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -5,37 +5,29 @@ import os
|
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
8
|
-
from uuid import UUID
|
|
8
|
+
from uuid import UUID
|
|
9
9
|
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
import numpy as np
|
|
13
10
|
from datasets import Dataset
|
|
14
|
-
from sklearn.metrics import (
|
|
15
|
-
accuracy_score,
|
|
16
|
-
auc,
|
|
17
|
-
f1_score,
|
|
18
|
-
roc_auc_score,
|
|
19
|
-
)
|
|
20
11
|
|
|
21
12
|
from ._generated_api_client.api import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
13
|
+
create_classification_model,
|
|
14
|
+
delete_classification_model,
|
|
15
|
+
evaluate_classification_model,
|
|
16
|
+
get_classification_model,
|
|
17
|
+
get_classification_model_evaluation,
|
|
18
|
+
list_classification_models,
|
|
28
19
|
list_predictions,
|
|
29
|
-
|
|
20
|
+
predict_label_gpu,
|
|
30
21
|
record_prediction_feedback,
|
|
31
|
-
|
|
22
|
+
update_classification_model,
|
|
32
23
|
)
|
|
33
24
|
from ._generated_api_client.models import (
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
25
|
+
ClassificationEvaluationRequest,
|
|
26
|
+
ClassificationModelMetadata,
|
|
27
|
+
ClassificationPredictionRequest,
|
|
28
|
+
CreateClassificationModelRequest,
|
|
29
|
+
LabelPredictionWithMemoriesAndFeedback,
|
|
37
30
|
ListPredictionsRequest,
|
|
38
|
-
PrecisionRecallCurve,
|
|
39
31
|
)
|
|
40
32
|
from ._generated_api_client.models import (
|
|
41
33
|
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
@@ -43,19 +35,19 @@ from ._generated_api_client.models import (
|
|
|
43
35
|
from ._generated_api_client.models import (
|
|
44
36
|
PredictionSortItemItemType1 as PredictionSortDirection,
|
|
45
37
|
)
|
|
46
|
-
from ._generated_api_client.models import
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
RACModelUpdate,
|
|
50
|
-
ROCCurve,
|
|
51
|
-
)
|
|
52
|
-
from ._generated_api_client.models.prediction_request import PredictionRequest
|
|
53
|
-
from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
|
|
38
|
+
from ._generated_api_client.models import PredictiveModelUpdate, RACHeadType
|
|
39
|
+
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
40
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
54
41
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
55
|
-
from ._utils.task import wait_for_task
|
|
56
42
|
from .datasource import Datasource
|
|
57
|
-
from .
|
|
58
|
-
from .
|
|
43
|
+
from .job import Job
|
|
44
|
+
from .memoryset import (
|
|
45
|
+
FilterItem,
|
|
46
|
+
FilterItemTuple,
|
|
47
|
+
LabeledMemoryset,
|
|
48
|
+
_parse_filter_item_from_tuple,
|
|
49
|
+
)
|
|
50
|
+
from .telemetry import ClassificationPrediction, _parse_feedback
|
|
59
51
|
|
|
60
52
|
|
|
61
53
|
class ClassificationModel:
|
|
@@ -72,6 +64,7 @@ class ClassificationModel:
|
|
|
72
64
|
memory_lookup_count: Number of memories the model uses for each prediction
|
|
73
65
|
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
74
66
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
67
|
+
locked: Whether the model is locked to prevent accidental deletion
|
|
75
68
|
created_at: When the model was created
|
|
76
69
|
"""
|
|
77
70
|
|
|
@@ -85,9 +78,10 @@ class ClassificationModel:
|
|
|
85
78
|
weigh_memories: bool | None
|
|
86
79
|
min_memory_weight: float | None
|
|
87
80
|
version: int
|
|
81
|
+
locked: bool
|
|
88
82
|
created_at: datetime
|
|
89
83
|
|
|
90
|
-
def __init__(self, metadata:
|
|
84
|
+
def __init__(self, metadata: ClassificationModelMetadata):
|
|
91
85
|
# for internal use only, do not document
|
|
92
86
|
self.id = metadata.id
|
|
93
87
|
self.name = metadata.name
|
|
@@ -99,10 +93,11 @@ class ClassificationModel:
|
|
|
99
93
|
self.weigh_memories = metadata.weigh_memories
|
|
100
94
|
self.min_memory_weight = metadata.min_memory_weight
|
|
101
95
|
self.version = metadata.version
|
|
96
|
+
self.locked = metadata.locked
|
|
102
97
|
self.created_at = metadata.created_at
|
|
103
98
|
|
|
104
99
|
self._memoryset_override_id: str | None = None
|
|
105
|
-
self._last_prediction:
|
|
100
|
+
self._last_prediction: ClassificationPrediction | None = None
|
|
106
101
|
self._last_prediction_was_batch: bool = False
|
|
107
102
|
|
|
108
103
|
def __eq__(self, other) -> bool:
|
|
@@ -120,7 +115,7 @@ class ClassificationModel:
|
|
|
120
115
|
)
|
|
121
116
|
|
|
122
117
|
@property
|
|
123
|
-
def last_prediction(self) ->
|
|
118
|
+
def last_prediction(self) -> ClassificationPrediction:
|
|
124
119
|
"""
|
|
125
120
|
Last prediction made by the model
|
|
126
121
|
|
|
@@ -208,8 +203,8 @@ class ClassificationModel:
|
|
|
208
203
|
|
|
209
204
|
return existing
|
|
210
205
|
|
|
211
|
-
metadata =
|
|
212
|
-
body=
|
|
206
|
+
metadata = create_classification_model(
|
|
207
|
+
body=CreateClassificationModelRequest(
|
|
213
208
|
name=name,
|
|
214
209
|
memoryset_id=memoryset.id,
|
|
215
210
|
head_type=RACHeadType(head_type),
|
|
@@ -236,7 +231,7 @@ class ClassificationModel:
|
|
|
236
231
|
Raises:
|
|
237
232
|
LookupError: If the classification model does not exist
|
|
238
233
|
"""
|
|
239
|
-
return cls(
|
|
234
|
+
return cls(get_classification_model(name))
|
|
240
235
|
|
|
241
236
|
@classmethod
|
|
242
237
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -263,7 +258,7 @@ class ClassificationModel:
|
|
|
263
258
|
Returns:
|
|
264
259
|
List of handles to all classification models in the OrcaCloud
|
|
265
260
|
"""
|
|
266
|
-
return [cls(metadata) for metadata in
|
|
261
|
+
return [cls(metadata) for metadata in list_classification_models()]
|
|
267
262
|
|
|
268
263
|
@classmethod
|
|
269
264
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -282,7 +277,7 @@ class ClassificationModel:
|
|
|
282
277
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
283
278
|
"""
|
|
284
279
|
try:
|
|
285
|
-
|
|
280
|
+
delete_classification_model(name_or_id)
|
|
286
281
|
logging.info(f"Deleted model {name_or_id}")
|
|
287
282
|
except LookupError:
|
|
288
283
|
if if_not_exists == "error":
|
|
@@ -290,34 +285,53 @@ class ClassificationModel:
|
|
|
290
285
|
|
|
291
286
|
def refresh(self):
|
|
292
287
|
"""Refresh the model data from the OrcaCloud"""
|
|
293
|
-
self.__dict__.update(
|
|
288
|
+
self.__dict__.update(self.open(self.name).__dict__)
|
|
294
289
|
|
|
295
|
-
def
|
|
290
|
+
def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
|
|
296
291
|
"""
|
|
297
|
-
Update editable
|
|
292
|
+
Update editable attributes of the model.
|
|
293
|
+
|
|
294
|
+
Note:
|
|
295
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
298
296
|
|
|
299
297
|
Params:
|
|
300
|
-
description: Value to set for the description
|
|
298
|
+
description: Value to set for the description
|
|
299
|
+
locked: Value to set for the locked status
|
|
301
300
|
|
|
302
301
|
Examples:
|
|
303
302
|
Update the description:
|
|
304
|
-
>>> model.
|
|
303
|
+
>>> model.set(description="New description")
|
|
305
304
|
|
|
306
305
|
Remove description:
|
|
307
|
-
>>> model.
|
|
306
|
+
>>> model.set(description=None)
|
|
307
|
+
|
|
308
|
+
Lock the model:
|
|
309
|
+
>>> model.set(locked=True)
|
|
308
310
|
"""
|
|
309
|
-
|
|
311
|
+
update_data = PredictiveModelUpdate(
|
|
312
|
+
description=CLIENT_UNSET if description is UNSET else description,
|
|
313
|
+
locked=CLIENT_UNSET if locked is UNSET else locked,
|
|
314
|
+
)
|
|
315
|
+
update_classification_model(self.id, body=update_data)
|
|
310
316
|
self.refresh()
|
|
311
317
|
|
|
318
|
+
def lock(self) -> None:
|
|
319
|
+
"""Lock the model to prevent accidental deletion"""
|
|
320
|
+
self.set(locked=True)
|
|
321
|
+
|
|
322
|
+
def unlock(self) -> None:
|
|
323
|
+
"""Unlock the model to allow deletion"""
|
|
324
|
+
self.set(locked=False)
|
|
325
|
+
|
|
312
326
|
@overload
|
|
313
327
|
def predict(
|
|
314
328
|
self,
|
|
315
329
|
value: list[str],
|
|
316
330
|
expected_labels: list[int] | None = None,
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
) -> list[
|
|
331
|
+
filters: list[FilterItemTuple] = [],
|
|
332
|
+
tags: set[str] | None = None,
|
|
333
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
334
|
+
) -> list[ClassificationPrediction]:
|
|
321
335
|
pass
|
|
322
336
|
|
|
323
337
|
@overload
|
|
@@ -325,20 +339,20 @@ class ClassificationModel:
|
|
|
325
339
|
self,
|
|
326
340
|
value: str,
|
|
327
341
|
expected_labels: int | None = None,
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
) ->
|
|
342
|
+
filters: list[FilterItemTuple] = [],
|
|
343
|
+
tags: set[str] | None = None,
|
|
344
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
345
|
+
) -> ClassificationPrediction:
|
|
332
346
|
pass
|
|
333
347
|
|
|
334
348
|
def predict(
|
|
335
349
|
self,
|
|
336
350
|
value: list[str] | str,
|
|
337
351
|
expected_labels: list[int] | int | None = None,
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
) -> list[
|
|
352
|
+
filters: list[FilterItemTuple] = [],
|
|
353
|
+
tags: set[str] | None = None,
|
|
354
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
355
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
342
356
|
"""
|
|
343
357
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
344
358
|
|
|
@@ -346,10 +360,12 @@ class ClassificationModel:
|
|
|
346
360
|
value: Value(s) to get predict the labels of
|
|
347
361
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
348
362
|
tags: Tags to add to the prediction(s)
|
|
349
|
-
save_telemetry: Whether to
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
363
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
364
|
+
* `"off"`: Do not save telemetry
|
|
365
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
366
|
+
environment variable is set.
|
|
367
|
+
* `"sync"`: Save telemetry synchronously
|
|
368
|
+
* `"async"`: Save telemetry asynchronously
|
|
353
369
|
|
|
354
370
|
Returns:
|
|
355
371
|
Label prediction or list of label predictions
|
|
@@ -357,26 +373,26 @@ class ClassificationModel:
|
|
|
357
373
|
Examples:
|
|
358
374
|
Predict the label for a single value:
|
|
359
375
|
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
360
|
-
|
|
376
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
361
377
|
|
|
362
378
|
Predict the labels for a list of values:
|
|
363
379
|
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
364
380
|
[
|
|
365
|
-
|
|
366
|
-
|
|
381
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
382
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
367
383
|
]
|
|
368
384
|
"""
|
|
369
385
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
|
|
374
|
-
)
|
|
375
|
-
save_telemetry_synchronously = env_var.lower() == "true"
|
|
386
|
+
parsed_filters = [
|
|
387
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
388
|
+
]
|
|
376
389
|
|
|
377
|
-
|
|
390
|
+
if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
|
|
391
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
392
|
+
|
|
393
|
+
response = predict_label_gpu(
|
|
378
394
|
self.id,
|
|
379
|
-
body=
|
|
395
|
+
body=ClassificationPredictionRequest(
|
|
380
396
|
input_values=value if isinstance(value, list) else [value],
|
|
381
397
|
memoryset_override_id=self._memoryset_override_id,
|
|
382
398
|
expected_labels=(
|
|
@@ -384,27 +400,32 @@ class ClassificationModel:
|
|
|
384
400
|
if isinstance(expected_labels, list)
|
|
385
401
|
else [expected_labels] if expected_labels is not None else None
|
|
386
402
|
),
|
|
387
|
-
tags=list(tags),
|
|
388
|
-
save_telemetry=save_telemetry,
|
|
389
|
-
save_telemetry_synchronously=
|
|
403
|
+
tags=list(tags or set()),
|
|
404
|
+
save_telemetry=save_telemetry != "off",
|
|
405
|
+
save_telemetry_synchronously=(
|
|
406
|
+
os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
|
|
407
|
+
),
|
|
408
|
+
filters=cast(list[FilterItem], parsed_filters),
|
|
390
409
|
),
|
|
391
410
|
)
|
|
392
411
|
|
|
393
|
-
if save_telemetry and any(p.prediction_id is None for p in response):
|
|
412
|
+
if save_telemetry != "off" and any(p.prediction_id is None for p in response):
|
|
394
413
|
raise RuntimeError("Failed to save prediction to database.")
|
|
395
414
|
|
|
396
415
|
predictions = [
|
|
397
|
-
|
|
416
|
+
ClassificationPrediction(
|
|
398
417
|
prediction_id=prediction.prediction_id,
|
|
399
418
|
label=prediction.label,
|
|
400
419
|
label_name=prediction.label_name,
|
|
420
|
+
score=None,
|
|
401
421
|
confidence=prediction.confidence,
|
|
402
422
|
anomaly_score=prediction.anomaly_score,
|
|
403
423
|
memoryset=self.memoryset,
|
|
404
424
|
model=self,
|
|
405
425
|
logits=prediction.logits,
|
|
426
|
+
input_value=input_value,
|
|
406
427
|
)
|
|
407
|
-
for prediction in response
|
|
428
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
408
429
|
]
|
|
409
430
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
410
431
|
self._last_prediction = predictions[-1]
|
|
@@ -417,7 +438,7 @@ class ClassificationModel:
|
|
|
417
438
|
tag: str | None = None,
|
|
418
439
|
sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
|
|
419
440
|
expected_label_match: bool | None = None,
|
|
420
|
-
) -> list[
|
|
441
|
+
) -> list[ClassificationPrediction]:
|
|
421
442
|
"""
|
|
422
443
|
Get a list of predictions made by this model
|
|
423
444
|
|
|
@@ -437,19 +458,19 @@ class ClassificationModel:
|
|
|
437
458
|
Get the last 3 predictions:
|
|
438
459
|
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
439
460
|
[
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
461
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
462
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
463
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
|
|
443
464
|
]
|
|
444
465
|
|
|
445
466
|
|
|
446
467
|
Get second most confident prediction:
|
|
447
468
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
448
|
-
[
|
|
469
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
|
|
449
470
|
|
|
450
471
|
Get predictions where the expected label doesn't match the predicted label:
|
|
451
472
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
452
|
-
[
|
|
473
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
453
474
|
"""
|
|
454
475
|
predictions = list_predictions(
|
|
455
476
|
body=ListPredictionsRequest(
|
|
@@ -462,10 +483,11 @@ class ClassificationModel:
|
|
|
462
483
|
),
|
|
463
484
|
)
|
|
464
485
|
return [
|
|
465
|
-
|
|
486
|
+
ClassificationPrediction(
|
|
466
487
|
prediction_id=prediction.prediction_id,
|
|
467
488
|
label=prediction.label,
|
|
468
489
|
label_name=prediction.label_name,
|
|
490
|
+
score=None,
|
|
469
491
|
confidence=prediction.confidence,
|
|
470
492
|
anomaly_score=prediction.anomaly_score,
|
|
471
493
|
memoryset=self.memoryset,
|
|
@@ -473,60 +495,9 @@ class ClassificationModel:
|
|
|
473
495
|
telemetry=prediction,
|
|
474
496
|
)
|
|
475
497
|
for prediction in predictions
|
|
498
|
+
if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
|
|
476
499
|
]
|
|
477
500
|
|
|
478
|
-
def _calculate_metrics(
|
|
479
|
-
self,
|
|
480
|
-
predictions: list[LabelPrediction],
|
|
481
|
-
expected_labels: list[int],
|
|
482
|
-
) -> ClassificationEvaluationResult:
|
|
483
|
-
|
|
484
|
-
targets_array = np.array(expected_labels)
|
|
485
|
-
predictions_array = np.array([p.label for p in predictions])
|
|
486
|
-
|
|
487
|
-
logits_array = np.array([p.logits for p in predictions])
|
|
488
|
-
|
|
489
|
-
f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
|
|
490
|
-
accuracy = float(accuracy_score(targets_array, predictions_array))
|
|
491
|
-
|
|
492
|
-
# Only compute ROC AUC and PR AUC for binary classification
|
|
493
|
-
unique_classes = np.unique(targets_array)
|
|
494
|
-
|
|
495
|
-
pr_curve = None
|
|
496
|
-
roc_curve = None
|
|
497
|
-
|
|
498
|
-
if len(unique_classes) == 2:
|
|
499
|
-
try:
|
|
500
|
-
precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
|
|
501
|
-
pr_auc = float(auc(recalls, precisions))
|
|
502
|
-
|
|
503
|
-
pr_curve = PrecisionRecallCurve(
|
|
504
|
-
precisions=precisions.tolist(),
|
|
505
|
-
recalls=recalls.tolist(),
|
|
506
|
-
thresholds=pr_thresholds.tolist(),
|
|
507
|
-
auc=pr_auc,
|
|
508
|
-
)
|
|
509
|
-
|
|
510
|
-
fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
|
|
511
|
-
roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
|
|
512
|
-
|
|
513
|
-
roc_curve = ROCCurve(
|
|
514
|
-
false_positive_rates=fpr.tolist(),
|
|
515
|
-
true_positive_rates=tpr.tolist(),
|
|
516
|
-
thresholds=roc_thresholds.tolist(),
|
|
517
|
-
auc=roc_auc,
|
|
518
|
-
)
|
|
519
|
-
except ValueError as e:
|
|
520
|
-
logging.warning(f"Error calculating PR and ROC curves: {e}")
|
|
521
|
-
|
|
522
|
-
return ClassificationEvaluationResult(
|
|
523
|
-
f1_score=f1,
|
|
524
|
-
accuracy=accuracy,
|
|
525
|
-
loss=0.0,
|
|
526
|
-
precision_recall_curve=pr_curve,
|
|
527
|
-
roc_curve=roc_curve,
|
|
528
|
-
)
|
|
529
|
-
|
|
530
501
|
def _evaluate_datasource(
|
|
531
502
|
self,
|
|
532
503
|
datasource: Datasource,
|
|
@@ -534,10 +505,11 @@ class ClassificationModel:
|
|
|
534
505
|
label_column: str,
|
|
535
506
|
record_predictions: bool,
|
|
536
507
|
tags: set[str] | None,
|
|
537
|
-
|
|
538
|
-
|
|
508
|
+
background: bool = False,
|
|
509
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
510
|
+
response = evaluate_classification_model(
|
|
539
511
|
self.id,
|
|
540
|
-
body=
|
|
512
|
+
body=ClassificationEvaluationRequest(
|
|
541
513
|
datasource_id=datasource.id,
|
|
542
514
|
datasource_label_column=label_column,
|
|
543
515
|
datasource_value_column=value_column,
|
|
@@ -546,10 +518,13 @@ class ClassificationModel:
|
|
|
546
518
|
telemetry_tags=list(tags) if tags else None,
|
|
547
519
|
),
|
|
548
520
|
)
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
521
|
+
|
|
522
|
+
job = Job(
|
|
523
|
+
response.task_id,
|
|
524
|
+
lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
|
|
525
|
+
and ClassificationMetrics(**r.to_dict()),
|
|
526
|
+
)
|
|
527
|
+
return job if background else job.result()
|
|
553
528
|
|
|
554
529
|
def _evaluate_dataset(
|
|
555
530
|
self,
|
|
@@ -559,34 +534,64 @@ class ClassificationModel:
|
|
|
559
534
|
record_predictions: bool,
|
|
560
535
|
tags: set[str],
|
|
561
536
|
batch_size: int,
|
|
562
|
-
) ->
|
|
563
|
-
predictions = [
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
expected_labels=batch[label_column],
|
|
572
|
-
tags=tags,
|
|
573
|
-
save_telemetry=record_predictions,
|
|
574
|
-
save_telemetry_synchronously=(not record_predictions),
|
|
575
|
-
)
|
|
537
|
+
) -> ClassificationMetrics:
|
|
538
|
+
predictions = [
|
|
539
|
+
prediction
|
|
540
|
+
for i in range(0, len(dataset), batch_size)
|
|
541
|
+
for prediction in self.predict(
|
|
542
|
+
dataset[i : i + batch_size][value_column],
|
|
543
|
+
expected_labels=dataset[i : i + batch_size][label_column],
|
|
544
|
+
tags=tags,
|
|
545
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
576
546
|
)
|
|
577
|
-
|
|
547
|
+
]
|
|
548
|
+
|
|
549
|
+
return calculate_classification_metrics(
|
|
550
|
+
expected_labels=dataset[label_column],
|
|
551
|
+
logits=[p.logits for p in predictions],
|
|
552
|
+
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
553
|
+
include_curves=True,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
@overload
|
|
557
|
+
def evaluate(
|
|
558
|
+
self,
|
|
559
|
+
data: Datasource | Dataset,
|
|
560
|
+
*,
|
|
561
|
+
value_column: str = "value",
|
|
562
|
+
label_column: str = "label",
|
|
563
|
+
record_predictions: bool = False,
|
|
564
|
+
tags: set[str] = {"evaluation"},
|
|
565
|
+
batch_size: int = 100,
|
|
566
|
+
background: Literal[True],
|
|
567
|
+
) -> Job[ClassificationMetrics]:
|
|
568
|
+
pass
|
|
578
569
|
|
|
579
|
-
|
|
570
|
+
@overload
|
|
571
|
+
def evaluate(
|
|
572
|
+
self,
|
|
573
|
+
data: Datasource | Dataset,
|
|
574
|
+
*,
|
|
575
|
+
value_column: str = "value",
|
|
576
|
+
label_column: str = "label",
|
|
577
|
+
record_predictions: bool = False,
|
|
578
|
+
tags: set[str] = {"evaluation"},
|
|
579
|
+
batch_size: int = 100,
|
|
580
|
+
background: Literal[False] = False,
|
|
581
|
+
) -> ClassificationMetrics:
|
|
582
|
+
pass
|
|
580
583
|
|
|
581
584
|
def evaluate(
|
|
582
585
|
self,
|
|
583
586
|
data: Datasource | Dataset,
|
|
587
|
+
*,
|
|
584
588
|
value_column: str = "value",
|
|
585
589
|
label_column: str = "label",
|
|
586
590
|
record_predictions: bool = False,
|
|
587
591
|
tags: set[str] = {"evaluation"},
|
|
588
592
|
batch_size: int = 100,
|
|
589
|
-
|
|
593
|
+
background: bool = False,
|
|
594
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
590
595
|
"""
|
|
591
596
|
Evaluate the classification model on a given dataset or datasource
|
|
592
597
|
|
|
@@ -594,21 +599,23 @@ class ClassificationModel:
|
|
|
594
599
|
data: Dataset or Datasource to evaluate the model on
|
|
595
600
|
value_column: Name of the column that contains the input values to the model
|
|
596
601
|
label_column: Name of the column containing the expected labels
|
|
597
|
-
record_predictions: Whether to record [`
|
|
598
|
-
tags: Optional tags to add to the recorded [`
|
|
602
|
+
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
603
|
+
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
599
604
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
605
|
+
background: Whether to run the operation in the background and return a job handle
|
|
600
606
|
|
|
601
607
|
Returns:
|
|
602
|
-
|
|
608
|
+
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
603
609
|
|
|
604
610
|
Examples:
|
|
605
|
-
Evaluate using a Datasource:
|
|
606
611
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
607
|
-
{
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
+
ClassificationMetrics({
|
|
613
|
+
accuracy: 0.8500,
|
|
614
|
+
f1_score: 0.8500,
|
|
615
|
+
roc_auc: 0.8500,
|
|
616
|
+
pr_auc: 0.8500,
|
|
617
|
+
anomaly_score: 0.3500 ± 0.0500,
|
|
618
|
+
})
|
|
612
619
|
"""
|
|
613
620
|
if isinstance(data, Datasource):
|
|
614
621
|
return self._evaluate_datasource(
|
|
@@ -617,8 +624,9 @@ class ClassificationModel:
|
|
|
617
624
|
label_column=label_column,
|
|
618
625
|
record_predictions=record_predictions,
|
|
619
626
|
tags=tags,
|
|
627
|
+
background=background,
|
|
620
628
|
)
|
|
621
|
-
|
|
629
|
+
elif isinstance(data, Dataset):
|
|
622
630
|
return self._evaluate_dataset(
|
|
623
631
|
dataset=data,
|
|
624
632
|
value_column=value_column,
|
|
@@ -627,6 +635,8 @@ class ClassificationModel:
|
|
|
627
635
|
tags=tags,
|
|
628
636
|
batch_size=batch_size,
|
|
629
637
|
)
|
|
638
|
+
else:
|
|
639
|
+
raise ValueError(f"Invalid data type: {type(data)}")
|
|
630
640
|
|
|
631
641
|
def finetune(self, datasource: Datasource):
|
|
632
642
|
# do not document until implemented
|