orca-sdk 0.0.91__py3-none-any.whl → 0.0.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +4 -0
  2. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
  3. orca_sdk/_generated_api_client/models/__init__.py +4 -0
  4. orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
  5. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
  6. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
  7. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
  8. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
  9. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
  10. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
  11. orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
  12. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
  13. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
  14. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
  15. orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
  16. orca_sdk/_shared/__init__.py +1 -0
  17. orca_sdk/_shared/metrics.py +195 -0
  18. orca_sdk/_shared/metrics_test.py +169 -0
  19. orca_sdk/_utils/data_parsing.py +31 -2
  20. orca_sdk/_utils/data_parsing_test.py +18 -15
  21. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  22. orca_sdk/classification_model.py +170 -27
  23. orca_sdk/classification_model_test.py +74 -32
  24. orca_sdk/conftest.py +86 -25
  25. orca_sdk/datasource.py +22 -12
  26. orca_sdk/embedding_model_test.py +6 -5
  27. orca_sdk/memoryset.py +78 -0
  28. orca_sdk/memoryset_test.py +197 -123
  29. orca_sdk/telemetry.py +3 -0
  30. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/METADATA +3 -1
  31. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/RECORD +32 -25
  32. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/WHEEL +0 -0
@@ -38,6 +38,7 @@ class LabelPredictionMemoryLookup:
38
38
  memory_version (int):
39
39
  created_at (datetime.datetime):
40
40
  updated_at (datetime.datetime):
41
+ edited_at (datetime.datetime):
41
42
  metrics (MemoryMetrics):
42
43
  label (int):
43
44
  label_name (Union[None, str]):
@@ -54,6 +55,7 @@ class LabelPredictionMemoryLookup:
54
55
  memory_version: int
55
56
  created_at: datetime.datetime
56
57
  updated_at: datetime.datetime
58
+ edited_at: datetime.datetime
57
59
  metrics: "MemoryMetrics"
58
60
  label: int
59
61
  label_name: Union[None, str]
@@ -81,6 +83,8 @@ class LabelPredictionMemoryLookup:
81
83
 
82
84
  updated_at = self.updated_at.isoformat()
83
85
 
86
+ edited_at = self.edited_at.isoformat()
87
+
84
88
  metrics = self.metrics.to_dict()
85
89
 
86
90
  label = self.label
@@ -106,6 +110,7 @@ class LabelPredictionMemoryLookup:
106
110
  "memory_version": memory_version,
107
111
  "created_at": created_at,
108
112
  "updated_at": updated_at,
113
+ "edited_at": edited_at,
109
114
  "metrics": metrics,
110
115
  "label": label,
111
116
  "label_name": label_name,
@@ -148,6 +153,8 @@ class LabelPredictionMemoryLookup:
148
153
 
149
154
  updated_at = isoparse(d.pop("updated_at"))
150
155
 
156
+ edited_at = isoparse(d.pop("edited_at"))
157
+
151
158
  metrics = MemoryMetrics.from_dict(d.pop("metrics"))
152
159
 
153
160
  label = d.pop("label")
@@ -174,6 +181,7 @@ class LabelPredictionMemoryLookup:
174
181
  memory_version=memory_version,
175
182
  created_at=created_at,
176
183
  updated_at=updated_at,
184
+ edited_at=edited_at,
177
185
  metrics=metrics,
178
186
  label=label,
179
187
  label_name=label_name,
@@ -34,10 +34,10 @@ class LabelPredictionWithMemoriesAndFeedback:
34
34
  anomaly_score (Union[None, float]):
35
35
  label (int):
36
36
  label_name (Union[None, str]):
37
+ logits (List[float]):
37
38
  timestamp (datetime.datetime):
38
39
  input_value (str):
39
40
  input_embedding (List[float]):
40
- logits (List[float]):
41
41
  expected_label (Union[None, int]):
42
42
  expected_label_name (Union[None, str]):
43
43
  memories (List['LabelPredictionMemoryLookup']):
@@ -56,10 +56,10 @@ class LabelPredictionWithMemoriesAndFeedback:
56
56
  anomaly_score: Union[None, float]
57
57
  label: int
58
58
  label_name: Union[None, str]
59
+ logits: List[float]
59
60
  timestamp: datetime.datetime
60
61
  input_value: str
61
62
  input_embedding: List[float]
62
- logits: List[float]
63
63
  expected_label: Union[None, int]
64
64
  expected_label_name: Union[None, str]
65
65
  memories: List["LabelPredictionMemoryLookup"]
@@ -86,6 +86,8 @@ class LabelPredictionWithMemoriesAndFeedback:
86
86
  label_name: Union[None, str]
87
87
  label_name = self.label_name
88
88
 
89
+ logits = self.logits
90
+
89
91
  timestamp = self.timestamp.isoformat()
90
92
 
91
93
  input_value: str
@@ -93,8 +95,6 @@ class LabelPredictionWithMemoriesAndFeedback:
93
95
 
94
96
  input_embedding = self.input_embedding
95
97
 
96
- logits = self.logits
97
-
98
98
  expected_label: Union[None, int]
99
99
  expected_label = self.expected_label
100
100
 
@@ -136,10 +136,10 @@ class LabelPredictionWithMemoriesAndFeedback:
136
136
  "anomaly_score": anomaly_score,
137
137
  "label": label,
138
138
  "label_name": label_name,
139
+ "logits": logits,
139
140
  "timestamp": timestamp,
140
141
  "input_value": input_value,
141
142
  "input_embedding": input_embedding,
142
- "logits": logits,
143
143
  "expected_label": expected_label,
144
144
  "expected_label_name": expected_label_name,
145
145
  "memories": memories,
@@ -182,6 +182,8 @@ class LabelPredictionWithMemoriesAndFeedback:
182
182
 
183
183
  label_name = _parse_label_name(d.pop("label_name"))
184
184
 
185
+ logits = cast(List[float], d.pop("logits"))
186
+
185
187
  timestamp = isoparse(d.pop("timestamp"))
186
188
 
187
189
  def _parse_input_value(data: object) -> str:
@@ -191,8 +193,6 @@ class LabelPredictionWithMemoriesAndFeedback:
191
193
 
192
194
  input_embedding = cast(List[float], d.pop("input_embedding"))
193
195
 
194
- logits = cast(List[float], d.pop("logits"))
195
-
196
196
  def _parse_expected_label(data: object) -> Union[None, int]:
197
197
  if data is None:
198
198
  return data
@@ -251,10 +251,10 @@ class LabelPredictionWithMemoriesAndFeedback:
251
251
  anomaly_score=anomaly_score,
252
252
  label=label,
253
253
  label_name=label_name,
254
+ logits=logits,
254
255
  timestamp=timestamp,
255
256
  input_value=input_value,
256
257
  input_embedding=input_embedding,
257
- logits=logits,
258
258
  expected_label=expected_label,
259
259
  expected_label_name=expected_label_name,
260
260
  memories=memories,
@@ -38,6 +38,7 @@ class LabeledMemory:
38
38
  memory_version (int):
39
39
  created_at (datetime.datetime):
40
40
  updated_at (datetime.datetime):
41
+ edited_at (datetime.datetime):
41
42
  metrics (LabeledMemoryMetrics): Metrics computed for a labeled memory.
42
43
  label (int):
43
44
  label_name (Union[None, str]):
@@ -51,6 +52,7 @@ class LabeledMemory:
51
52
  memory_version: int
52
53
  created_at: datetime.datetime
53
54
  updated_at: datetime.datetime
55
+ edited_at: datetime.datetime
54
56
  metrics: "LabeledMemoryMetrics"
55
57
  label: int
56
58
  label_name: Union[None, str]
@@ -75,6 +77,8 @@ class LabeledMemory:
75
77
 
76
78
  updated_at = self.updated_at.isoformat()
77
79
 
80
+ edited_at = self.edited_at.isoformat()
81
+
78
82
  metrics = self.metrics.to_dict()
79
83
 
80
84
  label = self.label
@@ -94,6 +98,7 @@ class LabeledMemory:
94
98
  "memory_version": memory_version,
95
99
  "created_at": created_at,
96
100
  "updated_at": updated_at,
101
+ "edited_at": edited_at,
97
102
  "metrics": metrics,
98
103
  "label": label,
99
104
  "label_name": label_name,
@@ -133,6 +138,8 @@ class LabeledMemory:
133
138
 
134
139
  updated_at = isoparse(d.pop("updated_at"))
135
140
 
141
+ edited_at = isoparse(d.pop("edited_at"))
142
+
136
143
  metrics = LabeledMemoryMetrics.from_dict(d.pop("metrics"))
137
144
 
138
145
  label = d.pop("label")
@@ -153,6 +160,7 @@ class LabeledMemory:
153
160
  memory_version=memory_version,
154
161
  created_at=created_at,
155
162
  updated_at=updated_at,
163
+ edited_at=edited_at,
156
164
  metrics=metrics,
157
165
  label=label,
158
166
  label_name=label_name,
@@ -38,6 +38,7 @@ class LabeledMemoryLookup:
38
38
  memory_version (int):
39
39
  created_at (datetime.datetime):
40
40
  updated_at (datetime.datetime):
41
+ edited_at (datetime.datetime):
41
42
  metrics (MemoryMetrics):
42
43
  label (int):
43
44
  label_name (Union[None, str]):
@@ -52,6 +53,7 @@ class LabeledMemoryLookup:
52
53
  memory_version: int
53
54
  created_at: datetime.datetime
54
55
  updated_at: datetime.datetime
56
+ edited_at: datetime.datetime
55
57
  metrics: "MemoryMetrics"
56
58
  label: int
57
59
  label_name: Union[None, str]
@@ -77,6 +79,8 @@ class LabeledMemoryLookup:
77
79
 
78
80
  updated_at = self.updated_at.isoformat()
79
81
 
82
+ edited_at = self.edited_at.isoformat()
83
+
80
84
  metrics = self.metrics.to_dict()
81
85
 
82
86
  label = self.label
@@ -98,6 +102,7 @@ class LabeledMemoryLookup:
98
102
  "memory_version": memory_version,
99
103
  "created_at": created_at,
100
104
  "updated_at": updated_at,
105
+ "edited_at": edited_at,
101
106
  "metrics": metrics,
102
107
  "label": label,
103
108
  "label_name": label_name,
@@ -138,6 +143,8 @@ class LabeledMemoryLookup:
138
143
 
139
144
  updated_at = isoparse(d.pop("updated_at"))
140
145
 
146
+ edited_at = isoparse(d.pop("edited_at"))
147
+
141
148
  metrics = MemoryMetrics.from_dict(d.pop("metrics"))
142
149
 
143
150
  label = d.pop("label")
@@ -160,6 +167,7 @@ class LabeledMemoryLookup:
160
167
  memory_version=memory_version,
161
168
  created_at=created_at,
162
169
  updated_at=updated_at,
170
+ edited_at=edited_at,
163
171
  metrics=metrics,
164
172
  label=label,
165
173
  label_name=label_name,
@@ -40,6 +40,7 @@ class LabeledMemoryWithFeedbackMetrics:
40
40
  memory_version (int):
41
41
  created_at (datetime.datetime):
42
42
  updated_at (datetime.datetime):
43
+ edited_at (datetime.datetime):
43
44
  metrics (LabeledMemoryMetrics): Metrics computed for a labeled memory.
44
45
  label (int):
45
46
  label_name (Union[None, str]):
@@ -55,6 +56,7 @@ class LabeledMemoryWithFeedbackMetrics:
55
56
  memory_version: int
56
57
  created_at: datetime.datetime
57
58
  updated_at: datetime.datetime
59
+ edited_at: datetime.datetime
58
60
  metrics: "LabeledMemoryMetrics"
59
61
  label: int
60
62
  label_name: Union[None, str]
@@ -81,6 +83,8 @@ class LabeledMemoryWithFeedbackMetrics:
81
83
 
82
84
  updated_at = self.updated_at.isoformat()
83
85
 
86
+ edited_at = self.edited_at.isoformat()
87
+
84
88
  metrics = self.metrics.to_dict()
85
89
 
86
90
  label = self.label
@@ -104,6 +108,7 @@ class LabeledMemoryWithFeedbackMetrics:
104
108
  "memory_version": memory_version,
105
109
  "created_at": created_at,
106
110
  "updated_at": updated_at,
111
+ "edited_at": edited_at,
107
112
  "metrics": metrics,
108
113
  "label": label,
109
114
  "label_name": label_name,
@@ -148,6 +153,8 @@ class LabeledMemoryWithFeedbackMetrics:
148
153
 
149
154
  updated_at = isoparse(d.pop("updated_at"))
150
155
 
156
+ edited_at = isoparse(d.pop("edited_at"))
157
+
151
158
  metrics = LabeledMemoryMetrics.from_dict(d.pop("metrics"))
152
159
 
153
160
  label = d.pop("label")
@@ -172,6 +179,7 @@ class LabeledMemoryWithFeedbackMetrics:
172
179
  memory_version=memory_version,
173
180
  created_at=created_at,
174
181
  updated_at=updated_at,
182
+ edited_at=edited_at,
175
183
  metrics=metrics,
176
184
  label=label,
177
185
  label_name=label_name,
@@ -43,6 +43,7 @@ class LabeledMemorysetMetadata:
43
43
  label_names (List[str]):
44
44
  created_at (datetime.datetime):
45
45
  updated_at (datetime.datetime):
46
+ memories_updated_at (datetime.datetime):
46
47
  insertion_task_id (str):
47
48
  insertion_status (TaskStatus): Status of task in the task queue
48
49
  metrics (MemorysetMetrics):
@@ -59,6 +60,7 @@ class LabeledMemorysetMetadata:
59
60
  label_names: List[str]
60
61
  created_at: datetime.datetime
61
62
  updated_at: datetime.datetime
63
+ memories_updated_at: datetime.datetime
62
64
  insertion_task_id: str
63
65
  insertion_status: TaskStatus
64
66
  metrics: "MemorysetMetrics"
@@ -97,6 +99,8 @@ class LabeledMemorysetMetadata:
97
99
 
98
100
  updated_at = self.updated_at.isoformat()
99
101
 
102
+ memories_updated_at = self.memories_updated_at.isoformat()
103
+
100
104
  insertion_task_id = self.insertion_task_id
101
105
 
102
106
  insertion_status = (
@@ -120,6 +124,7 @@ class LabeledMemorysetMetadata:
120
124
  "label_names": label_names,
121
125
  "created_at": created_at,
122
126
  "updated_at": updated_at,
127
+ "memories_updated_at": memories_updated_at,
123
128
  "insertion_task_id": insertion_task_id,
124
129
  "insertion_status": insertion_status,
125
130
  "metrics": metrics,
@@ -180,6 +185,8 @@ class LabeledMemorysetMetadata:
180
185
 
181
186
  updated_at = isoparse(d.pop("updated_at"))
182
187
 
188
+ memories_updated_at = isoparse(d.pop("memories_updated_at"))
189
+
183
190
  insertion_task_id = d.pop("insertion_task_id")
184
191
 
185
192
  insertion_status = TaskStatus(d.pop("insertion_status"))
@@ -198,6 +205,7 @@ class LabeledMemorysetMetadata:
198
205
  label_names=label_names,
199
206
  created_at=created_at,
200
207
  updated_at=updated_at,
208
+ memories_updated_at=memories_updated_at,
201
209
  insertion_task_id=insertion_task_id,
202
210
  insertion_status=insertion_status,
203
211
  metrics=metrics,
@@ -28,14 +28,16 @@ class PredictionRequest:
28
28
  expected_labels (Union[List[int], None, Unset]):
29
29
  tags (Union[Unset, List[str]]):
30
30
  memoryset_override_id (Union[None, Unset, str]):
31
- disable_telemetry (Union[Unset, bool]): Default: False.
31
+ save_telemetry (Union[Unset, bool]): Default: True.
32
+ save_telemetry_synchronously (Union[Unset, bool]): Default: False.
32
33
  """
33
34
 
34
35
  input_values: List[str]
35
36
  expected_labels: Union[List[int], None, Unset] = UNSET
36
37
  tags: Union[Unset, List[str]] = UNSET
37
38
  memoryset_override_id: Union[None, Unset, str] = UNSET
38
- disable_telemetry: Union[Unset, bool] = False
39
+ save_telemetry: Union[Unset, bool] = True
40
+ save_telemetry_synchronously: Union[Unset, bool] = False
39
41
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
40
42
 
41
43
  def to_dict(self) -> dict[str, Any]:
@@ -62,7 +64,9 @@ class PredictionRequest:
62
64
  else:
63
65
  memoryset_override_id = self.memoryset_override_id
64
66
 
65
- disable_telemetry = self.disable_telemetry
67
+ save_telemetry = self.save_telemetry
68
+
69
+ save_telemetry_synchronously = self.save_telemetry_synchronously
66
70
 
67
71
  field_dict: dict[str, Any] = {}
68
72
  field_dict.update(self.additional_properties)
@@ -77,8 +81,10 @@ class PredictionRequest:
77
81
  field_dict["tags"] = tags
78
82
  if memoryset_override_id is not UNSET:
79
83
  field_dict["memoryset_override_id"] = memoryset_override_id
80
- if disable_telemetry is not UNSET:
81
- field_dict["disable_telemetry"] = disable_telemetry
84
+ if save_telemetry is not UNSET:
85
+ field_dict["save_telemetry"] = save_telemetry
86
+ if save_telemetry_synchronously is not UNSET:
87
+ field_dict["save_telemetry_synchronously"] = save_telemetry_synchronously
82
88
 
83
89
  return field_dict
84
90
 
@@ -156,14 +162,17 @@ class PredictionRequest:
156
162
 
157
163
  memoryset_override_id = _parse_memoryset_override_id(d.pop("memoryset_override_id", UNSET))
158
164
 
159
- disable_telemetry = d.pop("disable_telemetry", UNSET)
165
+ save_telemetry = d.pop("save_telemetry", UNSET)
166
+
167
+ save_telemetry_synchronously = d.pop("save_telemetry_synchronously", UNSET)
160
168
 
161
169
  prediction_request = cls(
162
170
  input_values=input_values,
163
171
  expected_labels=expected_labels,
164
172
  tags=tags,
165
173
  memoryset_override_id=memoryset_override_id,
166
- disable_telemetry=disable_telemetry,
174
+ save_telemetry=save_telemetry,
175
+ save_telemetry_synchronously=save_telemetry_synchronously,
167
176
  )
168
177
 
169
178
  prediction_request.additional_properties = d
@@ -0,0 +1 @@
1
+ from .metrics import calculate_pr_curve, calculate_roc_curve, compute_classifier_metrics
@@ -0,0 +1,195 @@
1
+ """
2
+ This module contains metrics for usage with the Hugging Face Trainer.
3
+
4
+ IMPORTANT:
5
+ - This is a shared file between OrcaLib and the Orca SDK.
6
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
7
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
8
+
9
+ """
10
+
11
+ from typing import Literal, Tuple, TypedDict
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+ from scipy.special import softmax
16
+ from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
17
+ from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
18
+ from sklearn.metrics import roc_auc_score
19
+ from sklearn.metrics import roc_curve as sklearn_roc_curve
20
+ from transformers.trainer_utils import EvalPrediction
21
+
22
+
23
+ class ClassificationMetrics(TypedDict):
24
+ accuracy: float
25
+ f1_score: float
26
+ roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
27
+ pr_auc: float | None # precision-recall area under the curve (only for binary classification)
28
+ log_loss: float # cross-entropy loss for probabilities
29
+
30
+
31
+ def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
32
+ """
33
+ Compute standard metrics for classifier with Hugging Face Trainer.
34
+
35
+ Args:
36
+ eval_pred: The predictions containing logits and expected labels as given by the Trainer.
37
+
38
+ Returns:
39
+ A dictionary containing the accuracy, f1 score, and ROC AUC score.
40
+ """
41
+ logits, references = eval_pred
42
+ if isinstance(logits, tuple):
43
+ logits = logits[0]
44
+ if not isinstance(logits, np.ndarray):
45
+ raise ValueError("Logits must be a numpy array")
46
+ if not isinstance(references, np.ndarray):
47
+ raise ValueError(
48
+ "Multiple label columns found, use the `label_names` training argument to specify which one to use"
49
+ )
50
+
51
+ if not (logits > 0).all():
52
+ # convert logits to probabilities with softmax if necessary
53
+ probabilities = softmax(logits)
54
+ elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
55
+ # convert logits to probabilities through normalization if necessary
56
+ probabilities = logits / logits.sum(-1, keepdims=True)
57
+ else:
58
+ probabilities = logits
59
+
60
+ return classification_scores(references, probabilities)
61
+
62
+
63
+ def classification_scores(
64
+ references: NDArray[np.int64],
65
+ probabilities: NDArray[np.float32],
66
+ average: Literal["micro", "macro", "weighted", "binary"] | None = None,
67
+ multi_class: Literal["ovr", "ovo"] = "ovr",
68
+ ) -> ClassificationMetrics:
69
+ if probabilities.ndim == 1:
70
+ # convert 1D probabilities (binary) to 2D logits
71
+ probabilities = np.column_stack([1 - probabilities, probabilities])
72
+ elif probabilities.ndim == 2:
73
+ if probabilities.shape[1] < 2:
74
+ raise ValueError("Use a different metric function for regression tasks")
75
+ else:
76
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
77
+
78
+ predictions = np.argmax(probabilities, axis=-1)
79
+
80
+ num_classes_references = len(set(references))
81
+ num_classes_predictions = len(set(predictions))
82
+
83
+ if average is None:
84
+ average = "binary" if num_classes_references == 2 else "weighted"
85
+
86
+ accuracy = accuracy_score(references, predictions)
87
+ f1 = f1_score(references, predictions, average=average)
88
+ loss = log_loss(references, probabilities)
89
+
90
+ if num_classes_references == num_classes_predictions:
91
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
92
+ if num_classes_references == 2:
93
+ roc_auc = roc_auc_score(references, probabilities[:, 1])
94
+ precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
95
+ pr_auc = auc(recalls, precisions)
96
+ else:
97
+ roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
98
+ pr_auc = None
99
+ else:
100
+ roc_auc = None
101
+ pr_auc = None
102
+
103
+ return {
104
+ "accuracy": float(accuracy),
105
+ "f1_score": float(f1),
106
+ "roc_auc": float(roc_auc) if roc_auc is not None else None,
107
+ "pr_auc": float(pr_auc) if pr_auc is not None else None,
108
+ "log_loss": float(loss),
109
+ }
110
+
111
+
112
+ def calculate_pr_curve(
113
+ references: NDArray[np.int64],
114
+ probabilities: NDArray[np.float32],
115
+ max_length: int = 100,
116
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
117
+ if probabilities.ndim == 1:
118
+ probabilities_slice = probabilities
119
+ elif probabilities.ndim == 2:
120
+ probabilities_slice = probabilities[:, 1]
121
+ else:
122
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
123
+
124
+ if len(probabilities_slice) != len(references):
125
+ raise ValueError("Probabilities and references must have the same length")
126
+
127
+ precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
128
+
129
+ # Convert all arrays to float32 immediately after getting them
130
+ precisions = precisions.astype(np.float32)
131
+ recalls = recalls.astype(np.float32)
132
+ thresholds = thresholds.astype(np.float32)
133
+
134
+ # Concatenate with 0 to include the lowest threshold
135
+ thresholds = np.concatenate(([0], thresholds))
136
+
137
+ # Sort by threshold
138
+ sorted_indices = np.argsort(thresholds)
139
+ thresholds = thresholds[sorted_indices]
140
+ precisions = precisions[sorted_indices]
141
+ recalls = recalls[sorted_indices]
142
+
143
+ if len(precisions) > max_length:
144
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
145
+ new_precisions = np.interp(new_thresholds, thresholds, precisions)
146
+ new_recalls = np.interp(new_thresholds, thresholds, recalls)
147
+ thresholds = new_thresholds
148
+ precisions = new_precisions
149
+ recalls = new_recalls
150
+
151
+ return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
152
+
153
+
154
+ def calculate_roc_curve(
155
+ references: NDArray[np.int64],
156
+ probabilities: NDArray[np.float32],
157
+ max_length: int = 100,
158
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
159
+ if probabilities.ndim == 1:
160
+ probabilities_slice = probabilities
161
+ elif probabilities.ndim == 2:
162
+ probabilities_slice = probabilities[:, 1]
163
+ else:
164
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
165
+
166
+ if len(probabilities_slice) != len(references):
167
+ raise ValueError("Probabilities and references must have the same length")
168
+
169
+ # Convert probabilities to float32 before calling sklearn_roc_curve
170
+ probabilities_slice = probabilities_slice.astype(np.float32)
171
+ fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
172
+
173
+ # Convert all arrays to float32 immediately after getting them
174
+ fpr = fpr.astype(np.float32)
175
+ tpr = tpr.astype(np.float32)
176
+ thresholds = thresholds.astype(np.float32)
177
+
178
+ # We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
179
+ thresholds[0] = 1.0
180
+
181
+ # Sort by threshold
182
+ sorted_indices = np.argsort(thresholds)
183
+ thresholds = thresholds[sorted_indices]
184
+ fpr = fpr[sorted_indices]
185
+ tpr = tpr[sorted_indices]
186
+
187
+ if len(fpr) > max_length:
188
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
189
+ new_fpr = np.interp(new_thresholds, thresholds, fpr)
190
+ new_tpr = np.interp(new_thresholds, thresholds, tpr)
191
+ thresholds = new_thresholds
192
+ fpr = new_fpr
193
+ tpr = new_tpr
194
+
195
+ return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)