datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/utils.py CHANGED
@@ -1,6 +1,9 @@
1
+ import inspect
1
2
  import re
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Sequence
5
+ from pathlib import PurePosixPath
6
+ from urllib.parse import urlparse
4
7
 
5
8
 
6
9
  class AbstractUDF(ABC):
@@ -18,13 +21,11 @@ class AbstractUDF(ABC):
18
21
 
19
22
 
20
23
  class DataChainError(Exception):
21
- def __init__(self, message):
22
- super().__init__(message)
24
+ pass
23
25
 
24
26
 
25
27
  class DataChainParamsError(DataChainError):
26
- def __init__(self, message):
27
- super().__init__(message)
28
+ pass
28
29
 
29
30
 
30
31
  class DataChainColumnError(DataChainParamsError):
@@ -32,6 +33,25 @@ class DataChainColumnError(DataChainParamsError):
32
33
  super().__init__(f"Error for column {col_name}: {msg}")
33
34
 
34
35
 
36
+ def callable_name(obj: object) -> str:
37
+ """Return a friendly name for a callable or UDF-like instance."""
38
+ # UDF classes in DataChain inherit from AbstractUDF; prefer class name
39
+ if isinstance(obj, AbstractUDF):
40
+ return obj.__class__.__name__
41
+
42
+ # Plain functions and bound/unbound methods
43
+ if inspect.ismethod(obj) or inspect.isfunction(obj):
44
+ # __name__ exists for functions/methods; includes "<lambda>" for lambdas
45
+ return obj.__name__ # type: ignore[attr-defined]
46
+
47
+ # Generic callable object
48
+ if callable(obj):
49
+ return obj.__class__.__name__
50
+
51
+ # Fallback for non-callables
52
+ return str(obj)
53
+
54
+
35
55
  def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
36
56
  """Returns normalized_name -> original_name dict."""
37
57
  gen_col_counter = 0
@@ -59,3 +79,97 @@ def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
59
79
  new_col_names[generated_column] = org_column
60
80
 
61
81
  return new_col_names
82
+
83
+
84
+ def rebase_path(
85
+ src_path: str,
86
+ old_base: str,
87
+ new_base: str,
88
+ suffix: str = "",
89
+ extension: str = "",
90
+ ) -> str:
91
+ """
92
+ Rebase a file path from one base directory to another.
93
+
94
+ Args:
95
+ src_path: Source file path (can include URI scheme like s3://)
96
+ old_base: Base directory to remove from src_path
97
+ new_base: New base directory to prepend
98
+ suffix: Optional suffix to add before file extension
99
+ extension: Optional new file extension (without dot)
100
+
101
+ Returns:
102
+ str: Rebased path with new base directory
103
+
104
+ Raises:
105
+ ValueError: If old_base is not found in src_path
106
+ """
107
+ # Parse URIs to handle schemes properly
108
+ src_parsed = urlparse(src_path)
109
+ old_base_parsed = urlparse(old_base)
110
+ new_base_parsed = urlparse(new_base)
111
+
112
+ # Get the path component (without scheme)
113
+ if src_parsed.scheme:
114
+ src_path_only = src_parsed.netloc + src_parsed.path
115
+ else:
116
+ src_path_only = src_path
117
+
118
+ if old_base_parsed.scheme:
119
+ old_base_only = old_base_parsed.netloc + old_base_parsed.path
120
+ else:
121
+ old_base_only = old_base
122
+
123
+ # Normalize paths
124
+ src_path_norm = PurePosixPath(src_path_only).as_posix()
125
+ old_base_norm = PurePosixPath(old_base_only).as_posix()
126
+
127
+ # Find where old_base appears in src_path
128
+ if old_base_norm in src_path_norm:
129
+ # Find the index where old_base appears
130
+ idx = src_path_norm.find(old_base_norm)
131
+ if idx == -1:
132
+ raise ValueError(f"old_base '{old_base}' not found in src_path")
133
+
134
+ # Extract the relative path after old_base
135
+ relative_start = idx + len(old_base_norm)
136
+ # Skip leading slash if present
137
+ if relative_start < len(src_path_norm) and src_path_norm[relative_start] == "/":
138
+ relative_start += 1
139
+ relative_path = src_path_norm[relative_start:]
140
+ else:
141
+ raise ValueError(f"old_base '{old_base}' not found in src_path")
142
+
143
+ # Parse the filename
144
+ path_obj = PurePosixPath(relative_path)
145
+ stem = path_obj.stem
146
+ current_ext = path_obj.suffix
147
+
148
+ # Apply suffix and extension changes
149
+ new_stem = stem + suffix if suffix else stem
150
+ if extension:
151
+ new_ext = f".{extension}"
152
+ elif current_ext:
153
+ new_ext = current_ext
154
+ else:
155
+ new_ext = ""
156
+
157
+ # Build new filename
158
+ new_name = new_stem + new_ext
159
+
160
+ # Reconstruct path with new base
161
+ parent = str(path_obj.parent)
162
+ if parent == ".":
163
+ new_relative_path = new_name
164
+ else:
165
+ new_relative_path = str(PurePosixPath(parent) / new_name)
166
+
167
+ # Handle new_base URI scheme
168
+ if new_base_parsed.scheme:
169
+ # Has schema like s3://
170
+ base_path = new_base_parsed.netloc + new_base_parsed.path
171
+ base_path = PurePosixPath(base_path).as_posix()
172
+ full_path = str(PurePosixPath(base_path) / new_relative_path)
173
+ return f"{new_base_parsed.scheme}://{full_path}"
174
+ # Regular path
175
+ return str(PurePosixPath(new_base) / new_relative_path)
datachain/lib/video.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import posixpath
2
2
  import shutil
3
3
  import tempfile
4
- from typing import Optional, Union
5
4
 
6
5
  from numpy import ndarray
7
6
 
@@ -18,7 +17,7 @@ except ImportError as exc:
18
17
  ) from exc
19
18
 
20
19
 
21
- def video_info(file: Union[File, VideoFile]) -> Video:
20
+ def video_info(file: File | VideoFile) -> Video:
22
21
  """
23
22
  Returns video file information.
24
23
 
@@ -34,21 +33,27 @@ def video_info(file: Union[File, VideoFile]) -> Video:
34
33
  file.ensure_cached()
35
34
  file_path = file.get_local_path()
36
35
  if not file_path:
37
- raise FileError(file, "unable to download video file")
36
+ raise FileError("unable to download video file", file.source, file.path)
38
37
 
39
38
  try:
40
39
  probe = ffmpeg.probe(file_path)
41
40
  except Exception as exc:
42
- raise FileError(file, "unable to extract metadata from video file") from exc
41
+ raise FileError(
42
+ "unable to extract metadata from video file", file.source, file.path
43
+ ) from exc
43
44
 
44
45
  all_streams = probe.get("streams")
45
46
  video_format = probe.get("format")
46
47
  if not all_streams or not video_format:
47
- raise FileError(file, "unable to extract metadata from video file")
48
+ raise FileError(
49
+ "unable to extract metadata from video file", file.source, file.path
50
+ )
48
51
 
49
52
  video_streams = [s for s in all_streams if s["codec_type"] == "video"]
50
53
  if len(video_streams) == 0:
51
- raise FileError(file, "unable to extract metadata from video file")
54
+ raise FileError(
55
+ "unable to extract metadata from video file", file.source, file.path
56
+ )
52
57
 
53
58
  video_stream = video_streams[0]
54
59
 
@@ -102,7 +107,7 @@ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
102
107
  def validate_frame_range(
103
108
  video: VideoFile,
104
109
  start: int = 0,
105
- end: Optional[int] = None,
110
+ end: int | None = None,
106
111
  step: int = 1,
107
112
  ) -> tuple[int, int, int]:
108
113
  """
@@ -180,7 +185,7 @@ def save_video_fragment(
180
185
  start: float,
181
186
  end: float,
182
187
  output: str,
183
- format: Optional[str] = None,
188
+ format: str | None = None,
184
189
  ) -> VideoFile:
185
190
  """
186
191
  Saves video interval as a new video file. If output is a remote path,
@@ -199,7 +204,10 @@ def save_video_fragment(
199
204
  VideoFile: Video fragment model.
200
205
  """
201
206
  if start < 0 or end < 0 or start >= end:
202
- raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
207
+ raise ValueError(
208
+ f"Can't save video fragment for '{video.path}', "
209
+ f"invalid time range: ({start:.3f}, {end:.3f})"
210
+ )
203
211
 
204
212
  if format is None:
205
213
  format = video.get_file_ext()
@@ -1,20 +1,13 @@
1
- import json
2
1
  import tarfile
2
+ import types
3
3
  import warnings
4
- from collections.abc import Iterator, Sequence
4
+ from collections.abc import Callable, Iterator, Sequence
5
5
  from pathlib import Path
6
- from typing import (
7
- Any,
8
- Callable,
9
- ClassVar,
10
- Optional,
11
- Union,
12
- get_args,
13
- get_origin,
14
- )
6
+ from typing import Any, ClassVar, Union, get_args, get_origin
15
7
 
16
8
  from pydantic import Field
17
9
 
10
+ from datachain import json
18
11
  from datachain.lib.data_model import DataModel
19
12
  from datachain.lib.file import File
20
13
  from datachain.lib.tar import build_tar_member
@@ -34,29 +27,29 @@ warnings.filterwarnings(
34
27
 
35
28
 
36
29
  class WDSError(DataChainError):
37
- def __init__(self, tar_stream, message: str):
38
- super().__init__(f"WebDataset error '{tar_stream.get_full_name()}': {message}")
30
+ def __init__(self, tar_name: str, message: str):
31
+ super().__init__(f"WebDataset error '{tar_name}': {message}")
39
32
 
40
33
 
41
34
  class CoreFileDuplicationError(WDSError):
42
- def __init__(self, tar_stream, file1: str, file2: str):
35
+ def __init__(self, tar_name: str, file1: str, file2: str):
43
36
  super().__init__(
44
- tar_stream, f"duplication of files with core extensions: {file1}, {file2}"
37
+ tar_name, f"duplication of files with core extensions: {file1}, {file2}"
45
38
  )
46
39
 
47
40
 
48
41
  class CoreFileNotFoundError(WDSError):
49
- def __init__(self, tar_stream, extensions, stem):
42
+ def __init__(self, tar_name: str, extensions: Sequence[str], stem: str):
50
43
  super().__init__(
51
- tar_stream,
44
+ tar_name,
52
45
  f"no files with the extensions '{','.join(extensions)}'"
53
46
  f" were found for file stem {stem}",
54
47
  )
55
48
 
56
49
 
57
50
  class UnknownFileExtensionError(WDSError):
58
- def __init__(self, tar_stream, name, ext):
59
- super().__init__(tar_stream, f"unknown extension '{ext}' for file '{name}'")
51
+ def __init__(self, tar_name, name: str, ext: str):
52
+ super().__init__(tar_name, f"unknown extension '{ext}' for file '{name}'")
60
53
 
61
54
 
62
55
  class WDSBasic(DataModel):
@@ -64,28 +57,28 @@ class WDSBasic(DataModel):
64
57
 
65
58
 
66
59
  class WDSAllFile(WDSBasic):
67
- txt: Optional[str] = Field(default=None)
68
- text: Optional[str] = Field(default=None)
69
- cap: Optional[str] = Field(default=None)
70
- transcript: Optional[str] = Field(default=None)
71
- cls: Optional[int] = Field(default=None)
72
- cls2: Optional[int] = Field(default=None)
73
- index: Optional[int] = Field(default=None)
74
- inx: Optional[int] = Field(default=None)
75
- id: Optional[int] = Field(default=None)
76
- json: Optional[dict] = Field(default=None) # type: ignore[assignment]
77
- jsn: Optional[dict] = Field(default=None)
78
-
79
- pyd: Optional[bytes] = Field(default=None)
80
- pickle: Optional[bytes] = Field(default=None)
81
- pth: Optional[bytes] = Field(default=None)
82
- ten: Optional[bytes] = Field(default=None)
83
- tb: Optional[bytes] = Field(default=None)
84
- mp: Optional[bytes] = Field(default=None)
85
- msg: Optional[bytes] = Field(default=None)
86
- npy: Optional[bytes] = Field(default=None)
87
- npz: Optional[bytes] = Field(default=None)
88
- cbor: Optional[bytes] = Field(default=None)
60
+ txt: str | None = Field(default=None)
61
+ text: str | None = Field(default=None)
62
+ cap: str | None = Field(default=None)
63
+ transcript: str | None = Field(default=None)
64
+ cls: int | None = Field(default=None)
65
+ cls2: int | None = Field(default=None)
66
+ index: int | None = Field(default=None)
67
+ inx: int | None = Field(default=None)
68
+ id: int | None = Field(default=None)
69
+ json: dict | None = Field(default=None) # type: ignore[assignment]
70
+ jsn: dict | None = Field(default=None)
71
+
72
+ pyd: bytes | None = Field(default=None)
73
+ pickle: bytes | None = Field(default=None)
74
+ pth: bytes | None = Field(default=None)
75
+ ten: bytes | None = Field(default=None)
76
+ tb: bytes | None = Field(default=None)
77
+ mp: bytes | None = Field(default=None)
78
+ msg: bytes | None = Field(default=None)
79
+ npy: bytes | None = Field(default=None)
80
+ npz: bytes | None = Field(default=None)
81
+ cbor: bytes | None = Field(default=None)
89
82
 
90
83
 
91
84
  class WDSReadableSubclass(DataModel):
@@ -113,10 +106,10 @@ class Builder:
113
106
  def __init__(
114
107
  self,
115
108
  tar_stream: File,
116
- core_extensions: list[str],
109
+ core_extensions: Sequence[str],
117
110
  wds_class: type[WDSBasic],
118
- tar,
119
- encoding="utf-8",
111
+ tar: tarfile.TarFile,
112
+ encoding: str = "utf-8",
120
113
  ):
121
114
  self._core_extensions = core_extensions
122
115
  self._tar_stream = tar_stream
@@ -145,18 +138,20 @@ class Builder:
145
138
  if ext in self._core_extensions:
146
139
  if self.state.core_file is not None:
147
140
  raise CoreFileDuplicationError(
148
- self._tar_stream, file.name, self.state.core_file.name
141
+ self._tar_stream.name, file.name, self.state.core_file.name
149
142
  )
150
143
  self.state.core_file = file
151
144
  elif ext in self.state.data:
152
145
  raise WDSError(
153
- self._tar_stream,
146
+ self._tar_stream.name,
154
147
  f"file with extension '.{ext}' already exists in the archive",
155
148
  )
156
149
  else:
157
150
  type_ = self._get_type(ext)
158
151
  if type_ is None:
159
- raise UnknownFileExtensionError(self._tar_stream, fstream.name, ext)
152
+ raise UnknownFileExtensionError(
153
+ self._tar_stream.name, fstream.name, ext
154
+ )
160
155
 
161
156
  if issubclass(type_, WDSReadableSubclass):
162
157
  reader = type_._reader
@@ -165,7 +160,7 @@ class Builder:
165
160
 
166
161
  if reader is None:
167
162
  raise WDSError(
168
- self._tar_stream,
163
+ self._tar_stream.name,
169
164
  f"unable to find a reader for type {type_}, extension .{ext}",
170
165
  )
171
166
  self.state.data[ext] = reader(self, file)
@@ -173,7 +168,7 @@ class Builder:
173
168
  def produce(self):
174
169
  if self.state.core_file is None:
175
170
  raise CoreFileNotFoundError(
176
- self._tar_stream, self._core_extensions, self.state.stem
171
+ self._tar_stream.name, self._core_extensions, self.state.stem
177
172
  )
178
173
 
179
174
  file = build_tar_member(self._tar_stream, self.state.core_file)
@@ -187,14 +182,22 @@ class Builder:
187
182
  return
188
183
 
189
184
  anno = field.annotation
190
- if get_origin(anno) == Union:
191
- args = get_args(anno)
192
- anno = args[0]
185
+ anno_origin = get_origin(anno)
186
+ if anno_origin in (Union, types.UnionType):
187
+ anno_args = get_args(anno)
188
+ if len(anno_args) == 2 and type(None) in anno_args:
189
+ return anno_args[0] if anno_args[1] is type(None) else anno_args[1]
193
190
 
194
191
  return anno
195
192
 
196
193
 
197
- def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
194
+ def get_tar_groups(
195
+ stream: File,
196
+ tar: tarfile.TarFile,
197
+ core_extensions: Sequence[str],
198
+ spec: type[WDSBasic],
199
+ encoding: str = "utf-8",
200
+ ) -> Iterator[WDSBasic]:
198
201
  builder = Builder(stream, core_extensions, spec, tar, encoding)
199
202
 
200
203
  for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
@@ -210,9 +213,11 @@ def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
210
213
 
211
214
 
212
215
  def process_webdataset(
213
- core_extensions: Sequence[str] = ("jpg", "png"), spec=WDSAllFile, encoding="utf-8"
214
- ) -> Callable:
215
- def wds_func(file: File) -> Iterator[spec]:
216
+ core_extensions: Sequence[str] = ("jpg", "png"),
217
+ spec: type[WDSBasic] = WDSAllFile,
218
+ encoding: str = "utf-8",
219
+ ) -> Callable[[File], Iterator]:
220
+ def wds_func(file: File) -> Iterator[spec]: # type: ignore[valid-type]
216
221
  with file.open() as fd:
217
222
  with tarfile.open(fileobj=fd) as tar:
218
223
  yield from get_tar_groups(file, tar, core_extensions, spec, encoding)
@@ -1,6 +1,5 @@
1
1
  import warnings
2
2
  from collections.abc import Iterator
3
- from typing import Optional
4
3
 
5
4
  import numpy as np
6
5
  from pydantic import BaseModel, Field
@@ -23,18 +22,18 @@ warnings.filterwarnings(
23
22
 
24
23
  class Laion(WDSReadableSubclass):
25
24
  uid: str = Field(default="")
26
- face_bboxes: Optional[list[list[float]]] = Field(default=None)
27
- caption: Optional[str] = Field(default=None)
28
- url: Optional[str] = Field(default=None)
29
- key: Optional[str] = Field(default=None)
30
- status: Optional[str] = Field(default=None)
31
- error_message: Optional[str] = Field(default=None)
32
- width: Optional[int] = Field(default=None)
33
- height: Optional[int] = Field(default=None)
34
- original_width: Optional[int] = Field(default=None)
35
- original_height: Optional[int] = Field(default=None)
36
- exif: Optional[str] = Field(default=None)
37
- sha256: Optional[str] = Field(default=None)
25
+ face_bboxes: list[list[float]] | None = Field(default=None)
26
+ caption: str | None = Field(default=None)
27
+ url: str | None = Field(default=None)
28
+ key: str | None = Field(default=None)
29
+ status: str | None = Field(default=None)
30
+ error_message: str | None = Field(default=None)
31
+ width: int | None = Field(default=None)
32
+ height: int | None = Field(default=None)
33
+ original_width: int | None = Field(default=None)
34
+ original_height: int | None = Field(default=None)
35
+ exif: str | None = Field(default=None)
36
+ sha256: str | None = Field(default=None)
38
37
 
39
38
  @staticmethod
40
39
  def _reader(builder, item):
@@ -42,13 +41,13 @@ class Laion(WDSReadableSubclass):
42
41
 
43
42
 
44
43
  class WDSLaion(WDSBasic):
45
- txt: Optional[str] = Field(default=None)
46
- json: Laion # type: ignore[assignment]
44
+ txt: str | None = Field(default=None)
45
+ json: Laion = Field(default_factory=Laion) # type: ignore[assignment]
47
46
 
48
47
 
49
48
  class LaionMeta(BaseModel):
50
49
  file: File
51
- index: Optional[int] = Field(default=None)
50
+ index: int | None = Field(default=None)
52
51
  b32_img: list[float] = Field(default=[])
53
52
  b32_txt: list[float] = Field(default=[])
54
53
  l14_img: list[float] = Field(default=[])
datachain/listing.py CHANGED
@@ -2,7 +2,7 @@ import glob
2
2
  import os
3
3
  from collections.abc import Iterable, Iterator
4
4
  from functools import cached_property
5
- from typing import TYPE_CHECKING, Optional
5
+ from typing import TYPE_CHECKING
6
6
 
7
7
  from sqlalchemy import Column
8
8
  from sqlalchemy.sql import func
@@ -25,16 +25,17 @@ class Listing:
25
25
  metastore: "AbstractMetastore",
26
26
  warehouse: "AbstractWarehouse",
27
27
  client: "Client",
28
- dataset_name: Optional["str"] = None,
29
- dataset_version: Optional[int] = None,
30
- object_name: str = "file",
28
+ dataset_name: str | None = None,
29
+ dataset_version: str | None = None,
30
+ column: str = "file",
31
31
  ):
32
32
  self.metastore = metastore
33
33
  self.warehouse = warehouse
34
34
  self.client = client
35
35
  self.dataset_name = dataset_name # dataset representing bucket listing
36
36
  self.dataset_version = dataset_version # dataset representing bucket listing
37
- self.object_name = object_name
37
+ self.column = column
38
+ self._closed = False
38
39
 
39
40
  def clone(self) -> "Listing":
40
41
  return self.__class__(
@@ -43,7 +44,7 @@ class Listing:
43
44
  self.client,
44
45
  self.dataset_name,
45
46
  self.dataset_version,
46
- self.object_name,
47
+ self.column,
47
48
  )
48
49
 
49
50
  def __enter__(self) -> "Listing":
@@ -53,7 +54,13 @@ class Listing:
53
54
  self.close()
54
55
 
55
56
  def close(self) -> None:
56
- self.warehouse.close()
57
+ if self._closed:
58
+ return
59
+ self._closed = True
60
+ try:
61
+ self.warehouse.close_on_exit()
62
+ finally:
63
+ self.metastore.close_on_exit()
57
64
 
58
65
  @property
59
66
  def uri(self):
@@ -66,7 +73,12 @@ class Listing:
66
73
  @cached_property
67
74
  def dataset(self) -> "DatasetRecord":
68
75
  assert self.dataset_name
69
- return self.metastore.get_dataset(self.dataset_name)
76
+ project = self.metastore.listing_project
77
+ return self.metastore.get_dataset(
78
+ self.dataset_name,
79
+ namespace_name=project.namespace.name,
80
+ project_name=project.name,
81
+ )
70
82
 
71
83
  @cached_property
72
84
  def dataset_rows(self):
@@ -74,7 +86,7 @@ class Listing:
74
86
  return self.warehouse.dataset_rows(
75
87
  dataset,
76
88
  self.dataset_version or dataset.latest_version,
77
- object_name=self.object_name,
89
+ column=self.column,
78
90
  )
79
91
 
80
92
  def expand_path(self, path, use_glob=True) -> list[Node]:
@@ -97,7 +109,7 @@ class Listing:
97
109
  def collect_nodes_to_instantiate(
98
110
  self,
99
111
  sources: Iterable["DataSource"],
100
- copy_to_filename: Optional[str],
112
+ copy_to_filename: str | None,
101
113
  recursive=False,
102
114
  copy_dir_contents=False,
103
115
  from_dataset=False,
datachain/model/bbox.py CHANGED
@@ -198,7 +198,9 @@ class BBox(DataModel):
198
198
  def pose_inside(self, pose: Union["Pose", "Pose3D"]) -> bool:
199
199
  """Return True if the pose is inside the bounding box."""
200
200
  return all(
201
- self.point_inside(x, y) for x, y in zip(pose.x, pose.y) if x > 0 or y > 0
201
+ self.point_inside(x, y)
202
+ for x, y in zip(pose.x, pose.y, strict=False)
203
+ if x > 0 or y > 0
202
204
  )
203
205
 
204
206
  @staticmethod
@@ -31,11 +31,11 @@ class YoloBBox(DataModel):
31
31
  if not summary:
32
32
  return YoloBBox(box=BBox())
33
33
  name = summary[0].get("name", "")
34
- box = (
35
- BBox.from_dict(summary[0]["box"], title=name)
36
- if "box" in summary[0]
37
- else BBox()
38
- )
34
+ if summary[0].get("box"):
35
+ assert isinstance(summary[0]["box"], dict)
36
+ box = BBox.from_dict(summary[0]["box"], title=name)
37
+ else:
38
+ box = BBox()
39
39
  return YoloBBox(
40
40
  cls=summary[0]["class"],
41
41
  name=name,
@@ -69,7 +69,9 @@ class YoloBBoxes(DataModel):
69
69
  cls.append(s["class"])
70
70
  names.append(name)
71
71
  confidence.append(s["confidence"])
72
- box.append(BBox.from_dict(s.get("box", {}), title=name))
72
+ if s.get("box"):
73
+ assert isinstance(s["box"], dict)
74
+ box.append(BBox.from_dict(s["box"], title=name))
73
75
  return YoloBBoxes(
74
76
  cls=cls,
75
77
  name=names,
@@ -100,11 +102,11 @@ class YoloOBBox(DataModel):
100
102
  if not summary:
101
103
  return YoloOBBox(box=OBBox())
102
104
  name = summary[0].get("name", "")
103
- box = (
104
- OBBox.from_dict(summary[0]["box"], title=name)
105
- if "box" in summary[0]
106
- else OBBox()
107
- )
105
+ if summary[0].get("box"):
106
+ assert isinstance(summary[0]["box"], dict)
107
+ box = OBBox.from_dict(summary[0]["box"], title=name)
108
+ else:
109
+ box = OBBox()
108
110
  return YoloOBBox(
109
111
  cls=summary[0]["class"],
110
112
  name=name,
@@ -138,7 +140,9 @@ class YoloOBBoxes(DataModel):
138
140
  cls.append(s["class"])
139
141
  names.append(name)
140
142
  confidence.append(s["confidence"])
141
- box.append(OBBox.from_dict(s.get("box", {}), title=name))
143
+ if s.get("box"):
144
+ assert isinstance(s["box"], dict)
145
+ box.append(OBBox.from_dict(s["box"], title=name))
142
146
  return YoloOBBoxes(
143
147
  cls=cls,
144
148
  name=names,