datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +17 -8
- datachain/catalog/catalog.py +5 -5
- datachain/cli.py +0 -2
- datachain/data_storage/schema.py +5 -5
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +7 -7
- datachain/lib/arrow.py +25 -8
- datachain/lib/clip.py +6 -11
- datachain/lib/convert/__init__.py +0 -0
- datachain/lib/convert/flatten.py +67 -0
- datachain/lib/convert/type_converter.py +96 -0
- datachain/lib/convert/unflatten.py +69 -0
- datachain/lib/convert/values_to_tuples.py +85 -0
- datachain/lib/data_model.py +74 -0
- datachain/lib/dc.py +225 -168
- datachain/lib/file.py +41 -41
- datachain/lib/gpt4_vision.py +1 -9
- datachain/lib/hf_image_to_text.py +9 -17
- datachain/lib/hf_pipeline.py +4 -12
- datachain/lib/image.py +2 -18
- datachain/lib/image_transform.py +0 -1
- datachain/lib/iptc_exif_xmp.py +8 -15
- datachain/lib/meta_formats.py +1 -5
- datachain/lib/model_store.py +77 -0
- datachain/lib/pytorch.py +9 -21
- datachain/lib/signal_schema.py +139 -60
- datachain/lib/text.py +5 -16
- datachain/lib/udf.py +114 -30
- datachain/lib/udf_signature.py +5 -5
- datachain/lib/webdataset.py +3 -3
- datachain/lib/webdataset_laion.py +2 -3
- datachain/node.py +4 -4
- datachain/query/batch.py +1 -1
- datachain/query/dataset.py +51 -178
- datachain/query/dispatch.py +43 -30
- datachain/query/udf.py +46 -26
- datachain/remote/studio.py +1 -9
- datachain/torch/__init__.py +21 -0
- datachain/utils.py +39 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
- datachain/image/__init__.py +0 -3
- datachain/lib/cached_stream.py +0 -38
- datachain/lib/claude.py +0 -69
- datachain/lib/feature.py +0 -412
- datachain/lib/feature_registry.py +0 -51
- datachain/lib/feature_utils.py +0 -154
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -2,9 +2,8 @@ from collections.abc import Iterator
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
from pydantic import Field
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
6
|
|
|
7
|
-
from datachain.lib.feature import Feature
|
|
8
7
|
from datachain.lib.file import File
|
|
9
8
|
from datachain.lib.webdataset import WDSBasic, WDSReadableSubclass
|
|
10
9
|
|
|
@@ -34,7 +33,7 @@ class WDSLaion(WDSBasic):
|
|
|
34
33
|
json: Laion # type: ignore[assignment]
|
|
35
34
|
|
|
36
35
|
|
|
37
|
-
class LaionMeta(
|
|
36
|
+
class LaionMeta(BaseModel):
|
|
38
37
|
file: File
|
|
39
38
|
index: Optional[int] = Field(default=None)
|
|
40
39
|
b32_img: list[float] = Field(default=None)
|
datachain/node.py
CHANGED
|
@@ -46,8 +46,8 @@ class DirTypeGroup:
|
|
|
46
46
|
|
|
47
47
|
@attrs.define
|
|
48
48
|
class Node:
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
sys__id: int = 0
|
|
50
|
+
sys__rand: int = -1
|
|
51
51
|
vtype: str = ""
|
|
52
52
|
dir_type: Optional[int] = None
|
|
53
53
|
parent: str = ""
|
|
@@ -127,11 +127,11 @@ class Node:
|
|
|
127
127
|
|
|
128
128
|
@classmethod
|
|
129
129
|
def from_dir(cls, parent, name, **kwargs) -> "Node":
|
|
130
|
-
return cls(
|
|
130
|
+
return cls(sys__id=-1, dir_type=DirType.DIR, parent=parent, name=name, **kwargs)
|
|
131
131
|
|
|
132
132
|
@classmethod
|
|
133
133
|
def root(cls) -> "Node":
|
|
134
|
-
return cls(
|
|
134
|
+
return cls(sys__id=-1, dir_type=DirType.DIR)
|
|
135
135
|
|
|
136
136
|
|
|
137
137
|
@attrs.define
|
datachain/query/batch.py
CHANGED
|
@@ -104,7 +104,7 @@ class Partition(BatchingStrategy):
|
|
|
104
104
|
with contextlib.closing(
|
|
105
105
|
execute(
|
|
106
106
|
query,
|
|
107
|
-
order_by=(PARTITION_COLUMN_ID, "
|
|
107
|
+
order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
|
|
108
108
|
limit=query._limit,
|
|
109
109
|
)
|
|
110
110
|
) as rows:
|
datachain/query/dataset.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import ast
|
|
2
1
|
import contextlib
|
|
3
2
|
import datetime
|
|
4
3
|
import inspect
|
|
@@ -10,7 +9,6 @@ import re
|
|
|
10
9
|
import string
|
|
11
10
|
import subprocess
|
|
12
11
|
import sys
|
|
13
|
-
import types
|
|
14
12
|
from abc import ABC, abstractmethod
|
|
15
13
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
16
14
|
from copy import copy
|
|
@@ -26,10 +24,8 @@ from typing import (
|
|
|
26
24
|
)
|
|
27
25
|
|
|
28
26
|
import attrs
|
|
29
|
-
import pandas as pd
|
|
30
27
|
import sqlalchemy
|
|
31
28
|
from attrs import frozen
|
|
32
|
-
from dill import dumps, source
|
|
33
29
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
|
|
34
30
|
from sqlalchemy import Column
|
|
35
31
|
from sqlalchemy.sql import func as f
|
|
@@ -52,12 +48,14 @@ from datachain.data_storage.schema import (
|
|
|
52
48
|
from datachain.dataset import DatasetStatus, RowDict
|
|
53
49
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
54
50
|
from datachain.progress import CombinedDownloadCallback
|
|
55
|
-
from datachain.query.schema import DEFAULT_DELIMITER
|
|
56
51
|
from datachain.sql.functions import rand
|
|
57
52
|
from datachain.storage import Storage, StorageURI
|
|
58
|
-
from datachain.utils import
|
|
53
|
+
from datachain.utils import (
|
|
54
|
+
batched,
|
|
55
|
+
determine_processes,
|
|
56
|
+
filtered_cloudpickle_dumps,
|
|
57
|
+
)
|
|
59
58
|
|
|
60
|
-
from .batch import RowBatch
|
|
61
59
|
from .metrics import metrics
|
|
62
60
|
from .schema import C, UDFParamSpec, normalize_param
|
|
63
61
|
from .session import Session
|
|
@@ -257,7 +255,7 @@ class DatasetDiffOperation(Step):
|
|
|
257
255
|
"""
|
|
258
256
|
|
|
259
257
|
def apply(self, query_generator, temp_tables: list[str]):
|
|
260
|
-
source_query = query_generator.exclude(("
|
|
258
|
+
source_query = query_generator.exclude(("sys__id",))
|
|
261
259
|
target_query = self.dq.apply_steps().select()
|
|
262
260
|
temp_tables.extend(self.dq.temp_table_names)
|
|
263
261
|
|
|
@@ -427,22 +425,6 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
427
425
|
return DEFAULT_CALLBACK
|
|
428
426
|
|
|
429
427
|
|
|
430
|
-
def run_udf(
|
|
431
|
-
udf,
|
|
432
|
-
udf_inputs,
|
|
433
|
-
catalog,
|
|
434
|
-
is_generator,
|
|
435
|
-
cache,
|
|
436
|
-
download_cb: Callback = DEFAULT_CALLBACK,
|
|
437
|
-
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
438
|
-
) -> Iterator[Iterable["UDFResult"]]:
|
|
439
|
-
for batch in udf_inputs:
|
|
440
|
-
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
|
|
441
|
-
output = udf(catalog, batch, is_generator, cache, cb=download_cb)
|
|
442
|
-
processed_cb.relative_update(n_rows)
|
|
443
|
-
yield output
|
|
444
|
-
|
|
445
|
-
|
|
446
428
|
@frozen
|
|
447
429
|
class UDF(Step, ABC):
|
|
448
430
|
udf: UDFType
|
|
@@ -508,7 +490,7 @@ class UDF(Step, ABC):
|
|
|
508
490
|
elif processes:
|
|
509
491
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
510
492
|
udf_info = {
|
|
511
|
-
"
|
|
493
|
+
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
512
494
|
"catalog_init": self.catalog.get_init_params(),
|
|
513
495
|
"id_generator_clone_params": (
|
|
514
496
|
self.catalog.id_generator.clone_params()
|
|
@@ -529,16 +511,15 @@ class UDF(Step, ABC):
|
|
|
529
511
|
|
|
530
512
|
envs = dict(os.environ)
|
|
531
513
|
envs.update({"PYTHONPATH": os.getcwd()})
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
raise RuntimeError("UDF Execution Failed!")
|
|
514
|
+
process_data = filtered_cloudpickle_dumps(udf_info)
|
|
515
|
+
result = subprocess.run( # noqa: S603
|
|
516
|
+
[datachain_exec_path, "--internal-run-udf"],
|
|
517
|
+
input=process_data,
|
|
518
|
+
check=False,
|
|
519
|
+
env=envs,
|
|
520
|
+
)
|
|
521
|
+
if result.returncode != 0:
|
|
522
|
+
raise RuntimeError("UDF Execution Failed!")
|
|
542
523
|
|
|
543
524
|
else:
|
|
544
525
|
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
@@ -548,9 +529,6 @@ class UDF(Step, ABC):
|
|
|
548
529
|
else:
|
|
549
530
|
udf = self.udf
|
|
550
531
|
|
|
551
|
-
if hasattr(udf.func, "setup") and callable(udf.func.setup):
|
|
552
|
-
udf.func.setup()
|
|
553
|
-
|
|
554
532
|
warehouse = self.catalog.warehouse
|
|
555
533
|
|
|
556
534
|
with contextlib.closing(
|
|
@@ -560,8 +538,7 @@ class UDF(Step, ABC):
|
|
|
560
538
|
processed_cb = get_processed_callback()
|
|
561
539
|
generated_cb = get_generated_callback(self.is_generator)
|
|
562
540
|
try:
|
|
563
|
-
udf_results =
|
|
564
|
-
udf,
|
|
541
|
+
udf_results = udf.run(
|
|
565
542
|
udf_inputs,
|
|
566
543
|
self.catalog,
|
|
567
544
|
self.is_generator,
|
|
@@ -583,9 +560,6 @@ class UDF(Step, ABC):
|
|
|
583
560
|
|
|
584
561
|
warehouse.insert_rows_done(udf_table)
|
|
585
562
|
|
|
586
|
-
if hasattr(udf.func, "teardown") and callable(udf.func.teardown):
|
|
587
|
-
udf.func.teardown()
|
|
588
|
-
|
|
589
563
|
except QueryScriptCancelError:
|
|
590
564
|
self.catalog.warehouse.close()
|
|
591
565
|
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
|
|
@@ -594,57 +568,6 @@ class UDF(Step, ABC):
|
|
|
594
568
|
self.catalog.warehouse.close()
|
|
595
569
|
raise
|
|
596
570
|
|
|
597
|
-
@contextlib.contextmanager
|
|
598
|
-
def process_feature_module(self):
|
|
599
|
-
# Generate a random name for the feature module
|
|
600
|
-
feature_module_name = "tmp" + _random_string(10)
|
|
601
|
-
# Create a dynamic module with the generated name
|
|
602
|
-
dynamic_module = types.ModuleType(feature_module_name)
|
|
603
|
-
# Get the import lines for the necessary objects from the main module
|
|
604
|
-
main_module = sys.modules["__main__"]
|
|
605
|
-
if getattr(main_module, "__file__", None):
|
|
606
|
-
import_lines = list(get_imports(main_module))
|
|
607
|
-
else:
|
|
608
|
-
import_lines = [
|
|
609
|
-
source.getimport(obj, alias=name)
|
|
610
|
-
for name, obj in main_module.__dict__.items()
|
|
611
|
-
if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
|
|
612
|
-
]
|
|
613
|
-
|
|
614
|
-
# Get the feature classes from the main module
|
|
615
|
-
feature_classes = {
|
|
616
|
-
name: obj
|
|
617
|
-
for name, obj in main_module.__dict__.items()
|
|
618
|
-
if _feature_predicate(obj)
|
|
619
|
-
}
|
|
620
|
-
if not feature_classes:
|
|
621
|
-
yield None
|
|
622
|
-
return
|
|
623
|
-
|
|
624
|
-
# Get the source code of the feature classes
|
|
625
|
-
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
|
|
626
|
-
# Set the module name for the feature classes to the generated name
|
|
627
|
-
for name, cls in feature_classes.items():
|
|
628
|
-
cls.__module__ = feature_module_name
|
|
629
|
-
setattr(dynamic_module, name, cls)
|
|
630
|
-
# Add the dynamic module to the sys.modules dictionary
|
|
631
|
-
sys.modules[feature_module_name] = dynamic_module
|
|
632
|
-
# Combine the import lines and feature sources
|
|
633
|
-
feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
|
|
634
|
-
|
|
635
|
-
# Write the module content to a .py file
|
|
636
|
-
with open(f"{feature_module_name}.py", "w") as module_file:
|
|
637
|
-
module_file.write(feature_file)
|
|
638
|
-
|
|
639
|
-
try:
|
|
640
|
-
yield feature_module_name
|
|
641
|
-
finally:
|
|
642
|
-
for cls in feature_classes.values():
|
|
643
|
-
cls.__module__ = main_module.__name__
|
|
644
|
-
os.unlink(f"{feature_module_name}.py")
|
|
645
|
-
# Remove the dynamic module from sys.modules
|
|
646
|
-
del sys.modules[feature_module_name]
|
|
647
|
-
|
|
648
571
|
def create_partitions_table(self, query: Select) -> "Table":
|
|
649
572
|
"""
|
|
650
573
|
Create temporary table with group by partitions.
|
|
@@ -663,7 +586,7 @@ class UDF(Step, ABC):
|
|
|
663
586
|
|
|
664
587
|
# fill table with partitions
|
|
665
588
|
cols = [
|
|
666
|
-
query.selected_columns.
|
|
589
|
+
query.selected_columns.sys__id,
|
|
667
590
|
f.dense_rank().over(order_by=list_partition_by).label(PARTITION_COLUMN_ID),
|
|
668
591
|
]
|
|
669
592
|
self.catalog.warehouse.db.execute(
|
|
@@ -697,7 +620,7 @@ class UDF(Step, ABC):
|
|
|
697
620
|
subq = query.subquery()
|
|
698
621
|
query = (
|
|
699
622
|
sqlalchemy.select(*subq.c)
|
|
700
|
-
.outerjoin(partition_tbl, partition_tbl.c.
|
|
623
|
+
.outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id)
|
|
701
624
|
.add_columns(*partition_columns())
|
|
702
625
|
)
|
|
703
626
|
|
|
@@ -729,18 +652,18 @@ class UDFSignal(UDF):
|
|
|
729
652
|
columns = [
|
|
730
653
|
sqlalchemy.Column(c.name, c.type)
|
|
731
654
|
for c in query.selected_columns
|
|
732
|
-
if c.name != "
|
|
655
|
+
if c.name != "sys__id"
|
|
733
656
|
]
|
|
734
657
|
table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
|
|
735
658
|
select_q = query.with_only_columns(
|
|
736
|
-
*[c for c in query.selected_columns if c.name != "
|
|
659
|
+
*[c for c in query.selected_columns if c.name != "sys__id"]
|
|
737
660
|
)
|
|
738
661
|
|
|
739
662
|
# if there is order by clause we need row_number to preserve order
|
|
740
663
|
# if there is no order by clause we still need row_number to generate
|
|
741
664
|
# unique ids as uniqueness is important for this table
|
|
742
665
|
select_q = select_q.add_columns(
|
|
743
|
-
f.row_number().over(order_by=select_q._order_by_clauses).label("
|
|
666
|
+
f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
|
|
744
667
|
)
|
|
745
668
|
|
|
746
669
|
self.catalog.warehouse.db.execute(
|
|
@@ -756,7 +679,7 @@ class UDFSignal(UDF):
|
|
|
756
679
|
if query._order_by_clauses:
|
|
757
680
|
# we are adding ordering only if it's explicitly added by user in
|
|
758
681
|
# query part before adding signals
|
|
759
|
-
q = q.order_by(table.c.
|
|
682
|
+
q = q.order_by(table.c.sys__id)
|
|
760
683
|
return q, [table]
|
|
761
684
|
|
|
762
685
|
def create_result_query(
|
|
@@ -766,7 +689,7 @@ class UDFSignal(UDF):
|
|
|
766
689
|
original_cols = [c for c in subq.c if c.name not in partition_col_names]
|
|
767
690
|
|
|
768
691
|
# new signal columns that are added to udf_table
|
|
769
|
-
signal_cols = [c for c in udf_table.c if c.name != "
|
|
692
|
+
signal_cols = [c for c in udf_table.c if c.name != "sys__id"]
|
|
770
693
|
signal_name_cols = {c.name: c for c in signal_cols}
|
|
771
694
|
cols = signal_cols
|
|
772
695
|
|
|
@@ -786,7 +709,7 @@ class UDFSignal(UDF):
|
|
|
786
709
|
res = (
|
|
787
710
|
sqlalchemy.select(*cols1)
|
|
788
711
|
.select_from(subq)
|
|
789
|
-
.outerjoin(udf_table, udf_table.c.
|
|
712
|
+
.outerjoin(udf_table, udf_table.c.sys__id == subq.c.sys__id)
|
|
790
713
|
.add_columns(*cols2)
|
|
791
714
|
)
|
|
792
715
|
else:
|
|
@@ -795,7 +718,7 @@ class UDFSignal(UDF):
|
|
|
795
718
|
if query._order_by_clauses:
|
|
796
719
|
# if ordering is used in query part before adding signals, we
|
|
797
720
|
# will have it as order by id from select from pre-created udf table
|
|
798
|
-
res = res.order_by(subq.c.
|
|
721
|
+
res = res.order_by(subq.c.sys__id)
|
|
799
722
|
|
|
800
723
|
if self.partition_by is not None:
|
|
801
724
|
subquery = res.subquery()
|
|
@@ -833,7 +756,7 @@ class RowGenerator(UDF):
|
|
|
833
756
|
# we get the same rows as we got as inputs of UDF since selecting
|
|
834
757
|
# without ordering can be non deterministic in some databases
|
|
835
758
|
c = query.selected_columns
|
|
836
|
-
query = query.order_by(c.
|
|
759
|
+
query = query.order_by(c.sys__id)
|
|
837
760
|
|
|
838
761
|
udf_table_query = udf_table.select().subquery()
|
|
839
762
|
udf_table_cols: list[sqlalchemy.Label[Any]] = [
|
|
@@ -1025,7 +948,7 @@ class SQLJoin(Step):
|
|
|
1025
948
|
q1_column_names = {c.name for c in q1_columns}
|
|
1026
949
|
q2_columns = [
|
|
1027
950
|
c
|
|
1028
|
-
if c.name not in q1_column_names and c.name != "
|
|
951
|
+
if c.name not in q1_column_names and c.name != "sys__id"
|
|
1029
952
|
else c.label(self.rname.format(name=c.name))
|
|
1030
953
|
for c in q2.c
|
|
1031
954
|
]
|
|
@@ -1165,8 +1088,8 @@ class DatasetQuery:
|
|
|
1165
1088
|
self.version = version or ds.latest_version
|
|
1166
1089
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
1167
1090
|
self.column_types = copy(ds.schema)
|
|
1168
|
-
if "
|
|
1169
|
-
self.column_types.pop("
|
|
1091
|
+
if "sys__id" in self.column_types:
|
|
1092
|
+
self.column_types.pop("sys__id")
|
|
1170
1093
|
self.starting_step = QueryStep(self.catalog, name, self.version)
|
|
1171
1094
|
# attaching to specific dataset
|
|
1172
1095
|
self.name = name
|
|
@@ -1239,7 +1162,7 @@ class DatasetQuery:
|
|
|
1239
1162
|
query.steps = self._chunk_limit(query.steps, index, total)
|
|
1240
1163
|
|
|
1241
1164
|
# Prepend the chunk filter to the step chain.
|
|
1242
|
-
query = query.filter(C.
|
|
1165
|
+
query = query.filter(C.sys__rand % total == index)
|
|
1243
1166
|
query.steps = query.steps[-1:] + query.steps[:-1]
|
|
1244
1167
|
|
|
1245
1168
|
result = query.starting_step.apply()
|
|
@@ -1366,20 +1289,12 @@ class DatasetQuery:
|
|
|
1366
1289
|
finally:
|
|
1367
1290
|
self.cleanup()
|
|
1368
1291
|
|
|
1369
|
-
def to_records(self) -> list[dict]:
|
|
1370
|
-
|
|
1371
|
-
cols = result.columns
|
|
1372
|
-
return [dict(zip(cols, row)) for row in result]
|
|
1373
|
-
|
|
1374
|
-
def to_pandas(self) -> "pd.DataFrame":
|
|
1375
|
-
records = self.to_records()
|
|
1376
|
-
df = pd.DataFrame.from_records(records)
|
|
1377
|
-
df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
|
|
1378
|
-
return df
|
|
1292
|
+
def to_records(self) -> list[dict[str, Any]]:
|
|
1293
|
+
return self.results(lambda cols, row: dict(zip(cols, row)))
|
|
1379
1294
|
|
|
1380
1295
|
def shuffle(self) -> "Self":
|
|
1381
1296
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1382
|
-
return self.order_by(C.
|
|
1297
|
+
return self.order_by(C.sys__rand)
|
|
1383
1298
|
|
|
1384
1299
|
def sample(self, n) -> "Self":
|
|
1385
1300
|
"""
|
|
@@ -1395,22 +1310,6 @@ class DatasetQuery:
|
|
|
1395
1310
|
|
|
1396
1311
|
return sampled.limit(n)
|
|
1397
1312
|
|
|
1398
|
-
def show(self, limit=20) -> None:
|
|
1399
|
-
df = self.limit(limit).to_pandas()
|
|
1400
|
-
|
|
1401
|
-
options = ["display.max_colwidth", 50, "display.show_dimensions", False]
|
|
1402
|
-
with pd.option_context(*options):
|
|
1403
|
-
if inside_notebook():
|
|
1404
|
-
from IPython.display import display
|
|
1405
|
-
|
|
1406
|
-
display(df)
|
|
1407
|
-
|
|
1408
|
-
else:
|
|
1409
|
-
print(df.to_string())
|
|
1410
|
-
|
|
1411
|
-
if len(df) == limit:
|
|
1412
|
-
print(f"[limited by {limit} objects]")
|
|
1413
|
-
|
|
1414
1313
|
def clone(self, new_table=True) -> "Self":
|
|
1415
1314
|
obj = copy(self)
|
|
1416
1315
|
obj.steps = obj.steps.copy()
|
|
@@ -1508,30 +1407,35 @@ class DatasetQuery:
|
|
|
1508
1407
|
query.steps.append(SQLOffset(offset))
|
|
1509
1408
|
return query
|
|
1510
1409
|
|
|
1410
|
+
def as_scalar(self) -> Any:
|
|
1411
|
+
with self.as_iterable() as rows:
|
|
1412
|
+
row = next(iter(rows))
|
|
1413
|
+
return row[0]
|
|
1414
|
+
|
|
1511
1415
|
def count(self) -> int:
|
|
1512
1416
|
query = self.clone()
|
|
1513
1417
|
query.steps.append(SQLCount())
|
|
1514
|
-
return query.
|
|
1418
|
+
return query.as_scalar()
|
|
1515
1419
|
|
|
1516
|
-
def sum(self, col: ColumnElement):
|
|
1420
|
+
def sum(self, col: ColumnElement) -> int:
|
|
1517
1421
|
query = self.clone()
|
|
1518
1422
|
query.steps.append(SQLSelect((f.sum(col),)))
|
|
1519
|
-
return query.
|
|
1423
|
+
return query.as_scalar()
|
|
1520
1424
|
|
|
1521
|
-
def avg(self, col: ColumnElement):
|
|
1425
|
+
def avg(self, col: ColumnElement) -> int:
|
|
1522
1426
|
query = self.clone()
|
|
1523
1427
|
query.steps.append(SQLSelect((f.avg(col),)))
|
|
1524
|
-
return query.
|
|
1428
|
+
return query.as_scalar()
|
|
1525
1429
|
|
|
1526
|
-
def min(self, col: ColumnElement):
|
|
1430
|
+
def min(self, col: ColumnElement) -> int:
|
|
1527
1431
|
query = self.clone()
|
|
1528
1432
|
query.steps.append(SQLSelect((f.min(col),)))
|
|
1529
|
-
return query.
|
|
1433
|
+
return query.as_scalar()
|
|
1530
1434
|
|
|
1531
|
-
def max(self, col: ColumnElement):
|
|
1435
|
+
def max(self, col: ColumnElement) -> int:
|
|
1532
1436
|
query = self.clone()
|
|
1533
1437
|
query.steps.append(SQLSelect((f.max(col),)))
|
|
1534
|
-
return query.
|
|
1438
|
+
return query.as_scalar()
|
|
1535
1439
|
|
|
1536
1440
|
@detach
|
|
1537
1441
|
def group_by(self, *cols: ColumnElement) -> "Self":
|
|
@@ -1723,7 +1627,7 @@ class DatasetQuery:
|
|
|
1723
1627
|
c if isinstance(c, Column) else Column(c.name, c.type)
|
|
1724
1628
|
for c in query.columns
|
|
1725
1629
|
]
|
|
1726
|
-
if not [c for c in columns if c.name != "
|
|
1630
|
+
if not [c for c in columns if c.name != "sys__id"]:
|
|
1727
1631
|
raise RuntimeError(
|
|
1728
1632
|
"No columns to save in the query. "
|
|
1729
1633
|
"Ensure at least one column (other than 'id') is selected."
|
|
@@ -1742,11 +1646,11 @@ class DatasetQuery:
|
|
|
1742
1646
|
|
|
1743
1647
|
# Exclude the id column and let the db create it to avoid unique
|
|
1744
1648
|
# constraint violations.
|
|
1745
|
-
q = query.exclude(("
|
|
1649
|
+
q = query.exclude(("sys__id",))
|
|
1746
1650
|
if q._order_by_clauses:
|
|
1747
1651
|
# ensuring we have id sorted by order by clause if it exists in a query
|
|
1748
1652
|
q = q.add_columns(
|
|
1749
|
-
f.row_number().over(order_by=q._order_by_clauses).label("
|
|
1653
|
+
f.row_number().over(order_by=q._order_by_clauses).label("sys__id")
|
|
1750
1654
|
)
|
|
1751
1655
|
|
|
1752
1656
|
cols = tuple(c.name for c in q.columns)
|
|
@@ -1873,34 +1777,3 @@ def _random_string(length: int) -> str:
|
|
|
1873
1777
|
random.choice(string.ascii_letters + string.digits) # noqa: S311
|
|
1874
1778
|
for i in range(length)
|
|
1875
1779
|
)
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
def _feature_predicate(obj):
|
|
1879
|
-
from datachain.lib.feature import Feature
|
|
1880
|
-
|
|
1881
|
-
return inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, Feature)
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
def _imports(obj):
|
|
1885
|
-
return not source.isfrommain(obj)
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
def get_imports(m):
|
|
1889
|
-
root = ast.parse(inspect.getsource(m))
|
|
1890
|
-
|
|
1891
|
-
for node in ast.iter_child_nodes(root):
|
|
1892
|
-
if isinstance(node, ast.Import):
|
|
1893
|
-
module = None
|
|
1894
|
-
elif isinstance(node, ast.ImportFrom):
|
|
1895
|
-
module = node.module
|
|
1896
|
-
else:
|
|
1897
|
-
continue
|
|
1898
|
-
|
|
1899
|
-
for n in node.names:
|
|
1900
|
-
import_script = ""
|
|
1901
|
-
if module:
|
|
1902
|
-
import_script += f"from {module} "
|
|
1903
|
-
import_script += f"import {n.name}"
|
|
1904
|
-
if n.asname:
|
|
1905
|
-
import_script += f" as {n.asname}"
|
|
1906
|
-
yield import_script
|
datachain/query/dispatch.py
CHANGED
|
@@ -10,13 +10,12 @@ from typing import Any, Optional
|
|
|
10
10
|
|
|
11
11
|
import attrs
|
|
12
12
|
import multiprocess
|
|
13
|
-
from
|
|
13
|
+
from cloudpickle import load, loads
|
|
14
14
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
15
15
|
from multiprocess import get_context
|
|
16
16
|
|
|
17
17
|
from datachain.catalog import Catalog
|
|
18
18
|
from datachain.catalog.loader import get_distributed_class
|
|
19
|
-
from datachain.query.batch import RowBatch
|
|
20
19
|
from datachain.query.dataset import (
|
|
21
20
|
get_download_callback,
|
|
22
21
|
get_generated_callback,
|
|
@@ -85,7 +84,7 @@ def put_into_queue(queue: Queue, item: Any) -> None:
|
|
|
85
84
|
|
|
86
85
|
def udf_entrypoint() -> int:
|
|
87
86
|
# Load UDF info from stdin
|
|
88
|
-
udf_info = load(stdin.buffer)
|
|
87
|
+
udf_info = load(stdin.buffer)
|
|
89
88
|
|
|
90
89
|
(
|
|
91
90
|
warehouse_class,
|
|
@@ -96,7 +95,7 @@ def udf_entrypoint() -> int:
|
|
|
96
95
|
|
|
97
96
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
98
97
|
dispatch = UDFDispatcher(
|
|
99
|
-
udf_info["
|
|
98
|
+
udf_info["udf_data"],
|
|
100
99
|
udf_info["catalog_init"],
|
|
101
100
|
udf_info["id_generator_clone_params"],
|
|
102
101
|
udf_info["metastore_clone_params"],
|
|
@@ -109,7 +108,7 @@ def udf_entrypoint() -> int:
|
|
|
109
108
|
batching = udf_info["batching"]
|
|
110
109
|
table = udf_info["table"]
|
|
111
110
|
n_workers = udf_info["processes"]
|
|
112
|
-
udf = udf_info["
|
|
111
|
+
udf = loads(udf_info["udf_data"])
|
|
113
112
|
if n_workers is True:
|
|
114
113
|
# Use default number of CPUs (cores)
|
|
115
114
|
n_workers = None
|
|
@@ -147,7 +146,7 @@ class UDFDispatcher:
|
|
|
147
146
|
|
|
148
147
|
def __init__(
|
|
149
148
|
self,
|
|
150
|
-
|
|
149
|
+
udf_data,
|
|
151
150
|
catalog_init_params,
|
|
152
151
|
id_generator_clone_params,
|
|
153
152
|
metastore_clone_params,
|
|
@@ -156,14 +155,7 @@ class UDFDispatcher:
|
|
|
156
155
|
is_generator=False,
|
|
157
156
|
buffer_size=DEFAULT_BATCH_SIZE,
|
|
158
157
|
):
|
|
159
|
-
|
|
160
|
-
# and so these two types are not considered exactly equal,
|
|
161
|
-
# even if they have the same import path.
|
|
162
|
-
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
|
|
163
|
-
self.udf = udf
|
|
164
|
-
else:
|
|
165
|
-
self.udf = None
|
|
166
|
-
self.udf_factory = udf
|
|
158
|
+
self.udf_data = udf_data
|
|
167
159
|
self.catalog_init_params = catalog_init_params
|
|
168
160
|
(
|
|
169
161
|
self.id_generator_class,
|
|
@@ -215,6 +207,15 @@ class UDFDispatcher:
|
|
|
215
207
|
self.catalog = Catalog(
|
|
216
208
|
id_generator, metastore, warehouse, **self.catalog_init_params
|
|
217
209
|
)
|
|
210
|
+
udf = loads(self.udf_data)
|
|
211
|
+
# isinstance cannot be used here, as cloudpickle packages the entire class
|
|
212
|
+
# definition, and so these two types are not considered exactly equal,
|
|
213
|
+
# even if they have the same import path.
|
|
214
|
+
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
|
|
215
|
+
self.udf = udf
|
|
216
|
+
else:
|
|
217
|
+
self.udf = None
|
|
218
|
+
self.udf_factory = udf
|
|
218
219
|
if not self.udf:
|
|
219
220
|
self.udf = self.udf_factory()
|
|
220
221
|
|
|
@@ -355,6 +356,15 @@ class WorkerCallback(Callback):
|
|
|
355
356
|
put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
|
|
356
357
|
|
|
357
358
|
|
|
359
|
+
class ProcessedCallback(Callback):
|
|
360
|
+
def __init__(self):
|
|
361
|
+
self.processed_rows: Optional[int] = None
|
|
362
|
+
super().__init__()
|
|
363
|
+
|
|
364
|
+
def relative_update(self, inc: int = 1) -> None:
|
|
365
|
+
self.processed_rows = inc
|
|
366
|
+
|
|
367
|
+
|
|
358
368
|
@attrs.define
|
|
359
369
|
class UDFWorker:
|
|
360
370
|
catalog: Catalog
|
|
@@ -370,25 +380,28 @@ class UDFWorker:
|
|
|
370
380
|
return WorkerCallback(self.done_queue)
|
|
371
381
|
|
|
372
382
|
def run(self) -> None:
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
)
|
|
383
|
+
processed_cb = ProcessedCallback()
|
|
384
|
+
udf_results = self.udf.run(
|
|
385
|
+
self.get_inputs(),
|
|
386
|
+
self.catalog,
|
|
387
|
+
self.is_generator,
|
|
388
|
+
self.cache,
|
|
389
|
+
download_cb=self.cb,
|
|
390
|
+
processed_cb=processed_cb,
|
|
391
|
+
)
|
|
392
|
+
for udf_output in udf_results:
|
|
384
393
|
if isinstance(udf_output, GeneratorType):
|
|
385
394
|
udf_output = list(udf_output) # can not pickle generator
|
|
386
395
|
put_into_queue(
|
|
387
396
|
self.done_queue,
|
|
388
|
-
{
|
|
397
|
+
{
|
|
398
|
+
"status": OK_STATUS,
|
|
399
|
+
"result": udf_output,
|
|
400
|
+
"processed": processed_cb.processed_rows,
|
|
401
|
+
},
|
|
389
402
|
)
|
|
390
|
-
|
|
391
|
-
if hasattr(self.udf.func, "teardown") and callable(self.udf.func.teardown):
|
|
392
|
-
self.udf.func.teardown()
|
|
393
|
-
|
|
394
403
|
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
|
|
404
|
+
|
|
405
|
+
def get_inputs(self):
|
|
406
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
407
|
+
yield batch
|