orca-sdk 0.0.92__py3-none-any.whl → 0.0.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +8 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/models/__init__.py +10 -0
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_utils/data_parsing.py +31 -2
- orca_sdk/_utils/data_parsing_test.py +18 -15
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/classification_model.py +32 -12
- orca_sdk/classification_model_test.py +95 -34
- orca_sdk/conftest.py +87 -25
- orca_sdk/datasource.py +56 -12
- orca_sdk/datasource_test.py +9 -0
- orca_sdk/embedding_model_test.py +6 -5
- orca_sdk/memoryset.py +78 -0
- orca_sdk/memoryset_test.py +199 -123
- orca_sdk/telemetry.py +5 -3
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +36 -28
- {orca_sdk-0.0.92.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import pickle
|
|
3
4
|
import tempfile
|
|
4
5
|
from collections import namedtuple
|
|
@@ -14,6 +15,8 @@ from torch.utils.data import Dataset as TorchDataset
|
|
|
14
15
|
from ..conftest import SAMPLE_DATA
|
|
15
16
|
from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
|
|
16
17
|
|
|
18
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
19
|
+
|
|
17
20
|
|
|
18
21
|
class PytorchDictDataset(TorchDataset):
|
|
19
22
|
def __init__(self):
|
|
@@ -29,11 +32,11 @@ class PytorchDictDataset(TorchDataset):
|
|
|
29
32
|
def test_hf_dataset_from_torch_dict():
|
|
30
33
|
# Given a Pytorch dataset that returns a dictionary for each item
|
|
31
34
|
dataset = PytorchDictDataset()
|
|
32
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
35
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
33
36
|
# Then the HF dataset should be created successfully
|
|
34
37
|
assert isinstance(hf_dataset, Dataset)
|
|
35
38
|
assert len(hf_dataset) == len(dataset)
|
|
36
|
-
assert set(hf_dataset.column_names) == {"
|
|
39
|
+
assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
class PytorchTupleDataset(TorchDataset):
|
|
@@ -41,7 +44,7 @@ class PytorchTupleDataset(TorchDataset):
|
|
|
41
44
|
self.data = SAMPLE_DATA
|
|
42
45
|
|
|
43
46
|
def __getitem__(self, i):
|
|
44
|
-
return self.data[i]["
|
|
47
|
+
return self.data[i]["value"], self.data[i]["label"]
|
|
45
48
|
|
|
46
49
|
def __len__(self):
|
|
47
50
|
return len(self.data)
|
|
@@ -51,11 +54,11 @@ def test_hf_dataset_from_torch_tuple():
|
|
|
51
54
|
# Given a Pytorch dataset that returns a tuple for each item
|
|
52
55
|
dataset = PytorchTupleDataset()
|
|
53
56
|
# And the correct number of column names passed in
|
|
54
|
-
hf_dataset = hf_dataset_from_torch(dataset, column_names=["
|
|
57
|
+
hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
|
|
55
58
|
# Then the HF dataset should be created successfully
|
|
56
59
|
assert isinstance(hf_dataset, Dataset)
|
|
57
60
|
assert len(hf_dataset) == len(dataset)
|
|
58
|
-
assert hf_dataset.column_names == ["
|
|
61
|
+
assert hf_dataset.column_names == ["value", "label"]
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
def test_hf_dataset_from_torch_tuple_error():
|
|
@@ -63,7 +66,7 @@ def test_hf_dataset_from_torch_tuple_error():
|
|
|
63
66
|
dataset = PytorchTupleDataset()
|
|
64
67
|
# Then the HF dataset should raise an error if no column names are passed in
|
|
65
68
|
with pytest.raises(DatasetGenerationError):
|
|
66
|
-
hf_dataset_from_torch(dataset)
|
|
69
|
+
hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
67
70
|
|
|
68
71
|
|
|
69
72
|
def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
@@ -71,7 +74,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
|
71
74
|
dataset = PytorchTupleDataset()
|
|
72
75
|
# Then the HF dataset should raise an error if not enough column names are passed in
|
|
73
76
|
with pytest.raises(DatasetGenerationError):
|
|
74
|
-
hf_dataset_from_torch(dataset, column_names=["value"])
|
|
77
|
+
hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
|
|
@@ -82,7 +85,7 @@ class PytorchNamedTupleDataset(TorchDataset):
|
|
|
82
85
|
self.data = SAMPLE_DATA
|
|
83
86
|
|
|
84
87
|
def __getitem__(self, i):
|
|
85
|
-
return DatasetTuple(self.data[i]["
|
|
88
|
+
return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
|
|
86
89
|
|
|
87
90
|
def __len__(self):
|
|
88
91
|
return len(self.data)
|
|
@@ -92,7 +95,7 @@ def test_hf_dataset_from_torch_named_tuple():
|
|
|
92
95
|
# Given a Pytorch dataset that returns a namedtuple for each item
|
|
93
96
|
dataset = PytorchNamedTupleDataset()
|
|
94
97
|
# And no column names are passed in
|
|
95
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
98
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
96
99
|
# Then the HF dataset should be created successfully
|
|
97
100
|
assert isinstance(hf_dataset, Dataset)
|
|
98
101
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -110,7 +113,7 @@ class PytorchDataclassDataset(TorchDataset):
|
|
|
110
113
|
self.data = SAMPLE_DATA
|
|
111
114
|
|
|
112
115
|
def __getitem__(self, i):
|
|
113
|
-
return DatasetItem(text=self.data[i]["
|
|
116
|
+
return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
|
|
114
117
|
|
|
115
118
|
def __len__(self):
|
|
116
119
|
return len(self.data)
|
|
@@ -119,7 +122,7 @@ class PytorchDataclassDataset(TorchDataset):
|
|
|
119
122
|
def test_hf_dataset_from_torch_dataclass():
|
|
120
123
|
# Given a Pytorch dataset that returns a dataclass for each item
|
|
121
124
|
dataset = PytorchDataclassDataset()
|
|
122
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
125
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
123
126
|
# Then the HF dataset should be created successfully
|
|
124
127
|
assert isinstance(hf_dataset, Dataset)
|
|
125
128
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -131,7 +134,7 @@ class PytorchInvalidDataset(TorchDataset):
|
|
|
131
134
|
self.data = SAMPLE_DATA
|
|
132
135
|
|
|
133
136
|
def __getitem__(self, i):
|
|
134
|
-
return [self.data[i]["
|
|
137
|
+
return [self.data[i]["value"], self.data[i]["label"]]
|
|
135
138
|
|
|
136
139
|
def __len__(self):
|
|
137
140
|
return len(self.data)
|
|
@@ -142,7 +145,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
|
|
|
142
145
|
dataset = PytorchInvalidDataset()
|
|
143
146
|
# Then the HF dataset should raise an error
|
|
144
147
|
with pytest.raises(DatasetGenerationError):
|
|
145
|
-
hf_dataset_from_torch(dataset)
|
|
148
|
+
hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
146
149
|
|
|
147
150
|
|
|
148
151
|
def test_hf_dataset_from_torchdataloader():
|
|
@@ -150,10 +153,10 @@ def test_hf_dataset_from_torchdataloader():
|
|
|
150
153
|
dataset = PytorchDictDataset()
|
|
151
154
|
|
|
152
155
|
def collate_fn(x: list[dict]):
|
|
153
|
-
return {"value": [item["
|
|
156
|
+
return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
|
|
154
157
|
|
|
155
158
|
dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
|
|
156
|
-
hf_dataset = hf_dataset_from_torch(dataloader)
|
|
159
|
+
hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
|
|
157
160
|
# Then the HF dataset should be created successfully
|
|
158
161
|
assert isinstance(hf_dataset, Dataset)
|
|
159
162
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
class TqdmFileReader:
|
|
2
|
+
def __init__(self, file_obj, pbar):
|
|
3
|
+
self.file_obj = file_obj
|
|
4
|
+
self.pbar = pbar
|
|
5
|
+
|
|
6
|
+
def read(self, size=-1):
|
|
7
|
+
data = self.file_obj.read(size)
|
|
8
|
+
self.pbar.update(len(data))
|
|
9
|
+
return data
|
|
10
|
+
|
|
11
|
+
def __getattr__(self, attr):
|
|
12
|
+
return getattr(self.file_obj, attr)
|
orca_sdk/classification_model.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import os
|
|
4
5
|
from contextlib import contextmanager
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
|
-
from uuid import UUID
|
|
8
|
+
from uuid import UUID, uuid4
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
8
11
|
|
|
9
12
|
import numpy as np
|
|
10
13
|
from datasets import Dataset
|
|
@@ -312,7 +315,8 @@ class ClassificationModel:
|
|
|
312
315
|
value: list[str],
|
|
313
316
|
expected_labels: list[int] | None = None,
|
|
314
317
|
tags: set[str] = set(),
|
|
315
|
-
|
|
318
|
+
save_telemetry: bool = True,
|
|
319
|
+
save_telemetry_synchronously: bool = False,
|
|
316
320
|
) -> list[LabelPrediction]:
|
|
317
321
|
pass
|
|
318
322
|
|
|
@@ -322,7 +326,8 @@ class ClassificationModel:
|
|
|
322
326
|
value: str,
|
|
323
327
|
expected_labels: int | None = None,
|
|
324
328
|
tags: set[str] = set(),
|
|
325
|
-
|
|
329
|
+
save_telemetry: bool = True,
|
|
330
|
+
save_telemetry_synchronously: bool = False,
|
|
326
331
|
) -> LabelPrediction:
|
|
327
332
|
pass
|
|
328
333
|
|
|
@@ -331,7 +336,8 @@ class ClassificationModel:
|
|
|
331
336
|
value: list[str] | str,
|
|
332
337
|
expected_labels: list[int] | int | None = None,
|
|
333
338
|
tags: set[str] = set(),
|
|
334
|
-
|
|
339
|
+
save_telemetry: bool = True,
|
|
340
|
+
save_telemetry_synchronously: bool = False,
|
|
335
341
|
) -> list[LabelPrediction] | LabelPrediction:
|
|
336
342
|
"""
|
|
337
343
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -340,7 +346,10 @@ class ClassificationModel:
|
|
|
340
346
|
value: Value(s) to get predict the labels of
|
|
341
347
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
342
348
|
tags: Tags to add to the prediction(s)
|
|
343
|
-
|
|
349
|
+
save_telemetry: Whether to enable telemetry for the prediction(s)
|
|
350
|
+
save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
|
|
351
|
+
asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
|
|
352
|
+
may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
|
|
344
353
|
|
|
345
354
|
Returns:
|
|
346
355
|
Label prediction or list of label predictions
|
|
@@ -358,6 +367,13 @@ class ClassificationModel:
|
|
|
358
367
|
]
|
|
359
368
|
"""
|
|
360
369
|
|
|
370
|
+
if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
|
|
371
|
+
env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
|
|
372
|
+
logging.info(
|
|
373
|
+
f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
|
|
374
|
+
)
|
|
375
|
+
save_telemetry_synchronously = env_var.lower() == "true"
|
|
376
|
+
|
|
361
377
|
response = predict_gpu(
|
|
362
378
|
self.id,
|
|
363
379
|
body=PredictionRequest(
|
|
@@ -366,14 +382,17 @@ class ClassificationModel:
|
|
|
366
382
|
expected_labels=(
|
|
367
383
|
expected_labels
|
|
368
384
|
if isinstance(expected_labels, list)
|
|
369
|
-
else [expected_labels]
|
|
385
|
+
else [expected_labels]
|
|
386
|
+
if expected_labels is not None
|
|
387
|
+
else None
|
|
370
388
|
),
|
|
371
389
|
tags=list(tags),
|
|
372
|
-
|
|
390
|
+
save_telemetry=save_telemetry,
|
|
391
|
+
save_telemetry_synchronously=save_telemetry_synchronously,
|
|
373
392
|
),
|
|
374
393
|
)
|
|
375
394
|
|
|
376
|
-
if
|
|
395
|
+
if save_telemetry and any(p.prediction_id is None for p in response):
|
|
377
396
|
raise RuntimeError("Failed to save prediction to database.")
|
|
378
397
|
|
|
379
398
|
predictions = [
|
|
@@ -386,8 +405,9 @@ class ClassificationModel:
|
|
|
386
405
|
memoryset=self.memoryset,
|
|
387
406
|
model=self,
|
|
388
407
|
logits=prediction.logits,
|
|
408
|
+
input_value=input_value,
|
|
389
409
|
)
|
|
390
|
-
for prediction in response
|
|
410
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
391
411
|
]
|
|
392
412
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
393
413
|
self._last_prediction = predictions[-1]
|
|
@@ -463,7 +483,6 @@ class ClassificationModel:
|
|
|
463
483
|
predictions: list[LabelPrediction],
|
|
464
484
|
expected_labels: list[int],
|
|
465
485
|
) -> ClassificationEvaluationResult:
|
|
466
|
-
|
|
467
486
|
targets_array = np.array(expected_labels)
|
|
468
487
|
predictions_array = np.array([p.label for p in predictions])
|
|
469
488
|
|
|
@@ -553,7 +572,8 @@ class ClassificationModel:
|
|
|
553
572
|
batch[value_column],
|
|
554
573
|
expected_labels=batch[label_column],
|
|
555
574
|
tags=tags,
|
|
556
|
-
|
|
575
|
+
save_telemetry=record_predictions,
|
|
576
|
+
save_telemetry_synchronously=(not record_predictions),
|
|
557
577
|
)
|
|
558
578
|
)
|
|
559
579
|
expected_labels.extend(batch[label_column])
|
|
@@ -581,7 +601,7 @@ class ClassificationModel:
|
|
|
581
601
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
582
602
|
|
|
583
603
|
Returns:
|
|
584
|
-
Dictionary with evaluation metrics
|
|
604
|
+
Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
|
|
585
605
|
|
|
586
606
|
Examples:
|
|
587
607
|
Evaluate using a Datasource:
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
1
3
|
from uuid import uuid4
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -9,46 +11,51 @@ from .datasource import Datasource
|
|
|
9
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
10
12
|
from .memoryset import LabeledMemoryset
|
|
11
13
|
|
|
14
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
12
15
|
|
|
13
|
-
|
|
16
|
+
|
|
17
|
+
SKIP_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
14
21
|
assert model is not None
|
|
15
22
|
assert model.name == "test_model"
|
|
16
|
-
assert model.memoryset ==
|
|
23
|
+
assert model.memoryset == readonly_memoryset
|
|
17
24
|
assert model.num_classes == 2
|
|
18
25
|
assert model.memory_lookup_count == 3
|
|
19
26
|
|
|
20
27
|
|
|
21
|
-
def test_create_model_already_exists_error(
|
|
28
|
+
def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
|
|
22
29
|
with pytest.raises(ValueError):
|
|
23
|
-
ClassificationModel.create("test_model",
|
|
30
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
24
31
|
with pytest.raises(ValueError):
|
|
25
|
-
ClassificationModel.create("test_model",
|
|
32
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="error")
|
|
26
33
|
|
|
27
34
|
|
|
28
|
-
def test_create_model_already_exists_return(
|
|
35
|
+
def test_create_model_already_exists_return(readonly_memoryset, model: ClassificationModel):
|
|
29
36
|
with pytest.raises(ValueError):
|
|
30
|
-
ClassificationModel.create("test_model",
|
|
37
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
31
38
|
|
|
32
39
|
with pytest.raises(ValueError):
|
|
33
|
-
ClassificationModel.create("test_model",
|
|
40
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", memory_lookup_count=37)
|
|
34
41
|
|
|
35
42
|
with pytest.raises(ValueError):
|
|
36
|
-
ClassificationModel.create("test_model",
|
|
43
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
37
44
|
|
|
38
45
|
with pytest.raises(ValueError):
|
|
39
|
-
ClassificationModel.create("test_model",
|
|
46
|
+
ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77)
|
|
40
47
|
|
|
41
|
-
new_model = ClassificationModel.create("test_model",
|
|
48
|
+
new_model = ClassificationModel.create("test_model", readonly_memoryset, if_exists="open")
|
|
42
49
|
assert new_model is not None
|
|
43
50
|
assert new_model.name == "test_model"
|
|
44
|
-
assert new_model.memoryset ==
|
|
51
|
+
assert new_model.memoryset == readonly_memoryset
|
|
45
52
|
assert new_model.num_classes == 2
|
|
46
53
|
assert new_model.memory_lookup_count == 3
|
|
47
54
|
|
|
48
55
|
|
|
49
|
-
def test_create_model_unauthenticated(unauthenticated,
|
|
56
|
+
def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
|
|
50
57
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
51
|
-
ClassificationModel.create("test_model",
|
|
58
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
52
59
|
|
|
53
60
|
|
|
54
61
|
def test_get_model(model: ClassificationModel):
|
|
@@ -107,8 +114,8 @@ def test_update_model_no_description(model: ClassificationModel):
|
|
|
107
114
|
assert model.description is None
|
|
108
115
|
|
|
109
116
|
|
|
110
|
-
def test_delete_model(
|
|
111
|
-
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(
|
|
117
|
+
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
118
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
|
|
112
119
|
assert ClassificationModel.open("model_to_delete")
|
|
113
120
|
ClassificationModel.drop("model_to_delete")
|
|
114
121
|
with pytest.raises(LookupError):
|
|
@@ -133,25 +140,38 @@ def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
|
|
|
133
140
|
|
|
134
141
|
|
|
135
142
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
136
|
-
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset
|
|
143
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
|
|
137
144
|
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
138
145
|
with pytest.raises(RuntimeError):
|
|
139
146
|
LabeledMemoryset.drop(memoryset.id)
|
|
140
147
|
|
|
141
148
|
|
|
142
|
-
def
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
]
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
149
|
+
def test_evaluate(model, eval_datasource: Datasource):
|
|
150
|
+
result = model.evaluate(eval_datasource)
|
|
151
|
+
assert result is not None
|
|
152
|
+
assert isinstance(result, dict)
|
|
153
|
+
# And anomaly score statistics are present and valid
|
|
154
|
+
assert isinstance(result["anomaly_score_mean"], float)
|
|
155
|
+
assert isinstance(result["anomaly_score_median"], float)
|
|
156
|
+
assert isinstance(result["anomaly_score_variance"], float)
|
|
157
|
+
assert -1.0 <= result["anomaly_score_mean"] <= 1.0
|
|
158
|
+
assert -1.0 <= result["anomaly_score_median"] <= 1.0
|
|
159
|
+
assert -1.0 <= result["anomaly_score_variance"] <= 1.0
|
|
160
|
+
assert isinstance(result["accuracy"], float)
|
|
161
|
+
assert isinstance(result["f1_score"], float)
|
|
162
|
+
assert isinstance(result["loss"], float)
|
|
163
|
+
assert len(result["precision_recall_curve"]["thresholds"]) == 4
|
|
164
|
+
assert len(result["precision_recall_curve"]["precisions"]) == 4
|
|
165
|
+
assert len(result["precision_recall_curve"]["recalls"]) == 4
|
|
166
|
+
assert len(result["roc_curve"]["thresholds"]) == 4
|
|
167
|
+
assert len(result["roc_curve"]["false_positive_rates"]) == 4
|
|
168
|
+
assert len(result["roc_curve"]["true_positive_rates"]) == 4
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_evaluate_combined(model, eval_datasource: Datasource, eval_dataset: Dataset):
|
|
172
|
+
result_datasource = model.evaluate(eval_datasource)
|
|
173
|
+
|
|
174
|
+
result_dataset = model.evaluate(eval_dataset)
|
|
155
175
|
|
|
156
176
|
for result in [result_datasource, result_dataset]:
|
|
157
177
|
assert result is not None
|
|
@@ -217,7 +237,7 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
217
237
|
|
|
218
238
|
|
|
219
239
|
def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
|
|
220
|
-
predictions = model.predict(["Do you love soup?", "Are cats cute?"],
|
|
240
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry=False)
|
|
221
241
|
assert len(predictions) == 2
|
|
222
242
|
assert predictions[0].prediction_id is None
|
|
223
243
|
assert predictions[1].prediction_id is None
|
|
@@ -239,9 +259,12 @@ def test_predict_unauthorized(unauthorized, model: ClassificationModel):
|
|
|
239
259
|
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
240
260
|
|
|
241
261
|
|
|
242
|
-
def test_predict_constraint_violation(
|
|
262
|
+
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
243
263
|
model = ClassificationModel.create(
|
|
244
|
-
"test_model_lookup_count_too_high",
|
|
264
|
+
"test_model_lookup_count_too_high",
|
|
265
|
+
readonly_memoryset,
|
|
266
|
+
num_classes=2,
|
|
267
|
+
memory_lookup_count=readonly_memoryset.length + 2,
|
|
245
268
|
)
|
|
246
269
|
with pytest.raises(RuntimeError):
|
|
247
270
|
model.predict("test")
|
|
@@ -281,7 +304,6 @@ def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset:
|
|
|
281
304
|
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
282
305
|
"test_memoryset_inverted_labels",
|
|
283
306
|
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
284
|
-
value_column="text",
|
|
285
307
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
286
308
|
)
|
|
287
309
|
with model.use_memoryset(inverted_labeled_memoryset):
|
|
@@ -323,3 +345,42 @@ def test_last_prediction_with_single(model: ClassificationModel):
|
|
|
323
345
|
assert model.last_prediction.prediction_id == prediction.prediction_id
|
|
324
346
|
assert model.last_prediction.input_value == "Do you love soup?"
|
|
325
347
|
assert model._last_prediction_was_batch is False
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@pytest.mark.skipif(
|
|
351
|
+
SKIP_IN_GITHUB_ACTIONS, reason="Skipping explanation test because in CI we don't have Anthropic API key"
|
|
352
|
+
)
|
|
353
|
+
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
354
|
+
|
|
355
|
+
writable_memoryset.analyze(
|
|
356
|
+
{"name": "neighbor", "neighbor_counts": [1, 3]},
|
|
357
|
+
lookup_count=3,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
model = ClassificationModel.create(
|
|
361
|
+
"test_model_for_explain",
|
|
362
|
+
writable_memoryset,
|
|
363
|
+
num_classes=2,
|
|
364
|
+
memory_lookup_count=3,
|
|
365
|
+
description="This is a test model for explain",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
369
|
+
assert len(predictions) == 2
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
explanation = predictions[0].explanation
|
|
373
|
+
print(explanation)
|
|
374
|
+
assert explanation is not None
|
|
375
|
+
assert len(explanation) > 10
|
|
376
|
+
assert "soup" in explanation.lower()
|
|
377
|
+
except Exception as e:
|
|
378
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
379
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set on server")
|
|
380
|
+
else:
|
|
381
|
+
raise e
|
|
382
|
+
finally:
|
|
383
|
+
try:
|
|
384
|
+
ClassificationModel.drop("test_model_for_explain")
|
|
385
|
+
except Exception as e:
|
|
386
|
+
logging.info(f"Failed to drop test model for explain: {e}")
|
orca_sdk/conftest.py
CHANGED
|
@@ -17,6 +17,8 @@ logging.basicConfig(level=logging.INFO)
|
|
|
17
17
|
|
|
18
18
|
os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
|
|
19
19
|
|
|
20
|
+
os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
21
|
+
|
|
20
22
|
|
|
21
23
|
def _create_org_id():
|
|
22
24
|
# UUID start to identify test data (0xtest...)
|
|
@@ -69,22 +71,22 @@ def label_names():
|
|
|
69
71
|
|
|
70
72
|
|
|
71
73
|
SAMPLE_DATA = [
|
|
72
|
-
{"
|
|
73
|
-
{"
|
|
74
|
-
{"
|
|
75
|
-
{"
|
|
76
|
-
{"
|
|
77
|
-
{"
|
|
78
|
-
{"
|
|
79
|
-
{"
|
|
80
|
-
{"
|
|
81
|
-
{"
|
|
82
|
-
{"
|
|
83
|
-
{"
|
|
84
|
-
{"
|
|
85
|
-
{"
|
|
86
|
-
{"
|
|
87
|
-
{"
|
|
74
|
+
{"value": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
|
|
75
|
+
{"value": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
|
|
76
|
+
{"value": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
|
|
77
|
+
{"value": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
|
|
78
|
+
{"value": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
|
|
79
|
+
{"value": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
|
|
80
|
+
{"value": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
|
|
81
|
+
{"value": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
|
|
82
|
+
{"value": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
|
|
83
|
+
{"value": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
|
|
84
|
+
{"value": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
|
|
85
|
+
{"value": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
|
|
86
|
+
{"value": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
|
|
87
|
+
{"value": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
|
|
88
|
+
{"value": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
|
|
89
|
+
{"value": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
|
|
88
90
|
]
|
|
89
91
|
|
|
90
92
|
|
|
@@ -94,7 +96,7 @@ def hf_dataset(label_names):
|
|
|
94
96
|
SAMPLE_DATA,
|
|
95
97
|
features=Features(
|
|
96
98
|
{
|
|
97
|
-
"
|
|
99
|
+
"value": Value("string"),
|
|
98
100
|
"label": ClassLabel(names=label_names),
|
|
99
101
|
"key": Value("string"),
|
|
100
102
|
"score": Value("float"),
|
|
@@ -106,23 +108,83 @@ def hf_dataset(label_names):
|
|
|
106
108
|
|
|
107
109
|
@pytest.fixture(scope="session")
|
|
108
110
|
def datasource(hf_dataset) -> Datasource:
|
|
109
|
-
|
|
111
|
+
datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
112
|
+
return datasource
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
EVAL_DATASET = [
|
|
116
|
+
{"value": "chicken noodle soup is the best", "label": 1},
|
|
117
|
+
{"value": "cats are cute", "label": 0},
|
|
118
|
+
{"value": "soup is great for the winter", "label": 0},
|
|
119
|
+
{"value": "i love cats", "label": 1},
|
|
120
|
+
]
|
|
110
121
|
|
|
111
122
|
|
|
112
123
|
@pytest.fixture(scope="session")
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
|
|
124
|
+
def eval_datasource() -> Datasource:
|
|
125
|
+
eval_datasource = Datasource.from_list("eval_datasource", EVAL_DATASET)
|
|
126
|
+
return eval_datasource
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.fixture(scope="session")
|
|
130
|
+
def eval_dataset() -> Dataset:
|
|
131
|
+
eval_dataset = Dataset.from_list(EVAL_DATASET)
|
|
132
|
+
return eval_dataset
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@pytest.fixture(scope="session")
|
|
136
|
+
def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
|
|
137
|
+
memoryset = LabeledMemoryset.create(
|
|
138
|
+
"test_readonly_memoryset",
|
|
116
139
|
datasource=datasource,
|
|
117
140
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
118
|
-
value_column="text",
|
|
119
141
|
source_id_column="source_id",
|
|
120
142
|
max_seq_length_override=32,
|
|
121
143
|
)
|
|
144
|
+
return memoryset
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@pytest.fixture(scope="function")
|
|
148
|
+
def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[LabeledMemoryset, None, None]:
|
|
149
|
+
"""
|
|
150
|
+
Function-scoped fixture that provides a writable memoryset for tests that mutate state.
|
|
151
|
+
|
|
152
|
+
This fixture creates a fresh `LabeledMemoryset` named 'test_writable_memoryset' before each test.
|
|
153
|
+
After the test, it attempts to restore the memoryset to its initial state by deleting any added entries
|
|
154
|
+
and reinserting sample data — unless the memoryset has been dropped by the test itself, in which case
|
|
155
|
+
it will be recreated on the next invocation.
|
|
156
|
+
|
|
157
|
+
Note: Re-creating the memoryset from scratch is surprisingly more expensive than cleaning it up.
|
|
158
|
+
"""
|
|
159
|
+
# It shouldn't be possible for this memoryset to already exist
|
|
160
|
+
memoryset = LabeledMemoryset.create(
|
|
161
|
+
"test_writable_memoryset",
|
|
162
|
+
datasource=datasource,
|
|
163
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
164
|
+
source_id_column="source_id",
|
|
165
|
+
max_seq_length_override=32,
|
|
166
|
+
if_exists="open",
|
|
167
|
+
)
|
|
168
|
+
try:
|
|
169
|
+
yield memoryset
|
|
170
|
+
finally:
|
|
171
|
+
# Restore the memoryset to a clean state for the next test.
|
|
172
|
+
OrcaCredentials.set_api_key(api_key, check_validity=False)
|
|
173
|
+
|
|
174
|
+
if LabeledMemoryset.exists("test_writable_memoryset"):
|
|
175
|
+
memory_ids = [memoryset[i].memory_id for i in range(len(memoryset))]
|
|
176
|
+
|
|
177
|
+
if memory_ids:
|
|
178
|
+
memoryset.delete(memory_ids)
|
|
179
|
+
memoryset.refresh()
|
|
180
|
+
assert len(memoryset) == 0
|
|
181
|
+
memoryset.insert(SAMPLE_DATA)
|
|
182
|
+
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
|
122
183
|
|
|
123
184
|
|
|
124
185
|
@pytest.fixture(scope="session")
|
|
125
|
-
def model(
|
|
126
|
-
|
|
127
|
-
"test_model",
|
|
186
|
+
def model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
|
|
187
|
+
model = ClassificationModel.create(
|
|
188
|
+
"test_model", readonly_memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
|
|
128
189
|
)
|
|
190
|
+
return model
|