datachain 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/asyn.py CHANGED
@@ -224,3 +224,23 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
224
224
  async def _break_iteration(self) -> None:
225
225
  self.heap = []
226
226
  self._push_result(self._next_yield, None)
227
+
228
+
229
+ def iter_over_async(ait, loop):
230
+ """Wrap an asynchronous iterator into a synchronous one"""
231
+ ait = ait.__aiter__()
232
+
233
+ # helper async fn that just gets the next element from the async iterator
234
+ async def get_next():
235
+ try:
236
+ obj = await ait.__anext__()
237
+ return False, obj
238
+ except StopAsyncIteration:
239
+ return True, None
240
+
241
+ # actual sync iterator
242
+ while True:
243
+ done, obj = asyncio.run_coroutine_threadsafe(get_next(), loop).result()
244
+ if done:
245
+ break
246
+ yield obj
@@ -577,6 +577,7 @@ class Catalog:
577
577
  warehouse_ready_callback: Optional[
578
578
  Callable[["AbstractWarehouse"], None]
579
579
  ] = None,
580
+ in_memory: bool = False,
580
581
  ):
581
582
  datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
582
583
  datachain_dir.init()
@@ -590,6 +591,7 @@ class Catalog:
590
591
  "tmp_dir": tmp_dir,
591
592
  }
592
593
  self._warehouse_ready_callback = warehouse_ready_callback
594
+ self.in_memory = in_memory
593
595
 
594
596
  @cached_property
595
597
  def warehouse(self) -> "AbstractWarehouse":
@@ -1627,8 +1629,17 @@ class Catalog:
1627
1629
  version = self.get_dataset(dataset_name).get_version(dataset_version)
1628
1630
 
1629
1631
  file_signals_values = {}
1632
+ file_schemas = {}
1633
+ # TODO: To remove after we properly fix deserialization
1634
+ for signal, type_name in version.feature_schema.items():
1635
+ from datachain.lib.model_store import ModelStore
1630
1636
 
1631
- schema = SignalSchema.deserialize(version.feature_schema)
1637
+ type_name_parsed, v = ModelStore.parse_name_version(type_name)
1638
+ fr = ModelStore.get(type_name_parsed, v)
1639
+ if fr and issubclass(fr, File):
1640
+ file_schemas[signal] = type_name
1641
+
1642
+ schema = SignalSchema.deserialize(file_schemas)
1632
1643
  for file_signals in schema.get_signals(File):
1633
1644
  prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
1634
1645
  file_signals_values[file_signals] = {
@@ -28,8 +28,10 @@ WAREHOUSE_ARG_PREFIX = "DATACHAIN_WAREHOUSE_ARG_"
28
28
  DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
29
29
  DISTRIBUTED_ARG_PREFIX = "DATACHAIN_DISTRIBUTED_ARG_"
30
30
 
31
+ IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
31
32
 
32
- def get_id_generator() -> "AbstractIDGenerator":
33
+
34
+ def get_id_generator(in_memory: bool = False) -> "AbstractIDGenerator":
33
35
  id_generator_serialized = os.environ.get(ID_GENERATOR_SERIALIZED)
34
36
  if id_generator_serialized:
35
37
  id_generator_obj = deserialize(id_generator_serialized)
@@ -43,25 +45,31 @@ def get_id_generator() -> "AbstractIDGenerator":
43
45
  id_generator_import_path = os.environ.get(ID_GENERATOR_IMPORT_PATH)
44
46
  id_generator_arg_envs = get_envs_by_prefix(ID_GENERATOR_ARG_PREFIX)
45
47
  # Convert env variable names to keyword argument names by lowercasing them
46
- id_generator_args = {k.lower(): v for k, v in id_generator_arg_envs.items()}
47
-
48
- if id_generator_import_path:
49
- # ID generator paths are specified as (for example):
50
- # datachain.data_storage.SQLiteIDGenerator
51
- if "." not in id_generator_import_path:
52
- raise RuntimeError(
53
- f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
54
- f"{id_generator_import_path}"
55
- )
56
- module_name, _, class_name = id_generator_import_path.rpartition(".")
57
- id_generator = import_module(module_name)
58
- id_generator_class = getattr(id_generator, class_name)
59
- else:
60
- id_generator_class = SQLiteIDGenerator
48
+ id_generator_args: dict[str, Any] = {
49
+ k.lower(): v for k, v in id_generator_arg_envs.items()
50
+ }
51
+
52
+ if not id_generator_import_path:
53
+ id_generator_args["in_memory"] = in_memory
54
+ return SQLiteIDGenerator(**id_generator_args)
55
+ if in_memory:
56
+ raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
57
+ # ID generator paths are specified as (for example):
58
+ # datachain.data_storage.SQLiteIDGenerator
59
+ if "." not in id_generator_import_path:
60
+ raise RuntimeError(
61
+ f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
62
+ f"{id_generator_import_path}"
63
+ )
64
+ module_name, _, class_name = id_generator_import_path.rpartition(".")
65
+ id_generator = import_module(module_name)
66
+ id_generator_class = getattr(id_generator, class_name)
61
67
  return id_generator_class(**id_generator_args)
62
68
 
63
69
 
64
- def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMetastore":
70
+ def get_metastore(
71
+ id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
72
+ ) -> "AbstractMetastore":
65
73
  metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
66
74
  if metastore_serialized:
67
75
  metastore_obj = deserialize(metastore_serialized)
@@ -78,24 +86,32 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
78
86
  metastore_import_path = os.environ.get(METASTORE_IMPORT_PATH)
79
87
  metastore_arg_envs = get_envs_by_prefix(METASTORE_ARG_PREFIX)
80
88
  # Convert env variable names to keyword argument names by lowercasing them
81
- metastore_args = {k.lower(): v for k, v in metastore_arg_envs.items()}
82
-
83
- if metastore_import_path:
84
- # Metastore paths are specified as (for example):
85
- # datachain.data_storage.SQLiteMetastore
86
- if "." not in metastore_import_path:
87
- raise RuntimeError(
88
- f"Invalid {METASTORE_IMPORT_PATH} import path: {metastore_import_path}"
89
- )
90
- module_name, _, class_name = metastore_import_path.rpartition(".")
91
- metastore = import_module(module_name)
92
- metastore_class = getattr(metastore, class_name)
93
- else:
94
- metastore_class = SQLiteMetastore
89
+ metastore_args: dict[str, Any] = {
90
+ k.lower(): v for k, v in metastore_arg_envs.items()
91
+ }
92
+
93
+ if not metastore_import_path:
94
+ if not isinstance(id_generator, SQLiteIDGenerator):
95
+ raise ValueError("SQLiteMetastore can only be used with SQLiteIDGenerator")
96
+ metastore_args["in_memory"] = in_memory
97
+ return SQLiteMetastore(id_generator, **metastore_args)
98
+ if in_memory:
99
+ raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
100
+ # Metastore paths are specified as (for example):
101
+ # datachain.data_storage.SQLiteMetastore
102
+ if "." not in metastore_import_path:
103
+ raise RuntimeError(
104
+ f"Invalid {METASTORE_IMPORT_PATH} import path: {metastore_import_path}"
105
+ )
106
+ module_name, _, class_name = metastore_import_path.rpartition(".")
107
+ metastore = import_module(module_name)
108
+ metastore_class = getattr(metastore, class_name)
95
109
  return metastore_class(id_generator, **metastore_args)
96
110
 
97
111
 
98
- def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWarehouse":
112
+ def get_warehouse(
113
+ id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
114
+ ) -> "AbstractWarehouse":
99
115
  warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
100
116
  if warehouse_serialized:
101
117
  warehouse_obj = deserialize(warehouse_serialized)
@@ -112,20 +128,26 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
112
128
  warehouse_import_path = os.environ.get(WAREHOUSE_IMPORT_PATH)
113
129
  warehouse_arg_envs = get_envs_by_prefix(WAREHOUSE_ARG_PREFIX)
114
130
  # Convert env variable names to keyword argument names by lowercasing them
115
- warehouse_args = {k.lower(): v for k, v in warehouse_arg_envs.items()}
116
-
117
- if warehouse_import_path:
118
- # Warehouse paths are specified as (for example):
119
- # datachain.data_storage.SQLiteWarehouse
120
- if "." not in warehouse_import_path:
121
- raise RuntimeError(
122
- f"Invalid {WAREHOUSE_IMPORT_PATH} import path: {warehouse_import_path}"
123
- )
124
- module_name, _, class_name = warehouse_import_path.rpartition(".")
125
- warehouse = import_module(module_name)
126
- warehouse_class = getattr(warehouse, class_name)
127
- else:
128
- warehouse_class = SQLiteWarehouse
131
+ warehouse_args: dict[str, Any] = {
132
+ k.lower(): v for k, v in warehouse_arg_envs.items()
133
+ }
134
+
135
+ if not warehouse_import_path:
136
+ if not isinstance(id_generator, SQLiteIDGenerator):
137
+ raise ValueError("SQLiteWarehouse can only be used with SQLiteIDGenerator")
138
+ warehouse_args["in_memory"] = in_memory
139
+ return SQLiteWarehouse(id_generator, **warehouse_args)
140
+ if in_memory:
141
+ raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
142
+ # Warehouse paths are specified as (for example):
143
+ # datachain.data_storage.SQLiteWarehouse
144
+ if "." not in warehouse_import_path:
145
+ raise RuntimeError(
146
+ f"Invalid {WAREHOUSE_IMPORT_PATH} import path: {warehouse_import_path}"
147
+ )
148
+ module_name, _, class_name = warehouse_import_path.rpartition(".")
149
+ warehouse = import_module(module_name)
150
+ warehouse_class = getattr(warehouse, class_name)
129
151
  return warehouse_class(id_generator, **warehouse_args)
130
152
 
131
153
 
@@ -152,7 +174,9 @@ def get_distributed_class(**kwargs):
152
174
  return distributed_class(**distributed_args | kwargs)
153
175
 
154
176
 
155
- def get_catalog(client_config: Optional[dict[str, Any]] = None) -> Catalog:
177
+ def get_catalog(
178
+ client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
179
+ ) -> Catalog:
156
180
  """
157
181
  Function that creates Catalog instance with appropriate metastore
158
182
  and warehouse classes. Metastore class can be provided with env variable
@@ -164,10 +188,11 @@ def get_catalog(client_config: Optional[dict[str, Any]] = None) -> Catalog:
164
188
  and name of variable after, e.g. if it accepts team_id as kwargs
165
189
  we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
166
190
  """
167
- id_generator = get_id_generator()
191
+ id_generator = get_id_generator(in_memory=in_memory)
168
192
  return Catalog(
169
193
  id_generator=id_generator,
170
- metastore=get_metastore(id_generator),
171
- warehouse=get_warehouse(id_generator),
194
+ metastore=get_metastore(id_generator, in_memory=in_memory),
195
+ warehouse=get_warehouse(id_generator, in_memory=in_memory),
172
196
  client_config=client_config,
197
+ in_memory=in_memory,
173
198
  )
datachain/client/azure.py CHANGED
@@ -3,6 +3,7 @@ from typing import Any
3
3
  from adlfs import AzureBlobFileSystem
4
4
  from tqdm import tqdm
5
5
 
6
+ from datachain.lib.file import File
6
7
  from datachain.node import Entry
7
8
 
8
9
  from .fsspec import DELIMITER, Client, ResultQueue
@@ -24,6 +25,18 @@ class AzureClient(Client):
24
25
  size=v.get("size", ""),
25
26
  )
26
27
 
28
+ def info_to_file(self, v: dict[str, Any], path: str) -> File:
29
+ version_id = v.get("version_id")
30
+ return File(
31
+ source=self.uri,
32
+ path=path,
33
+ etag=v.get("etag", "").strip('"'),
34
+ version=version_id or "",
35
+ is_latest=version_id is None or bool(v.get("is_current_version")),
36
+ last_modified=v["last_modified"],
37
+ size=v.get("size", ""),
38
+ )
39
+
27
40
  async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
28
41
  prefix = start_prefix
29
42
  if prefix:
datachain/client/gcs.py CHANGED
@@ -9,6 +9,7 @@ from dateutil.parser import isoparse
9
9
  from gcsfs import GCSFileSystem
10
10
  from tqdm import tqdm
11
11
 
12
+ from datachain.lib.file import File
12
13
  from datachain.node import Entry
13
14
 
14
15
  from .fsspec import DELIMITER, Client, ResultQueue
@@ -120,3 +121,14 @@ class GCSClient(Client):
120
121
  last_modified=self.parse_timestamp(v["updated"]),
121
122
  size=v.get("size", ""),
122
123
  )
124
+
125
+ def info_to_file(self, v: dict[str, Any], path: str) -> File:
126
+ return File(
127
+ source=self.uri,
128
+ path=path,
129
+ etag=v.get("etag", ""),
130
+ version=v.get("generation", ""),
131
+ is_latest=not v.get("timeDeleted"),
132
+ last_modified=self.parse_timestamp(v["updated"]),
133
+ size=v.get("size", ""),
134
+ )
datachain/client/local.py CHANGED
@@ -7,6 +7,7 @@ from urllib.parse import urlparse
7
7
 
8
8
  from fsspec.implementations.local import LocalFileSystem
9
9
 
10
+ from datachain.lib.file import File
10
11
  from datachain.node import Entry
11
12
  from datachain.storage import StorageURI
12
13
 
@@ -144,6 +145,16 @@ class FileClient(Client):
144
145
  size=v.get("size", ""),
145
146
  )
146
147
 
148
+ def info_to_file(self, v: dict[str, Any], path: str) -> File:
149
+ return File(
150
+ source=self.uri,
151
+ path=path,
152
+ size=v.get("size", ""),
153
+ etag=v["mtime"].hex(),
154
+ is_latest=True,
155
+ last_modified=datetime.fromtimestamp(v["mtime"], timezone.utc),
156
+ )
157
+
147
158
  def fetch_nodes(
148
159
  self,
149
160
  nodes,
datachain/client/s3.py CHANGED
@@ -5,6 +5,7 @@ from botocore.exceptions import NoCredentialsError
5
5
  from s3fs import S3FileSystem
6
6
  from tqdm import tqdm
7
7
 
8
+ from datachain.lib.file import File
8
9
  from datachain.node import Entry
9
10
 
10
11
  from .fsspec import DELIMITER, Client, ResultQueue
@@ -167,3 +168,14 @@ class ClientS3(Client):
167
168
  owner_name=v.get("Owner", {}).get("DisplayName", ""),
168
169
  owner_id=v.get("Owner", {}).get("ID", ""),
169
170
  )
171
+
172
+ def info_to_file(self, v: dict[str, Any], path: str) -> File:
173
+ return File(
174
+ source=self.uri,
175
+ path=path,
176
+ size=v["size"],
177
+ version=ClientS3.clean_s3_version(v.get("VersionId", "")),
178
+ etag=v.get("ETag", "").strip('"'),
179
+ is_latest=v.get("IsLatest", True),
180
+ last_modified=v.get("LastModified", ""),
181
+ )
@@ -67,7 +67,11 @@ def convert_rows_custom_column_types(
67
67
  for row in rows:
68
68
  row_list = list(row)
69
69
  for idx, t in custom_columns_types:
70
- row_list[idx] = t.on_read_convert(row_list[idx], dialect)
70
+ row_list[idx] = (
71
+ t.default_value(dialect)
72
+ if row_list[idx] is None
73
+ else t.on_read_convert(row_list[idx], dialect)
74
+ )
71
75
 
72
76
  yield tuple(row_list)
73
77
 
@@ -136,7 +140,15 @@ class DataTable:
136
140
  self.column_types: dict[str, SQLType] = column_types or {}
137
141
 
138
142
  @staticmethod
139
- def copy_column(column: sa.Column):
143
+ def copy_column(
144
+ column: sa.Column,
145
+ primary_key: Optional[bool] = None,
146
+ index: Optional[bool] = None,
147
+ nullable: Optional[bool] = None,
148
+ default: Optional[Any] = None,
149
+ server_default: Optional[Any] = None,
150
+ unique: Optional[bool] = None,
151
+ ) -> sa.Column:
140
152
  """
141
153
  Copy a sqlalchemy Column object intended for use as a signal column.
142
154
 
@@ -150,12 +162,14 @@ class DataTable:
150
162
  return sa.Column(
151
163
  column.name,
152
164
  column.type,
153
- primary_key=column.primary_key,
154
- index=column.index,
155
- nullable=column.nullable,
156
- default=column.default,
157
- server_default=column.server_default,
158
- unique=column.unique,
165
+ primary_key=primary_key if primary_key is not None else column.primary_key,
166
+ index=index if index is not None else column.index,
167
+ nullable=nullable if nullable is not None else column.nullable,
168
+ default=default if default is not None else column.default,
169
+ server_default=(
170
+ server_default if server_default is not None else column.server_default
171
+ ),
172
+ unique=unique if unique is not None else column.unique,
159
173
  )
160
174
 
161
175
  @classmethod
@@ -20,6 +20,8 @@ from sqlalchemy.dialects import sqlite
20
20
  from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
21
21
  from sqlalchemy.sql import func
22
22
  from sqlalchemy.sql.expression import bindparam, cast
23
+ from sqlalchemy.sql.selectable import Select
24
+ from tqdm import tqdm
23
25
 
24
26
  import datachain.sql.sqlite
25
27
  from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
@@ -35,14 +37,13 @@ from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_diale
35
37
  from datachain.sql.sqlite.base import load_usearch_extension
36
38
  from datachain.sql.types import SQLType
37
39
  from datachain.storage import StorageURI
38
- from datachain.utils import DataChainDir
40
+ from datachain.utils import DataChainDir, batched_it
39
41
 
40
42
  if TYPE_CHECKING:
41
43
  from sqlalchemy.dialects.sqlite import Insert
42
44
  from sqlalchemy.engine.base import Engine
43
45
  from sqlalchemy.schema import SchemaItem
44
- from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
45
- from sqlalchemy.sql.selectable import Select
46
+ from sqlalchemy.sql.elements import ColumnElement
46
47
  from sqlalchemy.types import TypeEngine
47
48
 
48
49
 
@@ -54,8 +55,6 @@ RETRY_FACTOR = 2
54
55
 
55
56
  DETECT_TYPES = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
56
57
 
57
- Column = Union[str, "ColumnClause[Any]", "TextClause"]
58
-
59
58
  datachain.sql.sqlite.setup()
60
59
 
61
60
  quote_schema = sqlite_dialect.identifier_preparer.quote_schema
@@ -82,6 +81,17 @@ def retry_sqlite_locks(func):
82
81
  return wrapper
83
82
 
84
83
 
84
+ def get_db_file_in_memory(
85
+ db_file: Optional[str] = None, in_memory: bool = False
86
+ ) -> Optional[str]:
87
+ """Get in-memory db_file and check that conflicting arguments are not provided."""
88
+ if in_memory:
89
+ if db_file and db_file != ":memory:":
90
+ raise RuntimeError("A db_file cannot be specified if in_memory is True")
91
+ db_file = ":memory:"
92
+ return db_file
93
+
94
+
85
95
  class SQLiteDatabaseEngine(DatabaseEngine):
86
96
  dialect = sqlite_dialect
87
97
 
@@ -122,6 +132,11 @@ class SQLiteDatabaseEngine(DatabaseEngine):
122
132
  engine = sqlalchemy.create_engine(
123
133
  "sqlite+pysqlite:///", creator=lambda: db, future=True
124
134
  )
135
+ # ensure we run SA on_connect init (e.g it registers regexp function),
136
+ # also makes sure that it's consistent. Otherwise in some cases it
137
+ # seems we are getting different results if engine object is used in a
138
+ # different thread first and enine is not used in the Main thread.
139
+ engine.connect().close()
125
140
 
126
141
  db.isolation_level = None # Use autocommit mode
127
142
  db.execute("PRAGMA foreign_keys = ON")
@@ -260,7 +275,10 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
260
275
  table_prefix: Optional[str] = None,
261
276
  skip_db_init: bool = False,
262
277
  db_file: Optional[str] = None,
278
+ in_memory: bool = False,
263
279
  ):
280
+ db_file = get_db_file_in_memory(db_file, in_memory)
281
+
264
282
  db = db or SQLiteDatabaseEngine.from_db_file(db_file)
265
283
 
266
284
  super().__init__(db, table_prefix, skip_db_init)
@@ -378,6 +396,7 @@ class SQLiteMetastore(AbstractDBMetastore):
378
396
  partial_id: Optional[int] = None,
379
397
  db: Optional["SQLiteDatabaseEngine"] = None,
380
398
  db_file: Optional[str] = None,
399
+ in_memory: bool = False,
381
400
  ):
382
401
  self.schema: DefaultSchema = DefaultSchema()
383
402
  super().__init__(id_generator, uri, partial_id)
@@ -386,6 +405,8 @@ class SQLiteMetastore(AbstractDBMetastore):
386
405
  # foreign keys
387
406
  self.default_table_names: list[str] = []
388
407
 
408
+ db_file = get_db_file_in_memory(db_file, in_memory)
409
+
389
410
  self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
390
411
 
391
412
  self._init_tables()
@@ -550,10 +571,13 @@ class SQLiteWarehouse(AbstractWarehouse):
550
571
  id_generator: "SQLiteIDGenerator",
551
572
  db: Optional["SQLiteDatabaseEngine"] = None,
552
573
  db_file: Optional[str] = None,
574
+ in_memory: bool = False,
553
575
  ):
554
576
  self.schema: DefaultSchema = DefaultSchema()
555
577
  super().__init__(id_generator)
556
578
 
579
+ db_file = get_db_file_in_memory(db_file, in_memory)
580
+
557
581
  self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
558
582
 
559
583
  def __exit__(self, exc_type, exc_value, traceback) -> None:
@@ -626,9 +650,7 @@ class SQLiteWarehouse(AbstractWarehouse):
626
650
  self.db.create_table(table, if_not_exists=if_not_exists)
627
651
  return table
628
652
 
629
- def dataset_rows_select(
630
- self, select_query: sqlalchemy.sql.selectable.Select, **kwargs
631
- ):
653
+ def dataset_rows_select(self, select_query: Select, **kwargs):
632
654
  rows = self.db.execute(select_query, **kwargs)
633
655
  yield from convert_rows_custom_column_types(
634
656
  select_query.selected_columns, rows, sqlite_dialect
@@ -746,6 +768,34 @@ class SQLiteWarehouse(AbstractWarehouse):
746
768
  ) -> list[str]:
747
769
  raise NotImplementedError("Exporting dataset table not implemented for SQLite")
748
770
 
771
+ def copy_table(
772
+ self,
773
+ table: Table,
774
+ query: Select,
775
+ progress_cb: Optional[Callable[[int], None]] = None,
776
+ ) -> None:
777
+ if "sys__id" in query.selected_columns:
778
+ col_id = query.selected_columns.sys__id
779
+ else:
780
+ col_id = sqlalchemy.column("sys__id")
781
+ select_ids = query.with_only_columns(col_id)
782
+
783
+ ids = self.db.execute(select_ids).fetchall()
784
+
785
+ select_q = query.with_only_columns(
786
+ *[c for c in query.selected_columns if c.name != "sys__id"]
787
+ )
788
+
789
+ for batch in batched_it(ids, 10_000):
790
+ batch_ids = [row[0] for row in batch]
791
+ select_q._where_criteria = (col_id.in_(batch_ids),)
792
+ q = table.insert().from_select(list(select_q.selected_columns), select_q)
793
+
794
+ self.db.execute(q)
795
+
796
+ if progress_cb:
797
+ progress_cb(len(batch_ids))
798
+
749
799
  def create_pre_udf_table(self, query: "Select") -> "Table":
750
800
  """
751
801
  Create a temporary table from a query for use in a UDF.
@@ -757,11 +807,7 @@ class SQLiteWarehouse(AbstractWarehouse):
757
807
  ]
758
808
  table = self.create_udf_table(columns)
759
809
 
760
- select_q = query.with_only_columns(
761
- *[c for c in query.selected_columns if c.name != "sys__id"]
762
- )
763
- self.db.execute(
764
- table.insert().from_select(list(select_q.selected_columns), select_q)
765
- )
810
+ with tqdm(desc="Preparing", unit=" rows") as pbar:
811
+ self.copy_table(table, query, progress_cb=pbar.update)
766
812
 
767
813
  return table
@@ -6,7 +6,7 @@ import random
6
6
  import string
7
7
  from abc import ABC, abstractmethod
8
8
  from collections.abc import Generator, Iterable, Iterator, Sequence
9
- from typing import TYPE_CHECKING, Any, Optional, Union
9
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
10
10
  from urllib.parse import urlparse
11
11
 
12
12
  import attrs
@@ -14,6 +14,7 @@ import sqlalchemy as sa
14
14
  from sqlalchemy import Table, case, select
15
15
  from sqlalchemy.sql import func
16
16
  from sqlalchemy.sql.expression import true
17
+ from tqdm import tqdm
17
18
 
18
19
  from datachain.client import Client
19
20
  from datachain.data_storage.serializer import Serializable
@@ -901,6 +902,17 @@ class AbstractWarehouse(ABC, Serializable):
901
902
  self.db.create_table(tbl, if_not_exists=True)
902
903
  return tbl
903
904
 
905
+ @abstractmethod
906
+ def copy_table(
907
+ self,
908
+ table: Table,
909
+ query: "Select",
910
+ progress_cb: Optional[Callable[[int], None]] = None,
911
+ ) -> None:
912
+ """
913
+ Copy the results of a query into a table.
914
+ """
915
+
904
916
  @abstractmethod
905
917
  def create_pre_udf_table(self, query: "Select") -> "Table":
906
918
  """
@@ -928,8 +940,10 @@ class AbstractWarehouse(ABC, Serializable):
928
940
  This should be implemented to ensure that the provided tables
929
941
  are cleaned up as soon as they are no longer needed.
930
942
  """
931
- for name in names:
932
- self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
943
+ with tqdm(desc="Cleanup", unit=" tables") as pbar:
944
+ for name in names:
945
+ self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
946
+ pbar.update(1)
933
947
 
934
948
  def changed_query(
935
949
  self,
datachain/lib/arrow.py CHANGED
@@ -122,7 +122,7 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
122
122
  if pa.types.is_string(col_type) or pa.types.is_large_string(col_type):
123
123
  return str
124
124
  if pa.types.is_list(col_type):
125
- return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[misc]
125
+ return list[_arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
126
126
  if pa.types.is_struct(col_type) or pa.types.is_map(col_type):
127
127
  return dict
128
128
  if isinstance(col_type, pa.lib.DictionaryType):
@@ -1,7 +1,12 @@
1
1
  from collections.abc import Sequence
2
2
  from typing import Any, Union
3
3
 
4
- from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
4
+ from datachain.lib.data_model import (
5
+ DataType,
6
+ DataTypeNames,
7
+ DataValuesType,
8
+ is_chain_type,
9
+ )
5
10
  from datachain.lib.utils import DataChainParamsError
6
11
 
7
12
 
@@ -15,7 +20,7 @@ class ValuesToTupleError(DataChainParamsError):
15
20
  def values_to_tuples( # noqa: C901, PLR0912
16
21
  ds_name: str = "",
17
22
  output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
18
- **fr_map,
23
+ **fr_map: Sequence[DataValuesType],
19
24
  ) -> tuple[Any, Any, Any]:
20
25
  if output:
21
26
  if not isinstance(output, (Sequence, str, dict)):
@@ -47,10 +52,10 @@ def values_to_tuples( # noqa: C901, PLR0912
47
52
  f" number of signals '{len(fr_map)}'",
48
53
  )
49
54
 
50
- types_map = {}
55
+ types_map: dict[str, type] = {}
51
56
  length = -1
52
57
  for k, v in fr_map.items():
53
- if not isinstance(v, Sequence) or isinstance(v, str):
58
+ if not isinstance(v, Sequence) or isinstance(v, str): # type: ignore[unreachable]
54
59
  raise ValuesToTupleError(ds_name, f"signals '{k}' is not a sequence")
55
60
  len_ = len(v)
56
61
 
@@ -64,15 +69,16 @@ def values_to_tuples( # noqa: C901, PLR0912
64
69
  if len_ == 0:
65
70
  raise ValuesToTupleError(ds_name, f"signal '{k}' is empty list")
66
71
 
67
- typ = type(v[0])
72
+ first_element = next(iter(v))
73
+ typ = type(first_element)
68
74
  if not is_chain_type(typ):
69
75
  raise ValuesToTupleError(
70
76
  ds_name,
71
77
  f"signal '{k}' has unsupported type '{typ.__name__}'."
72
78
  f" Please use DataModel types: {DataTypeNames}",
73
79
  )
74
- if typ is list:
75
- types_map[k] = list[type(v[0][0])] # type: ignore[misc]
80
+ if isinstance(first_element, list):
81
+ types_map[k] = list[type(first_element[0])] # type: ignore[assignment, misc]
76
82
  else:
77
83
  types_map[k] = typ
78
84
 
@@ -98,7 +104,7 @@ def values_to_tuples( # noqa: C901, PLR0912
98
104
  if len(output) > 1: # type: ignore[arg-type]
99
105
  tuple_type = tuple(output_types)
100
106
  res_type = tuple[tuple_type] # type: ignore[valid-type]
101
- res_values = list(zip(*fr_map.values()))
107
+ res_values: Sequence[Any] = list(zip(*fr_map.values()))
102
108
  else:
103
109
  res_type = output_types[0] # type: ignore[misc]
104
110
  res_values = next(iter(fr_map.values()))