datachain 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/catalog/catalog.py +33 -5
  2. datachain/catalog/loader.py +19 -13
  3. datachain/cli/__init__.py +3 -1
  4. datachain/cli/commands/show.py +12 -1
  5. datachain/cli/parser/studio.py +13 -1
  6. datachain/cli/parser/utils.py +6 -0
  7. datachain/client/fsspec.py +12 -16
  8. datachain/client/hf.py +36 -14
  9. datachain/client/local.py +1 -4
  10. datachain/data_storage/warehouse.py +3 -8
  11. datachain/dataset.py +8 -0
  12. datachain/error.py +0 -12
  13. datachain/fs/utils.py +30 -0
  14. datachain/func/__init__.py +5 -0
  15. datachain/func/func.py +2 -1
  16. datachain/lib/data_model.py +6 -0
  17. datachain/lib/dc.py +114 -28
  18. datachain/lib/file.py +100 -25
  19. datachain/lib/image.py +30 -6
  20. datachain/lib/listing.py +21 -39
  21. datachain/lib/signal_schema.py +194 -15
  22. datachain/lib/video.py +7 -5
  23. datachain/model/bbox.py +209 -58
  24. datachain/model/pose.py +49 -37
  25. datachain/model/segment.py +22 -18
  26. datachain/model/ultralytics/bbox.py +9 -9
  27. datachain/model/ultralytics/pose.py +7 -7
  28. datachain/model/ultralytics/segment.py +7 -7
  29. datachain/model/utils.py +191 -0
  30. datachain/nodes_thread_pool.py +32 -11
  31. datachain/query/dataset.py +4 -2
  32. datachain/studio.py +8 -6
  33. datachain/utils.py +3 -16
  34. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/METADATA +6 -4
  35. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/RECORD +39 -37
  36. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/WHEEL +1 -1
  37. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/LICENSE +0 -0
  38. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/entry_points.txt +0 -0
  39. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/top_level.txt +0 -0
datachain/lib/listing.py CHANGED
@@ -1,19 +1,21 @@
1
+ import glob
1
2
  import logging
2
3
  import os
3
4
  import posixpath
4
5
  from collections.abc import Iterator
5
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar
6
+ from contextlib import contextmanager
7
+ from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
6
8
 
7
9
  from fsspec.asyn import get_loop
8
10
  from sqlalchemy.sql.expression import true
9
11
 
12
+ import datachain.fs.utils as fsutils
10
13
  from datachain.asyn import iter_over_async
11
14
  from datachain.client import Client
12
- from datachain.error import REMOTE_ERRORS, ClientError
15
+ from datachain.error import ClientError
13
16
  from datachain.lib.file import File
14
17
  from datachain.query.schema import Column
15
18
  from datachain.sql.functions import path as pathfunc
16
- from datachain.telemetry import telemetry
17
19
  from datachain.utils import uses_glob
18
20
 
19
21
  if TYPE_CHECKING:
@@ -92,38 +94,6 @@ def ls(
92
94
  return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
93
95
 
94
96
 
95
- def _isfile(client: "Client", path: str) -> bool:
96
- """
97
- Returns True if uri points to a file
98
- """
99
- try:
100
- if "://" in path:
101
- # This makes sure that the uppercase scheme is converted to lowercase
102
- scheme, path = path.split("://", 1)
103
- path = f"{scheme.lower()}://{path}"
104
-
105
- if os.name == "nt" and "*" in path:
106
- # On Windows, the glob pattern "*" is not supported
107
- return False
108
-
109
- info = client.fs.info(path)
110
- name = info.get("name")
111
- # case for special simulated directories on some clouds
112
- # e.g. Google creates a zero byte file with the same name as the
113
- # directory with a trailing slash at the end
114
- if not name or name.endswith("/"):
115
- return False
116
-
117
- return info["type"] == "file"
118
- except FileNotFoundError:
119
- return False
120
- except REMOTE_ERRORS as e:
121
- raise ClientError(
122
- message=str(e),
123
- error_code=getattr(e, "code", None),
124
- ) from e
125
-
126
-
127
97
  def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
128
98
  """
129
99
  Parsing uri and returns listing dataset name, listing uri and listing path
@@ -156,8 +126,16 @@ def listing_uri_from_name(dataset_name: str) -> str:
156
126
  return dataset_name.removeprefix(LISTING_PREFIX)
157
127
 
158
128
 
129
+ @contextmanager
130
+ def _reraise_as_client_error() -> Iterator[None]:
131
+ try:
132
+ yield
133
+ except Exception as e:
134
+ raise ClientError(message=str(e), error_code=getattr(e, "code", None)) from e
135
+
136
+
159
137
  def get_listing(
160
- uri: str, session: "Session", update: bool = False
138
+ uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
161
139
  ) -> tuple[Optional[str], str, str, bool]:
162
140
  """Returns correct listing dataset name that must be used for saving listing
163
141
  operation. It takes into account existing listings and reusability of those.
@@ -167,6 +145,7 @@ def get_listing(
167
145
  be used to find rows based on uri.
168
146
  """
169
147
  from datachain.client.local import FileClient
148
+ from datachain.telemetry import telemetry
170
149
 
171
150
  catalog = session.catalog
172
151
  cache = catalog.cache
@@ -174,11 +153,14 @@ def get_listing(
174
153
 
175
154
  client = Client.get_client(uri, cache, **client_config)
176
155
  telemetry.log_param("client", client.PREFIX)
156
+ if not isinstance(uri, str):
157
+ uri = os.fspath(uri)
177
158
 
178
159
  # we don't want to use cached dataset (e.g. for a single file listing)
179
- if not uri.endswith("/") and _isfile(client, uri):
180
- storage_uri, path = Client.parse_url(uri)
181
- return None, f"{storage_uri}/{path.lstrip('/')}", path, False
160
+ isfile = _reraise_as_client_error()(fsutils.isfile)
161
+ if not glob.has_magic(uri) and not uri.endswith("/") and isfile(client.fs, uri):
162
+ _, path = Client.parse_url(uri)
163
+ return None, uri, path, False
182
164
 
183
165
  ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
184
166
  listing = None
@@ -91,6 +91,7 @@ class CustomType(BaseModel):
91
91
  name: str
92
92
  fields: dict[str, str]
93
93
  bases: list[tuple[str, str, Optional[str]]]
94
+ hidden_fields: Optional[list[str]] = None
94
95
 
95
96
  @classmethod
96
97
  def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
@@ -102,6 +103,7 @@ class CustomType(BaseModel):
102
103
  "name": type_name,
103
104
  "fields": data,
104
105
  "bases": [],
106
+ "hidden_fields": [],
105
107
  }
106
108
 
107
109
  return cls(**data)
@@ -179,6 +181,16 @@ class SignalSchema:
179
181
  )
180
182
  return SignalSchema(signals)
181
183
 
184
+ @staticmethod
185
+ def _get_bases(fr: type) -> list[tuple[str, str, Optional[str]]]:
186
+ bases: list[tuple[str, str, Optional[str]]] = []
187
+ for base in fr.__mro__:
188
+ model_store_name = (
189
+ ModelStore.get_name(base) if issubclass(base, DataModel) else None
190
+ )
191
+ bases.append((base.__name__, base.__module__, model_store_name))
192
+ return bases
193
+
182
194
  @staticmethod
183
195
  def _serialize_custom_model(
184
196
  version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
@@ -196,14 +208,15 @@ class SignalSchema:
196
208
  assert field_type
197
209
  fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
198
210
 
199
- bases: list[tuple[str, str, Optional[str]]] = []
200
- for type_ in fr.__mro__:
201
- model_store_name = (
202
- ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
203
- )
204
- bases.append((type_.__name__, type_.__module__, model_store_name))
211
+ bases = SignalSchema._get_bases(fr)
205
212
 
206
- ct = CustomType(schema_version=2, name=version_name, fields=fields, bases=bases)
213
+ ct = CustomType(
214
+ schema_version=2,
215
+ name=version_name,
216
+ fields=fields,
217
+ bases=bases,
218
+ hidden_fields=getattr(fr, "_hidden_fields", []),
219
+ )
207
220
  custom_types[version_name] = ct.model_dump()
208
221
 
209
222
  return version_name
@@ -384,6 +397,37 @@ class SignalSchema:
384
397
 
385
398
  return SignalSchema(signals)
386
399
 
400
+ @staticmethod
401
+ def get_flatten_hidden_fields(schema):
402
+ custom_types = schema.get("_custom_types", {})
403
+ if not custom_types:
404
+ return []
405
+
406
+ hidden_by_types = {
407
+ name: schema.get("hidden_fields", [])
408
+ for name, schema in custom_types.items()
409
+ }
410
+
411
+ hidden_fields = []
412
+
413
+ def traverse(prefix, schema_info):
414
+ for field, field_type in schema_info.items():
415
+ if field == "_custom_types":
416
+ continue
417
+
418
+ if field_type in custom_types:
419
+ hidden_fields.extend(
420
+ f"{prefix}{field}__{f}" for f in hidden_by_types[field_type]
421
+ )
422
+ traverse(
423
+ prefix + field + "__",
424
+ custom_types[field_type].get("fields", {}),
425
+ )
426
+
427
+ traverse("", schema)
428
+
429
+ return hidden_fields
430
+
387
431
  def to_udf_spec(self) -> dict[str, type]:
388
432
  res = {}
389
433
  for path, type_, has_subtree, _ in self.get_flat_tree():
@@ -479,7 +523,7 @@ class SignalSchema:
479
523
  raise SignalResolvingError([col_name], "is not found")
480
524
 
481
525
  def db_signals(
482
- self, name: Optional[str] = None, as_columns=False
526
+ self, name: Optional[str] = None, as_columns=False, include_hidden: bool = True
483
527
  ) -> Union[list[str], list[Column]]:
484
528
  """
485
529
  Returns DB columns as strings or Column objects with proper types
@@ -489,7 +533,9 @@ class SignalSchema:
489
533
  DEFAULT_DELIMITER.join(path)
490
534
  if not as_columns
491
535
  else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type))
492
- for path, _type, has_subtree, _ in self.get_flat_tree()
536
+ for path, _type, has_subtree, _ in self.get_flat_tree(
537
+ include_hidden=include_hidden
538
+ )
493
539
  if not has_subtree
494
540
  ]
495
541
 
@@ -624,19 +670,31 @@ class SignalSchema:
624
670
  for name, val in values.items()
625
671
  }
626
672
 
627
- def get_flat_tree(self) -> Iterator[tuple[list[str], DataType, bool, int]]:
628
- yield from self._get_flat_tree(self.tree, [], 0)
673
+ def get_flat_tree(
674
+ self, include_hidden: bool = True
675
+ ) -> Iterator[tuple[list[str], DataType, bool, int]]:
676
+ yield from self._get_flat_tree(self.tree, [], 0, include_hidden)
629
677
 
630
678
  def _get_flat_tree(
631
- self, tree: dict, prefix: list[str], depth: int
679
+ self, tree: dict, prefix: list[str], depth: int, include_hidden: bool
632
680
  ) -> Iterator[tuple[list[str], DataType, bool, int]]:
633
681
  for name, (type_, substree) in tree.items():
634
682
  suffix = name.split(".")
635
683
  new_prefix = prefix + suffix
684
+ hidden_fields = getattr(type_, "_hidden_fields", None)
685
+ if hidden_fields and substree and not include_hidden:
686
+ substree = {
687
+ field: info
688
+ for field, info in substree.items()
689
+ if field not in hidden_fields
690
+ }
691
+
636
692
  has_subtree = substree is not None
637
693
  yield new_prefix, type_, has_subtree, depth
638
694
  if substree is not None:
639
- yield from self._get_flat_tree(substree, new_prefix, depth + 1)
695
+ yield from self._get_flat_tree(
696
+ substree, new_prefix, depth + 1, include_hidden
697
+ )
640
698
 
641
699
  def print_tree(self, indent: int = 4, start_at: int = 0):
642
700
  for path, type_, _, depth in self.get_flat_tree():
@@ -649,9 +707,13 @@ class SignalSchema:
649
707
  sub_schema = SignalSchema({"* list of": args[0]})
650
708
  sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
651
709
 
652
- def get_headers_with_length(self):
710
+ def get_headers_with_length(self, include_hidden: bool = True):
653
711
  paths = [
654
- path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
712
+ path
713
+ for path, _, has_subtree, _ in self.get_flat_tree(
714
+ include_hidden=include_hidden
715
+ )
716
+ if not has_subtree
655
717
  ]
656
718
  max_length = max([len(path) for path in paths], default=0)
657
719
  return [
@@ -749,3 +811,120 @@ class SignalSchema:
749
811
  res[name] = (anno, subtree) # type: ignore[assignment]
750
812
 
751
813
  return res
814
+
815
+ def to_partial(self, *columns: str) -> "SignalSchema":
816
+ """
817
+ Convert the schema to a partial schema with only the specified columns.
818
+
819
+ E.g. if original schema is:
820
+
821
+ ```
822
+ signal: Foo@v1
823
+ name: str
824
+ value: float
825
+ count: int
826
+ ```
827
+
828
+ Then `to_partial("signal.name", "count")` will return a partial schema:
829
+
830
+ ```
831
+ signal: FooPartial@v1
832
+ name: str
833
+ count: int
834
+ ```
835
+
836
+ Note that partial schema will have a different name for the custom types
837
+ (e.g. `FooPartial@v1` instead of `Foo@v1`) to avoid conflicts
838
+ with the original schema.
839
+
840
+ Args:
841
+ *columns (str): The columns to include in the partial schema.
842
+
843
+ Returns:
844
+ SignalSchema: The new partial schema.
845
+ """
846
+ serialized = self.serialize()
847
+ custom_types = serialized.get("_custom_types", {})
848
+
849
+ schema: dict[str, Any] = {}
850
+ schema_custom_types: dict[str, CustomType] = {}
851
+
852
+ data_model_bases: Optional[list[tuple[str, str, Optional[str]]]] = None
853
+
854
+ signal_partials: dict[str, str] = {}
855
+ partial_versions: dict[str, int] = {}
856
+
857
+ def _type_name_to_partial(signal_name: str, type_name: str) -> str:
858
+ if "@" not in type_name:
859
+ return type_name
860
+ model_name, _ = ModelStore.parse_name_version(type_name)
861
+
862
+ if signal_name not in signal_partials:
863
+ partial_versions.setdefault(model_name, 0)
864
+ partial_versions[model_name] += 1
865
+ version = partial_versions[model_name]
866
+ signal_partials[signal_name] = f"{model_name}Partial{version}"
867
+
868
+ return signal_partials[signal_name]
869
+
870
+ for column in columns:
871
+ parent_type, parent_type_partial = "", ""
872
+ column_parts = column.split(".")
873
+ for i, signal in enumerate(column_parts):
874
+ if i == 0:
875
+ if signal not in serialized:
876
+ raise SignalSchemaError(
877
+ f"Column {column} not found in the schema"
878
+ )
879
+
880
+ parent_type = serialized[signal]
881
+ parent_type_partial = _type_name_to_partial(signal, parent_type)
882
+
883
+ schema[signal] = parent_type_partial
884
+ continue
885
+
886
+ if parent_type not in custom_types:
887
+ raise SignalSchemaError(
888
+ f"Custom type {parent_type} not found in the schema"
889
+ )
890
+
891
+ custom_type = custom_types[parent_type]
892
+ signal_type = custom_type["fields"].get(signal)
893
+ if not signal_type:
894
+ raise SignalSchemaError(
895
+ f"Field {signal} not found in custom type {parent_type}"
896
+ )
897
+
898
+ partial_type = _type_name_to_partial(
899
+ ".".join(column_parts[: i + 1]),
900
+ signal_type,
901
+ )
902
+
903
+ if parent_type_partial in schema_custom_types:
904
+ schema_custom_types[parent_type_partial].fields[signal] = (
905
+ partial_type
906
+ )
907
+ else:
908
+ if data_model_bases is None:
909
+ data_model_bases = SignalSchema._get_bases(DataModel)
910
+
911
+ partial_type_name, _ = ModelStore.parse_name_version(partial_type)
912
+ schema_custom_types[parent_type_partial] = CustomType(
913
+ schema_version=2,
914
+ name=partial_type_name,
915
+ fields={signal: partial_type},
916
+ bases=[
917
+ (partial_type_name, "__main__", partial_type),
918
+ *data_model_bases,
919
+ ],
920
+ )
921
+
922
+ parent_type, parent_type_partial = signal_type, partial_type
923
+
924
+ if schema_custom_types:
925
+ schema["_custom_types"] = {
926
+ type_name: ct.model_dump()
927
+ for type_name, ct in schema_custom_types.items()
928
+ }
929
+
930
+ return SignalSchema.deserialize(schema)
datachain/lib/video.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import posixpath
2
2
  import shutil
3
3
  import tempfile
4
- from typing import Optional
4
+ from typing import Optional, Union
5
5
 
6
6
  from numpy import ndarray
7
7
 
8
- from datachain.lib.file import FileError, ImageFile, Video, VideoFile
8
+ from datachain.lib.file import File, FileError, ImageFile, Video, VideoFile
9
9
 
10
10
  try:
11
11
  import ffmpeg
@@ -18,7 +18,7 @@ except ImportError as exc:
18
18
  ) from exc
19
19
 
20
20
 
21
- def video_info(file: VideoFile) -> Video:
21
+ def video_info(file: Union[File, VideoFile]) -> Video:
22
22
  """
23
23
  Returns video file information.
24
24
 
@@ -28,6 +28,8 @@ def video_info(file: VideoFile) -> Video:
28
28
  Returns:
29
29
  Video: Video file information.
30
30
  """
31
+ file = file.as_video_file()
32
+
31
33
  if not (file_path := file.get_local_path()):
32
34
  file.ensure_cached()
33
35
  file_path = file.get_local_path()
@@ -170,7 +172,7 @@ def save_video_frame(
170
172
  output_file = posixpath.join(
171
173
  output, f"{video.get_file_stem()}_{frame:04d}.{format}"
172
174
  )
173
- return ImageFile.upload(img, output_file)
175
+ return ImageFile.upload(img, output_file, catalog=video._catalog)
174
176
 
175
177
 
176
178
  def save_video_fragment(
@@ -218,6 +220,6 @@ def save_video_fragment(
218
220
  ).output(output_file_tmp).run(quiet=True)
219
221
 
220
222
  with open(output_file_tmp, "rb") as f:
221
- return VideoFile.upload(f.read(), output_file)
223
+ return VideoFile.upload(f.read(), output_file, catalog=video._catalog)
222
224
  finally:
223
225
  shutil.rmtree(temp_dir)