datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/hf.py CHANGED
@@ -11,7 +11,7 @@ try:
11
11
  Image,
12
12
  IterableDataset,
13
13
  IterableDatasetDict,
14
- Sequence,
14
+ List,
15
15
  Value,
16
16
  load_dataset,
17
17
  )
@@ -26,7 +26,7 @@ except ImportError as exc:
26
26
  ) from exc
27
27
 
28
28
  from io import BytesIO
29
- from typing import TYPE_CHECKING, Any, Union
29
+ from typing import TYPE_CHECKING, Any, TypeAlias
30
30
 
31
31
  import PIL
32
32
  from tqdm.auto import tqdm
@@ -34,13 +34,16 @@ from tqdm.auto import tqdm
34
34
  from datachain.lib.arrow import arrow_type_mapper
35
35
  from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
36
36
  from datachain.lib.udf import Generator
37
+ from datachain.lib.utils import normalize_col_names
37
38
 
38
39
  if TYPE_CHECKING:
39
40
  import pyarrow as pa
40
41
  from pydantic import BaseModel
41
42
 
42
43
 
43
- HFDatasetType = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]
44
+ HFDatasetType: TypeAlias = (
45
+ str | DatasetDict | Dataset | IterableDatasetDict | IterableDataset
46
+ )
44
47
 
45
48
 
46
49
  class HFClassLabel(DataModel):
@@ -59,7 +62,6 @@ class HFImage(DataModel):
59
62
 
60
63
 
61
64
  class HFAudio(DataModel):
62
- path: str
63
65
  array: list[float]
64
66
  sampling_rate: int
65
67
 
@@ -67,23 +69,27 @@ class HFAudio(DataModel):
67
69
  class HFGenerator(Generator):
68
70
  def __init__(
69
71
  self,
70
- ds: Union[str, HFDatasetType],
72
+ ds: HFDatasetType,
71
73
  output_schema: type["BaseModel"],
74
+ limit: int = 0,
72
75
  *args,
73
76
  **kwargs,
74
77
  ):
75
78
  """
76
- Generator for chain from huggingface datasets.
79
+ Generator for chain from Hugging Face datasets.
77
80
 
78
81
  Parameters:
79
82
 
80
- ds : Path or name of the dataset to read from Hugging Face Hub,
81
- or an instance of `datasets.Dataset`-like object.
82
- output_schema : Pydantic model for validation.
83
+ ds : Path or name of the dataset to read from Hugging Face Hub,
84
+ or an instance of `datasets.Dataset`-like object.
85
+ limit : Limit the number of items to read from the HF dataset.
86
+ Defaults to 0 (no limit).
87
+ output_schema : Pydantic model for validation.
83
88
  """
84
89
  super().__init__()
85
90
  self.ds = ds
86
91
  self.output_schema = output_schema
92
+ self.limit = limit
87
93
  self.args = args
88
94
  self.kwargs = kwargs
89
95
 
@@ -93,57 +99,81 @@ class HFGenerator(Generator):
93
99
  def process(self, split: str = ""):
94
100
  desc = "Parsed Hugging Face dataset"
95
101
  ds = self.ds_dict[split]
102
+ if self.limit > 0:
103
+ ds = ds.take(self.limit)
96
104
  if split:
97
105
  desc += f" split '{split}'"
106
+ model_fields = self.output_schema._model_fields_by_aliases() # type: ignore[attr-defined]
98
107
  with tqdm(desc=desc, unit=" rows", leave=False) as pbar:
99
108
  for row in ds:
100
109
  output_dict = {}
101
110
  if split and "split" in self.output_schema.model_fields:
102
111
  output_dict["split"] = split
103
112
  for name, feat in ds.features.items():
104
- anno = self.output_schema.model_fields[name].annotation
105
- output_dict[name] = convert_feature(row[name], feat, anno)
113
+ normalized_name, info = model_fields[name]
114
+ anno = info.annotation
115
+ output_dict[normalized_name] = convert_feature(
116
+ row[name], feat, anno
117
+ )
106
118
  yield self.output_schema(**output_dict)
107
119
  pbar.update(1)
108
120
 
109
121
 
110
- def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
122
+ def stream_splits(ds: HFDatasetType, *args, **kwargs):
111
123
  if isinstance(ds, str):
112
- kwargs["streaming"] = True
113
124
  ds = load_dataset(ds, *args, **kwargs)
114
125
  if isinstance(ds, (DatasetDict, IterableDatasetDict)):
115
126
  return ds
116
127
  return {"": ds}
117
128
 
118
129
 
119
- def convert_feature(val: Any, feat: Any, anno: Any) -> Any: # noqa: PLR0911
120
- if isinstance(feat, (Value, Array2D, Array3D, Array4D, Array5D)):
130
+ def convert_feature(val: Any, feat: Any, anno: Any) -> Any:
131
+ if isinstance(feat, (Value, Array2D, Array3D, Array4D, Array5D, List)):
121
132
  return val
122
133
  if isinstance(feat, ClassLabel):
123
134
  return HFClassLabel(string=feat.names[val], integer=val)
124
- if isinstance(feat, Sequence):
125
- if isinstance(feat.feature, dict):
126
- sdict = {}
127
- for sname in val:
128
- sfeat = feat.feature[sname]
129
- sanno = anno.model_fields[sname].annotation
130
- sdict[sname] = [convert_feature(v, sfeat, sanno) for v in val[sname]]
131
- return anno(**sdict)
132
- return val
135
+ if isinstance(feat, dict):
136
+ sdict = {}
137
+ model_fields = anno._model_fields_by_aliases() # type: ignore[attr-defined]
138
+ for sname in val:
139
+ sfeat = feat[sname]
140
+ norm_name, info = model_fields[sname]
141
+ sanno = info.annotation
142
+ if isinstance(val[sname], list):
143
+ sdict[norm_name] = [
144
+ convert_feature(v, sfeat, sanno) for v in val[sname]
145
+ ]
146
+ else:
147
+ sdict[norm_name] = convert_feature(val[sname], sfeat, sanno)
148
+ return anno(**sdict)
133
149
  if isinstance(feat, Image):
134
150
  if isinstance(val, dict):
135
151
  return HFImage(img=val["bytes"])
136
152
  return HFImage(img=image_to_bytes(val))
137
153
  if isinstance(feat, Audio):
138
- return HFAudio(**val)
139
-
140
-
141
- def get_output_schema(features: Features) -> dict[str, DataType]:
142
- """Generate UDF output schema from huggingface datasets features."""
154
+ return HFAudio(array=val["array"], sampling_rate=val["sampling_rate"])
155
+
156
+
157
+ def get_output_schema(
158
+ features: Features, existing_column_names: list[str] | None = None
159
+ ) -> tuple[dict[str, DataType], dict[str, str]]:
160
+ """
161
+ Generate UDF output schema from Hugging Face datasets features. It normalizes the
162
+ column names and returns a mapping of normalized names to original names along with
163
+ the data types. `existing_column_names` is the list of column names that already
164
+ exist in the dataset (to avoid name collisions due to normalization).
165
+ """
166
+ existing_column_names = existing_column_names or []
143
167
  fields_dict = {}
144
- for name, val in features.items():
145
- fields_dict[name] = _feature_to_chain_type(name, val)
146
- return fields_dict
168
+ normalized_names = normalize_col_names(
169
+ existing_column_names + list(features.keys())
170
+ )
171
+ # List of tuple(str, str) for HF dataset feature names, (normalized, original)
172
+ new_feature_names = list(normalized_names.items())[len(existing_column_names) :]
173
+ for idx, feat in enumerate(features.items()):
174
+ name, val = feat
175
+ fields_dict[new_feature_names[idx][0]] = _feature_to_chain_type(name, val)
176
+ return fields_dict, normalized_names
147
177
 
148
178
 
149
179
  def _feature_to_chain_type(name: str, val: Any) -> DataType: # noqa: PLR0911
@@ -151,13 +181,13 @@ def _feature_to_chain_type(name: str, val: Any) -> DataType: # noqa: PLR0911
151
181
  return arrow_type_mapper(val.pa_type)
152
182
  if isinstance(val, ClassLabel):
153
183
  return HFClassLabel
154
- if isinstance(val, Sequence):
155
- if isinstance(val.feature, dict):
156
- sequence_dict = {}
157
- for sname, sval in val.feature.items():
158
- dtype = _feature_to_chain_type(sname, sval)
159
- sequence_dict[sname] = list[dtype] # type: ignore[valid-type]
160
- return dict_to_data_model(name, sequence_dict) # type: ignore[arg-type]
184
+ if isinstance(val, dict):
185
+ sequence_dict = {}
186
+ for sname, sval in val.items():
187
+ dtype = _feature_to_chain_type(sname, sval)
188
+ sequence_dict[sname] = dtype # type: ignore[valid-type]
189
+ return dict_to_data_model(f"HFDataModel_{name}", sequence_dict) # type: ignore[arg-type]
190
+ if isinstance(val, List):
161
191
  return list[_feature_to_chain_type(name, val.feature)] # type: ignore[arg-type,misc,return-value]
162
192
  if isinstance(val, Array2D):
163
193
  dtype = arrow_type_mapper(string_to_arrow(val.dtype))
datachain/lib/image.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional, Union
1
+ from collections.abc import Callable
2
2
 
3
3
  import torch
4
4
  from PIL import Image as PILImage
@@ -6,7 +6,7 @@ from PIL import Image as PILImage
6
6
  from datachain.lib.file import File, FileError, Image, ImageFile
7
7
 
8
8
 
9
- def image_info(file: Union[File, ImageFile]) -> Image:
9
+ def image_info(file: File | ImageFile) -> Image:
10
10
  """
11
11
  Returns image file information.
12
12
 
@@ -19,7 +19,7 @@ def image_info(file: Union[File, ImageFile]) -> Image:
19
19
  try:
20
20
  img = file.as_image_file().read()
21
21
  except Exception as exc:
22
- raise FileError(file, "unable to open image file") from exc
22
+ raise FileError("unable to open image file", file.source, file.path) from exc
23
23
 
24
24
  return Image(
25
25
  width=img.width,
@@ -31,11 +31,11 @@ def image_info(file: Union[File, ImageFile]) -> Image:
31
31
  def convert_image(
32
32
  img: PILImage.Image,
33
33
  mode: str = "RGB",
34
- size: Optional[tuple[int, int]] = None,
35
- transform: Optional[Callable] = None,
36
- encoder: Optional[Callable] = None,
37
- device: Optional[Union[str, torch.device]] = None,
38
- ) -> Union[PILImage.Image, torch.Tensor]:
34
+ size: tuple[int, int] | None = None,
35
+ transform: Callable | None = None,
36
+ encoder: Callable | None = None,
37
+ device: str | torch.device | None = None,
38
+ ) -> PILImage.Image | torch.Tensor:
39
39
  """
40
40
  Resize, transform, and otherwise convert an image.
41
41
 
@@ -71,13 +71,13 @@ def convert_image(
71
71
 
72
72
 
73
73
  def convert_images(
74
- images: Union[PILImage.Image, list[PILImage.Image]],
74
+ images: PILImage.Image | list[PILImage.Image],
75
75
  mode: str = "RGB",
76
- size: Optional[tuple[int, int]] = None,
77
- transform: Optional[Callable] = None,
78
- encoder: Optional[Callable] = None,
79
- device: Optional[Union[str, torch.device]] = None,
80
- ) -> Union[list[PILImage.Image], torch.Tensor]:
76
+ size: tuple[int, int] | None = None,
77
+ transform: Callable | None = None,
78
+ encoder: Callable | None = None,
79
+ device: str | torch.device | None = None,
80
+ ) -> list[PILImage.Image] | torch.Tensor:
81
81
  """
82
82
  Resize, transform, and otherwise convert one or more images.
83
83
 
datachain/lib/listing.py CHANGED
@@ -2,10 +2,10 @@ import glob
2
2
  import logging
3
3
  import os
4
4
  import posixpath
5
- from collections.abc import Iterator
5
+ from collections.abc import Callable, Iterator
6
6
  from contextlib import contextmanager
7
7
  from datetime import datetime, timedelta, timezone
8
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
8
+ from typing import TYPE_CHECKING, TypeVar
9
9
 
10
10
  from fsspec.asyn import get_loop
11
11
  from sqlalchemy.sql.expression import true
@@ -56,6 +56,8 @@ def list_bucket(uri: str, cache, client_config=None) -> Callable:
56
56
  for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
57
57
  yield from entries
58
58
 
59
+ list_func.__name__ = "read_storage"
60
+
59
61
  return list_func
60
62
 
61
63
 
@@ -71,8 +73,8 @@ def get_file_info(uri: str, cache, client_config=None) -> File:
71
73
  def ls(
72
74
  dc: D,
73
75
  path: str,
74
- recursive: Optional[bool] = True,
75
- object_name="file",
76
+ recursive: bool | None = True,
77
+ column="file",
76
78
  ) -> D:
77
79
  """
78
80
  Return files by some path from DataChain instance which contains bucket listing.
@@ -82,7 +84,7 @@ def ls(
82
84
  """
83
85
 
84
86
  def _file_c(name: str) -> Column:
85
- return Column(f"{object_name}.{name}")
87
+ return Column(f"{column}.{name}")
86
88
 
87
89
  dc = dc.filter(_file_c("is_latest") == true())
88
90
 
@@ -105,11 +107,10 @@ def ls(
105
107
  return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
106
108
 
107
109
 
108
- def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
110
+ def parse_listing_uri(uri: str) -> tuple[str, str, str]:
109
111
  """
110
112
  Parsing uri and returns listing dataset name, listing uri and listing path
111
113
  """
112
- client_config = client_config or {}
113
114
  storage_uri, path = Client.parse_url(uri)
114
115
  if uses_glob(path):
115
116
  lst_uri_path = posixpath.dirname(path)
@@ -122,6 +123,9 @@ def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
122
123
  f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
123
124
  )
124
125
 
126
+ # we should remove dots from the name
127
+ ds_name = ds_name.replace(".", "_")
128
+
125
129
  return ds_name, lst_uri, path
126
130
 
127
131
 
@@ -146,8 +150,8 @@ def _reraise_as_client_error() -> Iterator[None]:
146
150
 
147
151
 
148
152
  def get_listing(
149
- uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
150
- ) -> tuple[Optional[str], str, str, bool]:
153
+ uri: str | os.PathLike[str], session: "Session", update: bool = False
154
+ ) -> tuple[str | None, str, str, bool]:
151
155
  """Returns correct listing dataset name that must be used for saving listing
152
156
  operation. It takes into account existing listings and reusability of those.
153
157
  It also returns boolean saying if returned dataset name is reused / already
@@ -173,7 +177,7 @@ def get_listing(
173
177
  _, path = Client.parse_url(uri)
174
178
  return None, uri, path, False
175
179
 
176
- ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
180
+ ds_name, list_uri, list_path = parse_listing_uri(uri)
177
181
  listing = None
178
182
  listings = [
179
183
  ls for ls in catalog.listings() if not ls.is_expired and ls.contains(ds_name)
@@ -194,5 +198,4 @@ def get_listing(
194
198
  list_path = f"{ds_name.strip('/').removeprefix(listing.name)}/{list_path}"
195
199
 
196
200
  ds_name = listing.name if listing else ds_name
197
-
198
201
  return ds_name, list_uri, list_path, bool(listing)
@@ -1,5 +1,4 @@
1
1
  from datetime import datetime, timedelta, timezone
2
- from typing import Optional
3
2
 
4
3
  from datachain.client import Client
5
4
  from datachain.lib.dataset_info import DatasetInfo
@@ -17,7 +16,7 @@ class ListingInfo(DatasetInfo):
17
16
  return uri
18
17
 
19
18
  @property
20
- def expires(self) -> Optional[datetime]:
19
+ def expires(self) -> datetime | None:
21
20
  if not self.finished_at:
22
21
  return None
23
22
  return self.finished_at + timedelta(seconds=LISTING_TTL)
@@ -1,14 +1,13 @@
1
1
  import csv
2
- import json
3
2
  import tempfile
4
3
  import uuid
5
- from collections.abc import Iterator
4
+ from collections.abc import Callable, Iterator
6
5
  from pathlib import Path
7
- from typing import Callable
8
6
 
9
7
  import jmespath as jsp
10
8
  from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
11
9
 
10
+ from datachain import json
12
11
  from datachain.lib.data_model import DataModel # noqa: F401
13
12
  from datachain.lib.file import TextFile
14
13
 
@@ -106,7 +105,7 @@ def read_meta( # noqa: C901
106
105
  from datachain import read_storage
107
106
 
108
107
  if schema_from:
109
- file = next(read_storage(schema_from, type="text").limit(1).collect("file"))
108
+ file = read_storage(schema_from, type="text").limit(1).to_values("file")[0]
110
109
  model_code = gen_datamodel_code(
111
110
  file, format=format, jmespath=jmespath, model_name=model_name
112
111
  )
@@ -1,11 +1,8 @@
1
1
  import inspect
2
- import logging
3
- from typing import Any, ClassVar, Optional
2
+ from typing import Any, ClassVar
4
3
 
5
4
  from pydantic import BaseModel
6
5
 
7
- logger = logging.getLogger(__name__)
8
-
9
6
 
10
7
  class ModelStore:
11
8
  store: ClassVar[dict[str, dict[int, type[BaseModel]]]] = {}
@@ -14,7 +11,7 @@ class ModelStore:
14
11
  def get_version(cls, model: type[BaseModel]) -> int:
15
12
  if not hasattr(model, "_version"):
16
13
  return 0
17
- return model._version
14
+ return model._version # type: ignore[attr-defined]
18
15
 
19
16
  @classmethod
20
17
  def get_name(cls, model) -> str:
@@ -39,7 +36,7 @@ class ModelStore:
39
36
  cls.register(anno)
40
37
 
41
38
  @classmethod
42
- def get(cls, name: str, version: Optional[int] = None) -> Optional[type]:
39
+ def get(cls, name: str, version: int | None = None) -> type | None:
43
40
  class_dict = cls.store.get(name, None)
44
41
  if class_dict is None:
45
42
  return None
@@ -77,7 +74,42 @@ class ModelStore:
77
74
  )
78
75
 
79
76
  @staticmethod
80
- def to_pydantic(val) -> Optional[type[BaseModel]]:
77
+ def to_pydantic(val) -> type[BaseModel] | None:
81
78
  if val is None or not ModelStore.is_pydantic(val):
82
79
  return None
83
80
  return val
81
+
82
+ @staticmethod
83
+ def is_partial(parent_type) -> bool:
84
+ return (
85
+ parent_type
86
+ and ModelStore.is_pydantic(parent_type)
87
+ and "@" in ModelStore.get_name(parent_type)
88
+ )
89
+
90
+ @classmethod
91
+ def rebuild_all(cls) -> None:
92
+ """Ensure pydantic schemas are (re)built for all registered models.
93
+
94
+ Uses ``force=True`` to avoid subtle cases where a deserialized class
95
+ (e.g. from by-value cloudpickle in workers) reports built state but
96
+ nested model field schemas aren't fully resolved yet.
97
+ """
98
+ visited: set[type[BaseModel]] = set()
99
+ visiting: set[type[BaseModel]] = set()
100
+
101
+ def visit(model: type[BaseModel]) -> None:
102
+ if model in visited or model in visiting:
103
+ return
104
+ visiting.add(model)
105
+ for field in model.model_fields.values():
106
+ child = cls.to_pydantic(field.annotation)
107
+ if child is not None:
108
+ visit(child)
109
+ visiting.remove(model)
110
+ model.model_rebuild(force=True)
111
+ visited.add(model)
112
+
113
+ for versions in cls.store.values():
114
+ for model in versions.values():
115
+ visit(model)
@@ -0,0 +1,125 @@
1
+ from datachain.error import (
2
+ NamespaceCreateNotAllowedError,
3
+ NamespaceDeleteNotAllowedError,
4
+ )
5
+ from datachain.lib.projects import delete as delete_project
6
+ from datachain.namespace import Namespace, parse_name
7
+ from datachain.query import Session
8
+
9
+
10
+ def create(
11
+ name: str, descr: str | None = None, session: Session | None = None
12
+ ) -> Namespace:
13
+ """
14
+ Creates a new namespace.
15
+
16
+ Namespaces organize projects, which in turn organize datasets. A default
17
+ namespace always exists and is used if none is specified. Multiple namespaces
18
+ can be created in Studio, but only the default is available in the CLI.
19
+
20
+ Parameters:
21
+ name: Name of the new namespace.
22
+ descr: Optional description of the namespace.
23
+ session: Optional session to use for the operation.
24
+
25
+ Example:
26
+ ```py
27
+ from datachain.lib.namespaces import create as create_namespace
28
+ namespace = create_namespace("dev", "Dev namespace")
29
+ ```
30
+ """
31
+ session = Session.get(session)
32
+
33
+ from datachain.lib.dc.utils import is_studio
34
+
35
+ if not is_studio():
36
+ raise NamespaceCreateNotAllowedError("Creating namespace is not allowed")
37
+
38
+ Namespace.validate_name(name)
39
+
40
+ return session.catalog.metastore.create_namespace(name, descr)
41
+
42
+
43
+ def get(name: str, session: Session | None = None) -> Namespace:
44
+ """
45
+ Gets a namespace by name.
46
+ If the namespace is not found, a `NamespaceNotFoundError` is raised.
47
+
48
+ Parameters:
49
+ name : The name of the namespace.
50
+ session : Session to use for getting namespace.
51
+
52
+ Example:
53
+ ```py
54
+ import datachain as dc
55
+ namespace = dc.get_namespace("local")
56
+ ```
57
+ """
58
+ session = Session.get(session)
59
+ return session.catalog.metastore.get_namespace(name)
60
+
61
+
62
+ def ls(session: Session | None = None) -> list[Namespace]:
63
+ """
64
+ Gets a list of all namespaces.
65
+
66
+ Parameters:
67
+ session : Session to use for getting namespaces.
68
+
69
+ Example:
70
+ ```py
71
+ from datachain.lib.namespaces import ls as ls_namespaces
72
+ namespaces = ls_namespaces()
73
+ ```
74
+ """
75
+ return Session.get(session).catalog.metastore.list_namespaces()
76
+
77
+
78
+ def delete_namespace(name: str, session: Session | None = None) -> None:
79
+ """
80
+ Removes a namespace by name.
81
+
82
+ Raises:
83
+ NamespaceNotFoundError: If the namespace does not exist.
84
+ NamespaceDeleteNotAllowedError: If the namespace is non-empty,
85
+ is the default namespace, or is a system namespace,
86
+ as these cannot be removed.
87
+
88
+ Parameters:
89
+ name: The name of the namespace.
90
+ session: Session to use for getting project.
91
+
92
+ Example:
93
+ ```py
94
+ import datachain as dc
95
+ dc.delete_namespace("dev")
96
+ ```
97
+ """
98
+ session = Session.get(session)
99
+ metastore = session.catalog.metastore
100
+
101
+ namespace_name, project_name = parse_name(name)
102
+
103
+ if project_name:
104
+ return delete_project(project_name, namespace_name, session)
105
+
106
+ namespace = metastore.get_namespace(name)
107
+
108
+ if name == metastore.system_namespace_name:
109
+ raise NamespaceDeleteNotAllowedError(
110
+ f"Namespace {metastore.system_namespace_name} cannot be removed"
111
+ )
112
+
113
+ if name == metastore.default_namespace_name:
114
+ raise NamespaceDeleteNotAllowedError(
115
+ f"Namespace {metastore.default_namespace_name} cannot be removed"
116
+ )
117
+
118
+ num_projects = metastore.count_projects(namespace.id)
119
+ if num_projects > 0:
120
+ raise NamespaceDeleteNotAllowedError(
121
+ f"Namespace cannot be removed. It contains {num_projects} project(s). "
122
+ "Please remove the project(s) first."
123
+ )
124
+
125
+ metastore.remove_namespace(namespace.id)