orca-sdk 0.0.91__py3-none-any.whl → 0.0.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +4 -0
  2. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
  3. orca_sdk/_generated_api_client/models/__init__.py +4 -0
  4. orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
  5. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
  6. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
  7. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
  8. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
  9. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
  10. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
  11. orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
  12. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
  13. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
  14. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
  15. orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
  16. orca_sdk/_shared/__init__.py +1 -0
  17. orca_sdk/_shared/metrics.py +195 -0
  18. orca_sdk/_shared/metrics_test.py +169 -0
  19. orca_sdk/_utils/data_parsing.py +31 -2
  20. orca_sdk/_utils/data_parsing_test.py +18 -15
  21. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  22. orca_sdk/classification_model.py +170 -27
  23. orca_sdk/classification_model_test.py +74 -32
  24. orca_sdk/conftest.py +86 -25
  25. orca_sdk/datasource.py +22 -12
  26. orca_sdk/embedding_model_test.py +6 -5
  27. orca_sdk/memoryset.py +78 -0
  28. orca_sdk/memoryset_test.py +197 -123
  29. orca_sdk/telemetry.py +3 -0
  30. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/METADATA +3 -1
  31. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/RECORD +32 -25
  32. {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/WHEEL +0 -0
@@ -0,0 +1,169 @@
1
+ """
2
+ IMPORTANT:
3
+ - This is a shared file between OrcaLib and the Orca SDK.
4
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
5
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
6
+ """
7
+
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+ import pytest
12
+
13
+ from .metrics import (
14
+ EvalPrediction,
15
+ calculate_pr_curve,
16
+ calculate_roc_curve,
17
+ classification_scores,
18
+ compute_classifier_metrics,
19
+ softmax,
20
+ )
21
+
22
+
23
+ def test_binary_metrics():
24
+ y_true = np.array([0, 1, 1, 0, 1])
25
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
+
27
+ metrics = classification_scores(y_true, y_score)
28
+
29
+ assert metrics["accuracy"] == 0.8
30
+ assert metrics["f1_score"] == 0.8
31
+ assert metrics["roc_auc"] is not None
32
+ assert metrics["roc_auc"] > 0.8
33
+ assert metrics["roc_auc"] < 1.0
34
+ assert metrics["pr_auc"] is not None
35
+ assert metrics["pr_auc"] > 0.8
36
+ assert metrics["pr_auc"] < 1.0
37
+ assert metrics["log_loss"] is not None
38
+ assert metrics["log_loss"] > 0.0
39
+
40
+
41
+ def test_multiclass_metrics_with_2_classes():
42
+ y_true = np.array([0, 1, 1, 0, 1])
43
+ y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
+
45
+ metrics = classification_scores(y_true, y_score)
46
+
47
+ assert metrics["accuracy"] == 0.8
48
+ assert metrics["f1_score"] == 0.8
49
+ assert metrics["roc_auc"] is not None
50
+ assert metrics["roc_auc"] > 0.8
51
+ assert metrics["roc_auc"] < 1.0
52
+ assert metrics["pr_auc"] is not None
53
+ assert metrics["pr_auc"] > 0.8
54
+ assert metrics["pr_auc"] < 1.0
55
+ assert metrics["log_loss"] is not None
56
+ assert metrics["log_loss"] > 0.0
57
+
58
+
59
+ @pytest.mark.parametrize(
60
+ "average, multiclass",
61
+ [("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
62
+ )
63
+ def test_multiclass_metrics_with_3_classes(
64
+ average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
65
+ ):
66
+ y_true = np.array([0, 1, 1, 0, 2])
67
+ y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
+
69
+ metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
70
+
71
+ assert metrics["accuracy"] == 1.0
72
+ assert metrics["f1_score"] == 1.0
73
+ assert metrics["roc_auc"] is not None
74
+ assert metrics["roc_auc"] > 0.8
75
+ assert metrics["pr_auc"] is None
76
+ assert metrics["log_loss"] is not None
77
+ assert metrics["log_loss"] > 0.0
78
+
79
+
80
+ def test_does_not_modify_logits_unless_necessary():
81
+ logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
+ references = np.array([0, 1, 0, 1])
83
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
84
+ assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
85
+
86
+
87
+ def test_normalizes_logits_if_necessary():
88
+ logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
89
+ references = np.array([0, 1, 0, 1])
90
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
91
+ assert (
92
+ metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
93
+ )
94
+
95
+
96
+ def test_softmaxes_logits_if_necessary():
97
+ logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
98
+ references = np.array([0, 1, 0, 1])
99
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
100
+ assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
101
+
102
+
103
+ def test_precision_recall_curve():
104
+ y_true = np.array([0, 1, 1, 0, 1])
105
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
106
+
107
+ precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
108
+ assert precision is not None
109
+ assert recall is not None
110
+ assert thresholds is not None
111
+
112
+ assert len(precision) == len(recall) == len(thresholds) == 6
113
+ assert precision[0] == 0.6
114
+ assert recall[0] == 1.0
115
+ assert precision[-1] == 1.0
116
+ assert recall[-1] == 0.0
117
+
118
+ # test that thresholds are sorted
119
+ assert np.all(np.diff(thresholds) >= 0)
120
+
121
+
122
+ def test_roc_curve():
123
+ y_true = np.array([0, 1, 1, 0, 1])
124
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
125
+
126
+ fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
127
+ assert fpr is not None
128
+ assert tpr is not None
129
+ assert thresholds is not None
130
+
131
+ assert len(fpr) == len(tpr) == len(thresholds) == 6
132
+ assert fpr[0] == 1.0
133
+ assert tpr[0] == 1.0
134
+ assert fpr[-1] == 0.0
135
+ assert tpr[-1] == 0.0
136
+
137
+ # test that thresholds are sorted
138
+ assert np.all(np.diff(thresholds) >= 0)
139
+
140
+
141
+ def test_precision_recall_curve_max_length():
142
+ y_true = np.array([0, 1, 1, 0, 1])
143
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
144
+
145
+ precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
146
+ assert len(precision) == len(recall) == len(thresholds) == 5
147
+
148
+ assert precision[0] == 0.6
149
+ assert recall[0] == 1.0
150
+ assert precision[-1] == 1.0
151
+ assert recall[-1] == 0.0
152
+
153
+ # test that thresholds are sorted
154
+ assert np.all(np.diff(thresholds) >= 0)
155
+
156
+
157
+ def test_roc_curve_max_length():
158
+ y_true = np.array([0, 1, 1, 0, 1])
159
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
160
+
161
+ fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
162
+ assert len(fpr) == len(tpr) == len(thresholds) == 5
163
+ assert fpr[0] == 1.0
164
+ assert tpr[0] == 1.0
165
+ assert fpr[-1] == 0.0
166
+ assert tpr[-1] == 0.0
167
+
168
+ # test that thresholds are sorted
169
+ assert np.all(np.diff(thresholds) >= 0)
@@ -1,12 +1,16 @@
1
+ import logging
1
2
  import pickle
2
3
  from dataclasses import asdict, is_dataclass
3
4
  from os import PathLike
5
+ from tempfile import TemporaryDirectory
4
6
  from typing import Any, cast
5
7
 
6
8
  from datasets import Dataset
7
9
  from torch.utils.data import DataLoader as TorchDataLoader
8
10
  from torch.utils.data import Dataset as TorchDataset
9
11
 
12
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
13
+
10
14
 
11
15
  def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
12
16
  if isinstance(item, dict):
@@ -40,7 +44,24 @@ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]
40
44
  return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
41
45
 
42
46
 
43
- def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None) -> Dataset:
47
+ def hf_dataset_from_torch(
48
+ torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None, ignore_cache=False
49
+ ) -> Dataset:
50
+ """
51
+ Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
52
+
53
+ NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
54
+ cached results can ignore changes you've made to tests. This can make a test appear to succeed
55
+ when it's actually broken or vice versa.
56
+
57
+ Params:
58
+ torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
59
+ column_names: Optional list of column names to use for the dataset. If not provided,
60
+ the column names will be inferred from the data.
61
+ ignore_cache: If True, the dataset will not be cached on disk.
62
+ Returns:
63
+ A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
64
+ """
44
65
  if isinstance(torch_data, TorchDataLoader):
45
66
  dataloader = torch_data
46
67
  else:
@@ -50,7 +71,15 @@ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_nam
50
71
  for batch in dataloader:
51
72
  yield from parse_batch(batch, column_names=column_names)
52
73
 
53
- return cast(Dataset, Dataset.from_generator(generator))
74
+ if ignore_cache:
75
+ with TemporaryDirectory() as temp_dir:
76
+ ds = Dataset.from_generator(generator, cache_dir=temp_dir)
77
+ else:
78
+ ds = Dataset.from_generator(generator)
79
+
80
+ if not isinstance(ds, Dataset):
81
+ raise ValueError(f"Failed to create dataset from generator: {type(ds)}")
82
+ return ds
54
83
 
55
84
 
56
85
  def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import pickle
3
4
  import tempfile
4
5
  from collections import namedtuple
@@ -14,6 +15,8 @@ from torch.utils.data import Dataset as TorchDataset
14
15
  from ..conftest import SAMPLE_DATA
15
16
  from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
16
17
 
18
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
19
+
17
20
 
18
21
  class PytorchDictDataset(TorchDataset):
19
22
  def __init__(self):
@@ -29,11 +32,11 @@ class PytorchDictDataset(TorchDataset):
29
32
  def test_hf_dataset_from_torch_dict():
30
33
  # Given a Pytorch dataset that returns a dictionary for each item
31
34
  dataset = PytorchDictDataset()
32
- hf_dataset = hf_dataset_from_torch(dataset)
35
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
33
36
  # Then the HF dataset should be created successfully
34
37
  assert isinstance(hf_dataset, Dataset)
35
38
  assert len(hf_dataset) == len(dataset)
36
- assert set(hf_dataset.column_names) == {"text", "label", "key", "score", "source_id"}
39
+ assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
37
40
 
38
41
 
39
42
  class PytorchTupleDataset(TorchDataset):
@@ -41,7 +44,7 @@ class PytorchTupleDataset(TorchDataset):
41
44
  self.data = SAMPLE_DATA
42
45
 
43
46
  def __getitem__(self, i):
44
- return self.data[i]["text"], self.data[i]["label"]
47
+ return self.data[i]["value"], self.data[i]["label"]
45
48
 
46
49
  def __len__(self):
47
50
  return len(self.data)
@@ -51,11 +54,11 @@ def test_hf_dataset_from_torch_tuple():
51
54
  # Given a Pytorch dataset that returns a tuple for each item
52
55
  dataset = PytorchTupleDataset()
53
56
  # And the correct number of column names passed in
54
- hf_dataset = hf_dataset_from_torch(dataset, column_names=["text", "label"])
57
+ hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
55
58
  # Then the HF dataset should be created successfully
56
59
  assert isinstance(hf_dataset, Dataset)
57
60
  assert len(hf_dataset) == len(dataset)
58
- assert hf_dataset.column_names == ["text", "label"]
61
+ assert hf_dataset.column_names == ["value", "label"]
59
62
 
60
63
 
61
64
  def test_hf_dataset_from_torch_tuple_error():
@@ -63,7 +66,7 @@ def test_hf_dataset_from_torch_tuple_error():
63
66
  dataset = PytorchTupleDataset()
64
67
  # Then the HF dataset should raise an error if no column names are passed in
65
68
  with pytest.raises(DatasetGenerationError):
66
- hf_dataset_from_torch(dataset)
69
+ hf_dataset_from_torch(dataset, ignore_cache=True)
67
70
 
68
71
 
69
72
  def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
@@ -71,7 +74,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
71
74
  dataset = PytorchTupleDataset()
72
75
  # Then the HF dataset should raise an error if not enough column names are passed in
73
76
  with pytest.raises(DatasetGenerationError):
74
- hf_dataset_from_torch(dataset, column_names=["value"])
77
+ hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
75
78
 
76
79
 
77
80
  DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
@@ -82,7 +85,7 @@ class PytorchNamedTupleDataset(TorchDataset):
82
85
  self.data = SAMPLE_DATA
83
86
 
84
87
  def __getitem__(self, i):
85
- return DatasetTuple(self.data[i]["text"], self.data[i]["label"])
88
+ return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
86
89
 
87
90
  def __len__(self):
88
91
  return len(self.data)
@@ -92,7 +95,7 @@ def test_hf_dataset_from_torch_named_tuple():
92
95
  # Given a Pytorch dataset that returns a namedtuple for each item
93
96
  dataset = PytorchNamedTupleDataset()
94
97
  # And no column names are passed in
95
- hf_dataset = hf_dataset_from_torch(dataset)
98
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
96
99
  # Then the HF dataset should be created successfully
97
100
  assert isinstance(hf_dataset, Dataset)
98
101
  assert len(hf_dataset) == len(dataset)
@@ -110,7 +113,7 @@ class PytorchDataclassDataset(TorchDataset):
110
113
  self.data = SAMPLE_DATA
111
114
 
112
115
  def __getitem__(self, i):
113
- return DatasetItem(text=self.data[i]["text"], label=self.data[i]["label"])
116
+ return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
114
117
 
115
118
  def __len__(self):
116
119
  return len(self.data)
@@ -119,7 +122,7 @@ class PytorchDataclassDataset(TorchDataset):
119
122
  def test_hf_dataset_from_torch_dataclass():
120
123
  # Given a Pytorch dataset that returns a dataclass for each item
121
124
  dataset = PytorchDataclassDataset()
122
- hf_dataset = hf_dataset_from_torch(dataset)
125
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
123
126
  # Then the HF dataset should be created successfully
124
127
  assert isinstance(hf_dataset, Dataset)
125
128
  assert len(hf_dataset) == len(dataset)
@@ -131,7 +134,7 @@ class PytorchInvalidDataset(TorchDataset):
131
134
  self.data = SAMPLE_DATA
132
135
 
133
136
  def __getitem__(self, i):
134
- return [self.data[i]["text"], self.data[i]["label"]]
137
+ return [self.data[i]["value"], self.data[i]["label"]]
135
138
 
136
139
  def __len__(self):
137
140
  return len(self.data)
@@ -142,7 +145,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
142
145
  dataset = PytorchInvalidDataset()
143
146
  # Then the HF dataset should raise an error
144
147
  with pytest.raises(DatasetGenerationError):
145
- hf_dataset_from_torch(dataset)
148
+ hf_dataset_from_torch(dataset, ignore_cache=True)
146
149
 
147
150
 
148
151
  def test_hf_dataset_from_torchdataloader():
@@ -150,10 +153,10 @@ def test_hf_dataset_from_torchdataloader():
150
153
  dataset = PytorchDictDataset()
151
154
 
152
155
  def collate_fn(x: list[dict]):
153
- return {"value": [item["text"] for item in x], "label": [item["label"] for item in x]}
156
+ return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
154
157
 
155
158
  dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
156
- hf_dataset = hf_dataset_from_torch(dataloader)
159
+ hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
157
160
  # Then the HF dataset should be created successfully
158
161
  assert isinstance(hf_dataset, Dataset)
159
162
  assert len(hf_dataset) == len(dataset)
@@ -0,0 +1,12 @@
1
+ class TqdmFileReader:
2
+ def __init__(self, file_obj, pbar):
3
+ self.file_obj = file_obj
4
+ self.pbar = pbar
5
+
6
+ def read(self, size=-1):
7
+ data = self.file_obj.read(size)
8
+ self.pbar.update(len(data))
9
+ return data
10
+
11
+ def __getattr__(self, attr):
12
+ return getattr(self.file_obj, attr)
@@ -1,10 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ import os
4
5
  from contextlib import contextmanager
5
6
  from datetime import datetime
6
7
  from typing import Any, Generator, Iterable, Literal, cast, overload
7
- from uuid import UUID
8
+ from uuid import UUID, uuid4
9
+
10
+ import numpy as np
11
+
12
+ import numpy as np
13
+ from datasets import Dataset
14
+ from sklearn.metrics import (
15
+ accuracy_score,
16
+ auc,
17
+ f1_score,
18
+ roc_auc_score,
19
+ )
8
20
 
9
21
  from ._generated_api_client.api import (
10
22
  create_evaluation,
@@ -19,9 +31,11 @@ from ._generated_api_client.api import (
19
31
  update_model,
20
32
  )
21
33
  from ._generated_api_client.models import (
34
+ ClassificationEvaluationResult,
22
35
  CreateRACModelRequest,
23
36
  EvaluationRequest,
24
37
  ListPredictionsRequest,
38
+ PrecisionRecallCurve,
25
39
  )
26
40
  from ._generated_api_client.models import (
27
41
  PredictionSortItemItemType0 as PredictionSortColumns,
@@ -33,8 +47,10 @@ from ._generated_api_client.models import (
33
47
  RACHeadType,
34
48
  RACModelMetadata,
35
49
  RACModelUpdate,
50
+ ROCCurve,
36
51
  )
37
52
  from ._generated_api_client.models.prediction_request import PredictionRequest
53
+ from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
38
54
  from ._utils.common import UNSET, CreateMode, DropMode
39
55
  from ._utils.task import wait_for_task
40
56
  from .datasource import Datasource
@@ -299,7 +315,8 @@ class ClassificationModel:
299
315
  value: list[str],
300
316
  expected_labels: list[int] | None = None,
301
317
  tags: set[str] = set(),
302
- disable_telemetry: bool = False,
318
+ save_telemetry: bool = True,
319
+ save_telemetry_synchronously: bool = False,
303
320
  ) -> list[LabelPrediction]:
304
321
  pass
305
322
 
@@ -309,7 +326,8 @@ class ClassificationModel:
309
326
  value: str,
310
327
  expected_labels: int | None = None,
311
328
  tags: set[str] = set(),
312
- disable_telemetry: bool = False,
329
+ save_telemetry: bool = True,
330
+ save_telemetry_synchronously: bool = False,
313
331
  ) -> LabelPrediction:
314
332
  pass
315
333
 
@@ -318,7 +336,8 @@ class ClassificationModel:
318
336
  value: list[str] | str,
319
337
  expected_labels: list[int] | int | None = None,
320
338
  tags: set[str] = set(),
321
- disable_telemetry: bool = False,
339
+ save_telemetry: bool = True,
340
+ save_telemetry_synchronously: bool = False,
322
341
  ) -> list[LabelPrediction] | LabelPrediction:
323
342
  """
324
343
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -327,7 +346,10 @@ class ClassificationModel:
327
346
  value: Value(s) to get predict the labels of
328
347
  expected_labels: Expected label(s) for the given input to record for model evaluation
329
348
  tags: Tags to add to the prediction(s)
330
- disable_telemetry: Whether to disable telemetry for the prediction(s)
349
+ save_telemetry: Whether to enable telemetry for the prediction(s)
350
+ save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
351
+ asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
352
+ may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
331
353
 
332
354
  Returns:
333
355
  Label prediction or list of label predictions
@@ -345,6 +367,13 @@ class ClassificationModel:
345
367
  ]
346
368
  """
347
369
 
370
+ if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
371
+ env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
372
+ logging.info(
373
+ f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
374
+ )
375
+ save_telemetry_synchronously = env_var.lower() == "true"
376
+
348
377
  response = predict_gpu(
349
378
  self.id,
350
379
  body=PredictionRequest(
@@ -356,11 +385,12 @@ class ClassificationModel:
356
385
  else [expected_labels] if expected_labels is not None else None
357
386
  ),
358
387
  tags=list(tags),
359
- disable_telemetry=disable_telemetry,
388
+ save_telemetry=save_telemetry,
389
+ save_telemetry_synchronously=save_telemetry_synchronously,
360
390
  ),
361
391
  )
362
392
 
363
- if not disable_telemetry and any(p.prediction_id is None for p in response):
393
+ if save_telemetry and any(p.prediction_id is None for p in response):
364
394
  raise RuntimeError("Failed to save prediction to database.")
365
395
 
366
396
  predictions = [
@@ -372,6 +402,7 @@ class ClassificationModel:
372
402
  anomaly_score=prediction.anomaly_score,
373
403
  memoryset=self.memoryset,
374
404
  model=self,
405
+ logits=prediction.logits,
375
406
  )
376
407
  for prediction in response
377
408
  ]
@@ -444,46 +475,158 @@ class ClassificationModel:
444
475
  for prediction in predictions
445
476
  ]
446
477
 
447
- def evaluate(
478
+ def _calculate_metrics(
479
+ self,
480
+ predictions: list[LabelPrediction],
481
+ expected_labels: list[int],
482
+ ) -> ClassificationEvaluationResult:
483
+
484
+ targets_array = np.array(expected_labels)
485
+ predictions_array = np.array([p.label for p in predictions])
486
+
487
+ logits_array = np.array([p.logits for p in predictions])
488
+
489
+ f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
490
+ accuracy = float(accuracy_score(targets_array, predictions_array))
491
+
492
+ # Only compute ROC AUC and PR AUC for binary classification
493
+ unique_classes = np.unique(targets_array)
494
+
495
+ pr_curve = None
496
+ roc_curve = None
497
+
498
+ if len(unique_classes) == 2:
499
+ try:
500
+ precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
501
+ pr_auc = float(auc(recalls, precisions))
502
+
503
+ pr_curve = PrecisionRecallCurve(
504
+ precisions=precisions.tolist(),
505
+ recalls=recalls.tolist(),
506
+ thresholds=pr_thresholds.tolist(),
507
+ auc=pr_auc,
508
+ )
509
+
510
+ fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
511
+ roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
512
+
513
+ roc_curve = ROCCurve(
514
+ false_positive_rates=fpr.tolist(),
515
+ true_positive_rates=tpr.tolist(),
516
+ thresholds=roc_thresholds.tolist(),
517
+ auc=roc_auc,
518
+ )
519
+ except ValueError as e:
520
+ logging.warning(f"Error calculating PR and ROC curves: {e}")
521
+
522
+ return ClassificationEvaluationResult(
523
+ f1_score=f1,
524
+ accuracy=accuracy,
525
+ loss=0.0,
526
+ precision_recall_curve=pr_curve,
527
+ roc_curve=roc_curve,
528
+ )
529
+
530
+ def _evaluate_datasource(
448
531
  self,
449
532
  datasource: Datasource,
533
+ value_column: str,
534
+ label_column: str,
535
+ record_predictions: bool,
536
+ tags: set[str] | None,
537
+ ) -> dict[str, Any]:
538
+ response = create_evaluation(
539
+ self.id,
540
+ body=EvaluationRequest(
541
+ datasource_id=datasource.id,
542
+ datasource_label_column=label_column,
543
+ datasource_value_column=value_column,
544
+ memoryset_override_id=self._memoryset_override_id,
545
+ record_telemetry=record_predictions,
546
+ telemetry_tags=list(tags) if tags else None,
547
+ ),
548
+ )
549
+ wait_for_task(response.task_id, description="Running evaluation")
550
+ response = get_evaluation(self.id, UUID(response.task_id))
551
+ assert response.result is not None
552
+ return response.result.to_dict()
553
+
554
+ def _evaluate_dataset(
555
+ self,
556
+ dataset: Dataset,
557
+ value_column: str,
558
+ label_column: str,
559
+ record_predictions: bool,
560
+ tags: set[str],
561
+ batch_size: int,
562
+ ) -> dict[str, Any]:
563
+ predictions = []
564
+ expected_labels = []
565
+
566
+ for i in range(0, len(dataset), batch_size):
567
+ batch = dataset[i : i + batch_size]
568
+ predictions.extend(
569
+ self.predict(
570
+ batch[value_column],
571
+ expected_labels=batch[label_column],
572
+ tags=tags,
573
+ save_telemetry=record_predictions,
574
+ save_telemetry_synchronously=(not record_predictions),
575
+ )
576
+ )
577
+ expected_labels.extend(batch[label_column])
578
+
579
+ return self._calculate_metrics(predictions, expected_labels).to_dict()
580
+
581
+ def evaluate(
582
+ self,
583
+ data: Datasource | Dataset,
450
584
  value_column: str = "value",
451
585
  label_column: str = "label",
452
586
  record_predictions: bool = False,
453
- tags: set[str] | None = None,
587
+ tags: set[str] = {"evaluation"},
588
+ batch_size: int = 100,
454
589
  ) -> dict[str, Any]:
455
590
  """
456
- Evaluate the classification model on a given datasource
591
+ Evaluate the classification model on a given dataset or datasource
457
592
 
458
593
  Params:
459
- datasource: Datasource to evaluate the model on
594
+ data: Dataset or Datasource to evaluate the model on
460
595
  value_column: Name of the column that contains the input values to the model
461
596
  label_column: Name of the column containing the expected labels
462
597
  record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
463
598
  tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
599
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
464
600
 
465
601
  Returns:
466
- Dictionary with evaluation metrics
602
+ Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
467
603
 
468
604
  Examples:
605
+ Evaluate using a Datasource:
469
606
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
470
607
  { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
608
+
609
+ Evaluate using a Dataset:
610
+ >>> model.evaluate(dataset, value_column="text", label_column="sentiment")
611
+ { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
471
612
  """
472
- response = create_evaluation(
473
- self.id,
474
- body=EvaluationRequest(
475
- datasource_id=datasource.id,
476
- datasource_label_column=label_column,
477
- datasource_value_column=value_column,
478
- memoryset_override_id=self._memoryset_override_id,
479
- record_telemetry=record_predictions,
480
- telemetry_tags=list(tags) if tags else None,
481
- ),
482
- )
483
- wait_for_task(response.task_id, description="Running evaluation")
484
- response = get_evaluation(self.id, UUID(response.task_id))
485
- assert response.result is not None
486
- return response.result.to_dict()
613
+ if isinstance(data, Datasource):
614
+ return self._evaluate_datasource(
615
+ datasource=data,
616
+ value_column=value_column,
617
+ label_column=label_column,
618
+ record_predictions=record_predictions,
619
+ tags=tags,
620
+ )
621
+ else:
622
+ return self._evaluate_dataset(
623
+ dataset=data,
624
+ value_column=value_column,
625
+ label_column=label_column,
626
+ record_predictions=record_predictions,
627
+ tags=tags,
628
+ batch_size=batch_size,
629
+ )
487
630
 
488
631
  def finetune(self, datasource: Datasource):
489
632
  # do not document until implemented