orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +80 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +84 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +183 -172
- orca_sdk/classification_model_test.py +147 -157
- orca_sdk/conftest.py +76 -26
- orca_sdk/datasource_test.py +0 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +56 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +337 -0
- orca_sdk/telemetry.py +223 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -5,37 +5,29 @@ import os
|
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
8
|
-
from uuid import UUID
|
|
8
|
+
from uuid import UUID
|
|
9
9
|
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
import numpy as np
|
|
13
10
|
from datasets import Dataset
|
|
14
|
-
from sklearn.metrics import (
|
|
15
|
-
accuracy_score,
|
|
16
|
-
auc,
|
|
17
|
-
f1_score,
|
|
18
|
-
roc_auc_score,
|
|
19
|
-
)
|
|
20
11
|
|
|
21
12
|
from ._generated_api_client.api import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
13
|
+
create_classification_model_gpu,
|
|
14
|
+
delete_classification_model,
|
|
15
|
+
evaluate_classification_model,
|
|
16
|
+
get_classification_model,
|
|
17
|
+
get_classification_model_evaluation,
|
|
18
|
+
list_classification_models,
|
|
28
19
|
list_predictions,
|
|
29
|
-
|
|
20
|
+
predict_label_gpu,
|
|
30
21
|
record_prediction_feedback,
|
|
31
|
-
|
|
22
|
+
update_classification_model,
|
|
32
23
|
)
|
|
33
24
|
from ._generated_api_client.models import (
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
25
|
+
ClassificationEvaluationRequest,
|
|
26
|
+
ClassificationModelMetadata,
|
|
27
|
+
ClassificationPredictionRequest,
|
|
28
|
+
CreateClassificationModelRequest,
|
|
29
|
+
LabelPredictionWithMemoriesAndFeedback,
|
|
37
30
|
ListPredictionsRequest,
|
|
38
|
-
PrecisionRecallCurve,
|
|
39
31
|
)
|
|
40
32
|
from ._generated_api_client.models import (
|
|
41
33
|
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
@@ -44,18 +36,21 @@ from ._generated_api_client.models import (
|
|
|
44
36
|
PredictionSortItemItemType1 as PredictionSortDirection,
|
|
45
37
|
)
|
|
46
38
|
from ._generated_api_client.models import (
|
|
39
|
+
PredictiveModelUpdate,
|
|
47
40
|
RACHeadType,
|
|
48
|
-
RACModelMetadata,
|
|
49
|
-
RACModelUpdate,
|
|
50
|
-
ROCCurve,
|
|
51
41
|
)
|
|
52
|
-
from ._generated_api_client.
|
|
53
|
-
from ._shared.metrics import
|
|
42
|
+
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
43
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
54
44
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
55
|
-
from ._utils.task import wait_for_task
|
|
56
45
|
from .datasource import Datasource
|
|
57
|
-
from .
|
|
58
|
-
from .
|
|
46
|
+
from .job import Job
|
|
47
|
+
from .memoryset import (
|
|
48
|
+
FilterItem,
|
|
49
|
+
FilterItemTuple,
|
|
50
|
+
LabeledMemoryset,
|
|
51
|
+
_parse_filter_item_from_tuple,
|
|
52
|
+
)
|
|
53
|
+
from .telemetry import ClassificationPrediction, _parse_feedback
|
|
59
54
|
|
|
60
55
|
|
|
61
56
|
class ClassificationModel:
|
|
@@ -72,6 +67,7 @@ class ClassificationModel:
|
|
|
72
67
|
memory_lookup_count: Number of memories the model uses for each prediction
|
|
73
68
|
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
74
69
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
70
|
+
locked: Whether the model is locked to prevent accidental deletion
|
|
75
71
|
created_at: When the model was created
|
|
76
72
|
"""
|
|
77
73
|
|
|
@@ -85,9 +81,10 @@ class ClassificationModel:
|
|
|
85
81
|
weigh_memories: bool | None
|
|
86
82
|
min_memory_weight: float | None
|
|
87
83
|
version: int
|
|
84
|
+
locked: bool
|
|
88
85
|
created_at: datetime
|
|
89
86
|
|
|
90
|
-
def __init__(self, metadata:
|
|
87
|
+
def __init__(self, metadata: ClassificationModelMetadata):
|
|
91
88
|
# for internal use only, do not document
|
|
92
89
|
self.id = metadata.id
|
|
93
90
|
self.name = metadata.name
|
|
@@ -99,10 +96,11 @@ class ClassificationModel:
|
|
|
99
96
|
self.weigh_memories = metadata.weigh_memories
|
|
100
97
|
self.min_memory_weight = metadata.min_memory_weight
|
|
101
98
|
self.version = metadata.version
|
|
99
|
+
self.locked = metadata.locked
|
|
102
100
|
self.created_at = metadata.created_at
|
|
103
101
|
|
|
104
102
|
self._memoryset_override_id: str | None = None
|
|
105
|
-
self._last_prediction:
|
|
103
|
+
self._last_prediction: ClassificationPrediction | None = None
|
|
106
104
|
self._last_prediction_was_batch: bool = False
|
|
107
105
|
|
|
108
106
|
def __eq__(self, other) -> bool:
|
|
@@ -120,7 +118,7 @@ class ClassificationModel:
|
|
|
120
118
|
)
|
|
121
119
|
|
|
122
120
|
@property
|
|
123
|
-
def last_prediction(self) ->
|
|
121
|
+
def last_prediction(self) -> ClassificationPrediction:
|
|
124
122
|
"""
|
|
125
123
|
Last prediction made by the model
|
|
126
124
|
|
|
@@ -208,8 +206,8 @@ class ClassificationModel:
|
|
|
208
206
|
|
|
209
207
|
return existing
|
|
210
208
|
|
|
211
|
-
metadata =
|
|
212
|
-
body=
|
|
209
|
+
metadata = create_classification_model_gpu(
|
|
210
|
+
body=CreateClassificationModelRequest(
|
|
213
211
|
name=name,
|
|
214
212
|
memoryset_id=memoryset.id,
|
|
215
213
|
head_type=RACHeadType(head_type),
|
|
@@ -236,7 +234,7 @@ class ClassificationModel:
|
|
|
236
234
|
Raises:
|
|
237
235
|
LookupError: If the classification model does not exist
|
|
238
236
|
"""
|
|
239
|
-
return cls(
|
|
237
|
+
return cls(get_classification_model(name))
|
|
240
238
|
|
|
241
239
|
@classmethod
|
|
242
240
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -263,7 +261,7 @@ class ClassificationModel:
|
|
|
263
261
|
Returns:
|
|
264
262
|
List of handles to all classification models in the OrcaCloud
|
|
265
263
|
"""
|
|
266
|
-
return [cls(metadata) for metadata in
|
|
264
|
+
return [cls(metadata) for metadata in list_classification_models()]
|
|
267
265
|
|
|
268
266
|
@classmethod
|
|
269
267
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -282,7 +280,7 @@ class ClassificationModel:
|
|
|
282
280
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
283
281
|
"""
|
|
284
282
|
try:
|
|
285
|
-
|
|
283
|
+
delete_classification_model(name_or_id)
|
|
286
284
|
logging.info(f"Deleted model {name_or_id}")
|
|
287
285
|
except LookupError:
|
|
288
286
|
if if_not_exists == "error":
|
|
@@ -290,34 +288,53 @@ class ClassificationModel:
|
|
|
290
288
|
|
|
291
289
|
def refresh(self):
|
|
292
290
|
"""Refresh the model data from the OrcaCloud"""
|
|
293
|
-
self.__dict__.update(
|
|
291
|
+
self.__dict__.update(self.open(self.name).__dict__)
|
|
294
292
|
|
|
295
|
-
def
|
|
293
|
+
def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
|
|
296
294
|
"""
|
|
297
|
-
Update editable
|
|
295
|
+
Update editable attributes of the model.
|
|
296
|
+
|
|
297
|
+
Note:
|
|
298
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
298
299
|
|
|
299
300
|
Params:
|
|
300
|
-
description: Value to set for the description
|
|
301
|
+
description: Value to set for the description
|
|
302
|
+
locked: Value to set for the locked status
|
|
301
303
|
|
|
302
304
|
Examples:
|
|
303
305
|
Update the description:
|
|
304
|
-
>>> model.
|
|
306
|
+
>>> model.set(description="New description")
|
|
305
307
|
|
|
306
308
|
Remove description:
|
|
307
|
-
>>> model.
|
|
309
|
+
>>> model.set(description=None)
|
|
310
|
+
|
|
311
|
+
Lock the model:
|
|
312
|
+
>>> model.set(locked=True)
|
|
308
313
|
"""
|
|
309
|
-
|
|
314
|
+
update_data = PredictiveModelUpdate(
|
|
315
|
+
description=CLIENT_UNSET if description is UNSET else description,
|
|
316
|
+
locked=CLIENT_UNSET if locked is UNSET else locked,
|
|
317
|
+
)
|
|
318
|
+
update_classification_model(self.id, body=update_data)
|
|
310
319
|
self.refresh()
|
|
311
320
|
|
|
321
|
+
def lock(self) -> None:
|
|
322
|
+
"""Lock the model to prevent accidental deletion"""
|
|
323
|
+
self.set(locked=True)
|
|
324
|
+
|
|
325
|
+
def unlock(self) -> None:
|
|
326
|
+
"""Unlock the model to allow deletion"""
|
|
327
|
+
self.set(locked=False)
|
|
328
|
+
|
|
312
329
|
@overload
|
|
313
330
|
def predict(
|
|
314
331
|
self,
|
|
315
332
|
value: list[str],
|
|
316
333
|
expected_labels: list[int] | None = None,
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
) -> list[
|
|
334
|
+
filters: list[FilterItemTuple] = [],
|
|
335
|
+
tags: set[str] | None = None,
|
|
336
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
337
|
+
) -> list[ClassificationPrediction]:
|
|
321
338
|
pass
|
|
322
339
|
|
|
323
340
|
@overload
|
|
@@ -325,20 +342,20 @@ class ClassificationModel:
|
|
|
325
342
|
self,
|
|
326
343
|
value: str,
|
|
327
344
|
expected_labels: int | None = None,
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
) ->
|
|
345
|
+
filters: list[FilterItemTuple] = [],
|
|
346
|
+
tags: set[str] | None = None,
|
|
347
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
348
|
+
) -> ClassificationPrediction:
|
|
332
349
|
pass
|
|
333
350
|
|
|
334
351
|
def predict(
|
|
335
352
|
self,
|
|
336
353
|
value: list[str] | str,
|
|
337
354
|
expected_labels: list[int] | int | None = None,
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
) -> list[
|
|
355
|
+
filters: list[FilterItemTuple] = [],
|
|
356
|
+
tags: set[str] | None = None,
|
|
357
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
358
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
342
359
|
"""
|
|
343
360
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
344
361
|
|
|
@@ -346,10 +363,12 @@ class ClassificationModel:
|
|
|
346
363
|
value: Value(s) to get predict the labels of
|
|
347
364
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
348
365
|
tags: Tags to add to the prediction(s)
|
|
349
|
-
save_telemetry: Whether to
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
366
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
367
|
+
* `"off"`: Do not save telemetry
|
|
368
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
369
|
+
environment variable is set.
|
|
370
|
+
* `"sync"`: Save telemetry synchronously
|
|
371
|
+
* `"async"`: Save telemetry asynchronously
|
|
353
372
|
|
|
354
373
|
Returns:
|
|
355
374
|
Label prediction or list of label predictions
|
|
@@ -357,49 +376,51 @@ class ClassificationModel:
|
|
|
357
376
|
Examples:
|
|
358
377
|
Predict the label for a single value:
|
|
359
378
|
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
360
|
-
|
|
379
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
361
380
|
|
|
362
381
|
Predict the labels for a list of values:
|
|
363
382
|
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
364
383
|
[
|
|
365
|
-
|
|
366
|
-
|
|
384
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
385
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
367
386
|
]
|
|
368
387
|
"""
|
|
369
388
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
389
|
+
parsed_filters = [
|
|
390
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
391
|
+
]
|
|
392
|
+
|
|
393
|
+
if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
|
|
394
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
376
395
|
|
|
377
|
-
response =
|
|
396
|
+
response = predict_label_gpu(
|
|
378
397
|
self.id,
|
|
379
|
-
body=
|
|
398
|
+
body=ClassificationPredictionRequest(
|
|
380
399
|
input_values=value if isinstance(value, list) else [value],
|
|
381
400
|
memoryset_override_id=self._memoryset_override_id,
|
|
382
401
|
expected_labels=(
|
|
383
402
|
expected_labels
|
|
384
403
|
if isinstance(expected_labels, list)
|
|
385
|
-
else [expected_labels]
|
|
386
|
-
|
|
387
|
-
|
|
404
|
+
else [expected_labels] if expected_labels is not None else None
|
|
405
|
+
),
|
|
406
|
+
tags=list(tags or set()),
|
|
407
|
+
save_telemetry=save_telemetry != "off",
|
|
408
|
+
save_telemetry_synchronously=(
|
|
409
|
+
os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
|
|
388
410
|
),
|
|
389
|
-
|
|
390
|
-
save_telemetry=save_telemetry,
|
|
391
|
-
save_telemetry_synchronously=save_telemetry_synchronously,
|
|
411
|
+
filters=cast(list[FilterItem], parsed_filters),
|
|
392
412
|
),
|
|
393
413
|
)
|
|
394
414
|
|
|
395
|
-
if save_telemetry and any(p.prediction_id is None for p in response):
|
|
415
|
+
if save_telemetry != "off" and any(p.prediction_id is None for p in response):
|
|
396
416
|
raise RuntimeError("Failed to save prediction to database.")
|
|
397
417
|
|
|
398
418
|
predictions = [
|
|
399
|
-
|
|
419
|
+
ClassificationPrediction(
|
|
400
420
|
prediction_id=prediction.prediction_id,
|
|
401
421
|
label=prediction.label,
|
|
402
422
|
label_name=prediction.label_name,
|
|
423
|
+
score=None,
|
|
403
424
|
confidence=prediction.confidence,
|
|
404
425
|
anomaly_score=prediction.anomaly_score,
|
|
405
426
|
memoryset=self.memoryset,
|
|
@@ -420,7 +441,7 @@ class ClassificationModel:
|
|
|
420
441
|
tag: str | None = None,
|
|
421
442
|
sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
|
|
422
443
|
expected_label_match: bool | None = None,
|
|
423
|
-
) -> list[
|
|
444
|
+
) -> list[ClassificationPrediction]:
|
|
424
445
|
"""
|
|
425
446
|
Get a list of predictions made by this model
|
|
426
447
|
|
|
@@ -440,19 +461,19 @@ class ClassificationModel:
|
|
|
440
461
|
Get the last 3 predictions:
|
|
441
462
|
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
442
463
|
[
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
464
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
465
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
466
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
|
|
446
467
|
]
|
|
447
468
|
|
|
448
469
|
|
|
449
470
|
Get second most confident prediction:
|
|
450
471
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
451
|
-
[
|
|
472
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
|
|
452
473
|
|
|
453
474
|
Get predictions where the expected label doesn't match the predicted label:
|
|
454
475
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
455
|
-
[
|
|
476
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
456
477
|
"""
|
|
457
478
|
predictions = list_predictions(
|
|
458
479
|
body=ListPredictionsRequest(
|
|
@@ -465,10 +486,11 @@ class ClassificationModel:
|
|
|
465
486
|
),
|
|
466
487
|
)
|
|
467
488
|
return [
|
|
468
|
-
|
|
489
|
+
ClassificationPrediction(
|
|
469
490
|
prediction_id=prediction.prediction_id,
|
|
470
491
|
label=prediction.label,
|
|
471
492
|
label_name=prediction.label_name,
|
|
493
|
+
score=None,
|
|
472
494
|
confidence=prediction.confidence,
|
|
473
495
|
anomaly_score=prediction.anomaly_score,
|
|
474
496
|
memoryset=self.memoryset,
|
|
@@ -476,59 +498,9 @@ class ClassificationModel:
|
|
|
476
498
|
telemetry=prediction,
|
|
477
499
|
)
|
|
478
500
|
for prediction in predictions
|
|
501
|
+
if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
|
|
479
502
|
]
|
|
480
503
|
|
|
481
|
-
def _calculate_metrics(
|
|
482
|
-
self,
|
|
483
|
-
predictions: list[LabelPrediction],
|
|
484
|
-
expected_labels: list[int],
|
|
485
|
-
) -> ClassificationEvaluationResult:
|
|
486
|
-
targets_array = np.array(expected_labels)
|
|
487
|
-
predictions_array = np.array([p.label for p in predictions])
|
|
488
|
-
|
|
489
|
-
logits_array = np.array([p.logits for p in predictions])
|
|
490
|
-
|
|
491
|
-
f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
|
|
492
|
-
accuracy = float(accuracy_score(targets_array, predictions_array))
|
|
493
|
-
|
|
494
|
-
# Only compute ROC AUC and PR AUC for binary classification
|
|
495
|
-
unique_classes = np.unique(targets_array)
|
|
496
|
-
|
|
497
|
-
pr_curve = None
|
|
498
|
-
roc_curve = None
|
|
499
|
-
|
|
500
|
-
if len(unique_classes) == 2:
|
|
501
|
-
try:
|
|
502
|
-
precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
|
|
503
|
-
pr_auc = float(auc(recalls, precisions))
|
|
504
|
-
|
|
505
|
-
pr_curve = PrecisionRecallCurve(
|
|
506
|
-
precisions=precisions.tolist(),
|
|
507
|
-
recalls=recalls.tolist(),
|
|
508
|
-
thresholds=pr_thresholds.tolist(),
|
|
509
|
-
auc=pr_auc,
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
|
|
513
|
-
roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
|
|
514
|
-
|
|
515
|
-
roc_curve = ROCCurve(
|
|
516
|
-
false_positive_rates=fpr.tolist(),
|
|
517
|
-
true_positive_rates=tpr.tolist(),
|
|
518
|
-
thresholds=roc_thresholds.tolist(),
|
|
519
|
-
auc=roc_auc,
|
|
520
|
-
)
|
|
521
|
-
except ValueError as e:
|
|
522
|
-
logging.warning(f"Error calculating PR and ROC curves: {e}")
|
|
523
|
-
|
|
524
|
-
return ClassificationEvaluationResult(
|
|
525
|
-
f1_score=f1,
|
|
526
|
-
accuracy=accuracy,
|
|
527
|
-
loss=0.0,
|
|
528
|
-
precision_recall_curve=pr_curve,
|
|
529
|
-
roc_curve=roc_curve,
|
|
530
|
-
)
|
|
531
|
-
|
|
532
504
|
def _evaluate_datasource(
|
|
533
505
|
self,
|
|
534
506
|
datasource: Datasource,
|
|
@@ -536,10 +508,11 @@ class ClassificationModel:
|
|
|
536
508
|
label_column: str,
|
|
537
509
|
record_predictions: bool,
|
|
538
510
|
tags: set[str] | None,
|
|
539
|
-
|
|
540
|
-
|
|
511
|
+
background: bool = False,
|
|
512
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
513
|
+
response = evaluate_classification_model(
|
|
541
514
|
self.id,
|
|
542
|
-
body=
|
|
515
|
+
body=ClassificationEvaluationRequest(
|
|
543
516
|
datasource_id=datasource.id,
|
|
544
517
|
datasource_label_column=label_column,
|
|
545
518
|
datasource_value_column=value_column,
|
|
@@ -548,10 +521,13 @@ class ClassificationModel:
|
|
|
548
521
|
telemetry_tags=list(tags) if tags else None,
|
|
549
522
|
),
|
|
550
523
|
)
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
524
|
+
|
|
525
|
+
job = Job(
|
|
526
|
+
response.task_id,
|
|
527
|
+
lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
|
|
528
|
+
and ClassificationMetrics(**r.to_dict()),
|
|
529
|
+
)
|
|
530
|
+
return job if background else job.result()
|
|
555
531
|
|
|
556
532
|
def _evaluate_dataset(
|
|
557
533
|
self,
|
|
@@ -561,34 +537,64 @@ class ClassificationModel:
|
|
|
561
537
|
record_predictions: bool,
|
|
562
538
|
tags: set[str],
|
|
563
539
|
batch_size: int,
|
|
564
|
-
) ->
|
|
565
|
-
predictions = [
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
expected_labels=batch[label_column],
|
|
574
|
-
tags=tags,
|
|
575
|
-
save_telemetry=record_predictions,
|
|
576
|
-
save_telemetry_synchronously=(not record_predictions),
|
|
577
|
-
)
|
|
540
|
+
) -> ClassificationMetrics:
|
|
541
|
+
predictions = [
|
|
542
|
+
prediction
|
|
543
|
+
for i in range(0, len(dataset), batch_size)
|
|
544
|
+
for prediction in self.predict(
|
|
545
|
+
dataset[i : i + batch_size][value_column],
|
|
546
|
+
expected_labels=dataset[i : i + batch_size][label_column],
|
|
547
|
+
tags=tags,
|
|
548
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
578
549
|
)
|
|
579
|
-
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
return calculate_classification_metrics(
|
|
553
|
+
expected_labels=dataset[label_column],
|
|
554
|
+
logits=[p.logits for p in predictions],
|
|
555
|
+
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
556
|
+
include_curves=True,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
@overload
|
|
560
|
+
def evaluate(
|
|
561
|
+
self,
|
|
562
|
+
data: Datasource | Dataset,
|
|
563
|
+
*,
|
|
564
|
+
value_column: str = "value",
|
|
565
|
+
label_column: str = "label",
|
|
566
|
+
record_predictions: bool = False,
|
|
567
|
+
tags: set[str] = {"evaluation"},
|
|
568
|
+
batch_size: int = 100,
|
|
569
|
+
background: Literal[True],
|
|
570
|
+
) -> Job[ClassificationMetrics]:
|
|
571
|
+
pass
|
|
580
572
|
|
|
581
|
-
|
|
573
|
+
@overload
|
|
574
|
+
def evaluate(
|
|
575
|
+
self,
|
|
576
|
+
data: Datasource | Dataset,
|
|
577
|
+
*,
|
|
578
|
+
value_column: str = "value",
|
|
579
|
+
label_column: str = "label",
|
|
580
|
+
record_predictions: bool = False,
|
|
581
|
+
tags: set[str] = {"evaluation"},
|
|
582
|
+
batch_size: int = 100,
|
|
583
|
+
background: Literal[False] = False,
|
|
584
|
+
) -> ClassificationMetrics:
|
|
585
|
+
pass
|
|
582
586
|
|
|
583
587
|
def evaluate(
|
|
584
588
|
self,
|
|
585
589
|
data: Datasource | Dataset,
|
|
590
|
+
*,
|
|
586
591
|
value_column: str = "value",
|
|
587
592
|
label_column: str = "label",
|
|
588
593
|
record_predictions: bool = False,
|
|
589
594
|
tags: set[str] = {"evaluation"},
|
|
590
595
|
batch_size: int = 100,
|
|
591
|
-
|
|
596
|
+
background: bool = False,
|
|
597
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
592
598
|
"""
|
|
593
599
|
Evaluate the classification model on a given dataset or datasource
|
|
594
600
|
|
|
@@ -596,21 +602,23 @@ class ClassificationModel:
|
|
|
596
602
|
data: Dataset or Datasource to evaluate the model on
|
|
597
603
|
value_column: Name of the column that contains the input values to the model
|
|
598
604
|
label_column: Name of the column containing the expected labels
|
|
599
|
-
record_predictions: Whether to record [`
|
|
600
|
-
tags: Optional tags to add to the recorded [`
|
|
605
|
+
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
606
|
+
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
601
607
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
608
|
+
background: Whether to run the operation in the background and return a job handle
|
|
602
609
|
|
|
603
610
|
Returns:
|
|
604
|
-
|
|
611
|
+
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
605
612
|
|
|
606
613
|
Examples:
|
|
607
|
-
Evaluate using a Datasource:
|
|
608
614
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
609
|
-
{
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
615
|
+
ClassificationMetrics({
|
|
616
|
+
accuracy: 0.8500,
|
|
617
|
+
f1_score: 0.8500,
|
|
618
|
+
roc_auc: 0.8500,
|
|
619
|
+
pr_auc: 0.8500,
|
|
620
|
+
anomaly_score: 0.3500 ± 0.0500,
|
|
621
|
+
})
|
|
614
622
|
"""
|
|
615
623
|
if isinstance(data, Datasource):
|
|
616
624
|
return self._evaluate_datasource(
|
|
@@ -619,8 +627,9 @@ class ClassificationModel:
|
|
|
619
627
|
label_column=label_column,
|
|
620
628
|
record_predictions=record_predictions,
|
|
621
629
|
tags=tags,
|
|
630
|
+
background=background,
|
|
622
631
|
)
|
|
623
|
-
|
|
632
|
+
elif isinstance(data, Dataset):
|
|
624
633
|
return self._evaluate_dataset(
|
|
625
634
|
dataset=data,
|
|
626
635
|
value_column=value_column,
|
|
@@ -629,6 +638,8 @@ class ClassificationModel:
|
|
|
629
638
|
tags=tags,
|
|
630
639
|
batch_size=batch_size,
|
|
631
640
|
)
|
|
641
|
+
else:
|
|
642
|
+
raise ValueError(f"Invalid data type: {type(data)}")
|
|
632
643
|
|
|
633
644
|
def finetune(self, datasource: Datasource):
|
|
634
645
|
# do not document until implemented
|