orca-sdk 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +3 -3
- orca_sdk/_utils/auth.py +2 -3
- orca_sdk/_utils/common.py +24 -1
- orca_sdk/_utils/torch_parsing.py +77 -0
- orca_sdk/_utils/torch_parsing_test.py +142 -0
- orca_sdk/async_client.py +156 -4
- orca_sdk/classification_model.py +202 -65
- orca_sdk/classification_model_test.py +16 -3
- orca_sdk/client.py +156 -4
- orca_sdk/conftest.py +10 -9
- orca_sdk/datasource.py +31 -13
- orca_sdk/embedding_model.py +8 -31
- orca_sdk/embedding_model_test.py +1 -1
- orca_sdk/memoryset.py +236 -321
- orca_sdk/memoryset_test.py +39 -13
- orca_sdk/regression_model.py +185 -64
- orca_sdk/regression_model_test.py +18 -3
- orca_sdk/telemetry.py +15 -6
- {orca_sdk-0.1.11.dist-info → orca_sdk-0.1.12.dist-info}/METADATA +3 -5
- orca_sdk-0.1.12.dist-info/RECORD +38 -0
- orca_sdk/_shared/__init__.py +0 -10
- orca_sdk/_shared/metrics.py +0 -634
- orca_sdk/_shared/metrics_test.py +0 -570
- orca_sdk/_utils/data_parsing.py +0 -137
- orca_sdk/_utils/data_parsing_disk_test.py +0 -91
- orca_sdk/_utils/data_parsing_torch_test.py +0 -159
- orca_sdk-0.1.11.dist-info/RECORD +0 -42
- {orca_sdk-0.1.11.dist-info → orca_sdk-0.1.12.dist-info}/WHEEL +0 -0
orca_sdk/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
OrcaSDK is a Python library for building and using retrieval augmented models in the OrcaCloud.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
5
|
+
from ._utils.common import UNSET, CreateMode, DropMode, logger
|
|
6
6
|
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
7
7
|
from .client import OrcaClient
|
|
8
8
|
from .credentials import OrcaCredentials
|
|
@@ -23,8 +23,8 @@ from .memoryset import (
|
|
|
23
23
|
ScoredMemoryLookup,
|
|
24
24
|
ScoredMemoryset,
|
|
25
25
|
)
|
|
26
|
-
from .regression_model import RegressionModel
|
|
26
|
+
from .regression_model import RegressionMetrics, RegressionModel
|
|
27
27
|
from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
|
|
28
28
|
|
|
29
29
|
# only specify things that should show up on the root page of the reference docs because they are in private modules
|
|
30
|
-
__all__ = ["UNSET", "CreateMode", "DropMode"]
|
|
30
|
+
__all__ = ["UNSET", "CreateMode", "DropMode", "logger"]
|
orca_sdk/_utils/auth.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
"""This module contains internal utils for managing api keys in tests"""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
4
|
from typing import List, Literal
|
|
6
5
|
|
|
7
6
|
from dotenv import load_dotenv
|
|
8
7
|
|
|
9
8
|
from ..client import ApiKeyMetadata, OrcaClient
|
|
10
|
-
from .common import DropMode
|
|
9
|
+
from .common import DropMode, logger
|
|
11
10
|
|
|
12
11
|
load_dotenv() # this needs to be here to ensure env is populated before accessing it
|
|
13
12
|
|
|
@@ -59,7 +58,7 @@ def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "
|
|
|
59
58
|
client = OrcaClient._resolve_client()
|
|
60
59
|
client.base_url = "http://localhost:1584"
|
|
61
60
|
client.headers.update({"Api-Key": _create_api_key(org_id, api_key_name)})
|
|
62
|
-
|
|
61
|
+
logger.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
|
|
63
62
|
|
|
64
63
|
|
|
65
64
|
__all__ = ["_create_api_key", "_delete_api_key", "_delete_org", "_list_api_keys", "_authenticate_local_api"]
|
orca_sdk/_utils/common.py
CHANGED
|
@@ -1,4 +1,21 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Iterable, Iterator, Literal, TypeVar
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from itertools import batched
|
|
6
|
+
except ImportError:
|
|
7
|
+
# Polyfill for Python <3.12
|
|
8
|
+
|
|
9
|
+
from itertools import islice
|
|
10
|
+
|
|
11
|
+
_BatchT = TypeVar("_BatchT")
|
|
12
|
+
|
|
13
|
+
def batched(iterable: Iterable[_BatchT], n: int) -> Iterator[tuple[_BatchT, ...]]:
|
|
14
|
+
"""Batch an iterable into chunks of size n (backfill for Python <3.12)."""
|
|
15
|
+
it = iter(iterable)
|
|
16
|
+
while batch := tuple(islice(it, n)):
|
|
17
|
+
yield batch
|
|
18
|
+
|
|
2
19
|
|
|
3
20
|
CreateMode = Literal["error", "open"]
|
|
4
21
|
"""
|
|
@@ -35,3 +52,9 @@ UNSET: Any = _UnsetSentinel()
|
|
|
35
52
|
"""
|
|
36
53
|
Default value to indicate that no update should be applied to a field and it should not be set to None
|
|
37
54
|
"""
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger("orca_sdk")
|
|
57
|
+
"""
|
|
58
|
+
Logger for the Orca SDK.
|
|
59
|
+
"""
|
|
60
|
+
logger.addHandler(logging.NullHandler())
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, is_dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
# peer dependencies that are used for types only
|
|
8
|
+
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
9
|
+
from torch.utils.data import Dataset as TorchDataset # type: ignore
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
|
|
13
|
+
if isinstance(item, dict):
|
|
14
|
+
return item
|
|
15
|
+
|
|
16
|
+
if isinstance(item, tuple):
|
|
17
|
+
if column_names is not None:
|
|
18
|
+
if len(item) != len(column_names):
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Tuple length ({len(item)}) does not match number of column names ({len(column_names)})"
|
|
21
|
+
)
|
|
22
|
+
return {column_names[i]: item[i] for i in range(len(item))}
|
|
23
|
+
elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
|
|
24
|
+
return {field: getattr(item, field) for field in item._fields} # type: ignore
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
|
|
27
|
+
|
|
28
|
+
if is_dataclass(item) and not isinstance(item, type):
|
|
29
|
+
return asdict(item)
|
|
30
|
+
|
|
31
|
+
raise ValueError(f"Cannot parse {type(item)}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
|
|
35
|
+
if isinstance(batch, list):
|
|
36
|
+
return [parse_dict_like(item, column_names) for item in batch]
|
|
37
|
+
|
|
38
|
+
batch = parse_dict_like(batch, column_names)
|
|
39
|
+
keys = list(batch.keys())
|
|
40
|
+
batch_size = len(batch[keys[0]])
|
|
41
|
+
for key in keys:
|
|
42
|
+
if not len(batch[key]) == batch_size:
|
|
43
|
+
raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
|
|
44
|
+
return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def list_from_torch(
|
|
48
|
+
torch_data: TorchDataLoader | TorchDataset,
|
|
49
|
+
column_names: list[str] | None = None,
|
|
50
|
+
) -> list[dict]:
|
|
51
|
+
"""
|
|
52
|
+
Convert a PyTorch DataLoader or Dataset to a list of dictionaries.
|
|
53
|
+
|
|
54
|
+
Params:
|
|
55
|
+
torch_data: A PyTorch DataLoader or Dataset object to convert.
|
|
56
|
+
column_names: Optional list of column names to use for the data. If not provided,
|
|
57
|
+
the column names will be inferred from the data.
|
|
58
|
+
Returns:
|
|
59
|
+
A list of dictionaries containing the data from the PyTorch DataLoader or Dataset.
|
|
60
|
+
"""
|
|
61
|
+
# peer dependency that is guaranteed to exist if the user provided a torch dataset
|
|
62
|
+
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
63
|
+
|
|
64
|
+
if isinstance(torch_data, TorchDataLoader):
|
|
65
|
+
dataloader = torch_data
|
|
66
|
+
else:
|
|
67
|
+
dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
|
|
68
|
+
|
|
69
|
+
# Collect data from the dataloader into a list
|
|
70
|
+
data_list = []
|
|
71
|
+
try:
|
|
72
|
+
for batch in dataloader:
|
|
73
|
+
data_list.extend(parse_batch(batch, column_names=column_names))
|
|
74
|
+
except ValueError as e:
|
|
75
|
+
raise ValueError(str(e)) from e
|
|
76
|
+
|
|
77
|
+
return data_list
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from collections import namedtuple
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from .torch_parsing import list_from_torch
|
|
7
|
+
|
|
8
|
+
pytest.importorskip("torch")
|
|
9
|
+
|
|
10
|
+
from torch.utils.data import DataLoader as TorchDataLoader # noqa: E402
|
|
11
|
+
from torch.utils.data import Dataset as TorchDataset # noqa: E402
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_list_from_torch_dict_dataset(data: list[dict]):
|
|
15
|
+
class PytorchDictDataset(TorchDataset):
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.data = data
|
|
18
|
+
|
|
19
|
+
def __getitem__(self, i):
|
|
20
|
+
return self.data[i]
|
|
21
|
+
|
|
22
|
+
def __len__(self):
|
|
23
|
+
return len(self.data)
|
|
24
|
+
|
|
25
|
+
dataset = PytorchDictDataset()
|
|
26
|
+
data_list = list_from_torch(dataset)
|
|
27
|
+
|
|
28
|
+
assert isinstance(data_list, list)
|
|
29
|
+
assert len(data_list) == len(dataset)
|
|
30
|
+
assert set(list(data_list[0].keys())) == {"value", "label", "key", "score", "source_id", "partition_id"}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_list_from_torch_dataloader(data: list[dict]):
|
|
34
|
+
class PytorchDictDataset(TorchDataset):
|
|
35
|
+
def __init__(self):
|
|
36
|
+
self.data = data
|
|
37
|
+
|
|
38
|
+
def __getitem__(self, i):
|
|
39
|
+
return self.data[i]
|
|
40
|
+
|
|
41
|
+
def __len__(self):
|
|
42
|
+
return len(self.data)
|
|
43
|
+
|
|
44
|
+
dataset = PytorchDictDataset()
|
|
45
|
+
|
|
46
|
+
def collate_fn(x: list[dict]):
|
|
47
|
+
return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
|
|
48
|
+
|
|
49
|
+
dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
|
|
50
|
+
data_list = list_from_torch(dataloader)
|
|
51
|
+
|
|
52
|
+
assert isinstance(data_list, list)
|
|
53
|
+
assert len(data_list) == len(dataset)
|
|
54
|
+
assert list(data_list[0].keys()) == ["value", "label"]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def test_list_from_torch_tuple_dataset(data: list[dict]):
|
|
58
|
+
class PytorchTupleDataset(TorchDataset):
|
|
59
|
+
def __init__(self):
|
|
60
|
+
self.data = data
|
|
61
|
+
|
|
62
|
+
def __getitem__(self, i):
|
|
63
|
+
return self.data[i]["value"], self.data[i]["label"]
|
|
64
|
+
|
|
65
|
+
def __len__(self):
|
|
66
|
+
return len(self.data)
|
|
67
|
+
|
|
68
|
+
dataset = PytorchTupleDataset()
|
|
69
|
+
|
|
70
|
+
# raises error if no column names are passed in
|
|
71
|
+
with pytest.raises(ValueError):
|
|
72
|
+
list_from_torch(dataset)
|
|
73
|
+
|
|
74
|
+
# raises error if not enough column names are passed in
|
|
75
|
+
with pytest.raises(ValueError):
|
|
76
|
+
list_from_torch(dataset, column_names=["value"])
|
|
77
|
+
|
|
78
|
+
# creates list if correct number of column names are passed in
|
|
79
|
+
data_list = list_from_torch(dataset, column_names=["value", "label"])
|
|
80
|
+
assert isinstance(data_list, list)
|
|
81
|
+
assert len(data_list) == len(dataset)
|
|
82
|
+
assert list(data_list[0].keys()) == ["value", "label"]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_list_from_torch_named_tuple_dataset(data: list[dict]):
|
|
86
|
+
# Given a Pytorch dataset that returns a namedtuple for each item
|
|
87
|
+
DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
|
|
88
|
+
|
|
89
|
+
class PytorchNamedTupleDataset(TorchDataset):
|
|
90
|
+
def __init__(self):
|
|
91
|
+
self.data = data
|
|
92
|
+
|
|
93
|
+
def __getitem__(self, i):
|
|
94
|
+
return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
|
|
95
|
+
|
|
96
|
+
def __len__(self):
|
|
97
|
+
return len(self.data)
|
|
98
|
+
|
|
99
|
+
dataset = PytorchNamedTupleDataset()
|
|
100
|
+
data_list = list_from_torch(dataset)
|
|
101
|
+
assert isinstance(data_list, list)
|
|
102
|
+
assert len(data_list) == len(dataset)
|
|
103
|
+
assert list(data_list[0].keys()) == ["value", "label"]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_list_from_torch_dataclass_dataset(data: list[dict]):
|
|
107
|
+
@dataclass
|
|
108
|
+
class DatasetItem:
|
|
109
|
+
text: str
|
|
110
|
+
label: int
|
|
111
|
+
|
|
112
|
+
class PytorchDataclassDataset(TorchDataset):
|
|
113
|
+
def __init__(self):
|
|
114
|
+
self.data = data
|
|
115
|
+
|
|
116
|
+
def __getitem__(self, i):
|
|
117
|
+
return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
|
|
118
|
+
|
|
119
|
+
def __len__(self):
|
|
120
|
+
return len(self.data)
|
|
121
|
+
|
|
122
|
+
dataset = PytorchDataclassDataset()
|
|
123
|
+
data_list = list_from_torch(dataset)
|
|
124
|
+
assert isinstance(data_list, list)
|
|
125
|
+
assert len(data_list) == len(dataset)
|
|
126
|
+
assert list(data_list[0].keys()) == ["text", "label"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_list_from_torch_invalid_dataset(data: list[dict]):
|
|
130
|
+
class PytorchInvalidDataset(TorchDataset):
|
|
131
|
+
def __init__(self):
|
|
132
|
+
self.data = data
|
|
133
|
+
|
|
134
|
+
def __getitem__(self, i):
|
|
135
|
+
return [self.data[i]["value"], self.data[i]["label"]]
|
|
136
|
+
|
|
137
|
+
def __len__(self):
|
|
138
|
+
return len(self.data)
|
|
139
|
+
|
|
140
|
+
dataset = PytorchInvalidDataset()
|
|
141
|
+
with pytest.raises(ValueError):
|
|
142
|
+
list_from_torch(dataset)
|
orca_sdk/async_client.py
CHANGED
|
@@ -85,7 +85,7 @@ class BaseLabelPredictionResult(TypedDict):
|
|
|
85
85
|
anomaly_score: float | None
|
|
86
86
|
label: int | None
|
|
87
87
|
label_name: str | None
|
|
88
|
-
logits: list[float]
|
|
88
|
+
logits: list[float] | None
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
class BaseModel(TypedDict):
|
|
@@ -160,6 +160,18 @@ The type of a column in a datasource
|
|
|
160
160
|
"""
|
|
161
161
|
|
|
162
162
|
|
|
163
|
+
class ComputeClassificationMetricsRequest(TypedDict):
|
|
164
|
+
expected_labels: list[int]
|
|
165
|
+
logits: list[list[float] | None]
|
|
166
|
+
anomaly_scores: NotRequired[list[float] | None]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class ComputeRegressionMetricsRequest(TypedDict):
|
|
170
|
+
expected_scores: list[float]
|
|
171
|
+
predicted_scores: list[float | None]
|
|
172
|
+
anomaly_scores: NotRequired[list[float] | None]
|
|
173
|
+
|
|
174
|
+
|
|
163
175
|
class ConstraintViolationErrorResponse(TypedDict):
|
|
164
176
|
status_code: Literal[409]
|
|
165
177
|
constraint: str
|
|
@@ -322,6 +334,7 @@ class GetDatasourceRowsRequest(TypedDict):
|
|
|
322
334
|
|
|
323
335
|
class GetMemoriesRequest(TypedDict):
|
|
324
336
|
memory_ids: list[str]
|
|
337
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
325
338
|
|
|
326
339
|
|
|
327
340
|
class HealthyResponse(TypedDict):
|
|
@@ -392,6 +405,7 @@ class ListMemoriesRequest(TypedDict):
|
|
|
392
405
|
offset: NotRequired[int]
|
|
393
406
|
limit: NotRequired[int]
|
|
394
407
|
filters: NotRequired[list[FilterItem]]
|
|
408
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
395
409
|
|
|
396
410
|
|
|
397
411
|
class LookupRequest(TypedDict):
|
|
@@ -400,6 +414,7 @@ class LookupRequest(TypedDict):
|
|
|
400
414
|
prompt: NotRequired[str | None]
|
|
401
415
|
partition_id: NotRequired[str | list[str | None] | None]
|
|
402
416
|
partition_filter_mode: NotRequired[Literal["ignore_partitions", "include_global", "exclude_global", "only_global"]]
|
|
417
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
403
418
|
|
|
404
419
|
|
|
405
420
|
class LookupScoreMetrics(TypedDict):
|
|
@@ -570,8 +585,17 @@ class OrgPlan(TypedDict):
|
|
|
570
585
|
|
|
571
586
|
class PRCurve(TypedDict):
|
|
572
587
|
thresholds: list[float]
|
|
588
|
+
"""
|
|
589
|
+
Threshold values for the curve
|
|
590
|
+
"""
|
|
573
591
|
precisions: list[float]
|
|
592
|
+
"""
|
|
593
|
+
Precision values at each threshold
|
|
594
|
+
"""
|
|
574
595
|
recalls: list[float]
|
|
596
|
+
"""
|
|
597
|
+
Recall values at each threshold
|
|
598
|
+
"""
|
|
575
599
|
|
|
576
600
|
|
|
577
601
|
class PredictionFeedback(TypedDict):
|
|
@@ -642,8 +666,17 @@ RARHeadType: TypeAlias = Literal["MMOE", "KNN"]
|
|
|
642
666
|
|
|
643
667
|
class ROCCurve(TypedDict):
|
|
644
668
|
thresholds: list[float]
|
|
669
|
+
"""
|
|
670
|
+
Threshold values for the curve
|
|
671
|
+
"""
|
|
645
672
|
false_positive_rates: list[float]
|
|
673
|
+
"""
|
|
674
|
+
False positive rate values at each threshold
|
|
675
|
+
"""
|
|
646
676
|
true_positive_rates: list[float]
|
|
677
|
+
"""
|
|
678
|
+
True positive rate values at each threshold
|
|
679
|
+
"""
|
|
647
680
|
|
|
648
681
|
|
|
649
682
|
class ReadyResponse(TypedDict):
|
|
@@ -666,15 +699,49 @@ class RegressionEvaluationRequest(TypedDict):
|
|
|
666
699
|
|
|
667
700
|
class RegressionMetrics(TypedDict):
|
|
668
701
|
coverage: float
|
|
702
|
+
"""
|
|
703
|
+
Percentage of predictions that are not none
|
|
704
|
+
"""
|
|
669
705
|
mse: float
|
|
706
|
+
"""
|
|
707
|
+
Mean squared error of the predictions
|
|
708
|
+
"""
|
|
670
709
|
rmse: float
|
|
710
|
+
"""
|
|
711
|
+
Root mean squared error of the predictions
|
|
712
|
+
"""
|
|
671
713
|
mae: float
|
|
714
|
+
"""
|
|
715
|
+
Mean absolute error of the predictions
|
|
716
|
+
"""
|
|
672
717
|
r2: float
|
|
718
|
+
"""
|
|
719
|
+
R-squared score (coefficient of determination) of the predictions
|
|
720
|
+
"""
|
|
673
721
|
explained_variance: float
|
|
722
|
+
"""
|
|
723
|
+
Explained variance score of the predictions
|
|
724
|
+
"""
|
|
674
725
|
loss: float
|
|
726
|
+
"""
|
|
727
|
+
Mean squared error loss of the predictions
|
|
728
|
+
"""
|
|
675
729
|
anomaly_score_mean: NotRequired[float | None]
|
|
730
|
+
"""
|
|
731
|
+
Mean of anomaly scores across the dataset
|
|
732
|
+
"""
|
|
676
733
|
anomaly_score_median: NotRequired[float | None]
|
|
734
|
+
"""
|
|
735
|
+
Median of anomaly scores across the dataset
|
|
736
|
+
"""
|
|
677
737
|
anomaly_score_variance: NotRequired[float | None]
|
|
738
|
+
"""
|
|
739
|
+
Variance of anomaly scores across the dataset
|
|
740
|
+
"""
|
|
741
|
+
warnings: NotRequired[list[str]]
|
|
742
|
+
"""
|
|
743
|
+
Human-readable warnings about skipped or adjusted metrics
|
|
744
|
+
"""
|
|
678
745
|
|
|
679
746
|
|
|
680
747
|
class RegressionModelMetadata(TypedDict):
|
|
@@ -703,7 +770,7 @@ class RegressionPredictionRequest(TypedDict):
|
|
|
703
770
|
save_telemetry_synchronously: NotRequired[bool]
|
|
704
771
|
prompt: NotRequired[str | None]
|
|
705
772
|
use_lookup_cache: NotRequired[bool]
|
|
706
|
-
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]
|
|
773
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
707
774
|
ignore_unlabeled: NotRequired[bool]
|
|
708
775
|
partition_ids: NotRequired[str | list[str | None] | None]
|
|
709
776
|
partition_filter_mode: NotRequired[Literal["ignore_partitions", "include_global", "exclude_global", "only_global"]]
|
|
@@ -927,6 +994,7 @@ class GetMemorysetByNameOrIdMemoryByMemoryIdParams(TypedDict):
|
|
|
927
994
|
"""
|
|
928
995
|
ID of the memory
|
|
929
996
|
"""
|
|
997
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
930
998
|
|
|
931
999
|
|
|
932
1000
|
class DeleteMemorysetByNameOrIdMemoryByMemoryIdParams(TypedDict):
|
|
@@ -1304,18 +1372,57 @@ class BootstrapLabeledMemoryDataResult(TypedDict):
|
|
|
1304
1372
|
|
|
1305
1373
|
class ClassificationMetrics(TypedDict):
|
|
1306
1374
|
coverage: float
|
|
1375
|
+
"""
|
|
1376
|
+
Percentage of predictions that are not none
|
|
1377
|
+
"""
|
|
1307
1378
|
f1_score: float
|
|
1379
|
+
"""
|
|
1380
|
+
F1 score of the predictions
|
|
1381
|
+
"""
|
|
1308
1382
|
accuracy: float
|
|
1383
|
+
"""
|
|
1384
|
+
Accuracy of the predictions
|
|
1385
|
+
"""
|
|
1309
1386
|
loss: float | None
|
|
1387
|
+
"""
|
|
1388
|
+
Cross-entropy loss of the logits
|
|
1389
|
+
"""
|
|
1310
1390
|
anomaly_score_mean: NotRequired[float | None]
|
|
1391
|
+
"""
|
|
1392
|
+
Mean of anomaly scores across the dataset
|
|
1393
|
+
"""
|
|
1311
1394
|
anomaly_score_median: NotRequired[float | None]
|
|
1395
|
+
"""
|
|
1396
|
+
Median of anomaly scores across the dataset
|
|
1397
|
+
"""
|
|
1312
1398
|
anomaly_score_variance: NotRequired[float | None]
|
|
1399
|
+
"""
|
|
1400
|
+
Variance of anomaly scores across the dataset
|
|
1401
|
+
"""
|
|
1313
1402
|
roc_auc: NotRequired[float | None]
|
|
1403
|
+
"""
|
|
1404
|
+
Receiver operating characteristic area under the curve
|
|
1405
|
+
"""
|
|
1314
1406
|
pr_auc: NotRequired[float | None]
|
|
1407
|
+
"""
|
|
1408
|
+
Average precision (area under the curve of the precision-recall curve)
|
|
1409
|
+
"""
|
|
1315
1410
|
pr_curve: NotRequired[PRCurve | None]
|
|
1411
|
+
"""
|
|
1412
|
+
Precision-recall curve
|
|
1413
|
+
"""
|
|
1316
1414
|
roc_curve: NotRequired[ROCCurve | None]
|
|
1415
|
+
"""
|
|
1416
|
+
Receiver operating characteristic curve
|
|
1417
|
+
"""
|
|
1317
1418
|
confusion_matrix: NotRequired[list[list[int]] | None]
|
|
1419
|
+
"""
|
|
1420
|
+
Confusion matrix where the entry at row i, column j is the count of samples with true label i predicted as label j
|
|
1421
|
+
"""
|
|
1318
1422
|
warnings: NotRequired[list[str]]
|
|
1423
|
+
"""
|
|
1424
|
+
Human-readable warnings about skipped or adjusted metrics
|
|
1425
|
+
"""
|
|
1319
1426
|
|
|
1320
1427
|
|
|
1321
1428
|
class ClassificationModelMetadata(TypedDict):
|
|
@@ -1348,7 +1455,7 @@ class ClassificationPredictionRequest(TypedDict):
|
|
|
1348
1455
|
save_telemetry_synchronously: NotRequired[bool]
|
|
1349
1456
|
prompt: NotRequired[str | None]
|
|
1350
1457
|
use_lookup_cache: NotRequired[bool]
|
|
1351
|
-
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]
|
|
1458
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
1352
1459
|
ignore_unlabeled: NotRequired[bool]
|
|
1353
1460
|
partition_ids: NotRequired[str | list[str | None] | None]
|
|
1354
1461
|
partition_filter_mode: NotRequired[Literal["ignore_partitions", "include_global", "exclude_global", "only_global"]]
|
|
@@ -1362,6 +1469,7 @@ class CloneMemorysetRequest(TypedDict):
|
|
|
1362
1469
|
finetuned_embedding_model_name_or_id: NotRequired[str | None]
|
|
1363
1470
|
max_seq_length_override: NotRequired[int | None]
|
|
1364
1471
|
prompt: NotRequired[str]
|
|
1472
|
+
is_partitioned: NotRequired[bool | None]
|
|
1365
1473
|
|
|
1366
1474
|
|
|
1367
1475
|
class ColumnInfo(TypedDict):
|
|
@@ -1409,6 +1517,7 @@ class CreateMemorysetFromDatasourceRequest(TypedDict):
|
|
|
1409
1517
|
prompt: NotRequired[str]
|
|
1410
1518
|
hidden: NotRequired[bool]
|
|
1411
1519
|
memory_type: NotRequired[MemoryType | None]
|
|
1520
|
+
is_partitioned: NotRequired[bool]
|
|
1412
1521
|
datasource_name_or_id: str
|
|
1413
1522
|
datasource_label_column: NotRequired[str | None]
|
|
1414
1523
|
datasource_score_column: NotRequired[str | None]
|
|
@@ -1433,6 +1542,7 @@ class CreateMemorysetRequest(TypedDict):
|
|
|
1433
1542
|
prompt: NotRequired[str]
|
|
1434
1543
|
hidden: NotRequired[bool]
|
|
1435
1544
|
memory_type: NotRequired[MemoryType | None]
|
|
1545
|
+
is_partitioned: NotRequired[bool]
|
|
1436
1546
|
|
|
1437
1547
|
|
|
1438
1548
|
class CreateRegressionModelRequest(TypedDict):
|
|
@@ -1590,7 +1700,7 @@ class LabelPredictionWithMemoriesAndFeedback(TypedDict):
|
|
|
1590
1700
|
anomaly_score: float | None
|
|
1591
1701
|
label: int | None
|
|
1592
1702
|
label_name: str | None
|
|
1593
|
-
logits: list[float]
|
|
1703
|
+
logits: list[float] | None
|
|
1594
1704
|
timestamp: str
|
|
1595
1705
|
input_value: str | bytes
|
|
1596
1706
|
input_embedding: list[float]
|
|
@@ -1746,6 +1856,7 @@ class TelemetryMemoriesRequest(TypedDict):
|
|
|
1746
1856
|
limit: NotRequired[int]
|
|
1747
1857
|
filters: NotRequired[list[FilterItem | TelemetryFilterItem]]
|
|
1748
1858
|
sort: NotRequired[list[TelemetrySortOptions] | None]
|
|
1859
|
+
consistency_level: NotRequired[Literal["Bounded", "Session", "Strong", "Eventual"]]
|
|
1749
1860
|
|
|
1750
1861
|
|
|
1751
1862
|
class WorkerInfo(TypedDict):
|
|
@@ -1812,6 +1923,7 @@ class MemorysetMetadata(TypedDict):
|
|
|
1812
1923
|
document_prompt_override: str | None
|
|
1813
1924
|
query_prompt_override: str | None
|
|
1814
1925
|
hidden: bool
|
|
1926
|
+
is_partitioned: bool
|
|
1815
1927
|
insertion_task_id: str | None
|
|
1816
1928
|
|
|
1817
1929
|
|
|
@@ -3660,6 +3772,46 @@ class OrcaAsyncClient(AsyncClient):
|
|
|
3660
3772
|
) -> EvaluationResponse:
|
|
3661
3773
|
pass
|
|
3662
3774
|
|
|
3775
|
+
@overload
|
|
3776
|
+
async def POST(
|
|
3777
|
+
self,
|
|
3778
|
+
path: Literal["/classification_model/metrics"],
|
|
3779
|
+
*,
|
|
3780
|
+
params: None = None,
|
|
3781
|
+
json: ComputeClassificationMetricsRequest,
|
|
3782
|
+
data: None = None,
|
|
3783
|
+
files: None = None,
|
|
3784
|
+
content: None = None,
|
|
3785
|
+
parse_as: Literal["json"] = "json",
|
|
3786
|
+
headers: HeaderTypes | None = None,
|
|
3787
|
+
cookies: CookieTypes | None = None,
|
|
3788
|
+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3789
|
+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3790
|
+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3791
|
+
extensions: RequestExtensions | None = None,
|
|
3792
|
+
) -> ClassificationMetrics:
|
|
3793
|
+
pass
|
|
3794
|
+
|
|
3795
|
+
@overload
|
|
3796
|
+
async def POST(
|
|
3797
|
+
self,
|
|
3798
|
+
path: Literal["/regression_model/metrics"],
|
|
3799
|
+
*,
|
|
3800
|
+
params: None = None,
|
|
3801
|
+
json: ComputeRegressionMetricsRequest,
|
|
3802
|
+
data: None = None,
|
|
3803
|
+
files: None = None,
|
|
3804
|
+
content: None = None,
|
|
3805
|
+
parse_as: Literal["json"] = "json",
|
|
3806
|
+
headers: HeaderTypes | None = None,
|
|
3807
|
+
cookies: CookieTypes | None = None,
|
|
3808
|
+
auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3809
|
+
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3810
|
+
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
3811
|
+
extensions: RequestExtensions | None = None,
|
|
3812
|
+
) -> RegressionMetrics:
|
|
3813
|
+
pass
|
|
3814
|
+
|
|
3663
3815
|
@overload
|
|
3664
3816
|
async def POST(
|
|
3665
3817
|
self,
|