datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +16 -6
- datachain/cache.py +32 -10
- datachain/catalog/catalog.py +17 -1
- datachain/cli/__init__.py +311 -0
- datachain/cli/commands/__init__.py +29 -0
- datachain/cli/commands/datasets.py +129 -0
- datachain/cli/commands/du.py +14 -0
- datachain/cli/commands/index.py +12 -0
- datachain/cli/commands/ls.py +169 -0
- datachain/cli/commands/misc.py +28 -0
- datachain/cli/commands/query.py +53 -0
- datachain/cli/commands/show.py +38 -0
- datachain/cli/parser/__init__.py +547 -0
- datachain/cli/parser/job.py +120 -0
- datachain/cli/parser/studio.py +126 -0
- datachain/cli/parser/utils.py +63 -0
- datachain/{cli_utils.py → cli/utils.py} +27 -1
- datachain/client/azure.py +6 -2
- datachain/client/fsspec.py +9 -3
- datachain/client/gcs.py +6 -2
- datachain/client/s3.py +16 -1
- datachain/data_storage/db_engine.py +9 -0
- datachain/data_storage/schema.py +4 -10
- datachain/data_storage/sqlite.py +7 -1
- datachain/data_storage/warehouse.py +6 -4
- datachain/{lib/diff.py → diff/__init__.py} +116 -12
- datachain/func/__init__.py +3 -2
- datachain/func/conditional.py +74 -0
- datachain/func/func.py +5 -1
- datachain/lib/arrow.py +7 -1
- datachain/lib/dc.py +8 -3
- datachain/lib/file.py +16 -5
- datachain/lib/hf.py +1 -1
- datachain/lib/listing.py +19 -1
- datachain/lib/pytorch.py +57 -13
- datachain/lib/signal_schema.py +89 -27
- datachain/lib/udf.py +82 -40
- datachain/listing.py +1 -0
- datachain/progress.py +20 -3
- datachain/query/dataset.py +122 -93
- datachain/query/dispatch.py +22 -16
- datachain/studio.py +58 -38
- datachain/utils.py +14 -3
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
- datachain/cli.py +0 -1475
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
datachain/func/conditional.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
from typing import Union
|
|
2
2
|
|
|
3
|
+
from sqlalchemy import case as sql_case
|
|
4
|
+
from sqlalchemy.sql.elements import BinaryExpression
|
|
5
|
+
|
|
6
|
+
from datachain.lib.utils import DataChainParamsError
|
|
3
7
|
from datachain.sql.functions import conditional
|
|
4
8
|
|
|
5
9
|
from .func import ColT, Func
|
|
6
10
|
|
|
11
|
+
CaseT = Union[int, float, complex, bool, str]
|
|
12
|
+
|
|
7
13
|
|
|
8
14
|
def greatest(*args: Union[ColT, float]) -> Func:
|
|
9
15
|
"""
|
|
@@ -79,3 +85,71 @@ def least(*args: Union[ColT, float]) -> Func:
|
|
|
79
85
|
return Func(
|
|
80
86
|
"least", inner=conditional.least, cols=cols, args=func_args, result_type=int
|
|
81
87
|
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
|
|
91
|
+
"""
|
|
92
|
+
Returns the case function that produces case expression which has a list of
|
|
93
|
+
conditions and corresponding results. Results can only be python primitives
|
|
94
|
+
like string, numbes or booleans. Result type is inferred from condition results.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
args (tuple(BinaryExpression, value(str | int | float | complex | bool):
|
|
98
|
+
- Tuple of binary expression and values pair which corresponds to one
|
|
99
|
+
case condition - value
|
|
100
|
+
else_ (str | int | float | complex | bool): else value in case expression
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Func: A Func object that represents the case function.
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
```py
|
|
107
|
+
dc.mutate(
|
|
108
|
+
res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
|
|
109
|
+
)
|
|
110
|
+
```
|
|
111
|
+
"""
|
|
112
|
+
supported_types = [int, float, complex, str, bool]
|
|
113
|
+
|
|
114
|
+
type_ = type(else_) if else_ else None
|
|
115
|
+
|
|
116
|
+
if not args:
|
|
117
|
+
raise DataChainParamsError("Missing statements")
|
|
118
|
+
|
|
119
|
+
for arg in args:
|
|
120
|
+
if type_ and not isinstance(arg[1], type_):
|
|
121
|
+
raise DataChainParamsError("Statement values must be of the same type")
|
|
122
|
+
type_ = type(arg[1])
|
|
123
|
+
|
|
124
|
+
if type_ not in supported_types:
|
|
125
|
+
raise DataChainParamsError(
|
|
126
|
+
f"Only python literals ({supported_types}) are supported for values"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
kwargs = {"else_": else_}
|
|
130
|
+
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
|
|
134
|
+
"""
|
|
135
|
+
Returns the ifelse function that produces if expression which has a condition
|
|
136
|
+
and values for true and false outcome. Results can only be python primitives
|
|
137
|
+
like string, numbes or booleans. Result type is inferred from the values.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
condition: BinaryExpression - condition which is evaluated
|
|
141
|
+
if_val: (str | int | float | complex | bool): value for true condition outcome
|
|
142
|
+
else_val: (str | int | float | complex | bool): value for false condition
|
|
143
|
+
outcome
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Func: A Func object that represents the ifelse function.
|
|
147
|
+
|
|
148
|
+
Example:
|
|
149
|
+
```py
|
|
150
|
+
dc.mutate(
|
|
151
|
+
res=func.ifelse(C("num") > 0, "P", "N"),
|
|
152
|
+
)
|
|
153
|
+
```
|
|
154
|
+
"""
|
|
155
|
+
return case((condition, if_val), else_=else_val)
|
datachain/func/func.py
CHANGED
|
@@ -35,6 +35,7 @@ class Func(Function):
|
|
|
35
35
|
inner: Callable,
|
|
36
36
|
cols: Optional[Sequence[ColT]] = None,
|
|
37
37
|
args: Optional[Sequence[Any]] = None,
|
|
38
|
+
kwargs: Optional[dict[str, Any]] = None,
|
|
38
39
|
result_type: Optional["DataType"] = None,
|
|
39
40
|
is_array: bool = False,
|
|
40
41
|
is_window: bool = False,
|
|
@@ -45,6 +46,7 @@ class Func(Function):
|
|
|
45
46
|
self.inner = inner
|
|
46
47
|
self.cols = cols or []
|
|
47
48
|
self.args = args or []
|
|
49
|
+
self.kwargs = kwargs or {}
|
|
48
50
|
self.result_type = result_type
|
|
49
51
|
self.is_array = is_array
|
|
50
52
|
self.is_window = is_window
|
|
@@ -63,6 +65,7 @@ class Func(Function):
|
|
|
63
65
|
self.inner,
|
|
64
66
|
self.cols,
|
|
65
67
|
self.args,
|
|
68
|
+
self.kwargs,
|
|
66
69
|
self.result_type,
|
|
67
70
|
self.is_array,
|
|
68
71
|
self.is_window,
|
|
@@ -333,6 +336,7 @@ class Func(Function):
|
|
|
333
336
|
self.inner,
|
|
334
337
|
self.cols,
|
|
335
338
|
self.args,
|
|
339
|
+
self.kwargs,
|
|
336
340
|
self.result_type,
|
|
337
341
|
self.is_array,
|
|
338
342
|
self.is_window,
|
|
@@ -387,7 +391,7 @@ class Func(Function):
|
|
|
387
391
|
return col
|
|
388
392
|
|
|
389
393
|
cols = [get_col(col) for col in self._db_cols]
|
|
390
|
-
func_col = self.inner(*cols, *self.args)
|
|
394
|
+
func_col = self.inner(*cols, *self.args, **self.kwargs)
|
|
391
395
|
|
|
392
396
|
if self.is_window:
|
|
393
397
|
if not self.window:
|
datachain/lib/arrow.py
CHANGED
|
@@ -91,7 +91,9 @@ class ArrowGenerator(Generator):
|
|
|
91
91
|
yield from record_batch.to_pylist()
|
|
92
92
|
|
|
93
93
|
it = islice(iter_records(), self.nrows)
|
|
94
|
-
with tqdm(
|
|
94
|
+
with tqdm(
|
|
95
|
+
it, desc="Parsed by pyarrow", unit="rows", total=self.nrows, leave=False
|
|
96
|
+
) as pbar:
|
|
95
97
|
for index, record in enumerate(pbar):
|
|
96
98
|
yield self._process_record(
|
|
97
99
|
record, file, index, hf_schema, use_datachain_schema
|
|
@@ -149,6 +151,10 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
|
149
151
|
for file in chain.collect("file"):
|
|
150
152
|
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
151
153
|
schemas.append(ds.schema)
|
|
154
|
+
if not schemas:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Cannot infer schema (no files to process or can't access them)"
|
|
157
|
+
)
|
|
152
158
|
return pa.unify_schemas(schemas)
|
|
153
159
|
|
|
154
160
|
|
datachain/lib/dc.py
CHANGED
|
@@ -451,6 +451,7 @@ class DataChain:
|
|
|
451
451
|
return dc
|
|
452
452
|
|
|
453
453
|
if update or not list_ds_exists:
|
|
454
|
+
# disable prefetch for listing, as it pre-downloads all files
|
|
454
455
|
(
|
|
455
456
|
cls.from_records(
|
|
456
457
|
DataChain.DEFAULT_FILE_RECORD,
|
|
@@ -458,6 +459,7 @@ class DataChain:
|
|
|
458
459
|
settings=settings,
|
|
459
460
|
in_memory=in_memory,
|
|
460
461
|
)
|
|
462
|
+
.settings(prefetch=0)
|
|
461
463
|
.gen(
|
|
462
464
|
list_bucket(list_uri, cache, client_config=client_config),
|
|
463
465
|
output={f"{object_name}": File},
|
|
@@ -1534,7 +1536,7 @@ class DataChain:
|
|
|
1534
1536
|
|
|
1535
1537
|
Example:
|
|
1536
1538
|
```py
|
|
1537
|
-
|
|
1539
|
+
res = persons.compare(
|
|
1538
1540
|
new_persons,
|
|
1539
1541
|
on=["id"],
|
|
1540
1542
|
right_on=["other_id"],
|
|
@@ -1547,9 +1549,9 @@ class DataChain:
|
|
|
1547
1549
|
)
|
|
1548
1550
|
```
|
|
1549
1551
|
"""
|
|
1550
|
-
from datachain.
|
|
1552
|
+
from datachain.diff import _compare
|
|
1551
1553
|
|
|
1552
|
-
return
|
|
1554
|
+
return _compare(
|
|
1553
1555
|
self,
|
|
1554
1556
|
other,
|
|
1555
1557
|
on,
|
|
@@ -1882,6 +1884,9 @@ class DataChain:
|
|
|
1882
1884
|
"`nrows` only supported for csv and json formats.",
|
|
1883
1885
|
)
|
|
1884
1886
|
|
|
1887
|
+
if "file" not in self.schema or not self.count():
|
|
1888
|
+
raise DatasetPrepareError(self.name, "no files to parse.")
|
|
1889
|
+
|
|
1885
1890
|
schema = None
|
|
1886
1891
|
col_names = output if isinstance(output, Sequence) else None
|
|
1887
1892
|
if col_names or not output:
|
datachain/lib/file.py
CHANGED
|
@@ -269,10 +269,21 @@ class File(DataModel):
|
|
|
269
269
|
client = self._catalog.get_client(self.source)
|
|
270
270
|
client.download(self, callback=self._download_cb)
|
|
271
271
|
|
|
272
|
-
async def _prefetch(self) ->
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
272
|
+
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
|
|
273
|
+
from datachain.client.hf import HfClient
|
|
274
|
+
|
|
275
|
+
if self._catalog is None:
|
|
276
|
+
raise RuntimeError("cannot prefetch file because catalog is not setup")
|
|
277
|
+
|
|
278
|
+
client = self._catalog.get_client(self.source)
|
|
279
|
+
if client.protocol == HfClient.protocol:
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
await client._download(self, callback=download_cb or self._download_cb)
|
|
283
|
+
self._set_stream(
|
|
284
|
+
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
|
|
285
|
+
)
|
|
286
|
+
return True
|
|
276
287
|
|
|
277
288
|
def get_local_path(self) -> Optional[str]:
|
|
278
289
|
"""Return path to a file in a local cache.
|
|
@@ -364,7 +375,7 @@ class File(DataModel):
|
|
|
364
375
|
|
|
365
376
|
try:
|
|
366
377
|
info = client.fs.info(client.get_full_path(self.path))
|
|
367
|
-
converted_info = client.info_to_file(info, self.
|
|
378
|
+
converted_info = client.info_to_file(info, self.path)
|
|
368
379
|
return type(self)(
|
|
369
380
|
path=self.path,
|
|
370
381
|
source=self.source,
|
datachain/lib/hf.py
CHANGED
|
@@ -95,7 +95,7 @@ class HFGenerator(Generator):
|
|
|
95
95
|
ds = self.ds_dict[split]
|
|
96
96
|
if split:
|
|
97
97
|
desc += f" split '{split}'"
|
|
98
|
-
with tqdm(desc=desc, unit=" rows") as pbar:
|
|
98
|
+
with tqdm(desc=desc, unit=" rows", leave=False) as pbar:
|
|
99
99
|
for row in ds:
|
|
100
100
|
output_dict = {}
|
|
101
101
|
if split and "split" in self.output_schema.model_fields:
|
datachain/lib/listing.py
CHANGED
|
@@ -85,6 +85,24 @@ def ls(
|
|
|
85
85
|
return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
|
|
86
86
|
|
|
87
87
|
|
|
88
|
+
def _isfile(client: "Client", path: str) -> bool:
|
|
89
|
+
"""
|
|
90
|
+
Returns True if uri points to a file
|
|
91
|
+
"""
|
|
92
|
+
try:
|
|
93
|
+
info = client.fs.info(path)
|
|
94
|
+
name = info.get("name")
|
|
95
|
+
# case for special simulated directories on some clouds
|
|
96
|
+
# e.g. Google creates a zero byte file with the same name as the
|
|
97
|
+
# directory with a trailing slash at the end
|
|
98
|
+
if not name or name.endswith("/"):
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
return info["type"] == "file"
|
|
102
|
+
except: # noqa: E722
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
|
|
88
106
|
def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], str, str]:
|
|
89
107
|
"""
|
|
90
108
|
Parsing uri and returns listing dataset name, listing uri and listing path
|
|
@@ -94,7 +112,7 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], st
|
|
|
94
112
|
storage_uri, path = Client.parse_url(uri)
|
|
95
113
|
telemetry.log_param("client", client.PREFIX)
|
|
96
114
|
|
|
97
|
-
if not uri.endswith("/") and client
|
|
115
|
+
if not uri.endswith("/") and _isfile(client, uri):
|
|
98
116
|
return None, f'{storage_uri}/{path.lstrip("/")}', path
|
|
99
117
|
if uses_glob(path):
|
|
100
118
|
lst_uri_path = posixpath.dirname(path)
|
datachain/lib/pytorch.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
|
|
2
|
+
import os
|
|
3
|
+
import weakref
|
|
4
|
+
from collections.abc import Generator, Iterable, Iterator
|
|
5
|
+
from contextlib import closing
|
|
3
6
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
7
|
|
|
5
8
|
from PIL import Image
|
|
@@ -9,15 +12,19 @@ from torch.utils.data import IterableDataset, get_worker_info
|
|
|
9
12
|
from torchvision.transforms import v2
|
|
10
13
|
|
|
11
14
|
from datachain import Session
|
|
12
|
-
from datachain.
|
|
15
|
+
from datachain.cache import get_temp_cache
|
|
13
16
|
from datachain.catalog import Catalog, get_catalog
|
|
14
17
|
from datachain.lib.dc import DataChain
|
|
15
18
|
from datachain.lib.settings import Settings
|
|
16
19
|
from datachain.lib.text import convert_text
|
|
20
|
+
from datachain.progress import CombinedDownloadCallback
|
|
21
|
+
from datachain.query.dataset import get_download_callback
|
|
17
22
|
|
|
18
23
|
if TYPE_CHECKING:
|
|
19
24
|
from torchvision.transforms.v2 import Transform
|
|
20
25
|
|
|
26
|
+
from datachain.cache import DataChainCache as Cache
|
|
27
|
+
|
|
21
28
|
|
|
22
29
|
logger = logging.getLogger("datachain")
|
|
23
30
|
|
|
@@ -75,6 +82,19 @@ class PytorchDataset(IterableDataset):
|
|
|
75
82
|
if (prefetch := dc_settings.prefetch) is not None:
|
|
76
83
|
self.prefetch = prefetch
|
|
77
84
|
|
|
85
|
+
self._cache = catalog.cache
|
|
86
|
+
self._prefetch_cache: Optional[Cache] = None
|
|
87
|
+
if prefetch and not self.cache:
|
|
88
|
+
tmp_dir = catalog.cache.tmp_dir
|
|
89
|
+
assert tmp_dir
|
|
90
|
+
self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
|
|
91
|
+
self._cache = self._prefetch_cache
|
|
92
|
+
weakref.finalize(self, self._prefetch_cache.destroy)
|
|
93
|
+
|
|
94
|
+
def close(self) -> None:
|
|
95
|
+
if self._prefetch_cache:
|
|
96
|
+
self._prefetch_cache.destroy()
|
|
97
|
+
|
|
78
98
|
def _init_catalog(self, catalog: "Catalog"):
|
|
79
99
|
# For compatibility with multiprocessing,
|
|
80
100
|
# we can only store params in __init__(), as Catalog isn't picklable
|
|
@@ -89,9 +109,15 @@ class PytorchDataset(IterableDataset):
|
|
|
89
109
|
ms = ms_cls(*ms_args, **ms_kwargs)
|
|
90
110
|
wh_cls, wh_args, wh_kwargs = self._wh_params
|
|
91
111
|
wh = wh_cls(*wh_args, **wh_kwargs)
|
|
92
|
-
|
|
112
|
+
catalog = Catalog(ms, wh, **self._catalog_params)
|
|
113
|
+
catalog.cache = self._cache
|
|
114
|
+
return catalog
|
|
93
115
|
|
|
94
|
-
def
|
|
116
|
+
def _row_iter(
|
|
117
|
+
self,
|
|
118
|
+
total_rank: int,
|
|
119
|
+
total_workers: int,
|
|
120
|
+
) -> Generator[tuple[Any, ...], None, None]:
|
|
95
121
|
catalog = self._get_catalog()
|
|
96
122
|
session = Session("PyTorch", catalog=catalog)
|
|
97
123
|
ds = DataChain.from_dataset(
|
|
@@ -104,16 +130,34 @@ class PytorchDataset(IterableDataset):
|
|
|
104
130
|
ds = ds.chunk(total_rank, total_workers)
|
|
105
131
|
yield from ds.collect()
|
|
106
132
|
|
|
107
|
-
def
|
|
108
|
-
|
|
109
|
-
rows = self._rows_iter(total_rank, total_workers)
|
|
110
|
-
if self.prefetch > 0:
|
|
111
|
-
from datachain.lib.udf import _prefetch_input
|
|
112
|
-
|
|
113
|
-
rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
|
|
114
|
-
yield from map(self._process_row, rows)
|
|
133
|
+
def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
|
|
134
|
+
from datachain.lib.udf import _prefetch_inputs
|
|
115
135
|
|
|
116
|
-
|
|
136
|
+
total_rank, total_workers = self.get_rank_and_workers()
|
|
137
|
+
download_cb = CombinedDownloadCallback()
|
|
138
|
+
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
|
|
139
|
+
download_cb = get_download_callback(
|
|
140
|
+
f"{total_rank}/{total_workers}",
|
|
141
|
+
position=total_rank,
|
|
142
|
+
leave=True,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
rows = self._row_iter(total_rank, total_workers)
|
|
146
|
+
rows = _prefetch_inputs(
|
|
147
|
+
rows,
|
|
148
|
+
self.prefetch,
|
|
149
|
+
download_cb=download_cb,
|
|
150
|
+
after_prefetch=download_cb.increment_file_count,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
with download_cb, closing(rows):
|
|
154
|
+
yield from rows
|
|
155
|
+
|
|
156
|
+
def __iter__(self) -> Iterator[list[Any]]:
|
|
157
|
+
with closing(self._iter_with_prefetch()) as rows:
|
|
158
|
+
yield from map(self._process_row, rows)
|
|
159
|
+
|
|
160
|
+
def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
|
|
117
161
|
row = []
|
|
118
162
|
for fr in row_features:
|
|
119
163
|
if hasattr(fr, "read"):
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -13,13 +13,14 @@ from typing import ( # noqa: UP035
|
|
|
13
13
|
Final,
|
|
14
14
|
List,
|
|
15
15
|
Literal,
|
|
16
|
+
Mapping,
|
|
16
17
|
Optional,
|
|
17
18
|
Union,
|
|
18
19
|
get_args,
|
|
19
20
|
get_origin,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
|
-
from pydantic import BaseModel, create_model
|
|
23
|
+
from pydantic import BaseModel, Field, create_model
|
|
23
24
|
from sqlalchemy import ColumnElement
|
|
24
25
|
from typing_extensions import Literal as LiteralEx
|
|
25
26
|
|
|
@@ -85,8 +86,31 @@ class SignalResolvingTypeError(SignalResolvingError):
|
|
|
85
86
|
)
|
|
86
87
|
|
|
87
88
|
|
|
89
|
+
class CustomType(BaseModel):
|
|
90
|
+
schema_version: int = Field(ge=1, le=2, strict=True)
|
|
91
|
+
name: str
|
|
92
|
+
fields: dict[str, str]
|
|
93
|
+
bases: list[tuple[str, str, Optional[str]]]
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
|
|
97
|
+
version = data.get("schema_version", 1)
|
|
98
|
+
|
|
99
|
+
if version == 1:
|
|
100
|
+
data = {
|
|
101
|
+
"schema_version": 1,
|
|
102
|
+
"name": type_name,
|
|
103
|
+
"fields": data,
|
|
104
|
+
"bases": [],
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return cls(**data)
|
|
108
|
+
|
|
109
|
+
|
|
88
110
|
def create_feature_model(
|
|
89
|
-
name: str,
|
|
111
|
+
name: str,
|
|
112
|
+
fields: Mapping[str, Union[type, None, tuple[type, Any]]],
|
|
113
|
+
base: Optional[type] = None,
|
|
90
114
|
) -> type[BaseModel]:
|
|
91
115
|
"""
|
|
92
116
|
This gets or returns a dynamic feature model for use in restoring a model
|
|
@@ -98,7 +122,7 @@ def create_feature_model(
|
|
|
98
122
|
name = name.replace("@", "_")
|
|
99
123
|
return create_model(
|
|
100
124
|
name,
|
|
101
|
-
__base__=DataModel, # type: ignore[call-overload]
|
|
125
|
+
__base__=base or DataModel, # type: ignore[call-overload]
|
|
102
126
|
# These are tuples for each field of: annotation, default (if any)
|
|
103
127
|
**{
|
|
104
128
|
field_name: anno if isinstance(anno, tuple) else (anno, None)
|
|
@@ -156,7 +180,7 @@ class SignalSchema:
|
|
|
156
180
|
return SignalSchema(signals)
|
|
157
181
|
|
|
158
182
|
@staticmethod
|
|
159
|
-
def
|
|
183
|
+
def _serialize_custom_model(
|
|
160
184
|
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
|
|
161
185
|
) -> str:
|
|
162
186
|
"""This serializes any custom type information to the provided custom_types
|
|
@@ -165,12 +189,23 @@ class SignalSchema:
|
|
|
165
189
|
# This type is already stored in custom_types.
|
|
166
190
|
return version_name
|
|
167
191
|
fields = {}
|
|
192
|
+
|
|
168
193
|
for field_name, info in fr.model_fields.items():
|
|
169
194
|
field_type = info.annotation
|
|
170
195
|
# All fields should be typed.
|
|
171
196
|
assert field_type
|
|
172
197
|
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
|
|
173
|
-
|
|
198
|
+
|
|
199
|
+
bases: list[tuple[str, str, Optional[str]]] = []
|
|
200
|
+
for type_ in fr.__mro__:
|
|
201
|
+
model_store_name = (
|
|
202
|
+
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
|
|
203
|
+
)
|
|
204
|
+
bases.append((type_.__name__, type_.__module__, model_store_name))
|
|
205
|
+
|
|
206
|
+
ct = CustomType(schema_version=2, name=version_name, fields=fields, bases=bases)
|
|
207
|
+
custom_types[version_name] = ct.model_dump()
|
|
208
|
+
|
|
174
209
|
return version_name
|
|
175
210
|
|
|
176
211
|
@staticmethod
|
|
@@ -184,15 +219,12 @@ class SignalSchema:
|
|
|
184
219
|
if st is None or not ModelStore.is_pydantic(st):
|
|
185
220
|
continue
|
|
186
221
|
# Register and save feature types.
|
|
187
|
-
ModelStore.register(st)
|
|
188
222
|
st_version_name = ModelStore.get_name(st)
|
|
189
223
|
if st is fr:
|
|
190
224
|
# If the main type is Pydantic, then use the ModelStore version name.
|
|
191
225
|
type_name = st_version_name
|
|
192
226
|
# Save this type to custom_types.
|
|
193
|
-
SignalSchema.
|
|
194
|
-
st_version_name, st, custom_types
|
|
195
|
-
)
|
|
227
|
+
SignalSchema._serialize_custom_model(st_version_name, st, custom_types)
|
|
196
228
|
return type_name
|
|
197
229
|
|
|
198
230
|
def serialize(self) -> dict[str, Any]:
|
|
@@ -215,7 +247,7 @@ class SignalSchema:
|
|
|
215
247
|
depth += 1
|
|
216
248
|
elif c == "]":
|
|
217
249
|
if depth == 0:
|
|
218
|
-
raise
|
|
250
|
+
raise ValueError(
|
|
219
251
|
"Extra closing square bracket when parsing subtype list"
|
|
220
252
|
)
|
|
221
253
|
depth -= 1
|
|
@@ -223,16 +255,51 @@ class SignalSchema:
|
|
|
223
255
|
subtypes.append(type_name[start:i].strip())
|
|
224
256
|
start = i + 1
|
|
225
257
|
if depth > 0:
|
|
226
|
-
raise
|
|
258
|
+
raise ValueError("Unclosed square bracket when parsing subtype list")
|
|
227
259
|
subtypes.append(type_name[start:].strip())
|
|
228
260
|
return subtypes
|
|
229
261
|
|
|
230
262
|
@staticmethod
|
|
231
|
-
def
|
|
263
|
+
def _deserialize_custom_type(
|
|
264
|
+
type_name: str, custom_types: dict[str, Any]
|
|
265
|
+
) -> Optional[type]:
|
|
266
|
+
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
|
|
267
|
+
it based on the information from the custom types dict that includes fields and
|
|
268
|
+
bases."""
|
|
269
|
+
model_name, version = ModelStore.parse_name_version(type_name)
|
|
270
|
+
fr = ModelStore.get(model_name, version)
|
|
271
|
+
if fr:
|
|
272
|
+
return fr
|
|
273
|
+
|
|
274
|
+
if type_name in custom_types:
|
|
275
|
+
ct = CustomType.deserialize(custom_types[type_name], type_name)
|
|
276
|
+
|
|
277
|
+
fields = {
|
|
278
|
+
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
|
|
279
|
+
for field_name, field_type_str in ct.fields.items()
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
base_model = None
|
|
283
|
+
for base in ct.bases:
|
|
284
|
+
_, _, model_store_name = base
|
|
285
|
+
if model_store_name:
|
|
286
|
+
model_name, version = ModelStore.parse_name_version(
|
|
287
|
+
model_store_name
|
|
288
|
+
)
|
|
289
|
+
base_model = ModelStore.get(model_name, version)
|
|
290
|
+
if base_model:
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
return create_feature_model(type_name, fields, base=base_model)
|
|
294
|
+
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
|
|
232
299
|
"""Convert a string-based type back into a python type."""
|
|
233
300
|
type_name = type_name.strip()
|
|
234
301
|
if not type_name:
|
|
235
|
-
raise
|
|
302
|
+
raise ValueError("Type cannot be empty")
|
|
236
303
|
if type_name == "NoneType":
|
|
237
304
|
return None
|
|
238
305
|
|
|
@@ -240,14 +307,14 @@ class SignalSchema:
|
|
|
240
307
|
subtypes: Optional[tuple[Optional[type], ...]] = None
|
|
241
308
|
if bracket_idx > -1:
|
|
242
309
|
if bracket_idx == 0:
|
|
243
|
-
raise
|
|
310
|
+
raise ValueError("Type cannot start with '['")
|
|
244
311
|
close_bracket_idx = type_name.rfind("]")
|
|
245
312
|
if close_bracket_idx == -1:
|
|
246
|
-
raise
|
|
313
|
+
raise ValueError("Unclosed square bracket when parsing type")
|
|
247
314
|
if close_bracket_idx < bracket_idx:
|
|
248
|
-
raise
|
|
315
|
+
raise ValueError("Square brackets are out of order when parsing type")
|
|
249
316
|
if close_bracket_idx == bracket_idx + 1:
|
|
250
|
-
raise
|
|
317
|
+
raise ValueError("Empty square brackets when parsing type")
|
|
251
318
|
subtype_names = SignalSchema._split_subtypes(
|
|
252
319
|
type_name[bracket_idx + 1 : close_bracket_idx]
|
|
253
320
|
)
|
|
@@ -267,18 +334,10 @@ class SignalSchema:
|
|
|
267
334
|
return fr[subtypes] # type: ignore[index]
|
|
268
335
|
return fr # type: ignore[return-value]
|
|
269
336
|
|
|
270
|
-
|
|
271
|
-
fr = ModelStore.get(model_name, version)
|
|
337
|
+
fr = SignalSchema._deserialize_custom_type(type_name, custom_types)
|
|
272
338
|
if fr:
|
|
273
339
|
return fr
|
|
274
340
|
|
|
275
|
-
if type_name in custom_types:
|
|
276
|
-
fields = custom_types[type_name]
|
|
277
|
-
fields = {
|
|
278
|
-
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
|
|
279
|
-
for field_name, field_type_str in fields.items()
|
|
280
|
-
}
|
|
281
|
-
return create_feature_model(type_name, fields)
|
|
282
341
|
# This can occur if a third-party or custom type is used, which is not available
|
|
283
342
|
# when deserializing.
|
|
284
343
|
warnings.warn(
|
|
@@ -317,7 +376,7 @@ class SignalSchema:
|
|
|
317
376
|
stacklevel=2,
|
|
318
377
|
)
|
|
319
378
|
continue
|
|
320
|
-
except
|
|
379
|
+
except ValueError as err:
|
|
321
380
|
raise SignalSchemaError(
|
|
322
381
|
f"cannot deserialize '{signal}': {err}"
|
|
323
382
|
) from err
|
|
@@ -662,6 +721,9 @@ class SignalSchema:
|
|
|
662
721
|
stacklevel=2,
|
|
663
722
|
)
|
|
664
723
|
return "Any"
|
|
724
|
+
if ModelStore.is_pydantic(type_):
|
|
725
|
+
ModelStore.register(type_)
|
|
726
|
+
return ModelStore.get_name(type_)
|
|
665
727
|
return type_.__name__
|
|
666
728
|
|
|
667
729
|
@staticmethod
|