datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +35 -15
- datachain/cli.py +37 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +646 -152
- datachain/lib/file.py +117 -15
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +39 -14
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/lib/webdataset.py +4 -3
- datachain/node.py +11 -8
- datachain/query/dataset.py +66 -147
- datachain/query/dispatch.py +15 -13
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +45 -0
- datachain-0.2.12.dist-info/METADATA +412 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
- datachain/lib/feature_registry.py +0 -77
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.10.dist-info/METADATA +0 -430
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/lib/webdataset.py
CHANGED
|
@@ -13,8 +13,9 @@ from typing import (
|
|
|
13
13
|
get_origin,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
-
from pydantic import
|
|
16
|
+
from pydantic import Field
|
|
17
17
|
|
|
18
|
+
from datachain.lib.data_model import DataModel
|
|
18
19
|
from datachain.lib.file import File, TarVFile
|
|
19
20
|
from datachain.lib.utils import DataChainError
|
|
20
21
|
|
|
@@ -45,7 +46,7 @@ class UnknownFileExtensionError(WDSError):
|
|
|
45
46
|
super().__init__(tar_stream, f"unknown extension '{ext}' for file '{name}'")
|
|
46
47
|
|
|
47
48
|
|
|
48
|
-
class WDSBasic(
|
|
49
|
+
class WDSBasic(DataModel):
|
|
49
50
|
file: File
|
|
50
51
|
|
|
51
52
|
|
|
@@ -74,7 +75,7 @@ class WDSAllFile(WDSBasic):
|
|
|
74
75
|
cbor: Optional[bytes] = Field(default=None)
|
|
75
76
|
|
|
76
77
|
|
|
77
|
-
class WDSReadableSubclass(
|
|
78
|
+
class WDSReadableSubclass(DataModel):
|
|
78
79
|
@staticmethod
|
|
79
80
|
def _reader(builder, item: tarfile.TarInfo) -> "WDSReadableSubclass":
|
|
80
81
|
raise NotImplementedError
|
datachain/node.py
CHANGED
|
@@ -5,7 +5,7 @@ import attrs
|
|
|
5
5
|
|
|
6
6
|
from datachain.cache import UniqueId
|
|
7
7
|
from datachain.storage import StorageURI
|
|
8
|
-
from datachain.utils import time_to_str
|
|
8
|
+
from datachain.utils import TIME_ZERO, time_to_str
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from typing_extensions import Self
|
|
@@ -111,13 +111,16 @@ class Node:
|
|
|
111
111
|
if storage is None:
|
|
112
112
|
storage = self.source
|
|
113
113
|
return UniqueId(
|
|
114
|
-
storage,
|
|
115
|
-
self.parent,
|
|
116
|
-
self.name,
|
|
117
|
-
self.
|
|
118
|
-
self.
|
|
119
|
-
self.
|
|
120
|
-
self.
|
|
114
|
+
storage=storage,
|
|
115
|
+
parent=self.parent,
|
|
116
|
+
name=self.name,
|
|
117
|
+
size=self.size,
|
|
118
|
+
version=self.version or "",
|
|
119
|
+
etag=self.etag,
|
|
120
|
+
is_latest=self.is_latest,
|
|
121
|
+
vtype=self.vtype,
|
|
122
|
+
location=self.location,
|
|
123
|
+
last_modified=self.last_modified or TIME_ZERO,
|
|
121
124
|
)
|
|
122
125
|
|
|
123
126
|
@classmethod
|
datachain/query/dataset.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import ast
|
|
2
1
|
import contextlib
|
|
3
2
|
import datetime
|
|
4
3
|
import inspect
|
|
@@ -10,7 +9,6 @@ import re
|
|
|
10
9
|
import string
|
|
11
10
|
import subprocess
|
|
12
11
|
import sys
|
|
13
|
-
import types
|
|
14
12
|
from abc import ABC, abstractmethod
|
|
15
13
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
16
14
|
from copy import copy
|
|
@@ -26,12 +24,9 @@ from typing import (
|
|
|
26
24
|
)
|
|
27
25
|
|
|
28
26
|
import attrs
|
|
29
|
-
import pandas as pd
|
|
30
27
|
import sqlalchemy
|
|
31
28
|
from attrs import frozen
|
|
32
|
-
from dill import dumps, source
|
|
33
29
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
|
|
34
|
-
from pydantic import BaseModel
|
|
35
30
|
from sqlalchemy import Column
|
|
36
31
|
from sqlalchemy.sql import func as f
|
|
37
32
|
from sqlalchemy.sql.elements import ColumnClause, ColumnElement
|
|
@@ -53,10 +48,14 @@ from datachain.data_storage.schema import (
|
|
|
53
48
|
from datachain.dataset import DatasetStatus, RowDict
|
|
54
49
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
55
50
|
from datachain.progress import CombinedDownloadCallback
|
|
56
|
-
from datachain.query.schema import DEFAULT_DELIMITER
|
|
57
51
|
from datachain.sql.functions import rand
|
|
58
52
|
from datachain.storage import Storage, StorageURI
|
|
59
|
-
from datachain.utils import
|
|
53
|
+
from datachain.utils import (
|
|
54
|
+
batched,
|
|
55
|
+
determine_processes,
|
|
56
|
+
filtered_cloudpickle_dumps,
|
|
57
|
+
get_datachain_executable,
|
|
58
|
+
)
|
|
60
59
|
|
|
61
60
|
from .metrics import metrics
|
|
62
61
|
from .schema import C, UDFParamSpec, normalize_param
|
|
@@ -428,7 +427,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
428
427
|
|
|
429
428
|
|
|
430
429
|
@frozen
|
|
431
|
-
class
|
|
430
|
+
class UDFStep(Step, ABC):
|
|
432
431
|
udf: UDFType
|
|
433
432
|
catalog: "Catalog"
|
|
434
433
|
partition_by: Optional[PartitionByType] = None
|
|
@@ -492,7 +491,7 @@ class UDF(Step, ABC):
|
|
|
492
491
|
elif processes:
|
|
493
492
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
494
493
|
udf_info = {
|
|
495
|
-
"
|
|
494
|
+
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
496
495
|
"catalog_init": self.catalog.get_init_params(),
|
|
497
496
|
"id_generator_clone_params": (
|
|
498
497
|
self.catalog.id_generator.clone_params()
|
|
@@ -509,20 +508,18 @@ class UDF(Step, ABC):
|
|
|
509
508
|
|
|
510
509
|
# Run the UDFDispatcher in another process to avoid needing
|
|
511
510
|
# if __name__ == '__main__': in user scripts
|
|
512
|
-
|
|
513
|
-
|
|
511
|
+
exec_cmd = get_datachain_executable()
|
|
514
512
|
envs = dict(os.environ)
|
|
515
513
|
envs.update({"PYTHONPATH": os.getcwd()})
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
raise RuntimeError("UDF Execution Failed!")
|
|
514
|
+
process_data = filtered_cloudpickle_dumps(udf_info)
|
|
515
|
+
result = subprocess.run( # noqa: S603
|
|
516
|
+
[*exec_cmd, "internal-run-udf"],
|
|
517
|
+
input=process_data,
|
|
518
|
+
check=False,
|
|
519
|
+
env=envs,
|
|
520
|
+
)
|
|
521
|
+
if result.returncode != 0:
|
|
522
|
+
raise RuntimeError("UDF Execution Failed!")
|
|
526
523
|
|
|
527
524
|
else:
|
|
528
525
|
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
@@ -571,57 +568,6 @@ class UDF(Step, ABC):
|
|
|
571
568
|
self.catalog.warehouse.close()
|
|
572
569
|
raise
|
|
573
570
|
|
|
574
|
-
@contextlib.contextmanager
|
|
575
|
-
def process_feature_module(self):
|
|
576
|
-
# Generate a random name for the feature module
|
|
577
|
-
feature_module_name = "tmp" + _random_string(10)
|
|
578
|
-
# Create a dynamic module with the generated name
|
|
579
|
-
dynamic_module = types.ModuleType(feature_module_name)
|
|
580
|
-
# Get the import lines for the necessary objects from the main module
|
|
581
|
-
main_module = sys.modules["__main__"]
|
|
582
|
-
if getattr(main_module, "__file__", None):
|
|
583
|
-
import_lines = list(get_imports(main_module))
|
|
584
|
-
else:
|
|
585
|
-
import_lines = [
|
|
586
|
-
source.getimport(obj, alias=name)
|
|
587
|
-
for name, obj in main_module.__dict__.items()
|
|
588
|
-
if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
|
|
589
|
-
]
|
|
590
|
-
|
|
591
|
-
# Get the feature classes from the main module
|
|
592
|
-
feature_classes = {
|
|
593
|
-
name: obj
|
|
594
|
-
for name, obj in main_module.__dict__.items()
|
|
595
|
-
if _feature_predicate(obj)
|
|
596
|
-
}
|
|
597
|
-
if not feature_classes:
|
|
598
|
-
yield None
|
|
599
|
-
return
|
|
600
|
-
|
|
601
|
-
# Get the source code of the feature classes
|
|
602
|
-
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
|
|
603
|
-
# Set the module name for the feature classes to the generated name
|
|
604
|
-
for name, cls in feature_classes.items():
|
|
605
|
-
cls.__module__ = feature_module_name
|
|
606
|
-
setattr(dynamic_module, name, cls)
|
|
607
|
-
# Add the dynamic module to the sys.modules dictionary
|
|
608
|
-
sys.modules[feature_module_name] = dynamic_module
|
|
609
|
-
# Combine the import lines and feature sources
|
|
610
|
-
feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
|
|
611
|
-
|
|
612
|
-
# Write the module content to a .py file
|
|
613
|
-
with open(f"{feature_module_name}.py", "w") as module_file:
|
|
614
|
-
module_file.write(feature_file)
|
|
615
|
-
|
|
616
|
-
try:
|
|
617
|
-
yield feature_module_name
|
|
618
|
-
finally:
|
|
619
|
-
for cls in feature_classes.values():
|
|
620
|
-
cls.__module__ = main_module.__name__
|
|
621
|
-
os.unlink(f"{feature_module_name}.py")
|
|
622
|
-
# Remove the dynamic module from sys.modules
|
|
623
|
-
del sys.modules[feature_module_name]
|
|
624
|
-
|
|
625
571
|
def create_partitions_table(self, query: Select) -> "Table":
|
|
626
572
|
"""
|
|
627
573
|
Create temporary table with group by partitions.
|
|
@@ -689,7 +635,7 @@ class UDF(Step, ABC):
|
|
|
689
635
|
|
|
690
636
|
|
|
691
637
|
@frozen
|
|
692
|
-
class UDFSignal(
|
|
638
|
+
class UDFSignal(UDFStep):
|
|
693
639
|
is_generator = False
|
|
694
640
|
|
|
695
641
|
def create_udf_table(self, query: Select) -> "Table":
|
|
@@ -784,7 +730,7 @@ class UDFSignal(UDF):
|
|
|
784
730
|
|
|
785
731
|
|
|
786
732
|
@frozen
|
|
787
|
-
class RowGenerator(
|
|
733
|
+
class RowGenerator(UDFStep):
|
|
788
734
|
"""Extend dataset with new rows."""
|
|
789
735
|
|
|
790
736
|
is_generator = True
|
|
@@ -919,6 +865,18 @@ class SQLCount(SQLClause):
|
|
|
919
865
|
return sqlalchemy.select(f.count(1)).select_from(query.subquery())
|
|
920
866
|
|
|
921
867
|
|
|
868
|
+
@frozen
|
|
869
|
+
class SQLDistinct(SQLClause):
|
|
870
|
+
args: tuple[ColumnElement, ...]
|
|
871
|
+
dialect: str
|
|
872
|
+
|
|
873
|
+
def apply_sql_clause(self, query):
|
|
874
|
+
if self.dialect == "sqlite":
|
|
875
|
+
return query.group_by(*self.args)
|
|
876
|
+
|
|
877
|
+
return query.distinct(*self.args)
|
|
878
|
+
|
|
879
|
+
|
|
922
880
|
@frozen
|
|
923
881
|
class SQLUnion(Step):
|
|
924
882
|
query1: "DatasetQuery"
|
|
@@ -1000,12 +958,15 @@ class SQLJoin(Step):
|
|
|
1000
958
|
|
|
1001
959
|
q1_columns = list(q1.c)
|
|
1002
960
|
q1_column_names = {c.name for c in q1_columns}
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
961
|
+
|
|
962
|
+
q2_columns = []
|
|
963
|
+
for c in q2.c:
|
|
964
|
+
if c.name.startswith("sys__"):
|
|
965
|
+
continue
|
|
966
|
+
|
|
967
|
+
if c.name in q1_column_names:
|
|
968
|
+
c = c.label(self.rname.format(name=c.name))
|
|
969
|
+
q2_columns.append(c)
|
|
1009
970
|
|
|
1010
971
|
res_columns = q1_columns + q2_columns
|
|
1011
972
|
predicates = (
|
|
@@ -1112,6 +1073,7 @@ class DatasetQuery:
|
|
|
1112
1073
|
anon: bool = False,
|
|
1113
1074
|
indexing_feature_schema: Optional[dict] = None,
|
|
1114
1075
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1076
|
+
update: Optional[bool] = False,
|
|
1115
1077
|
):
|
|
1116
1078
|
if client_config is None:
|
|
1117
1079
|
client_config = {}
|
|
@@ -1134,10 +1096,12 @@ class DatasetQuery:
|
|
|
1134
1096
|
self.session = Session.get(session, catalog=catalog)
|
|
1135
1097
|
|
|
1136
1098
|
if path:
|
|
1137
|
-
|
|
1099
|
+
kwargs = {"update": True} if update else {}
|
|
1100
|
+
self.starting_step = IndexingStep(path, self.catalog, kwargs, recursive)
|
|
1138
1101
|
self.feature_schema = indexing_feature_schema
|
|
1139
1102
|
self.column_types = indexing_column_types
|
|
1140
1103
|
elif name:
|
|
1104
|
+
self.name = name
|
|
1141
1105
|
ds = self.catalog.get_dataset(name)
|
|
1142
1106
|
self.version = version or ds.latest_version
|
|
1143
1107
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
@@ -1145,9 +1109,6 @@ class DatasetQuery:
|
|
|
1145
1109
|
if "sys__id" in self.column_types:
|
|
1146
1110
|
self.column_types.pop("sys__id")
|
|
1147
1111
|
self.starting_step = QueryStep(self.catalog, name, self.version)
|
|
1148
|
-
# attaching to specific dataset
|
|
1149
|
-
self.name = name
|
|
1150
|
-
self.version = version
|
|
1151
1112
|
else:
|
|
1152
1113
|
raise ValueError("must provide path or name")
|
|
1153
1114
|
|
|
@@ -1156,7 +1117,7 @@ class DatasetQuery:
|
|
|
1156
1117
|
return bool(re.compile(r"^[a-zA-Z0-9]+://").match(path))
|
|
1157
1118
|
|
|
1158
1119
|
def __iter__(self):
|
|
1159
|
-
return iter(self.
|
|
1120
|
+
return iter(self.db_results())
|
|
1160
1121
|
|
|
1161
1122
|
def __or__(self, other):
|
|
1162
1123
|
return self.union(other)
|
|
@@ -1277,13 +1238,16 @@ class DatasetQuery:
|
|
|
1277
1238
|
warehouse.close()
|
|
1278
1239
|
self.temp_table_names = []
|
|
1279
1240
|
|
|
1280
|
-
def
|
|
1241
|
+
def db_results(self, row_factory=None, **kwargs):
|
|
1281
1242
|
with self.as_iterable(**kwargs) as result:
|
|
1282
1243
|
if row_factory:
|
|
1283
1244
|
cols = result.columns
|
|
1284
1245
|
return [row_factory(cols, r) for r in result]
|
|
1285
1246
|
return list(result)
|
|
1286
1247
|
|
|
1248
|
+
def to_db_records(self) -> list[dict[str, Any]]:
|
|
1249
|
+
return self.db_results(lambda cols, row: dict(zip(cols, row)))
|
|
1250
|
+
|
|
1287
1251
|
@contextlib.contextmanager
|
|
1288
1252
|
def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
|
|
1289
1253
|
try:
|
|
@@ -1343,15 +1307,6 @@ class DatasetQuery:
|
|
|
1343
1307
|
finally:
|
|
1344
1308
|
self.cleanup()
|
|
1345
1309
|
|
|
1346
|
-
def to_records(self) -> list[dict[str, Any]]:
|
|
1347
|
-
return self.results(lambda cols, row: dict(zip(cols, row)))
|
|
1348
|
-
|
|
1349
|
-
def to_pandas(self) -> "pd.DataFrame":
|
|
1350
|
-
records = self.to_records()
|
|
1351
|
-
df = pd.DataFrame.from_records(records)
|
|
1352
|
-
df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
|
|
1353
|
-
return df
|
|
1354
|
-
|
|
1355
1310
|
def shuffle(self) -> "Self":
|
|
1356
1311
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1357
1312
|
return self.order_by(C.sys__rand)
|
|
@@ -1370,22 +1325,6 @@ class DatasetQuery:
|
|
|
1370
1325
|
|
|
1371
1326
|
return sampled.limit(n)
|
|
1372
1327
|
|
|
1373
|
-
def show(self, limit=20) -> None:
|
|
1374
|
-
df = self.limit(limit).to_pandas()
|
|
1375
|
-
|
|
1376
|
-
options = ["display.max_colwidth", 50, "display.show_dimensions", False]
|
|
1377
|
-
with pd.option_context(*options):
|
|
1378
|
-
if inside_notebook():
|
|
1379
|
-
from IPython.display import display
|
|
1380
|
-
|
|
1381
|
-
display(df)
|
|
1382
|
-
|
|
1383
|
-
else:
|
|
1384
|
-
print(df.to_string())
|
|
1385
|
-
|
|
1386
|
-
if len(df) == limit:
|
|
1387
|
-
print(f"[limited by {limit} objects]")
|
|
1388
|
-
|
|
1389
1328
|
def clone(self, new_table=True) -> "Self":
|
|
1390
1329
|
obj = copy(self)
|
|
1391
1330
|
obj.steps = obj.steps.copy()
|
|
@@ -1483,6 +1422,14 @@ class DatasetQuery:
|
|
|
1483
1422
|
query.steps.append(SQLOffset(offset))
|
|
1484
1423
|
return query
|
|
1485
1424
|
|
|
1425
|
+
@detach
|
|
1426
|
+
def distinct(self, *args) -> "Self":
|
|
1427
|
+
query = self.clone()
|
|
1428
|
+
query.steps.append(
|
|
1429
|
+
SQLDistinct(args, dialect=self.catalog.warehouse.db.dialect.name)
|
|
1430
|
+
)
|
|
1431
|
+
return query
|
|
1432
|
+
|
|
1486
1433
|
def as_scalar(self) -> Any:
|
|
1487
1434
|
with self.as_iterable() as rows:
|
|
1488
1435
|
row = next(iter(rows))
|
|
@@ -1781,10 +1728,13 @@ def _send_result(dataset_query: DatasetQuery) -> None:
|
|
|
1781
1728
|
|
|
1782
1729
|
columns = preview_args.get("columns") or []
|
|
1783
1730
|
|
|
1784
|
-
|
|
1785
|
-
dataset_query.select(*columns)
|
|
1786
|
-
|
|
1787
|
-
.
|
|
1731
|
+
if type(dataset_query) is DatasetQuery:
|
|
1732
|
+
preview_query = dataset_query.select(*columns)
|
|
1733
|
+
else:
|
|
1734
|
+
preview_query = dataset_query.select(*columns, _sys=False)
|
|
1735
|
+
|
|
1736
|
+
preview_query = preview_query.limit(preview_args.get("limit", 10)).offset(
|
|
1737
|
+
preview_args.get("offset", 0)
|
|
1788
1738
|
)
|
|
1789
1739
|
|
|
1790
1740
|
dataset: Optional[tuple[str, int]] = None
|
|
@@ -1793,7 +1743,7 @@ def _send_result(dataset_query: DatasetQuery) -> None:
|
|
|
1793
1743
|
assert dataset_query.version, "Dataset version should be provided"
|
|
1794
1744
|
dataset = dataset_query.name, dataset_query.version
|
|
1795
1745
|
|
|
1796
|
-
preview = preview_query.
|
|
1746
|
+
preview = preview_query.to_db_records()
|
|
1797
1747
|
result = ExecutionResult(preview, dataset, metrics)
|
|
1798
1748
|
data = attrs.asdict(result)
|
|
1799
1749
|
|
|
@@ -1853,34 +1803,3 @@ def _random_string(length: int) -> str:
|
|
|
1853
1803
|
random.choice(string.ascii_letters + string.digits) # noqa: S311
|
|
1854
1804
|
for i in range(length)
|
|
1855
1805
|
)
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
def _feature_predicate(obj):
|
|
1859
|
-
return (
|
|
1860
|
-
inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, BaseModel)
|
|
1861
|
-
)
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
def _imports(obj):
|
|
1865
|
-
return not source.isfrommain(obj)
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
def get_imports(m):
|
|
1869
|
-
root = ast.parse(inspect.getsource(m))
|
|
1870
|
-
|
|
1871
|
-
for node in ast.iter_child_nodes(root):
|
|
1872
|
-
if isinstance(node, ast.Import):
|
|
1873
|
-
module = None
|
|
1874
|
-
elif isinstance(node, ast.ImportFrom):
|
|
1875
|
-
module = node.module
|
|
1876
|
-
else:
|
|
1877
|
-
continue
|
|
1878
|
-
|
|
1879
|
-
for n in node.names:
|
|
1880
|
-
import_script = ""
|
|
1881
|
-
if module:
|
|
1882
|
-
import_script += f"from {module} "
|
|
1883
|
-
import_script += f"import {n.name}"
|
|
1884
|
-
if n.asname:
|
|
1885
|
-
import_script += f" as {n.asname}"
|
|
1886
|
-
yield import_script
|
datachain/query/dispatch.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import Any, Optional
|
|
|
10
10
|
|
|
11
11
|
import attrs
|
|
12
12
|
import multiprocess
|
|
13
|
-
from
|
|
13
|
+
from cloudpickle import load, loads
|
|
14
14
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
15
15
|
from multiprocess import get_context
|
|
16
16
|
|
|
@@ -84,7 +84,7 @@ def put_into_queue(queue: Queue, item: Any) -> None:
|
|
|
84
84
|
|
|
85
85
|
def udf_entrypoint() -> int:
|
|
86
86
|
# Load UDF info from stdin
|
|
87
|
-
udf_info = load(stdin.buffer)
|
|
87
|
+
udf_info = load(stdin.buffer)
|
|
88
88
|
|
|
89
89
|
(
|
|
90
90
|
warehouse_class,
|
|
@@ -95,7 +95,7 @@ def udf_entrypoint() -> int:
|
|
|
95
95
|
|
|
96
96
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
97
97
|
dispatch = UDFDispatcher(
|
|
98
|
-
udf_info["
|
|
98
|
+
udf_info["udf_data"],
|
|
99
99
|
udf_info["catalog_init"],
|
|
100
100
|
udf_info["id_generator_clone_params"],
|
|
101
101
|
udf_info["metastore_clone_params"],
|
|
@@ -108,7 +108,7 @@ def udf_entrypoint() -> int:
|
|
|
108
108
|
batching = udf_info["batching"]
|
|
109
109
|
table = udf_info["table"]
|
|
110
110
|
n_workers = udf_info["processes"]
|
|
111
|
-
udf = udf_info["
|
|
111
|
+
udf = loads(udf_info["udf_data"])
|
|
112
112
|
if n_workers is True:
|
|
113
113
|
# Use default number of CPUs (cores)
|
|
114
114
|
n_workers = None
|
|
@@ -146,7 +146,7 @@ class UDFDispatcher:
|
|
|
146
146
|
|
|
147
147
|
def __init__(
|
|
148
148
|
self,
|
|
149
|
-
|
|
149
|
+
udf_data,
|
|
150
150
|
catalog_init_params,
|
|
151
151
|
id_generator_clone_params,
|
|
152
152
|
metastore_clone_params,
|
|
@@ -155,14 +155,7 @@ class UDFDispatcher:
|
|
|
155
155
|
is_generator=False,
|
|
156
156
|
buffer_size=DEFAULT_BATCH_SIZE,
|
|
157
157
|
):
|
|
158
|
-
|
|
159
|
-
# and so these two types are not considered exactly equal,
|
|
160
|
-
# even if they have the same import path.
|
|
161
|
-
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
|
|
162
|
-
self.udf = udf
|
|
163
|
-
else:
|
|
164
|
-
self.udf = None
|
|
165
|
-
self.udf_factory = udf
|
|
158
|
+
self.udf_data = udf_data
|
|
166
159
|
self.catalog_init_params = catalog_init_params
|
|
167
160
|
(
|
|
168
161
|
self.id_generator_class,
|
|
@@ -214,6 +207,15 @@ class UDFDispatcher:
|
|
|
214
207
|
self.catalog = Catalog(
|
|
215
208
|
id_generator, metastore, warehouse, **self.catalog_init_params
|
|
216
209
|
)
|
|
210
|
+
udf = loads(self.udf_data)
|
|
211
|
+
# isinstance cannot be used here, as cloudpickle packages the entire class
|
|
212
|
+
# definition, and so these two types are not considered exactly equal,
|
|
213
|
+
# even if they have the same import path.
|
|
214
|
+
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
|
|
215
|
+
self.udf = udf
|
|
216
|
+
else:
|
|
217
|
+
self.udf = None
|
|
218
|
+
self.udf_factory = udf
|
|
217
219
|
if not self.udf:
|
|
218
220
|
self.udf = self.udf_factory()
|
|
219
221
|
|
datachain/query/schema.py
CHANGED
|
@@ -32,6 +32,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
|
32
32
|
inherit_cache: Optional[bool] = True
|
|
33
33
|
|
|
34
34
|
def __init__(self, text, type_=None, is_literal=False, _selectable=None):
|
|
35
|
+
"""Dataset column."""
|
|
35
36
|
self.name = ColumnMeta.to_db_name(text)
|
|
36
37
|
super().__init__(
|
|
37
38
|
self.name, type_=type_, is_literal=is_literal, _selectable=_selectable
|
|
@@ -41,6 +42,7 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
|
41
42
|
return Column(self.name + DEFAULT_DELIMITER + name)
|
|
42
43
|
|
|
43
44
|
def glob(self, glob_str):
|
|
45
|
+
"""Search for matches using glob pattern matching."""
|
|
44
46
|
return self.op("GLOB")(glob_str)
|
|
45
47
|
|
|
46
48
|
|
datachain/query/session.py
CHANGED
|
@@ -28,9 +28,9 @@ class Session:
|
|
|
28
28
|
|
|
29
29
|
Parameters:
|
|
30
30
|
|
|
31
|
-
|
|
31
|
+
name (str): The name of the session. Only latters and numbers are supported.
|
|
32
32
|
It can be empty.
|
|
33
|
-
|
|
33
|
+
catalog (Catalog): Catalog object.
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
GLOBAL_SESSION_CTX: Optional["Session"] = None
|
|
@@ -80,9 +80,9 @@ class Session:
|
|
|
80
80
|
"""Creates a Session() object from a catalog.
|
|
81
81
|
|
|
82
82
|
Parameters:
|
|
83
|
-
|
|
83
|
+
session (Session): Optional Session(). If not provided a new session will
|
|
84
84
|
be created. It's needed mostly for simplie API purposes.
|
|
85
|
-
|
|
85
|
+
catalog (Catalog): Optional catalog. By default a new catalog is created.
|
|
86
86
|
"""
|
|
87
87
|
if session:
|
|
88
88
|
return session
|
datachain/sql/functions/array.py
CHANGED
|
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class cosine_distance(GenericFunction): # noqa: N801
|
|
8
|
+
"""
|
|
9
|
+
Takes a column and array and returns the cosine distance between them.
|
|
10
|
+
"""
|
|
11
|
+
|
|
8
12
|
type = Float()
|
|
9
13
|
package = "array"
|
|
10
14
|
name = "cosine_distance"
|
|
@@ -12,6 +16,10 @@ class cosine_distance(GenericFunction): # noqa: N801
|
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class euclidean_distance(GenericFunction): # noqa: N801
|
|
19
|
+
"""
|
|
20
|
+
Takes a column and array and returns the Euclidean distance between them.
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
type = Float()
|
|
16
24
|
package = "array"
|
|
17
25
|
name = "euclidean_distance"
|
|
@@ -19,6 +27,10 @@ class euclidean_distance(GenericFunction): # noqa: N801
|
|
|
19
27
|
|
|
20
28
|
|
|
21
29
|
class length(GenericFunction): # noqa: N801
|
|
30
|
+
"""
|
|
31
|
+
Returns the length of the array.
|
|
32
|
+
"""
|
|
33
|
+
|
|
22
34
|
type = Int64()
|
|
23
35
|
package = "array"
|
|
24
36
|
name = "length"
|
|
@@ -5,6 +5,10 @@ from datachain.sql.utils import compiler_not_implemented
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class length(GenericFunction): # noqa: N801
|
|
8
|
+
"""
|
|
9
|
+
Returns the length of the string.
|
|
10
|
+
"""
|
|
11
|
+
|
|
8
12
|
type = Int64()
|
|
9
13
|
package = "string"
|
|
10
14
|
name = "length"
|
|
@@ -12,6 +16,10 @@ class length(GenericFunction): # noqa: N801
|
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class split(GenericFunction): # noqa: N801
|
|
19
|
+
"""
|
|
20
|
+
Takes a column and split character and returns an array of the parts.
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
type = Array(String())
|
|
16
24
|
package = "string"
|
|
17
25
|
name = "split"
|
datachain/torch/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
try:
|
|
2
|
-
from datachain.lib.clip import
|
|
2
|
+
from datachain.lib.clip import clip_similarity_scores
|
|
3
3
|
from datachain.lib.image import convert_image, convert_images
|
|
4
4
|
from datachain.lib.pytorch import PytorchDataset, label_to_int
|
|
5
5
|
from datachain.lib.text import convert_text
|
datachain/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import glob
|
|
2
2
|
import importlib.util
|
|
3
|
+
import io
|
|
3
4
|
import json
|
|
4
5
|
import os
|
|
5
6
|
import os.path as osp
|
|
@@ -13,8 +14,10 @@ from itertools import islice
|
|
|
13
14
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
14
15
|
from uuid import UUID
|
|
15
16
|
|
|
17
|
+
import cloudpickle
|
|
16
18
|
from dateutil import tz
|
|
17
19
|
from dateutil.parser import isoparse
|
|
20
|
+
from pydantic import BaseModel
|
|
18
21
|
|
|
19
22
|
if TYPE_CHECKING:
|
|
20
23
|
import pandas as pd
|
|
@@ -388,3 +391,45 @@ def inside_notebook() -> bool:
|
|
|
388
391
|
return False
|
|
389
392
|
|
|
390
393
|
return False
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def get_all_subclasses(cls):
|
|
397
|
+
"""Return all subclasses of a given class.
|
|
398
|
+
Can return duplicates due to multiple inheritance."""
|
|
399
|
+
for subclass in cls.__subclasses__():
|
|
400
|
+
yield from get_all_subclasses(subclass)
|
|
401
|
+
yield subclass
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def filtered_cloudpickle_dumps(obj: Any) -> bytes:
|
|
405
|
+
"""Equivalent to cloudpickle.dumps, but this supports Pydantic models."""
|
|
406
|
+
model_namespaces = {}
|
|
407
|
+
|
|
408
|
+
with io.BytesIO() as f:
|
|
409
|
+
pickler = cloudpickle.CloudPickler(f)
|
|
410
|
+
|
|
411
|
+
for model_class in get_all_subclasses(BaseModel):
|
|
412
|
+
# This "is not None" check is needed, because due to multiple inheritance,
|
|
413
|
+
# it is theoretically possible to get the same class twice from
|
|
414
|
+
# get_all_subclasses.
|
|
415
|
+
if model_class.__pydantic_parent_namespace__ is not None:
|
|
416
|
+
# __pydantic_parent_namespace__ can contain many unnecessary and
|
|
417
|
+
# unpickleable entities, so should be removed for serialization.
|
|
418
|
+
model_namespaces[model_class] = (
|
|
419
|
+
model_class.__pydantic_parent_namespace__
|
|
420
|
+
)
|
|
421
|
+
model_class.__pydantic_parent_namespace__ = None
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
pickler.dump(obj)
|
|
425
|
+
return f.getvalue()
|
|
426
|
+
finally:
|
|
427
|
+
for model_class, namespace in model_namespaces.items():
|
|
428
|
+
# Restore original __pydantic_parent_namespace__ locally.
|
|
429
|
+
model_class.__pydantic_parent_namespace__ = namespace
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def get_datachain_executable() -> list[str]:
|
|
433
|
+
if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"):
|
|
434
|
+
return [datachain_exec_path]
|
|
435
|
+
return [sys.executable, "-m", "datachain"]
|