datachain 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +61 -219
- datachain/cli.py +136 -22
- datachain/client/fsspec.py +9 -0
- datachain/client/local.py +11 -32
- datachain/config.py +126 -51
- datachain/data_storage/schema.py +66 -33
- datachain/data_storage/sqlite.py +4 -4
- datachain/data_storage/warehouse.py +101 -125
- datachain/lib/arrow.py +2 -15
- datachain/lib/data_model.py +10 -2
- datachain/lib/dc.py +211 -52
- datachain/lib/func/__init__.py +20 -2
- datachain/lib/func/aggregate.py +319 -8
- datachain/lib/func/func.py +97 -9
- datachain/lib/listing.py +6 -21
- datachain/lib/listing_info.py +4 -0
- datachain/lib/signal_schema.py +8 -5
- datachain/lib/udf.py +3 -3
- datachain/lib/utils.py +30 -0
- datachain/listing.py +22 -48
- datachain/query/dataset.py +11 -3
- datachain/remote/studio.py +63 -14
- datachain/studio.py +129 -0
- datachain/utils.py +58 -0
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/METADATA +7 -6
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/RECORD +30 -29
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/WHEEL +1 -1
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/LICENSE +0 -0
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/top_level.txt +0 -0
datachain/data_storage/sqlite.py
CHANGED
|
@@ -643,7 +643,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
643
643
|
self, dataset: DatasetRecord, version: int
|
|
644
644
|
) -> list[StorageURI]:
|
|
645
645
|
dr = self.dataset_rows(dataset, version)
|
|
646
|
-
query = dr.select(dr.c
|
|
646
|
+
query = dr.select(dr.c("source", object_name="file")).distinct()
|
|
647
647
|
cur = self.db.cursor()
|
|
648
648
|
cur.row_factory = sqlite3.Row # type: ignore[assignment]
|
|
649
649
|
|
|
@@ -671,13 +671,13 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
671
671
|
# destination table doesn't exist, create it
|
|
672
672
|
self.create_dataset_rows_table(
|
|
673
673
|
self.dataset_table_name(dst.name, dst_version),
|
|
674
|
-
columns=src_dr.
|
|
674
|
+
columns=src_dr.columns,
|
|
675
675
|
)
|
|
676
676
|
dst_empty = True
|
|
677
677
|
|
|
678
678
|
dst_dr = self.dataset_rows(dst, dst_version).table
|
|
679
|
-
merge_fields = [c.name for c in src_dr.
|
|
680
|
-
select_src = select(*(getattr(src_dr.
|
|
679
|
+
merge_fields = [c.name for c in src_dr.columns if c.name != "sys__id"]
|
|
680
|
+
select_src = select(*(getattr(src_dr.columns, f) for f in merge_fields))
|
|
681
681
|
|
|
682
682
|
if dst_empty:
|
|
683
683
|
# we don't need union, but just select from source to destination
|
|
@@ -185,7 +185,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
185
185
|
@abstractmethod
|
|
186
186
|
def is_ready(self, timeout: Optional[int] = None) -> bool: ...
|
|
187
187
|
|
|
188
|
-
def dataset_rows(
|
|
188
|
+
def dataset_rows(
|
|
189
|
+
self,
|
|
190
|
+
dataset: DatasetRecord,
|
|
191
|
+
version: Optional[int] = None,
|
|
192
|
+
object_name: str = "file",
|
|
193
|
+
):
|
|
189
194
|
version = version or dataset.latest_version
|
|
190
195
|
|
|
191
196
|
table_name = self.dataset_table_name(dataset.name, version)
|
|
@@ -194,6 +199,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
194
199
|
self.db.engine,
|
|
195
200
|
self.db.metadata,
|
|
196
201
|
dataset.get_schema(version),
|
|
202
|
+
object_name=object_name,
|
|
197
203
|
)
|
|
198
204
|
|
|
199
205
|
@property
|
|
@@ -319,55 +325,6 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
319
325
|
self, dataset: DatasetRecord, version: int
|
|
320
326
|
) -> list[StorageURI]: ...
|
|
321
327
|
|
|
322
|
-
def nodes_dataset_query(
|
|
323
|
-
self,
|
|
324
|
-
dataset_rows: "DataTable",
|
|
325
|
-
*,
|
|
326
|
-
column_names: Iterable[str],
|
|
327
|
-
path: Optional[str] = None,
|
|
328
|
-
recursive: Optional[bool] = False,
|
|
329
|
-
) -> "sa.Select":
|
|
330
|
-
"""
|
|
331
|
-
Creates query pointing to certain bucket listing represented by dataset_rows
|
|
332
|
-
The given `column_names`
|
|
333
|
-
will be selected in the order they're given. `path` is a glob which
|
|
334
|
-
will select files in matching directories, or if `recursive=True` is
|
|
335
|
-
set then the entire tree under matching directories will be selected.
|
|
336
|
-
"""
|
|
337
|
-
dr = dataset_rows
|
|
338
|
-
|
|
339
|
-
def _is_glob(path: str) -> bool:
|
|
340
|
-
return any(c in path for c in ["*", "?", "[", "]"])
|
|
341
|
-
|
|
342
|
-
column_objects = [dr.c[c] for c in column_names]
|
|
343
|
-
# include all object types - file, tar archive, tar file (subobject)
|
|
344
|
-
select_query = dr.select(*column_objects).where(dr.c.is_latest == true())
|
|
345
|
-
if path is None:
|
|
346
|
-
return select_query
|
|
347
|
-
if recursive:
|
|
348
|
-
root = False
|
|
349
|
-
where = self.path_expr(dr).op("GLOB")(path)
|
|
350
|
-
if not path or path == "/":
|
|
351
|
-
# root of the bucket, e.g s3://bucket/ -> getting all the nodes
|
|
352
|
-
# in the bucket
|
|
353
|
-
root = True
|
|
354
|
-
|
|
355
|
-
if not root and not _is_glob(path):
|
|
356
|
-
# not a root and not a explicit glob, so it's pointing to some directory
|
|
357
|
-
# and we are adding a proper glob syntax for it
|
|
358
|
-
# e.g s3://bucket/dir1 -> s3://bucket/dir1/*
|
|
359
|
-
dir_path = path.rstrip("/") + "/*"
|
|
360
|
-
where = where | self.path_expr(dr).op("GLOB")(dir_path)
|
|
361
|
-
|
|
362
|
-
if not root:
|
|
363
|
-
# not a root, so running glob query
|
|
364
|
-
select_query = select_query.where(where)
|
|
365
|
-
|
|
366
|
-
else:
|
|
367
|
-
parent = self.get_node_by_path(dr, path.lstrip("/").rstrip("/*"))
|
|
368
|
-
select_query = select_query.where(pathfunc.parent(dr.c.path) == parent.path)
|
|
369
|
-
return select_query
|
|
370
|
-
|
|
371
328
|
def rename_dataset_table(
|
|
372
329
|
self,
|
|
373
330
|
old_name: str,
|
|
@@ -471,8 +428,14 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
471
428
|
self,
|
|
472
429
|
query: sa.Select,
|
|
473
430
|
type: str,
|
|
431
|
+
dataset_rows: "DataTable",
|
|
474
432
|
include_subobjects: bool = True,
|
|
475
433
|
) -> sa.Select:
|
|
434
|
+
dr = dataset_rows
|
|
435
|
+
|
|
436
|
+
def col(name: str):
|
|
437
|
+
return getattr(query.selected_columns, dr.col_name(name))
|
|
438
|
+
|
|
476
439
|
file_group: Sequence[int]
|
|
477
440
|
if type in {"f", "file", "files"}:
|
|
478
441
|
if include_subobjects:
|
|
@@ -487,21 +450,21 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
487
450
|
else:
|
|
488
451
|
raise ValueError(f"invalid file type: {type!r}")
|
|
489
452
|
|
|
490
|
-
|
|
491
|
-
q = query.where(c.dir_type.in_(file_group))
|
|
453
|
+
q = query.where(col("dir_type").in_(file_group))
|
|
492
454
|
if not include_subobjects:
|
|
493
|
-
q = q.where((
|
|
455
|
+
q = q.where((col("location") == "") | (col("location").is_(None)))
|
|
494
456
|
return q
|
|
495
457
|
|
|
496
|
-
def get_nodes(self, query) -> Iterator[Node]:
|
|
458
|
+
def get_nodes(self, query, dataset_rows: "DataTable") -> Iterator[Node]:
|
|
497
459
|
"""
|
|
498
460
|
This gets nodes based on the provided query, and should be used sparingly,
|
|
499
461
|
as it will be slow on any OLAP database systems.
|
|
500
462
|
"""
|
|
463
|
+
dr = dataset_rows
|
|
501
464
|
columns = [c.name for c in query.selected_columns]
|
|
502
465
|
for row in self.db.execute(query):
|
|
503
466
|
d = dict(zip(columns, row))
|
|
504
|
-
yield Node(**d)
|
|
467
|
+
yield Node(**{dr.without_object(k): v for k, v in d.items()})
|
|
505
468
|
|
|
506
469
|
def get_dirs_by_parent_path(
|
|
507
470
|
self,
|
|
@@ -514,48 +477,56 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
514
477
|
dr,
|
|
515
478
|
parent_path,
|
|
516
479
|
type="dir",
|
|
517
|
-
conds=[pathfunc.parent(sa.Column("path")) == parent_path],
|
|
518
|
-
order_by=["source", "path"],
|
|
480
|
+
conds=[pathfunc.parent(sa.Column(dr.col_name("path"))) == parent_path],
|
|
481
|
+
order_by=[dr.col_name("source"), dr.col_name("path")],
|
|
519
482
|
)
|
|
520
|
-
return self.get_nodes(query)
|
|
483
|
+
return self.get_nodes(query, dr)
|
|
521
484
|
|
|
522
485
|
def _get_nodes_by_glob_path_pattern(
|
|
523
|
-
self,
|
|
486
|
+
self,
|
|
487
|
+
dataset_rows: "DataTable",
|
|
488
|
+
path_list: list[str],
|
|
489
|
+
glob_name: str,
|
|
490
|
+
object_name="file",
|
|
524
491
|
) -> Iterator[Node]:
|
|
525
492
|
"""Finds all Nodes that correspond to GLOB like path pattern."""
|
|
526
493
|
dr = dataset_rows
|
|
527
|
-
de = dr.
|
|
528
|
-
|
|
494
|
+
de = dr.dir_expansion()
|
|
495
|
+
q = de.query(
|
|
496
|
+
dr.select().where(dr.c("is_latest") == true()).subquery()
|
|
529
497
|
).subquery()
|
|
530
498
|
path_glob = "/".join([*path_list, glob_name])
|
|
531
499
|
dirpath = path_glob[: -len(glob_name)]
|
|
532
|
-
relpath = func.substr(
|
|
500
|
+
relpath = func.substr(de.c(q, "path"), len(dirpath) + 1)
|
|
533
501
|
|
|
534
502
|
return self.get_nodes(
|
|
535
|
-
self.expand_query(de, dr)
|
|
503
|
+
self.expand_query(de, q, dr)
|
|
536
504
|
.where(
|
|
537
|
-
(
|
|
505
|
+
(de.c(q, "path").op("GLOB")(path_glob))
|
|
538
506
|
& ~self.instr(relpath, "/")
|
|
539
|
-
& (
|
|
507
|
+
& (de.c(q, "path") != dirpath)
|
|
540
508
|
)
|
|
541
|
-
.order_by(de.c
|
|
509
|
+
.order_by(de.c(q, "source"), de.c(q, "path"), de.c(q, "version")),
|
|
510
|
+
dr,
|
|
542
511
|
)
|
|
543
512
|
|
|
544
513
|
def _get_node_by_path_list(
|
|
545
514
|
self, dataset_rows: "DataTable", path_list: list[str], name: str
|
|
546
|
-
) -> Node:
|
|
515
|
+
) -> "Node":
|
|
547
516
|
"""
|
|
548
517
|
Gets node that correspond some path list, e.g ["data-lakes", "dogs-and-cats"]
|
|
549
518
|
"""
|
|
550
519
|
parent = "/".join(path_list)
|
|
551
520
|
dr = dataset_rows
|
|
552
|
-
de = dr.
|
|
553
|
-
|
|
521
|
+
de = dr.dir_expansion()
|
|
522
|
+
q = de.query(
|
|
523
|
+
dr.select().where(dr.c("is_latest") == true()).subquery(),
|
|
524
|
+
object_name=dr.object_name,
|
|
554
525
|
).subquery()
|
|
555
|
-
|
|
526
|
+
q = self.expand_query(de, q, dr)
|
|
556
527
|
|
|
557
|
-
q =
|
|
558
|
-
de.c
|
|
528
|
+
q = q.where(de.c(q, "path") == get_path(parent, name)).order_by(
|
|
529
|
+
de.c(q, "source"), de.c(q, "path"), de.c(q, "version")
|
|
559
530
|
)
|
|
560
531
|
row = next(self.dataset_rows_select(q), None)
|
|
561
532
|
if not row:
|
|
@@ -604,29 +575,34 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
604
575
|
return result
|
|
605
576
|
|
|
606
577
|
@staticmethod
|
|
607
|
-
def expand_query(dir_expanded_query, dataset_rows: "DataTable"):
|
|
578
|
+
def expand_query(dir_expansion, dir_expanded_query, dataset_rows: "DataTable"):
|
|
608
579
|
dr = dataset_rows
|
|
609
|
-
de =
|
|
580
|
+
de = dir_expansion
|
|
581
|
+
q = dir_expanded_query
|
|
610
582
|
|
|
611
583
|
def with_default(column):
|
|
612
|
-
default = getattr(
|
|
584
|
+
default = getattr(
|
|
585
|
+
attrs.fields(Node), dr.without_object(column.name)
|
|
586
|
+
).default
|
|
613
587
|
return func.coalesce(column, default).label(column.name)
|
|
614
588
|
|
|
615
589
|
return sa.select(
|
|
616
|
-
|
|
617
|
-
case((de.c
|
|
618
|
-
"dir_type"
|
|
590
|
+
q.c.sys__id,
|
|
591
|
+
case((de.c(q, "is_dir") == true(), DirType.DIR), else_=DirType.FILE).label(
|
|
592
|
+
dr.col_name("dir_type")
|
|
619
593
|
),
|
|
620
|
-
de.c
|
|
621
|
-
with_default(dr.c
|
|
622
|
-
de.c
|
|
623
|
-
with_default(dr.c
|
|
624
|
-
dr.c
|
|
625
|
-
with_default(dr.c
|
|
626
|
-
with_default(dr.c
|
|
627
|
-
dr.c
|
|
628
|
-
de.c
|
|
629
|
-
).select_from(
|
|
594
|
+
de.c(q, "path"),
|
|
595
|
+
with_default(dr.c("etag")),
|
|
596
|
+
de.c(q, "version"),
|
|
597
|
+
with_default(dr.c("is_latest")),
|
|
598
|
+
dr.c("last_modified"),
|
|
599
|
+
with_default(dr.c("size")),
|
|
600
|
+
with_default(dr.c("rand", object_name="sys")),
|
|
601
|
+
dr.c("location"),
|
|
602
|
+
de.c(q, "source"),
|
|
603
|
+
).select_from(
|
|
604
|
+
q.outerjoin(dr.table, q.c.sys__id == dr.c("id", object_name="sys"))
|
|
605
|
+
)
|
|
630
606
|
|
|
631
607
|
def get_node_by_path(self, dataset_rows: "DataTable", path: str) -> Node:
|
|
632
608
|
"""Gets node that corresponds to some path"""
|
|
@@ -635,18 +611,18 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
635
611
|
dr = dataset_rows
|
|
636
612
|
if not path.endswith("/"):
|
|
637
613
|
query = dr.select().where(
|
|
638
|
-
|
|
639
|
-
dr.c
|
|
614
|
+
dr.c("path") == path,
|
|
615
|
+
dr.c("is_latest") == true(),
|
|
640
616
|
)
|
|
641
|
-
|
|
642
|
-
if
|
|
643
|
-
return
|
|
617
|
+
node = next(self.get_nodes(query, dr), None)
|
|
618
|
+
if node:
|
|
619
|
+
return node
|
|
644
620
|
path += "/"
|
|
645
621
|
query = sa.select(1).where(
|
|
646
622
|
dr.select()
|
|
647
623
|
.where(
|
|
648
|
-
dr.c
|
|
649
|
-
dr.c
|
|
624
|
+
dr.c("is_latest") == true(),
|
|
625
|
+
dr.c("path").startswith(path),
|
|
650
626
|
)
|
|
651
627
|
.exists()
|
|
652
628
|
)
|
|
@@ -675,25 +651,26 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
675
651
|
Gets latest-version file nodes from the provided parent path
|
|
676
652
|
"""
|
|
677
653
|
dr = dataset_rows
|
|
678
|
-
de = dr.
|
|
679
|
-
|
|
654
|
+
de = dr.dir_expansion()
|
|
655
|
+
q = de.query(
|
|
656
|
+
dr.select().where(dr.c("is_latest") == true()).subquery()
|
|
680
657
|
).subquery()
|
|
681
|
-
where_cond = pathfunc.parent(de.c
|
|
658
|
+
where_cond = pathfunc.parent(de.c(q, "path")) == parent_path
|
|
682
659
|
if parent_path == "":
|
|
683
660
|
# Exclude the root dir
|
|
684
|
-
where_cond = where_cond & (de.c
|
|
685
|
-
inner_query = self.expand_query(de, dr).where(where_cond).subquery()
|
|
661
|
+
where_cond = where_cond & (de.c(q, "path") != "")
|
|
662
|
+
inner_query = self.expand_query(de, q, dr).where(where_cond).subquery()
|
|
686
663
|
|
|
687
664
|
def field_to_expr(f):
|
|
688
665
|
if f == "name":
|
|
689
|
-
return pathfunc.name(
|
|
690
|
-
return
|
|
666
|
+
return pathfunc.name(de.c(inner_query, "path"))
|
|
667
|
+
return de.c(inner_query, f)
|
|
691
668
|
|
|
692
669
|
return self.db.execute(
|
|
693
670
|
select(*(field_to_expr(f) for f in fields)).order_by(
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
671
|
+
de.c(inner_query, "source"),
|
|
672
|
+
de.c(inner_query, "path"),
|
|
673
|
+
de.c(inner_query, "version"),
|
|
697
674
|
)
|
|
698
675
|
)
|
|
699
676
|
|
|
@@ -708,17 +685,17 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
708
685
|
|
|
709
686
|
def field_to_expr(f):
|
|
710
687
|
if f == "name":
|
|
711
|
-
return pathfunc.name(dr.c
|
|
712
|
-
return
|
|
688
|
+
return pathfunc.name(dr.c("path"))
|
|
689
|
+
return dr.c(f)
|
|
713
690
|
|
|
714
691
|
q = (
|
|
715
692
|
select(*(field_to_expr(f) for f in fields))
|
|
716
693
|
.where(
|
|
717
|
-
|
|
718
|
-
~self.instr(pathfunc.name(dr.c
|
|
719
|
-
dr.c
|
|
694
|
+
dr.c("path").like(f"{sql_escape_like(dirpath)}%"),
|
|
695
|
+
~self.instr(pathfunc.name(dr.c("path")), "/"),
|
|
696
|
+
dr.c("is_latest") == true(),
|
|
720
697
|
)
|
|
721
|
-
.order_by(dr.c
|
|
698
|
+
.order_by(dr.c("source"), dr.c("path"), dr.c("version"), dr.c("etag"))
|
|
722
699
|
)
|
|
723
700
|
return self.db.execute(q)
|
|
724
701
|
|
|
@@ -747,15 +724,14 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
747
724
|
sub_glob = posixpath.join(path, "*")
|
|
748
725
|
dr = dataset_rows
|
|
749
726
|
selections: list[sa.ColumnElement] = [
|
|
750
|
-
func.sum(dr.c
|
|
727
|
+
func.sum(dr.c("size")),
|
|
751
728
|
]
|
|
752
729
|
if count_files:
|
|
753
730
|
selections.append(func.count())
|
|
754
731
|
results = next(
|
|
755
732
|
self.db.execute(
|
|
756
733
|
dr.select(*selections).where(
|
|
757
|
-
(
|
|
758
|
-
& (dr.c.is_latest == true())
|
|
734
|
+
(dr.c("path").op("GLOB")(sub_glob)) & (dr.c("is_latest") == true())
|
|
759
735
|
)
|
|
760
736
|
),
|
|
761
737
|
(0, 0),
|
|
@@ -764,9 +740,6 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
764
740
|
return results[0] or 0, results[1] or 0
|
|
765
741
|
return results[0] or 0, 0
|
|
766
742
|
|
|
767
|
-
def path_expr(self, t):
|
|
768
|
-
return t.c.path
|
|
769
|
-
|
|
770
743
|
def _find_query(
|
|
771
744
|
self,
|
|
772
745
|
dataset_rows: "DataTable",
|
|
@@ -781,11 +754,12 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
781
754
|
conds = []
|
|
782
755
|
|
|
783
756
|
dr = dataset_rows
|
|
784
|
-
de = dr.
|
|
785
|
-
|
|
757
|
+
de = dr.dir_expansion()
|
|
758
|
+
q = de.query(
|
|
759
|
+
dr.select().where(dr.c("is_latest") == true()).subquery()
|
|
786
760
|
).subquery()
|
|
787
|
-
q = self.expand_query(de, dr).subquery()
|
|
788
|
-
path =
|
|
761
|
+
q = self.expand_query(de, q, dr).subquery()
|
|
762
|
+
path = de.c(q, "path")
|
|
789
763
|
|
|
790
764
|
if parent_path:
|
|
791
765
|
sub_glob = posixpath.join(parent_path, "*")
|
|
@@ -800,7 +774,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
800
774
|
query = sa.select(*columns)
|
|
801
775
|
query = query.where(*conds)
|
|
802
776
|
if type is not None:
|
|
803
|
-
query = self.add_node_type_where(query, type, include_subobjects)
|
|
777
|
+
query = self.add_node_type_where(query, type, dr, include_subobjects)
|
|
804
778
|
if order_by is not None:
|
|
805
779
|
if isinstance(order_by, str):
|
|
806
780
|
order_by = [order_by]
|
|
@@ -828,14 +802,14 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
828
802
|
if sort is not None:
|
|
829
803
|
if not isinstance(sort, list):
|
|
830
804
|
sort = [sort]
|
|
831
|
-
query = query.order_by(*(sa.text(s) for s in sort)) # type: ignore [attr-defined]
|
|
805
|
+
query = query.order_by(*(sa.text(dr.col_name(s)) for s in sort)) # type: ignore [attr-defined]
|
|
832
806
|
|
|
833
807
|
prefix_len = len(node.path)
|
|
834
808
|
|
|
835
809
|
def make_node_with_path(node: Node) -> NodeWithPath:
|
|
836
810
|
return NodeWithPath(node, node.path[prefix_len:].lstrip("/").split("/"))
|
|
837
811
|
|
|
838
|
-
return map(make_node_with_path, self.get_nodes(query))
|
|
812
|
+
return map(make_node_with_path, self.get_nodes(query, dr))
|
|
839
813
|
|
|
840
814
|
def find(
|
|
841
815
|
self,
|
|
@@ -850,8 +824,10 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
850
824
|
Finds nodes that match certain criteria and only looks for latest nodes
|
|
851
825
|
under the passed node.
|
|
852
826
|
"""
|
|
827
|
+
dr = dataset_rows
|
|
828
|
+
fields = [dr.col_name(f) for f in fields]
|
|
853
829
|
query = self._find_query(
|
|
854
|
-
|
|
830
|
+
dr,
|
|
855
831
|
node.path,
|
|
856
832
|
fields=fields,
|
|
857
833
|
type=type,
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import re
|
|
2
1
|
from collections.abc import Sequence
|
|
3
2
|
from tempfile import NamedTemporaryFile
|
|
4
3
|
from typing import TYPE_CHECKING, Any, Optional
|
|
@@ -13,6 +12,7 @@ from datachain.lib.file import ArrowRow, File
|
|
|
13
12
|
from datachain.lib.model_store import ModelStore
|
|
14
13
|
from datachain.lib.signal_schema import SignalSchema
|
|
15
14
|
from datachain.lib.udf import Generator
|
|
15
|
+
from datachain.lib.utils import normalize_col_names
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
from datasets.features.features import Features
|
|
@@ -128,7 +128,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
128
128
|
signal_schema = _get_datachain_schema(schema)
|
|
129
129
|
if signal_schema:
|
|
130
130
|
return signal_schema.values
|
|
131
|
-
columns =
|
|
131
|
+
columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type]
|
|
132
132
|
hf_schema = _get_hf_schema(schema)
|
|
133
133
|
if hf_schema:
|
|
134
134
|
return {
|
|
@@ -143,19 +143,6 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
143
143
|
return output
|
|
144
144
|
|
|
145
145
|
|
|
146
|
-
def _convert_col_names(col_names: Sequence[str]) -> list[str]:
|
|
147
|
-
default_column = 0
|
|
148
|
-
converted_col_names = []
|
|
149
|
-
for column in col_names:
|
|
150
|
-
column = column.lower()
|
|
151
|
-
column = re.sub("[^0-9a-z_]+", "", column)
|
|
152
|
-
if not column:
|
|
153
|
-
column = f"c{default_column}"
|
|
154
|
-
default_column += 1
|
|
155
|
-
converted_col_names.append(column)
|
|
156
|
-
return converted_col_names
|
|
157
|
-
|
|
158
|
-
|
|
159
146
|
def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
|
|
160
147
|
"""Convert pyarrow types to basic types."""
|
|
161
148
|
from datetime import datetime
|
datachain/lib/data_model.py
CHANGED
|
@@ -2,9 +2,10 @@ from collections.abc import Sequence
|
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from typing import ClassVar, Union, get_args, get_origin
|
|
4
4
|
|
|
5
|
-
from pydantic import BaseModel, create_model
|
|
5
|
+
from pydantic import BaseModel, Field, create_model
|
|
6
6
|
|
|
7
7
|
from datachain.lib.model_store import ModelStore
|
|
8
|
+
from datachain.lib.utils import normalize_col_names
|
|
8
9
|
|
|
9
10
|
StandardType = Union[
|
|
10
11
|
type[int],
|
|
@@ -60,7 +61,14 @@ def is_chain_type(t: type) -> bool:
|
|
|
60
61
|
|
|
61
62
|
|
|
62
63
|
def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
|
|
63
|
-
|
|
64
|
+
# Gets a map of a normalized_name -> original_name
|
|
65
|
+
columns = normalize_col_names(list(data_dict.keys()))
|
|
66
|
+
# We reverse if for convenience to original_name -> normalized_name
|
|
67
|
+
columns = {v: k for k, v in columns.items()}
|
|
68
|
+
|
|
69
|
+
fields = {
|
|
70
|
+
columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items()
|
|
71
|
+
}
|
|
64
72
|
return create_model(
|
|
65
73
|
name,
|
|
66
74
|
__base__=(DataModel,), # type: ignore[call-overload]
|