orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +3 -3
- orca_sdk/_utils/analysis_ui.py +4 -1
- orca_sdk/_utils/auth.py +2 -3
- orca_sdk/_utils/common.py +24 -1
- orca_sdk/_utils/prediction_result_ui.py +4 -1
- orca_sdk/_utils/torch_parsing.py +77 -0
- orca_sdk/_utils/torch_parsing_test.py +142 -0
- orca_sdk/_utils/value_parser.py +44 -17
- orca_sdk/_utils/value_parser_test.py +6 -5
- orca_sdk/async_client.py +234 -22
- orca_sdk/classification_model.py +203 -66
- orca_sdk/classification_model_test.py +85 -25
- orca_sdk/client.py +234 -20
- orca_sdk/conftest.py +97 -16
- orca_sdk/credentials_test.py +5 -8
- orca_sdk/datasource.py +44 -21
- orca_sdk/datasource_test.py +8 -2
- orca_sdk/embedding_model.py +15 -33
- orca_sdk/embedding_model_test.py +30 -1
- orca_sdk/memoryset.py +558 -425
- orca_sdk/memoryset_test.py +120 -185
- orca_sdk/regression_model.py +186 -65
- orca_sdk/regression_model_test.py +62 -3
- orca_sdk/telemetry.py +16 -7
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/METADATA +4 -8
- orca_sdk-0.1.12.dist-info/RECORD +38 -0
- orca_sdk/_shared/__init__.py +0 -10
- orca_sdk/_shared/metrics.py +0 -634
- orca_sdk/_shared/metrics_test.py +0 -570
- orca_sdk/_utils/data_parsing.py +0 -129
- orca_sdk/_utils/data_parsing_test.py +0 -244
- orca_sdk-0.1.10.dist-info/RECORD +0 -41
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/WHEEL +0 -0
orca_sdk/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
OrcaSDK is a Python library for building and using retrieval augmented models in the OrcaCloud.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
5
|
+
from ._utils.common import UNSET, CreateMode, DropMode, logger
|
|
6
6
|
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
7
7
|
from .client import OrcaClient
|
|
8
8
|
from .credentials import OrcaCredentials
|
|
@@ -23,8 +23,8 @@ from .memoryset import (
|
|
|
23
23
|
ScoredMemoryLookup,
|
|
24
24
|
ScoredMemoryset,
|
|
25
25
|
)
|
|
26
|
-
from .regression_model import RegressionModel
|
|
26
|
+
from .regression_model import RegressionMetrics, RegressionModel
|
|
27
27
|
from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
|
|
28
28
|
|
|
29
29
|
# only specify things that should show up on the root page of the reference docs because they are in private modules
|
|
30
|
-
__all__ = ["UNSET", "CreateMode", "DropMode"]
|
|
30
|
+
__all__ = ["UNSET", "CreateMode", "DropMode", "logger"]
|
orca_sdk/_utils/analysis_ui.py
CHANGED
|
@@ -5,7 +5,10 @@ import re
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import TypedDict, cast
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
try:
|
|
9
|
+
import gradio as gr # type: ignore
|
|
10
|
+
except ImportError as e:
|
|
11
|
+
raise ImportError("gradio is required for UI features. Install it with: pip install orca_sdk[ui]") from e
|
|
9
12
|
|
|
10
13
|
from ..memoryset import LabeledMemory, LabeledMemoryset
|
|
11
14
|
|
orca_sdk/_utils/auth.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
"""This module contains internal utils for managing api keys in tests"""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
4
|
from typing import List, Literal
|
|
6
5
|
|
|
7
6
|
from dotenv import load_dotenv
|
|
8
7
|
|
|
9
8
|
from ..client import ApiKeyMetadata, OrcaClient
|
|
10
|
-
from .common import DropMode
|
|
9
|
+
from .common import DropMode, logger
|
|
11
10
|
|
|
12
11
|
load_dotenv() # this needs to be here to ensure env is populated before accessing it
|
|
13
12
|
|
|
@@ -59,7 +58,7 @@ def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "
|
|
|
59
58
|
client = OrcaClient._resolve_client()
|
|
60
59
|
client.base_url = "http://localhost:1584"
|
|
61
60
|
client.headers.update({"Api-Key": _create_api_key(org_id, api_key_name)})
|
|
62
|
-
|
|
61
|
+
logger.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
|
|
63
62
|
|
|
64
63
|
|
|
65
64
|
__all__ = ["_create_api_key", "_delete_api_key", "_delete_org", "_list_api_keys", "_authenticate_local_api"]
|
orca_sdk/_utils/common.py
CHANGED
|
@@ -1,4 +1,21 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Iterable, Iterator, Literal, TypeVar
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from itertools import batched
|
|
6
|
+
except ImportError:
|
|
7
|
+
# Polyfill for Python <3.12
|
|
8
|
+
|
|
9
|
+
from itertools import islice
|
|
10
|
+
|
|
11
|
+
_BatchT = TypeVar("_BatchT")
|
|
12
|
+
|
|
13
|
+
def batched(iterable: Iterable[_BatchT], n: int) -> Iterator[tuple[_BatchT, ...]]:
|
|
14
|
+
"""Batch an iterable into chunks of size n (backfill for Python <3.12)."""
|
|
15
|
+
it = iter(iterable)
|
|
16
|
+
while batch := tuple(islice(it, n)):
|
|
17
|
+
yield batch
|
|
18
|
+
|
|
2
19
|
|
|
3
20
|
CreateMode = Literal["error", "open"]
|
|
4
21
|
"""
|
|
@@ -35,3 +52,9 @@ UNSET: Any = _UnsetSentinel()
|
|
|
35
52
|
"""
|
|
36
53
|
Default value to indicate that no update should be applied to a field and it should not be set to None
|
|
37
54
|
"""
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger("orca_sdk")
|
|
57
|
+
"""
|
|
58
|
+
Logger for the Orca SDK.
|
|
59
|
+
"""
|
|
60
|
+
logger.addHandler(logging.NullHandler())
|
|
@@ -5,7 +5,10 @@ import re
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
try:
|
|
9
|
+
import gradio as gr # type: ignore
|
|
10
|
+
except ImportError as e:
|
|
11
|
+
raise ImportError("gradio is required for UI features. Install it with: pip install orca_sdk[ui]") from e
|
|
9
12
|
|
|
10
13
|
from ..memoryset import LabeledMemoryLookup, LabeledMemoryset, ScoredMemoryLookup
|
|
11
14
|
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, is_dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
# peer dependencies that are used for types only
|
|
8
|
+
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
9
|
+
from torch.utils.data import Dataset as TorchDataset # type: ignore
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
|
|
13
|
+
if isinstance(item, dict):
|
|
14
|
+
return item
|
|
15
|
+
|
|
16
|
+
if isinstance(item, tuple):
|
|
17
|
+
if column_names is not None:
|
|
18
|
+
if len(item) != len(column_names):
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Tuple length ({len(item)}) does not match number of column names ({len(column_names)})"
|
|
21
|
+
)
|
|
22
|
+
return {column_names[i]: item[i] for i in range(len(item))}
|
|
23
|
+
elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
|
|
24
|
+
return {field: getattr(item, field) for field in item._fields} # type: ignore
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
|
|
27
|
+
|
|
28
|
+
if is_dataclass(item) and not isinstance(item, type):
|
|
29
|
+
return asdict(item)
|
|
30
|
+
|
|
31
|
+
raise ValueError(f"Cannot parse {type(item)}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
|
|
35
|
+
if isinstance(batch, list):
|
|
36
|
+
return [parse_dict_like(item, column_names) for item in batch]
|
|
37
|
+
|
|
38
|
+
batch = parse_dict_like(batch, column_names)
|
|
39
|
+
keys = list(batch.keys())
|
|
40
|
+
batch_size = len(batch[keys[0]])
|
|
41
|
+
for key in keys:
|
|
42
|
+
if not len(batch[key]) == batch_size:
|
|
43
|
+
raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
|
|
44
|
+
return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def list_from_torch(
|
|
48
|
+
torch_data: TorchDataLoader | TorchDataset,
|
|
49
|
+
column_names: list[str] | None = None,
|
|
50
|
+
) -> list[dict]:
|
|
51
|
+
"""
|
|
52
|
+
Convert a PyTorch DataLoader or Dataset to a list of dictionaries.
|
|
53
|
+
|
|
54
|
+
Params:
|
|
55
|
+
torch_data: A PyTorch DataLoader or Dataset object to convert.
|
|
56
|
+
column_names: Optional list of column names to use for the data. If not provided,
|
|
57
|
+
the column names will be inferred from the data.
|
|
58
|
+
Returns:
|
|
59
|
+
A list of dictionaries containing the data from the PyTorch DataLoader or Dataset.
|
|
60
|
+
"""
|
|
61
|
+
# peer dependency that is guaranteed to exist if the user provided a torch dataset
|
|
62
|
+
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
63
|
+
|
|
64
|
+
if isinstance(torch_data, TorchDataLoader):
|
|
65
|
+
dataloader = torch_data
|
|
66
|
+
else:
|
|
67
|
+
dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
|
|
68
|
+
|
|
69
|
+
# Collect data from the dataloader into a list
|
|
70
|
+
data_list = []
|
|
71
|
+
try:
|
|
72
|
+
for batch in dataloader:
|
|
73
|
+
data_list.extend(parse_batch(batch, column_names=column_names))
|
|
74
|
+
except ValueError as e:
|
|
75
|
+
raise ValueError(str(e)) from e
|
|
76
|
+
|
|
77
|
+
return data_list
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from collections import namedtuple
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from .torch_parsing import list_from_torch
|
|
7
|
+
|
|
8
|
+
pytest.importorskip("torch")
|
|
9
|
+
|
|
10
|
+
from torch.utils.data import DataLoader as TorchDataLoader # noqa: E402
|
|
11
|
+
from torch.utils.data import Dataset as TorchDataset # noqa: E402
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_list_from_torch_dict_dataset(data: list[dict]):
|
|
15
|
+
class PytorchDictDataset(TorchDataset):
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.data = data
|
|
18
|
+
|
|
19
|
+
def __getitem__(self, i):
|
|
20
|
+
return self.data[i]
|
|
21
|
+
|
|
22
|
+
def __len__(self):
|
|
23
|
+
return len(self.data)
|
|
24
|
+
|
|
25
|
+
dataset = PytorchDictDataset()
|
|
26
|
+
data_list = list_from_torch(dataset)
|
|
27
|
+
|
|
28
|
+
assert isinstance(data_list, list)
|
|
29
|
+
assert len(data_list) == len(dataset)
|
|
30
|
+
assert set(list(data_list[0].keys())) == {"value", "label", "key", "score", "source_id", "partition_id"}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_list_from_torch_dataloader(data: list[dict]):
|
|
34
|
+
class PytorchDictDataset(TorchDataset):
|
|
35
|
+
def __init__(self):
|
|
36
|
+
self.data = data
|
|
37
|
+
|
|
38
|
+
def __getitem__(self, i):
|
|
39
|
+
return self.data[i]
|
|
40
|
+
|
|
41
|
+
def __len__(self):
|
|
42
|
+
return len(self.data)
|
|
43
|
+
|
|
44
|
+
dataset = PytorchDictDataset()
|
|
45
|
+
|
|
46
|
+
def collate_fn(x: list[dict]):
|
|
47
|
+
return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
|
|
48
|
+
|
|
49
|
+
dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
|
|
50
|
+
data_list = list_from_torch(dataloader)
|
|
51
|
+
|
|
52
|
+
assert isinstance(data_list, list)
|
|
53
|
+
assert len(data_list) == len(dataset)
|
|
54
|
+
assert list(data_list[0].keys()) == ["value", "label"]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def test_list_from_torch_tuple_dataset(data: list[dict]):
|
|
58
|
+
class PytorchTupleDataset(TorchDataset):
|
|
59
|
+
def __init__(self):
|
|
60
|
+
self.data = data
|
|
61
|
+
|
|
62
|
+
def __getitem__(self, i):
|
|
63
|
+
return self.data[i]["value"], self.data[i]["label"]
|
|
64
|
+
|
|
65
|
+
def __len__(self):
|
|
66
|
+
return len(self.data)
|
|
67
|
+
|
|
68
|
+
dataset = PytorchTupleDataset()
|
|
69
|
+
|
|
70
|
+
# raises error if no column names are passed in
|
|
71
|
+
with pytest.raises(ValueError):
|
|
72
|
+
list_from_torch(dataset)
|
|
73
|
+
|
|
74
|
+
# raises error if not enough column names are passed in
|
|
75
|
+
with pytest.raises(ValueError):
|
|
76
|
+
list_from_torch(dataset, column_names=["value"])
|
|
77
|
+
|
|
78
|
+
# creates list if correct number of column names are passed in
|
|
79
|
+
data_list = list_from_torch(dataset, column_names=["value", "label"])
|
|
80
|
+
assert isinstance(data_list, list)
|
|
81
|
+
assert len(data_list) == len(dataset)
|
|
82
|
+
assert list(data_list[0].keys()) == ["value", "label"]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_list_from_torch_named_tuple_dataset(data: list[dict]):
|
|
86
|
+
# Given a Pytorch dataset that returns a namedtuple for each item
|
|
87
|
+
DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
|
|
88
|
+
|
|
89
|
+
class PytorchNamedTupleDataset(TorchDataset):
|
|
90
|
+
def __init__(self):
|
|
91
|
+
self.data = data
|
|
92
|
+
|
|
93
|
+
def __getitem__(self, i):
|
|
94
|
+
return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
|
|
95
|
+
|
|
96
|
+
def __len__(self):
|
|
97
|
+
return len(self.data)
|
|
98
|
+
|
|
99
|
+
dataset = PytorchNamedTupleDataset()
|
|
100
|
+
data_list = list_from_torch(dataset)
|
|
101
|
+
assert isinstance(data_list, list)
|
|
102
|
+
assert len(data_list) == len(dataset)
|
|
103
|
+
assert list(data_list[0].keys()) == ["value", "label"]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_list_from_torch_dataclass_dataset(data: list[dict]):
|
|
107
|
+
@dataclass
|
|
108
|
+
class DatasetItem:
|
|
109
|
+
text: str
|
|
110
|
+
label: int
|
|
111
|
+
|
|
112
|
+
class PytorchDataclassDataset(TorchDataset):
|
|
113
|
+
def __init__(self):
|
|
114
|
+
self.data = data
|
|
115
|
+
|
|
116
|
+
def __getitem__(self, i):
|
|
117
|
+
return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
|
|
118
|
+
|
|
119
|
+
def __len__(self):
|
|
120
|
+
return len(self.data)
|
|
121
|
+
|
|
122
|
+
dataset = PytorchDataclassDataset()
|
|
123
|
+
data_list = list_from_torch(dataset)
|
|
124
|
+
assert isinstance(data_list, list)
|
|
125
|
+
assert len(data_list) == len(dataset)
|
|
126
|
+
assert list(data_list[0].keys()) == ["text", "label"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_list_from_torch_invalid_dataset(data: list[dict]):
|
|
130
|
+
class PytorchInvalidDataset(TorchDataset):
|
|
131
|
+
def __init__(self):
|
|
132
|
+
self.data = data
|
|
133
|
+
|
|
134
|
+
def __getitem__(self, i):
|
|
135
|
+
return [self.data[i]["value"], self.data[i]["label"]]
|
|
136
|
+
|
|
137
|
+
def __len__(self):
|
|
138
|
+
return len(self.data)
|
|
139
|
+
|
|
140
|
+
dataset = PytorchInvalidDataset()
|
|
141
|
+
with pytest.raises(ValueError):
|
|
142
|
+
list_from_torch(dataset)
|
orca_sdk/_utils/value_parser.py
CHANGED
|
@@ -1,27 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import base64
|
|
2
4
|
import io
|
|
3
|
-
from typing import
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
4
6
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
# peer dependencies that are used for types only
|
|
9
|
+
import numpy as np # type: ignore
|
|
10
|
+
from numpy.typing import NDArray # type: ignore
|
|
11
|
+
from PIL import Image as pil # type: ignore
|
|
8
12
|
|
|
9
|
-
ValueType = str | pil.Image | NDArray[np.float32]
|
|
10
|
-
"""
|
|
11
|
-
The type of a value in a memoryset
|
|
13
|
+
ValueType = str | pil.Image | NDArray[np.float32]
|
|
14
|
+
"""
|
|
15
|
+
The type of a value in a memoryset
|
|
12
16
|
|
|
13
|
-
- `str`: string
|
|
14
|
-
- `pil.Image`: image
|
|
15
|
-
- `NDArray[np.float32]`: univariate or multivariate timeseries
|
|
16
|
-
"""
|
|
17
|
+
- `str`: string
|
|
18
|
+
- `pil.Image`: image
|
|
19
|
+
- `NDArray[np.float32]`: univariate or multivariate timeseries
|
|
20
|
+
"""
|
|
21
|
+
else:
|
|
22
|
+
ValueType = Any
|
|
17
23
|
|
|
18
24
|
|
|
19
25
|
def decode_value(value: str) -> ValueType:
|
|
20
26
|
if value.startswith("data:image"):
|
|
27
|
+
try:
|
|
28
|
+
from PIL import Image as pil # type: ignore
|
|
29
|
+
except ImportError as e:
|
|
30
|
+
raise ImportError("Install Pillow to use image values") from e
|
|
31
|
+
|
|
21
32
|
header, data = value.split(",", 1)
|
|
22
33
|
return pil.open(io.BytesIO(base64.b64decode(data)))
|
|
23
34
|
|
|
24
35
|
if value.startswith("data:numpy"):
|
|
36
|
+
try:
|
|
37
|
+
import numpy as np # type: ignore
|
|
38
|
+
except ImportError as e:
|
|
39
|
+
raise ImportError("Install numpy to use timeseries values") from e
|
|
40
|
+
|
|
25
41
|
header, data = value.split(",", 1)
|
|
26
42
|
return np.load(io.BytesIO(base64.b64decode(data)))
|
|
27
43
|
|
|
@@ -29,17 +45,28 @@ def decode_value(value: str) -> ValueType:
|
|
|
29
45
|
|
|
30
46
|
|
|
31
47
|
def encode_value(value: ValueType) -> str:
|
|
32
|
-
|
|
33
|
-
|
|
48
|
+
try:
|
|
49
|
+
from PIL import Image as pil # type: ignore
|
|
50
|
+
except ImportError:
|
|
51
|
+
pil = None # type: ignore[assignment]
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
import numpy as np # type: ignore
|
|
55
|
+
except ImportError:
|
|
56
|
+
np = None # type: ignore[assignment]
|
|
57
|
+
|
|
58
|
+
if pil is not None and isinstance(value, pil.Image):
|
|
59
|
+
header = f"data:image/{value.format.lower()};base64," if value.format else "data:image;base64," # type: ignore[union-attr]
|
|
34
60
|
buffer = io.BytesIO()
|
|
35
|
-
value.save(buffer, format=value.format)
|
|
61
|
+
value.save(buffer, format=value.format) # type: ignore[union-attr]
|
|
36
62
|
bytes = buffer.getvalue()
|
|
37
63
|
return header + base64.b64encode(bytes).decode("utf-8")
|
|
38
64
|
|
|
39
|
-
if isinstance(value, np.ndarray):
|
|
40
|
-
header = f"data:numpy/{value.dtype.name};base64,"
|
|
65
|
+
if np is not None and isinstance(value, np.ndarray):
|
|
66
|
+
header = f"data:numpy/{value.dtype.name};base64," # type: ignore[union-attr]
|
|
41
67
|
buffer = io.BytesIO()
|
|
42
68
|
np.save(buffer, value)
|
|
43
69
|
return header + base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
44
70
|
|
|
45
|
-
|
|
71
|
+
# Value is already a string, or an unhandled type (fall back to str conversion)
|
|
72
|
+
return value if isinstance(value, str) else str(value)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
import
|
|
2
|
-
from PIL import Image as pil
|
|
1
|
+
import pytest
|
|
3
2
|
|
|
4
3
|
from .value_parser import decode_value, encode_value
|
|
5
4
|
|
|
@@ -13,6 +12,7 @@ def test_string_parsing():
|
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def test_image_parsing():
|
|
15
|
+
pil = pytest.importorskip("PIL.Image")
|
|
16
16
|
img = pil.new("RGB", (10, 10), color="red")
|
|
17
17
|
img.format = "PNG"
|
|
18
18
|
|
|
@@ -22,10 +22,11 @@ def test_image_parsing():
|
|
|
22
22
|
|
|
23
23
|
decoded = decode_value(encoded)
|
|
24
24
|
assert isinstance(decoded, pil.Image)
|
|
25
|
-
assert decoded.size == img.size
|
|
25
|
+
assert decoded.size == img.size # type: ignore[union-attr]
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def test_timeseries_parsing():
|
|
29
|
+
np = pytest.importorskip("numpy")
|
|
29
30
|
timeseries = np.random.rand(20, 3).astype(np.float32)
|
|
30
31
|
|
|
31
32
|
encoded = encode_value(timeseries)
|
|
@@ -34,6 +35,6 @@ def test_timeseries_parsing():
|
|
|
34
35
|
|
|
35
36
|
decoded = decode_value(encoded)
|
|
36
37
|
assert isinstance(decoded, np.ndarray)
|
|
37
|
-
assert decoded.shape == timeseries.shape
|
|
38
|
-
assert decoded.dtype == timeseries.dtype
|
|
38
|
+
assert decoded.shape == timeseries.shape # type: ignore[union-attr]
|
|
39
|
+
assert decoded.dtype == timeseries.dtype # type: ignore[union-attr]
|
|
39
40
|
assert np.allclose(decoded, timeseries)
|