datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +16 -6
- datachain/cache.py +32 -10
- datachain/catalog/catalog.py +17 -1
- datachain/cli/__init__.py +311 -0
- datachain/cli/commands/__init__.py +29 -0
- datachain/cli/commands/datasets.py +129 -0
- datachain/cli/commands/du.py +14 -0
- datachain/cli/commands/index.py +12 -0
- datachain/cli/commands/ls.py +169 -0
- datachain/cli/commands/misc.py +28 -0
- datachain/cli/commands/query.py +53 -0
- datachain/cli/commands/show.py +38 -0
- datachain/cli/parser/__init__.py +547 -0
- datachain/cli/parser/job.py +120 -0
- datachain/cli/parser/studio.py +126 -0
- datachain/cli/parser/utils.py +63 -0
- datachain/{cli_utils.py → cli/utils.py} +27 -1
- datachain/client/azure.py +6 -2
- datachain/client/fsspec.py +9 -3
- datachain/client/gcs.py +6 -2
- datachain/client/s3.py +16 -1
- datachain/data_storage/db_engine.py +9 -0
- datachain/data_storage/schema.py +4 -10
- datachain/data_storage/sqlite.py +7 -1
- datachain/data_storage/warehouse.py +6 -4
- datachain/{lib/diff.py → diff/__init__.py} +116 -12
- datachain/func/__init__.py +3 -2
- datachain/func/conditional.py +74 -0
- datachain/func/func.py +5 -1
- datachain/lib/arrow.py +7 -1
- datachain/lib/dc.py +8 -3
- datachain/lib/file.py +16 -5
- datachain/lib/hf.py +1 -1
- datachain/lib/listing.py +19 -1
- datachain/lib/pytorch.py +57 -13
- datachain/lib/signal_schema.py +89 -27
- datachain/lib/udf.py +82 -40
- datachain/listing.py +1 -0
- datachain/progress.py +20 -3
- datachain/query/dataset.py +122 -93
- datachain/query/dispatch.py +22 -16
- datachain/studio.py +58 -38
- datachain/utils.py +14 -3
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
- datachain/cli.py +0 -1475
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
def add_studio_parser(subparsers, parent_parser) -> None:
|
|
2
|
+
studio_help = "Commands to authenticate DataChain with Iterative Studio"
|
|
3
|
+
studio_description = (
|
|
4
|
+
"Authenticate DataChain with Studio and set the token. "
|
|
5
|
+
"Once this token has been properly configured,\n"
|
|
6
|
+
"DataChain will utilize it for seamlessly sharing datasets\n"
|
|
7
|
+
"and using Studio features from CLI"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
studio_parser = subparsers.add_parser(
|
|
11
|
+
"studio",
|
|
12
|
+
parents=[parent_parser],
|
|
13
|
+
description=studio_description,
|
|
14
|
+
help=studio_help,
|
|
15
|
+
)
|
|
16
|
+
studio_subparser = studio_parser.add_subparsers(
|
|
17
|
+
dest="cmd",
|
|
18
|
+
help="Use `DataChain studio CMD --help` to display command-specific help.",
|
|
19
|
+
required=True,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
studio_login_help = "Authenticate DataChain with Studio host"
|
|
23
|
+
studio_login_description = (
|
|
24
|
+
"By default, this command authenticates the DataChain with Studio\n"
|
|
25
|
+
"using default scopes and assigns a random name as the token name."
|
|
26
|
+
)
|
|
27
|
+
login_parser = studio_subparser.add_parser(
|
|
28
|
+
"login",
|
|
29
|
+
parents=[parent_parser],
|
|
30
|
+
description=studio_login_description,
|
|
31
|
+
help=studio_login_help,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
login_parser.add_argument(
|
|
35
|
+
"-H",
|
|
36
|
+
"--hostname",
|
|
37
|
+
action="store",
|
|
38
|
+
default=None,
|
|
39
|
+
help="The hostname of the Studio instance to authenticate with.",
|
|
40
|
+
)
|
|
41
|
+
login_parser.add_argument(
|
|
42
|
+
"-s",
|
|
43
|
+
"--scopes",
|
|
44
|
+
action="store",
|
|
45
|
+
default=None,
|
|
46
|
+
help="The scopes for the authentication token. ",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
login_parser.add_argument(
|
|
50
|
+
"-n",
|
|
51
|
+
"--name",
|
|
52
|
+
action="store",
|
|
53
|
+
default=None,
|
|
54
|
+
help="The name of the authentication token. It will be used to\n"
|
|
55
|
+
"identify token shown in Studio profile.",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
login_parser.add_argument(
|
|
59
|
+
"--no-open",
|
|
60
|
+
action="store_true",
|
|
61
|
+
default=False,
|
|
62
|
+
help="Use authentication flow based on user code.\n"
|
|
63
|
+
"You will be presented with user code to enter in browser.\n"
|
|
64
|
+
"DataChain will also use this if it cannot launch browser on your behalf.",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
studio_logout_help = "Logout user from Studio"
|
|
68
|
+
studio_logout_description = "This removes the studio token from your global config."
|
|
69
|
+
|
|
70
|
+
studio_subparser.add_parser(
|
|
71
|
+
"logout",
|
|
72
|
+
parents=[parent_parser],
|
|
73
|
+
description=studio_logout_description,
|
|
74
|
+
help=studio_logout_help,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
studio_team_help = "Set the default team for DataChain"
|
|
78
|
+
studio_team_description = (
|
|
79
|
+
"Set the default team for DataChain to use when interacting with Studio."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
team_parser = studio_subparser.add_parser(
|
|
83
|
+
"team",
|
|
84
|
+
parents=[parent_parser],
|
|
85
|
+
description=studio_team_description,
|
|
86
|
+
help=studio_team_help,
|
|
87
|
+
)
|
|
88
|
+
team_parser.add_argument(
|
|
89
|
+
"team_name",
|
|
90
|
+
action="store",
|
|
91
|
+
help="The name of the team to set as the default.",
|
|
92
|
+
)
|
|
93
|
+
team_parser.add_argument(
|
|
94
|
+
"--global",
|
|
95
|
+
action="store_true",
|
|
96
|
+
default=False,
|
|
97
|
+
help="Set the team globally for all DataChain projects.",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
studio_token_help = "View the token datachain uses to contact Studio" # noqa: S105 # nosec B105
|
|
101
|
+
|
|
102
|
+
studio_subparser.add_parser(
|
|
103
|
+
"token",
|
|
104
|
+
parents=[parent_parser],
|
|
105
|
+
description=studio_token_help,
|
|
106
|
+
help=studio_token_help,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
studio_ls_dataset_help = "List the available datasets from Studio"
|
|
110
|
+
studio_ls_dataset_description = (
|
|
111
|
+
"This command lists all the datasets available in Studio.\n"
|
|
112
|
+
"It will show the dataset name and the number of versions available."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
ls_dataset_parser = studio_subparser.add_parser(
|
|
116
|
+
"dataset",
|
|
117
|
+
parents=[parent_parser],
|
|
118
|
+
description=studio_ls_dataset_description,
|
|
119
|
+
help=studio_ls_dataset_help,
|
|
120
|
+
)
|
|
121
|
+
ls_dataset_parser.add_argument(
|
|
122
|
+
"--team",
|
|
123
|
+
action="store",
|
|
124
|
+
default=None,
|
|
125
|
+
help="The team to list datasets for. By default, it will use team from config.",
|
|
126
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from argparse import Action, ArgumentParser, ArgumentTypeError
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from datachain.cli.utils import CommaSeparatedArgs
|
|
5
|
+
|
|
6
|
+
FIND_COLUMNS = ["du", "name", "path", "size", "type"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def find_columns_type(
|
|
10
|
+
columns_str: str,
|
|
11
|
+
default_colums_str: str = "path",
|
|
12
|
+
) -> list[str]:
|
|
13
|
+
if not columns_str:
|
|
14
|
+
columns_str = default_colums_str
|
|
15
|
+
|
|
16
|
+
return [parse_find_column(c) for c in columns_str.split(",")]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def parse_find_column(column: str) -> str:
|
|
20
|
+
column_lower = column.strip().lower()
|
|
21
|
+
if column_lower in FIND_COLUMNS:
|
|
22
|
+
return column_lower
|
|
23
|
+
raise ArgumentTypeError(
|
|
24
|
+
f"Invalid column for find: '{column}' Options are: {','.join(FIND_COLUMNS)}"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Action:
|
|
29
|
+
return parser.add_argument(
|
|
30
|
+
"sources",
|
|
31
|
+
type=str,
|
|
32
|
+
nargs=nargs,
|
|
33
|
+
help="Data sources - paths to cloud storage dirs",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def add_show_args(parser: ArgumentParser) -> None:
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--limit",
|
|
40
|
+
action="store",
|
|
41
|
+
default=10,
|
|
42
|
+
type=int,
|
|
43
|
+
help="Number of rows to show",
|
|
44
|
+
)
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--offset",
|
|
47
|
+
action="store",
|
|
48
|
+
default=0,
|
|
49
|
+
type=int,
|
|
50
|
+
help="Number of rows to offset",
|
|
51
|
+
)
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--columns",
|
|
54
|
+
default=[],
|
|
55
|
+
action=CommaSeparatedArgs,
|
|
56
|
+
help="Columns to show",
|
|
57
|
+
)
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--no-collapse",
|
|
60
|
+
action="store_true",
|
|
61
|
+
default=False,
|
|
62
|
+
help="Do not collapse the columns",
|
|
63
|
+
)
|
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from argparse import SUPPRESS, Action, ArgumentError, Namespace, _AppendAction
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from datachain.error import DataChainError
|
|
2
6
|
|
|
3
7
|
|
|
4
8
|
class BooleanOptionalAction(Action):
|
|
@@ -70,3 +74,25 @@ class KeyValueArgs(_AppendAction): # pylint: disable=protected-access
|
|
|
70
74
|
items[key.strip()] = value
|
|
71
75
|
|
|
72
76
|
setattr(namespace, self.dest, items)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_logging_level(args: Namespace) -> int:
|
|
80
|
+
if args.quiet:
|
|
81
|
+
return logging.CRITICAL
|
|
82
|
+
if args.verbose:
|
|
83
|
+
return logging.DEBUG
|
|
84
|
+
return logging.INFO
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
|
|
88
|
+
if studio and not token:
|
|
89
|
+
raise DataChainError(
|
|
90
|
+
"Not logged in to Studio. Log in with 'datachain studio login'."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if local or studio:
|
|
94
|
+
all = False
|
|
95
|
+
|
|
96
|
+
all = all and not (local or studio)
|
|
97
|
+
|
|
98
|
+
return all, local, studio
|
datachain/client/azure.py
CHANGED
|
@@ -31,8 +31,12 @@ class AzureClient(Client):
|
|
|
31
31
|
Generate a signed URL for the given path.
|
|
32
32
|
"""
|
|
33
33
|
version_id = kwargs.pop("version_id", None)
|
|
34
|
+
content_disposition = kwargs.pop("content_disposition", None)
|
|
34
35
|
result = self.fs.sign(
|
|
35
|
-
self.get_full_path(path, version_id),
|
|
36
|
+
self.get_full_path(path, version_id),
|
|
37
|
+
expiration=expires,
|
|
38
|
+
content_disposition=content_disposition,
|
|
39
|
+
**kwargs,
|
|
36
40
|
)
|
|
37
41
|
return result + (f"&versionid={version_id}" if version_id else "")
|
|
38
42
|
|
|
@@ -42,7 +46,7 @@ class AzureClient(Client):
|
|
|
42
46
|
prefix = prefix.lstrip(DELIMITER) + DELIMITER
|
|
43
47
|
found = False
|
|
44
48
|
try:
|
|
45
|
-
with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
|
|
49
|
+
with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
|
|
46
50
|
async with self.fs.service_client.get_container_client(
|
|
47
51
|
container=self.name
|
|
48
52
|
) as container_client:
|
datachain/client/fsspec.py
CHANGED
|
@@ -215,7 +215,7 @@ class Client(ABC):
|
|
|
215
215
|
info = await self.fs._info(
|
|
216
216
|
self.get_full_path(file.path, file.version), **kwargs
|
|
217
217
|
)
|
|
218
|
-
return self.info_to_file(info,
|
|
218
|
+
return self.info_to_file(info, file.path).etag
|
|
219
219
|
|
|
220
220
|
def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File":
|
|
221
221
|
info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
|
|
@@ -249,7 +249,7 @@ class Client(ABC):
|
|
|
249
249
|
await main_task
|
|
250
250
|
|
|
251
251
|
async def _fetch_nested(self, start_prefix: str, result_queue: ResultQueue) -> None:
|
|
252
|
-
progress_bar = tqdm(desc=f"Listing {self.uri}", unit=" objects")
|
|
252
|
+
progress_bar = tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False)
|
|
253
253
|
loop = get_loop()
|
|
254
254
|
|
|
255
255
|
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
@@ -343,7 +343,7 @@ class Client(ABC):
|
|
|
343
343
|
return self.version_path(f"{self.PREFIX}{self.name}/{rel_path}", version_id)
|
|
344
344
|
|
|
345
345
|
@abstractmethod
|
|
346
|
-
def info_to_file(self, v: dict[str, Any],
|
|
346
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> "File": ...
|
|
347
347
|
|
|
348
348
|
def fetch_nodes(
|
|
349
349
|
self,
|
|
@@ -390,6 +390,12 @@ class Client(ABC):
|
|
|
390
390
|
self.fs.open(self.get_full_path(file.path, file.version)), cb
|
|
391
391
|
) # type: ignore[return-value]
|
|
392
392
|
|
|
393
|
+
def upload(self, path: str, data: bytes) -> "File":
|
|
394
|
+
full_path = self.get_full_path(path)
|
|
395
|
+
self.fs.pipe_file(full_path, data)
|
|
396
|
+
file_info = self.fs.info(full_path)
|
|
397
|
+
return self.info_to_file(file_info, path)
|
|
398
|
+
|
|
393
399
|
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
394
400
|
sync(get_loop(), functools.partial(self._download, file, callback=callback))
|
|
395
401
|
|
datachain/client/gcs.py
CHANGED
|
@@ -39,11 +39,15 @@ class GCSClient(Client):
|
|
|
39
39
|
(see https://cloud.google.com/storage/docs/access-public-data#api-link).
|
|
40
40
|
"""
|
|
41
41
|
version_id = kwargs.pop("version_id", None)
|
|
42
|
+
content_disposition = kwargs.pop("content_disposition", None)
|
|
42
43
|
if self.fs.storage_options.get("token") == "anon":
|
|
43
44
|
query = f"?generation={version_id}" if version_id else ""
|
|
44
45
|
return f"https://storage.googleapis.com/{self.name}/{path}{query}"
|
|
45
46
|
return self.fs.sign(
|
|
46
|
-
self.get_full_path(path, version_id),
|
|
47
|
+
self.get_full_path(path, version_id),
|
|
48
|
+
expiration=expires,
|
|
49
|
+
response_disposition=content_disposition,
|
|
50
|
+
**kwargs,
|
|
47
51
|
)
|
|
48
52
|
|
|
49
53
|
@staticmethod
|
|
@@ -83,7 +87,7 @@ class GCSClient(Client):
|
|
|
83
87
|
self, page_queue: PageQueue, result_queue: ResultQueue
|
|
84
88
|
) -> bool:
|
|
85
89
|
found = False
|
|
86
|
-
with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
|
|
90
|
+
with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
|
|
87
91
|
while (page := await page_queue.get()) is not None:
|
|
88
92
|
if page:
|
|
89
93
|
found = True
|
datachain/client/s3.py
CHANGED
|
@@ -51,6 +51,21 @@ class ClientS3(Client):
|
|
|
51
51
|
|
|
52
52
|
return cast(S3FileSystem, super().create_fs(**kwargs))
|
|
53
53
|
|
|
54
|
+
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Generate a signed URL for the given path.
|
|
57
|
+
"""
|
|
58
|
+
version_id = kwargs.pop("version_id", None)
|
|
59
|
+
content_disposition = kwargs.pop("content_disposition", None)
|
|
60
|
+
if content_disposition:
|
|
61
|
+
kwargs["ResponseContentDisposition"] = content_disposition
|
|
62
|
+
|
|
63
|
+
return self.fs.sign(
|
|
64
|
+
self.get_full_path(path, version_id),
|
|
65
|
+
expiration=expires,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
68
|
+
|
|
54
69
|
async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
|
|
55
70
|
async def get_pages(it, page_queue):
|
|
56
71
|
try:
|
|
@@ -61,7 +76,7 @@ class ClientS3(Client):
|
|
|
61
76
|
|
|
62
77
|
async def process_pages(page_queue, result_queue):
|
|
63
78
|
found = False
|
|
64
|
-
with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
|
|
79
|
+
with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
|
|
65
80
|
while (res := await page_queue.get()) is not None:
|
|
66
81
|
if res:
|
|
67
82
|
found = True
|
|
@@ -79,6 +79,15 @@ class DatabaseEngine(ABC, Serializable):
|
|
|
79
79
|
conn: Optional[Any] = None,
|
|
80
80
|
) -> Iterator[tuple[Any, ...]]: ...
|
|
81
81
|
|
|
82
|
+
def get_table(self, name: str) -> "Table":
|
|
83
|
+
table = self.metadata.tables.get(name)
|
|
84
|
+
if table is None:
|
|
85
|
+
sa.Table(name, self.metadata, autoload_with=self.engine)
|
|
86
|
+
# ^^^ This table may not be correctly initialised on some dialects
|
|
87
|
+
# Grab it from metadata instead.
|
|
88
|
+
table = self.metadata.tables[name]
|
|
89
|
+
return table
|
|
90
|
+
|
|
82
91
|
@abstractmethod
|
|
83
92
|
def executemany(
|
|
84
93
|
self, query, params, cursor: Optional[Any] = None
|
datachain/data_storage/schema.py
CHANGED
|
@@ -16,7 +16,6 @@ from datachain.sql.functions import path as pathfunc
|
|
|
16
16
|
from datachain.sql.types import Int, SQLType, UInt64
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
|
-
from sqlalchemy import Engine
|
|
20
19
|
from sqlalchemy.engine.interfaces import Dialect
|
|
21
20
|
from sqlalchemy.sql.base import (
|
|
22
21
|
ColumnCollection,
|
|
@@ -25,6 +24,8 @@ if TYPE_CHECKING:
|
|
|
25
24
|
)
|
|
26
25
|
from sqlalchemy.sql.elements import ColumnElement
|
|
27
26
|
|
|
27
|
+
from datachain.data_storage.db_engine import DatabaseEngine
|
|
28
|
+
|
|
28
29
|
|
|
29
30
|
DEFAULT_DELIMITER = "__"
|
|
30
31
|
|
|
@@ -150,14 +151,12 @@ class DataTable:
|
|
|
150
151
|
def __init__(
|
|
151
152
|
self,
|
|
152
153
|
name: str,
|
|
153
|
-
engine: "
|
|
154
|
-
metadata: Optional["sa.MetaData"] = None,
|
|
154
|
+
engine: "DatabaseEngine",
|
|
155
155
|
column_types: Optional[dict[str, SQLType]] = None,
|
|
156
156
|
object_name: str = "file",
|
|
157
157
|
):
|
|
158
158
|
self.name: str = name
|
|
159
159
|
self.engine = engine
|
|
160
|
-
self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
|
|
161
160
|
self.column_types: dict[str, SQLType] = column_types or {}
|
|
162
161
|
self.object_name = object_name
|
|
163
162
|
|
|
@@ -211,12 +210,7 @@ class DataTable:
|
|
|
211
210
|
return sa.Table(name, metadata, *columns)
|
|
212
211
|
|
|
213
212
|
def get_table(self) -> "sa.Table":
|
|
214
|
-
table = self.
|
|
215
|
-
if table is None:
|
|
216
|
-
sa.Table(self.name, self.metadata, autoload_with=self.engine)
|
|
217
|
-
# ^^^ This table may not be correctly initialised on some dialects
|
|
218
|
-
# Grab it from metadata instead.
|
|
219
|
-
table = self.metadata.tables[self.name]
|
|
213
|
+
table = self.engine.get_table(self.name)
|
|
220
214
|
|
|
221
215
|
column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
|
|
222
216
|
# adjusting types for custom columns to be instances of SQLType if possible
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -186,6 +186,12 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
186
186
|
self.db_file = db_file
|
|
187
187
|
self.is_closed = False
|
|
188
188
|
|
|
189
|
+
def get_table(self, name: str) -> Table:
|
|
190
|
+
if self.is_closed:
|
|
191
|
+
# Reconnect in case of being closed previously.
|
|
192
|
+
self._reconnect()
|
|
193
|
+
return super().get_table(name)
|
|
194
|
+
|
|
189
195
|
@retry_sqlite_locks
|
|
190
196
|
def execute(
|
|
191
197
|
self,
|
|
@@ -670,7 +676,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
670
676
|
]
|
|
671
677
|
table = self.create_udf_table(columns)
|
|
672
678
|
|
|
673
|
-
with tqdm(desc="Preparing", unit=" rows") as pbar:
|
|
679
|
+
with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar:
|
|
674
680
|
self.copy_table(table, query, progress_cb=pbar.update)
|
|
675
681
|
|
|
676
682
|
return table
|
|
@@ -191,8 +191,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
191
191
|
table_name = self.dataset_table_name(dataset.name, version)
|
|
192
192
|
return self.schema.dataset_row_cls(
|
|
193
193
|
table_name,
|
|
194
|
-
self.db
|
|
195
|
-
self.db.metadata,
|
|
194
|
+
self.db,
|
|
196
195
|
dataset.get_schema(version),
|
|
197
196
|
object_name=object_name,
|
|
198
197
|
)
|
|
@@ -904,8 +903,11 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
904
903
|
This should be implemented to ensure that the provided tables
|
|
905
904
|
are cleaned up as soon as they are no longer needed.
|
|
906
905
|
"""
|
|
907
|
-
|
|
908
|
-
|
|
906
|
+
to_drop = set(names)
|
|
907
|
+
with tqdm(
|
|
908
|
+
desc="Cleanup", unit=" tables", total=len(to_drop), leave=False
|
|
909
|
+
) as pbar:
|
|
910
|
+
for name in to_drop:
|
|
909
911
|
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
|
|
910
912
|
pbar.update(1)
|
|
911
913
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
2
|
import string
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import TYPE_CHECKING, Optional, Union
|
|
5
6
|
|
|
6
7
|
import sqlalchemy as sa
|
|
@@ -16,7 +17,22 @@ if TYPE_CHECKING:
|
|
|
16
17
|
C = Column
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def
|
|
20
|
+
def get_status_col_name() -> str:
|
|
21
|
+
"""Returns new unique status col name"""
|
|
22
|
+
return "diff_" + "".join(
|
|
23
|
+
random.choice(string.ascii_letters) # noqa: S311
|
|
24
|
+
for _ in range(10)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CompareStatus(str, Enum):
|
|
29
|
+
ADDED = "A"
|
|
30
|
+
DELETED = "D"
|
|
31
|
+
MODIFIED = "M"
|
|
32
|
+
SAME = "S"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _compare( # noqa: PLR0912, PLR0915, C901
|
|
20
36
|
left: "DataChain",
|
|
21
37
|
right: "DataChain",
|
|
22
38
|
on: Union[str, Sequence[str]],
|
|
@@ -72,13 +88,10 @@ def compare( # noqa: PLR0912, PLR0915, C901
|
|
|
72
88
|
"At least one of added, deleted, modified, same flags must be set"
|
|
73
89
|
)
|
|
74
90
|
|
|
75
|
-
# we still need status column for internal implementation even if not
|
|
76
|
-
# needed in output
|
|
77
91
|
need_status_col = bool(status_col)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
)
|
|
92
|
+
# we still need status column for internal implementation even if not
|
|
93
|
+
# needed in the output
|
|
94
|
+
status_col = status_col or get_status_col_name()
|
|
82
95
|
|
|
83
96
|
# calculate on and compare column names
|
|
84
97
|
right_on = right_on or on
|
|
@@ -112,7 +125,7 @@ def compare( # noqa: PLR0912, PLR0915, C901
|
|
|
112
125
|
for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
|
|
113
126
|
]
|
|
114
127
|
)
|
|
115
|
-
diff_cond.append((added_cond,
|
|
128
|
+
diff_cond.append((added_cond, CompareStatus.ADDED))
|
|
116
129
|
if modified and compare:
|
|
117
130
|
modified_cond = sa.or_(
|
|
118
131
|
*[
|
|
@@ -120,7 +133,7 @@ def compare( # noqa: PLR0912, PLR0915, C901
|
|
|
120
133
|
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
|
|
121
134
|
]
|
|
122
135
|
)
|
|
123
|
-
diff_cond.append((modified_cond,
|
|
136
|
+
diff_cond.append((modified_cond, CompareStatus.MODIFIED))
|
|
124
137
|
if same and compare:
|
|
125
138
|
same_cond = sa.and_(
|
|
126
139
|
*[
|
|
@@ -128,9 +141,11 @@ def compare( # noqa: PLR0912, PLR0915, C901
|
|
|
128
141
|
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
|
|
129
142
|
]
|
|
130
143
|
)
|
|
131
|
-
diff_cond.append((same_cond,
|
|
144
|
+
diff_cond.append((same_cond, CompareStatus.SAME))
|
|
132
145
|
|
|
133
|
-
diff = sa.case(*diff_cond, else_=None if compare else
|
|
146
|
+
diff = sa.case(*diff_cond, else_=None if compare else CompareStatus.MODIFIED).label(
|
|
147
|
+
status_col
|
|
148
|
+
)
|
|
134
149
|
diff.type = String()
|
|
135
150
|
|
|
136
151
|
left_right_merge = left.merge(
|
|
@@ -145,7 +160,7 @@ def compare( # noqa: PLR0912, PLR0915, C901
|
|
|
145
160
|
)
|
|
146
161
|
)
|
|
147
162
|
|
|
148
|
-
diff_col = sa.literal(
|
|
163
|
+
diff_col = sa.literal(CompareStatus.DELETED).label(status_col)
|
|
149
164
|
diff_col.type = String()
|
|
150
165
|
|
|
151
166
|
right_left_merge = right.merge(
|
|
@@ -195,3 +210,92 @@ def compare( # noqa: PLR0912, PLR0915, C901
|
|
|
195
210
|
res = res.select_except(C(status_col))
|
|
196
211
|
|
|
197
212
|
return left._evolve(query=res, signal_schema=schema)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def compare_and_split(
|
|
216
|
+
left: "DataChain",
|
|
217
|
+
right: "DataChain",
|
|
218
|
+
on: Union[str, Sequence[str]],
|
|
219
|
+
right_on: Optional[Union[str, Sequence[str]]] = None,
|
|
220
|
+
compare: Optional[Union[str, Sequence[str]]] = None,
|
|
221
|
+
right_compare: Optional[Union[str, Sequence[str]]] = None,
|
|
222
|
+
added: bool = True,
|
|
223
|
+
deleted: bool = True,
|
|
224
|
+
modified: bool = True,
|
|
225
|
+
same: bool = False,
|
|
226
|
+
) -> dict[str, "DataChain"]:
|
|
227
|
+
"""Comparing two chains and returning multiple chains, one for each of `added`,
|
|
228
|
+
`deleted`, `modified` and `same` status. Result is returned in form of
|
|
229
|
+
dictionary where each item represents one of the statuses and key values
|
|
230
|
+
are `A`, `D`, `M`, `S` corresponding. Note that status column is not in the
|
|
231
|
+
resulting chains.
|
|
232
|
+
|
|
233
|
+
Parameters:
|
|
234
|
+
left: Chain to calculate diff on.
|
|
235
|
+
right: Chain to calculate diff from.
|
|
236
|
+
on: Column or list of columns to match on. If both chains have the
|
|
237
|
+
same columns then this column is enough for the match. Otherwise,
|
|
238
|
+
`right_on` parameter has to specify the columns for the other chain.
|
|
239
|
+
This value is used to find corresponding row in other dataset. If not
|
|
240
|
+
found there, row is considered as added (or removed if vice versa), and
|
|
241
|
+
if found then row can be either modified or same.
|
|
242
|
+
right_on: Optional column or list of columns
|
|
243
|
+
for the `other` to match.
|
|
244
|
+
compare: Column or list of columns to compare on. If both chains have
|
|
245
|
+
the same columns then this column is enough for the compare. Otherwise,
|
|
246
|
+
`right_compare` parameter has to specify the columns for the other
|
|
247
|
+
chain. This value is used to see if row is modified or same. If
|
|
248
|
+
not set, all columns will be used for comparison
|
|
249
|
+
right_compare: Optional column or list of columns
|
|
250
|
+
for the `other` to compare to.
|
|
251
|
+
added (bool): Whether to return chain containing only added rows.
|
|
252
|
+
deleted (bool): Whether to return chain containing only deleted rows.
|
|
253
|
+
modified (bool): Whether to return chain containing only modified rows.
|
|
254
|
+
same (bool): Whether to return chain containing only same rows.
|
|
255
|
+
|
|
256
|
+
Example:
|
|
257
|
+
```py
|
|
258
|
+
chains = compare(
|
|
259
|
+
persons,
|
|
260
|
+
new_persons,
|
|
261
|
+
on=["id"],
|
|
262
|
+
right_on=["other_id"],
|
|
263
|
+
compare=["name"],
|
|
264
|
+
added=True,
|
|
265
|
+
deleted=True,
|
|
266
|
+
modified=True,
|
|
267
|
+
same=True,
|
|
268
|
+
)
|
|
269
|
+
```
|
|
270
|
+
"""
|
|
271
|
+
status_col = get_status_col_name()
|
|
272
|
+
|
|
273
|
+
res = _compare(
|
|
274
|
+
left,
|
|
275
|
+
right,
|
|
276
|
+
on,
|
|
277
|
+
right_on=right_on,
|
|
278
|
+
compare=compare,
|
|
279
|
+
right_compare=right_compare,
|
|
280
|
+
added=added,
|
|
281
|
+
deleted=deleted,
|
|
282
|
+
modified=modified,
|
|
283
|
+
same=same,
|
|
284
|
+
status_col=status_col,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
chains = {}
|
|
288
|
+
|
|
289
|
+
def filter_by_status(compare_status) -> "DataChain":
|
|
290
|
+
return res.filter(C(status_col) == compare_status).select_except(status_col)
|
|
291
|
+
|
|
292
|
+
if added:
|
|
293
|
+
chains[CompareStatus.ADDED.value] = filter_by_status(CompareStatus.ADDED)
|
|
294
|
+
if deleted:
|
|
295
|
+
chains[CompareStatus.DELETED.value] = filter_by_status(CompareStatus.DELETED)
|
|
296
|
+
if modified:
|
|
297
|
+
chains[CompareStatus.MODIFIED.value] = filter_by_status(CompareStatus.MODIFIED)
|
|
298
|
+
if same:
|
|
299
|
+
chains[CompareStatus.SAME.value] = filter_by_status(CompareStatus.SAME)
|
|
300
|
+
|
|
301
|
+
return chains
|
datachain/func/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from sqlalchemy import
|
|
1
|
+
from sqlalchemy import literal
|
|
2
2
|
|
|
3
3
|
from . import array, path, random, string
|
|
4
4
|
from .aggregate import (
|
|
@@ -16,7 +16,7 @@ from .aggregate import (
|
|
|
16
16
|
sum,
|
|
17
17
|
)
|
|
18
18
|
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
|
-
from .conditional import greatest, least
|
|
19
|
+
from .conditional import case, greatest, ifelse, least
|
|
20
20
|
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
22
22
|
from .string import byte_hamming_distance
|
|
@@ -40,6 +40,7 @@ __all__ = [
|
|
|
40
40
|
"euclidean_distance",
|
|
41
41
|
"first",
|
|
42
42
|
"greatest",
|
|
43
|
+
"ifelse",
|
|
43
44
|
"int_hash_64",
|
|
44
45
|
"least",
|
|
45
46
|
"length",
|