orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/__init__.py
CHANGED
|
@@ -3,8 +3,8 @@ OrcaSDK is a Python library for building and using retrieval augmented models in
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
6
|
-
from .
|
|
7
|
-
from .
|
|
6
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
7
|
+
from .client import orca_api
|
|
8
8
|
from .credentials import OrcaCredentials
|
|
9
9
|
from .datasource import Datasource
|
|
10
10
|
from .embedding_model import (
|
|
@@ -12,13 +12,19 @@ from .embedding_model import (
|
|
|
12
12
|
PretrainedEmbeddingModel,
|
|
13
13
|
PretrainedEmbeddingModelName,
|
|
14
14
|
)
|
|
15
|
+
from .job import Job, Status
|
|
15
16
|
from .memoryset import (
|
|
17
|
+
CascadingEditSuggestion,
|
|
16
18
|
FilterItemTuple,
|
|
17
19
|
LabeledMemory,
|
|
18
20
|
LabeledMemoryLookup,
|
|
19
21
|
LabeledMemoryset,
|
|
22
|
+
ScoredMemory,
|
|
23
|
+
ScoredMemoryLookup,
|
|
24
|
+
ScoredMemoryset,
|
|
20
25
|
)
|
|
21
|
-
from .
|
|
26
|
+
from .regression_model import RegressionModel
|
|
27
|
+
from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
|
|
22
28
|
|
|
23
29
|
# only specify things that should show up on the root page of the reference docs because they are in private modules
|
|
24
|
-
__all__ = ["
|
|
30
|
+
__all__ = ["UNSET", "CreateMode", "DropMode"]
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains metrics for usage with the Hugging Face Trainer.
|
|
3
|
+
|
|
4
|
+
IMPORTANT:
|
|
5
|
+
- This is a shared file between OrcaLib and the OrcaSDK.
|
|
6
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
7
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Literal, TypedDict, cast
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import sklearn.metrics
|
|
16
|
+
from numpy.typing import NDArray
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# we don't want to depend on scipy or torch in orca_sdk
|
|
20
|
+
def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
21
|
+
shifted = logits - np.max(logits, axis=axis, keepdims=True)
|
|
22
|
+
exps = np.exp(shifted)
|
|
23
|
+
return exps / np.sum(exps, axis=axis, keepdims=True)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# We don't want to depend on transformers just for the eval_pred type in orca_sdk
|
|
27
|
+
def transform_eval_pred(eval_pred: Any) -> tuple[NDArray, NDArray[np.float32]]:
|
|
28
|
+
# convert results from Trainer compute_metrics param for use in calculate_classification_metrics
|
|
29
|
+
logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
|
|
30
|
+
if isinstance(logits, tuple):
|
|
31
|
+
logits = logits[0]
|
|
32
|
+
if not isinstance(logits, np.ndarray):
|
|
33
|
+
raise ValueError("Logits must be a numpy array")
|
|
34
|
+
if not isinstance(references, np.ndarray):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return (references, logits)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PRCurve(TypedDict):
|
|
43
|
+
thresholds: list[float]
|
|
44
|
+
precisions: list[float]
|
|
45
|
+
recalls: list[float]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def calculate_pr_curve(
|
|
49
|
+
references: NDArray[np.int64],
|
|
50
|
+
probabilities: NDArray[np.float32],
|
|
51
|
+
max_length: int = 100,
|
|
52
|
+
) -> PRCurve:
|
|
53
|
+
if probabilities.ndim == 1:
|
|
54
|
+
probabilities_slice = probabilities
|
|
55
|
+
elif probabilities.ndim == 2:
|
|
56
|
+
probabilities_slice = probabilities[:, 1]
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
59
|
+
|
|
60
|
+
if len(probabilities_slice) != len(references):
|
|
61
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
62
|
+
|
|
63
|
+
precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(references, probabilities_slice)
|
|
64
|
+
|
|
65
|
+
# Convert all arrays to float32 immediately after getting them
|
|
66
|
+
precisions = precisions.astype(np.float32)
|
|
67
|
+
recalls = recalls.astype(np.float32)
|
|
68
|
+
thresholds = thresholds.astype(np.float32)
|
|
69
|
+
|
|
70
|
+
# Concatenate with 0 to include the lowest threshold
|
|
71
|
+
thresholds = np.concatenate(([0], thresholds))
|
|
72
|
+
|
|
73
|
+
# Sort by threshold
|
|
74
|
+
sorted_indices = np.argsort(thresholds)
|
|
75
|
+
thresholds = thresholds[sorted_indices]
|
|
76
|
+
precisions = precisions[sorted_indices]
|
|
77
|
+
recalls = recalls[sorted_indices]
|
|
78
|
+
|
|
79
|
+
if len(precisions) > max_length:
|
|
80
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
81
|
+
new_precisions = np.interp(new_thresholds, thresholds, precisions)
|
|
82
|
+
new_recalls = np.interp(new_thresholds, thresholds, recalls)
|
|
83
|
+
thresholds = new_thresholds
|
|
84
|
+
precisions = new_precisions
|
|
85
|
+
recalls = new_recalls
|
|
86
|
+
|
|
87
|
+
return PRCurve(
|
|
88
|
+
thresholds=cast(list[float], thresholds.tolist()),
|
|
89
|
+
precisions=cast(list[float], precisions.tolist()),
|
|
90
|
+
recalls=cast(list[float], recalls.tolist()),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ROCCurve(TypedDict):
|
|
95
|
+
thresholds: list[float]
|
|
96
|
+
false_positive_rates: list[float]
|
|
97
|
+
true_positive_rates: list[float]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def calculate_roc_curve(
|
|
101
|
+
references: NDArray[np.int64],
|
|
102
|
+
probabilities: NDArray[np.float32],
|
|
103
|
+
max_length: int = 100,
|
|
104
|
+
) -> ROCCurve:
|
|
105
|
+
if probabilities.ndim == 1:
|
|
106
|
+
probabilities_slice = probabilities
|
|
107
|
+
elif probabilities.ndim == 2:
|
|
108
|
+
probabilities_slice = probabilities[:, 1]
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
111
|
+
|
|
112
|
+
if len(probabilities_slice) != len(references):
|
|
113
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
114
|
+
|
|
115
|
+
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
116
|
+
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
117
|
+
fpr, tpr, thresholds = sklearn.metrics.roc_curve(references, probabilities_slice)
|
|
118
|
+
|
|
119
|
+
# Convert all arrays to float32 immediately after getting them
|
|
120
|
+
fpr = fpr.astype(np.float32)
|
|
121
|
+
tpr = tpr.astype(np.float32)
|
|
122
|
+
thresholds = thresholds.astype(np.float32)
|
|
123
|
+
|
|
124
|
+
# We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
|
|
125
|
+
thresholds[0] = 1.0
|
|
126
|
+
|
|
127
|
+
# Sort by threshold
|
|
128
|
+
sorted_indices = np.argsort(thresholds)
|
|
129
|
+
thresholds = thresholds[sorted_indices]
|
|
130
|
+
fpr = fpr[sorted_indices]
|
|
131
|
+
tpr = tpr[sorted_indices]
|
|
132
|
+
|
|
133
|
+
if len(fpr) > max_length:
|
|
134
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
135
|
+
new_fpr = np.interp(new_thresholds, thresholds, fpr)
|
|
136
|
+
new_tpr = np.interp(new_thresholds, thresholds, tpr)
|
|
137
|
+
thresholds = new_thresholds
|
|
138
|
+
fpr = new_fpr
|
|
139
|
+
tpr = new_tpr
|
|
140
|
+
|
|
141
|
+
return ROCCurve(
|
|
142
|
+
false_positive_rates=cast(list[float], fpr.tolist()),
|
|
143
|
+
true_positive_rates=cast(list[float], tpr.tolist()),
|
|
144
|
+
thresholds=cast(list[float], thresholds.tolist()),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class ClassificationMetrics:
|
|
150
|
+
coverage: float
|
|
151
|
+
"""Percentage of predictions that are not none"""
|
|
152
|
+
|
|
153
|
+
f1_score: float
|
|
154
|
+
"""F1 score of the predictions"""
|
|
155
|
+
|
|
156
|
+
accuracy: float
|
|
157
|
+
"""Accuracy of the predictions"""
|
|
158
|
+
|
|
159
|
+
loss: float | None
|
|
160
|
+
"""Cross-entropy loss of the logits"""
|
|
161
|
+
|
|
162
|
+
anomaly_score_mean: float | None = None
|
|
163
|
+
"""Mean of anomaly scores across the dataset"""
|
|
164
|
+
|
|
165
|
+
anomaly_score_median: float | None = None
|
|
166
|
+
"""Median of anomaly scores across the dataset"""
|
|
167
|
+
|
|
168
|
+
anomaly_score_variance: float | None = None
|
|
169
|
+
"""Variance of anomaly scores across the dataset"""
|
|
170
|
+
|
|
171
|
+
roc_auc: float | None = None
|
|
172
|
+
"""Receiver operating characteristic area under the curve"""
|
|
173
|
+
|
|
174
|
+
pr_auc: float | None = None
|
|
175
|
+
"""Average precision (area under the curve of the precision-recall curve)"""
|
|
176
|
+
|
|
177
|
+
pr_curve: PRCurve | None = None
|
|
178
|
+
"""Precision-recall curve"""
|
|
179
|
+
|
|
180
|
+
roc_curve: ROCCurve | None = None
|
|
181
|
+
"""Receiver operating characteristic curve"""
|
|
182
|
+
|
|
183
|
+
def __repr__(self) -> str:
|
|
184
|
+
return (
|
|
185
|
+
"ClassificationMetrics({\n"
|
|
186
|
+
+ f" accuracy: {self.accuracy:.4f},\n"
|
|
187
|
+
+ f" f1_score: {self.f1_score:.4f},\n"
|
|
188
|
+
+ (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
|
|
189
|
+
+ (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
|
|
190
|
+
+ (
|
|
191
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
192
|
+
if self.anomaly_score_mean
|
|
193
|
+
else ""
|
|
194
|
+
)
|
|
195
|
+
+ "})"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def calculate_classification_metrics(
|
|
200
|
+
expected_labels: list[int] | NDArray[np.int64],
|
|
201
|
+
logits: list[list[float]] | list[NDArray[np.float32]] | NDArray[np.float32],
|
|
202
|
+
anomaly_scores: list[float] | None = None,
|
|
203
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
204
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
205
|
+
include_curves: bool = False,
|
|
206
|
+
) -> ClassificationMetrics:
|
|
207
|
+
references = np.array(expected_labels)
|
|
208
|
+
|
|
209
|
+
logits = np.array(logits)
|
|
210
|
+
if logits.ndim == 1:
|
|
211
|
+
if (logits > 1).any() or (logits < 0).any():
|
|
212
|
+
raise ValueError("Logits must be between 0 and 1 for binary classification")
|
|
213
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
214
|
+
logits = np.column_stack([1 - logits, logits])
|
|
215
|
+
probabilities = logits # no need to convert to probabilities
|
|
216
|
+
elif logits.ndim == 2:
|
|
217
|
+
if logits.shape[1] < 2:
|
|
218
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
219
|
+
if not (logits > 0).all():
|
|
220
|
+
# convert logits to probabilities with softmax if necessary
|
|
221
|
+
probabilities = softmax(logits)
|
|
222
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
223
|
+
# convert logits to probabilities through normalization if necessary
|
|
224
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
225
|
+
else:
|
|
226
|
+
probabilities = logits
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError("Logits must be 1 or 2 dimensional")
|
|
229
|
+
|
|
230
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
231
|
+
predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
|
|
232
|
+
|
|
233
|
+
num_classes_references = len(set(references))
|
|
234
|
+
num_classes_predictions = len(set(predictions))
|
|
235
|
+
num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
|
|
236
|
+
coverage = 1 - num_none_predictions / len(probabilities)
|
|
237
|
+
|
|
238
|
+
if average is None:
|
|
239
|
+
average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
|
|
240
|
+
|
|
241
|
+
anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
|
|
242
|
+
anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
|
|
243
|
+
anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
|
|
244
|
+
|
|
245
|
+
accuracy = sklearn.metrics.accuracy_score(references, predictions)
|
|
246
|
+
f1 = sklearn.metrics.f1_score(references, predictions, average=average)
|
|
247
|
+
# Ensure sklearn sees the full class set corresponding to probability columns
|
|
248
|
+
# to avoid errors when y_true does not contain all classes.
|
|
249
|
+
loss = (
|
|
250
|
+
sklearn.metrics.log_loss(
|
|
251
|
+
references,
|
|
252
|
+
probabilities,
|
|
253
|
+
labels=list(range(probabilities.shape[1])),
|
|
254
|
+
)
|
|
255
|
+
if num_none_predictions == 0
|
|
256
|
+
else None
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if num_classes_references == num_classes_predictions and num_none_predictions == 0:
|
|
260
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
261
|
+
if num_classes_references == 2:
|
|
262
|
+
roc_auc = sklearn.metrics.roc_auc_score(references, logits[:, 1])
|
|
263
|
+
roc_curve = calculate_roc_curve(references, logits[:, 1]) if include_curves else None
|
|
264
|
+
pr_auc = sklearn.metrics.average_precision_score(references, logits[:, 1])
|
|
265
|
+
pr_curve = calculate_pr_curve(references, logits[:, 1]) if include_curves else None
|
|
266
|
+
else:
|
|
267
|
+
roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
268
|
+
roc_curve = None
|
|
269
|
+
pr_auc = None
|
|
270
|
+
pr_curve = None
|
|
271
|
+
else:
|
|
272
|
+
roc_auc = None
|
|
273
|
+
pr_auc = None
|
|
274
|
+
pr_curve = None
|
|
275
|
+
roc_curve = None
|
|
276
|
+
|
|
277
|
+
return ClassificationMetrics(
|
|
278
|
+
coverage=coverage,
|
|
279
|
+
accuracy=float(accuracy),
|
|
280
|
+
f1_score=float(f1),
|
|
281
|
+
loss=float(loss) if loss is not None else None,
|
|
282
|
+
anomaly_score_mean=anomaly_score_mean,
|
|
283
|
+
anomaly_score_median=anomaly_score_median,
|
|
284
|
+
anomaly_score_variance=anomaly_score_variance,
|
|
285
|
+
roc_auc=float(roc_auc) if roc_auc is not None else None,
|
|
286
|
+
pr_auc=float(pr_auc) if pr_auc is not None else None,
|
|
287
|
+
pr_curve=pr_curve,
|
|
288
|
+
roc_curve=roc_curve,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@dataclass
|
|
293
|
+
class RegressionMetrics:
|
|
294
|
+
coverage: float
|
|
295
|
+
"""Percentage of predictions that are not none"""
|
|
296
|
+
|
|
297
|
+
mse: float
|
|
298
|
+
"""Mean squared error of the predictions"""
|
|
299
|
+
|
|
300
|
+
rmse: float
|
|
301
|
+
"""Root mean squared error of the predictions"""
|
|
302
|
+
|
|
303
|
+
mae: float
|
|
304
|
+
"""Mean absolute error of the predictions"""
|
|
305
|
+
|
|
306
|
+
r2: float
|
|
307
|
+
"""R-squared score (coefficient of determination) of the predictions"""
|
|
308
|
+
|
|
309
|
+
explained_variance: float
|
|
310
|
+
"""Explained variance score of the predictions"""
|
|
311
|
+
|
|
312
|
+
loss: float
|
|
313
|
+
"""Mean squared error loss of the predictions"""
|
|
314
|
+
|
|
315
|
+
anomaly_score_mean: float | None = None
|
|
316
|
+
"""Mean of anomaly scores across the dataset"""
|
|
317
|
+
|
|
318
|
+
anomaly_score_median: float | None = None
|
|
319
|
+
"""Median of anomaly scores across the dataset"""
|
|
320
|
+
|
|
321
|
+
anomaly_score_variance: float | None = None
|
|
322
|
+
"""Variance of anomaly scores across the dataset"""
|
|
323
|
+
|
|
324
|
+
def __repr__(self) -> str:
|
|
325
|
+
return (
|
|
326
|
+
"RegressionMetrics({\n"
|
|
327
|
+
+ f" mae: {self.mae:.4f},\n"
|
|
328
|
+
+ f" rmse: {self.rmse:.4f},\n"
|
|
329
|
+
+ f" r2: {self.r2:.4f},\n"
|
|
330
|
+
+ (
|
|
331
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
332
|
+
if self.anomaly_score_mean
|
|
333
|
+
else ""
|
|
334
|
+
)
|
|
335
|
+
+ "})"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def calculate_regression_metrics(
|
|
340
|
+
expected_scores: NDArray[np.float32] | list[float],
|
|
341
|
+
predicted_scores: NDArray[np.float32] | list[float],
|
|
342
|
+
anomaly_scores: list[float] | None = None,
|
|
343
|
+
) -> RegressionMetrics:
|
|
344
|
+
"""
|
|
345
|
+
Calculate regression metrics for model evaluation.
|
|
346
|
+
|
|
347
|
+
Params:
|
|
348
|
+
references: True target values
|
|
349
|
+
predictions: Predicted values from the model
|
|
350
|
+
anomaly_scores: Optional anomaly scores for each prediction
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Comprehensive regression metrics including MSE, RMSE, MAE, R², and explained variance
|
|
354
|
+
|
|
355
|
+
Raises:
|
|
356
|
+
ValueError: If predictions and references have different lengths
|
|
357
|
+
"""
|
|
358
|
+
references = np.array(expected_scores)
|
|
359
|
+
predictions = np.array(predicted_scores)
|
|
360
|
+
|
|
361
|
+
if len(predictions) != len(references):
|
|
362
|
+
raise ValueError("Predictions and references must have the same length")
|
|
363
|
+
|
|
364
|
+
anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
|
|
365
|
+
anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
|
|
366
|
+
anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
|
|
367
|
+
|
|
368
|
+
none_prediction_mask = np.isnan(predictions)
|
|
369
|
+
num_none_predictions = none_prediction_mask.sum()
|
|
370
|
+
coverage = 1 - num_none_predictions / len(predictions)
|
|
371
|
+
if num_none_predictions > 0:
|
|
372
|
+
references = references[~none_prediction_mask]
|
|
373
|
+
predictions = predictions[~none_prediction_mask]
|
|
374
|
+
|
|
375
|
+
# Calculate core regression metrics
|
|
376
|
+
mse = float(sklearn.metrics.mean_squared_error(references, predictions))
|
|
377
|
+
rmse = float(np.sqrt(mse))
|
|
378
|
+
mae = float(sklearn.metrics.mean_absolute_error(references, predictions))
|
|
379
|
+
r2 = float(sklearn.metrics.r2_score(references, predictions))
|
|
380
|
+
explained_var = float(sklearn.metrics.explained_variance_score(references, predictions))
|
|
381
|
+
|
|
382
|
+
return RegressionMetrics(
|
|
383
|
+
coverage=coverage,
|
|
384
|
+
mse=mse,
|
|
385
|
+
rmse=rmse,
|
|
386
|
+
mae=mae,
|
|
387
|
+
r2=r2,
|
|
388
|
+
explained_variance=explained_var,
|
|
389
|
+
loss=mse, # For regression, loss is typically MSE
|
|
390
|
+
anomaly_score_mean=anomaly_score_mean,
|
|
391
|
+
anomaly_score_median=anomaly_score_median,
|
|
392
|
+
anomaly_score_variance=anomaly_score_variance,
|
|
393
|
+
)
|