orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/__init__.py CHANGED
@@ -3,8 +3,8 @@ OrcaSDK is a Python library for building and using retrieval augmented models in
3
3
  """
4
4
 
5
5
  from ._utils.common import UNSET, CreateMode, DropMode
6
- from ._utils.task import TaskStatus
7
- from .classification_model import ClassificationModel
6
+ from .classification_model import ClassificationMetrics, ClassificationModel
7
+ from .client import OrcaClient
8
8
  from .credentials import OrcaCredentials
9
9
  from .datasource import Datasource
10
10
  from .embedding_model import (
@@ -12,13 +12,19 @@ from .embedding_model import (
12
12
  PretrainedEmbeddingModel,
13
13
  PretrainedEmbeddingModelName,
14
14
  )
15
+ from .job import Job, Status
15
16
  from .memoryset import (
17
+ CascadingEditSuggestion,
16
18
  FilterItemTuple,
17
19
  LabeledMemory,
18
20
  LabeledMemoryLookup,
19
21
  LabeledMemoryset,
22
+ ScoredMemory,
23
+ ScoredMemoryLookup,
24
+ ScoredMemoryset,
20
25
  )
21
- from .telemetry import LabelPrediction
26
+ from .regression_model import RegressionModel
27
+ from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
22
28
 
23
29
  # only specify things that should show up on the root page of the reference docs because they are in private modules
24
- __all__ = ["TaskStatus", "UNSET", "CreateMode", "DropMode"]
30
+ __all__ = ["UNSET", "CreateMode", "DropMode"]
@@ -0,0 +1,10 @@
1
+ from .metrics import (
2
+ ClassificationMetrics,
3
+ PRCurve,
4
+ RegressionMetrics,
5
+ ROCCurve,
6
+ calculate_classification_metrics,
7
+ calculate_pr_curve,
8
+ calculate_regression_metrics,
9
+ calculate_roc_curve,
10
+ )
@@ -0,0 +1,393 @@
1
+ """
2
+ This module contains metrics for usage with the Hugging Face Trainer.
3
+
4
+ IMPORTANT:
5
+ - This is a shared file between OrcaLib and the OrcaSDK.
6
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
7
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
8
+
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any, Literal, TypedDict, cast
13
+
14
+ import numpy as np
15
+ import sklearn.metrics
16
+ from numpy.typing import NDArray
17
+
18
+
19
+ # we don't want to depend on scipy or torch in orca_sdk
20
+ def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
21
+ shifted = logits - np.max(logits, axis=axis, keepdims=True)
22
+ exps = np.exp(shifted)
23
+ return exps / np.sum(exps, axis=axis, keepdims=True)
24
+
25
+
26
+ # We don't want to depend on transformers just for the eval_pred type in orca_sdk
27
+ def transform_eval_pred(eval_pred: Any) -> tuple[NDArray, NDArray[np.float32]]:
28
+ # convert results from Trainer compute_metrics param for use in calculate_classification_metrics
29
+ logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
30
+ if isinstance(logits, tuple):
31
+ logits = logits[0]
32
+ if not isinstance(logits, np.ndarray):
33
+ raise ValueError("Logits must be a numpy array")
34
+ if not isinstance(references, np.ndarray):
35
+ raise ValueError(
36
+ "Multiple label columns found, use the `label_names` training argument to specify which one to use"
37
+ )
38
+
39
+ return (references, logits)
40
+
41
+
42
+ class PRCurve(TypedDict):
43
+ thresholds: list[float]
44
+ precisions: list[float]
45
+ recalls: list[float]
46
+
47
+
48
+ def calculate_pr_curve(
49
+ references: NDArray[np.int64],
50
+ probabilities: NDArray[np.float32],
51
+ max_length: int = 100,
52
+ ) -> PRCurve:
53
+ if probabilities.ndim == 1:
54
+ probabilities_slice = probabilities
55
+ elif probabilities.ndim == 2:
56
+ probabilities_slice = probabilities[:, 1]
57
+ else:
58
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
59
+
60
+ if len(probabilities_slice) != len(references):
61
+ raise ValueError("Probabilities and references must have the same length")
62
+
63
+ precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(references, probabilities_slice)
64
+
65
+ # Convert all arrays to float32 immediately after getting them
66
+ precisions = precisions.astype(np.float32)
67
+ recalls = recalls.astype(np.float32)
68
+ thresholds = thresholds.astype(np.float32)
69
+
70
+ # Concatenate with 0 to include the lowest threshold
71
+ thresholds = np.concatenate(([0], thresholds))
72
+
73
+ # Sort by threshold
74
+ sorted_indices = np.argsort(thresholds)
75
+ thresholds = thresholds[sorted_indices]
76
+ precisions = precisions[sorted_indices]
77
+ recalls = recalls[sorted_indices]
78
+
79
+ if len(precisions) > max_length:
80
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
81
+ new_precisions = np.interp(new_thresholds, thresholds, precisions)
82
+ new_recalls = np.interp(new_thresholds, thresholds, recalls)
83
+ thresholds = new_thresholds
84
+ precisions = new_precisions
85
+ recalls = new_recalls
86
+
87
+ return PRCurve(
88
+ thresholds=cast(list[float], thresholds.tolist()),
89
+ precisions=cast(list[float], precisions.tolist()),
90
+ recalls=cast(list[float], recalls.tolist()),
91
+ )
92
+
93
+
94
+ class ROCCurve(TypedDict):
95
+ thresholds: list[float]
96
+ false_positive_rates: list[float]
97
+ true_positive_rates: list[float]
98
+
99
+
100
+ def calculate_roc_curve(
101
+ references: NDArray[np.int64],
102
+ probabilities: NDArray[np.float32],
103
+ max_length: int = 100,
104
+ ) -> ROCCurve:
105
+ if probabilities.ndim == 1:
106
+ probabilities_slice = probabilities
107
+ elif probabilities.ndim == 2:
108
+ probabilities_slice = probabilities[:, 1]
109
+ else:
110
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
111
+
112
+ if len(probabilities_slice) != len(references):
113
+ raise ValueError("Probabilities and references must have the same length")
114
+
115
+ # Convert probabilities to float32 before calling sklearn_roc_curve
116
+ probabilities_slice = probabilities_slice.astype(np.float32)
117
+ fpr, tpr, thresholds = sklearn.metrics.roc_curve(references, probabilities_slice)
118
+
119
+ # Convert all arrays to float32 immediately after getting them
120
+ fpr = fpr.astype(np.float32)
121
+ tpr = tpr.astype(np.float32)
122
+ thresholds = thresholds.astype(np.float32)
123
+
124
+ # We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
125
+ thresholds[0] = 1.0
126
+
127
+ # Sort by threshold
128
+ sorted_indices = np.argsort(thresholds)
129
+ thresholds = thresholds[sorted_indices]
130
+ fpr = fpr[sorted_indices]
131
+ tpr = tpr[sorted_indices]
132
+
133
+ if len(fpr) > max_length:
134
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
135
+ new_fpr = np.interp(new_thresholds, thresholds, fpr)
136
+ new_tpr = np.interp(new_thresholds, thresholds, tpr)
137
+ thresholds = new_thresholds
138
+ fpr = new_fpr
139
+ tpr = new_tpr
140
+
141
+ return ROCCurve(
142
+ false_positive_rates=cast(list[float], fpr.tolist()),
143
+ true_positive_rates=cast(list[float], tpr.tolist()),
144
+ thresholds=cast(list[float], thresholds.tolist()),
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class ClassificationMetrics:
150
+ coverage: float
151
+ """Percentage of predictions that are not none"""
152
+
153
+ f1_score: float
154
+ """F1 score of the predictions"""
155
+
156
+ accuracy: float
157
+ """Accuracy of the predictions"""
158
+
159
+ loss: float | None
160
+ """Cross-entropy loss of the logits"""
161
+
162
+ anomaly_score_mean: float | None = None
163
+ """Mean of anomaly scores across the dataset"""
164
+
165
+ anomaly_score_median: float | None = None
166
+ """Median of anomaly scores across the dataset"""
167
+
168
+ anomaly_score_variance: float | None = None
169
+ """Variance of anomaly scores across the dataset"""
170
+
171
+ roc_auc: float | None = None
172
+ """Receiver operating characteristic area under the curve"""
173
+
174
+ pr_auc: float | None = None
175
+ """Average precision (area under the curve of the precision-recall curve)"""
176
+
177
+ pr_curve: PRCurve | None = None
178
+ """Precision-recall curve"""
179
+
180
+ roc_curve: ROCCurve | None = None
181
+ """Receiver operating characteristic curve"""
182
+
183
+ def __repr__(self) -> str:
184
+ return (
185
+ "ClassificationMetrics({\n"
186
+ + f" accuracy: {self.accuracy:.4f},\n"
187
+ + f" f1_score: {self.f1_score:.4f},\n"
188
+ + (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
189
+ + (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
190
+ + (
191
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
192
+ if self.anomaly_score_mean
193
+ else ""
194
+ )
195
+ + "})"
196
+ )
197
+
198
+
199
+ def calculate_classification_metrics(
200
+ expected_labels: list[int] | NDArray[np.int64],
201
+ logits: list[list[float]] | list[NDArray[np.float32]] | NDArray[np.float32],
202
+ anomaly_scores: list[float] | None = None,
203
+ average: Literal["micro", "macro", "weighted", "binary"] | None = None,
204
+ multi_class: Literal["ovr", "ovo"] = "ovr",
205
+ include_curves: bool = False,
206
+ ) -> ClassificationMetrics:
207
+ references = np.array(expected_labels)
208
+
209
+ logits = np.array(logits)
210
+ if logits.ndim == 1:
211
+ if (logits > 1).any() or (logits < 0).any():
212
+ raise ValueError("Logits must be between 0 and 1 for binary classification")
213
+ # convert 1D probabilities (binary) to 2D logits
214
+ logits = np.column_stack([1 - logits, logits])
215
+ probabilities = logits # no need to convert to probabilities
216
+ elif logits.ndim == 2:
217
+ if logits.shape[1] < 2:
218
+ raise ValueError("Use a different metric function for regression tasks")
219
+ if not (logits > 0).all():
220
+ # convert logits to probabilities with softmax if necessary
221
+ probabilities = softmax(logits)
222
+ elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
223
+ # convert logits to probabilities through normalization if necessary
224
+ probabilities = logits / logits.sum(-1, keepdims=True)
225
+ else:
226
+ probabilities = logits
227
+ else:
228
+ raise ValueError("Logits must be 1 or 2 dimensional")
229
+
230
+ predictions = np.argmax(probabilities, axis=-1)
231
+ predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
232
+
233
+ num_classes_references = len(set(references))
234
+ num_classes_predictions = len(set(predictions))
235
+ num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
236
+ coverage = 1 - num_none_predictions / len(probabilities)
237
+
238
+ if average is None:
239
+ average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
240
+
241
+ anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
242
+ anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
243
+ anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
244
+
245
+ accuracy = sklearn.metrics.accuracy_score(references, predictions)
246
+ f1 = sklearn.metrics.f1_score(references, predictions, average=average)
247
+ # Ensure sklearn sees the full class set corresponding to probability columns
248
+ # to avoid errors when y_true does not contain all classes.
249
+ loss = (
250
+ sklearn.metrics.log_loss(
251
+ references,
252
+ probabilities,
253
+ labels=list(range(probabilities.shape[1])),
254
+ )
255
+ if num_none_predictions == 0
256
+ else None
257
+ )
258
+
259
+ if num_classes_references == num_classes_predictions and num_none_predictions == 0:
260
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
261
+ if num_classes_references == 2:
262
+ roc_auc = sklearn.metrics.roc_auc_score(references, logits[:, 1])
263
+ roc_curve = calculate_roc_curve(references, logits[:, 1]) if include_curves else None
264
+ pr_auc = sklearn.metrics.average_precision_score(references, logits[:, 1])
265
+ pr_curve = calculate_pr_curve(references, logits[:, 1]) if include_curves else None
266
+ else:
267
+ roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
268
+ roc_curve = None
269
+ pr_auc = None
270
+ pr_curve = None
271
+ else:
272
+ roc_auc = None
273
+ pr_auc = None
274
+ pr_curve = None
275
+ roc_curve = None
276
+
277
+ return ClassificationMetrics(
278
+ coverage=coverage,
279
+ accuracy=float(accuracy),
280
+ f1_score=float(f1),
281
+ loss=float(loss) if loss is not None else None,
282
+ anomaly_score_mean=anomaly_score_mean,
283
+ anomaly_score_median=anomaly_score_median,
284
+ anomaly_score_variance=anomaly_score_variance,
285
+ roc_auc=float(roc_auc) if roc_auc is not None else None,
286
+ pr_auc=float(pr_auc) if pr_auc is not None else None,
287
+ pr_curve=pr_curve,
288
+ roc_curve=roc_curve,
289
+ )
290
+
291
+
292
+ @dataclass
293
+ class RegressionMetrics:
294
+ coverage: float
295
+ """Percentage of predictions that are not none"""
296
+
297
+ mse: float
298
+ """Mean squared error of the predictions"""
299
+
300
+ rmse: float
301
+ """Root mean squared error of the predictions"""
302
+
303
+ mae: float
304
+ """Mean absolute error of the predictions"""
305
+
306
+ r2: float
307
+ """R-squared score (coefficient of determination) of the predictions"""
308
+
309
+ explained_variance: float
310
+ """Explained variance score of the predictions"""
311
+
312
+ loss: float
313
+ """Mean squared error loss of the predictions"""
314
+
315
+ anomaly_score_mean: float | None = None
316
+ """Mean of anomaly scores across the dataset"""
317
+
318
+ anomaly_score_median: float | None = None
319
+ """Median of anomaly scores across the dataset"""
320
+
321
+ anomaly_score_variance: float | None = None
322
+ """Variance of anomaly scores across the dataset"""
323
+
324
+ def __repr__(self) -> str:
325
+ return (
326
+ "RegressionMetrics({\n"
327
+ + f" mae: {self.mae:.4f},\n"
328
+ + f" rmse: {self.rmse:.4f},\n"
329
+ + f" r2: {self.r2:.4f},\n"
330
+ + (
331
+ f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
332
+ if self.anomaly_score_mean
333
+ else ""
334
+ )
335
+ + "})"
336
+ )
337
+
338
+
339
+ def calculate_regression_metrics(
340
+ expected_scores: NDArray[np.float32] | list[float],
341
+ predicted_scores: NDArray[np.float32] | list[float],
342
+ anomaly_scores: list[float] | None = None,
343
+ ) -> RegressionMetrics:
344
+ """
345
+ Calculate regression metrics for model evaluation.
346
+
347
+ Params:
348
+ references: True target values
349
+ predictions: Predicted values from the model
350
+ anomaly_scores: Optional anomaly scores for each prediction
351
+
352
+ Returns:
353
+ Comprehensive regression metrics including MSE, RMSE, MAE, R², and explained variance
354
+
355
+ Raises:
356
+ ValueError: If predictions and references have different lengths
357
+ """
358
+ references = np.array(expected_scores)
359
+ predictions = np.array(predicted_scores)
360
+
361
+ if len(predictions) != len(references):
362
+ raise ValueError("Predictions and references must have the same length")
363
+
364
+ anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
365
+ anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
366
+ anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
367
+
368
+ none_prediction_mask = np.isnan(predictions)
369
+ num_none_predictions = none_prediction_mask.sum()
370
+ coverage = 1 - num_none_predictions / len(predictions)
371
+ if num_none_predictions > 0:
372
+ references = references[~none_prediction_mask]
373
+ predictions = predictions[~none_prediction_mask]
374
+
375
+ # Calculate core regression metrics
376
+ mse = float(sklearn.metrics.mean_squared_error(references, predictions))
377
+ rmse = float(np.sqrt(mse))
378
+ mae = float(sklearn.metrics.mean_absolute_error(references, predictions))
379
+ r2 = float(sklearn.metrics.r2_score(references, predictions))
380
+ explained_var = float(sklearn.metrics.explained_variance_score(references, predictions))
381
+
382
+ return RegressionMetrics(
383
+ coverage=coverage,
384
+ mse=mse,
385
+ rmse=rmse,
386
+ mae=mae,
387
+ r2=r2,
388
+ explained_variance=explained_var,
389
+ loss=mse, # For regression, loss is typically MSE
390
+ anomaly_score_mean=anomaly_score_mean,
391
+ anomaly_score_median=anomaly_score_median,
392
+ anomaly_score_variance=anomaly_score_variance,
393
+ )