datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (49) hide show
  1. datachain/__init__.py +3 -4
  2. datachain/cache.py +10 -4
  3. datachain/catalog/catalog.py +35 -15
  4. datachain/cli.py +37 -32
  5. datachain/data_storage/metastore.py +24 -0
  6. datachain/data_storage/warehouse.py +3 -1
  7. datachain/job.py +56 -0
  8. datachain/lib/arrow.py +19 -7
  9. datachain/lib/clip.py +89 -66
  10. datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
  11. datachain/lib/convert/sql_to_python.py +23 -0
  12. datachain/lib/convert/values_to_tuples.py +51 -33
  13. datachain/lib/data_model.py +6 -27
  14. datachain/lib/dataset_info.py +70 -0
  15. datachain/lib/dc.py +646 -152
  16. datachain/lib/file.py +117 -15
  17. datachain/lib/image.py +1 -1
  18. datachain/lib/meta_formats.py +14 -2
  19. datachain/lib/model_store.py +3 -2
  20. datachain/lib/pytorch.py +10 -7
  21. datachain/lib/signal_schema.py +39 -14
  22. datachain/lib/text.py +2 -1
  23. datachain/lib/udf.py +56 -5
  24. datachain/lib/udf_signature.py +1 -1
  25. datachain/lib/webdataset.py +4 -3
  26. datachain/node.py +11 -8
  27. datachain/query/dataset.py +66 -147
  28. datachain/query/dispatch.py +15 -13
  29. datachain/query/schema.py +2 -0
  30. datachain/query/session.py +4 -4
  31. datachain/sql/functions/array.py +12 -0
  32. datachain/sql/functions/string.py +8 -0
  33. datachain/torch/__init__.py +1 -1
  34. datachain/utils.py +45 -0
  35. datachain-0.2.12.dist-info/METADATA +412 -0
  36. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
  37. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
  38. datachain/lib/feature_registry.py +0 -77
  39. datachain/lib/gpt4_vision.py +0 -97
  40. datachain/lib/hf_image_to_text.py +0 -97
  41. datachain/lib/hf_pipeline.py +0 -90
  42. datachain/lib/image_transform.py +0 -103
  43. datachain/lib/iptc_exif_xmp.py +0 -76
  44. datachain/lib/unstructured.py +0 -41
  45. datachain/text/__init__.py +0 -3
  46. datachain-0.2.10.dist-info/METADATA +0 -430
  47. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
  48. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
  49. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
1
+ from datachain.lib.data_model import DataModel, DataType, is_chain_type
2
2
  from datachain.lib.dc import C, Column, DataChain, Sys
3
3
  from datachain.lib.file import (
4
4
  File,
@@ -8,15 +8,14 @@ from datachain.lib.file import (
8
8
  TarVFile,
9
9
  TextFile,
10
10
  )
11
+ from datachain.lib.model_store import ModelStore
11
12
  from datachain.lib.udf import Aggregator, Generator, Mapper
12
13
  from datachain.lib.utils import AbstractUDF, DataChainError
13
- from datachain.query.dataset import UDF as BaseUDF # noqa: N811
14
14
  from datachain.query.session import Session
15
15
 
16
16
  __all__ = [
17
17
  "AbstractUDF",
18
18
  "Aggregator",
19
- "BaseUDF",
20
19
  "C",
21
20
  "Column",
22
21
  "DataChain",
@@ -24,12 +23,12 @@ __all__ = [
24
23
  "DataModel",
25
24
  "DataType",
26
25
  "File",
27
- "FileBasic",
28
26
  "FileError",
29
27
  "Generator",
30
28
  "ImageFile",
31
29
  "IndexedFile",
32
30
  "Mapper",
31
+ "ModelStore",
33
32
  "Session",
34
33
  "Sys",
35
34
  "TarVFile",
datachain/cache.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import hashlib
2
2
  import json
3
3
  import os
4
+ from datetime import datetime
4
5
  from functools import partial
5
6
  from typing import TYPE_CHECKING, Optional
6
7
 
@@ -9,6 +10,8 @@ from dvc_data.hashfile.db.local import LocalHashFileDB
9
10
  from dvc_objects.fs.local import LocalFileSystem
10
11
  from fsspec.callbacks import Callback, TqdmCallback
11
12
 
13
+ from datachain.utils import TIME_ZERO
14
+
12
15
  from .progress import Tqdm
13
16
 
14
17
  if TYPE_CHECKING:
@@ -23,10 +26,13 @@ class UniqueId:
23
26
  storage: "StorageURI"
24
27
  parent: str
25
28
  name: str
26
- etag: str
27
29
  size: int
28
- vtype: str
29
- location: Optional[str]
30
+ etag: str
31
+ version: str = ""
32
+ is_latest: bool = True
33
+ vtype: str = ""
34
+ location: Optional[str] = None
35
+ last_modified: datetime = TIME_ZERO
30
36
 
31
37
  @property
32
38
  def path(self) -> str:
@@ -49,7 +55,7 @@ class UniqueId:
49
55
  def get_hash(self) -> str:
50
56
  etag = f"{self.vtype}{self.location}" if self.vtype else self.etag
51
57
  return sha256(
52
- f"{self.storage}/{self.parent}/{self.name}/{etag}".encode()
58
+ f"{self.storage}/{self.parent}/{self.name}/{self.version}/{etag}".encode()
53
59
  ).hexdigest()
54
60
 
55
61
 
@@ -84,12 +84,14 @@ if TYPE_CHECKING:
84
84
  AbstractMetastore,
85
85
  AbstractWarehouse,
86
86
  )
87
+ from datachain.dataset import DatasetVersion
88
+ from datachain.job import Job
87
89
 
88
90
  logger = logging.getLogger("datachain")
89
91
 
90
92
  DEFAULT_DATASET_DIR = "dataset"
91
93
  DATASET_FILE_SUFFIX = ".edatachain"
92
- FEATURE_CLASSES = ["Feature"]
94
+ FEATURE_CLASSES = ["DataModel"]
93
95
 
94
96
  TTL_INT = 4 * 60 * 60
95
97
 
@@ -948,13 +950,9 @@ class Catalog:
948
950
  ms = self.metastore.clone(uri, None)
949
951
  st = self.warehouse.clone()
950
952
  listing = Listing(None, ms, st, client, None)
951
- rows = (
952
- DatasetQuery(
953
- name=dataset.name, version=ds_version, catalog=self
954
- )
955
- .select()
956
- .to_records()
957
- )
953
+ rows = DatasetQuery(
954
+ name=dataset.name, version=ds_version, catalog=self
955
+ ).to_db_records()
958
956
  indexed_sources.append(
959
957
  (
960
958
  listing,
@@ -1160,9 +1158,8 @@ class Catalog:
1160
1158
  if not dataset_version.preview:
1161
1159
  values["preview"] = (
1162
1160
  DatasetQuery(name=dataset.name, version=version, catalog=self)
1163
- .select()
1164
1161
  .limit(20)
1165
- .to_records()
1162
+ .to_db_records()
1166
1163
  )
1167
1164
 
1168
1165
  if not values:
@@ -1420,6 +1417,25 @@ class Catalog:
1420
1417
  if not d.is_bucket_listing:
1421
1418
  yield d
1422
1419
 
1420
+ def list_datasets_versions(
1421
+ self,
1422
+ ) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
1423
+ """Iterate over all dataset versions with related jobs."""
1424
+ datasets = list(self.ls_datasets())
1425
+
1426
+ # preselect dataset versions jobs from db to avoid multiple queries
1427
+ jobs_ids: set[str] = {
1428
+ v.job_id for ds in datasets for v in ds.versions if v.job_id
1429
+ }
1430
+ jobs: dict[str, Job] = {}
1431
+ if jobs_ids:
1432
+ jobs = {j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))}
1433
+
1434
+ for d in datasets:
1435
+ yield from (
1436
+ (d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions
1437
+ )
1438
+
1423
1439
  def ls_dataset_rows(
1424
1440
  self, name: str, version: int, offset=None, limit=None
1425
1441
  ) -> list[dict]:
@@ -1427,7 +1443,7 @@ class Catalog:
1427
1443
 
1428
1444
  dataset = self.get_dataset(name)
1429
1445
 
1430
- q = DatasetQuery(name=dataset.name, version=version, catalog=self).select()
1446
+ q = DatasetQuery(name=dataset.name, version=version, catalog=self)
1431
1447
  if limit:
1432
1448
  q = q.limit(limit)
1433
1449
  if offset:
@@ -1435,7 +1451,7 @@ class Catalog:
1435
1451
 
1436
1452
  q = q.order_by("sys__id")
1437
1453
 
1438
- return q.to_records()
1454
+ return q.to_db_records()
1439
1455
 
1440
1456
  def signed_url(self, source: str, path: str, client_config=None) -> str:
1441
1457
  client_config = client_config or self.client_config
@@ -1609,6 +1625,7 @@ class Catalog:
1609
1625
  ...
1610
1626
  }
1611
1627
  """
1628
+ from datachain.lib.file import File
1612
1629
  from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema
1613
1630
 
1614
1631
  version = self.get_dataset(dataset_name).get_version(dataset_version)
@@ -1616,7 +1633,7 @@ class Catalog:
1616
1633
  file_signals_values = {}
1617
1634
 
1618
1635
  schema = SignalSchema.deserialize(version.feature_schema)
1619
- for file_signals in schema.get_file_signals():
1636
+ for file_signals in schema.get_signals(File):
1620
1637
  prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
1621
1638
  file_signals_values[file_signals] = {
1622
1639
  c_name.removeprefix(prefix): c_value
@@ -1657,10 +1674,13 @@ class Catalog:
1657
1674
  row["source"],
1658
1675
  row["parent"],
1659
1676
  row["name"],
1660
- row["etag"],
1661
1677
  row["size"],
1678
+ row["etag"],
1679
+ row["version"],
1680
+ row["is_latest"],
1662
1681
  row["vtype"],
1663
1682
  row["location"],
1683
+ row["last_modified"],
1664
1684
  )
1665
1685
 
1666
1686
  def ls(
@@ -1992,7 +2012,7 @@ class Catalog:
1992
2012
  )
1993
2013
  if proc.returncode == QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE:
1994
2014
  raise QueryScriptRunError(
1995
- "Last line in a script was not an instance of DatasetQuery",
2015
+ "Last line in a script was not an instance of DataChain",
1996
2016
  return_code=proc.returncode,
1997
2017
  output=output,
1998
2018
  )
datachain/cli.py CHANGED
@@ -3,7 +3,7 @@ import os
3
3
  import shlex
4
4
  import sys
5
5
  import traceback
6
- from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace
6
+ from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace
7
7
  from collections.abc import Iterable, Iterator, Mapping, Sequence
8
8
  from importlib.metadata import PackageNotFoundError, version
9
9
  from itertools import chain
@@ -106,10 +106,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
106
106
  parser = ArgumentParser(
107
107
  description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
108
108
  )
109
-
110
109
  parser.add_argument("-V", "--version", action="version", version=__version__)
111
- parser.add_argument("--internal-run-udf", action="store_true", help=SUPPRESS)
112
- parser.add_argument("--internal-run-udf-worker", action="store_true", help=SUPPRESS)
113
110
 
114
111
  parent_parser = ArgumentParser(add_help=False)
115
112
  parent_parser.add_argument(
@@ -150,9 +147,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
150
147
  help="Drop into the pdb debugger on fatal exception",
151
148
  )
152
149
 
153
- subp = parser.add_subparsers(help="Sub-command help", dest="command")
150
+ subp = parser.add_subparsers(
151
+ title="Available Commands",
152
+ metavar="command",
153
+ dest="command",
154
+ help=f"Use `{parser.prog} command --help` for command-specific help.",
155
+ required=True,
156
+ )
154
157
  parse_cp = subp.add_parser(
155
- "cp", parents=[parent_parser], help="Copy data files from the cloud"
158
+ "cp", parents=[parent_parser], description="Copy data files from the cloud"
156
159
  )
157
160
  add_sources_arg(parse_cp).complete = shtab.DIR # type: ignore[attr-defined]
158
161
  parse_cp.add_argument("output", type=str, help="Output")
@@ -179,7 +182,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
179
182
  )
180
183
 
181
184
  parse_clone = subp.add_parser(
182
- "clone", parents=[parent_parser], help="Copy data files from the cloud"
185
+ "clone", parents=[parent_parser], description="Copy data files from the cloud"
183
186
  )
184
187
  add_sources_arg(parse_clone).complete = shtab.DIR # type: ignore[attr-defined]
185
188
  parse_clone.add_argument("output", type=str, help="Output")
@@ -222,7 +225,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
222
225
  )
223
226
 
224
227
  parse_pull = subp.add_parser(
225
- "pull", parents=[parent_parser], help="Pull specific dataset version from SaaS"
228
+ "pull",
229
+ parents=[parent_parser],
230
+ description="Pull specific dataset version from SaaS",
226
231
  )
227
232
  parse_pull.add_argument(
228
233
  "dataset",
@@ -263,7 +268,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
263
268
  )
264
269
 
265
270
  parse_edit_dataset = subp.add_parser(
266
- "edit-dataset", parents=[parent_parser], help="Edit dataset metadata"
271
+ "edit-dataset", parents=[parent_parser], description="Edit dataset metadata"
267
272
  )
268
273
  parse_edit_dataset.add_argument("name", type=str, help="Dataset name")
269
274
  parse_edit_dataset.add_argument(
@@ -285,9 +290,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
285
290
  help="Dataset labels",
286
291
  )
287
292
 
288
- subp.add_parser("ls-datasets", parents=[parent_parser], help="List datasets")
293
+ subp.add_parser("ls-datasets", parents=[parent_parser], description="List datasets")
289
294
  rm_dataset_parser = subp.add_parser(
290
- "rm-dataset", parents=[parent_parser], help="Removes dataset"
295
+ "rm-dataset", parents=[parent_parser], description="Removes dataset"
291
296
  )
292
297
  rm_dataset_parser.add_argument("name", type=str, help="Dataset name")
293
298
  rm_dataset_parser.add_argument(
@@ -305,7 +310,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
305
310
  )
306
311
 
307
312
  dataset_stats_parser = subp.add_parser(
308
- "dataset-stats", parents=[parent_parser], help="Shows basic dataset stats"
313
+ "dataset-stats",
314
+ parents=[parent_parser],
315
+ description="Shows basic dataset stats",
309
316
  )
310
317
  dataset_stats_parser.add_argument("name", type=str, help="Dataset name")
311
318
  dataset_stats_parser.add_argument(
@@ -330,7 +337,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
330
337
  )
331
338
 
332
339
  parse_merge_datasets = subp.add_parser(
333
- "merge-datasets", parents=[parent_parser], help="Merges datasets"
340
+ "merge-datasets", parents=[parent_parser], description="Merges datasets"
334
341
  )
335
342
  parse_merge_datasets.add_argument(
336
343
  "--src",
@@ -360,7 +367,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
360
367
  )
361
368
 
362
369
  parse_ls = subp.add_parser(
363
- "ls", parents=[parent_parser], help="List storage contents"
370
+ "ls", parents=[parent_parser], description="List storage contents"
364
371
  )
365
372
  add_sources_arg(parse_ls, nargs="*")
366
373
  parse_ls.add_argument(
@@ -378,7 +385,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
378
385
  )
379
386
 
380
387
  parse_du = subp.add_parser(
381
- "du", parents=[parent_parser], help="Display space usage"
388
+ "du", parents=[parent_parser], description="Display space usage"
382
389
  )
383
390
  add_sources_arg(parse_du)
384
391
  parse_du.add_argument(
@@ -408,7 +415,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
408
415
  )
409
416
 
410
417
  parse_find = subp.add_parser(
411
- "find", parents=[parent_parser], help="Search in a directory hierarchy"
418
+ "find", parents=[parent_parser], description="Search in a directory hierarchy"
412
419
  )
413
420
  add_sources_arg(parse_find)
414
421
  parse_find.add_argument(
@@ -461,20 +468,20 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
461
468
  )
462
469
 
463
470
  parse_index = subp.add_parser(
464
- "index", parents=[parent_parser], help="Index storage location"
471
+ "index", parents=[parent_parser], description="Index storage location"
465
472
  )
466
473
  add_sources_arg(parse_index)
467
474
 
468
475
  subp.add_parser(
469
476
  "find-stale-storages",
470
477
  parents=[parent_parser],
471
- help="Finds and marks stale storages",
478
+ description="Finds and marks stale storages",
472
479
  )
473
480
 
474
481
  show_parser = subp.add_parser(
475
482
  "show",
476
483
  parents=[parent_parser],
477
- help="Create a new dataset with a query script",
484
+ description="Create a new dataset with a query script",
478
485
  )
479
486
  show_parser.add_argument("name", type=str, help="Dataset name")
480
487
  show_parser.add_argument(
@@ -489,7 +496,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
489
496
  query_parser = subp.add_parser(
490
497
  "query",
491
498
  parents=[parent_parser],
492
- help="Create a new dataset with a query script",
499
+ description="Create a new dataset with a query script",
493
500
  )
494
501
  query_parser.add_argument(
495
502
  "script", metavar="<script.py>", type=str, help="Filepath for script"
@@ -520,7 +527,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
520
527
  )
521
528
 
522
529
  apply_udf_parser = subp.add_parser(
523
- "apply-udf", parents=[parent_parser], help="Apply UDF"
530
+ "apply-udf", parents=[parent_parser], description="Apply UDF"
524
531
  )
525
532
  apply_udf_parser.add_argument("udf", type=str, help="UDF location")
526
533
  apply_udf_parser.add_argument("source", type=str, help="Source storage or dataset")
@@ -541,12 +548,14 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
541
548
  "--udf-params", type=str, default=None, help="UDF class parameters"
542
549
  )
543
550
  subp.add_parser(
544
- "clear-cache", parents=[parent_parser], help="Clear the local file cache"
551
+ "clear-cache", parents=[parent_parser], description="Clear the local file cache"
545
552
  )
546
553
  subp.add_parser(
547
- "gc", parents=[parent_parser], help="Garbage collect temporary tables"
554
+ "gc", parents=[parent_parser], description="Garbage collect temporary tables"
548
555
  )
549
556
 
557
+ subp.add_parser("internal-run-udf", parents=[parent_parser])
558
+ subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
550
559
  add_completion_parser(subp, [parent_parser])
551
560
  return parser
552
561
 
@@ -555,7 +564,7 @@ def add_completion_parser(subparsers, parents):
555
564
  parser = subparsers.add_parser(
556
565
  "completion",
557
566
  parents=parents,
558
- help="Output shell completion script",
567
+ description="Output shell completion script",
559
568
  )
560
569
  parser.add_argument(
561
570
  "-s",
@@ -817,7 +826,7 @@ def show(
817
826
  .limit(limit)
818
827
  .offset(offset)
819
828
  )
820
- records = query.to_records()
829
+ records = query.to_db_records()
821
830
  show_records(records, collapse_columns=not no_collapse)
822
831
 
823
832
 
@@ -901,27 +910,23 @@ def completion(shell: str) -> str:
901
910
  )
902
911
 
903
912
 
904
- def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0911, PLR0912, PLR0915
913
+ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915
905
914
  # Required for Windows multiprocessing support
906
915
  freeze_support()
907
916
 
908
917
  parser = get_parser()
909
918
  args = parser.parse_args(argv)
910
919
 
911
- if args.internal_run_udf:
920
+ if args.command == "internal-run-udf":
912
921
  from datachain.query.dispatch import udf_entrypoint
913
922
 
914
923
  return udf_entrypoint()
915
924
 
916
- if args.internal_run_udf_worker:
925
+ if args.command == "internal-run-udf-worker":
917
926
  from datachain.query.dispatch import udf_worker_entrypoint
918
927
 
919
928
  return udf_worker_entrypoint()
920
929
 
921
- if args.command is None:
922
- parser.print_help()
923
- return 1
924
-
925
930
  from .catalog import get_catalog
926
931
 
927
932
  logger.addHandler(logging.StreamHandler())
@@ -40,6 +40,7 @@ from datachain.error import (
40
40
  StorageNotFoundError,
41
41
  TableMissingError,
42
42
  )
43
+ from datachain.job import Job
43
44
  from datachain.storage import Storage, StorageStatus, StorageURI
44
45
  from datachain.utils import JSONSerialize, is_expired
45
46
 
@@ -67,6 +68,7 @@ class AbstractMetastore(ABC, Serializable):
67
68
  storage_class: type[Storage] = Storage
68
69
  dataset_class: type[DatasetRecord] = DatasetRecord
69
70
  dependency_class: type[DatasetDependency] = DatasetDependency
71
+ job_class: type[Job] = Job
70
72
 
71
73
  def __init__(
72
74
  self,
@@ -377,6 +379,9 @@ class AbstractMetastore(ABC, Serializable):
377
379
  # Jobs
378
380
  #
379
381
 
382
+ def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
383
+ raise NotImplementedError
384
+
380
385
  @abstractmethod
381
386
  def create_job(
382
387
  self,
@@ -1467,6 +1472,10 @@ class AbstractDBMetastore(AbstractMetastore):
1467
1472
  Column("metrics", JSON, nullable=False),
1468
1473
  ]
1469
1474
 
1475
+ @cached_property
1476
+ def _job_fields(self) -> list[str]:
1477
+ return [c.name for c in self._jobs_columns() if c.name] # type: ignore[attr-defined]
1478
+
1470
1479
  @cached_property
1471
1480
  def _jobs(self) -> "Table":
1472
1481
  return Table(self.JOBS_TABLE, self.db.metadata, *self._jobs_columns())
@@ -1484,6 +1493,21 @@ class AbstractDBMetastore(AbstractMetastore):
1484
1493
  return self._jobs.update()
1485
1494
  return self._jobs.update().where(*where)
1486
1495
 
1496
+ def _parse_job(self, rows) -> Job:
1497
+ return Job.parse(*rows)
1498
+
1499
+ def _parse_jobs(self, rows) -> Iterator["Job"]:
1500
+ for _, g in groupby(rows, lambda r: r[0]):
1501
+ yield self._parse_job(*list(g))
1502
+
1503
+ def _jobs_query(self):
1504
+ return self._jobs_select(*[getattr(self._jobs.c, f) for f in self._job_fields])
1505
+
1506
+ def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
1507
+ """List jobs by ids."""
1508
+ query = self._jobs_query().where(self._jobs.c.id.in_(ids))
1509
+ yield from self._parse_jobs(self.db.execute(query, conn=conn))
1510
+
1487
1511
  def create_job(
1488
1512
  self,
1489
1513
  name: str,
@@ -390,7 +390,9 @@ class AbstractWarehouse(ABC, Serializable):
390
390
  expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
391
391
  sa.func.count(table.c.sys__id),
392
392
  )
393
- if "size" in table.columns:
393
+ if "file__size" in table.columns:
394
+ expressions = (*expressions, sa.func.sum(table.c.file__size))
395
+ elif "size" in table.columns:
394
396
  expressions = (*expressions, sa.func.sum(table.c.size))
395
397
  query = select(*expressions)
396
398
  ((nrows, *rest),) = self.db.execute(query)
datachain/job.py ADDED
@@ -0,0 +1,56 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+ from typing import Any, Optional, TypeVar
5
+
6
+ J = TypeVar("J", bound="Job")
7
+
8
+
9
+ @dataclass
10
+ class Job:
11
+ id: str
12
+ name: str
13
+ status: int
14
+ created_at: datetime
15
+ query: str
16
+ query_type: int
17
+ workers: int
18
+ params: dict[str, str]
19
+ metrics: dict[str, Any]
20
+ finished_at: Optional[datetime] = None
21
+ python_version: Optional[str] = None
22
+ error_message: str = ""
23
+ error_stack: str = ""
24
+
25
+ @classmethod
26
+ def parse(
27
+ cls: type[J],
28
+ id: str,
29
+ name: str,
30
+ status: int,
31
+ created_at: datetime,
32
+ finished_at: Optional[datetime],
33
+ query: str,
34
+ query_type: int,
35
+ workers: int,
36
+ python_version: Optional[str],
37
+ error_message: str,
38
+ error_stack: str,
39
+ params: str,
40
+ metrics: str,
41
+ ) -> "Job":
42
+ return cls(
43
+ id,
44
+ name,
45
+ status,
46
+ created_at,
47
+ query,
48
+ query_type,
49
+ workers,
50
+ json.loads(params),
51
+ json.loads(metrics),
52
+ finished_at,
53
+ python_version,
54
+ error_message,
55
+ error_stack,
56
+ )
datachain/lib/arrow.py CHANGED
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import pyarrow as pa
6
6
  from pyarrow.dataset import dataset
7
+ from tqdm import tqdm
7
8
 
8
9
  from datachain.lib.file import File, IndexedFile
9
10
  from datachain.lib.udf import Generator
@@ -13,33 +14,44 @@ if TYPE_CHECKING:
13
14
 
14
15
 
15
16
  class ArrowGenerator(Generator):
16
- def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
17
+ def __init__(
18
+ self,
19
+ schema: Optional["pa.Schema"] = None,
20
+ nrows: Optional[int] = None,
21
+ **kwargs,
22
+ ):
17
23
  """
18
24
  Generator for getting rows from tabular files.
19
25
 
20
26
  Parameters:
21
27
 
22
28
  schema : Optional pyarrow schema for validation.
29
+ nrows : Optional row limit.
23
30
  kwargs: Parameters to pass to pyarrow.dataset.dataset.
24
31
  """
25
32
  super().__init__()
26
33
  self.schema = schema
34
+ self.nrows = nrows
27
35
  self.kwargs = kwargs
28
36
 
29
37
  def process(self, file: File):
30
38
  path = file.get_path()
31
39
  ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
32
40
  index = 0
33
- for record_batch in ds.to_batches():
34
- for record in record_batch.to_pylist():
35
- source = IndexedFile(file=file, index=index)
36
- yield [source, *record.values()]
37
- index += 1
41
+ with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
42
+ for record_batch in ds.to_batches():
43
+ for record in record_batch.to_pylist():
44
+ source = IndexedFile(file=file, index=index)
45
+ yield [source, *record.values()]
46
+ index += 1
47
+ if self.nrows and index >= self.nrows:
48
+ return
49
+ pbar.update(len(record_batch))
38
50
 
39
51
 
40
52
  def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
41
53
  schemas = []
42
- for file in chain.iterate_one("file"):
54
+ for file in chain.collect("file"):
43
55
  ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
44
56
  schemas.append(ds.schema)
45
57
  return pa.unify_schemas(schemas)