orca-sdk 0.0.94__py3-none-any.whl → 0.0.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +80 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +84 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +183 -175
- orca_sdk/classification_model_test.py +147 -157
- orca_sdk/conftest.py +76 -26
- orca_sdk/datasource_test.py +0 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +56 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +338 -0
- orca_sdk/telemetry.py +223 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +115 -69
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -5,37 +5,29 @@ import os
|
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
8
|
-
from uuid import UUID
|
|
8
|
+
from uuid import UUID
|
|
9
9
|
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
import numpy as np
|
|
13
10
|
from datasets import Dataset
|
|
14
|
-
from sklearn.metrics import (
|
|
15
|
-
accuracy_score,
|
|
16
|
-
auc,
|
|
17
|
-
f1_score,
|
|
18
|
-
roc_auc_score,
|
|
19
|
-
)
|
|
20
11
|
|
|
21
12
|
from ._generated_api_client.api import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
13
|
+
create_classification_model,
|
|
14
|
+
delete_classification_model,
|
|
15
|
+
evaluate_classification_model,
|
|
16
|
+
get_classification_model,
|
|
17
|
+
get_classification_model_evaluation,
|
|
18
|
+
list_classification_models,
|
|
28
19
|
list_predictions,
|
|
29
|
-
|
|
20
|
+
predict_label_gpu,
|
|
30
21
|
record_prediction_feedback,
|
|
31
|
-
|
|
22
|
+
update_classification_model,
|
|
32
23
|
)
|
|
33
24
|
from ._generated_api_client.models import (
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
25
|
+
ClassificationEvaluationRequest,
|
|
26
|
+
ClassificationModelMetadata,
|
|
27
|
+
ClassificationPredictionRequest,
|
|
28
|
+
CreateClassificationModelRequest,
|
|
29
|
+
LabelPredictionWithMemoriesAndFeedback,
|
|
37
30
|
ListPredictionsRequest,
|
|
38
|
-
PrecisionRecallCurve,
|
|
39
31
|
)
|
|
40
32
|
from ._generated_api_client.models import (
|
|
41
33
|
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
@@ -43,19 +35,19 @@ from ._generated_api_client.models import (
|
|
|
43
35
|
from ._generated_api_client.models import (
|
|
44
36
|
PredictionSortItemItemType1 as PredictionSortDirection,
|
|
45
37
|
)
|
|
46
|
-
from ._generated_api_client.models import
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
RACModelUpdate,
|
|
50
|
-
ROCCurve,
|
|
51
|
-
)
|
|
52
|
-
from ._generated_api_client.models.prediction_request import PredictionRequest
|
|
53
|
-
from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
|
|
38
|
+
from ._generated_api_client.models import PredictiveModelUpdate, RACHeadType
|
|
39
|
+
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
40
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
54
41
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
55
|
-
from ._utils.task import wait_for_task
|
|
56
42
|
from .datasource import Datasource
|
|
57
|
-
from .
|
|
58
|
-
from .
|
|
43
|
+
from .job import Job
|
|
44
|
+
from .memoryset import (
|
|
45
|
+
FilterItem,
|
|
46
|
+
FilterItemTuple,
|
|
47
|
+
LabeledMemoryset,
|
|
48
|
+
_parse_filter_item_from_tuple,
|
|
49
|
+
)
|
|
50
|
+
from .telemetry import ClassificationPrediction, _parse_feedback
|
|
59
51
|
|
|
60
52
|
|
|
61
53
|
class ClassificationModel:
|
|
@@ -72,6 +64,7 @@ class ClassificationModel:
|
|
|
72
64
|
memory_lookup_count: Number of memories the model uses for each prediction
|
|
73
65
|
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
74
66
|
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
67
|
+
locked: Whether the model is locked to prevent accidental deletion
|
|
75
68
|
created_at: When the model was created
|
|
76
69
|
"""
|
|
77
70
|
|
|
@@ -85,9 +78,10 @@ class ClassificationModel:
|
|
|
85
78
|
weigh_memories: bool | None
|
|
86
79
|
min_memory_weight: float | None
|
|
87
80
|
version: int
|
|
81
|
+
locked: bool
|
|
88
82
|
created_at: datetime
|
|
89
83
|
|
|
90
|
-
def __init__(self, metadata:
|
|
84
|
+
def __init__(self, metadata: ClassificationModelMetadata):
|
|
91
85
|
# for internal use only, do not document
|
|
92
86
|
self.id = metadata.id
|
|
93
87
|
self.name = metadata.name
|
|
@@ -99,10 +93,11 @@ class ClassificationModel:
|
|
|
99
93
|
self.weigh_memories = metadata.weigh_memories
|
|
100
94
|
self.min_memory_weight = metadata.min_memory_weight
|
|
101
95
|
self.version = metadata.version
|
|
96
|
+
self.locked = metadata.locked
|
|
102
97
|
self.created_at = metadata.created_at
|
|
103
98
|
|
|
104
99
|
self._memoryset_override_id: str | None = None
|
|
105
|
-
self._last_prediction:
|
|
100
|
+
self._last_prediction: ClassificationPrediction | None = None
|
|
106
101
|
self._last_prediction_was_batch: bool = False
|
|
107
102
|
|
|
108
103
|
def __eq__(self, other) -> bool:
|
|
@@ -120,7 +115,7 @@ class ClassificationModel:
|
|
|
120
115
|
)
|
|
121
116
|
|
|
122
117
|
@property
|
|
123
|
-
def last_prediction(self) ->
|
|
118
|
+
def last_prediction(self) -> ClassificationPrediction:
|
|
124
119
|
"""
|
|
125
120
|
Last prediction made by the model
|
|
126
121
|
|
|
@@ -208,8 +203,8 @@ class ClassificationModel:
|
|
|
208
203
|
|
|
209
204
|
return existing
|
|
210
205
|
|
|
211
|
-
metadata =
|
|
212
|
-
body=
|
|
206
|
+
metadata = create_classification_model(
|
|
207
|
+
body=CreateClassificationModelRequest(
|
|
213
208
|
name=name,
|
|
214
209
|
memoryset_id=memoryset.id,
|
|
215
210
|
head_type=RACHeadType(head_type),
|
|
@@ -236,7 +231,7 @@ class ClassificationModel:
|
|
|
236
231
|
Raises:
|
|
237
232
|
LookupError: If the classification model does not exist
|
|
238
233
|
"""
|
|
239
|
-
return cls(
|
|
234
|
+
return cls(get_classification_model(name))
|
|
240
235
|
|
|
241
236
|
@classmethod
|
|
242
237
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -263,7 +258,7 @@ class ClassificationModel:
|
|
|
263
258
|
Returns:
|
|
264
259
|
List of handles to all classification models in the OrcaCloud
|
|
265
260
|
"""
|
|
266
|
-
return [cls(metadata) for metadata in
|
|
261
|
+
return [cls(metadata) for metadata in list_classification_models()]
|
|
267
262
|
|
|
268
263
|
@classmethod
|
|
269
264
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -282,7 +277,7 @@ class ClassificationModel:
|
|
|
282
277
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
283
278
|
"""
|
|
284
279
|
try:
|
|
285
|
-
|
|
280
|
+
delete_classification_model(name_or_id)
|
|
286
281
|
logging.info(f"Deleted model {name_or_id}")
|
|
287
282
|
except LookupError:
|
|
288
283
|
if if_not_exists == "error":
|
|
@@ -290,34 +285,53 @@ class ClassificationModel:
|
|
|
290
285
|
|
|
291
286
|
def refresh(self):
|
|
292
287
|
"""Refresh the model data from the OrcaCloud"""
|
|
293
|
-
self.__dict__.update(
|
|
288
|
+
self.__dict__.update(self.open(self.name).__dict__)
|
|
294
289
|
|
|
295
|
-
def
|
|
290
|
+
def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
|
|
296
291
|
"""
|
|
297
|
-
Update editable
|
|
292
|
+
Update editable attributes of the model.
|
|
293
|
+
|
|
294
|
+
Note:
|
|
295
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
298
296
|
|
|
299
297
|
Params:
|
|
300
|
-
description: Value to set for the description
|
|
298
|
+
description: Value to set for the description
|
|
299
|
+
locked: Value to set for the locked status
|
|
301
300
|
|
|
302
301
|
Examples:
|
|
303
302
|
Update the description:
|
|
304
|
-
>>> model.
|
|
303
|
+
>>> model.set(description="New description")
|
|
305
304
|
|
|
306
305
|
Remove description:
|
|
307
|
-
>>> model.
|
|
306
|
+
>>> model.set(description=None)
|
|
307
|
+
|
|
308
|
+
Lock the model:
|
|
309
|
+
>>> model.set(locked=True)
|
|
308
310
|
"""
|
|
309
|
-
|
|
311
|
+
update_data = PredictiveModelUpdate(
|
|
312
|
+
description=CLIENT_UNSET if description is UNSET else description,
|
|
313
|
+
locked=CLIENT_UNSET if locked is UNSET else locked,
|
|
314
|
+
)
|
|
315
|
+
update_classification_model(self.id, body=update_data)
|
|
310
316
|
self.refresh()
|
|
311
317
|
|
|
318
|
+
def lock(self) -> None:
|
|
319
|
+
"""Lock the model to prevent accidental deletion"""
|
|
320
|
+
self.set(locked=True)
|
|
321
|
+
|
|
322
|
+
def unlock(self) -> None:
|
|
323
|
+
"""Unlock the model to allow deletion"""
|
|
324
|
+
self.set(locked=False)
|
|
325
|
+
|
|
312
326
|
@overload
|
|
313
327
|
def predict(
|
|
314
328
|
self,
|
|
315
329
|
value: list[str],
|
|
316
330
|
expected_labels: list[int] | None = None,
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
) -> list[
|
|
331
|
+
filters: list[FilterItemTuple] = [],
|
|
332
|
+
tags: set[str] | None = None,
|
|
333
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
334
|
+
) -> list[ClassificationPrediction]:
|
|
321
335
|
pass
|
|
322
336
|
|
|
323
337
|
@overload
|
|
@@ -325,20 +339,20 @@ class ClassificationModel:
|
|
|
325
339
|
self,
|
|
326
340
|
value: str,
|
|
327
341
|
expected_labels: int | None = None,
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
) ->
|
|
342
|
+
filters: list[FilterItemTuple] = [],
|
|
343
|
+
tags: set[str] | None = None,
|
|
344
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
345
|
+
) -> ClassificationPrediction:
|
|
332
346
|
pass
|
|
333
347
|
|
|
334
348
|
def predict(
|
|
335
349
|
self,
|
|
336
350
|
value: list[str] | str,
|
|
337
351
|
expected_labels: list[int] | int | None = None,
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
) -> list[
|
|
352
|
+
filters: list[FilterItemTuple] = [],
|
|
353
|
+
tags: set[str] | None = None,
|
|
354
|
+
save_telemetry: Literal["off", "on", "sync", "async"] = "on",
|
|
355
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
342
356
|
"""
|
|
343
357
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
344
358
|
|
|
@@ -346,10 +360,12 @@ class ClassificationModel:
|
|
|
346
360
|
value: Value(s) to get predict the labels of
|
|
347
361
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
348
362
|
tags: Tags to add to the prediction(s)
|
|
349
|
-
save_telemetry: Whether to
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
363
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
364
|
+
* `"off"`: Do not save telemetry
|
|
365
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
366
|
+
environment variable is set.
|
|
367
|
+
* `"sync"`: Save telemetry synchronously
|
|
368
|
+
* `"async"`: Save telemetry asynchronously
|
|
353
369
|
|
|
354
370
|
Returns:
|
|
355
371
|
Label prediction or list of label predictions
|
|
@@ -357,49 +373,51 @@ class ClassificationModel:
|
|
|
357
373
|
Examples:
|
|
358
374
|
Predict the label for a single value:
|
|
359
375
|
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
360
|
-
|
|
376
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
361
377
|
|
|
362
378
|
Predict the labels for a list of values:
|
|
363
379
|
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
364
380
|
[
|
|
365
|
-
|
|
366
|
-
|
|
381
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
382
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
367
383
|
]
|
|
368
384
|
"""
|
|
369
385
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
|
|
374
|
-
)
|
|
375
|
-
save_telemetry_synchronously = env_var.lower() == "true"
|
|
386
|
+
parsed_filters = [
|
|
387
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
388
|
+
]
|
|
376
389
|
|
|
377
|
-
|
|
390
|
+
if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
|
|
391
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
392
|
+
|
|
393
|
+
response = predict_label_gpu(
|
|
378
394
|
self.id,
|
|
379
|
-
body=
|
|
395
|
+
body=ClassificationPredictionRequest(
|
|
380
396
|
input_values=value if isinstance(value, list) else [value],
|
|
381
397
|
memoryset_override_id=self._memoryset_override_id,
|
|
382
398
|
expected_labels=(
|
|
383
399
|
expected_labels
|
|
384
400
|
if isinstance(expected_labels, list)
|
|
385
|
-
else [expected_labels]
|
|
386
|
-
|
|
387
|
-
|
|
401
|
+
else [expected_labels] if expected_labels is not None else None
|
|
402
|
+
),
|
|
403
|
+
tags=list(tags or set()),
|
|
404
|
+
save_telemetry=save_telemetry != "off",
|
|
405
|
+
save_telemetry_synchronously=(
|
|
406
|
+
os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
|
|
388
407
|
),
|
|
389
|
-
|
|
390
|
-
save_telemetry=save_telemetry,
|
|
391
|
-
save_telemetry_synchronously=save_telemetry_synchronously,
|
|
408
|
+
filters=cast(list[FilterItem], parsed_filters),
|
|
392
409
|
),
|
|
393
410
|
)
|
|
394
411
|
|
|
395
|
-
if save_telemetry and any(p.prediction_id is None for p in response):
|
|
412
|
+
if save_telemetry != "off" and any(p.prediction_id is None for p in response):
|
|
396
413
|
raise RuntimeError("Failed to save prediction to database.")
|
|
397
414
|
|
|
398
415
|
predictions = [
|
|
399
|
-
|
|
416
|
+
ClassificationPrediction(
|
|
400
417
|
prediction_id=prediction.prediction_id,
|
|
401
418
|
label=prediction.label,
|
|
402
419
|
label_name=prediction.label_name,
|
|
420
|
+
score=None,
|
|
403
421
|
confidence=prediction.confidence,
|
|
404
422
|
anomaly_score=prediction.anomaly_score,
|
|
405
423
|
memoryset=self.memoryset,
|
|
@@ -420,7 +438,7 @@ class ClassificationModel:
|
|
|
420
438
|
tag: str | None = None,
|
|
421
439
|
sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
|
|
422
440
|
expected_label_match: bool | None = None,
|
|
423
|
-
) -> list[
|
|
441
|
+
) -> list[ClassificationPrediction]:
|
|
424
442
|
"""
|
|
425
443
|
Get a list of predictions made by this model
|
|
426
444
|
|
|
@@ -440,19 +458,19 @@ class ClassificationModel:
|
|
|
440
458
|
Get the last 3 predictions:
|
|
441
459
|
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
442
460
|
[
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
461
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
462
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
463
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
|
|
446
464
|
]
|
|
447
465
|
|
|
448
466
|
|
|
449
467
|
Get second most confident prediction:
|
|
450
468
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
451
|
-
[
|
|
469
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
|
|
452
470
|
|
|
453
471
|
Get predictions where the expected label doesn't match the predicted label:
|
|
454
472
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
455
|
-
[
|
|
473
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
456
474
|
"""
|
|
457
475
|
predictions = list_predictions(
|
|
458
476
|
body=ListPredictionsRequest(
|
|
@@ -465,10 +483,11 @@ class ClassificationModel:
|
|
|
465
483
|
),
|
|
466
484
|
)
|
|
467
485
|
return [
|
|
468
|
-
|
|
486
|
+
ClassificationPrediction(
|
|
469
487
|
prediction_id=prediction.prediction_id,
|
|
470
488
|
label=prediction.label,
|
|
471
489
|
label_name=prediction.label_name,
|
|
490
|
+
score=None,
|
|
472
491
|
confidence=prediction.confidence,
|
|
473
492
|
anomaly_score=prediction.anomaly_score,
|
|
474
493
|
memoryset=self.memoryset,
|
|
@@ -476,59 +495,9 @@ class ClassificationModel:
|
|
|
476
495
|
telemetry=prediction,
|
|
477
496
|
)
|
|
478
497
|
for prediction in predictions
|
|
498
|
+
if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
|
|
479
499
|
]
|
|
480
500
|
|
|
481
|
-
def _calculate_metrics(
|
|
482
|
-
self,
|
|
483
|
-
predictions: list[LabelPrediction],
|
|
484
|
-
expected_labels: list[int],
|
|
485
|
-
) -> ClassificationEvaluationResult:
|
|
486
|
-
targets_array = np.array(expected_labels)
|
|
487
|
-
predictions_array = np.array([p.label for p in predictions])
|
|
488
|
-
|
|
489
|
-
logits_array = np.array([p.logits for p in predictions])
|
|
490
|
-
|
|
491
|
-
f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
|
|
492
|
-
accuracy = float(accuracy_score(targets_array, predictions_array))
|
|
493
|
-
|
|
494
|
-
# Only compute ROC AUC and PR AUC for binary classification
|
|
495
|
-
unique_classes = np.unique(targets_array)
|
|
496
|
-
|
|
497
|
-
pr_curve = None
|
|
498
|
-
roc_curve = None
|
|
499
|
-
|
|
500
|
-
if len(unique_classes) == 2:
|
|
501
|
-
try:
|
|
502
|
-
precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
|
|
503
|
-
pr_auc = float(auc(recalls, precisions))
|
|
504
|
-
|
|
505
|
-
pr_curve = PrecisionRecallCurve(
|
|
506
|
-
precisions=precisions.tolist(),
|
|
507
|
-
recalls=recalls.tolist(),
|
|
508
|
-
thresholds=pr_thresholds.tolist(),
|
|
509
|
-
auc=pr_auc,
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
|
|
513
|
-
roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
|
|
514
|
-
|
|
515
|
-
roc_curve = ROCCurve(
|
|
516
|
-
false_positive_rates=fpr.tolist(),
|
|
517
|
-
true_positive_rates=tpr.tolist(),
|
|
518
|
-
thresholds=roc_thresholds.tolist(),
|
|
519
|
-
auc=roc_auc,
|
|
520
|
-
)
|
|
521
|
-
except ValueError as e:
|
|
522
|
-
logging.warning(f"Error calculating PR and ROC curves: {e}")
|
|
523
|
-
|
|
524
|
-
return ClassificationEvaluationResult(
|
|
525
|
-
f1_score=f1,
|
|
526
|
-
accuracy=accuracy,
|
|
527
|
-
loss=0.0,
|
|
528
|
-
precision_recall_curve=pr_curve,
|
|
529
|
-
roc_curve=roc_curve,
|
|
530
|
-
)
|
|
531
|
-
|
|
532
501
|
def _evaluate_datasource(
|
|
533
502
|
self,
|
|
534
503
|
datasource: Datasource,
|
|
@@ -536,10 +505,11 @@ class ClassificationModel:
|
|
|
536
505
|
label_column: str,
|
|
537
506
|
record_predictions: bool,
|
|
538
507
|
tags: set[str] | None,
|
|
539
|
-
|
|
540
|
-
|
|
508
|
+
background: bool = False,
|
|
509
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
510
|
+
response = evaluate_classification_model(
|
|
541
511
|
self.id,
|
|
542
|
-
body=
|
|
512
|
+
body=ClassificationEvaluationRequest(
|
|
543
513
|
datasource_id=datasource.id,
|
|
544
514
|
datasource_label_column=label_column,
|
|
545
515
|
datasource_value_column=value_column,
|
|
@@ -548,10 +518,13 @@ class ClassificationModel:
|
|
|
548
518
|
telemetry_tags=list(tags) if tags else None,
|
|
549
519
|
),
|
|
550
520
|
)
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
521
|
+
|
|
522
|
+
job = Job(
|
|
523
|
+
response.task_id,
|
|
524
|
+
lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
|
|
525
|
+
and ClassificationMetrics(**r.to_dict()),
|
|
526
|
+
)
|
|
527
|
+
return job if background else job.result()
|
|
555
528
|
|
|
556
529
|
def _evaluate_dataset(
|
|
557
530
|
self,
|
|
@@ -561,34 +534,64 @@ class ClassificationModel:
|
|
|
561
534
|
record_predictions: bool,
|
|
562
535
|
tags: set[str],
|
|
563
536
|
batch_size: int,
|
|
564
|
-
) ->
|
|
565
|
-
predictions = [
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
expected_labels=batch[label_column],
|
|
574
|
-
tags=tags,
|
|
575
|
-
save_telemetry=record_predictions,
|
|
576
|
-
save_telemetry_synchronously=(not record_predictions),
|
|
577
|
-
)
|
|
537
|
+
) -> ClassificationMetrics:
|
|
538
|
+
predictions = [
|
|
539
|
+
prediction
|
|
540
|
+
for i in range(0, len(dataset), batch_size)
|
|
541
|
+
for prediction in self.predict(
|
|
542
|
+
dataset[i : i + batch_size][value_column],
|
|
543
|
+
expected_labels=dataset[i : i + batch_size][label_column],
|
|
544
|
+
tags=tags,
|
|
545
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
578
546
|
)
|
|
579
|
-
|
|
547
|
+
]
|
|
548
|
+
|
|
549
|
+
return calculate_classification_metrics(
|
|
550
|
+
expected_labels=dataset[label_column],
|
|
551
|
+
logits=[p.logits for p in predictions],
|
|
552
|
+
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
553
|
+
include_curves=True,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
@overload
|
|
557
|
+
def evaluate(
|
|
558
|
+
self,
|
|
559
|
+
data: Datasource | Dataset,
|
|
560
|
+
*,
|
|
561
|
+
value_column: str = "value",
|
|
562
|
+
label_column: str = "label",
|
|
563
|
+
record_predictions: bool = False,
|
|
564
|
+
tags: set[str] = {"evaluation"},
|
|
565
|
+
batch_size: int = 100,
|
|
566
|
+
background: Literal[True],
|
|
567
|
+
) -> Job[ClassificationMetrics]:
|
|
568
|
+
pass
|
|
580
569
|
|
|
581
|
-
|
|
570
|
+
@overload
|
|
571
|
+
def evaluate(
|
|
572
|
+
self,
|
|
573
|
+
data: Datasource | Dataset,
|
|
574
|
+
*,
|
|
575
|
+
value_column: str = "value",
|
|
576
|
+
label_column: str = "label",
|
|
577
|
+
record_predictions: bool = False,
|
|
578
|
+
tags: set[str] = {"evaluation"},
|
|
579
|
+
batch_size: int = 100,
|
|
580
|
+
background: Literal[False] = False,
|
|
581
|
+
) -> ClassificationMetrics:
|
|
582
|
+
pass
|
|
582
583
|
|
|
583
584
|
def evaluate(
|
|
584
585
|
self,
|
|
585
586
|
data: Datasource | Dataset,
|
|
587
|
+
*,
|
|
586
588
|
value_column: str = "value",
|
|
587
589
|
label_column: str = "label",
|
|
588
590
|
record_predictions: bool = False,
|
|
589
591
|
tags: set[str] = {"evaluation"},
|
|
590
592
|
batch_size: int = 100,
|
|
591
|
-
|
|
593
|
+
background: bool = False,
|
|
594
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
592
595
|
"""
|
|
593
596
|
Evaluate the classification model on a given dataset or datasource
|
|
594
597
|
|
|
@@ -596,21 +599,23 @@ class ClassificationModel:
|
|
|
596
599
|
data: Dataset or Datasource to evaluate the model on
|
|
597
600
|
value_column: Name of the column that contains the input values to the model
|
|
598
601
|
label_column: Name of the column containing the expected labels
|
|
599
|
-
record_predictions: Whether to record [`
|
|
600
|
-
tags: Optional tags to add to the recorded [`
|
|
602
|
+
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
603
|
+
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
601
604
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
605
|
+
background: Whether to run the operation in the background and return a job handle
|
|
602
606
|
|
|
603
607
|
Returns:
|
|
604
|
-
|
|
608
|
+
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
605
609
|
|
|
606
610
|
Examples:
|
|
607
|
-
Evaluate using a Datasource:
|
|
608
611
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
609
|
-
{
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
612
|
+
ClassificationMetrics({
|
|
613
|
+
accuracy: 0.8500,
|
|
614
|
+
f1_score: 0.8500,
|
|
615
|
+
roc_auc: 0.8500,
|
|
616
|
+
pr_auc: 0.8500,
|
|
617
|
+
anomaly_score: 0.3500 ± 0.0500,
|
|
618
|
+
})
|
|
614
619
|
"""
|
|
615
620
|
if isinstance(data, Datasource):
|
|
616
621
|
return self._evaluate_datasource(
|
|
@@ -619,8 +624,9 @@ class ClassificationModel:
|
|
|
619
624
|
label_column=label_column,
|
|
620
625
|
record_predictions=record_predictions,
|
|
621
626
|
tags=tags,
|
|
627
|
+
background=background,
|
|
622
628
|
)
|
|
623
|
-
|
|
629
|
+
elif isinstance(data, Dataset):
|
|
624
630
|
return self._evaluate_dataset(
|
|
625
631
|
dataset=data,
|
|
626
632
|
value_column=value_column,
|
|
@@ -629,6 +635,8 @@ class ClassificationModel:
|
|
|
629
635
|
tags=tags,
|
|
630
636
|
batch_size=batch_size,
|
|
631
637
|
)
|
|
638
|
+
else:
|
|
639
|
+
raise ValueError(f"Invalid data type: {type(data)}")
|
|
632
640
|
|
|
633
641
|
def finetune(self, datasource: Datasource):
|
|
634
642
|
# do not document until implemented
|