datachain 0.6.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -643,7 +643,7 @@ class SQLiteWarehouse(AbstractWarehouse):
643
643
  self, dataset: DatasetRecord, version: int
644
644
  ) -> list[StorageURI]:
645
645
  dr = self.dataset_rows(dataset, version)
646
- query = dr.select(dr.c.file__source).distinct()
646
+ query = dr.select(dr.c("source", object_name="file")).distinct()
647
647
  cur = self.db.cursor()
648
648
  cur.row_factory = sqlite3.Row # type: ignore[assignment]
649
649
 
@@ -671,13 +671,13 @@ class SQLiteWarehouse(AbstractWarehouse):
671
671
  # destination table doesn't exist, create it
672
672
  self.create_dataset_rows_table(
673
673
  self.dataset_table_name(dst.name, dst_version),
674
- columns=src_dr.c,
674
+ columns=src_dr.columns,
675
675
  )
676
676
  dst_empty = True
677
677
 
678
678
  dst_dr = self.dataset_rows(dst, dst_version).table
679
- merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
680
- select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))
679
+ merge_fields = [c.name for c in src_dr.columns if c.name != "sys__id"]
680
+ select_src = select(*(getattr(src_dr.columns, f) for f in merge_fields))
681
681
 
682
682
  if dst_empty:
683
683
  # we don't need union, but just select from source to destination
@@ -185,7 +185,12 @@ class AbstractWarehouse(ABC, Serializable):
185
185
  @abstractmethod
186
186
  def is_ready(self, timeout: Optional[int] = None) -> bool: ...
187
187
 
188
- def dataset_rows(self, dataset: DatasetRecord, version: Optional[int] = None):
188
+ def dataset_rows(
189
+ self,
190
+ dataset: DatasetRecord,
191
+ version: Optional[int] = None,
192
+ object_name: str = "file",
193
+ ):
189
194
  version = version or dataset.latest_version
190
195
 
191
196
  table_name = self.dataset_table_name(dataset.name, version)
@@ -194,6 +199,7 @@ class AbstractWarehouse(ABC, Serializable):
194
199
  self.db.engine,
195
200
  self.db.metadata,
196
201
  dataset.get_schema(version),
202
+ object_name=object_name,
197
203
  )
198
204
 
199
205
  @property
@@ -319,55 +325,6 @@ class AbstractWarehouse(ABC, Serializable):
319
325
  self, dataset: DatasetRecord, version: int
320
326
  ) -> list[StorageURI]: ...
321
327
 
322
- def nodes_dataset_query(
323
- self,
324
- dataset_rows: "DataTable",
325
- *,
326
- column_names: Iterable[str],
327
- path: Optional[str] = None,
328
- recursive: Optional[bool] = False,
329
- ) -> "sa.Select":
330
- """
331
- Creates query pointing to certain bucket listing represented by dataset_rows
332
- The given `column_names`
333
- will be selected in the order they're given. `path` is a glob which
334
- will select files in matching directories, or if `recursive=True` is
335
- set then the entire tree under matching directories will be selected.
336
- """
337
- dr = dataset_rows
338
-
339
- def _is_glob(path: str) -> bool:
340
- return any(c in path for c in ["*", "?", "[", "]"])
341
-
342
- column_objects = [dr.c[c] for c in column_names]
343
- # include all object types - file, tar archive, tar file (subobject)
344
- select_query = dr.select(*column_objects).where(dr.c.is_latest == true())
345
- if path is None:
346
- return select_query
347
- if recursive:
348
- root = False
349
- where = self.path_expr(dr).op("GLOB")(path)
350
- if not path or path == "/":
351
- # root of the bucket, e.g s3://bucket/ -> getting all the nodes
352
- # in the bucket
353
- root = True
354
-
355
- if not root and not _is_glob(path):
356
- # not a root and not a explicit glob, so it's pointing to some directory
357
- # and we are adding a proper glob syntax for it
358
- # e.g s3://bucket/dir1 -> s3://bucket/dir1/*
359
- dir_path = path.rstrip("/") + "/*"
360
- where = where | self.path_expr(dr).op("GLOB")(dir_path)
361
-
362
- if not root:
363
- # not a root, so running glob query
364
- select_query = select_query.where(where)
365
-
366
- else:
367
- parent = self.get_node_by_path(dr, path.lstrip("/").rstrip("/*"))
368
- select_query = select_query.where(pathfunc.parent(dr.c.path) == parent.path)
369
- return select_query
370
-
371
328
  def rename_dataset_table(
372
329
  self,
373
330
  old_name: str,
@@ -471,8 +428,14 @@ class AbstractWarehouse(ABC, Serializable):
471
428
  self,
472
429
  query: sa.Select,
473
430
  type: str,
431
+ dataset_rows: "DataTable",
474
432
  include_subobjects: bool = True,
475
433
  ) -> sa.Select:
434
+ dr = dataset_rows
435
+
436
+ def col(name: str):
437
+ return getattr(query.selected_columns, dr.col_name(name))
438
+
476
439
  file_group: Sequence[int]
477
440
  if type in {"f", "file", "files"}:
478
441
  if include_subobjects:
@@ -487,21 +450,21 @@ class AbstractWarehouse(ABC, Serializable):
487
450
  else:
488
451
  raise ValueError(f"invalid file type: {type!r}")
489
452
 
490
- c = query.selected_columns
491
- q = query.where(c.dir_type.in_(file_group))
453
+ q = query.where(col("dir_type").in_(file_group))
492
454
  if not include_subobjects:
493
- q = q.where((c.location == "") | (c.location.is_(None)))
455
+ q = q.where((col("location") == "") | (col("location").is_(None)))
494
456
  return q
495
457
 
496
- def get_nodes(self, query) -> Iterator[Node]:
458
+ def get_nodes(self, query, dataset_rows: "DataTable") -> Iterator[Node]:
497
459
  """
498
460
  This gets nodes based on the provided query, and should be used sparingly,
499
461
  as it will be slow on any OLAP database systems.
500
462
  """
463
+ dr = dataset_rows
501
464
  columns = [c.name for c in query.selected_columns]
502
465
  for row in self.db.execute(query):
503
466
  d = dict(zip(columns, row))
504
- yield Node(**d)
467
+ yield Node(**{dr.without_object(k): v for k, v in d.items()})
505
468
 
506
469
  def get_dirs_by_parent_path(
507
470
  self,
@@ -514,48 +477,56 @@ class AbstractWarehouse(ABC, Serializable):
514
477
  dr,
515
478
  parent_path,
516
479
  type="dir",
517
- conds=[pathfunc.parent(sa.Column("path")) == parent_path],
518
- order_by=["source", "path"],
480
+ conds=[pathfunc.parent(sa.Column(dr.col_name("path"))) == parent_path],
481
+ order_by=[dr.col_name("source"), dr.col_name("path")],
519
482
  )
520
- return self.get_nodes(query)
483
+ return self.get_nodes(query, dr)
521
484
 
522
485
  def _get_nodes_by_glob_path_pattern(
523
- self, dataset_rows: "DataTable", path_list: list[str], glob_name: str
486
+ self,
487
+ dataset_rows: "DataTable",
488
+ path_list: list[str],
489
+ glob_name: str,
490
+ object_name="file",
524
491
  ) -> Iterator[Node]:
525
492
  """Finds all Nodes that correspond to GLOB like path pattern."""
526
493
  dr = dataset_rows
527
- de = dr.dataset_dir_expansion(
528
- dr.select().where(dr.c.is_latest == true()).subquery()
494
+ de = dr.dir_expansion()
495
+ q = de.query(
496
+ dr.select().where(dr.c("is_latest") == true()).subquery()
529
497
  ).subquery()
530
498
  path_glob = "/".join([*path_list, glob_name])
531
499
  dirpath = path_glob[: -len(glob_name)]
532
- relpath = func.substr(self.path_expr(de), len(dirpath) + 1)
500
+ relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
533
501
 
534
502
  return self.get_nodes(
535
- self.expand_query(de, dr)
503
+ self.expand_query(de, q, dr)
536
504
  .where(
537
- (self.path_expr(de).op("GLOB")(path_glob))
505
+ (de.c(q, "path").op("GLOB")(path_glob))
538
506
  & ~self.instr(relpath, "/")
539
- & (self.path_expr(de) != dirpath)
507
+ & (de.c(q, "path") != dirpath)
540
508
  )
541
- .order_by(de.c.source, de.c.path, de.c.version)
509
+ .order_by(de.c(q, "source"), de.c(q, "path"), de.c(q, "version")),
510
+ dr,
542
511
  )
543
512
 
544
513
  def _get_node_by_path_list(
545
514
  self, dataset_rows: "DataTable", path_list: list[str], name: str
546
- ) -> Node:
515
+ ) -> "Node":
547
516
  """
548
517
  Gets node that correspond some path list, e.g ["data-lakes", "dogs-and-cats"]
549
518
  """
550
519
  parent = "/".join(path_list)
551
520
  dr = dataset_rows
552
- de = dr.dataset_dir_expansion(
553
- dr.select().where(dr.c.is_latest == true()).subquery()
521
+ de = dr.dir_expansion()
522
+ q = de.query(
523
+ dr.select().where(dr.c("is_latest") == true()).subquery(),
524
+ object_name=dr.object_name,
554
525
  ).subquery()
555
- query = self.expand_query(de, dr)
526
+ q = self.expand_query(de, q, dr)
556
527
 
557
- q = query.where(de.c.path == get_path(parent, name)).order_by(
558
- de.c.source, de.c.path, de.c.version
528
+ q = q.where(de.c(q, "path") == get_path(parent, name)).order_by(
529
+ de.c(q, "source"), de.c(q, "path"), de.c(q, "version")
559
530
  )
560
531
  row = next(self.dataset_rows_select(q), None)
561
532
  if not row:
@@ -604,29 +575,34 @@ class AbstractWarehouse(ABC, Serializable):
604
575
  return result
605
576
 
606
577
  @staticmethod
607
- def expand_query(dir_expanded_query, dataset_rows: "DataTable"):
578
+ def expand_query(dir_expansion, dir_expanded_query, dataset_rows: "DataTable"):
608
579
  dr = dataset_rows
609
- de = dir_expanded_query
580
+ de = dir_expansion
581
+ q = dir_expanded_query
610
582
 
611
583
  def with_default(column):
612
- default = getattr(attrs.fields(Node), column.name).default
584
+ default = getattr(
585
+ attrs.fields(Node), dr.without_object(column.name)
586
+ ).default
613
587
  return func.coalesce(column, default).label(column.name)
614
588
 
615
589
  return sa.select(
616
- de.c.sys__id,
617
- case((de.c.is_dir == true(), DirType.DIR), else_=DirType.FILE).label(
618
- "dir_type"
590
+ q.c.sys__id,
591
+ case((de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE).label(
592
+ dr.col_name("dir_type")
619
593
  ),
620
- de.c.path,
621
- with_default(dr.c.etag),
622
- de.c.version,
623
- with_default(dr.c.is_latest),
624
- dr.c.last_modified,
625
- with_default(dr.c.size),
626
- with_default(dr.c.sys__rand),
627
- dr.c.location,
628
- de.c.source,
629
- ).select_from(de.outerjoin(dr.table, de.c.sys__id == dr.c.sys__id))
594
+ de.c(q, "path"),
595
+ with_default(dr.c("etag")),
596
+ de.c(q, "version"),
597
+ with_default(dr.c("is_latest")),
598
+ dr.c("last_modified"),
599
+ with_default(dr.c("size")),
600
+ with_default(dr.c("rand", object_name="sys")),
601
+ dr.c("location"),
602
+ de.c(q, "source"),
603
+ ).select_from(
604
+ q.outerjoin(dr.table, q.c.sys__id == dr.c("id", object_name="sys"))
605
+ )
630
606
 
631
607
  def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
632
608
  """Gets node that corresponds to some path"""
@@ -635,18 +611,18 @@ class AbstractWarehouse(ABC, Serializable):
635
611
  dr = dataset_rows
636
612
  if not path.endswith("/"):
637
613
  query = dr.select().where(
638
- self.path_expr(dr) == path,
639
- dr.c.is_latest == true(),
614
+ dr.c("path") == path,
615
+ dr.c("is_latest") == true(),
640
616
  )
641
- row = next(self.db.execute(query), None)
642
- if row is not None:
643
- return Node(*row)
617
+ node = next(self.get_nodes(query, dr), None)
618
+ if node:
619
+ return node
644
620
  path += "/"
645
621
  query = sa.select(1).where(
646
622
  dr.select()
647
623
  .where(
648
- dr.c.is_latest == true(),
649
- dr.c.path.startswith(path),
624
+ dr.c("is_latest") == true(),
625
+ dr.c("path").startswith(path),
650
626
  )
651
627
  .exists()
652
628
  )
@@ -675,25 +651,26 @@ class AbstractWarehouse(ABC, Serializable):
675
651
  Gets latest-version file nodes from the provided parent path
676
652
  """
677
653
  dr = dataset_rows
678
- de = dr.dataset_dir_expansion(
679
- dr.select().where(dr.c.is_latest == true()).subquery()
654
+ de = dr.dir_expansion()
655
+ q = de.query(
656
+ dr.select().where(dr.c("is_latest") == true()).subquery()
680
657
  ).subquery()
681
- where_cond = pathfunc.parent(de.c.path) == parent_path
658
+ where_cond = pathfunc.parent(de.c(q, "path")) == parent_path
682
659
  if parent_path == "":
683
660
  # Exclude the root dir
684
- where_cond = where_cond & (de.c.path != "")
685
- inner_query = self.expand_query(de, dr).where(where_cond).subquery()
661
+ where_cond = where_cond & (de.c(q, "path") != "")
662
+ inner_query = self.expand_query(de, q, dr).where(where_cond).subquery()
686
663
 
687
664
  def field_to_expr(f):
688
665
  if f == "name":
689
- return pathfunc.name(inner_query.c.path)
690
- return getattr(inner_query.c, f)
666
+ return pathfunc.name(de.c(inner_query, "path"))
667
+ return de.c(inner_query, f)
691
668
 
692
669
  return self.db.execute(
693
670
  select(*(field_to_expr(f) for f in fields)).order_by(
694
- inner_query.c.source,
695
- inner_query.c.path,
696
- inner_query.c.version,
671
+ de.c(inner_query, "source"),
672
+ de.c(inner_query, "path"),
673
+ de.c(inner_query, "version"),
697
674
  )
698
675
  )
699
676
 
@@ -708,17 +685,17 @@ class AbstractWarehouse(ABC, Serializable):
708
685
 
709
686
  def field_to_expr(f):
710
687
  if f == "name":
711
- return pathfunc.name(dr.c.path)
712
- return getattr(dr.c, f)
688
+ return pathfunc.name(dr.c("path"))
689
+ return dr.c(f)
713
690
 
714
691
  q = (
715
692
  select(*(field_to_expr(f) for f in fields))
716
693
  .where(
717
- self.path_expr(dr).like(f"{sql_escape_like(dirpath)}%"),
718
- ~self.instr(pathfunc.name(dr.c.path), "/"),
719
- dr.c.is_latest == true(),
694
+ dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
695
+ ~self.instr(pathfunc.name(dr.c("path")), "/"),
696
+ dr.c("is_latest") == true(),
720
697
  )
721
- .order_by(dr.c.source, dr.c.path, dr.c.version, dr.c.etag)
698
+ .order_by(dr.c("source"), dr.c("path"), dr.c("version"), dr.c("etag"))
722
699
  )
723
700
  return self.db.execute(q)
724
701
 
@@ -747,15 +724,14 @@ class AbstractWarehouse(ABC, Serializable):
747
724
  sub_glob = posixpath.join(path, "*")
748
725
  dr = dataset_rows
749
726
  selections: list[sa.ColumnElement] = [
750
- func.sum(dr.c.size),
727
+ func.sum(dr.c("size")),
751
728
  ]
752
729
  if count_files:
753
730
  selections.append(func.count())
754
731
  results = next(
755
732
  self.db.execute(
756
733
  dr.select(*selections).where(
757
- (self.path_expr(dr).op("GLOB")(sub_glob))
758
- & (dr.c.is_latest == true())
734
+ (dr.c("path").op("GLOB")(sub_glob)) & (dr.c("is_latest") == true())
759
735
  )
760
736
  ),
761
737
  (0, 0),
@@ -764,9 +740,6 @@ class AbstractWarehouse(ABC, Serializable):
764
740
  return results[0] or 0, results[1] or 0
765
741
  return results[0] or 0, 0
766
742
 
767
- def path_expr(self, t):
768
- return t.c.path
769
-
770
743
  def _find_query(
771
744
  self,
772
745
  dataset_rows: "DataTable",
@@ -781,11 +754,12 @@ class AbstractWarehouse(ABC, Serializable):
781
754
  conds = []
782
755
 
783
756
  dr = dataset_rows
784
- de = dr.dataset_dir_expansion(
785
- dr.select().where(dr.c.is_latest == true()).subquery()
757
+ de = dr.dir_expansion()
758
+ q = de.query(
759
+ dr.select().where(dr.c("is_latest") == true()).subquery()
786
760
  ).subquery()
787
- q = self.expand_query(de, dr).subquery()
788
- path = self.path_expr(q)
761
+ q = self.expand_query(de, q, dr).subquery()
762
+ path = de.c(q, "path")
789
763
 
790
764
  if parent_path:
791
765
  sub_glob = posixpath.join(parent_path, "*")
@@ -800,7 +774,7 @@ class AbstractWarehouse(ABC, Serializable):
800
774
  query = sa.select(*columns)
801
775
  query = query.where(*conds)
802
776
  if type is not None:
803
- query = self.add_node_type_where(query, type, include_subobjects)
777
+ query = self.add_node_type_where(query, type, dr, include_subobjects)
804
778
  if order_by is not None:
805
779
  if isinstance(order_by, str):
806
780
  order_by = [order_by]
@@ -828,14 +802,14 @@ class AbstractWarehouse(ABC, Serializable):
828
802
  if sort is not None:
829
803
  if not isinstance(sort, list):
830
804
  sort = [sort]
831
- query = query.order_by(*(sa.text(s) for s in sort)) # type: ignore [attr-defined]
805
+ query = query.order_by(*(sa.text(dr.col_name(s)) for s in sort)) # type: ignore [attr-defined]
832
806
 
833
807
  prefix_len = len(node.path)
834
808
 
835
809
  def make_node_with_path(node: Node) -> NodeWithPath:
836
810
  return NodeWithPath(node, node.path[prefix_len:].lstrip("/").split("/"))
837
811
 
838
- return map(make_node_with_path, self.get_nodes(query))
812
+ return map(make_node_with_path, self.get_nodes(query, dr))
839
813
 
840
814
  def find(
841
815
  self,
@@ -850,8 +824,10 @@ class AbstractWarehouse(ABC, Serializable):
850
824
  Finds nodes that match certain criteria and only looks for latest nodes
851
825
  under the passed node.
852
826
  """
827
+ dr = dataset_rows
828
+ fields = [dr.col_name(f) for f in fields]
853
829
  query = self._find_query(
854
- dataset_rows,
830
+ dr,
855
831
  node.path,
856
832
  fields=fields,
857
833
  type=type,