orca-sdk 0.0.91__py3-none-any.whl → 0.0.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +4 -0
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
- orca_sdk/_generated_api_client/models/__init__.py +4 -0
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
- orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_shared/metrics.py +195 -0
- orca_sdk/_shared/metrics_test.py +169 -0
- orca_sdk/_utils/data_parsing.py +31 -2
- orca_sdk/_utils/data_parsing_test.py +18 -15
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/classification_model.py +170 -27
- orca_sdk/classification_model_test.py +74 -32
- orca_sdk/conftest.py +86 -25
- orca_sdk/datasource.py +22 -12
- orca_sdk/embedding_model_test.py +6 -5
- orca_sdk/memoryset.py +78 -0
- orca_sdk/memoryset_test.py +197 -123
- orca_sdk/telemetry.py +3 -0
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/METADATA +3 -1
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/RECORD +32 -25
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IMPORTANT:
|
|
3
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
4
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
5
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pytest
|
|
12
|
+
|
|
13
|
+
from .metrics import (
|
|
14
|
+
EvalPrediction,
|
|
15
|
+
calculate_pr_curve,
|
|
16
|
+
calculate_roc_curve,
|
|
17
|
+
classification_scores,
|
|
18
|
+
compute_classifier_metrics,
|
|
19
|
+
softmax,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_binary_metrics():
|
|
24
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
25
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
26
|
+
|
|
27
|
+
metrics = classification_scores(y_true, y_score)
|
|
28
|
+
|
|
29
|
+
assert metrics["accuracy"] == 0.8
|
|
30
|
+
assert metrics["f1_score"] == 0.8
|
|
31
|
+
assert metrics["roc_auc"] is not None
|
|
32
|
+
assert metrics["roc_auc"] > 0.8
|
|
33
|
+
assert metrics["roc_auc"] < 1.0
|
|
34
|
+
assert metrics["pr_auc"] is not None
|
|
35
|
+
assert metrics["pr_auc"] > 0.8
|
|
36
|
+
assert metrics["pr_auc"] < 1.0
|
|
37
|
+
assert metrics["log_loss"] is not None
|
|
38
|
+
assert metrics["log_loss"] > 0.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_multiclass_metrics_with_2_classes():
|
|
42
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
43
|
+
y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
44
|
+
|
|
45
|
+
metrics = classification_scores(y_true, y_score)
|
|
46
|
+
|
|
47
|
+
assert metrics["accuracy"] == 0.8
|
|
48
|
+
assert metrics["f1_score"] == 0.8
|
|
49
|
+
assert metrics["roc_auc"] is not None
|
|
50
|
+
assert metrics["roc_auc"] > 0.8
|
|
51
|
+
assert metrics["roc_auc"] < 1.0
|
|
52
|
+
assert metrics["pr_auc"] is not None
|
|
53
|
+
assert metrics["pr_auc"] > 0.8
|
|
54
|
+
assert metrics["pr_auc"] < 1.0
|
|
55
|
+
assert metrics["log_loss"] is not None
|
|
56
|
+
assert metrics["log_loss"] > 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.parametrize(
|
|
60
|
+
"average, multiclass",
|
|
61
|
+
[("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
|
|
62
|
+
)
|
|
63
|
+
def test_multiclass_metrics_with_3_classes(
|
|
64
|
+
average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
|
|
65
|
+
):
|
|
66
|
+
y_true = np.array([0, 1, 1, 0, 2])
|
|
67
|
+
y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
|
|
68
|
+
|
|
69
|
+
metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
|
|
70
|
+
|
|
71
|
+
assert metrics["accuracy"] == 1.0
|
|
72
|
+
assert metrics["f1_score"] == 1.0
|
|
73
|
+
assert metrics["roc_auc"] is not None
|
|
74
|
+
assert metrics["roc_auc"] > 0.8
|
|
75
|
+
assert metrics["pr_auc"] is None
|
|
76
|
+
assert metrics["log_loss"] is not None
|
|
77
|
+
assert metrics["log_loss"] > 0.0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_does_not_modify_logits_unless_necessary():
|
|
81
|
+
logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
82
|
+
references = np.array([0, 1, 0, 1])
|
|
83
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
84
|
+
assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_normalizes_logits_if_necessary():
|
|
88
|
+
logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
89
|
+
references = np.array([0, 1, 0, 1])
|
|
90
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
91
|
+
assert (
|
|
92
|
+
metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_softmaxes_logits_if_necessary():
|
|
97
|
+
logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
98
|
+
references = np.array([0, 1, 0, 1])
|
|
99
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
100
|
+
assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_precision_recall_curve():
|
|
104
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
105
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
106
|
+
|
|
107
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
|
|
108
|
+
assert precision is not None
|
|
109
|
+
assert recall is not None
|
|
110
|
+
assert thresholds is not None
|
|
111
|
+
|
|
112
|
+
assert len(precision) == len(recall) == len(thresholds) == 6
|
|
113
|
+
assert precision[0] == 0.6
|
|
114
|
+
assert recall[0] == 1.0
|
|
115
|
+
assert precision[-1] == 1.0
|
|
116
|
+
assert recall[-1] == 0.0
|
|
117
|
+
|
|
118
|
+
# test that thresholds are sorted
|
|
119
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_roc_curve():
|
|
123
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
124
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
125
|
+
|
|
126
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
|
|
127
|
+
assert fpr is not None
|
|
128
|
+
assert tpr is not None
|
|
129
|
+
assert thresholds is not None
|
|
130
|
+
|
|
131
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 6
|
|
132
|
+
assert fpr[0] == 1.0
|
|
133
|
+
assert tpr[0] == 1.0
|
|
134
|
+
assert fpr[-1] == 0.0
|
|
135
|
+
assert tpr[-1] == 0.0
|
|
136
|
+
|
|
137
|
+
# test that thresholds are sorted
|
|
138
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_precision_recall_curve_max_length():
|
|
142
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
143
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
144
|
+
|
|
145
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
|
|
146
|
+
assert len(precision) == len(recall) == len(thresholds) == 5
|
|
147
|
+
|
|
148
|
+
assert precision[0] == 0.6
|
|
149
|
+
assert recall[0] == 1.0
|
|
150
|
+
assert precision[-1] == 1.0
|
|
151
|
+
assert recall[-1] == 0.0
|
|
152
|
+
|
|
153
|
+
# test that thresholds are sorted
|
|
154
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_roc_curve_max_length():
|
|
158
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
159
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
160
|
+
|
|
161
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
|
|
162
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 5
|
|
163
|
+
assert fpr[0] == 1.0
|
|
164
|
+
assert tpr[0] == 1.0
|
|
165
|
+
assert fpr[-1] == 0.0
|
|
166
|
+
assert tpr[-1] == 0.0
|
|
167
|
+
|
|
168
|
+
# test that thresholds are sorted
|
|
169
|
+
assert np.all(np.diff(thresholds) >= 0)
|
orca_sdk/_utils/data_parsing.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import pickle
|
|
2
3
|
from dataclasses import asdict, is_dataclass
|
|
3
4
|
from os import PathLike
|
|
5
|
+
from tempfile import TemporaryDirectory
|
|
4
6
|
from typing import Any, cast
|
|
5
7
|
|
|
6
8
|
from datasets import Dataset
|
|
7
9
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
8
10
|
from torch.utils.data import Dataset as TorchDataset
|
|
9
11
|
|
|
12
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
13
|
+
|
|
10
14
|
|
|
11
15
|
def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
|
|
12
16
|
if isinstance(item, dict):
|
|
@@ -40,7 +44,24 @@ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]
|
|
|
40
44
|
return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
|
|
41
45
|
|
|
42
46
|
|
|
43
|
-
def hf_dataset_from_torch(
|
|
47
|
+
def hf_dataset_from_torch(
|
|
48
|
+
torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None, ignore_cache=False
|
|
49
|
+
) -> Dataset:
|
|
50
|
+
"""
|
|
51
|
+
Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
|
|
52
|
+
|
|
53
|
+
NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
|
|
54
|
+
cached results can ignore changes you've made to tests. This can make a test appear to succeed
|
|
55
|
+
when it's actually broken or vice versa.
|
|
56
|
+
|
|
57
|
+
Params:
|
|
58
|
+
torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
|
|
59
|
+
column_names: Optional list of column names to use for the dataset. If not provided,
|
|
60
|
+
the column names will be inferred from the data.
|
|
61
|
+
ignore_cache: If True, the dataset will not be cached on disk.
|
|
62
|
+
Returns:
|
|
63
|
+
A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
|
|
64
|
+
"""
|
|
44
65
|
if isinstance(torch_data, TorchDataLoader):
|
|
45
66
|
dataloader = torch_data
|
|
46
67
|
else:
|
|
@@ -50,7 +71,15 @@ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_nam
|
|
|
50
71
|
for batch in dataloader:
|
|
51
72
|
yield from parse_batch(batch, column_names=column_names)
|
|
52
73
|
|
|
53
|
-
|
|
74
|
+
if ignore_cache:
|
|
75
|
+
with TemporaryDirectory() as temp_dir:
|
|
76
|
+
ds = Dataset.from_generator(generator, cache_dir=temp_dir)
|
|
77
|
+
else:
|
|
78
|
+
ds = Dataset.from_generator(generator)
|
|
79
|
+
|
|
80
|
+
if not isinstance(ds, Dataset):
|
|
81
|
+
raise ValueError(f"Failed to create dataset from generator: {type(ds)}")
|
|
82
|
+
return ds
|
|
54
83
|
|
|
55
84
|
|
|
56
85
|
def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import pickle
|
|
3
4
|
import tempfile
|
|
4
5
|
from collections import namedtuple
|
|
@@ -14,6 +15,8 @@ from torch.utils.data import Dataset as TorchDataset
|
|
|
14
15
|
from ..conftest import SAMPLE_DATA
|
|
15
16
|
from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
|
|
16
17
|
|
|
18
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
19
|
+
|
|
17
20
|
|
|
18
21
|
class PytorchDictDataset(TorchDataset):
|
|
19
22
|
def __init__(self):
|
|
@@ -29,11 +32,11 @@ class PytorchDictDataset(TorchDataset):
|
|
|
29
32
|
def test_hf_dataset_from_torch_dict():
|
|
30
33
|
# Given a Pytorch dataset that returns a dictionary for each item
|
|
31
34
|
dataset = PytorchDictDataset()
|
|
32
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
35
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
33
36
|
# Then the HF dataset should be created successfully
|
|
34
37
|
assert isinstance(hf_dataset, Dataset)
|
|
35
38
|
assert len(hf_dataset) == len(dataset)
|
|
36
|
-
assert set(hf_dataset.column_names) == {"
|
|
39
|
+
assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
class PytorchTupleDataset(TorchDataset):
|
|
@@ -41,7 +44,7 @@ class PytorchTupleDataset(TorchDataset):
|
|
|
41
44
|
self.data = SAMPLE_DATA
|
|
42
45
|
|
|
43
46
|
def __getitem__(self, i):
|
|
44
|
-
return self.data[i]["
|
|
47
|
+
return self.data[i]["value"], self.data[i]["label"]
|
|
45
48
|
|
|
46
49
|
def __len__(self):
|
|
47
50
|
return len(self.data)
|
|
@@ -51,11 +54,11 @@ def test_hf_dataset_from_torch_tuple():
|
|
|
51
54
|
# Given a Pytorch dataset that returns a tuple for each item
|
|
52
55
|
dataset = PytorchTupleDataset()
|
|
53
56
|
# And the correct number of column names passed in
|
|
54
|
-
hf_dataset = hf_dataset_from_torch(dataset, column_names=["
|
|
57
|
+
hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
|
|
55
58
|
# Then the HF dataset should be created successfully
|
|
56
59
|
assert isinstance(hf_dataset, Dataset)
|
|
57
60
|
assert len(hf_dataset) == len(dataset)
|
|
58
|
-
assert hf_dataset.column_names == ["
|
|
61
|
+
assert hf_dataset.column_names == ["value", "label"]
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
def test_hf_dataset_from_torch_tuple_error():
|
|
@@ -63,7 +66,7 @@ def test_hf_dataset_from_torch_tuple_error():
|
|
|
63
66
|
dataset = PytorchTupleDataset()
|
|
64
67
|
# Then the HF dataset should raise an error if no column names are passed in
|
|
65
68
|
with pytest.raises(DatasetGenerationError):
|
|
66
|
-
hf_dataset_from_torch(dataset)
|
|
69
|
+
hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
67
70
|
|
|
68
71
|
|
|
69
72
|
def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
@@ -71,7 +74,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
|
71
74
|
dataset = PytorchTupleDataset()
|
|
72
75
|
# Then the HF dataset should raise an error if not enough column names are passed in
|
|
73
76
|
with pytest.raises(DatasetGenerationError):
|
|
74
|
-
hf_dataset_from_torch(dataset, column_names=["value"])
|
|
77
|
+
hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
|
|
@@ -82,7 +85,7 @@ class PytorchNamedTupleDataset(TorchDataset):
|
|
|
82
85
|
self.data = SAMPLE_DATA
|
|
83
86
|
|
|
84
87
|
def __getitem__(self, i):
|
|
85
|
-
return DatasetTuple(self.data[i]["
|
|
88
|
+
return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
|
|
86
89
|
|
|
87
90
|
def __len__(self):
|
|
88
91
|
return len(self.data)
|
|
@@ -92,7 +95,7 @@ def test_hf_dataset_from_torch_named_tuple():
|
|
|
92
95
|
# Given a Pytorch dataset that returns a namedtuple for each item
|
|
93
96
|
dataset = PytorchNamedTupleDataset()
|
|
94
97
|
# And no column names are passed in
|
|
95
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
98
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
96
99
|
# Then the HF dataset should be created successfully
|
|
97
100
|
assert isinstance(hf_dataset, Dataset)
|
|
98
101
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -110,7 +113,7 @@ class PytorchDataclassDataset(TorchDataset):
|
|
|
110
113
|
self.data = SAMPLE_DATA
|
|
111
114
|
|
|
112
115
|
def __getitem__(self, i):
|
|
113
|
-
return DatasetItem(text=self.data[i]["
|
|
116
|
+
return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
|
|
114
117
|
|
|
115
118
|
def __len__(self):
|
|
116
119
|
return len(self.data)
|
|
@@ -119,7 +122,7 @@ class PytorchDataclassDataset(TorchDataset):
|
|
|
119
122
|
def test_hf_dataset_from_torch_dataclass():
|
|
120
123
|
# Given a Pytorch dataset that returns a dataclass for each item
|
|
121
124
|
dataset = PytorchDataclassDataset()
|
|
122
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
125
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
123
126
|
# Then the HF dataset should be created successfully
|
|
124
127
|
assert isinstance(hf_dataset, Dataset)
|
|
125
128
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -131,7 +134,7 @@ class PytorchInvalidDataset(TorchDataset):
|
|
|
131
134
|
self.data = SAMPLE_DATA
|
|
132
135
|
|
|
133
136
|
def __getitem__(self, i):
|
|
134
|
-
return [self.data[i]["
|
|
137
|
+
return [self.data[i]["value"], self.data[i]["label"]]
|
|
135
138
|
|
|
136
139
|
def __len__(self):
|
|
137
140
|
return len(self.data)
|
|
@@ -142,7 +145,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
|
|
|
142
145
|
dataset = PytorchInvalidDataset()
|
|
143
146
|
# Then the HF dataset should raise an error
|
|
144
147
|
with pytest.raises(DatasetGenerationError):
|
|
145
|
-
hf_dataset_from_torch(dataset)
|
|
148
|
+
hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
146
149
|
|
|
147
150
|
|
|
148
151
|
def test_hf_dataset_from_torchdataloader():
|
|
@@ -150,10 +153,10 @@ def test_hf_dataset_from_torchdataloader():
|
|
|
150
153
|
dataset = PytorchDictDataset()
|
|
151
154
|
|
|
152
155
|
def collate_fn(x: list[dict]):
|
|
153
|
-
return {"value": [item["
|
|
156
|
+
return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
|
|
154
157
|
|
|
155
158
|
dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
|
|
156
|
-
hf_dataset = hf_dataset_from_torch(dataloader)
|
|
159
|
+
hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
|
|
157
160
|
# Then the HF dataset should be created successfully
|
|
158
161
|
assert isinstance(hf_dataset, Dataset)
|
|
159
162
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
class TqdmFileReader:
|
|
2
|
+
def __init__(self, file_obj, pbar):
|
|
3
|
+
self.file_obj = file_obj
|
|
4
|
+
self.pbar = pbar
|
|
5
|
+
|
|
6
|
+
def read(self, size=-1):
|
|
7
|
+
data = self.file_obj.read(size)
|
|
8
|
+
self.pbar.update(len(data))
|
|
9
|
+
return data
|
|
10
|
+
|
|
11
|
+
def __getattr__(self, attr):
|
|
12
|
+
return getattr(self.file_obj, attr)
|
orca_sdk/classification_model.py
CHANGED
|
@@ -1,10 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import os
|
|
4
5
|
from contextlib import contextmanager
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
|
-
from uuid import UUID
|
|
8
|
+
from uuid import UUID, uuid4
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from datasets import Dataset
|
|
14
|
+
from sklearn.metrics import (
|
|
15
|
+
accuracy_score,
|
|
16
|
+
auc,
|
|
17
|
+
f1_score,
|
|
18
|
+
roc_auc_score,
|
|
19
|
+
)
|
|
8
20
|
|
|
9
21
|
from ._generated_api_client.api import (
|
|
10
22
|
create_evaluation,
|
|
@@ -19,9 +31,11 @@ from ._generated_api_client.api import (
|
|
|
19
31
|
update_model,
|
|
20
32
|
)
|
|
21
33
|
from ._generated_api_client.models import (
|
|
34
|
+
ClassificationEvaluationResult,
|
|
22
35
|
CreateRACModelRequest,
|
|
23
36
|
EvaluationRequest,
|
|
24
37
|
ListPredictionsRequest,
|
|
38
|
+
PrecisionRecallCurve,
|
|
25
39
|
)
|
|
26
40
|
from ._generated_api_client.models import (
|
|
27
41
|
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
@@ -33,8 +47,10 @@ from ._generated_api_client.models import (
|
|
|
33
47
|
RACHeadType,
|
|
34
48
|
RACModelMetadata,
|
|
35
49
|
RACModelUpdate,
|
|
50
|
+
ROCCurve,
|
|
36
51
|
)
|
|
37
52
|
from ._generated_api_client.models.prediction_request import PredictionRequest
|
|
53
|
+
from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
|
|
38
54
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
39
55
|
from ._utils.task import wait_for_task
|
|
40
56
|
from .datasource import Datasource
|
|
@@ -299,7 +315,8 @@ class ClassificationModel:
|
|
|
299
315
|
value: list[str],
|
|
300
316
|
expected_labels: list[int] | None = None,
|
|
301
317
|
tags: set[str] = set(),
|
|
302
|
-
|
|
318
|
+
save_telemetry: bool = True,
|
|
319
|
+
save_telemetry_synchronously: bool = False,
|
|
303
320
|
) -> list[LabelPrediction]:
|
|
304
321
|
pass
|
|
305
322
|
|
|
@@ -309,7 +326,8 @@ class ClassificationModel:
|
|
|
309
326
|
value: str,
|
|
310
327
|
expected_labels: int | None = None,
|
|
311
328
|
tags: set[str] = set(),
|
|
312
|
-
|
|
329
|
+
save_telemetry: bool = True,
|
|
330
|
+
save_telemetry_synchronously: bool = False,
|
|
313
331
|
) -> LabelPrediction:
|
|
314
332
|
pass
|
|
315
333
|
|
|
@@ -318,7 +336,8 @@ class ClassificationModel:
|
|
|
318
336
|
value: list[str] | str,
|
|
319
337
|
expected_labels: list[int] | int | None = None,
|
|
320
338
|
tags: set[str] = set(),
|
|
321
|
-
|
|
339
|
+
save_telemetry: bool = True,
|
|
340
|
+
save_telemetry_synchronously: bool = False,
|
|
322
341
|
) -> list[LabelPrediction] | LabelPrediction:
|
|
323
342
|
"""
|
|
324
343
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -327,7 +346,10 @@ class ClassificationModel:
|
|
|
327
346
|
value: Value(s) to get predict the labels of
|
|
328
347
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
329
348
|
tags: Tags to add to the prediction(s)
|
|
330
|
-
|
|
349
|
+
save_telemetry: Whether to enable telemetry for the prediction(s)
|
|
350
|
+
save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
|
|
351
|
+
asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
|
|
352
|
+
may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
|
|
331
353
|
|
|
332
354
|
Returns:
|
|
333
355
|
Label prediction or list of label predictions
|
|
@@ -345,6 +367,13 @@ class ClassificationModel:
|
|
|
345
367
|
]
|
|
346
368
|
"""
|
|
347
369
|
|
|
370
|
+
if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
|
|
371
|
+
env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
|
|
372
|
+
logging.info(
|
|
373
|
+
f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
|
|
374
|
+
)
|
|
375
|
+
save_telemetry_synchronously = env_var.lower() == "true"
|
|
376
|
+
|
|
348
377
|
response = predict_gpu(
|
|
349
378
|
self.id,
|
|
350
379
|
body=PredictionRequest(
|
|
@@ -356,11 +385,12 @@ class ClassificationModel:
|
|
|
356
385
|
else [expected_labels] if expected_labels is not None else None
|
|
357
386
|
),
|
|
358
387
|
tags=list(tags),
|
|
359
|
-
|
|
388
|
+
save_telemetry=save_telemetry,
|
|
389
|
+
save_telemetry_synchronously=save_telemetry_synchronously,
|
|
360
390
|
),
|
|
361
391
|
)
|
|
362
392
|
|
|
363
|
-
if
|
|
393
|
+
if save_telemetry and any(p.prediction_id is None for p in response):
|
|
364
394
|
raise RuntimeError("Failed to save prediction to database.")
|
|
365
395
|
|
|
366
396
|
predictions = [
|
|
@@ -372,6 +402,7 @@ class ClassificationModel:
|
|
|
372
402
|
anomaly_score=prediction.anomaly_score,
|
|
373
403
|
memoryset=self.memoryset,
|
|
374
404
|
model=self,
|
|
405
|
+
logits=prediction.logits,
|
|
375
406
|
)
|
|
376
407
|
for prediction in response
|
|
377
408
|
]
|
|
@@ -444,46 +475,158 @@ class ClassificationModel:
|
|
|
444
475
|
for prediction in predictions
|
|
445
476
|
]
|
|
446
477
|
|
|
447
|
-
def
|
|
478
|
+
def _calculate_metrics(
|
|
479
|
+
self,
|
|
480
|
+
predictions: list[LabelPrediction],
|
|
481
|
+
expected_labels: list[int],
|
|
482
|
+
) -> ClassificationEvaluationResult:
|
|
483
|
+
|
|
484
|
+
targets_array = np.array(expected_labels)
|
|
485
|
+
predictions_array = np.array([p.label for p in predictions])
|
|
486
|
+
|
|
487
|
+
logits_array = np.array([p.logits for p in predictions])
|
|
488
|
+
|
|
489
|
+
f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
|
|
490
|
+
accuracy = float(accuracy_score(targets_array, predictions_array))
|
|
491
|
+
|
|
492
|
+
# Only compute ROC AUC and PR AUC for binary classification
|
|
493
|
+
unique_classes = np.unique(targets_array)
|
|
494
|
+
|
|
495
|
+
pr_curve = None
|
|
496
|
+
roc_curve = None
|
|
497
|
+
|
|
498
|
+
if len(unique_classes) == 2:
|
|
499
|
+
try:
|
|
500
|
+
precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
|
|
501
|
+
pr_auc = float(auc(recalls, precisions))
|
|
502
|
+
|
|
503
|
+
pr_curve = PrecisionRecallCurve(
|
|
504
|
+
precisions=precisions.tolist(),
|
|
505
|
+
recalls=recalls.tolist(),
|
|
506
|
+
thresholds=pr_thresholds.tolist(),
|
|
507
|
+
auc=pr_auc,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
|
|
511
|
+
roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
|
|
512
|
+
|
|
513
|
+
roc_curve = ROCCurve(
|
|
514
|
+
false_positive_rates=fpr.tolist(),
|
|
515
|
+
true_positive_rates=tpr.tolist(),
|
|
516
|
+
thresholds=roc_thresholds.tolist(),
|
|
517
|
+
auc=roc_auc,
|
|
518
|
+
)
|
|
519
|
+
except ValueError as e:
|
|
520
|
+
logging.warning(f"Error calculating PR and ROC curves: {e}")
|
|
521
|
+
|
|
522
|
+
return ClassificationEvaluationResult(
|
|
523
|
+
f1_score=f1,
|
|
524
|
+
accuracy=accuracy,
|
|
525
|
+
loss=0.0,
|
|
526
|
+
precision_recall_curve=pr_curve,
|
|
527
|
+
roc_curve=roc_curve,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def _evaluate_datasource(
|
|
448
531
|
self,
|
|
449
532
|
datasource: Datasource,
|
|
533
|
+
value_column: str,
|
|
534
|
+
label_column: str,
|
|
535
|
+
record_predictions: bool,
|
|
536
|
+
tags: set[str] | None,
|
|
537
|
+
) -> dict[str, Any]:
|
|
538
|
+
response = create_evaluation(
|
|
539
|
+
self.id,
|
|
540
|
+
body=EvaluationRequest(
|
|
541
|
+
datasource_id=datasource.id,
|
|
542
|
+
datasource_label_column=label_column,
|
|
543
|
+
datasource_value_column=value_column,
|
|
544
|
+
memoryset_override_id=self._memoryset_override_id,
|
|
545
|
+
record_telemetry=record_predictions,
|
|
546
|
+
telemetry_tags=list(tags) if tags else None,
|
|
547
|
+
),
|
|
548
|
+
)
|
|
549
|
+
wait_for_task(response.task_id, description="Running evaluation")
|
|
550
|
+
response = get_evaluation(self.id, UUID(response.task_id))
|
|
551
|
+
assert response.result is not None
|
|
552
|
+
return response.result.to_dict()
|
|
553
|
+
|
|
554
|
+
def _evaluate_dataset(
|
|
555
|
+
self,
|
|
556
|
+
dataset: Dataset,
|
|
557
|
+
value_column: str,
|
|
558
|
+
label_column: str,
|
|
559
|
+
record_predictions: bool,
|
|
560
|
+
tags: set[str],
|
|
561
|
+
batch_size: int,
|
|
562
|
+
) -> dict[str, Any]:
|
|
563
|
+
predictions = []
|
|
564
|
+
expected_labels = []
|
|
565
|
+
|
|
566
|
+
for i in range(0, len(dataset), batch_size):
|
|
567
|
+
batch = dataset[i : i + batch_size]
|
|
568
|
+
predictions.extend(
|
|
569
|
+
self.predict(
|
|
570
|
+
batch[value_column],
|
|
571
|
+
expected_labels=batch[label_column],
|
|
572
|
+
tags=tags,
|
|
573
|
+
save_telemetry=record_predictions,
|
|
574
|
+
save_telemetry_synchronously=(not record_predictions),
|
|
575
|
+
)
|
|
576
|
+
)
|
|
577
|
+
expected_labels.extend(batch[label_column])
|
|
578
|
+
|
|
579
|
+
return self._calculate_metrics(predictions, expected_labels).to_dict()
|
|
580
|
+
|
|
581
|
+
def evaluate(
|
|
582
|
+
self,
|
|
583
|
+
data: Datasource | Dataset,
|
|
450
584
|
value_column: str = "value",
|
|
451
585
|
label_column: str = "label",
|
|
452
586
|
record_predictions: bool = False,
|
|
453
|
-
tags: set[str]
|
|
587
|
+
tags: set[str] = {"evaluation"},
|
|
588
|
+
batch_size: int = 100,
|
|
454
589
|
) -> dict[str, Any]:
|
|
455
590
|
"""
|
|
456
|
-
Evaluate the classification model on a given datasource
|
|
591
|
+
Evaluate the classification model on a given dataset or datasource
|
|
457
592
|
|
|
458
593
|
Params:
|
|
459
|
-
|
|
594
|
+
data: Dataset or Datasource to evaluate the model on
|
|
460
595
|
value_column: Name of the column that contains the input values to the model
|
|
461
596
|
label_column: Name of the column containing the expected labels
|
|
462
597
|
record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
|
|
463
598
|
tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
|
|
599
|
+
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
464
600
|
|
|
465
601
|
Returns:
|
|
466
|
-
Dictionary with evaluation metrics
|
|
602
|
+
Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
|
|
467
603
|
|
|
468
604
|
Examples:
|
|
605
|
+
Evaluate using a Datasource:
|
|
469
606
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
470
607
|
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
|
|
608
|
+
|
|
609
|
+
Evaluate using a Dataset:
|
|
610
|
+
>>> model.evaluate(dataset, value_column="text", label_column="sentiment")
|
|
611
|
+
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
|
|
471
612
|
"""
|
|
472
|
-
|
|
473
|
-
self.
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
613
|
+
if isinstance(data, Datasource):
|
|
614
|
+
return self._evaluate_datasource(
|
|
615
|
+
datasource=data,
|
|
616
|
+
value_column=value_column,
|
|
617
|
+
label_column=label_column,
|
|
618
|
+
record_predictions=record_predictions,
|
|
619
|
+
tags=tags,
|
|
620
|
+
)
|
|
621
|
+
else:
|
|
622
|
+
return self._evaluate_dataset(
|
|
623
|
+
dataset=data,
|
|
624
|
+
value_column=value_column,
|
|
625
|
+
label_column=label_column,
|
|
626
|
+
record_predictions=record_predictions,
|
|
627
|
+
tags=tags,
|
|
628
|
+
batch_size=batch_size,
|
|
629
|
+
)
|
|
487
630
|
|
|
488
631
|
def finetune(self, datasource: Datasource):
|
|
489
632
|
# do not document until implemented
|