datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (50) hide show
  1. datachain/asyn.py +16 -6
  2. datachain/cache.py +32 -10
  3. datachain/catalog/catalog.py +17 -1
  4. datachain/cli/__init__.py +311 -0
  5. datachain/cli/commands/__init__.py +29 -0
  6. datachain/cli/commands/datasets.py +129 -0
  7. datachain/cli/commands/du.py +14 -0
  8. datachain/cli/commands/index.py +12 -0
  9. datachain/cli/commands/ls.py +169 -0
  10. datachain/cli/commands/misc.py +28 -0
  11. datachain/cli/commands/query.py +53 -0
  12. datachain/cli/commands/show.py +38 -0
  13. datachain/cli/parser/__init__.py +547 -0
  14. datachain/cli/parser/job.py +120 -0
  15. datachain/cli/parser/studio.py +126 -0
  16. datachain/cli/parser/utils.py +63 -0
  17. datachain/{cli_utils.py → cli/utils.py} +27 -1
  18. datachain/client/azure.py +6 -2
  19. datachain/client/fsspec.py +9 -3
  20. datachain/client/gcs.py +6 -2
  21. datachain/client/s3.py +16 -1
  22. datachain/data_storage/db_engine.py +9 -0
  23. datachain/data_storage/schema.py +4 -10
  24. datachain/data_storage/sqlite.py +7 -1
  25. datachain/data_storage/warehouse.py +6 -4
  26. datachain/{lib/diff.py → diff/__init__.py} +116 -12
  27. datachain/func/__init__.py +3 -2
  28. datachain/func/conditional.py +74 -0
  29. datachain/func/func.py +5 -1
  30. datachain/lib/arrow.py +7 -1
  31. datachain/lib/dc.py +8 -3
  32. datachain/lib/file.py +16 -5
  33. datachain/lib/hf.py +1 -1
  34. datachain/lib/listing.py +19 -1
  35. datachain/lib/pytorch.py +57 -13
  36. datachain/lib/signal_schema.py +89 -27
  37. datachain/lib/udf.py +82 -40
  38. datachain/listing.py +1 -0
  39. datachain/progress.py +20 -3
  40. datachain/query/dataset.py +122 -93
  41. datachain/query/dispatch.py +22 -16
  42. datachain/studio.py +58 -38
  43. datachain/utils.py +14 -3
  44. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
  45. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
  46. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
  47. datachain/cli.py +0 -1475
  48. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
  49. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
  50. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
1
+ def add_studio_parser(subparsers, parent_parser) -> None:
2
+ studio_help = "Commands to authenticate DataChain with Iterative Studio"
3
+ studio_description = (
4
+ "Authenticate DataChain with Studio and set the token. "
5
+ "Once this token has been properly configured,\n"
6
+ "DataChain will utilize it for seamlessly sharing datasets\n"
7
+ "and using Studio features from CLI"
8
+ )
9
+
10
+ studio_parser = subparsers.add_parser(
11
+ "studio",
12
+ parents=[parent_parser],
13
+ description=studio_description,
14
+ help=studio_help,
15
+ )
16
+ studio_subparser = studio_parser.add_subparsers(
17
+ dest="cmd",
18
+ help="Use `DataChain studio CMD --help` to display command-specific help.",
19
+ required=True,
20
+ )
21
+
22
+ studio_login_help = "Authenticate DataChain with Studio host"
23
+ studio_login_description = (
24
+ "By default, this command authenticates the DataChain with Studio\n"
25
+ "using default scopes and assigns a random name as the token name."
26
+ )
27
+ login_parser = studio_subparser.add_parser(
28
+ "login",
29
+ parents=[parent_parser],
30
+ description=studio_login_description,
31
+ help=studio_login_help,
32
+ )
33
+
34
+ login_parser.add_argument(
35
+ "-H",
36
+ "--hostname",
37
+ action="store",
38
+ default=None,
39
+ help="The hostname of the Studio instance to authenticate with.",
40
+ )
41
+ login_parser.add_argument(
42
+ "-s",
43
+ "--scopes",
44
+ action="store",
45
+ default=None,
46
+ help="The scopes for the authentication token. ",
47
+ )
48
+
49
+ login_parser.add_argument(
50
+ "-n",
51
+ "--name",
52
+ action="store",
53
+ default=None,
54
+ help="The name of the authentication token. It will be used to\n"
55
+ "identify token shown in Studio profile.",
56
+ )
57
+
58
+ login_parser.add_argument(
59
+ "--no-open",
60
+ action="store_true",
61
+ default=False,
62
+ help="Use authentication flow based on user code.\n"
63
+ "You will be presented with user code to enter in browser.\n"
64
+ "DataChain will also use this if it cannot launch browser on your behalf.",
65
+ )
66
+
67
+ studio_logout_help = "Logout user from Studio"
68
+ studio_logout_description = "This removes the studio token from your global config."
69
+
70
+ studio_subparser.add_parser(
71
+ "logout",
72
+ parents=[parent_parser],
73
+ description=studio_logout_description,
74
+ help=studio_logout_help,
75
+ )
76
+
77
+ studio_team_help = "Set the default team for DataChain"
78
+ studio_team_description = (
79
+ "Set the default team for DataChain to use when interacting with Studio."
80
+ )
81
+
82
+ team_parser = studio_subparser.add_parser(
83
+ "team",
84
+ parents=[parent_parser],
85
+ description=studio_team_description,
86
+ help=studio_team_help,
87
+ )
88
+ team_parser.add_argument(
89
+ "team_name",
90
+ action="store",
91
+ help="The name of the team to set as the default.",
92
+ )
93
+ team_parser.add_argument(
94
+ "--global",
95
+ action="store_true",
96
+ default=False,
97
+ help="Set the team globally for all DataChain projects.",
98
+ )
99
+
100
+ studio_token_help = "View the token datachain uses to contact Studio" # noqa: S105 # nosec B105
101
+
102
+ studio_subparser.add_parser(
103
+ "token",
104
+ parents=[parent_parser],
105
+ description=studio_token_help,
106
+ help=studio_token_help,
107
+ )
108
+
109
+ studio_ls_dataset_help = "List the available datasets from Studio"
110
+ studio_ls_dataset_description = (
111
+ "This command lists all the datasets available in Studio.\n"
112
+ "It will show the dataset name and the number of versions available."
113
+ )
114
+
115
+ ls_dataset_parser = studio_subparser.add_parser(
116
+ "dataset",
117
+ parents=[parent_parser],
118
+ description=studio_ls_dataset_description,
119
+ help=studio_ls_dataset_help,
120
+ )
121
+ ls_dataset_parser.add_argument(
122
+ "--team",
123
+ action="store",
124
+ default=None,
125
+ help="The team to list datasets for. By default, it will use team from config.",
126
+ )
@@ -0,0 +1,63 @@
1
+ from argparse import Action, ArgumentParser, ArgumentTypeError
2
+ from typing import Union
3
+
4
+ from datachain.cli.utils import CommaSeparatedArgs
5
+
6
+ FIND_COLUMNS = ["du", "name", "path", "size", "type"]
7
+
8
+
9
+ def find_columns_type(
10
+ columns_str: str,
11
+ default_colums_str: str = "path",
12
+ ) -> list[str]:
13
+ if not columns_str:
14
+ columns_str = default_colums_str
15
+
16
+ return [parse_find_column(c) for c in columns_str.split(",")]
17
+
18
+
19
+ def parse_find_column(column: str) -> str:
20
+ column_lower = column.strip().lower()
21
+ if column_lower in FIND_COLUMNS:
22
+ return column_lower
23
+ raise ArgumentTypeError(
24
+ f"Invalid column for find: '{column}' Options are: {','.join(FIND_COLUMNS)}"
25
+ )
26
+
27
+
28
+ def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Action:
29
+ return parser.add_argument(
30
+ "sources",
31
+ type=str,
32
+ nargs=nargs,
33
+ help="Data sources - paths to cloud storage dirs",
34
+ )
35
+
36
+
37
+ def add_show_args(parser: ArgumentParser) -> None:
38
+ parser.add_argument(
39
+ "--limit",
40
+ action="store",
41
+ default=10,
42
+ type=int,
43
+ help="Number of rows to show",
44
+ )
45
+ parser.add_argument(
46
+ "--offset",
47
+ action="store",
48
+ default=0,
49
+ type=int,
50
+ help="Number of rows to offset",
51
+ )
52
+ parser.add_argument(
53
+ "--columns",
54
+ default=[],
55
+ action=CommaSeparatedArgs,
56
+ help="Columns to show",
57
+ )
58
+ parser.add_argument(
59
+ "--no-collapse",
60
+ action="store_true",
61
+ default=False,
62
+ help="Do not collapse the columns",
63
+ )
@@ -1,4 +1,8 @@
1
- from argparse import SUPPRESS, Action, ArgumentError, _AppendAction
1
+ import logging
2
+ from argparse import SUPPRESS, Action, ArgumentError, Namespace, _AppendAction
3
+ from typing import Optional
4
+
5
+ from datachain.error import DataChainError
2
6
 
3
7
 
4
8
  class BooleanOptionalAction(Action):
@@ -70,3 +74,25 @@ class KeyValueArgs(_AppendAction): # pylint: disable=protected-access
70
74
  items[key.strip()] = value
71
75
 
72
76
  setattr(namespace, self.dest, items)
77
+
78
+
79
+ def get_logging_level(args: Namespace) -> int:
80
+ if args.quiet:
81
+ return logging.CRITICAL
82
+ if args.verbose:
83
+ return logging.DEBUG
84
+ return logging.INFO
85
+
86
+
87
+ def determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
88
+ if studio and not token:
89
+ raise DataChainError(
90
+ "Not logged in to Studio. Log in with 'datachain studio login'."
91
+ )
92
+
93
+ if local or studio:
94
+ all = False
95
+
96
+ all = all and not (local or studio)
97
+
98
+ return all, local, studio
datachain/client/azure.py CHANGED
@@ -31,8 +31,12 @@ class AzureClient(Client):
31
31
  Generate a signed URL for the given path.
32
32
  """
33
33
  version_id = kwargs.pop("version_id", None)
34
+ content_disposition = kwargs.pop("content_disposition", None)
34
35
  result = self.fs.sign(
35
- self.get_full_path(path, version_id), expiration=expires, **kwargs
36
+ self.get_full_path(path, version_id),
37
+ expiration=expires,
38
+ content_disposition=content_disposition,
39
+ **kwargs,
36
40
  )
37
41
  return result + (f"&versionid={version_id}" if version_id else "")
38
42
 
@@ -42,7 +46,7 @@ class AzureClient(Client):
42
46
  prefix = prefix.lstrip(DELIMITER) + DELIMITER
43
47
  found = False
44
48
  try:
45
- with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
49
+ with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
46
50
  async with self.fs.service_client.get_container_client(
47
51
  container=self.name
48
52
  ) as container_client:
@@ -215,7 +215,7 @@ class Client(ABC):
215
215
  info = await self.fs._info(
216
216
  self.get_full_path(file.path, file.version), **kwargs
217
217
  )
218
- return self.info_to_file(info, "").etag
218
+ return self.info_to_file(info, file.path).etag
219
219
 
220
220
  def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File":
221
221
  info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
@@ -249,7 +249,7 @@ class Client(ABC):
249
249
  await main_task
250
250
 
251
251
  async def _fetch_nested(self, start_prefix: str, result_queue: ResultQueue) -> None:
252
- progress_bar = tqdm(desc=f"Listing {self.uri}", unit=" objects")
252
+ progress_bar = tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False)
253
253
  loop = get_loop()
254
254
 
255
255
  queue: asyncio.Queue[str] = asyncio.Queue()
@@ -343,7 +343,7 @@ class Client(ABC):
343
343
  return self.version_path(f"{self.PREFIX}{self.name}/{rel_path}", version_id)
344
344
 
345
345
  @abstractmethod
346
- def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...
346
+ def info_to_file(self, v: dict[str, Any], path: str) -> "File": ...
347
347
 
348
348
  def fetch_nodes(
349
349
  self,
@@ -390,6 +390,12 @@ class Client(ABC):
390
390
  self.fs.open(self.get_full_path(file.path, file.version)), cb
391
391
  ) # type: ignore[return-value]
392
392
 
393
+ def upload(self, path: str, data: bytes) -> "File":
394
+ full_path = self.get_full_path(path)
395
+ self.fs.pipe_file(full_path, data)
396
+ file_info = self.fs.info(full_path)
397
+ return self.info_to_file(file_info, path)
398
+
393
399
  def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
394
400
  sync(get_loop(), functools.partial(self._download, file, callback=callback))
395
401
 
datachain/client/gcs.py CHANGED
@@ -39,11 +39,15 @@ class GCSClient(Client):
39
39
  (see https://cloud.google.com/storage/docs/access-public-data#api-link).
40
40
  """
41
41
  version_id = kwargs.pop("version_id", None)
42
+ content_disposition = kwargs.pop("content_disposition", None)
42
43
  if self.fs.storage_options.get("token") == "anon":
43
44
  query = f"?generation={version_id}" if version_id else ""
44
45
  return f"https://storage.googleapis.com/{self.name}/{path}{query}"
45
46
  return self.fs.sign(
46
- self.get_full_path(path, version_id), expiration=expires, **kwargs
47
+ self.get_full_path(path, version_id),
48
+ expiration=expires,
49
+ response_disposition=content_disposition,
50
+ **kwargs,
47
51
  )
48
52
 
49
53
  @staticmethod
@@ -83,7 +87,7 @@ class GCSClient(Client):
83
87
  self, page_queue: PageQueue, result_queue: ResultQueue
84
88
  ) -> bool:
85
89
  found = False
86
- with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
90
+ with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
87
91
  while (page := await page_queue.get()) is not None:
88
92
  if page:
89
93
  found = True
datachain/client/s3.py CHANGED
@@ -51,6 +51,21 @@ class ClientS3(Client):
51
51
 
52
52
  return cast(S3FileSystem, super().create_fs(**kwargs))
53
53
 
54
+ def url(self, path: str, expires: int = 3600, **kwargs) -> str:
55
+ """
56
+ Generate a signed URL for the given path.
57
+ """
58
+ version_id = kwargs.pop("version_id", None)
59
+ content_disposition = kwargs.pop("content_disposition", None)
60
+ if content_disposition:
61
+ kwargs["ResponseContentDisposition"] = content_disposition
62
+
63
+ return self.fs.sign(
64
+ self.get_full_path(path, version_id),
65
+ expiration=expires,
66
+ **kwargs,
67
+ )
68
+
54
69
  async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
55
70
  async def get_pages(it, page_queue):
56
71
  try:
@@ -61,7 +76,7 @@ class ClientS3(Client):
61
76
 
62
77
  async def process_pages(page_queue, result_queue):
63
78
  found = False
64
- with tqdm(desc=f"Listing {self.uri}", unit=" objects") as pbar:
79
+ with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
65
80
  while (res := await page_queue.get()) is not None:
66
81
  if res:
67
82
  found = True
@@ -79,6 +79,15 @@ class DatabaseEngine(ABC, Serializable):
79
79
  conn: Optional[Any] = None,
80
80
  ) -> Iterator[tuple[Any, ...]]: ...
81
81
 
82
+ def get_table(self, name: str) -> "Table":
83
+ table = self.metadata.tables.get(name)
84
+ if table is None:
85
+ sa.Table(name, self.metadata, autoload_with=self.engine)
86
+ # ^^^ This table may not be correctly initialised on some dialects
87
+ # Grab it from metadata instead.
88
+ table = self.metadata.tables[name]
89
+ return table
90
+
82
91
  @abstractmethod
83
92
  def executemany(
84
93
  self, query, params, cursor: Optional[Any] = None
@@ -16,7 +16,6 @@ from datachain.sql.functions import path as pathfunc
16
16
  from datachain.sql.types import Int, SQLType, UInt64
17
17
 
18
18
  if TYPE_CHECKING:
19
- from sqlalchemy import Engine
20
19
  from sqlalchemy.engine.interfaces import Dialect
21
20
  from sqlalchemy.sql.base import (
22
21
  ColumnCollection,
@@ -25,6 +24,8 @@ if TYPE_CHECKING:
25
24
  )
26
25
  from sqlalchemy.sql.elements import ColumnElement
27
26
 
27
+ from datachain.data_storage.db_engine import DatabaseEngine
28
+
28
29
 
29
30
  DEFAULT_DELIMITER = "__"
30
31
 
@@ -150,14 +151,12 @@ class DataTable:
150
151
  def __init__(
151
152
  self,
152
153
  name: str,
153
- engine: "Engine",
154
- metadata: Optional["sa.MetaData"] = None,
154
+ engine: "DatabaseEngine",
155
155
  column_types: Optional[dict[str, SQLType]] = None,
156
156
  object_name: str = "file",
157
157
  ):
158
158
  self.name: str = name
159
159
  self.engine = engine
160
- self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
161
160
  self.column_types: dict[str, SQLType] = column_types or {}
162
161
  self.object_name = object_name
163
162
 
@@ -211,12 +210,7 @@ class DataTable:
211
210
  return sa.Table(name, metadata, *columns)
212
211
 
213
212
  def get_table(self) -> "sa.Table":
214
- table = self.metadata.tables.get(self.name)
215
- if table is None:
216
- sa.Table(self.name, self.metadata, autoload_with=self.engine)
217
- # ^^^ This table may not be correctly initialised on some dialects
218
- # Grab it from metadata instead.
219
- table = self.metadata.tables[self.name]
213
+ table = self.engine.get_table(self.name)
220
214
 
221
215
  column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
222
216
  # adjusting types for custom columns to be instances of SQLType if possible
@@ -186,6 +186,12 @@ class SQLiteDatabaseEngine(DatabaseEngine):
186
186
  self.db_file = db_file
187
187
  self.is_closed = False
188
188
 
189
+ def get_table(self, name: str) -> Table:
190
+ if self.is_closed:
191
+ # Reconnect in case of being closed previously.
192
+ self._reconnect()
193
+ return super().get_table(name)
194
+
189
195
  @retry_sqlite_locks
190
196
  def execute(
191
197
  self,
@@ -670,7 +676,7 @@ class SQLiteWarehouse(AbstractWarehouse):
670
676
  ]
671
677
  table = self.create_udf_table(columns)
672
678
 
673
- with tqdm(desc="Preparing", unit=" rows") as pbar:
679
+ with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar:
674
680
  self.copy_table(table, query, progress_cb=pbar.update)
675
681
 
676
682
  return table
@@ -191,8 +191,7 @@ class AbstractWarehouse(ABC, Serializable):
191
191
  table_name = self.dataset_table_name(dataset.name, version)
192
192
  return self.schema.dataset_row_cls(
193
193
  table_name,
194
- self.db.engine,
195
- self.db.metadata,
194
+ self.db,
196
195
  dataset.get_schema(version),
197
196
  object_name=object_name,
198
197
  )
@@ -904,8 +903,11 @@ class AbstractWarehouse(ABC, Serializable):
904
903
  This should be implemented to ensure that the provided tables
905
904
  are cleaned up as soon as they are no longer needed.
906
905
  """
907
- with tqdm(desc="Cleanup", unit=" tables") as pbar:
908
- for name in set(names):
906
+ to_drop = set(names)
907
+ with tqdm(
908
+ desc="Cleanup", unit=" tables", total=len(to_drop), leave=False
909
+ ) as pbar:
910
+ for name in to_drop:
909
911
  self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
910
912
  pbar.update(1)
911
913
 
@@ -1,6 +1,7 @@
1
1
  import random
2
2
  import string
3
3
  from collections.abc import Sequence
4
+ from enum import Enum
4
5
  from typing import TYPE_CHECKING, Optional, Union
5
6
 
6
7
  import sqlalchemy as sa
@@ -16,7 +17,22 @@ if TYPE_CHECKING:
16
17
  C = Column
17
18
 
18
19
 
19
- def compare( # noqa: PLR0912, PLR0915, C901
20
+ def get_status_col_name() -> str:
21
+ """Returns new unique status col name"""
22
+ return "diff_" + "".join(
23
+ random.choice(string.ascii_letters) # noqa: S311
24
+ for _ in range(10)
25
+ )
26
+
27
+
28
+ class CompareStatus(str, Enum):
29
+ ADDED = "A"
30
+ DELETED = "D"
31
+ MODIFIED = "M"
32
+ SAME = "S"
33
+
34
+
35
+ def _compare( # noqa: PLR0912, PLR0915, C901
20
36
  left: "DataChain",
21
37
  right: "DataChain",
22
38
  on: Union[str, Sequence[str]],
@@ -72,13 +88,10 @@ def compare( # noqa: PLR0912, PLR0915, C901
72
88
  "At least one of added, deleted, modified, same flags must be set"
73
89
  )
74
90
 
75
- # we still need status column for internal implementation even if not
76
- # needed in output
77
91
  need_status_col = bool(status_col)
78
- status_col = status_col or "diff_" + "".join(
79
- random.choice(string.ascii_letters) # noqa: S311
80
- for _ in range(10)
81
- )
92
+ # we still need status column for internal implementation even if not
93
+ # needed in the output
94
+ status_col = status_col or get_status_col_name()
82
95
 
83
96
  # calculate on and compare column names
84
97
  right_on = right_on or on
@@ -112,7 +125,7 @@ def compare( # noqa: PLR0912, PLR0915, C901
112
125
  for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
113
126
  ]
114
127
  )
115
- diff_cond.append((added_cond, "A"))
128
+ diff_cond.append((added_cond, CompareStatus.ADDED))
116
129
  if modified and compare:
117
130
  modified_cond = sa.or_(
118
131
  *[
@@ -120,7 +133,7 @@ def compare( # noqa: PLR0912, PLR0915, C901
120
133
  for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
121
134
  ]
122
135
  )
123
- diff_cond.append((modified_cond, "M"))
136
+ diff_cond.append((modified_cond, CompareStatus.MODIFIED))
124
137
  if same and compare:
125
138
  same_cond = sa.and_(
126
139
  *[
@@ -128,9 +141,11 @@ def compare( # noqa: PLR0912, PLR0915, C901
128
141
  for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
129
142
  ]
130
143
  )
131
- diff_cond.append((same_cond, "S"))
144
+ diff_cond.append((same_cond, CompareStatus.SAME))
132
145
 
133
- diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col)
146
+ diff = sa.case(*diff_cond, else_=None if compare else CompareStatus.MODIFIED).label(
147
+ status_col
148
+ )
134
149
  diff.type = String()
135
150
 
136
151
  left_right_merge = left.merge(
@@ -145,7 +160,7 @@ def compare( # noqa: PLR0912, PLR0915, C901
145
160
  )
146
161
  )
147
162
 
148
- diff_col = sa.literal("D").label(status_col)
163
+ diff_col = sa.literal(CompareStatus.DELETED).label(status_col)
149
164
  diff_col.type = String()
150
165
 
151
166
  right_left_merge = right.merge(
@@ -195,3 +210,92 @@ def compare( # noqa: PLR0912, PLR0915, C901
195
210
  res = res.select_except(C(status_col))
196
211
 
197
212
  return left._evolve(query=res, signal_schema=schema)
213
+
214
+
215
+ def compare_and_split(
216
+ left: "DataChain",
217
+ right: "DataChain",
218
+ on: Union[str, Sequence[str]],
219
+ right_on: Optional[Union[str, Sequence[str]]] = None,
220
+ compare: Optional[Union[str, Sequence[str]]] = None,
221
+ right_compare: Optional[Union[str, Sequence[str]]] = None,
222
+ added: bool = True,
223
+ deleted: bool = True,
224
+ modified: bool = True,
225
+ same: bool = False,
226
+ ) -> dict[str, "DataChain"]:
227
+ """Comparing two chains and returning multiple chains, one for each of `added`,
228
+ `deleted`, `modified` and `same` status. Result is returned in form of
229
+ dictionary where each item represents one of the statuses and key values
230
+ are `A`, `D`, `M`, `S` corresponding. Note that status column is not in the
231
+ resulting chains.
232
+
233
+ Parameters:
234
+ left: Chain to calculate diff on.
235
+ right: Chain to calculate diff from.
236
+ on: Column or list of columns to match on. If both chains have the
237
+ same columns then this column is enough for the match. Otherwise,
238
+ `right_on` parameter has to specify the columns for the other chain.
239
+ This value is used to find corresponding row in other dataset. If not
240
+ found there, row is considered as added (or removed if vice versa), and
241
+ if found then row can be either modified or same.
242
+ right_on: Optional column or list of columns
243
+ for the `other` to match.
244
+ compare: Column or list of columns to compare on. If both chains have
245
+ the same columns then this column is enough for the compare. Otherwise,
246
+ `right_compare` parameter has to specify the columns for the other
247
+ chain. This value is used to see if row is modified or same. If
248
+ not set, all columns will be used for comparison
249
+ right_compare: Optional column or list of columns
250
+ for the `other` to compare to.
251
+ added (bool): Whether to return chain containing only added rows.
252
+ deleted (bool): Whether to return chain containing only deleted rows.
253
+ modified (bool): Whether to return chain containing only modified rows.
254
+ same (bool): Whether to return chain containing only same rows.
255
+
256
+ Example:
257
+ ```py
258
+ chains = compare(
259
+ persons,
260
+ new_persons,
261
+ on=["id"],
262
+ right_on=["other_id"],
263
+ compare=["name"],
264
+ added=True,
265
+ deleted=True,
266
+ modified=True,
267
+ same=True,
268
+ )
269
+ ```
270
+ """
271
+ status_col = get_status_col_name()
272
+
273
+ res = _compare(
274
+ left,
275
+ right,
276
+ on,
277
+ right_on=right_on,
278
+ compare=compare,
279
+ right_compare=right_compare,
280
+ added=added,
281
+ deleted=deleted,
282
+ modified=modified,
283
+ same=same,
284
+ status_col=status_col,
285
+ )
286
+
287
+ chains = {}
288
+
289
+ def filter_by_status(compare_status) -> "DataChain":
290
+ return res.filter(C(status_col) == compare_status).select_except(status_col)
291
+
292
+ if added:
293
+ chains[CompareStatus.ADDED.value] = filter_by_status(CompareStatus.ADDED)
294
+ if deleted:
295
+ chains[CompareStatus.DELETED.value] = filter_by_status(CompareStatus.DELETED)
296
+ if modified:
297
+ chains[CompareStatus.MODIFIED.value] = filter_by_status(CompareStatus.MODIFIED)
298
+ if same:
299
+ chains[CompareStatus.SAME.value] = filter_by_status(CompareStatus.SAME)
300
+
301
+ return chains
@@ -1,4 +1,4 @@
1
- from sqlalchemy import case, literal
1
+ from sqlalchemy import literal
2
2
 
3
3
  from . import array, path, random, string
4
4
  from .aggregate import (
@@ -16,7 +16,7 @@ from .aggregate import (
16
16
  sum,
17
17
  )
18
18
  from .array import cosine_distance, euclidean_distance, length, sip_hash_64
19
- from .conditional import greatest, least
19
+ from .conditional import case, greatest, ifelse, least
20
20
  from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
21
21
  from .random import rand
22
22
  from .string import byte_hamming_distance
@@ -40,6 +40,7 @@ __all__ = [
40
40
  "euclidean_distance",
41
41
  "first",
42
42
  "greatest",
43
+ "ifelse",
43
44
  "int_hash_64",
44
45
  "least",
45
46
  "length",