datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (50) hide show
  1. datachain/asyn.py +16 -6
  2. datachain/cache.py +32 -10
  3. datachain/catalog/catalog.py +17 -1
  4. datachain/cli/__init__.py +311 -0
  5. datachain/cli/commands/__init__.py +29 -0
  6. datachain/cli/commands/datasets.py +129 -0
  7. datachain/cli/commands/du.py +14 -0
  8. datachain/cli/commands/index.py +12 -0
  9. datachain/cli/commands/ls.py +169 -0
  10. datachain/cli/commands/misc.py +28 -0
  11. datachain/cli/commands/query.py +53 -0
  12. datachain/cli/commands/show.py +38 -0
  13. datachain/cli/parser/__init__.py +547 -0
  14. datachain/cli/parser/job.py +120 -0
  15. datachain/cli/parser/studio.py +126 -0
  16. datachain/cli/parser/utils.py +63 -0
  17. datachain/{cli_utils.py → cli/utils.py} +27 -1
  18. datachain/client/azure.py +6 -2
  19. datachain/client/fsspec.py +9 -3
  20. datachain/client/gcs.py +6 -2
  21. datachain/client/s3.py +16 -1
  22. datachain/data_storage/db_engine.py +9 -0
  23. datachain/data_storage/schema.py +4 -10
  24. datachain/data_storage/sqlite.py +7 -1
  25. datachain/data_storage/warehouse.py +6 -4
  26. datachain/{lib/diff.py → diff/__init__.py} +116 -12
  27. datachain/func/__init__.py +3 -2
  28. datachain/func/conditional.py +74 -0
  29. datachain/func/func.py +5 -1
  30. datachain/lib/arrow.py +7 -1
  31. datachain/lib/dc.py +8 -3
  32. datachain/lib/file.py +16 -5
  33. datachain/lib/hf.py +1 -1
  34. datachain/lib/listing.py +19 -1
  35. datachain/lib/pytorch.py +57 -13
  36. datachain/lib/signal_schema.py +89 -27
  37. datachain/lib/udf.py +82 -40
  38. datachain/listing.py +1 -0
  39. datachain/progress.py +20 -3
  40. datachain/query/dataset.py +122 -93
  41. datachain/query/dispatch.py +22 -16
  42. datachain/studio.py +58 -38
  43. datachain/utils.py +14 -3
  44. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
  45. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
  46. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
  47. datachain/cli.py +0 -1475
  48. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
  49. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
  50. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,15 @@
1
1
  from typing import Union
2
2
 
3
+ from sqlalchemy import case as sql_case
4
+ from sqlalchemy.sql.elements import BinaryExpression
5
+
6
+ from datachain.lib.utils import DataChainParamsError
3
7
  from datachain.sql.functions import conditional
4
8
 
5
9
  from .func import ColT, Func
6
10
 
11
+ CaseT = Union[int, float, complex, bool, str]
12
+
7
13
 
8
14
  def greatest(*args: Union[ColT, float]) -> Func:
9
15
  """
@@ -79,3 +85,71 @@ def least(*args: Union[ColT, float]) -> Func:
79
85
  return Func(
80
86
  "least", inner=conditional.least, cols=cols, args=func_args, result_type=int
81
87
  )
88
+
89
+
90
+ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
91
+ """
92
+ Returns the case function that produces case expression which has a list of
93
+ conditions and corresponding results. Results can only be python primitives
94
+ like string, numbes or booleans. Result type is inferred from condition results.
95
+
96
+ Args:
97
+ args (tuple(BinaryExpression, value(str | int | float | complex | bool):
98
+ - Tuple of binary expression and values pair which corresponds to one
99
+ case condition - value
100
+ else_ (str | int | float | complex | bool): else value in case expression
101
+
102
+ Returns:
103
+ Func: A Func object that represents the case function.
104
+
105
+ Example:
106
+ ```py
107
+ dc.mutate(
108
+ res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
109
+ )
110
+ ```
111
+ """
112
+ supported_types = [int, float, complex, str, bool]
113
+
114
+ type_ = type(else_) if else_ else None
115
+
116
+ if not args:
117
+ raise DataChainParamsError("Missing statements")
118
+
119
+ for arg in args:
120
+ if type_ and not isinstance(arg[1], type_):
121
+ raise DataChainParamsError("Statement values must be of the same type")
122
+ type_ = type(arg[1])
123
+
124
+ if type_ not in supported_types:
125
+ raise DataChainParamsError(
126
+ f"Only python literals ({supported_types}) are supported for values"
127
+ )
128
+
129
+ kwargs = {"else_": else_}
130
+ return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
131
+
132
+
133
+ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
134
+ """
135
+ Returns the ifelse function that produces if expression which has a condition
136
+ and values for true and false outcome. Results can only be python primitives
137
+ like string, numbes or booleans. Result type is inferred from the values.
138
+
139
+ Args:
140
+ condition: BinaryExpression - condition which is evaluated
141
+ if_val: (str | int | float | complex | bool): value for true condition outcome
142
+ else_val: (str | int | float | complex | bool): value for false condition
143
+ outcome
144
+
145
+ Returns:
146
+ Func: A Func object that represents the ifelse function.
147
+
148
+ Example:
149
+ ```py
150
+ dc.mutate(
151
+ res=func.ifelse(C("num") > 0, "P", "N"),
152
+ )
153
+ ```
154
+ """
155
+ return case((condition, if_val), else_=else_val)
datachain/func/func.py CHANGED
@@ -35,6 +35,7 @@ class Func(Function):
35
35
  inner: Callable,
36
36
  cols: Optional[Sequence[ColT]] = None,
37
37
  args: Optional[Sequence[Any]] = None,
38
+ kwargs: Optional[dict[str, Any]] = None,
38
39
  result_type: Optional["DataType"] = None,
39
40
  is_array: bool = False,
40
41
  is_window: bool = False,
@@ -45,6 +46,7 @@ class Func(Function):
45
46
  self.inner = inner
46
47
  self.cols = cols or []
47
48
  self.args = args or []
49
+ self.kwargs = kwargs or {}
48
50
  self.result_type = result_type
49
51
  self.is_array = is_array
50
52
  self.is_window = is_window
@@ -63,6 +65,7 @@ class Func(Function):
63
65
  self.inner,
64
66
  self.cols,
65
67
  self.args,
68
+ self.kwargs,
66
69
  self.result_type,
67
70
  self.is_array,
68
71
  self.is_window,
@@ -333,6 +336,7 @@ class Func(Function):
333
336
  self.inner,
334
337
  self.cols,
335
338
  self.args,
339
+ self.kwargs,
336
340
  self.result_type,
337
341
  self.is_array,
338
342
  self.is_window,
@@ -387,7 +391,7 @@ class Func(Function):
387
391
  return col
388
392
 
389
393
  cols = [get_col(col) for col in self._db_cols]
390
- func_col = self.inner(*cols, *self.args)
394
+ func_col = self.inner(*cols, *self.args, **self.kwargs)
391
395
 
392
396
  if self.is_window:
393
397
  if not self.window:
datachain/lib/arrow.py CHANGED
@@ -91,7 +91,9 @@ class ArrowGenerator(Generator):
91
91
  yield from record_batch.to_pylist()
92
92
 
93
93
  it = islice(iter_records(), self.nrows)
94
- with tqdm(it, desc="Parsed by pyarrow", unit="rows", total=self.nrows) as pbar:
94
+ with tqdm(
95
+ it, desc="Parsed by pyarrow", unit="rows", total=self.nrows, leave=False
96
+ ) as pbar:
95
97
  for index, record in enumerate(pbar):
96
98
  yield self._process_record(
97
99
  record, file, index, hf_schema, use_datachain_schema
@@ -149,6 +151,10 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
149
151
  for file in chain.collect("file"):
150
152
  ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
151
153
  schemas.append(ds.schema)
154
+ if not schemas:
155
+ raise ValueError(
156
+ "Cannot infer schema (no files to process or can't access them)"
157
+ )
152
158
  return pa.unify_schemas(schemas)
153
159
 
154
160
 
datachain/lib/dc.py CHANGED
@@ -451,6 +451,7 @@ class DataChain:
451
451
  return dc
452
452
 
453
453
  if update or not list_ds_exists:
454
+ # disable prefetch for listing, as it pre-downloads all files
454
455
  (
455
456
  cls.from_records(
456
457
  DataChain.DEFAULT_FILE_RECORD,
@@ -458,6 +459,7 @@ class DataChain:
458
459
  settings=settings,
459
460
  in_memory=in_memory,
460
461
  )
462
+ .settings(prefetch=0)
461
463
  .gen(
462
464
  list_bucket(list_uri, cache, client_config=client_config),
463
465
  output={f"{object_name}": File},
@@ -1534,7 +1536,7 @@ class DataChain:
1534
1536
 
1535
1537
  Example:
1536
1538
  ```py
1537
- diff = persons.diff(
1539
+ res = persons.compare(
1538
1540
  new_persons,
1539
1541
  on=["id"],
1540
1542
  right_on=["other_id"],
@@ -1547,9 +1549,9 @@ class DataChain:
1547
1549
  )
1548
1550
  ```
1549
1551
  """
1550
- from datachain.lib.diff import compare as chain_compare
1552
+ from datachain.diff import _compare
1551
1553
 
1552
- return chain_compare(
1554
+ return _compare(
1553
1555
  self,
1554
1556
  other,
1555
1557
  on,
@@ -1882,6 +1884,9 @@ class DataChain:
1882
1884
  "`nrows` only supported for csv and json formats.",
1883
1885
  )
1884
1886
 
1887
+ if "file" not in self.schema or not self.count():
1888
+ raise DatasetPrepareError(self.name, "no files to parse.")
1889
+
1885
1890
  schema = None
1886
1891
  col_names = output if isinstance(output, Sequence) else None
1887
1892
  if col_names or not output:
datachain/lib/file.py CHANGED
@@ -269,10 +269,21 @@ class File(DataModel):
269
269
  client = self._catalog.get_client(self.source)
270
270
  client.download(self, callback=self._download_cb)
271
271
 
272
- async def _prefetch(self) -> None:
273
- if self._caching_enabled:
274
- client = self._catalog.get_client(self.source)
275
- await client._download(self, callback=self._download_cb)
272
+ async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
273
+ from datachain.client.hf import HfClient
274
+
275
+ if self._catalog is None:
276
+ raise RuntimeError("cannot prefetch file because catalog is not setup")
277
+
278
+ client = self._catalog.get_client(self.source)
279
+ if client.protocol == HfClient.protocol:
280
+ return False
281
+
282
+ await client._download(self, callback=download_cb or self._download_cb)
283
+ self._set_stream(
284
+ self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
285
+ )
286
+ return True
276
287
 
277
288
  def get_local_path(self) -> Optional[str]:
278
289
  """Return path to a file in a local cache.
@@ -364,7 +375,7 @@ class File(DataModel):
364
375
 
365
376
  try:
366
377
  info = client.fs.info(client.get_full_path(self.path))
367
- converted_info = client.info_to_file(info, self.source)
378
+ converted_info = client.info_to_file(info, self.path)
368
379
  return type(self)(
369
380
  path=self.path,
370
381
  source=self.source,
datachain/lib/hf.py CHANGED
@@ -95,7 +95,7 @@ class HFGenerator(Generator):
95
95
  ds = self.ds_dict[split]
96
96
  if split:
97
97
  desc += f" split '{split}'"
98
- with tqdm(desc=desc, unit=" rows") as pbar:
98
+ with tqdm(desc=desc, unit=" rows", leave=False) as pbar:
99
99
  for row in ds:
100
100
  output_dict = {}
101
101
  if split and "split" in self.output_schema.model_fields:
datachain/lib/listing.py CHANGED
@@ -85,6 +85,24 @@ def ls(
85
85
  return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
86
86
 
87
87
 
88
+ def _isfile(client: "Client", path: str) -> bool:
89
+ """
90
+ Returns True if uri points to a file
91
+ """
92
+ try:
93
+ info = client.fs.info(path)
94
+ name = info.get("name")
95
+ # case for special simulated directories on some clouds
96
+ # e.g. Google creates a zero byte file with the same name as the
97
+ # directory with a trailing slash at the end
98
+ if not name or name.endswith("/"):
99
+ return False
100
+
101
+ return info["type"] == "file"
102
+ except: # noqa: E722
103
+ return False
104
+
105
+
88
106
  def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], str, str]:
89
107
  """
90
108
  Parsing uri and returns listing dataset name, listing uri and listing path
@@ -94,7 +112,7 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], st
94
112
  storage_uri, path = Client.parse_url(uri)
95
113
  telemetry.log_param("client", client.PREFIX)
96
114
 
97
- if not uri.endswith("/") and client.fs.isfile(uri):
115
+ if not uri.endswith("/") and _isfile(client, uri):
98
116
  return None, f'{storage_uri}/{path.lstrip("/")}', path
99
117
  if uses_glob(path):
100
118
  lst_uri_path = posixpath.dirname(path)
datachain/lib/pytorch.py CHANGED
@@ -1,5 +1,8 @@
1
1
  import logging
2
- from collections.abc import Iterator
2
+ import os
3
+ import weakref
4
+ from collections.abc import Generator, Iterable, Iterator
5
+ from contextlib import closing
3
6
  from typing import TYPE_CHECKING, Any, Callable, Optional
4
7
 
5
8
  from PIL import Image
@@ -9,15 +12,19 @@ from torch.utils.data import IterableDataset, get_worker_info
9
12
  from torchvision.transforms import v2
10
13
 
11
14
  from datachain import Session
12
- from datachain.asyn import AsyncMapper
15
+ from datachain.cache import get_temp_cache
13
16
  from datachain.catalog import Catalog, get_catalog
14
17
  from datachain.lib.dc import DataChain
15
18
  from datachain.lib.settings import Settings
16
19
  from datachain.lib.text import convert_text
20
+ from datachain.progress import CombinedDownloadCallback
21
+ from datachain.query.dataset import get_download_callback
17
22
 
18
23
  if TYPE_CHECKING:
19
24
  from torchvision.transforms.v2 import Transform
20
25
 
26
+ from datachain.cache import DataChainCache as Cache
27
+
21
28
 
22
29
  logger = logging.getLogger("datachain")
23
30
 
@@ -75,6 +82,19 @@ class PytorchDataset(IterableDataset):
75
82
  if (prefetch := dc_settings.prefetch) is not None:
76
83
  self.prefetch = prefetch
77
84
 
85
+ self._cache = catalog.cache
86
+ self._prefetch_cache: Optional[Cache] = None
87
+ if prefetch and not self.cache:
88
+ tmp_dir = catalog.cache.tmp_dir
89
+ assert tmp_dir
90
+ self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
91
+ self._cache = self._prefetch_cache
92
+ weakref.finalize(self, self._prefetch_cache.destroy)
93
+
94
+ def close(self) -> None:
95
+ if self._prefetch_cache:
96
+ self._prefetch_cache.destroy()
97
+
78
98
  def _init_catalog(self, catalog: "Catalog"):
79
99
  # For compatibility with multiprocessing,
80
100
  # we can only store params in __init__(), as Catalog isn't picklable
@@ -89,9 +109,15 @@ class PytorchDataset(IterableDataset):
89
109
  ms = ms_cls(*ms_args, **ms_kwargs)
90
110
  wh_cls, wh_args, wh_kwargs = self._wh_params
91
111
  wh = wh_cls(*wh_args, **wh_kwargs)
92
- return Catalog(ms, wh, **self._catalog_params)
112
+ catalog = Catalog(ms, wh, **self._catalog_params)
113
+ catalog.cache = self._cache
114
+ return catalog
93
115
 
94
- def _rows_iter(self, total_rank: int, total_workers: int):
116
+ def _row_iter(
117
+ self,
118
+ total_rank: int,
119
+ total_workers: int,
120
+ ) -> Generator[tuple[Any, ...], None, None]:
95
121
  catalog = self._get_catalog()
96
122
  session = Session("PyTorch", catalog=catalog)
97
123
  ds = DataChain.from_dataset(
@@ -104,16 +130,34 @@ class PytorchDataset(IterableDataset):
104
130
  ds = ds.chunk(total_rank, total_workers)
105
131
  yield from ds.collect()
106
132
 
107
- def __iter__(self) -> Iterator[Any]:
108
- total_rank, total_workers = self.get_rank_and_workers()
109
- rows = self._rows_iter(total_rank, total_workers)
110
- if self.prefetch > 0:
111
- from datachain.lib.udf import _prefetch_input
112
-
113
- rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
114
- yield from map(self._process_row, rows)
133
+ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
134
+ from datachain.lib.udf import _prefetch_inputs
115
135
 
116
- def _process_row(self, row_features):
136
+ total_rank, total_workers = self.get_rank_and_workers()
137
+ download_cb = CombinedDownloadCallback()
138
+ if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
139
+ download_cb = get_download_callback(
140
+ f"{total_rank}/{total_workers}",
141
+ position=total_rank,
142
+ leave=True,
143
+ )
144
+
145
+ rows = self._row_iter(total_rank, total_workers)
146
+ rows = _prefetch_inputs(
147
+ rows,
148
+ self.prefetch,
149
+ download_cb=download_cb,
150
+ after_prefetch=download_cb.increment_file_count,
151
+ )
152
+
153
+ with download_cb, closing(rows):
154
+ yield from rows
155
+
156
+ def __iter__(self) -> Iterator[list[Any]]:
157
+ with closing(self._iter_with_prefetch()) as rows:
158
+ yield from map(self._process_row, rows)
159
+
160
+ def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
117
161
  row = []
118
162
  for fr in row_features:
119
163
  if hasattr(fr, "read"):
@@ -13,13 +13,14 @@ from typing import ( # noqa: UP035
13
13
  Final,
14
14
  List,
15
15
  Literal,
16
+ Mapping,
16
17
  Optional,
17
18
  Union,
18
19
  get_args,
19
20
  get_origin,
20
21
  )
21
22
 
22
- from pydantic import BaseModel, create_model
23
+ from pydantic import BaseModel, Field, create_model
23
24
  from sqlalchemy import ColumnElement
24
25
  from typing_extensions import Literal as LiteralEx
25
26
 
@@ -85,8 +86,31 @@ class SignalResolvingTypeError(SignalResolvingError):
85
86
  )
86
87
 
87
88
 
89
+ class CustomType(BaseModel):
90
+ schema_version: int = Field(ge=1, le=2, strict=True)
91
+ name: str
92
+ fields: dict[str, str]
93
+ bases: list[tuple[str, str, Optional[str]]]
94
+
95
+ @classmethod
96
+ def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
97
+ version = data.get("schema_version", 1)
98
+
99
+ if version == 1:
100
+ data = {
101
+ "schema_version": 1,
102
+ "name": type_name,
103
+ "fields": data,
104
+ "bases": [],
105
+ }
106
+
107
+ return cls(**data)
108
+
109
+
88
110
  def create_feature_model(
89
- name: str, fields: dict[str, Union[type, tuple[type, Any]]]
111
+ name: str,
112
+ fields: Mapping[str, Union[type, None, tuple[type, Any]]],
113
+ base: Optional[type] = None,
90
114
  ) -> type[BaseModel]:
91
115
  """
92
116
  This gets or returns a dynamic feature model for use in restoring a model
@@ -98,7 +122,7 @@ def create_feature_model(
98
122
  name = name.replace("@", "_")
99
123
  return create_model(
100
124
  name,
101
- __base__=DataModel, # type: ignore[call-overload]
125
+ __base__=base or DataModel, # type: ignore[call-overload]
102
126
  # These are tuples for each field of: annotation, default (if any)
103
127
  **{
104
128
  field_name: anno if isinstance(anno, tuple) else (anno, None)
@@ -156,7 +180,7 @@ class SignalSchema:
156
180
  return SignalSchema(signals)
157
181
 
158
182
  @staticmethod
159
- def _serialize_custom_model_fields(
183
+ def _serialize_custom_model(
160
184
  version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
161
185
  ) -> str:
162
186
  """This serializes any custom type information to the provided custom_types
@@ -165,12 +189,23 @@ class SignalSchema:
165
189
  # This type is already stored in custom_types.
166
190
  return version_name
167
191
  fields = {}
192
+
168
193
  for field_name, info in fr.model_fields.items():
169
194
  field_type = info.annotation
170
195
  # All fields should be typed.
171
196
  assert field_type
172
197
  fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
173
- custom_types[version_name] = fields
198
+
199
+ bases: list[tuple[str, str, Optional[str]]] = []
200
+ for type_ in fr.__mro__:
201
+ model_store_name = (
202
+ ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
203
+ )
204
+ bases.append((type_.__name__, type_.__module__, model_store_name))
205
+
206
+ ct = CustomType(schema_version=2, name=version_name, fields=fields, bases=bases)
207
+ custom_types[version_name] = ct.model_dump()
208
+
174
209
  return version_name
175
210
 
176
211
  @staticmethod
@@ -184,15 +219,12 @@ class SignalSchema:
184
219
  if st is None or not ModelStore.is_pydantic(st):
185
220
  continue
186
221
  # Register and save feature types.
187
- ModelStore.register(st)
188
222
  st_version_name = ModelStore.get_name(st)
189
223
  if st is fr:
190
224
  # If the main type is Pydantic, then use the ModelStore version name.
191
225
  type_name = st_version_name
192
226
  # Save this type to custom_types.
193
- SignalSchema._serialize_custom_model_fields(
194
- st_version_name, st, custom_types
195
- )
227
+ SignalSchema._serialize_custom_model(st_version_name, st, custom_types)
196
228
  return type_name
197
229
 
198
230
  def serialize(self) -> dict[str, Any]:
@@ -215,7 +247,7 @@ class SignalSchema:
215
247
  depth += 1
216
248
  elif c == "]":
217
249
  if depth == 0:
218
- raise TypeError(
250
+ raise ValueError(
219
251
  "Extra closing square bracket when parsing subtype list"
220
252
  )
221
253
  depth -= 1
@@ -223,16 +255,51 @@ class SignalSchema:
223
255
  subtypes.append(type_name[start:i].strip())
224
256
  start = i + 1
225
257
  if depth > 0:
226
- raise TypeError("Unclosed square bracket when parsing subtype list")
258
+ raise ValueError("Unclosed square bracket when parsing subtype list")
227
259
  subtypes.append(type_name[start:].strip())
228
260
  return subtypes
229
261
 
230
262
  @staticmethod
231
- def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
263
+ def _deserialize_custom_type(
264
+ type_name: str, custom_types: dict[str, Any]
265
+ ) -> Optional[type]:
266
+ """Given a type name like MyType@v1 gets a type from ModelStore or recreates
267
+ it based on the information from the custom types dict that includes fields and
268
+ bases."""
269
+ model_name, version = ModelStore.parse_name_version(type_name)
270
+ fr = ModelStore.get(model_name, version)
271
+ if fr:
272
+ return fr
273
+
274
+ if type_name in custom_types:
275
+ ct = CustomType.deserialize(custom_types[type_name], type_name)
276
+
277
+ fields = {
278
+ field_name: SignalSchema._resolve_type(field_type_str, custom_types)
279
+ for field_name, field_type_str in ct.fields.items()
280
+ }
281
+
282
+ base_model = None
283
+ for base in ct.bases:
284
+ _, _, model_store_name = base
285
+ if model_store_name:
286
+ model_name, version = ModelStore.parse_name_version(
287
+ model_store_name
288
+ )
289
+ base_model = ModelStore.get(model_name, version)
290
+ if base_model:
291
+ break
292
+
293
+ return create_feature_model(type_name, fields, base=base_model)
294
+
295
+ return None
296
+
297
+ @staticmethod
298
+ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
232
299
  """Convert a string-based type back into a python type."""
233
300
  type_name = type_name.strip()
234
301
  if not type_name:
235
- raise TypeError("Type cannot be empty")
302
+ raise ValueError("Type cannot be empty")
236
303
  if type_name == "NoneType":
237
304
  return None
238
305
 
@@ -240,14 +307,14 @@ class SignalSchema:
240
307
  subtypes: Optional[tuple[Optional[type], ...]] = None
241
308
  if bracket_idx > -1:
242
309
  if bracket_idx == 0:
243
- raise TypeError("Type cannot start with '['")
310
+ raise ValueError("Type cannot start with '['")
244
311
  close_bracket_idx = type_name.rfind("]")
245
312
  if close_bracket_idx == -1:
246
- raise TypeError("Unclosed square bracket when parsing type")
313
+ raise ValueError("Unclosed square bracket when parsing type")
247
314
  if close_bracket_idx < bracket_idx:
248
- raise TypeError("Square brackets are out of order when parsing type")
315
+ raise ValueError("Square brackets are out of order when parsing type")
249
316
  if close_bracket_idx == bracket_idx + 1:
250
- raise TypeError("Empty square brackets when parsing type")
317
+ raise ValueError("Empty square brackets when parsing type")
251
318
  subtype_names = SignalSchema._split_subtypes(
252
319
  type_name[bracket_idx + 1 : close_bracket_idx]
253
320
  )
@@ -267,18 +334,10 @@ class SignalSchema:
267
334
  return fr[subtypes] # type: ignore[index]
268
335
  return fr # type: ignore[return-value]
269
336
 
270
- model_name, version = ModelStore.parse_name_version(type_name)
271
- fr = ModelStore.get(model_name, version)
337
+ fr = SignalSchema._deserialize_custom_type(type_name, custom_types)
272
338
  if fr:
273
339
  return fr
274
340
 
275
- if type_name in custom_types:
276
- fields = custom_types[type_name]
277
- fields = {
278
- field_name: SignalSchema._resolve_type(field_type_str, custom_types)
279
- for field_name, field_type_str in fields.items()
280
- }
281
- return create_feature_model(type_name, fields)
282
341
  # This can occur if a third-party or custom type is used, which is not available
283
342
  # when deserializing.
284
343
  warnings.warn(
@@ -317,7 +376,7 @@ class SignalSchema:
317
376
  stacklevel=2,
318
377
  )
319
378
  continue
320
- except TypeError as err:
379
+ except ValueError as err:
321
380
  raise SignalSchemaError(
322
381
  f"cannot deserialize '{signal}': {err}"
323
382
  ) from err
@@ -662,6 +721,9 @@ class SignalSchema:
662
721
  stacklevel=2,
663
722
  )
664
723
  return "Any"
724
+ if ModelStore.is_pydantic(type_):
725
+ ModelStore.register(type_)
726
+ return ModelStore.get_name(type_)
665
727
  return type_.__name__
666
728
 
667
729
  @staticmethod