datachain 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (38) hide show
  1. datachain/__init__.py +2 -0
  2. datachain/catalog/catalog.py +62 -228
  3. datachain/cli.py +136 -22
  4. datachain/client/fsspec.py +9 -0
  5. datachain/client/local.py +11 -32
  6. datachain/config.py +126 -51
  7. datachain/data_storage/schema.py +66 -33
  8. datachain/data_storage/sqlite.py +12 -4
  9. datachain/data_storage/warehouse.py +101 -129
  10. datachain/lib/convert/sql_to_python.py +8 -12
  11. datachain/lib/dc.py +275 -80
  12. datachain/lib/func/__init__.py +32 -0
  13. datachain/lib/func/aggregate.py +353 -0
  14. datachain/lib/func/func.py +152 -0
  15. datachain/lib/listing.py +6 -21
  16. datachain/lib/listing_info.py +4 -0
  17. datachain/lib/signal_schema.py +17 -8
  18. datachain/lib/udf.py +3 -3
  19. datachain/lib/utils.py +5 -0
  20. datachain/listing.py +22 -48
  21. datachain/query/__init__.py +1 -2
  22. datachain/query/batch.py +0 -1
  23. datachain/query/dataset.py +33 -46
  24. datachain/query/schema.py +1 -61
  25. datachain/query/session.py +33 -25
  26. datachain/remote/studio.py +63 -14
  27. datachain/sql/functions/__init__.py +1 -1
  28. datachain/sql/functions/aggregate.py +47 -0
  29. datachain/sql/functions/array.py +0 -8
  30. datachain/sql/sqlite/base.py +20 -2
  31. datachain/studio.py +129 -0
  32. datachain/utils.py +58 -0
  33. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/METADATA +7 -6
  34. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/RECORD +38 -33
  35. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/WHEEL +1 -1
  36. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/LICENSE +0 -0
  37. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/entry_points.txt +0 -0
  38. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/top_level.txt +0 -0
@@ -643,7 +643,7 @@ class SQLiteWarehouse(AbstractWarehouse):
643
643
  self, dataset: DatasetRecord, version: int
644
644
  ) -> list[StorageURI]:
645
645
  dr = self.dataset_rows(dataset, version)
646
- query = dr.select(dr.c.file__source).distinct()
646
+ query = dr.select(dr.c("source", object_name="file")).distinct()
647
647
  cur = self.db.cursor()
648
648
  cur.row_factory = sqlite3.Row # type: ignore[assignment]
649
649
 
@@ -671,13 +671,13 @@ class SQLiteWarehouse(AbstractWarehouse):
671
671
  # destination table doesn't exist, create it
672
672
  self.create_dataset_rows_table(
673
673
  self.dataset_table_name(dst.name, dst_version),
674
- columns=src_dr.c,
674
+ columns=src_dr.columns,
675
675
  )
676
676
  dst_empty = True
677
677
 
678
678
  dst_dr = self.dataset_rows(dst, dst_version).table
679
- merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
680
- select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))
679
+ merge_fields = [c.name for c in src_dr.columns if c.name != "sys__id"]
680
+ select_src = select(*(getattr(src_dr.columns, f) for f in merge_fields))
681
681
 
682
682
  if dst_empty:
683
683
  # we don't need union, but just select from source to destination
@@ -763,6 +763,14 @@ class SQLiteWarehouse(AbstractWarehouse):
763
763
  query: Select,
764
764
  progress_cb: Optional[Callable[[int], None]] = None,
765
765
  ) -> None:
766
+ if len(query._group_by_clause) > 0:
767
+ select_q = query.with_only_columns(
768
+ *[c for c in query.selected_columns if c.name != "sys__id"]
769
+ )
770
+ q = table.insert().from_select(list(select_q.selected_columns), select_q)
771
+ self.db.execute(q)
772
+ return
773
+
766
774
  if "sys__id" in query.selected_columns:
767
775
  col_id = query.selected_columns.sys__id
768
776
  else:
@@ -185,7 +185,12 @@ class AbstractWarehouse(ABC, Serializable):
185
185
  @abstractmethod
186
186
  def is_ready(self, timeout: Optional[int] = None) -> bool: ...
187
187
 
188
- def dataset_rows(self, dataset: DatasetRecord, version: Optional[int] = None):
188
+ def dataset_rows(
189
+ self,
190
+ dataset: DatasetRecord,
191
+ version: Optional[int] = None,
192
+ object_name: str = "file",
193
+ ):
189
194
  version = version or dataset.latest_version
190
195
 
191
196
  table_name = self.dataset_table_name(dataset.name, version)
@@ -194,6 +199,7 @@ class AbstractWarehouse(ABC, Serializable):
194
199
  self.db.engine,
195
200
  self.db.metadata,
196
201
  dataset.get_schema(version),
202
+ object_name=object_name,
197
203
  )
198
204
 
199
205
  @property
@@ -215,10 +221,6 @@ class AbstractWarehouse(ABC, Serializable):
215
221
  limit = query._limit
216
222
  paginated_query = query.limit(page_size)
217
223
 
218
- if not paginated_query._order_by_clauses:
219
- # default order by is order by `sys__id`
220
- paginated_query = paginated_query.order_by(query.selected_columns.sys__id)
221
-
222
224
  results = None
223
225
  offset = 0
224
226
  num_yielded = 0
@@ -323,55 +325,6 @@ class AbstractWarehouse(ABC, Serializable):
323
325
  self, dataset: DatasetRecord, version: int
324
326
  ) -> list[StorageURI]: ...
325
327
 
326
- def nodes_dataset_query(
327
- self,
328
- dataset_rows: "DataTable",
329
- *,
330
- column_names: Iterable[str],
331
- path: Optional[str] = None,
332
- recursive: Optional[bool] = False,
333
- ) -> "sa.Select":
334
- """
335
- Creates query pointing to certain bucket listing represented by dataset_rows
336
- The given `column_names`
337
- will be selected in the order they're given. `path` is a glob which
338
- will select files in matching directories, or if `recursive=True` is
339
- set then the entire tree under matching directories will be selected.
340
- """
341
- dr = dataset_rows
342
-
343
- def _is_glob(path: str) -> bool:
344
- return any(c in path for c in ["*", "?", "[", "]"])
345
-
346
- column_objects = [dr.c[c] for c in column_names]
347
- # include all object types - file, tar archive, tar file (subobject)
348
- select_query = dr.select(*column_objects).where(dr.c.is_latest == true())
349
- if path is None:
350
- return select_query
351
- if recursive:
352
- root = False
353
- where = self.path_expr(dr).op("GLOB")(path)
354
- if not path or path == "/":
355
- # root of the bucket, e.g s3://bucket/ -> getting all the nodes
356
- # in the bucket
357
- root = True
358
-
359
- if not root and not _is_glob(path):
360
- # not a root and not a explicit glob, so it's pointing to some directory
361
- # and we are adding a proper glob syntax for it
362
- # e.g s3://bucket/dir1 -> s3://bucket/dir1/*
363
- dir_path = path.rstrip("/") + "/*"
364
- where = where | self.path_expr(dr).op("GLOB")(dir_path)
365
-
366
- if not root:
367
- # not a root, so running glob query
368
- select_query = select_query.where(where)
369
-
370
- else:
371
- parent = self.get_node_by_path(dr, path.lstrip("/").rstrip("/*"))
372
- select_query = select_query.where(pathfunc.parent(dr.c.path) == parent.path)
373
- return select_query
374
-
375
328
  def rename_dataset_table(
376
329
  self,
377
330
  old_name: str,
@@ -475,8 +428,14 @@ class AbstractWarehouse(ABC, Serializable):
475
428
  self,
476
429
  query: sa.Select,
477
430
  type: str,
431
+ dataset_rows: "DataTable",
478
432
  include_subobjects: bool = True,
479
433
  ) -> sa.Select:
434
+ dr = dataset_rows
435
+
436
+ def col(name: str):
437
+ return getattr(query.selected_columns, dr.col_name(name))
438
+
480
439
  file_group: Sequence[int]
481
440
  if type in {"f", "file", "files"}:
482
441
  if include_subobjects:
@@ -491,21 +450,21 @@ class AbstractWarehouse(ABC, Serializable):
491
450
  else:
492
451
  raise ValueError(f"invalid file type: {type!r}")
493
452
 
494
- c = query.selected_columns
495
- q = query.where(c.dir_type.in_(file_group))
453
+ q = query.where(col("dir_type").in_(file_group))
496
454
  if not include_subobjects:
497
- q = q.where((c.location == "") | (c.location.is_(None)))
455
+ q = q.where((col("location") == "") | (col("location").is_(None)))
498
456
  return q
499
457
 
500
- def get_nodes(self, query) -> Iterator[Node]:
458
+ def get_nodes(self, query, dataset_rows: "DataTable") -> Iterator[Node]:
501
459
  """
502
460
  This gets nodes based on the provided query, and should be used sparingly,
503
461
  as it will be slow on any OLAP database systems.
504
462
  """
463
+ dr = dataset_rows
505
464
  columns = [c.name for c in query.selected_columns]
506
465
  for row in self.db.execute(query):
507
466
  d = dict(zip(columns, row))
508
- yield Node(**d)
467
+ yield Node(**{dr.without_object(k): v for k, v in d.items()})
509
468
 
510
469
  def get_dirs_by_parent_path(
511
470
  self,
@@ -518,48 +477,56 @@ class AbstractWarehouse(ABC, Serializable):
518
477
  dr,
519
478
  parent_path,
520
479
  type="dir",
521
- conds=[pathfunc.parent(sa.Column("path")) == parent_path],
522
- order_by=["source", "path"],
480
+ conds=[pathfunc.parent(sa.Column(dr.col_name("path"))) == parent_path],
481
+ order_by=[dr.col_name("source"), dr.col_name("path")],
523
482
  )
524
- return self.get_nodes(query)
483
+ return self.get_nodes(query, dr)
525
484
 
526
485
  def _get_nodes_by_glob_path_pattern(
527
- self, dataset_rows: "DataTable", path_list: list[str], glob_name: str
486
+ self,
487
+ dataset_rows: "DataTable",
488
+ path_list: list[str],
489
+ glob_name: str,
490
+ object_name="file",
528
491
  ) -> Iterator[Node]:
529
492
  """Finds all Nodes that correspond to GLOB like path pattern."""
530
493
  dr = dataset_rows
531
- de = dr.dataset_dir_expansion(
532
- dr.select().where(dr.c.is_latest == true()).subquery()
494
+ de = dr.dir_expansion()
495
+ q = de.query(
496
+ dr.select().where(dr.c("is_latest") == true()).subquery()
533
497
  ).subquery()
534
498
  path_glob = "/".join([*path_list, glob_name])
535
499
  dirpath = path_glob[: -len(glob_name)]
536
- relpath = func.substr(self.path_expr(de), len(dirpath) + 1)
500
+ relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
537
501
 
538
502
  return self.get_nodes(
539
- self.expand_query(de, dr)
503
+ self.expand_query(de, q, dr)
540
504
  .where(
541
- (self.path_expr(de).op("GLOB")(path_glob))
505
+ (de.c(q, "path").op("GLOB")(path_glob))
542
506
  & ~self.instr(relpath, "/")
543
- & (self.path_expr(de) != dirpath)
507
+ & (de.c(q, "path") != dirpath)
544
508
  )
545
- .order_by(de.c.source, de.c.path, de.c.version)
509
+ .order_by(de.c(q, "source"), de.c(q, "path"), de.c(q, "version")),
510
+ dr,
546
511
  )
547
512
 
548
513
  def _get_node_by_path_list(
549
514
  self, dataset_rows: "DataTable", path_list: list[str], name: str
550
- ) -> Node:
515
+ ) -> "Node":
551
516
  """
552
517
  Gets node that correspond some path list, e.g ["data-lakes", "dogs-and-cats"]
553
518
  """
554
519
  parent = "/".join(path_list)
555
520
  dr = dataset_rows
556
- de = dr.dataset_dir_expansion(
557
- dr.select().where(dr.c.is_latest == true()).subquery()
521
+ de = dr.dir_expansion()
522
+ q = de.query(
523
+ dr.select().where(dr.c("is_latest") == true()).subquery(),
524
+ object_name=dr.object_name,
558
525
  ).subquery()
559
- query = self.expand_query(de, dr)
526
+ q = self.expand_query(de, q, dr)
560
527
 
561
- q = query.where(de.c.path == get_path(parent, name)).order_by(
562
- de.c.source, de.c.path, de.c.version
528
+ q = q.where(de.c(q, "path") == get_path(parent, name)).order_by(
529
+ de.c(q, "source"), de.c(q, "path"), de.c(q, "version")
563
530
  )
564
531
  row = next(self.dataset_rows_select(q), None)
565
532
  if not row:
@@ -608,29 +575,34 @@ class AbstractWarehouse(ABC, Serializable):
608
575
  return result
609
576
 
610
577
  @staticmethod
611
- def expand_query(dir_expanded_query, dataset_rows: "DataTable"):
578
+ def expand_query(dir_expansion, dir_expanded_query, dataset_rows: "DataTable"):
612
579
  dr = dataset_rows
613
- de = dir_expanded_query
580
+ de = dir_expansion
581
+ q = dir_expanded_query
614
582
 
615
583
  def with_default(column):
616
- default = getattr(attrs.fields(Node), column.name).default
584
+ default = getattr(
585
+ attrs.fields(Node), dr.without_object(column.name)
586
+ ).default
617
587
  return func.coalesce(column, default).label(column.name)
618
588
 
619
589
  return sa.select(
620
- de.c.sys__id,
621
- case((de.c.is_dir == true(), DirType.DIR), else_=DirType.FILE).label(
622
- "dir_type"
590
+ q.c.sys__id,
591
+ case((de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE).label(
592
+ dr.col_name("dir_type")
623
593
  ),
624
- de.c.path,
625
- with_default(dr.c.etag),
626
- de.c.version,
627
- with_default(dr.c.is_latest),
628
- dr.c.last_modified,
629
- with_default(dr.c.size),
630
- with_default(dr.c.sys__rand),
631
- dr.c.location,
632
- de.c.source,
633
- ).select_from(de.outerjoin(dr.table, de.c.sys__id == dr.c.sys__id))
594
+ de.c(q, "path"),
595
+ with_default(dr.c("etag")),
596
+ de.c(q, "version"),
597
+ with_default(dr.c("is_latest")),
598
+ dr.c("last_modified"),
599
+ with_default(dr.c("size")),
600
+ with_default(dr.c("rand", object_name="sys")),
601
+ dr.c("location"),
602
+ de.c(q, "source"),
603
+ ).select_from(
604
+ q.outerjoin(dr.table, q.c.sys__id == dr.c("id", object_name="sys"))
605
+ )
634
606
 
635
607
  def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
636
608
  """Gets node that corresponds to some path"""
@@ -639,18 +611,18 @@ class AbstractWarehouse(ABC, Serializable):
639
611
  dr = dataset_rows
640
612
  if not path.endswith("/"):
641
613
  query = dr.select().where(
642
- self.path_expr(dr) == path,
643
- dr.c.is_latest == true(),
614
+ dr.c("path") == path,
615
+ dr.c("is_latest") == true(),
644
616
  )
645
- row = next(self.db.execute(query), None)
646
- if row is not None:
647
- return Node(*row)
617
+ node = next(self.get_nodes(query, dr), None)
618
+ if node:
619
+ return node
648
620
  path += "/"
649
621
  query = sa.select(1).where(
650
622
  dr.select()
651
623
  .where(
652
- dr.c.is_latest == true(),
653
- dr.c.path.startswith(path),
624
+ dr.c("is_latest") == true(),
625
+ dr.c("path").startswith(path),
654
626
  )
655
627
  .exists()
656
628
  )
@@ -679,25 +651,26 @@ class AbstractWarehouse(ABC, Serializable):
679
651
  Gets latest-version file nodes from the provided parent path
680
652
  """
681
653
  dr = dataset_rows
682
- de = dr.dataset_dir_expansion(
683
- dr.select().where(dr.c.is_latest == true()).subquery()
654
+ de = dr.dir_expansion()
655
+ q = de.query(
656
+ dr.select().where(dr.c("is_latest") == true()).subquery()
684
657
  ).subquery()
685
- where_cond = pathfunc.parent(de.c.path) == parent_path
658
+ where_cond = pathfunc.parent(de.c(q, "path")) == parent_path
686
659
  if parent_path == "":
687
660
  # Exclude the root dir
688
- where_cond = where_cond & (de.c.path != "")
689
- inner_query = self.expand_query(de, dr).where(where_cond).subquery()
661
+ where_cond = where_cond & (de.c(q, "path") != "")
662
+ inner_query = self.expand_query(de, q, dr).where(where_cond).subquery()
690
663
 
691
664
  def field_to_expr(f):
692
665
  if f == "name":
693
- return pathfunc.name(inner_query.c.path)
694
- return getattr(inner_query.c, f)
666
+ return pathfunc.name(de.c(inner_query, "path"))
667
+ return de.c(inner_query, f)
695
668
 
696
669
  return self.db.execute(
697
670
  select(*(field_to_expr(f) for f in fields)).order_by(
698
- inner_query.c.source,
699
- inner_query.c.path,
700
- inner_query.c.version,
671
+ de.c(inner_query, "source"),
672
+ de.c(inner_query, "path"),
673
+ de.c(inner_query, "version"),
701
674
  )
702
675
  )
703
676
 
@@ -712,17 +685,17 @@ class AbstractWarehouse(ABC, Serializable):
712
685
 
713
686
  def field_to_expr(f):
714
687
  if f == "name":
715
- return pathfunc.name(dr.c.path)
716
- return getattr(dr.c, f)
688
+ return pathfunc.name(dr.c("path"))
689
+ return dr.c(f)
717
690
 
718
691
  q = (
719
692
  select(*(field_to_expr(f) for f in fields))
720
693
  .where(
721
- self.path_expr(dr).like(f"{sql_escape_like(dirpath)}%"),
722
- ~self.instr(pathfunc.name(dr.c.path), "/"),
723
- dr.c.is_latest == true(),
694
+ dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
695
+ ~self.instr(pathfunc.name(dr.c("path")), "/"),
696
+ dr.c("is_latest") == true(),
724
697
  )
725
- .order_by(dr.c.source, dr.c.path, dr.c.version, dr.c.etag)
698
+ .order_by(dr.c("source"), dr.c("path"), dr.c("version"), dr.c("etag"))
726
699
  )
727
700
  return self.db.execute(q)
728
701
 
@@ -751,15 +724,14 @@ class AbstractWarehouse(ABC, Serializable):
751
724
  sub_glob = posixpath.join(path, "*")
752
725
  dr = dataset_rows
753
726
  selections: list[sa.ColumnElement] = [
754
- func.sum(dr.c.size),
727
+ func.sum(dr.c("size")),
755
728
  ]
756
729
  if count_files:
757
730
  selections.append(func.count())
758
731
  results = next(
759
732
  self.db.execute(
760
733
  dr.select(*selections).where(
761
- (self.path_expr(dr).op("GLOB")(sub_glob))
762
- & (dr.c.is_latest == true())
734
+ (dr.c("path").op("GLOB")(sub_glob)) & (dr.c("is_latest") == true())
763
735
  )
764
736
  ),
765
737
  (0, 0),
@@ -768,9 +740,6 @@ class AbstractWarehouse(ABC, Serializable):
768
740
  return results[0] or 0, results[1] or 0
769
741
  return results[0] or 0, 0
770
742
 
771
- def path_expr(self, t):
772
- return t.c.path
773
-
774
743
  def _find_query(
775
744
  self,
776
745
  dataset_rows: "DataTable",
@@ -785,11 +754,12 @@ class AbstractWarehouse(ABC, Serializable):
785
754
  conds = []
786
755
 
787
756
  dr = dataset_rows
788
- de = dr.dataset_dir_expansion(
789
- dr.select().where(dr.c.is_latest == true()).subquery()
757
+ de = dr.dir_expansion()
758
+ q = de.query(
759
+ dr.select().where(dr.c("is_latest") == true()).subquery()
790
760
  ).subquery()
791
- q = self.expand_query(de, dr).subquery()
792
- path = self.path_expr(q)
761
+ q = self.expand_query(de, q, dr).subquery()
762
+ path = de.c(q, "path")
793
763
 
794
764
  if parent_path:
795
765
  sub_glob = posixpath.join(parent_path, "*")
@@ -804,7 +774,7 @@ class AbstractWarehouse(ABC, Serializable):
804
774
  query = sa.select(*columns)
805
775
  query = query.where(*conds)
806
776
  if type is not None:
807
- query = self.add_node_type_where(query, type, include_subobjects)
777
+ query = self.add_node_type_where(query, type, dr, include_subobjects)
808
778
  if order_by is not None:
809
779
  if isinstance(order_by, str):
810
780
  order_by = [order_by]
@@ -832,14 +802,14 @@ class AbstractWarehouse(ABC, Serializable):
832
802
  if sort is not None:
833
803
  if not isinstance(sort, list):
834
804
  sort = [sort]
835
- query = query.order_by(*(sa.text(s) for s in sort)) # type: ignore [attr-defined]
805
+ query = query.order_by(*(sa.text(dr.col_name(s)) for s in sort)) # type: ignore [attr-defined]
836
806
 
837
807
  prefix_len = len(node.path)
838
808
 
839
809
  def make_node_with_path(node: Node) -> NodeWithPath:
840
810
  return NodeWithPath(node, node.path[prefix_len:].lstrip("/").split("/"))
841
811
 
842
- return map(make_node_with_path, self.get_nodes(query))
812
+ return map(make_node_with_path, self.get_nodes(query, dr))
843
813
 
844
814
  def find(
845
815
  self,
@@ -854,8 +824,10 @@ class AbstractWarehouse(ABC, Serializable):
854
824
  Finds nodes that match certain criteria and only looks for latest nodes
855
825
  under the passed node.
856
826
  """
827
+ dr = dataset_rows
828
+ fields = [dr.col_name(f) for f in fields]
857
829
  query = self._find_query(
858
- dataset_rows,
830
+ dr,
859
831
  node.path,
860
832
  fields=fields,
861
833
  type=type,
@@ -4,15 +4,11 @@ from typing import Any
4
4
  from sqlalchemy import ColumnElement
5
5
 
6
6
 
7
- def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
8
- res = {}
9
- for name, sql_exp in args_map.items():
10
- try:
11
- type_ = sql_exp.type.python_type
12
- if type_ == Decimal:
13
- type_ = float
14
- except NotImplementedError:
15
- type_ = str
16
- res[name] = type_
17
-
18
- return res
7
+ def sql_to_python(sql_exp: ColumnElement) -> Any:
8
+ try:
9
+ type_ = sql_exp.type.python_type
10
+ if type_ == Decimal:
11
+ type_ = float
12
+ except NotImplementedError:
13
+ type_ = str
14
+ return type_