datachain 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +4 -9
- datachain/catalog/catalog.py +20 -31
- datachain/client/azure.py +1 -13
- datachain/client/fsspec.py +16 -15
- datachain/client/gcs.py +2 -13
- datachain/client/hf.py +0 -10
- datachain/client/local.py +3 -12
- datachain/client/s3.py +9 -19
- datachain/data_storage/sqlite.py +10 -1
- datachain/data_storage/warehouse.py +11 -17
- datachain/dataset.py +1 -1
- datachain/lib/arrow.py +51 -16
- datachain/lib/dc.py +7 -2
- datachain/lib/file.py +76 -2
- datachain/lib/hf.py +23 -6
- datachain/lib/listing.py +8 -7
- datachain/lib/listing_info.py +2 -2
- datachain/lib/model_store.py +2 -2
- datachain/lib/pytorch.py +32 -26
- datachain/lib/signal_schema.py +157 -60
- datachain/lib/tar.py +33 -0
- datachain/lib/webdataset.py +3 -59
- datachain/listing.py +6 -8
- datachain/node.py +0 -43
- datachain/query/dataset.py +2 -6
- {datachain-0.3.13.dist-info → datachain-0.3.15.dist-info}/METADATA +1 -1
- {datachain-0.3.13.dist-info → datachain-0.3.15.dist-info}/RECORD +31 -30
- {datachain-0.3.13.dist-info → datachain-0.3.15.dist-info}/WHEEL +1 -1
- {datachain-0.3.13.dist-info → datachain-0.3.15.dist-info}/LICENSE +0 -0
- {datachain-0.3.13.dist-info → datachain-0.3.15.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.13.dist-info → datachain-0.3.15.dist-info}/top_level.txt +0 -0
datachain/lib/file.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
import posixpath
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -15,6 +16,9 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
|
15
16
|
from PIL import Image
|
|
16
17
|
from pydantic import Field, field_validator
|
|
17
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from typing_extensions import Self
|
|
21
|
+
|
|
18
22
|
from datachain.cache import UniqueId
|
|
19
23
|
from datachain.client.fileslice import FileSlice
|
|
20
24
|
from datachain.lib.data_model import DataModel
|
|
@@ -25,6 +29,8 @@ from datachain.utils import TIME_ZERO
|
|
|
25
29
|
if TYPE_CHECKING:
|
|
26
30
|
from datachain.catalog import Catalog
|
|
27
31
|
|
|
32
|
+
logger = logging.getLogger("datachain")
|
|
33
|
+
|
|
28
34
|
# how to create file path when exporting
|
|
29
35
|
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
30
36
|
|
|
@@ -251,14 +257,18 @@ class File(DataModel):
|
|
|
251
257
|
dump = self.model_dump()
|
|
252
258
|
return UniqueId(*(dump[k] for k in self._unique_id_keys))
|
|
253
259
|
|
|
254
|
-
def get_local_path(self) -> Optional[str]:
|
|
260
|
+
def get_local_path(self, download: bool = False) -> Optional[str]:
|
|
255
261
|
"""Returns path to a file in a local cache.
|
|
256
262
|
Return None if file is not cached. Throws an exception if cache is not setup."""
|
|
257
263
|
if self._catalog is None:
|
|
258
264
|
raise RuntimeError(
|
|
259
265
|
"cannot resolve local file path because catalog is not setup"
|
|
260
266
|
)
|
|
261
|
-
|
|
267
|
+
uid = self.get_uid()
|
|
268
|
+
if download:
|
|
269
|
+
client = self._catalog.get_client(self.source)
|
|
270
|
+
client.download(uid, callback=self._download_cb)
|
|
271
|
+
return self._catalog.cache.get_path(uid)
|
|
262
272
|
|
|
263
273
|
def get_file_suffix(self):
|
|
264
274
|
"""Returns last part of file name with `.`."""
|
|
@@ -313,6 +323,70 @@ class File(DataModel):
|
|
|
313
323
|
"""Returns `fsspec` filesystem for the file."""
|
|
314
324
|
return self._catalog.get_client(self.source).fs
|
|
315
325
|
|
|
326
|
+
def resolve(self) -> "Self":
|
|
327
|
+
"""
|
|
328
|
+
Resolve a File object by checking its existence and updating its metadata.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
File: The resolved File object with updated metadata.
|
|
332
|
+
"""
|
|
333
|
+
if self._catalog is None:
|
|
334
|
+
raise RuntimeError("Cannot resolve file: catalog is not set")
|
|
335
|
+
|
|
336
|
+
try:
|
|
337
|
+
client = self._catalog.get_client(self.source)
|
|
338
|
+
except NotImplementedError as e:
|
|
339
|
+
raise RuntimeError(
|
|
340
|
+
f"Unsupported protocol for file source: {self.source}"
|
|
341
|
+
) from e
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
info = client.fs.info(client.get_full_path(self.path))
|
|
345
|
+
converted_info = client.info_to_file(info, self.source)
|
|
346
|
+
return type(self)(
|
|
347
|
+
path=self.path,
|
|
348
|
+
source=self.source,
|
|
349
|
+
size=converted_info.size,
|
|
350
|
+
etag=converted_info.etag,
|
|
351
|
+
version=converted_info.version,
|
|
352
|
+
is_latest=converted_info.is_latest,
|
|
353
|
+
last_modified=converted_info.last_modified,
|
|
354
|
+
location=self.location,
|
|
355
|
+
)
|
|
356
|
+
except (FileNotFoundError, PermissionError, OSError) as e:
|
|
357
|
+
logger.warning("File system error when resolving %s: %s", self.path, str(e))
|
|
358
|
+
|
|
359
|
+
return type(self)(
|
|
360
|
+
path=self.path,
|
|
361
|
+
source=self.source,
|
|
362
|
+
size=0,
|
|
363
|
+
etag="",
|
|
364
|
+
version="",
|
|
365
|
+
is_latest=True,
|
|
366
|
+
last_modified=TIME_ZERO,
|
|
367
|
+
location=self.location,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def resolve(file: File) -> File:
|
|
372
|
+
"""
|
|
373
|
+
Resolve a File object by checking its existence and updating its metadata.
|
|
374
|
+
|
|
375
|
+
This function is a wrapper around the File.resolve() method, designed to be
|
|
376
|
+
used as a mapper in DataChain operations.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
file (File): The File object to resolve.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
File: The resolved File object with updated metadata.
|
|
383
|
+
|
|
384
|
+
Raises:
|
|
385
|
+
RuntimeError: If the file's catalog is not set or if
|
|
386
|
+
the file source protocol is unsupported.
|
|
387
|
+
"""
|
|
388
|
+
return file.resolve()
|
|
389
|
+
|
|
316
390
|
|
|
317
391
|
class TextFile(File):
|
|
318
392
|
"""`DataModel` for reading text files."""
|
datachain/lib/hf.py
CHANGED
|
@@ -15,7 +15,7 @@ try:
|
|
|
15
15
|
Value,
|
|
16
16
|
load_dataset,
|
|
17
17
|
)
|
|
18
|
-
from datasets.features.features import string_to_arrow
|
|
18
|
+
from datasets.features.features import Features, string_to_arrow
|
|
19
19
|
from datasets.features.image import image_to_bytes
|
|
20
20
|
|
|
21
21
|
except ImportError as exc:
|
|
@@ -36,6 +36,7 @@ from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
|
|
36
36
|
from datachain.lib.udf import Generator
|
|
37
37
|
|
|
38
38
|
if TYPE_CHECKING:
|
|
39
|
+
import pyarrow as pa
|
|
39
40
|
from pydantic import BaseModel
|
|
40
41
|
|
|
41
42
|
|
|
@@ -71,6 +72,15 @@ class HFGenerator(Generator):
|
|
|
71
72
|
*args,
|
|
72
73
|
**kwargs,
|
|
73
74
|
):
|
|
75
|
+
"""
|
|
76
|
+
Generator for chain from huggingface datasets.
|
|
77
|
+
|
|
78
|
+
Parameters:
|
|
79
|
+
|
|
80
|
+
ds : Path or name of the dataset to read from Hugging Face Hub,
|
|
81
|
+
or an instance of `datasets.Dataset`-like object.
|
|
82
|
+
output_schema : Pydantic model for validation.
|
|
83
|
+
"""
|
|
74
84
|
super().__init__()
|
|
75
85
|
self.ds = ds
|
|
76
86
|
self.output_schema = output_schema
|
|
@@ -92,7 +102,7 @@ class HFGenerator(Generator):
|
|
|
92
102
|
output_dict["split"] = split
|
|
93
103
|
for name, feat in ds.features.items():
|
|
94
104
|
anno = self.output_schema.model_fields[name].annotation
|
|
95
|
-
output_dict[name] =
|
|
105
|
+
output_dict[name] = convert_feature(row[name], feat, anno)
|
|
96
106
|
yield self.output_schema(**output_dict)
|
|
97
107
|
pbar.update(1)
|
|
98
108
|
|
|
@@ -106,7 +116,7 @@ def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
|
|
|
106
116
|
return {"": ds}
|
|
107
117
|
|
|
108
118
|
|
|
109
|
-
def
|
|
119
|
+
def convert_feature(val: Any, feat: Any, anno: Any) -> Any: # noqa: PLR0911
|
|
110
120
|
if isinstance(feat, (Value, Array2D, Array3D, Array4D, Array5D)):
|
|
111
121
|
return val
|
|
112
122
|
if isinstance(feat, ClassLabel):
|
|
@@ -117,20 +127,23 @@ def _convert_feature(val: Any, feat: Any, anno: Any) -> Any:
|
|
|
117
127
|
for sname in val:
|
|
118
128
|
sfeat = feat.feature[sname]
|
|
119
129
|
sanno = anno.model_fields[sname].annotation
|
|
120
|
-
sdict[sname] = [
|
|
130
|
+
sdict[sname] = [convert_feature(v, sfeat, sanno) for v in val[sname]]
|
|
121
131
|
return anno(**sdict)
|
|
122
132
|
return val
|
|
123
133
|
if isinstance(feat, Image):
|
|
134
|
+
if isinstance(val, dict):
|
|
135
|
+
return HFImage(img=val["bytes"])
|
|
124
136
|
return HFImage(img=image_to_bytes(val))
|
|
125
137
|
if isinstance(feat, Audio):
|
|
126
138
|
return HFAudio(**val)
|
|
127
139
|
|
|
128
140
|
|
|
129
141
|
def get_output_schema(
|
|
130
|
-
|
|
142
|
+
features: Features, model_name: str = "", stream: bool = True
|
|
131
143
|
) -> dict[str, DataType]:
|
|
144
|
+
"""Generate UDF output schema from huggingface datasets features."""
|
|
132
145
|
fields_dict = {}
|
|
133
|
-
for name, val in
|
|
146
|
+
for name, val in features.items():
|
|
134
147
|
fields_dict[name] = _feature_to_chain_type(name, val) # type: ignore[assignment]
|
|
135
148
|
return fields_dict # type: ignore[return-value]
|
|
136
149
|
|
|
@@ -165,3 +178,7 @@ def _feature_to_chain_type(name: str, val: Any) -> type: # noqa: PLR0911
|
|
|
165
178
|
if isinstance(val, Audio):
|
|
166
179
|
return HFAudio
|
|
167
180
|
raise TypeError(f"Unknown huggingface datasets type {type(val)}")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def schema_from_arrow(schema: "pa.Schema"):
|
|
184
|
+
return Features.from_arrow_schema(schema)
|
datachain/lib/listing.py
CHANGED
|
@@ -20,7 +20,7 @@ LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
|
|
|
20
20
|
LISTING_PREFIX = "lst__" # listing datasets start with this name
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def list_bucket(uri: str, client_config=None) -> Callable:
|
|
23
|
+
def list_bucket(uri: str, cache, client_config=None) -> Callable:
|
|
24
24
|
"""
|
|
25
25
|
Function that returns another generator function that yields File objects
|
|
26
26
|
from bucket where each File represents one bucket entry.
|
|
@@ -28,10 +28,10 @@ def list_bucket(uri: str, client_config=None) -> Callable:
|
|
|
28
28
|
|
|
29
29
|
def list_func() -> Iterator[File]:
|
|
30
30
|
config = client_config or {}
|
|
31
|
-
client
|
|
31
|
+
client = Client.get_client(uri, cache, **config) # type: ignore[arg-type]
|
|
32
|
+
_, path = Client.parse_url(uri)
|
|
32
33
|
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
|
|
33
|
-
|
|
34
|
-
yield entry.to_file(client.uri)
|
|
34
|
+
yield from entries
|
|
35
35
|
|
|
36
36
|
return list_func
|
|
37
37
|
|
|
@@ -77,16 +77,17 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
|
|
|
77
77
|
"""
|
|
78
78
|
Parsing uri and returns listing dataset name, listing uri and listing path
|
|
79
79
|
"""
|
|
80
|
-
client
|
|
80
|
+
client = Client.get_client(uri, cache, **client_config)
|
|
81
|
+
storage_uri, path = Client.parse_url(uri)
|
|
81
82
|
|
|
82
83
|
# clean path without globs
|
|
83
84
|
lst_uri_path = (
|
|
84
85
|
posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
|
|
85
86
|
)
|
|
86
87
|
|
|
87
|
-
lst_uri = f"{
|
|
88
|
+
lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
|
|
88
89
|
ds_name = (
|
|
89
|
-
f"{LISTING_PREFIX}{
|
|
90
|
+
f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
|
|
90
91
|
)
|
|
91
92
|
|
|
92
93
|
return ds_name, lst_uri, path
|
datachain/lib/listing_info.py
CHANGED
|
@@ -13,8 +13,8 @@ class ListingInfo(DatasetInfo):
|
|
|
13
13
|
|
|
14
14
|
@property
|
|
15
15
|
def storage_uri(self) -> str:
|
|
16
|
-
|
|
17
|
-
return
|
|
16
|
+
uri, _ = Client.parse_url(self.uri)
|
|
17
|
+
return uri
|
|
18
18
|
|
|
19
19
|
@property
|
|
20
20
|
def expires(self) -> Optional[datetime]:
|
datachain/lib/model_store.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import logging
|
|
3
|
-
from typing import ClassVar, Optional
|
|
3
|
+
from typing import Any, ClassVar, Optional
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
@@ -69,7 +69,7 @@ class ModelStore:
|
|
|
69
69
|
del cls.store[fr.__name__][version]
|
|
70
70
|
|
|
71
71
|
@staticmethod
|
|
72
|
-
def is_pydantic(val):
|
|
72
|
+
def is_pydantic(val: Any) -> bool:
|
|
73
73
|
return (
|
|
74
74
|
not hasattr(val, "__origin__")
|
|
75
75
|
and inspect.isclass(val)
|
datachain/lib/pytorch.py
CHANGED
|
@@ -7,6 +7,7 @@ from torch import float32
|
|
|
7
7
|
from torch.distributed import get_rank, get_world_size
|
|
8
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
9
|
from torchvision.transforms import v2
|
|
10
|
+
from tqdm import tqdm
|
|
10
11
|
|
|
11
12
|
from datachain.catalog import Catalog, get_catalog
|
|
12
13
|
from datachain.lib.dc import DataChain
|
|
@@ -93,33 +94,38 @@ class PytorchDataset(IterableDataset):
|
|
|
93
94
|
if self.num_samples > 0:
|
|
94
95
|
ds = ds.sample(self.num_samples)
|
|
95
96
|
ds = ds.chunk(total_rank, total_workers)
|
|
96
|
-
for
|
|
97
|
-
|
|
98
|
-
for
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
97
|
+
desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
|
|
98
|
+
with tqdm(desc=desc, unit=" rows") as pbar:
|
|
99
|
+
for row_features in ds.collect():
|
|
100
|
+
row = []
|
|
101
|
+
for fr in row_features:
|
|
102
|
+
if hasattr(fr, "read"):
|
|
103
|
+
row.append(fr.read()) # type: ignore[unreachable]
|
|
104
|
+
else:
|
|
105
|
+
row.append(fr)
|
|
106
|
+
# Apply transforms
|
|
107
|
+
if self.transform:
|
|
108
|
+
try:
|
|
109
|
+
if isinstance(self.transform, v2.Transform):
|
|
110
|
+
row = self.transform(row)
|
|
111
|
+
for i, val in enumerate(row):
|
|
112
|
+
if isinstance(val, Image.Image):
|
|
113
|
+
row[i] = self.transform(val)
|
|
114
|
+
except ValueError:
|
|
115
|
+
logger.warning(
|
|
116
|
+
"Skipping transform due to unsupported data types."
|
|
117
|
+
)
|
|
118
|
+
self.transform = None
|
|
119
|
+
if self.tokenizer:
|
|
108
120
|
for i, val in enumerate(row):
|
|
109
|
-
if isinstance(val,
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
isinstance(val, list) and isinstance(val[0], str)
|
|
118
|
-
):
|
|
119
|
-
row[i] = convert_text(
|
|
120
|
-
val, self.tokenizer, self.tokenizer_kwargs
|
|
121
|
-
).squeeze(0) # type: ignore[union-attr]
|
|
122
|
-
yield row
|
|
121
|
+
if isinstance(val, str) or (
|
|
122
|
+
isinstance(val, list) and isinstance(val[0], str)
|
|
123
|
+
):
|
|
124
|
+
row[i] = convert_text(
|
|
125
|
+
val, self.tokenizer, self.tokenizer_kwargs
|
|
126
|
+
).squeeze(0) # type: ignore[union-attr]
|
|
127
|
+
yield row
|
|
128
|
+
pbar.update(1)
|
|
123
129
|
|
|
124
130
|
@staticmethod
|
|
125
131
|
def get_rank_and_workers() -> tuple[int, int]:
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -4,11 +4,14 @@ from collections.abc import Iterator, Sequence
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from datetime import datetime
|
|
6
6
|
from inspect import isclass
|
|
7
|
-
from typing import (
|
|
7
|
+
from typing import ( # noqa: UP035
|
|
8
8
|
TYPE_CHECKING,
|
|
9
9
|
Annotated,
|
|
10
10
|
Any,
|
|
11
11
|
Callable,
|
|
12
|
+
Dict,
|
|
13
|
+
Final,
|
|
14
|
+
List,
|
|
12
15
|
Literal,
|
|
13
16
|
Optional,
|
|
14
17
|
Union,
|
|
@@ -42,8 +45,13 @@ NAMES_TO_TYPES = {
|
|
|
42
45
|
"dict": dict,
|
|
43
46
|
"bytes": bytes,
|
|
44
47
|
"datetime": datetime,
|
|
45
|
-
"
|
|
48
|
+
"Final": Final,
|
|
46
49
|
"Union": Union,
|
|
50
|
+
"Optional": Optional,
|
|
51
|
+
"List": list,
|
|
52
|
+
"Dict": dict,
|
|
53
|
+
"Literal": Any,
|
|
54
|
+
"Any": Any,
|
|
47
55
|
}
|
|
48
56
|
|
|
49
57
|
|
|
@@ -146,35 +154,11 @@ class SignalSchema:
|
|
|
146
154
|
return SignalSchema(signals)
|
|
147
155
|
|
|
148
156
|
@staticmethod
|
|
149
|
-
def
|
|
150
|
-
|
|
151
|
-
based on whether the type is Optional or not."""
|
|
152
|
-
orig = get_origin(fr_type)
|
|
153
|
-
args = get_args(fr_type)
|
|
154
|
-
# Check if fr_type is Optional
|
|
155
|
-
if orig == Union and len(args) == 2 and (type(None) in args):
|
|
156
|
-
fr_type = args[0]
|
|
157
|
-
orig = get_origin(fr_type)
|
|
158
|
-
if orig in (Literal, LiteralEx):
|
|
159
|
-
# Literal has no __name__ in Python 3.9
|
|
160
|
-
type_name = "Literal"
|
|
161
|
-
elif orig == Union:
|
|
162
|
-
# Union also has no __name__ in Python 3.9
|
|
163
|
-
type_name = "Union"
|
|
164
|
-
else:
|
|
165
|
-
type_name = str(fr_type.__name__) # type: ignore[union-attr]
|
|
166
|
-
return type_name, fr_type
|
|
167
|
-
|
|
168
|
-
@staticmethod
|
|
169
|
-
def serialize_custom_model_fields(
|
|
170
|
-
name: str, fr: type, custom_types: dict[str, Any]
|
|
157
|
+
def _serialize_custom_model_fields(
|
|
158
|
+
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
|
|
171
159
|
) -> str:
|
|
172
160
|
"""This serializes any custom type information to the provided custom_types
|
|
173
|
-
dict, and returns the name of the type
|
|
174
|
-
if hasattr(fr, "__origin__") or not issubclass(fr, BaseModel):
|
|
175
|
-
# Don't store non-feature types.
|
|
176
|
-
return name
|
|
177
|
-
version_name = ModelStore.get_name(fr)
|
|
161
|
+
dict, and returns the name of the type serialized."""
|
|
178
162
|
if version_name in custom_types:
|
|
179
163
|
# This type is already stored in custom_types.
|
|
180
164
|
return version_name
|
|
@@ -183,37 +167,102 @@ class SignalSchema:
|
|
|
183
167
|
field_type = info.annotation
|
|
184
168
|
# All fields should be typed.
|
|
185
169
|
assert field_type
|
|
186
|
-
|
|
187
|
-
field_type
|
|
188
|
-
)
|
|
189
|
-
# Serialize this type to custom_types if it is a custom type as well.
|
|
190
|
-
fields[field_name] = SignalSchema.serialize_custom_model_fields(
|
|
191
|
-
field_type_name, field_type, custom_types
|
|
192
|
-
)
|
|
170
|
+
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
|
|
193
171
|
custom_types[version_name] = fields
|
|
194
172
|
return version_name
|
|
195
173
|
|
|
174
|
+
@staticmethod
|
|
175
|
+
def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
|
|
176
|
+
"""Serialize a given type to a string, including automatic ModelStore
|
|
177
|
+
registration, and save this type and subtypes to custom_types as well."""
|
|
178
|
+
subtypes: list[Any] = []
|
|
179
|
+
type_name = SignalSchema._type_to_str(fr, subtypes)
|
|
180
|
+
# Iterate over all subtypes (includes the input type).
|
|
181
|
+
for st in subtypes:
|
|
182
|
+
if st is None or not ModelStore.is_pydantic(st):
|
|
183
|
+
continue
|
|
184
|
+
# Register and save feature types.
|
|
185
|
+
ModelStore.register(st)
|
|
186
|
+
st_version_name = ModelStore.get_name(st)
|
|
187
|
+
if st is fr:
|
|
188
|
+
# If the main type is Pydantic, then use the ModelStore version name.
|
|
189
|
+
type_name = st_version_name
|
|
190
|
+
# Save this type to custom_types.
|
|
191
|
+
SignalSchema._serialize_custom_model_fields(
|
|
192
|
+
st_version_name, st, custom_types
|
|
193
|
+
)
|
|
194
|
+
return type_name
|
|
195
|
+
|
|
196
196
|
def serialize(self) -> dict[str, Any]:
|
|
197
197
|
signals: dict[str, Any] = {}
|
|
198
198
|
custom_types: dict[str, Any] = {}
|
|
199
199
|
for name, fr_type in self.values.items():
|
|
200
|
-
|
|
201
|
-
ModelStore.register(fr)
|
|
202
|
-
signals[name] = ModelStore.get_name(fr)
|
|
203
|
-
type_name, fr_type = SignalSchema._get_name_original_type(fr)
|
|
204
|
-
else:
|
|
205
|
-
type_name, fr_type = SignalSchema._get_name_original_type(fr_type)
|
|
206
|
-
signals[name] = type_name
|
|
207
|
-
self.serialize_custom_model_fields(type_name, fr_type, custom_types)
|
|
200
|
+
signals[name] = self._serialize_type(fr_type, custom_types)
|
|
208
201
|
if custom_types:
|
|
209
202
|
signals["_custom_types"] = custom_types
|
|
210
203
|
return signals
|
|
211
204
|
|
|
212
205
|
@staticmethod
|
|
213
|
-
def
|
|
206
|
+
def _split_subtypes(type_name: str) -> list[str]:
|
|
207
|
+
"""This splits a list of subtypes, including proper square bracket handling."""
|
|
208
|
+
start = 0
|
|
209
|
+
depth = 0
|
|
210
|
+
subtypes = []
|
|
211
|
+
for i, c in enumerate(type_name):
|
|
212
|
+
if c == "[":
|
|
213
|
+
depth += 1
|
|
214
|
+
elif c == "]":
|
|
215
|
+
if depth == 0:
|
|
216
|
+
raise TypeError(
|
|
217
|
+
"Extra closing square bracket when parsing subtype list"
|
|
218
|
+
)
|
|
219
|
+
depth -= 1
|
|
220
|
+
elif c == "," and depth == 0:
|
|
221
|
+
subtypes.append(type_name[start:i].strip())
|
|
222
|
+
start = i + 1
|
|
223
|
+
if depth > 0:
|
|
224
|
+
raise TypeError("Unclosed square bracket when parsing subtype list")
|
|
225
|
+
subtypes.append(type_name[start:].strip())
|
|
226
|
+
return subtypes
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
|
|
214
230
|
"""Convert a string-based type back into a python type."""
|
|
231
|
+
type_name = type_name.strip()
|
|
232
|
+
if not type_name:
|
|
233
|
+
raise TypeError("Type cannot be empty")
|
|
234
|
+
if type_name == "NoneType":
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
bracket_idx = type_name.find("[")
|
|
238
|
+
subtypes: Optional[tuple[Optional[type], ...]] = None
|
|
239
|
+
if bracket_idx > -1:
|
|
240
|
+
if bracket_idx == 0:
|
|
241
|
+
raise TypeError("Type cannot start with '['")
|
|
242
|
+
close_bracket_idx = type_name.rfind("]")
|
|
243
|
+
if close_bracket_idx == -1:
|
|
244
|
+
raise TypeError("Unclosed square bracket when parsing type")
|
|
245
|
+
if close_bracket_idx < bracket_idx:
|
|
246
|
+
raise TypeError("Square brackets are out of order when parsing type")
|
|
247
|
+
if close_bracket_idx == bracket_idx + 1:
|
|
248
|
+
raise TypeError("Empty square brackets when parsing type")
|
|
249
|
+
subtype_names = SignalSchema._split_subtypes(
|
|
250
|
+
type_name[bracket_idx + 1 : close_bracket_idx]
|
|
251
|
+
)
|
|
252
|
+
# Types like Union require the parameters to be a tuple of types.
|
|
253
|
+
subtypes = tuple(
|
|
254
|
+
SignalSchema._resolve_type(st, custom_types) for st in subtype_names
|
|
255
|
+
)
|
|
256
|
+
type_name = type_name[:bracket_idx].strip()
|
|
257
|
+
|
|
215
258
|
fr = NAMES_TO_TYPES.get(type_name)
|
|
216
259
|
if fr:
|
|
260
|
+
if subtypes:
|
|
261
|
+
if len(subtypes) == 1:
|
|
262
|
+
# Types like Optional require there to be only one argument.
|
|
263
|
+
return fr[subtypes[0]] # type: ignore[index]
|
|
264
|
+
# Other types like Union require the parameters to be a tuple of types.
|
|
265
|
+
return fr[subtypes] # type: ignore[index]
|
|
217
266
|
return fr # type: ignore[return-value]
|
|
218
267
|
|
|
219
268
|
model_name, version = ModelStore.parse_name_version(type_name)
|
|
@@ -228,7 +277,14 @@ class SignalSchema:
|
|
|
228
277
|
for field_name, field_type_str in fields.items()
|
|
229
278
|
}
|
|
230
279
|
return create_feature_model(type_name, fields)
|
|
231
|
-
|
|
280
|
+
# This can occur if a third-party or custom type is used, which is not available
|
|
281
|
+
# when deserializing.
|
|
282
|
+
warnings.warn(
|
|
283
|
+
f"Could not resolve type: '{type_name}'.",
|
|
284
|
+
SignalSchemaWarning,
|
|
285
|
+
stacklevel=2,
|
|
286
|
+
)
|
|
287
|
+
return Any # type: ignore[return-value]
|
|
232
288
|
|
|
233
289
|
@staticmethod
|
|
234
290
|
def deserialize(schema: dict[str, Any]) -> "SignalSchema":
|
|
@@ -242,9 +298,14 @@ class SignalSchema:
|
|
|
242
298
|
# This entry is used as a lookup for custom types,
|
|
243
299
|
# and is not an actual field.
|
|
244
300
|
continue
|
|
301
|
+
if not isinstance(type_name, str):
|
|
302
|
+
raise SignalSchemaError(
|
|
303
|
+
f"cannot deserialize '{type_name}': "
|
|
304
|
+
"serialized types must be a string"
|
|
305
|
+
)
|
|
245
306
|
try:
|
|
246
307
|
fr = SignalSchema._resolve_type(type_name, custom_types)
|
|
247
|
-
if fr is
|
|
308
|
+
if fr is Any:
|
|
248
309
|
# Skip if the type is not found, so all data can be displayed.
|
|
249
310
|
warnings.warn(
|
|
250
311
|
f"In signal '{signal}': "
|
|
@@ -258,7 +319,7 @@ class SignalSchema:
|
|
|
258
319
|
raise SignalSchemaError(
|
|
259
320
|
f"cannot deserialize '{signal}': {err}"
|
|
260
321
|
) from err
|
|
261
|
-
signals[signal] = fr
|
|
322
|
+
signals[signal] = fr # type: ignore[assignment]
|
|
262
323
|
|
|
263
324
|
return SignalSchema(signals)
|
|
264
325
|
|
|
@@ -325,11 +386,20 @@ class SignalSchema:
|
|
|
325
386
|
else:
|
|
326
387
|
json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
|
|
327
388
|
obj = fr(**json)
|
|
328
|
-
|
|
329
|
-
obj._set_stream(catalog, caching_enabled=cache)
|
|
389
|
+
SignalSchema._set_file_stream(obj, catalog, cache)
|
|
330
390
|
res.append(obj)
|
|
331
391
|
return res
|
|
332
392
|
|
|
393
|
+
@staticmethod
|
|
394
|
+
def _set_file_stream(
|
|
395
|
+
obj: BaseModel, catalog: "Catalog", cache: bool = False
|
|
396
|
+
) -> None:
|
|
397
|
+
if isinstance(obj, File):
|
|
398
|
+
obj._set_stream(catalog, caching_enabled=cache)
|
|
399
|
+
for field, finfo in obj.model_fields.items():
|
|
400
|
+
if ModelStore.is_pydantic(finfo.annotation):
|
|
401
|
+
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
|
|
402
|
+
|
|
333
403
|
def db_signals(
|
|
334
404
|
self, name: Optional[str] = None, as_columns=False
|
|
335
405
|
) -> Union[list[str], list[Column]]:
|
|
@@ -509,31 +579,58 @@ class SignalSchema:
|
|
|
509
579
|
return self.values.pop(name)
|
|
510
580
|
|
|
511
581
|
@staticmethod
|
|
512
|
-
def _type_to_str(type_): # noqa: PLR0911
|
|
582
|
+
def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
|
|
583
|
+
"""Convert a type to a string-based representation."""
|
|
584
|
+
if type_ is None:
|
|
585
|
+
return "NoneType"
|
|
586
|
+
|
|
513
587
|
origin = get_origin(type_)
|
|
514
588
|
|
|
515
589
|
if origin == Union:
|
|
516
590
|
args = get_args(type_)
|
|
517
|
-
formatted_types = ", ".join(
|
|
591
|
+
formatted_types = ", ".join(
|
|
592
|
+
SignalSchema._type_to_str(arg, subtypes) for arg in args
|
|
593
|
+
)
|
|
518
594
|
return f"Union[{formatted_types}]"
|
|
519
595
|
if origin == Optional:
|
|
520
596
|
args = get_args(type_)
|
|
521
|
-
type_str = SignalSchema._type_to_str(args[0])
|
|
597
|
+
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
522
598
|
return f"Optional[{type_str}]"
|
|
523
|
-
if origin
|
|
599
|
+
if origin in (list, List): # noqa: UP006
|
|
524
600
|
args = get_args(type_)
|
|
525
|
-
type_str = SignalSchema._type_to_str(args[0])
|
|
601
|
+
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
526
602
|
return f"list[{type_str}]"
|
|
527
|
-
if origin
|
|
603
|
+
if origin in (dict, Dict): # noqa: UP006
|
|
528
604
|
args = get_args(type_)
|
|
529
|
-
type_str =
|
|
530
|
-
|
|
605
|
+
type_str = (
|
|
606
|
+
SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
|
|
607
|
+
)
|
|
608
|
+
vals = (
|
|
609
|
+
f", {SignalSchema._type_to_str(args[1], subtypes)}"
|
|
610
|
+
if len(args) > 1
|
|
611
|
+
else ""
|
|
612
|
+
)
|
|
531
613
|
return f"dict[{type_str}{vals}]"
|
|
532
614
|
if origin == Annotated:
|
|
533
615
|
args = get_args(type_)
|
|
534
|
-
return SignalSchema._type_to_str(args[0])
|
|
535
|
-
if origin in (Literal, LiteralEx):
|
|
616
|
+
return SignalSchema._type_to_str(args[0], subtypes)
|
|
617
|
+
if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx):
|
|
536
618
|
return "Literal"
|
|
619
|
+
if Any in (origin, type_):
|
|
620
|
+
return "Any"
|
|
621
|
+
if Final in (origin, type_):
|
|
622
|
+
return "Final"
|
|
623
|
+
if subtypes is not None:
|
|
624
|
+
# Include this type in the list of all subtypes, if requested.
|
|
625
|
+
subtypes.append(type_)
|
|
626
|
+
if not hasattr(type_, "__name__"):
|
|
627
|
+
# This can happen for some third-party or custom types, mostly on Python 3.9
|
|
628
|
+
warnings.warn(
|
|
629
|
+
f"Unable to determine name of type '{type_}'.",
|
|
630
|
+
SignalSchemaWarning,
|
|
631
|
+
stacklevel=2,
|
|
632
|
+
)
|
|
633
|
+
return "Any"
|
|
537
634
|
return type_.__name__
|
|
538
635
|
|
|
539
636
|
@staticmethod
|