datachain 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (46) hide show
  1. datachain/__init__.py +3 -4
  2. datachain/cache.py +10 -4
  3. datachain/catalog/catalog.py +42 -16
  4. datachain/cli.py +48 -32
  5. datachain/data_storage/metastore.py +24 -0
  6. datachain/data_storage/warehouse.py +3 -1
  7. datachain/job.py +56 -0
  8. datachain/lib/arrow.py +19 -7
  9. datachain/lib/clip.py +89 -66
  10. datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
  11. datachain/lib/convert/sql_to_python.py +23 -0
  12. datachain/lib/convert/values_to_tuples.py +51 -33
  13. datachain/lib/data_model.py +6 -27
  14. datachain/lib/dataset_info.py +70 -0
  15. datachain/lib/dc.py +618 -156
  16. datachain/lib/file.py +130 -22
  17. datachain/lib/image.py +1 -1
  18. datachain/lib/meta_formats.py +14 -2
  19. datachain/lib/model_store.py +3 -2
  20. datachain/lib/pytorch.py +10 -7
  21. datachain/lib/signal_schema.py +19 -11
  22. datachain/lib/text.py +2 -1
  23. datachain/lib/udf.py +56 -5
  24. datachain/lib/udf_signature.py +1 -1
  25. datachain/node.py +11 -8
  26. datachain/query/dataset.py +62 -28
  27. datachain/query/schema.py +2 -0
  28. datachain/query/session.py +4 -4
  29. datachain/sql/functions/array.py +12 -0
  30. datachain/sql/functions/string.py +8 -0
  31. datachain/torch/__init__.py +1 -1
  32. datachain/utils.py +6 -0
  33. datachain-0.2.13.dist-info/METADATA +411 -0
  34. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/RECORD +38 -42
  35. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/WHEEL +1 -1
  36. datachain/lib/gpt4_vision.py +0 -97
  37. datachain/lib/hf_image_to_text.py +0 -97
  38. datachain/lib/hf_pipeline.py +0 -90
  39. datachain/lib/image_transform.py +0 -103
  40. datachain/lib/iptc_exif_xmp.py +0 -76
  41. datachain/lib/unstructured.py +0 -41
  42. datachain/text/__init__.py +0 -3
  43. datachain-0.2.11.dist-info/METADATA +0 -431
  44. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/LICENSE +0 -0
  45. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/entry_points.txt +0 -0
  46. {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
1
+ from datachain.lib.data_model import DataModel, DataType, is_chain_type
2
2
  from datachain.lib.dc import C, Column, DataChain, Sys
3
3
  from datachain.lib.file import (
4
4
  File,
@@ -8,15 +8,14 @@ from datachain.lib.file import (
8
8
  TarVFile,
9
9
  TextFile,
10
10
  )
11
+ from datachain.lib.model_store import ModelStore
11
12
  from datachain.lib.udf import Aggregator, Generator, Mapper
12
13
  from datachain.lib.utils import AbstractUDF, DataChainError
13
- from datachain.query.dataset import UDF as BaseUDF # noqa: N811
14
14
  from datachain.query.session import Session
15
15
 
16
16
  __all__ = [
17
17
  "AbstractUDF",
18
18
  "Aggregator",
19
- "BaseUDF",
20
19
  "C",
21
20
  "Column",
22
21
  "DataChain",
@@ -24,12 +23,12 @@ __all__ = [
24
23
  "DataModel",
25
24
  "DataType",
26
25
  "File",
27
- "FileBasic",
28
26
  "FileError",
29
27
  "Generator",
30
28
  "ImageFile",
31
29
  "IndexedFile",
32
30
  "Mapper",
31
+ "ModelStore",
33
32
  "Session",
34
33
  "Sys",
35
34
  "TarVFile",
datachain/cache.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import hashlib
2
2
  import json
3
3
  import os
4
+ from datetime import datetime
4
5
  from functools import partial
5
6
  from typing import TYPE_CHECKING, Optional
6
7
 
@@ -9,6 +10,8 @@ from dvc_data.hashfile.db.local import LocalHashFileDB
9
10
  from dvc_objects.fs.local import LocalFileSystem
10
11
  from fsspec.callbacks import Callback, TqdmCallback
11
12
 
13
+ from datachain.utils import TIME_ZERO
14
+
12
15
  from .progress import Tqdm
13
16
 
14
17
  if TYPE_CHECKING:
@@ -23,10 +26,13 @@ class UniqueId:
23
26
  storage: "StorageURI"
24
27
  parent: str
25
28
  name: str
26
- etag: str
27
29
  size: int
28
- vtype: str
29
- location: Optional[str]
30
+ etag: str
31
+ version: str = ""
32
+ is_latest: bool = True
33
+ vtype: str = ""
34
+ location: Optional[str] = None
35
+ last_modified: datetime = TIME_ZERO
30
36
 
31
37
  @property
32
38
  def path(self) -> str:
@@ -49,7 +55,7 @@ class UniqueId:
49
55
  def get_hash(self) -> str:
50
56
  etag = f"{self.vtype}{self.location}" if self.vtype else self.etag
51
57
  return sha256(
52
- f"{self.storage}/{self.parent}/{self.name}/{etag}".encode()
58
+ f"{self.storage}/{self.parent}/{self.name}/{self.version}/{etag}".encode()
53
59
  ).hexdigest()
54
60
 
55
61
 
@@ -1,4 +1,5 @@
1
1
  import ast
2
+ import glob
2
3
  import io
3
4
  import json
4
5
  import logging
@@ -84,12 +85,14 @@ if TYPE_CHECKING:
84
85
  AbstractMetastore,
85
86
  AbstractWarehouse,
86
87
  )
88
+ from datachain.dataset import DatasetVersion
89
+ from datachain.job import Job
87
90
 
88
91
  logger = logging.getLogger("datachain")
89
92
 
90
93
  DEFAULT_DATASET_DIR = "dataset"
91
94
  DATASET_FILE_SUFFIX = ".edatachain"
92
- FEATURE_CLASSES = ["Feature"]
95
+ FEATURE_CLASSES = ["DataModel"]
93
96
 
94
97
  TTL_INT = 4 * 60 * 60
95
98
 
@@ -707,7 +710,12 @@ class Catalog:
707
710
 
708
711
  client_config = client_config or self.client_config
709
712
  client, path = self.parse_url(source, **client_config)
710
- prefix = posixpath.dirname(path)
713
+ stem = os.path.basename(os.path.normpath(path))
714
+ prefix = (
715
+ posixpath.dirname(path)
716
+ if glob.has_magic(stem) or client.fs.isfile(source)
717
+ else path
718
+ )
711
719
  storage_dataset_name = Storage.dataset_name(
712
720
  client.uri, posixpath.join(prefix, "")
713
721
  )
@@ -948,13 +956,9 @@ class Catalog:
948
956
  ms = self.metastore.clone(uri, None)
949
957
  st = self.warehouse.clone()
950
958
  listing = Listing(None, ms, st, client, None)
951
- rows = (
952
- DatasetQuery(
953
- name=dataset.name, version=ds_version, catalog=self
954
- )
955
- .select()
956
- .to_records()
957
- )
959
+ rows = DatasetQuery(
960
+ name=dataset.name, version=ds_version, catalog=self
961
+ ).to_db_records()
958
962
  indexed_sources.append(
959
963
  (
960
964
  listing,
@@ -1160,9 +1164,8 @@ class Catalog:
1160
1164
  if not dataset_version.preview:
1161
1165
  values["preview"] = (
1162
1166
  DatasetQuery(name=dataset.name, version=version, catalog=self)
1163
- .select()
1164
1167
  .limit(20)
1165
- .to_records()
1168
+ .to_db_records()
1166
1169
  )
1167
1170
 
1168
1171
  if not values:
@@ -1420,6 +1423,25 @@ class Catalog:
1420
1423
  if not d.is_bucket_listing:
1421
1424
  yield d
1422
1425
 
1426
+ def list_datasets_versions(
1427
+ self,
1428
+ ) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
1429
+ """Iterate over all dataset versions with related jobs."""
1430
+ datasets = list(self.ls_datasets())
1431
+
1432
+ # preselect dataset versions jobs from db to avoid multiple queries
1433
+ jobs_ids: set[str] = {
1434
+ v.job_id for ds in datasets for v in ds.versions if v.job_id
1435
+ }
1436
+ jobs: dict[str, Job] = {}
1437
+ if jobs_ids:
1438
+ jobs = {j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))}
1439
+
1440
+ for d in datasets:
1441
+ yield from (
1442
+ (d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions
1443
+ )
1444
+
1423
1445
  def ls_dataset_rows(
1424
1446
  self, name: str, version: int, offset=None, limit=None
1425
1447
  ) -> list[dict]:
@@ -1427,7 +1449,7 @@ class Catalog:
1427
1449
 
1428
1450
  dataset = self.get_dataset(name)
1429
1451
 
1430
- q = DatasetQuery(name=dataset.name, version=version, catalog=self).select()
1452
+ q = DatasetQuery(name=dataset.name, version=version, catalog=self)
1431
1453
  if limit:
1432
1454
  q = q.limit(limit)
1433
1455
  if offset:
@@ -1435,7 +1457,7 @@ class Catalog:
1435
1457
 
1436
1458
  q = q.order_by("sys__id")
1437
1459
 
1438
- return q.to_records()
1460
+ return q.to_db_records()
1439
1461
 
1440
1462
  def signed_url(self, source: str, path: str, client_config=None) -> str:
1441
1463
  client_config = client_config or self.client_config
@@ -1609,6 +1631,7 @@ class Catalog:
1609
1631
  ...
1610
1632
  }
1611
1633
  """
1634
+ from datachain.lib.file import File
1612
1635
  from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema
1613
1636
 
1614
1637
  version = self.get_dataset(dataset_name).get_version(dataset_version)
@@ -1616,7 +1639,7 @@ class Catalog:
1616
1639
  file_signals_values = {}
1617
1640
 
1618
1641
  schema = SignalSchema.deserialize(version.feature_schema)
1619
- for file_signals in schema.get_file_signals():
1642
+ for file_signals in schema.get_signals(File):
1620
1643
  prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
1621
1644
  file_signals_values[file_signals] = {
1622
1645
  c_name.removeprefix(prefix): c_value
@@ -1657,10 +1680,13 @@ class Catalog:
1657
1680
  row["source"],
1658
1681
  row["parent"],
1659
1682
  row["name"],
1660
- row["etag"],
1661
1683
  row["size"],
1684
+ row["etag"],
1685
+ row["version"],
1686
+ row["is_latest"],
1662
1687
  row["vtype"],
1663
1688
  row["location"],
1689
+ row["last_modified"],
1664
1690
  )
1665
1691
 
1666
1692
  def ls(
@@ -1992,7 +2018,7 @@ class Catalog:
1992
2018
  )
1993
2019
  if proc.returncode == QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE:
1994
2020
  raise QueryScriptRunError(
1995
- "Last line in a script was not an instance of DatasetQuery",
2021
+ "Last line in a script was not an instance of DataChain",
1996
2022
  return_code=proc.returncode,
1997
2023
  output=output,
1998
2024
  )
datachain/cli.py CHANGED
@@ -3,7 +3,7 @@ import os
3
3
  import shlex
4
4
  import sys
5
5
  import traceback
6
- from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace
6
+ from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace
7
7
  from collections.abc import Iterable, Iterator, Mapping, Sequence
8
8
  from importlib.metadata import PackageNotFoundError, version
9
9
  from itertools import chain
@@ -106,10 +106,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
106
106
  parser = ArgumentParser(
107
107
  description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
108
108
  )
109
-
110
109
  parser.add_argument("-V", "--version", action="version", version=__version__)
111
- parser.add_argument("--internal-run-udf", action="store_true", help=SUPPRESS)
112
- parser.add_argument("--internal-run-udf-worker", action="store_true", help=SUPPRESS)
113
110
 
114
111
  parent_parser = ArgumentParser(add_help=False)
115
112
  parent_parser.add_argument(
@@ -150,9 +147,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
150
147
  help="Drop into the pdb debugger on fatal exception",
151
148
  )
152
149
 
153
- subp = parser.add_subparsers(help="Sub-command help", dest="command")
150
+ subp = parser.add_subparsers(
151
+ title="Available Commands",
152
+ metavar="command",
153
+ dest="command",
154
+ help=f"Use `{parser.prog} command --help` for command-specific help.",
155
+ required=True,
156
+ )
154
157
  parse_cp = subp.add_parser(
155
- "cp", parents=[parent_parser], help="Copy data files from the cloud"
158
+ "cp", parents=[parent_parser], description="Copy data files from the cloud"
156
159
  )
157
160
  add_sources_arg(parse_cp).complete = shtab.DIR # type: ignore[attr-defined]
158
161
  parse_cp.add_argument("output", type=str, help="Output")
@@ -179,7 +182,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
179
182
  )
180
183
 
181
184
  parse_clone = subp.add_parser(
182
- "clone", parents=[parent_parser], help="Copy data files from the cloud"
185
+ "clone", parents=[parent_parser], description="Copy data files from the cloud"
183
186
  )
184
187
  add_sources_arg(parse_clone).complete = shtab.DIR # type: ignore[attr-defined]
185
188
  parse_clone.add_argument("output", type=str, help="Output")
@@ -222,7 +225,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
222
225
  )
223
226
 
224
227
  parse_pull = subp.add_parser(
225
- "pull", parents=[parent_parser], help="Pull specific dataset version from SaaS"
228
+ "pull",
229
+ parents=[parent_parser],
230
+ description="Pull specific dataset version from SaaS",
226
231
  )
227
232
  parse_pull.add_argument(
228
233
  "dataset",
@@ -263,7 +268,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
263
268
  )
264
269
 
265
270
  parse_edit_dataset = subp.add_parser(
266
- "edit-dataset", parents=[parent_parser], help="Edit dataset metadata"
271
+ "edit-dataset", parents=[parent_parser], description="Edit dataset metadata"
267
272
  )
268
273
  parse_edit_dataset.add_argument("name", type=str, help="Dataset name")
269
274
  parse_edit_dataset.add_argument(
@@ -285,9 +290,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
285
290
  help="Dataset labels",
286
291
  )
287
292
 
288
- subp.add_parser("ls-datasets", parents=[parent_parser], help="List datasets")
293
+ subp.add_parser("ls-datasets", parents=[parent_parser], description="List datasets")
289
294
  rm_dataset_parser = subp.add_parser(
290
- "rm-dataset", parents=[parent_parser], help="Removes dataset"
295
+ "rm-dataset", parents=[parent_parser], description="Removes dataset"
291
296
  )
292
297
  rm_dataset_parser.add_argument("name", type=str, help="Dataset name")
293
298
  rm_dataset_parser.add_argument(
@@ -305,7 +310,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
305
310
  )
306
311
 
307
312
  dataset_stats_parser = subp.add_parser(
308
- "dataset-stats", parents=[parent_parser], help="Shows basic dataset stats"
313
+ "dataset-stats",
314
+ parents=[parent_parser],
315
+ description="Shows basic dataset stats",
309
316
  )
310
317
  dataset_stats_parser.add_argument("name", type=str, help="Dataset name")
311
318
  dataset_stats_parser.add_argument(
@@ -330,7 +337,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
330
337
  )
331
338
 
332
339
  parse_merge_datasets = subp.add_parser(
333
- "merge-datasets", parents=[parent_parser], help="Merges datasets"
340
+ "merge-datasets", parents=[parent_parser], description="Merges datasets"
334
341
  )
335
342
  parse_merge_datasets.add_argument(
336
343
  "--src",
@@ -360,7 +367,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
360
367
  )
361
368
 
362
369
  parse_ls = subp.add_parser(
363
- "ls", parents=[parent_parser], help="List storage contents"
370
+ "ls", parents=[parent_parser], description="List storage contents"
364
371
  )
365
372
  add_sources_arg(parse_ls, nargs="*")
366
373
  parse_ls.add_argument(
@@ -378,7 +385,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
378
385
  )
379
386
 
380
387
  parse_du = subp.add_parser(
381
- "du", parents=[parent_parser], help="Display space usage"
388
+ "du", parents=[parent_parser], description="Display space usage"
382
389
  )
383
390
  add_sources_arg(parse_du)
384
391
  parse_du.add_argument(
@@ -408,7 +415,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
408
415
  )
409
416
 
410
417
  parse_find = subp.add_parser(
411
- "find", parents=[parent_parser], help="Search in a directory hierarchy"
418
+ "find", parents=[parent_parser], description="Search in a directory hierarchy"
412
419
  )
413
420
  add_sources_arg(parse_find)
414
421
  parse_find.add_argument(
@@ -461,20 +468,20 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
461
468
  )
462
469
 
463
470
  parse_index = subp.add_parser(
464
- "index", parents=[parent_parser], help="Index storage location"
471
+ "index", parents=[parent_parser], description="Index storage location"
465
472
  )
466
473
  add_sources_arg(parse_index)
467
474
 
468
475
  subp.add_parser(
469
476
  "find-stale-storages",
470
477
  parents=[parent_parser],
471
- help="Finds and marks stale storages",
478
+ description="Finds and marks stale storages",
472
479
  )
473
480
 
474
481
  show_parser = subp.add_parser(
475
482
  "show",
476
483
  parents=[parent_parser],
477
- help="Create a new dataset with a query script",
484
+ description="Create a new dataset with a query script",
478
485
  )
479
486
  show_parser.add_argument("name", type=str, help="Dataset name")
480
487
  show_parser.add_argument(
@@ -484,12 +491,13 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
484
491
  type=int,
485
492
  help="Dataset version",
486
493
  )
494
+ show_parser.add_argument("--schema", action="store_true", help="Show schema")
487
495
  add_show_args(show_parser)
488
496
 
489
497
  query_parser = subp.add_parser(
490
498
  "query",
491
499
  parents=[parent_parser],
492
- help="Create a new dataset with a query script",
500
+ description="Create a new dataset with a query script",
493
501
  )
494
502
  query_parser.add_argument(
495
503
  "script", metavar="<script.py>", type=str, help="Filepath for script"
@@ -520,7 +528,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
520
528
  )
521
529
 
522
530
  apply_udf_parser = subp.add_parser(
523
- "apply-udf", parents=[parent_parser], help="Apply UDF"
531
+ "apply-udf", parents=[parent_parser], description="Apply UDF"
524
532
  )
525
533
  apply_udf_parser.add_argument("udf", type=str, help="UDF location")
526
534
  apply_udf_parser.add_argument("source", type=str, help="Source storage or dataset")
@@ -541,12 +549,14 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
541
549
  "--udf-params", type=str, default=None, help="UDF class parameters"
542
550
  )
543
551
  subp.add_parser(
544
- "clear-cache", parents=[parent_parser], help="Clear the local file cache"
552
+ "clear-cache", parents=[parent_parser], description="Clear the local file cache"
545
553
  )
546
554
  subp.add_parser(
547
- "gc", parents=[parent_parser], help="Garbage collect temporary tables"
555
+ "gc", parents=[parent_parser], description="Garbage collect temporary tables"
548
556
  )
549
557
 
558
+ subp.add_parser("internal-run-udf", parents=[parent_parser])
559
+ subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
550
560
  add_completion_parser(subp, [parent_parser])
551
561
  return parser
552
562
 
@@ -555,7 +565,7 @@ def add_completion_parser(subparsers, parents):
555
565
  parser = subparsers.add_parser(
556
566
  "completion",
557
567
  parents=parents,
558
- help="Output shell completion script",
568
+ description="Output shell completion script",
559
569
  )
560
570
  parser.add_argument(
561
571
  "-s",
@@ -807,18 +817,27 @@ def show(
807
817
  offset: int = 0,
808
818
  columns: Sequence[str] = (),
809
819
  no_collapse: bool = False,
820
+ schema: bool = False,
810
821
  ) -> None:
822
+ from datachain.lib.dc import DataChain
811
823
  from datachain.query import DatasetQuery
812
824
  from datachain.utils import show_records
813
825
 
826
+ dataset = catalog.get_dataset(name)
827
+ dataset_version = dataset.get_version(version or dataset.latest_version)
828
+
814
829
  query = (
815
830
  DatasetQuery(name=name, version=version, catalog=catalog)
816
831
  .select(*columns)
817
832
  .limit(limit)
818
833
  .offset(offset)
819
834
  )
820
- records = query.to_records()
835
+ records = query.to_db_records()
821
836
  show_records(records, collapse_columns=not no_collapse)
837
+ if schema and dataset_version.feature_schema:
838
+ print("\nSchema:")
839
+ dc = DataChain(name=name, version=version, catalog=catalog)
840
+ dc.print_schema()
822
841
 
823
842
 
824
843
  def query(
@@ -901,27 +920,23 @@ def completion(shell: str) -> str:
901
920
  )
902
921
 
903
922
 
904
- def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0911, PLR0912, PLR0915
923
+ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915
905
924
  # Required for Windows multiprocessing support
906
925
  freeze_support()
907
926
 
908
927
  parser = get_parser()
909
928
  args = parser.parse_args(argv)
910
929
 
911
- if args.internal_run_udf:
930
+ if args.command == "internal-run-udf":
912
931
  from datachain.query.dispatch import udf_entrypoint
913
932
 
914
933
  return udf_entrypoint()
915
934
 
916
- if args.internal_run_udf_worker:
935
+ if args.command == "internal-run-udf-worker":
917
936
  from datachain.query.dispatch import udf_worker_entrypoint
918
937
 
919
938
  return udf_worker_entrypoint()
920
939
 
921
- if args.command is None:
922
- parser.print_help()
923
- return 1
924
-
925
940
  from .catalog import get_catalog
926
941
 
927
942
  logger.addHandler(logging.StreamHandler())
@@ -1008,6 +1023,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0911, PLR09
1008
1023
  offset=args.offset,
1009
1024
  columns=args.columns,
1010
1025
  no_collapse=args.no_collapse,
1026
+ schema=args.schema,
1011
1027
  )
1012
1028
  elif args.command == "rm-dataset":
1013
1029
  rm_dataset(catalog, args.name, version=args.version, force=args.force)
@@ -40,6 +40,7 @@ from datachain.error import (
40
40
  StorageNotFoundError,
41
41
  TableMissingError,
42
42
  )
43
+ from datachain.job import Job
43
44
  from datachain.storage import Storage, StorageStatus, StorageURI
44
45
  from datachain.utils import JSONSerialize, is_expired
45
46
 
@@ -67,6 +68,7 @@ class AbstractMetastore(ABC, Serializable):
67
68
  storage_class: type[Storage] = Storage
68
69
  dataset_class: type[DatasetRecord] = DatasetRecord
69
70
  dependency_class: type[DatasetDependency] = DatasetDependency
71
+ job_class: type[Job] = Job
70
72
 
71
73
  def __init__(
72
74
  self,
@@ -377,6 +379,9 @@ class AbstractMetastore(ABC, Serializable):
377
379
  # Jobs
378
380
  #
379
381
 
382
+ def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
383
+ raise NotImplementedError
384
+
380
385
  @abstractmethod
381
386
  def create_job(
382
387
  self,
@@ -1467,6 +1472,10 @@ class AbstractDBMetastore(AbstractMetastore):
1467
1472
  Column("metrics", JSON, nullable=False),
1468
1473
  ]
1469
1474
 
1475
+ @cached_property
1476
+ def _job_fields(self) -> list[str]:
1477
+ return [c.name for c in self._jobs_columns() if c.name] # type: ignore[attr-defined]
1478
+
1470
1479
  @cached_property
1471
1480
  def _jobs(self) -> "Table":
1472
1481
  return Table(self.JOBS_TABLE, self.db.metadata, *self._jobs_columns())
@@ -1484,6 +1493,21 @@ class AbstractDBMetastore(AbstractMetastore):
1484
1493
  return self._jobs.update()
1485
1494
  return self._jobs.update().where(*where)
1486
1495
 
1496
+ def _parse_job(self, rows) -> Job:
1497
+ return Job.parse(*rows)
1498
+
1499
+ def _parse_jobs(self, rows) -> Iterator["Job"]:
1500
+ for _, g in groupby(rows, lambda r: r[0]):
1501
+ yield self._parse_job(*list(g))
1502
+
1503
+ def _jobs_query(self):
1504
+ return self._jobs_select(*[getattr(self._jobs.c, f) for f in self._job_fields])
1505
+
1506
+ def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
1507
+ """List jobs by ids."""
1508
+ query = self._jobs_query().where(self._jobs.c.id.in_(ids))
1509
+ yield from self._parse_jobs(self.db.execute(query, conn=conn))
1510
+
1487
1511
  def create_job(
1488
1512
  self,
1489
1513
  name: str,
@@ -390,7 +390,9 @@ class AbstractWarehouse(ABC, Serializable):
390
390
  expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
391
391
  sa.func.count(table.c.sys__id),
392
392
  )
393
- if "size" in table.columns:
393
+ if "file__size" in table.columns:
394
+ expressions = (*expressions, sa.func.sum(table.c.file__size))
395
+ elif "size" in table.columns:
394
396
  expressions = (*expressions, sa.func.sum(table.c.size))
395
397
  query = select(*expressions)
396
398
  ((nrows, *rest),) = self.db.execute(query)
datachain/job.py ADDED
@@ -0,0 +1,56 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+ from typing import Any, Optional, TypeVar
5
+
6
+ J = TypeVar("J", bound="Job")
7
+
8
+
9
+ @dataclass
10
+ class Job:
11
+ id: str
12
+ name: str
13
+ status: int
14
+ created_at: datetime
15
+ query: str
16
+ query_type: int
17
+ workers: int
18
+ params: dict[str, str]
19
+ metrics: dict[str, Any]
20
+ finished_at: Optional[datetime] = None
21
+ python_version: Optional[str] = None
22
+ error_message: str = ""
23
+ error_stack: str = ""
24
+
25
+ @classmethod
26
+ def parse(
27
+ cls: type[J],
28
+ id: str,
29
+ name: str,
30
+ status: int,
31
+ created_at: datetime,
32
+ finished_at: Optional[datetime],
33
+ query: str,
34
+ query_type: int,
35
+ workers: int,
36
+ python_version: Optional[str],
37
+ error_message: str,
38
+ error_stack: str,
39
+ params: str,
40
+ metrics: str,
41
+ ) -> "Job":
42
+ return cls(
43
+ id,
44
+ name,
45
+ status,
46
+ created_at,
47
+ query,
48
+ query_type,
49
+ workers,
50
+ json.loads(params),
51
+ json.loads(metrics),
52
+ finished_at,
53
+ python_version,
54
+ error_message,
55
+ error_stack,
56
+ )
datachain/lib/arrow.py CHANGED
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import pyarrow as pa
6
6
  from pyarrow.dataset import dataset
7
+ from tqdm import tqdm
7
8
 
8
9
  from datachain.lib.file import File, IndexedFile
9
10
  from datachain.lib.udf import Generator
@@ -13,33 +14,44 @@ if TYPE_CHECKING:
13
14
 
14
15
 
15
16
  class ArrowGenerator(Generator):
16
- def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
17
+ def __init__(
18
+ self,
19
+ schema: Optional["pa.Schema"] = None,
20
+ nrows: Optional[int] = None,
21
+ **kwargs,
22
+ ):
17
23
  """
18
24
  Generator for getting rows from tabular files.
19
25
 
20
26
  Parameters:
21
27
 
22
28
  schema : Optional pyarrow schema for validation.
29
+ nrows : Optional row limit.
23
30
  kwargs: Parameters to pass to pyarrow.dataset.dataset.
24
31
  """
25
32
  super().__init__()
26
33
  self.schema = schema
34
+ self.nrows = nrows
27
35
  self.kwargs = kwargs
28
36
 
29
37
  def process(self, file: File):
30
38
  path = file.get_path()
31
39
  ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
32
40
  index = 0
33
- for record_batch in ds.to_batches():
34
- for record in record_batch.to_pylist():
35
- source = IndexedFile(file=file, index=index)
36
- yield [source, *record.values()]
37
- index += 1
41
+ with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
42
+ for record_batch in ds.to_batches():
43
+ for record in record_batch.to_pylist():
44
+ source = IndexedFile(file=file, index=index)
45
+ yield [source, *record.values()]
46
+ index += 1
47
+ if self.nrows and index >= self.nrows:
48
+ return
49
+ pbar.update(len(record_batch))
38
50
 
39
51
 
40
52
  def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
41
53
  schemas = []
42
- for file in chain.iterate_one("file"):
54
+ for file in chain.collect("file"):
43
55
  ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
44
56
  schemas.append(ds.schema)
45
57
  return pa.unify_schemas(schemas)