orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +84 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +90 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +184 -174
- orca_sdk/classification_model_test.py +178 -142
- orca_sdk/conftest.py +77 -26
- orca_sdk/datasource.py +34 -0
- orca_sdk/datasource_test.py +9 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +58 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +338 -0
- orca_sdk/telemetry.py +225 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is generated by the openapi-python-client tool via the generate_api_client.py script
|
|
3
|
+
|
|
4
|
+
It is a customized template from the openapi-python-client tool's default template:
|
|
5
|
+
https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
|
|
6
|
+
|
|
7
|
+
The main change is:
|
|
8
|
+
- Fix typing issues
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# flake8: noqa: C901
|
|
12
|
+
|
|
13
|
+
from typing import Any, List, Type, TypeVar, Union, cast
|
|
14
|
+
|
|
15
|
+
from attrs import define as _attrs_define
|
|
16
|
+
from attrs import field as _attrs_field
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T", bound="ValidationError")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@_attrs_define
|
|
22
|
+
class ValidationError:
|
|
23
|
+
"""
|
|
24
|
+
Attributes:
|
|
25
|
+
loc (List[Union[int, str]]):
|
|
26
|
+
msg (str):
|
|
27
|
+
type (str):
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
loc: List[Union[int, str]]
|
|
31
|
+
msg: str
|
|
32
|
+
type: str
|
|
33
|
+
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
34
|
+
|
|
35
|
+
def to_dict(self) -> dict[str, Any]:
|
|
36
|
+
loc = []
|
|
37
|
+
for loc_item_data in self.loc:
|
|
38
|
+
loc_item: Union[int, str]
|
|
39
|
+
loc_item = loc_item_data
|
|
40
|
+
loc.append(loc_item)
|
|
41
|
+
|
|
42
|
+
msg = self.msg
|
|
43
|
+
|
|
44
|
+
type = self.type
|
|
45
|
+
|
|
46
|
+
field_dict: dict[str, Any] = {}
|
|
47
|
+
field_dict.update(self.additional_properties)
|
|
48
|
+
field_dict.update(
|
|
49
|
+
{
|
|
50
|
+
"loc": loc,
|
|
51
|
+
"msg": msg,
|
|
52
|
+
"type": type,
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return field_dict
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
|
|
60
|
+
d = src_dict.copy()
|
|
61
|
+
loc = []
|
|
62
|
+
_loc = d.pop("loc")
|
|
63
|
+
for loc_item_data in _loc:
|
|
64
|
+
|
|
65
|
+
def _parse_loc_item(data: object) -> Union[int, str]:
|
|
66
|
+
return cast(Union[int, str], data)
|
|
67
|
+
|
|
68
|
+
loc_item = _parse_loc_item(loc_item_data)
|
|
69
|
+
|
|
70
|
+
loc.append(loc_item)
|
|
71
|
+
|
|
72
|
+
msg = d.pop("msg")
|
|
73
|
+
|
|
74
|
+
type = d.pop("type")
|
|
75
|
+
|
|
76
|
+
validation_error = cls(
|
|
77
|
+
loc=loc,
|
|
78
|
+
msg=msg,
|
|
79
|
+
type=type,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
validation_error.additional_properties = d
|
|
83
|
+
return validation_error
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def additional_keys(self) -> list[str]:
|
|
87
|
+
return list(self.additional_properties.keys())
|
|
88
|
+
|
|
89
|
+
def __getitem__(self, key: str) -> Any:
|
|
90
|
+
return self.additional_properties[key]
|
|
91
|
+
|
|
92
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
93
|
+
self.additional_properties[key] = value
|
|
94
|
+
|
|
95
|
+
def __delitem__(self, key: str) -> None:
|
|
96
|
+
del self.additional_properties[key]
|
|
97
|
+
|
|
98
|
+
def __contains__(self, key: str) -> bool:
|
|
99
|
+
return key in self.additional_properties
|
orca_sdk/_shared/__init__.py
CHANGED
|
@@ -1 +1,9 @@
|
|
|
1
|
-
from .metrics import
|
|
1
|
+
from .metrics import (
|
|
2
|
+
ClassificationMetrics,
|
|
3
|
+
PRCurve,
|
|
4
|
+
RegressionMetrics,
|
|
5
|
+
ROCCurve,
|
|
6
|
+
calculate_classification_metrics,
|
|
7
|
+
calculate_pr_curve,
|
|
8
|
+
calculate_roc_curve,
|
|
9
|
+
)
|
orca_sdk/_shared/metrics.py
CHANGED
|
@@ -8,37 +8,25 @@ IMPORTANT:
|
|
|
8
8
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Literal, TypedDict, cast
|
|
12
13
|
|
|
13
14
|
import numpy as np
|
|
15
|
+
import sklearn.metrics
|
|
14
16
|
from numpy.typing import NDArray
|
|
15
|
-
from scipy.special import softmax
|
|
16
|
-
from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
|
|
17
|
-
from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
|
|
18
|
-
from sklearn.metrics import roc_auc_score
|
|
19
|
-
from sklearn.metrics import roc_curve as sklearn_roc_curve
|
|
20
|
-
from transformers.trainer_utils import EvalPrediction
|
|
21
17
|
|
|
22
18
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
log_loss: float # cross-entropy loss for probabilities
|
|
29
|
-
|
|
19
|
+
# we don't want to depend on scipy or torch in orca_sdk
|
|
20
|
+
def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
21
|
+
shifted = logits - np.max(logits, axis=axis, keepdims=True)
|
|
22
|
+
exps = np.exp(shifted)
|
|
23
|
+
return exps / np.sum(exps, axis=axis, keepdims=True)
|
|
30
24
|
|
|
31
|
-
def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
|
|
32
|
-
"""
|
|
33
|
-
Compute standard metrics for classifier with Hugging Face Trainer.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
eval_pred: The predictions containing logits and expected labels as given by the Trainer.
|
|
37
25
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
logits, references = eval_pred
|
|
26
|
+
# We don't want to depend on transformers just for the eval_pred type in orca_sdk
|
|
27
|
+
def transform_eval_pred(eval_pred: Any) -> tuple[NDArray[np.int64], NDArray[np.float32]]:
|
|
28
|
+
# convert results from Trainer compute_metrics param for use in calculate_classification_metrics
|
|
29
|
+
logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
|
|
42
30
|
if isinstance(logits, tuple):
|
|
43
31
|
logits = logits[0]
|
|
44
32
|
if not isinstance(logits, np.ndarray):
|
|
@@ -48,72 +36,20 @@ def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetri
|
|
|
48
36
|
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
49
37
|
)
|
|
50
38
|
|
|
51
|
-
|
|
52
|
-
# convert logits to probabilities with softmax if necessary
|
|
53
|
-
probabilities = softmax(logits)
|
|
54
|
-
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
55
|
-
# convert logits to probabilities through normalization if necessary
|
|
56
|
-
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
57
|
-
else:
|
|
58
|
-
probabilities = logits
|
|
59
|
-
|
|
60
|
-
return classification_scores(references, probabilities)
|
|
39
|
+
return (references, logits)
|
|
61
40
|
|
|
62
41
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
68
|
-
) -> ClassificationMetrics:
|
|
69
|
-
if probabilities.ndim == 1:
|
|
70
|
-
# convert 1D probabilities (binary) to 2D logits
|
|
71
|
-
probabilities = np.column_stack([1 - probabilities, probabilities])
|
|
72
|
-
elif probabilities.ndim == 2:
|
|
73
|
-
if probabilities.shape[1] < 2:
|
|
74
|
-
raise ValueError("Use a different metric function for regression tasks")
|
|
75
|
-
else:
|
|
76
|
-
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
77
|
-
|
|
78
|
-
predictions = np.argmax(probabilities, axis=-1)
|
|
79
|
-
|
|
80
|
-
num_classes_references = len(set(references))
|
|
81
|
-
num_classes_predictions = len(set(predictions))
|
|
82
|
-
|
|
83
|
-
if average is None:
|
|
84
|
-
average = "binary" if num_classes_references == 2 else "weighted"
|
|
85
|
-
|
|
86
|
-
accuracy = accuracy_score(references, predictions)
|
|
87
|
-
f1 = f1_score(references, predictions, average=average)
|
|
88
|
-
loss = log_loss(references, probabilities)
|
|
89
|
-
|
|
90
|
-
if num_classes_references == num_classes_predictions:
|
|
91
|
-
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
92
|
-
if num_classes_references == 2:
|
|
93
|
-
roc_auc = roc_auc_score(references, probabilities[:, 1])
|
|
94
|
-
precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
|
|
95
|
-
pr_auc = auc(recalls, precisions)
|
|
96
|
-
else:
|
|
97
|
-
roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
98
|
-
pr_auc = None
|
|
99
|
-
else:
|
|
100
|
-
roc_auc = None
|
|
101
|
-
pr_auc = None
|
|
102
|
-
|
|
103
|
-
return {
|
|
104
|
-
"accuracy": float(accuracy),
|
|
105
|
-
"f1_score": float(f1),
|
|
106
|
-
"roc_auc": float(roc_auc) if roc_auc is not None else None,
|
|
107
|
-
"pr_auc": float(pr_auc) if pr_auc is not None else None,
|
|
108
|
-
"log_loss": float(loss),
|
|
109
|
-
}
|
|
42
|
+
class PRCurve(TypedDict):
|
|
43
|
+
thresholds: list[float]
|
|
44
|
+
precisions: list[float]
|
|
45
|
+
recalls: list[float]
|
|
110
46
|
|
|
111
47
|
|
|
112
48
|
def calculate_pr_curve(
|
|
113
49
|
references: NDArray[np.int64],
|
|
114
50
|
probabilities: NDArray[np.float32],
|
|
115
51
|
max_length: int = 100,
|
|
116
|
-
) ->
|
|
52
|
+
) -> PRCurve:
|
|
117
53
|
if probabilities.ndim == 1:
|
|
118
54
|
probabilities_slice = probabilities
|
|
119
55
|
elif probabilities.ndim == 2:
|
|
@@ -124,7 +60,7 @@ def calculate_pr_curve(
|
|
|
124
60
|
if len(probabilities_slice) != len(references):
|
|
125
61
|
raise ValueError("Probabilities and references must have the same length")
|
|
126
62
|
|
|
127
|
-
precisions, recalls, thresholds =
|
|
63
|
+
precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(references, probabilities_slice)
|
|
128
64
|
|
|
129
65
|
# Convert all arrays to float32 immediately after getting them
|
|
130
66
|
precisions = precisions.astype(np.float32)
|
|
@@ -148,14 +84,24 @@ def calculate_pr_curve(
|
|
|
148
84
|
precisions = new_precisions
|
|
149
85
|
recalls = new_recalls
|
|
150
86
|
|
|
151
|
-
return
|
|
87
|
+
return PRCurve(
|
|
88
|
+
thresholds=cast(list[float], thresholds.tolist()),
|
|
89
|
+
precisions=cast(list[float], precisions.tolist()),
|
|
90
|
+
recalls=cast(list[float], recalls.tolist()),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ROCCurve(TypedDict):
|
|
95
|
+
thresholds: list[float]
|
|
96
|
+
false_positive_rates: list[float]
|
|
97
|
+
true_positive_rates: list[float]
|
|
152
98
|
|
|
153
99
|
|
|
154
100
|
def calculate_roc_curve(
|
|
155
101
|
references: NDArray[np.int64],
|
|
156
102
|
probabilities: NDArray[np.float32],
|
|
157
103
|
max_length: int = 100,
|
|
158
|
-
) ->
|
|
104
|
+
) -> ROCCurve:
|
|
159
105
|
if probabilities.ndim == 1:
|
|
160
106
|
probabilities_slice = probabilities
|
|
161
107
|
elif probabilities.ndim == 2:
|
|
@@ -168,7 +114,7 @@ def calculate_roc_curve(
|
|
|
168
114
|
|
|
169
115
|
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
170
116
|
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
171
|
-
fpr, tpr, thresholds =
|
|
117
|
+
fpr, tpr, thresholds = sklearn.metrics.roc_curve(references, probabilities_slice)
|
|
172
118
|
|
|
173
119
|
# Convert all arrays to float32 immediately after getting them
|
|
174
120
|
fpr = fpr.astype(np.float32)
|
|
@@ -192,4 +138,228 @@ def calculate_roc_curve(
|
|
|
192
138
|
fpr = new_fpr
|
|
193
139
|
tpr = new_tpr
|
|
194
140
|
|
|
195
|
-
return
|
|
141
|
+
return ROCCurve(
|
|
142
|
+
false_positive_rates=cast(list[float], fpr.tolist()),
|
|
143
|
+
true_positive_rates=cast(list[float], tpr.tolist()),
|
|
144
|
+
thresholds=cast(list[float], thresholds.tolist()),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class ClassificationMetrics:
|
|
150
|
+
f1_score: float
|
|
151
|
+
"""F1 score of the predictions"""
|
|
152
|
+
|
|
153
|
+
accuracy: float
|
|
154
|
+
"""Accuracy of the predictions"""
|
|
155
|
+
|
|
156
|
+
loss: float
|
|
157
|
+
"""Cross-entropy loss of the logits"""
|
|
158
|
+
|
|
159
|
+
anomaly_score_mean: float | None = None
|
|
160
|
+
"""Mean of anomaly scores across the dataset"""
|
|
161
|
+
|
|
162
|
+
anomaly_score_median: float | None = None
|
|
163
|
+
"""Median of anomaly scores across the dataset"""
|
|
164
|
+
|
|
165
|
+
anomaly_score_variance: float | None = None
|
|
166
|
+
"""Variance of anomaly scores across the dataset"""
|
|
167
|
+
|
|
168
|
+
roc_auc: float | None = None
|
|
169
|
+
"""Receiver operating characteristic area under the curve"""
|
|
170
|
+
|
|
171
|
+
pr_auc: float | None = None
|
|
172
|
+
"""Average precision (area under the curve of the precision-recall curve)"""
|
|
173
|
+
|
|
174
|
+
pr_curve: PRCurve | None = None
|
|
175
|
+
"""Precision-recall curve"""
|
|
176
|
+
|
|
177
|
+
roc_curve: ROCCurve | None = None
|
|
178
|
+
"""Receiver operating characteristic curve"""
|
|
179
|
+
|
|
180
|
+
def __repr__(self) -> str:
|
|
181
|
+
return (
|
|
182
|
+
"ClassificationMetrics({\n"
|
|
183
|
+
+ f" accuracy: {self.accuracy:.4f},\n"
|
|
184
|
+
+ f" f1_score: {self.f1_score:.4f},\n"
|
|
185
|
+
+ (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
|
|
186
|
+
+ (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
|
|
187
|
+
+ (
|
|
188
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
189
|
+
if self.anomaly_score_mean
|
|
190
|
+
else ""
|
|
191
|
+
)
|
|
192
|
+
+ "})"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def calculate_classification_metrics(
|
|
197
|
+
expected_labels: list[int] | NDArray[np.int64],
|
|
198
|
+
logits: list[list[float]] | list[NDArray[np.float32]] | NDArray[np.float32],
|
|
199
|
+
anomaly_scores: list[float] | None = None,
|
|
200
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
201
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
202
|
+
include_curves: bool = False,
|
|
203
|
+
) -> ClassificationMetrics:
|
|
204
|
+
references = np.array(expected_labels)
|
|
205
|
+
|
|
206
|
+
logits = np.array(logits)
|
|
207
|
+
if logits.ndim == 1:
|
|
208
|
+
if (logits > 1).any() or (logits < 0).any():
|
|
209
|
+
raise ValueError("Logits must be between 0 and 1 for binary classification")
|
|
210
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
211
|
+
logits = np.column_stack([1 - logits, logits])
|
|
212
|
+
probabilities = logits # no need to convert to probabilities
|
|
213
|
+
elif logits.ndim == 2:
|
|
214
|
+
if logits.shape[1] < 2:
|
|
215
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
216
|
+
if not (logits > 0).all():
|
|
217
|
+
# convert logits to probabilities with softmax if necessary
|
|
218
|
+
probabilities = softmax(logits)
|
|
219
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
220
|
+
# convert logits to probabilities through normalization if necessary
|
|
221
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
222
|
+
else:
|
|
223
|
+
probabilities = logits
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError("Logits must be 1 or 2 dimensional")
|
|
226
|
+
|
|
227
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
228
|
+
|
|
229
|
+
num_classes_references = len(set(references))
|
|
230
|
+
num_classes_predictions = len(set(predictions))
|
|
231
|
+
|
|
232
|
+
if average is None:
|
|
233
|
+
average = "binary" if num_classes_references == 2 else "weighted"
|
|
234
|
+
|
|
235
|
+
anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
|
|
236
|
+
anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
|
|
237
|
+
anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
|
|
238
|
+
|
|
239
|
+
accuracy = sklearn.metrics.accuracy_score(references, predictions)
|
|
240
|
+
f1 = sklearn.metrics.f1_score(references, predictions, average=average)
|
|
241
|
+
loss = sklearn.metrics.log_loss(references, probabilities)
|
|
242
|
+
|
|
243
|
+
if num_classes_references == num_classes_predictions:
|
|
244
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
245
|
+
if num_classes_references == 2:
|
|
246
|
+
roc_auc = sklearn.metrics.roc_auc_score(references, logits[:, 1])
|
|
247
|
+
roc_curve = calculate_roc_curve(references, logits[:, 1]) if include_curves else None
|
|
248
|
+
pr_auc = sklearn.metrics.average_precision_score(references, logits[:, 1])
|
|
249
|
+
pr_curve = calculate_pr_curve(references, logits[:, 1]) if include_curves else None
|
|
250
|
+
else:
|
|
251
|
+
roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
252
|
+
roc_curve = None
|
|
253
|
+
pr_auc = None
|
|
254
|
+
pr_curve = None
|
|
255
|
+
else:
|
|
256
|
+
roc_auc = None
|
|
257
|
+
pr_auc = None
|
|
258
|
+
pr_curve = None
|
|
259
|
+
roc_curve = None
|
|
260
|
+
|
|
261
|
+
return ClassificationMetrics(
|
|
262
|
+
accuracy=float(accuracy),
|
|
263
|
+
f1_score=float(f1),
|
|
264
|
+
loss=float(loss),
|
|
265
|
+
anomaly_score_mean=anomaly_score_mean,
|
|
266
|
+
anomaly_score_median=anomaly_score_median,
|
|
267
|
+
anomaly_score_variance=anomaly_score_variance,
|
|
268
|
+
roc_auc=float(roc_auc) if roc_auc is not None else None,
|
|
269
|
+
pr_auc=float(pr_auc) if pr_auc is not None else None,
|
|
270
|
+
pr_curve=pr_curve,
|
|
271
|
+
roc_curve=roc_curve,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@dataclass
|
|
276
|
+
class RegressionMetrics:
|
|
277
|
+
mse: float
|
|
278
|
+
"""Mean squared error of the predictions"""
|
|
279
|
+
|
|
280
|
+
rmse: float
|
|
281
|
+
"""Root mean squared error of the predictions"""
|
|
282
|
+
|
|
283
|
+
mae: float
|
|
284
|
+
"""Mean absolute error of the predictions"""
|
|
285
|
+
|
|
286
|
+
r2: float
|
|
287
|
+
"""R-squared score (coefficient of determination) of the predictions"""
|
|
288
|
+
|
|
289
|
+
explained_variance: float
|
|
290
|
+
"""Explained variance score of the predictions"""
|
|
291
|
+
|
|
292
|
+
loss: float
|
|
293
|
+
"""Mean squared error loss of the predictions"""
|
|
294
|
+
|
|
295
|
+
anomaly_score_mean: float | None = None
|
|
296
|
+
"""Mean of anomaly scores across the dataset"""
|
|
297
|
+
|
|
298
|
+
anomaly_score_median: float | None = None
|
|
299
|
+
"""Median of anomaly scores across the dataset"""
|
|
300
|
+
|
|
301
|
+
anomaly_score_variance: float | None = None
|
|
302
|
+
"""Variance of anomaly scores across the dataset"""
|
|
303
|
+
|
|
304
|
+
def __repr__(self) -> str:
|
|
305
|
+
return (
|
|
306
|
+
"RegressionMetrics({\n"
|
|
307
|
+
+ f" mae: {self.mae:.4f},\n"
|
|
308
|
+
+ f" rmse: {self.rmse:.4f},\n"
|
|
309
|
+
+ f" r2: {self.r2:.4f},\n"
|
|
310
|
+
+ (
|
|
311
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
312
|
+
if self.anomaly_score_mean
|
|
313
|
+
else ""
|
|
314
|
+
)
|
|
315
|
+
+ "})"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def calculate_regression_metrics(
|
|
320
|
+
expected_scores: NDArray[np.float32] | list[float],
|
|
321
|
+
predicted_scores: NDArray[np.float32] | list[float],
|
|
322
|
+
anomaly_scores: list[float] | None = None,
|
|
323
|
+
) -> RegressionMetrics:
|
|
324
|
+
"""
|
|
325
|
+
Calculate regression metrics for model evaluation.
|
|
326
|
+
|
|
327
|
+
Params:
|
|
328
|
+
references: True target values
|
|
329
|
+
predictions: Predicted values from the model
|
|
330
|
+
anomaly_scores: Optional anomaly scores for each prediction
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Comprehensive regression metrics including MSE, RMSE, MAE, R², and explained variance
|
|
334
|
+
|
|
335
|
+
Raises:
|
|
336
|
+
ValueError: If predictions and references have different lengths
|
|
337
|
+
"""
|
|
338
|
+
references = np.array(expected_scores)
|
|
339
|
+
predictions = np.array(predicted_scores)
|
|
340
|
+
|
|
341
|
+
if len(predictions) != len(references):
|
|
342
|
+
raise ValueError("Predictions and references must have the same length")
|
|
343
|
+
|
|
344
|
+
anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
|
|
345
|
+
anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
|
|
346
|
+
anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
|
|
347
|
+
|
|
348
|
+
# Calculate core regression metrics
|
|
349
|
+
mse = float(sklearn.metrics.mean_squared_error(references, predictions))
|
|
350
|
+
rmse = float(np.sqrt(mse))
|
|
351
|
+
mae = float(sklearn.metrics.mean_absolute_error(references, predictions))
|
|
352
|
+
r2 = float(sklearn.metrics.r2_score(references, predictions))
|
|
353
|
+
explained_var = float(sklearn.metrics.explained_variance_score(references, predictions))
|
|
354
|
+
|
|
355
|
+
return RegressionMetrics(
|
|
356
|
+
mse=mse,
|
|
357
|
+
rmse=rmse,
|
|
358
|
+
mae=mae,
|
|
359
|
+
r2=r2,
|
|
360
|
+
explained_variance=explained_var,
|
|
361
|
+
loss=mse, # For regression, loss is typically MSE
|
|
362
|
+
anomaly_score_mean=anomaly_score_mean,
|
|
363
|
+
anomaly_score_median=anomaly_score_median,
|
|
364
|
+
anomaly_score_variance=anomaly_score_variance,
|
|
365
|
+
)
|