datachain 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/file.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import io
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  import posixpath
5
6
  from abc import ABC, abstractmethod
@@ -15,6 +16,9 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
16
  from PIL import Image
16
17
  from pydantic import Field, field_validator
17
18
 
19
+ if TYPE_CHECKING:
20
+ from typing_extensions import Self
21
+
18
22
  from datachain.cache import UniqueId
19
23
  from datachain.client.fileslice import FileSlice
20
24
  from datachain.lib.data_model import DataModel
@@ -25,6 +29,8 @@ from datachain.utils import TIME_ZERO
25
29
  if TYPE_CHECKING:
26
30
  from datachain.catalog import Catalog
27
31
 
32
+ logger = logging.getLogger("datachain")
33
+
28
34
  # how to create file path when exporting
29
35
  ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
30
36
 
@@ -251,14 +257,18 @@ class File(DataModel):
251
257
  dump = self.model_dump()
252
258
  return UniqueId(*(dump[k] for k in self._unique_id_keys))
253
259
 
254
- def get_local_path(self) -> Optional[str]:
260
+ def get_local_path(self, download: bool = False) -> Optional[str]:
255
261
  """Returns path to a file in a local cache.
256
262
  Return None if file is not cached. Throws an exception if cache is not setup."""
257
263
  if self._catalog is None:
258
264
  raise RuntimeError(
259
265
  "cannot resolve local file path because catalog is not setup"
260
266
  )
261
- return self._catalog.cache.get_path(self.get_uid())
267
+ uid = self.get_uid()
268
+ if download:
269
+ client = self._catalog.get_client(self.source)
270
+ client.download(uid, callback=self._download_cb)
271
+ return self._catalog.cache.get_path(uid)
262
272
 
263
273
  def get_file_suffix(self):
264
274
  """Returns last part of file name with `.`."""
@@ -313,6 +323,70 @@ class File(DataModel):
313
323
  """Returns `fsspec` filesystem for the file."""
314
324
  return self._catalog.get_client(self.source).fs
315
325
 
326
+ def resolve(self) -> "Self":
327
+ """
328
+ Resolve a File object by checking its existence and updating its metadata.
329
+
330
+ Returns:
331
+ File: The resolved File object with updated metadata.
332
+ """
333
+ if self._catalog is None:
334
+ raise RuntimeError("Cannot resolve file: catalog is not set")
335
+
336
+ try:
337
+ client = self._catalog.get_client(self.source)
338
+ except NotImplementedError as e:
339
+ raise RuntimeError(
340
+ f"Unsupported protocol for file source: {self.source}"
341
+ ) from e
342
+
343
+ try:
344
+ info = client.fs.info(client.get_full_path(self.path))
345
+ converted_info = client.info_to_file(info, self.source)
346
+ return type(self)(
347
+ path=self.path,
348
+ source=self.source,
349
+ size=converted_info.size,
350
+ etag=converted_info.etag,
351
+ version=converted_info.version,
352
+ is_latest=converted_info.is_latest,
353
+ last_modified=converted_info.last_modified,
354
+ location=self.location,
355
+ )
356
+ except (FileNotFoundError, PermissionError, OSError) as e:
357
+ logger.warning("File system error when resolving %s: %s", self.path, str(e))
358
+
359
+ return type(self)(
360
+ path=self.path,
361
+ source=self.source,
362
+ size=0,
363
+ etag="",
364
+ version="",
365
+ is_latest=True,
366
+ last_modified=TIME_ZERO,
367
+ location=self.location,
368
+ )
369
+
370
+
371
+ def resolve(file: File) -> File:
372
+ """
373
+ Resolve a File object by checking its existence and updating its metadata.
374
+
375
+ This function is a wrapper around the File.resolve() method, designed to be
376
+ used as a mapper in DataChain operations.
377
+
378
+ Args:
379
+ file (File): The File object to resolve.
380
+
381
+ Returns:
382
+ File: The resolved File object with updated metadata.
383
+
384
+ Raises:
385
+ RuntimeError: If the file's catalog is not set or if
386
+ the file source protocol is unsupported.
387
+ """
388
+ return file.resolve()
389
+
316
390
 
317
391
  class TextFile(File):
318
392
  """`DataModel` for reading text files."""
datachain/lib/hf.py CHANGED
@@ -15,7 +15,7 @@ try:
15
15
  Value,
16
16
  load_dataset,
17
17
  )
18
- from datasets.features.features import string_to_arrow
18
+ from datasets.features.features import Features, string_to_arrow
19
19
  from datasets.features.image import image_to_bytes
20
20
 
21
21
  except ImportError as exc:
@@ -36,6 +36,7 @@ from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
36
36
  from datachain.lib.udf import Generator
37
37
 
38
38
  if TYPE_CHECKING:
39
+ import pyarrow as pa
39
40
  from pydantic import BaseModel
40
41
 
41
42
 
@@ -71,6 +72,15 @@ class HFGenerator(Generator):
71
72
  *args,
72
73
  **kwargs,
73
74
  ):
75
+ """
76
+ Generator for chain from huggingface datasets.
77
+
78
+ Parameters:
79
+
80
+ ds : Path or name of the dataset to read from Hugging Face Hub,
81
+ or an instance of `datasets.Dataset`-like object.
82
+ output_schema : Pydantic model for validation.
83
+ """
74
84
  super().__init__()
75
85
  self.ds = ds
76
86
  self.output_schema = output_schema
@@ -92,7 +102,7 @@ class HFGenerator(Generator):
92
102
  output_dict["split"] = split
93
103
  for name, feat in ds.features.items():
94
104
  anno = self.output_schema.model_fields[name].annotation
95
- output_dict[name] = _convert_feature(row[name], feat, anno)
105
+ output_dict[name] = convert_feature(row[name], feat, anno)
96
106
  yield self.output_schema(**output_dict)
97
107
  pbar.update(1)
98
108
 
@@ -106,7 +116,7 @@ def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
106
116
  return {"": ds}
107
117
 
108
118
 
109
- def _convert_feature(val: Any, feat: Any, anno: Any) -> Any:
119
+ def convert_feature(val: Any, feat: Any, anno: Any) -> Any: # noqa: PLR0911
110
120
  if isinstance(feat, (Value, Array2D, Array3D, Array4D, Array5D)):
111
121
  return val
112
122
  if isinstance(feat, ClassLabel):
@@ -117,20 +127,23 @@ def _convert_feature(val: Any, feat: Any, anno: Any) -> Any:
117
127
  for sname in val:
118
128
  sfeat = feat.feature[sname]
119
129
  sanno = anno.model_fields[sname].annotation
120
- sdict[sname] = [_convert_feature(v, sfeat, sanno) for v in val[sname]]
130
+ sdict[sname] = [convert_feature(v, sfeat, sanno) for v in val[sname]]
121
131
  return anno(**sdict)
122
132
  return val
123
133
  if isinstance(feat, Image):
134
+ if isinstance(val, dict):
135
+ return HFImage(img=val["bytes"])
124
136
  return HFImage(img=image_to_bytes(val))
125
137
  if isinstance(feat, Audio):
126
138
  return HFAudio(**val)
127
139
 
128
140
 
129
141
  def get_output_schema(
130
- ds: Union[Dataset, IterableDataset], model_name: str = ""
142
+ features: Features, model_name: str = "", stream: bool = True
131
143
  ) -> dict[str, DataType]:
144
+ """Generate UDF output schema from huggingface datasets features."""
132
145
  fields_dict = {}
133
- for name, val in ds.features.items():
146
+ for name, val in features.items():
134
147
  fields_dict[name] = _feature_to_chain_type(name, val) # type: ignore[assignment]
135
148
  return fields_dict # type: ignore[return-value]
136
149
 
@@ -165,3 +178,7 @@ def _feature_to_chain_type(name: str, val: Any) -> type: # noqa: PLR0911
165
178
  if isinstance(val, Audio):
166
179
  return HFAudio
167
180
  raise TypeError(f"Unknown huggingface datasets type {type(val)}")
181
+
182
+
183
+ def schema_from_arrow(schema: "pa.Schema"):
184
+ return Features.from_arrow_schema(schema)
datachain/lib/listing.py CHANGED
@@ -20,7 +20,7 @@ LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
20
20
  LISTING_PREFIX = "lst__" # listing datasets start with this name
21
21
 
22
22
 
23
- def list_bucket(uri: str, client_config=None) -> Callable:
23
+ def list_bucket(uri: str, cache, client_config=None) -> Callable:
24
24
  """
25
25
  Function that returns another generator function that yields File objects
26
26
  from bucket where each File represents one bucket entry.
@@ -28,10 +28,10 @@ def list_bucket(uri: str, client_config=None) -> Callable:
28
28
 
29
29
  def list_func() -> Iterator[File]:
30
30
  config = client_config or {}
31
- client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
31
+ client = Client.get_client(uri, cache, **config) # type: ignore[arg-type]
32
+ _, path = Client.parse_url(uri)
32
33
  for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
33
- for entry in entries:
34
- yield entry.to_file(client.uri)
34
+ yield from entries
35
35
 
36
36
  return list_func
37
37
 
@@ -77,16 +77,17 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
77
77
  """
78
78
  Parsing uri and returns listing dataset name, listing uri and listing path
79
79
  """
80
- client, path = Client.parse_url(uri, cache, **client_config)
80
+ client = Client.get_client(uri, cache, **client_config)
81
+ storage_uri, path = Client.parse_url(uri)
81
82
 
82
83
  # clean path without globs
83
84
  lst_uri_path = (
84
85
  posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
85
86
  )
86
87
 
87
- lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}"
88
+ lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
88
89
  ds_name = (
89
- f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
90
+ f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
90
91
  )
91
92
 
92
93
  return ds_name, lst_uri, path
@@ -13,8 +13,8 @@ class ListingInfo(DatasetInfo):
13
13
 
14
14
  @property
15
15
  def storage_uri(self) -> str:
16
- client, _ = Client.parse_url(self.uri, None) # type: ignore[arg-type]
17
- return client.uri
16
+ uri, _ = Client.parse_url(self.uri)
17
+ return uri
18
18
 
19
19
  @property
20
20
  def expires(self) -> Optional[datetime]:
@@ -1,6 +1,6 @@
1
1
  import inspect
2
2
  import logging
3
- from typing import ClassVar, Optional
3
+ from typing import Any, ClassVar, Optional
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
@@ -69,7 +69,7 @@ class ModelStore:
69
69
  del cls.store[fr.__name__][version]
70
70
 
71
71
  @staticmethod
72
- def is_pydantic(val):
72
+ def is_pydantic(val: Any) -> bool:
73
73
  return (
74
74
  not hasattr(val, "__origin__")
75
75
  and inspect.isclass(val)
datachain/lib/pytorch.py CHANGED
@@ -7,6 +7,7 @@ from torch import float32
7
7
  from torch.distributed import get_rank, get_world_size
8
8
  from torch.utils.data import IterableDataset, get_worker_info
9
9
  from torchvision.transforms import v2
10
+ from tqdm import tqdm
10
11
 
11
12
  from datachain.catalog import Catalog, get_catalog
12
13
  from datachain.lib.dc import DataChain
@@ -93,33 +94,38 @@ class PytorchDataset(IterableDataset):
93
94
  if self.num_samples > 0:
94
95
  ds = ds.sample(self.num_samples)
95
96
  ds = ds.chunk(total_rank, total_workers)
96
- for row_features in ds.collect():
97
- row = []
98
- for fr in row_features:
99
- if hasattr(fr, "read"):
100
- row.append(fr.read()) # type: ignore[unreachable]
101
- else:
102
- row.append(fr)
103
- # Apply transforms
104
- if self.transform:
105
- try:
106
- if isinstance(self.transform, v2.Transform):
107
- row = self.transform(row)
97
+ desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
98
+ with tqdm(desc=desc, unit=" rows") as pbar:
99
+ for row_features in ds.collect():
100
+ row = []
101
+ for fr in row_features:
102
+ if hasattr(fr, "read"):
103
+ row.append(fr.read()) # type: ignore[unreachable]
104
+ else:
105
+ row.append(fr)
106
+ # Apply transforms
107
+ if self.transform:
108
+ try:
109
+ if isinstance(self.transform, v2.Transform):
110
+ row = self.transform(row)
111
+ for i, val in enumerate(row):
112
+ if isinstance(val, Image.Image):
113
+ row[i] = self.transform(val)
114
+ except ValueError:
115
+ logger.warning(
116
+ "Skipping transform due to unsupported data types."
117
+ )
118
+ self.transform = None
119
+ if self.tokenizer:
108
120
  for i, val in enumerate(row):
109
- if isinstance(val, Image.Image):
110
- row[i] = self.transform(val)
111
- except ValueError:
112
- logger.warning("Skipping transform due to unsupported data types.")
113
- self.transform = None
114
- if self.tokenizer:
115
- for i, val in enumerate(row):
116
- if isinstance(val, str) or (
117
- isinstance(val, list) and isinstance(val[0], str)
118
- ):
119
- row[i] = convert_text(
120
- val, self.tokenizer, self.tokenizer_kwargs
121
- ).squeeze(0) # type: ignore[union-attr]
122
- yield row
121
+ if isinstance(val, str) or (
122
+ isinstance(val, list) and isinstance(val[0], str)
123
+ ):
124
+ row[i] = convert_text(
125
+ val, self.tokenizer, self.tokenizer_kwargs
126
+ ).squeeze(0) # type: ignore[union-attr]
127
+ yield row
128
+ pbar.update(1)
123
129
 
124
130
  @staticmethod
125
131
  def get_rank_and_workers() -> tuple[int, int]:
@@ -4,11 +4,14 @@ from collections.abc import Iterator, Sequence
4
4
  from dataclasses import dataclass
5
5
  from datetime import datetime
6
6
  from inspect import isclass
7
- from typing import (
7
+ from typing import ( # noqa: UP035
8
8
  TYPE_CHECKING,
9
9
  Annotated,
10
10
  Any,
11
11
  Callable,
12
+ Dict,
13
+ Final,
14
+ List,
12
15
  Literal,
13
16
  Optional,
14
17
  Union,
@@ -42,8 +45,13 @@ NAMES_TO_TYPES = {
42
45
  "dict": dict,
43
46
  "bytes": bytes,
44
47
  "datetime": datetime,
45
- "Literal": Literal,
48
+ "Final": Final,
46
49
  "Union": Union,
50
+ "Optional": Optional,
51
+ "List": list,
52
+ "Dict": dict,
53
+ "Literal": Any,
54
+ "Any": Any,
47
55
  }
48
56
 
49
57
 
@@ -146,35 +154,11 @@ class SignalSchema:
146
154
  return SignalSchema(signals)
147
155
 
148
156
  @staticmethod
149
- def _get_name_original_type(fr_type: type) -> tuple[str, type]:
150
- """Returns the name of and the original type for the given type,
151
- based on whether the type is Optional or not."""
152
- orig = get_origin(fr_type)
153
- args = get_args(fr_type)
154
- # Check if fr_type is Optional
155
- if orig == Union and len(args) == 2 and (type(None) in args):
156
- fr_type = args[0]
157
- orig = get_origin(fr_type)
158
- if orig in (Literal, LiteralEx):
159
- # Literal has no __name__ in Python 3.9
160
- type_name = "Literal"
161
- elif orig == Union:
162
- # Union also has no __name__ in Python 3.9
163
- type_name = "Union"
164
- else:
165
- type_name = str(fr_type.__name__) # type: ignore[union-attr]
166
- return type_name, fr_type
167
-
168
- @staticmethod
169
- def serialize_custom_model_fields(
170
- name: str, fr: type, custom_types: dict[str, Any]
157
+ def _serialize_custom_model_fields(
158
+ version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
171
159
  ) -> str:
172
160
  """This serializes any custom type information to the provided custom_types
173
- dict, and returns the name of the type provided."""
174
- if hasattr(fr, "__origin__") or not issubclass(fr, BaseModel):
175
- # Don't store non-feature types.
176
- return name
177
- version_name = ModelStore.get_name(fr)
161
+ dict, and returns the name of the type serialized."""
178
162
  if version_name in custom_types:
179
163
  # This type is already stored in custom_types.
180
164
  return version_name
@@ -183,37 +167,102 @@ class SignalSchema:
183
167
  field_type = info.annotation
184
168
  # All fields should be typed.
185
169
  assert field_type
186
- field_type_name, field_type = SignalSchema._get_name_original_type(
187
- field_type
188
- )
189
- # Serialize this type to custom_types if it is a custom type as well.
190
- fields[field_name] = SignalSchema.serialize_custom_model_fields(
191
- field_type_name, field_type, custom_types
192
- )
170
+ fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
193
171
  custom_types[version_name] = fields
194
172
  return version_name
195
173
 
174
+ @staticmethod
175
+ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
176
+ """Serialize a given type to a string, including automatic ModelStore
177
+ registration, and save this type and subtypes to custom_types as well."""
178
+ subtypes: list[Any] = []
179
+ type_name = SignalSchema._type_to_str(fr, subtypes)
180
+ # Iterate over all subtypes (includes the input type).
181
+ for st in subtypes:
182
+ if st is None or not ModelStore.is_pydantic(st):
183
+ continue
184
+ # Register and save feature types.
185
+ ModelStore.register(st)
186
+ st_version_name = ModelStore.get_name(st)
187
+ if st is fr:
188
+ # If the main type is Pydantic, then use the ModelStore version name.
189
+ type_name = st_version_name
190
+ # Save this type to custom_types.
191
+ SignalSchema._serialize_custom_model_fields(
192
+ st_version_name, st, custom_types
193
+ )
194
+ return type_name
195
+
196
196
  def serialize(self) -> dict[str, Any]:
197
197
  signals: dict[str, Any] = {}
198
198
  custom_types: dict[str, Any] = {}
199
199
  for name, fr_type in self.values.items():
200
- if (fr := ModelStore.to_pydantic(fr_type)) is not None:
201
- ModelStore.register(fr)
202
- signals[name] = ModelStore.get_name(fr)
203
- type_name, fr_type = SignalSchema._get_name_original_type(fr)
204
- else:
205
- type_name, fr_type = SignalSchema._get_name_original_type(fr_type)
206
- signals[name] = type_name
207
- self.serialize_custom_model_fields(type_name, fr_type, custom_types)
200
+ signals[name] = self._serialize_type(fr_type, custom_types)
208
201
  if custom_types:
209
202
  signals["_custom_types"] = custom_types
210
203
  return signals
211
204
 
212
205
  @staticmethod
213
- def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
206
+ def _split_subtypes(type_name: str) -> list[str]:
207
+ """This splits a list of subtypes, including proper square bracket handling."""
208
+ start = 0
209
+ depth = 0
210
+ subtypes = []
211
+ for i, c in enumerate(type_name):
212
+ if c == "[":
213
+ depth += 1
214
+ elif c == "]":
215
+ if depth == 0:
216
+ raise TypeError(
217
+ "Extra closing square bracket when parsing subtype list"
218
+ )
219
+ depth -= 1
220
+ elif c == "," and depth == 0:
221
+ subtypes.append(type_name[start:i].strip())
222
+ start = i + 1
223
+ if depth > 0:
224
+ raise TypeError("Unclosed square bracket when parsing subtype list")
225
+ subtypes.append(type_name[start:].strip())
226
+ return subtypes
227
+
228
+ @staticmethod
229
+ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
214
230
  """Convert a string-based type back into a python type."""
231
+ type_name = type_name.strip()
232
+ if not type_name:
233
+ raise TypeError("Type cannot be empty")
234
+ if type_name == "NoneType":
235
+ return None
236
+
237
+ bracket_idx = type_name.find("[")
238
+ subtypes: Optional[tuple[Optional[type], ...]] = None
239
+ if bracket_idx > -1:
240
+ if bracket_idx == 0:
241
+ raise TypeError("Type cannot start with '['")
242
+ close_bracket_idx = type_name.rfind("]")
243
+ if close_bracket_idx == -1:
244
+ raise TypeError("Unclosed square bracket when parsing type")
245
+ if close_bracket_idx < bracket_idx:
246
+ raise TypeError("Square brackets are out of order when parsing type")
247
+ if close_bracket_idx == bracket_idx + 1:
248
+ raise TypeError("Empty square brackets when parsing type")
249
+ subtype_names = SignalSchema._split_subtypes(
250
+ type_name[bracket_idx + 1 : close_bracket_idx]
251
+ )
252
+ # Types like Union require the parameters to be a tuple of types.
253
+ subtypes = tuple(
254
+ SignalSchema._resolve_type(st, custom_types) for st in subtype_names
255
+ )
256
+ type_name = type_name[:bracket_idx].strip()
257
+
215
258
  fr = NAMES_TO_TYPES.get(type_name)
216
259
  if fr:
260
+ if subtypes:
261
+ if len(subtypes) == 1:
262
+ # Types like Optional require there to be only one argument.
263
+ return fr[subtypes[0]] # type: ignore[index]
264
+ # Other types like Union require the parameters to be a tuple of types.
265
+ return fr[subtypes] # type: ignore[index]
217
266
  return fr # type: ignore[return-value]
218
267
 
219
268
  model_name, version = ModelStore.parse_name_version(type_name)
@@ -228,7 +277,14 @@ class SignalSchema:
228
277
  for field_name, field_type_str in fields.items()
229
278
  }
230
279
  return create_feature_model(type_name, fields)
231
- return None
280
+ # This can occur if a third-party or custom type is used, which is not available
281
+ # when deserializing.
282
+ warnings.warn(
283
+ f"Could not resolve type: '{type_name}'.",
284
+ SignalSchemaWarning,
285
+ stacklevel=2,
286
+ )
287
+ return Any # type: ignore[return-value]
232
288
 
233
289
  @staticmethod
234
290
  def deserialize(schema: dict[str, Any]) -> "SignalSchema":
@@ -242,9 +298,14 @@ class SignalSchema:
242
298
  # This entry is used as a lookup for custom types,
243
299
  # and is not an actual field.
244
300
  continue
301
+ if not isinstance(type_name, str):
302
+ raise SignalSchemaError(
303
+ f"cannot deserialize '{type_name}': "
304
+ "serialized types must be a string"
305
+ )
245
306
  try:
246
307
  fr = SignalSchema._resolve_type(type_name, custom_types)
247
- if fr is None:
308
+ if fr is Any:
248
309
  # Skip if the type is not found, so all data can be displayed.
249
310
  warnings.warn(
250
311
  f"In signal '{signal}': "
@@ -258,7 +319,7 @@ class SignalSchema:
258
319
  raise SignalSchemaError(
259
320
  f"cannot deserialize '{signal}': {err}"
260
321
  ) from err
261
- signals[signal] = fr
322
+ signals[signal] = fr # type: ignore[assignment]
262
323
 
263
324
  return SignalSchema(signals)
264
325
 
@@ -325,11 +386,20 @@ class SignalSchema:
325
386
  else:
326
387
  json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
327
388
  obj = fr(**json)
328
- if isinstance(obj, File):
329
- obj._set_stream(catalog, caching_enabled=cache)
389
+ SignalSchema._set_file_stream(obj, catalog, cache)
330
390
  res.append(obj)
331
391
  return res
332
392
 
393
+ @staticmethod
394
+ def _set_file_stream(
395
+ obj: BaseModel, catalog: "Catalog", cache: bool = False
396
+ ) -> None:
397
+ if isinstance(obj, File):
398
+ obj._set_stream(catalog, caching_enabled=cache)
399
+ for field, finfo in obj.model_fields.items():
400
+ if ModelStore.is_pydantic(finfo.annotation):
401
+ SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
402
+
333
403
  def db_signals(
334
404
  self, name: Optional[str] = None, as_columns=False
335
405
  ) -> Union[list[str], list[Column]]:
@@ -509,31 +579,58 @@ class SignalSchema:
509
579
  return self.values.pop(name)
510
580
 
511
581
  @staticmethod
512
- def _type_to_str(type_): # noqa: PLR0911
582
+ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
583
+ """Convert a type to a string-based representation."""
584
+ if type_ is None:
585
+ return "NoneType"
586
+
513
587
  origin = get_origin(type_)
514
588
 
515
589
  if origin == Union:
516
590
  args = get_args(type_)
517
- formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args)
591
+ formatted_types = ", ".join(
592
+ SignalSchema._type_to_str(arg, subtypes) for arg in args
593
+ )
518
594
  return f"Union[{formatted_types}]"
519
595
  if origin == Optional:
520
596
  args = get_args(type_)
521
- type_str = SignalSchema._type_to_str(args[0])
597
+ type_str = SignalSchema._type_to_str(args[0], subtypes)
522
598
  return f"Optional[{type_str}]"
523
- if origin is list:
599
+ if origin in (list, List): # noqa: UP006
524
600
  args = get_args(type_)
525
- type_str = SignalSchema._type_to_str(args[0])
601
+ type_str = SignalSchema._type_to_str(args[0], subtypes)
526
602
  return f"list[{type_str}]"
527
- if origin is dict:
603
+ if origin in (dict, Dict): # noqa: UP006
528
604
  args = get_args(type_)
529
- type_str = SignalSchema._type_to_str(args[0]) if len(args) > 0 else ""
530
- vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else ""
605
+ type_str = (
606
+ SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
607
+ )
608
+ vals = (
609
+ f", {SignalSchema._type_to_str(args[1], subtypes)}"
610
+ if len(args) > 1
611
+ else ""
612
+ )
531
613
  return f"dict[{type_str}{vals}]"
532
614
  if origin == Annotated:
533
615
  args = get_args(type_)
534
- return SignalSchema._type_to_str(args[0])
535
- if origin in (Literal, LiteralEx):
616
+ return SignalSchema._type_to_str(args[0], subtypes)
617
+ if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx):
536
618
  return "Literal"
619
+ if Any in (origin, type_):
620
+ return "Any"
621
+ if Final in (origin, type_):
622
+ return "Final"
623
+ if subtypes is not None:
624
+ # Include this type in the list of all subtypes, if requested.
625
+ subtypes.append(type_)
626
+ if not hasattr(type_, "__name__"):
627
+ # This can happen for some third-party or custom types, mostly on Python 3.9
628
+ warnings.warn(
629
+ f"Unable to determine name of type '{type_}'.",
630
+ SignalSchemaWarning,
631
+ stacklevel=2,
632
+ )
633
+ return "Any"
537
634
  return type_.__name__
538
635
 
539
636
  @staticmethod