orca-sdk 0.0.94__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -175
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +338 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -8,37 +8,25 @@ IMPORTANT:
8
8
 
9
9
  """
10
10
 
11
- from typing import Literal, Tuple, TypedDict
11
+ from dataclasses import dataclass
12
+ from typing import Any, Literal, TypedDict, cast
12
13
 
13
14
  import numpy as np
15
+ import sklearn.metrics
14
16
  from numpy.typing import NDArray
15
- from scipy.special import softmax
16
- from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
17
- from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
18
- from sklearn.metrics import roc_auc_score
19
- from sklearn.metrics import roc_curve as sklearn_roc_curve
20
- from transformers.trainer_utils import EvalPrediction
21
17
 
22
18
 
23
- class ClassificationMetrics(TypedDict):
24
- accuracy: float
25
- f1_score: float
26
- roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
27
- pr_auc: float | None # precision-recall area under the curve (only for binary classification)
28
- log_loss: float # cross-entropy loss for probabilities
29
-
19
+ # we don't want to depend on scipy or torch in orca_sdk
20
+ def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
21
+ shifted = logits - np.max(logits, axis=axis, keepdims=True)
22
+ exps = np.exp(shifted)
23
+ return exps / np.sum(exps, axis=axis, keepdims=True)
30
24
 
31
- def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
32
- """
33
- Compute standard metrics for classifier with Hugging Face Trainer.
34
-
35
- Args:
36
- eval_pred: The predictions containing logits and expected labels as given by the Trainer.
37
25
 
38
- Returns:
39
- A dictionary containing the accuracy, f1 score, and ROC AUC score.
40
- """
41
- logits, references = eval_pred
26
+ # We don't want to depend on transformers just for the eval_pred type in orca_sdk
27
+ def transform_eval_pred(eval_pred: Any) -> tuple[NDArray[np.int64], NDArray[np.float32]]:
28
+ # convert results from Trainer compute_metrics param for use in calculate_classification_metrics
29
+ logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
42
30
  if isinstance(logits, tuple):
43
31
  logits = logits[0]
44
32
  if not isinstance(logits, np.ndarray):
@@ -48,72 +36,20 @@ def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetri
48
36
  "Multiple label columns found, use the `label_names` training argument to specify which one to use"
49
37
  )
50
38
 
51
- if not (logits > 0).all():
52
- # convert logits to probabilities with softmax if necessary
53
- probabilities = softmax(logits)
54
- elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
55
- # convert logits to probabilities through normalization if necessary
56
- probabilities = logits / logits.sum(-1, keepdims=True)
57
- else:
58
- probabilities = logits
59
-
60
- return classification_scores(references, probabilities)
39
+ return (references, logits)
61
40
 
62
41
 
63
- def classification_scores(
64
- references: NDArray[np.int64],
65
- probabilities: NDArray[np.float32],
66
- average: Literal["micro", "macro", "weighted", "binary"] | None = None,
67
- multi_class: Literal["ovr", "ovo"] = "ovr",
68
- ) -> ClassificationMetrics:
69
- if probabilities.ndim == 1:
70
- # convert 1D probabilities (binary) to 2D logits
71
- probabilities = np.column_stack([1 - probabilities, probabilities])
72
- elif probabilities.ndim == 2:
73
- if probabilities.shape[1] < 2:
74
- raise ValueError("Use a different metric function for regression tasks")
75
- else:
76
- raise ValueError("Probabilities must be 1 or 2 dimensional")
77
-
78
- predictions = np.argmax(probabilities, axis=-1)
79
-
80
- num_classes_references = len(set(references))
81
- num_classes_predictions = len(set(predictions))
82
-
83
- if average is None:
84
- average = "binary" if num_classes_references == 2 else "weighted"
85
-
86
- accuracy = accuracy_score(references, predictions)
87
- f1 = f1_score(references, predictions, average=average)
88
- loss = log_loss(references, probabilities)
89
-
90
- if num_classes_references == num_classes_predictions:
91
- # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
92
- if num_classes_references == 2:
93
- roc_auc = roc_auc_score(references, probabilities[:, 1])
94
- precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
95
- pr_auc = auc(recalls, precisions)
96
- else:
97
- roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
98
- pr_auc = None
99
- else:
100
- roc_auc = None
101
- pr_auc = None
102
-
103
- return {
104
- "accuracy": float(accuracy),
105
- "f1_score": float(f1),
106
- "roc_auc": float(roc_auc) if roc_auc is not None else None,
107
- "pr_auc": float(pr_auc) if pr_auc is not None else None,
108
- "log_loss": float(loss),
109
- }
42
+ class PRCurve(TypedDict):
43
+ thresholds: list[float]
44
+ precisions: list[float]
45
+ recalls: list[float]
110
46
 
111
47
 
112
48
  def calculate_pr_curve(
113
49
  references: NDArray[np.int64],
114
50
  probabilities: NDArray[np.float32],
115
51
  max_length: int = 100,
116
- ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
52
+ ) -> PRCurve:
117
53
  if probabilities.ndim == 1:
118
54
  probabilities_slice = probabilities
119
55
  elif probabilities.ndim == 2:
@@ -124,7 +60,7 @@ def calculate_pr_curve(
124
60
  if len(probabilities_slice) != len(references):
125
61
  raise ValueError("Probabilities and references must have the same length")
126
62
 
127
- precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
63
+ precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(references, probabilities_slice)
128
64
 
129
65
  # Convert all arrays to float32 immediately after getting them
130
66
  precisions = precisions.astype(np.float32)
@@ -148,14 +84,24 @@ def calculate_pr_curve(
148
84
  precisions = new_precisions
149
85
  recalls = new_recalls
150
86
 
151
- return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
87
+ return PRCurve(
88
+ thresholds=cast(list[float], thresholds.tolist()),
89
+ precisions=cast(list[float], precisions.tolist()),
90
+ recalls=cast(list[float], recalls.tolist()),
91
+ )
92
+
93
+
94
+ class ROCCurve(TypedDict):
95
+ thresholds: list[float]
96
+ false_positive_rates: list[float]
97
+ true_positive_rates: list[float]
152
98
 
153
99
 
154
100
  def calculate_roc_curve(
155
101
  references: NDArray[np.int64],
156
102
  probabilities: NDArray[np.float32],
157
103
  max_length: int = 100,
158
- ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
104
+ ) -> ROCCurve:
159
105
  if probabilities.ndim == 1:
160
106
  probabilities_slice = probabilities
161
107
  elif probabilities.ndim == 2:
@@ -168,7 +114,7 @@ def calculate_roc_curve(
168
114
 
169
115
  # Convert probabilities to float32 before calling sklearn_roc_curve
170
116
  probabilities_slice = probabilities_slice.astype(np.float32)
171
- fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
117
+ fpr, tpr, thresholds = sklearn.metrics.roc_curve(references, probabilities_slice)
172
118
 
173
119
  # Convert all arrays to float32 immediately after getting them
174
120
  fpr = fpr.astype(np.float32)
@@ -192,4 +138,228 @@ def calculate_roc_curve(
192
138
  fpr = new_fpr
193
139
  tpr = new_tpr
194
140
 
195
- return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
141
+ return ROCCurve(
142
+ false_positive_rates=cast(list[float], fpr.tolist()),
143
+ true_positive_rates=cast(list[float], tpr.tolist()),
144
+ thresholds=cast(list[float], thresholds.tolist()),
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class ClassificationMetrics:
150
+ f1_score: float
151
+ """F1 score of the predictions"""
152
+
153
+ accuracy: float
154
+ """Accuracy of the predictions"""
155
+
156
+ loss: float
157
+ """Cross-entropy loss of the logits"""
158
+
159
+ anomaly_score_mean: float | None = None
160
+ """Mean of anomaly scores across the dataset"""
161
+
162
+ anomaly_score_median: float | None = None
163
+ """Median of anomaly scores across the dataset"""
164
+
165
+ anomaly_score_variance: float | None = None
166
+ """Variance of anomaly scores across the dataset"""
167
+
168
+ roc_auc: float | None = None
169
+ """Receiver operating characteristic area under the curve"""
170
+
171
+ pr_auc: float | None = None
172
+ """Average precision (area under the curve of the precision-recall curve)"""
173
+
174
+ pr_curve: PRCurve | None = None
175
+ """Precision-recall curve"""
176
+
177
+ roc_curve: ROCCurve | None = None
178
+ """Receiver operating characteristic curve"""
179
+
180
+ def __repr__(self) -> str:
181
+ return (
182
+ "ClassificationMetrics({\n"
183
+ + f" accuracy: {self.accuracy:.4f},\n"
184
+ + f" f1_score: {self.f1_score:.4f},\n"
185
+ + (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
186
+ + (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
187
+ + (
188
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
189
+ if self.anomaly_score_mean
190
+ else ""
191
+ )
192
+ + "})"
193
+ )
194
+
195
+
196
+ def calculate_classification_metrics(
197
+ expected_labels: list[int] | NDArray[np.int64],
198
+ logits: list[list[float]] | list[NDArray[np.float32]] | NDArray[np.float32],
199
+ anomaly_scores: list[float] | None = None,
200
+ average: Literal["micro", "macro", "weighted", "binary"] | None = None,
201
+ multi_class: Literal["ovr", "ovo"] = "ovr",
202
+ include_curves: bool = False,
203
+ ) -> ClassificationMetrics:
204
+ references = np.array(expected_labels)
205
+
206
+ logits = np.array(logits)
207
+ if logits.ndim == 1:
208
+ if (logits > 1).any() or (logits < 0).any():
209
+ raise ValueError("Logits must be between 0 and 1 for binary classification")
210
+ # convert 1D probabilities (binary) to 2D logits
211
+ logits = np.column_stack([1 - logits, logits])
212
+ probabilities = logits # no need to convert to probabilities
213
+ elif logits.ndim == 2:
214
+ if logits.shape[1] < 2:
215
+ raise ValueError("Use a different metric function for regression tasks")
216
+ if not (logits > 0).all():
217
+ # convert logits to probabilities with softmax if necessary
218
+ probabilities = softmax(logits)
219
+ elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
220
+ # convert logits to probabilities through normalization if necessary
221
+ probabilities = logits / logits.sum(-1, keepdims=True)
222
+ else:
223
+ probabilities = logits
224
+ else:
225
+ raise ValueError("Logits must be 1 or 2 dimensional")
226
+
227
+ predictions = np.argmax(probabilities, axis=-1)
228
+
229
+ num_classes_references = len(set(references))
230
+ num_classes_predictions = len(set(predictions))
231
+
232
+ if average is None:
233
+ average = "binary" if num_classes_references == 2 else "weighted"
234
+
235
+ anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
236
+ anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
237
+ anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
238
+
239
+ accuracy = sklearn.metrics.accuracy_score(references, predictions)
240
+ f1 = sklearn.metrics.f1_score(references, predictions, average=average)
241
+ loss = sklearn.metrics.log_loss(references, probabilities)
242
+
243
+ if num_classes_references == num_classes_predictions:
244
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
245
+ if num_classes_references == 2:
246
+ roc_auc = sklearn.metrics.roc_auc_score(references, logits[:, 1])
247
+ roc_curve = calculate_roc_curve(references, logits[:, 1]) if include_curves else None
248
+ pr_auc = sklearn.metrics.average_precision_score(references, logits[:, 1])
249
+ pr_curve = calculate_pr_curve(references, logits[:, 1]) if include_curves else None
250
+ else:
251
+ roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
252
+ roc_curve = None
253
+ pr_auc = None
254
+ pr_curve = None
255
+ else:
256
+ roc_auc = None
257
+ pr_auc = None
258
+ pr_curve = None
259
+ roc_curve = None
260
+
261
+ return ClassificationMetrics(
262
+ accuracy=float(accuracy),
263
+ f1_score=float(f1),
264
+ loss=float(loss),
265
+ anomaly_score_mean=anomaly_score_mean,
266
+ anomaly_score_median=anomaly_score_median,
267
+ anomaly_score_variance=anomaly_score_variance,
268
+ roc_auc=float(roc_auc) if roc_auc is not None else None,
269
+ pr_auc=float(pr_auc) if pr_auc is not None else None,
270
+ pr_curve=pr_curve,
271
+ roc_curve=roc_curve,
272
+ )
273
+
274
+
275
+ @dataclass
276
+ class RegressionMetrics:
277
+ mse: float
278
+ """Mean squared error of the predictions"""
279
+
280
+ rmse: float
281
+ """Root mean squared error of the predictions"""
282
+
283
+ mae: float
284
+ """Mean absolute error of the predictions"""
285
+
286
+ r2: float
287
+ """R-squared score (coefficient of determination) of the predictions"""
288
+
289
+ explained_variance: float
290
+ """Explained variance score of the predictions"""
291
+
292
+ loss: float
293
+ """Mean squared error loss of the predictions"""
294
+
295
+ anomaly_score_mean: float | None = None
296
+ """Mean of anomaly scores across the dataset"""
297
+
298
+ anomaly_score_median: float | None = None
299
+ """Median of anomaly scores across the dataset"""
300
+
301
+ anomaly_score_variance: float | None = None
302
+ """Variance of anomaly scores across the dataset"""
303
+
304
+ def __repr__(self) -> str:
305
+ return (
306
+ "RegressionMetrics({\n"
307
+ + f" mae: {self.mae:.4f},\n"
308
+ + f" rmse: {self.rmse:.4f},\n"
309
+ + f" r2: {self.r2:.4f},\n"
310
+ + (
311
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
312
+ if self.anomaly_score_mean
313
+ else ""
314
+ )
315
+ + "})"
316
+ )
317
+
318
+
319
+ def calculate_regression_metrics(
320
+ expected_scores: NDArray[np.float32] | list[float],
321
+ predicted_scores: NDArray[np.float32] | list[float],
322
+ anomaly_scores: list[float] | None = None,
323
+ ) -> RegressionMetrics:
324
+ """
325
+ Calculate regression metrics for model evaluation.
326
+
327
+ Params:
328
+ references: True target values
329
+ predictions: Predicted values from the model
330
+ anomaly_scores: Optional anomaly scores for each prediction
331
+
332
+ Returns:
333
+ Comprehensive regression metrics including MSE, RMSE, MAE, R², and explained variance
334
+
335
+ Raises:
336
+ ValueError: If predictions and references have different lengths
337
+ """
338
+ references = np.array(expected_scores)
339
+ predictions = np.array(predicted_scores)
340
+
341
+ if len(predictions) != len(references):
342
+ raise ValueError("Predictions and references must have the same length")
343
+
344
+ anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
345
+ anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
346
+ anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
347
+
348
+ # Calculate core regression metrics
349
+ mse = float(sklearn.metrics.mean_squared_error(references, predictions))
350
+ rmse = float(np.sqrt(mse))
351
+ mae = float(sklearn.metrics.mean_absolute_error(references, predictions))
352
+ r2 = float(sklearn.metrics.r2_score(references, predictions))
353
+ explained_var = float(sklearn.metrics.explained_variance_score(references, predictions))
354
+
355
+ return RegressionMetrics(
356
+ mse=mse,
357
+ rmse=rmse,
358
+ mae=mae,
359
+ r2=r2,
360
+ explained_variance=explained_var,
361
+ loss=mse, # For regression, loss is typically MSE
362
+ anomaly_score_mean=anomaly_score_mean,
363
+ anomaly_score_median=anomaly_score_median,
364
+ anomaly_score_variance=anomaly_score_variance,
365
+ )