datachain 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -988,6 +988,14 @@ class Catalog:
988
988
  schema = {
989
989
  c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
990
990
  }
991
+
992
+ job_id = job_id or os.getenv("DATACHAIN_JOB_ID")
993
+ if not job_id:
994
+ from datachain.query.session import Session
995
+
996
+ session = Session.get(catalog=self)
997
+ job_id = session.job_id
998
+
991
999
  dataset = self.metastore.create_dataset_version(
992
1000
  dataset,
993
1001
  version,
datachain/cli.py CHANGED
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Optional, Union
12
12
 
13
13
  import shtab
14
14
 
15
- from datachain import utils
15
+ from datachain import Session, utils
16
16
  from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
17
17
  from datachain.lib.dc import DataChain
18
18
  from datachain.telemetry import telemetry
@@ -770,7 +770,8 @@ def show(
770
770
  show_records(records, collapse_columns=not no_collapse)
771
771
  if schema and dataset_version.feature_schema:
772
772
  print("\nSchema:")
773
- dc = DataChain(name=name, version=version, catalog=catalog)
773
+ session = Session.get(catalog=catalog)
774
+ dc = DataChain.from_dataset(name=name, version=version, session=session)
774
775
  dc.print_schema()
775
776
 
776
777
 
@@ -15,7 +15,6 @@ from uuid import uuid4
15
15
  from sqlalchemy import (
16
16
  JSON,
17
17
  BigInteger,
18
- Boolean,
19
18
  Column,
20
19
  DateTime,
21
20
  ForeignKey,
@@ -51,7 +50,6 @@ if TYPE_CHECKING:
51
50
  from datachain.data_storage import AbstractIDGenerator, schema
52
51
  from datachain.data_storage.db_engine import DatabaseEngine
53
52
 
54
-
55
53
  logger = logging.getLogger("datachain")
56
54
 
57
55
 
@@ -228,7 +226,7 @@ class AbstractMetastore(ABC, Serializable):
228
226
  self,
229
227
  dataset: DatasetRecord,
230
228
  version: int,
231
- status: int = DatasetStatus.CREATED,
229
+ status: int,
232
230
  sources: str = "",
233
231
  feature_schema: Optional[dict] = None,
234
232
  query_script: str = "",
@@ -385,6 +383,11 @@ class AbstractMetastore(ABC, Serializable):
385
383
  ) -> None:
386
384
  """Set the status of the given job and dataset."""
387
385
 
386
+ @abstractmethod
387
+ def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
388
+ """Returns dataset names and versions for the job."""
389
+ raise NotImplementedError
390
+
388
391
 
389
392
  class AbstractDBMetastore(AbstractMetastore):
390
393
  """
@@ -448,7 +451,6 @@ class AbstractDBMetastore(AbstractMetastore):
448
451
  Column("name", Text, nullable=False),
449
452
  Column("description", Text),
450
453
  Column("labels", JSON, nullable=True),
451
- Column("shadow", Boolean, nullable=False),
452
454
  Column("status", Integer, nullable=False),
453
455
  Column("feature_schema", JSON, nullable=True),
454
456
  Column("created_at", DateTime(timezone=True)),
@@ -481,8 +483,11 @@ class AbstractDBMetastore(AbstractMetastore):
481
483
  nullable=False,
482
484
  ),
483
485
  Column("version", Integer, nullable=False),
484
- # adding default for now until we fully remove shadow datasets
485
- Column("status", Integer, nullable=False, default=DatasetStatus.COMPLETE),
486
+ Column(
487
+ "status",
488
+ Integer,
489
+ nullable=False,
490
+ ),
486
491
  Column("feature_schema", JSON, nullable=True),
487
492
  Column("created_at", DateTime(timezone=True)),
488
493
  Column("finished_at", DateTime(timezone=True)),
@@ -969,7 +974,6 @@ class AbstractDBMetastore(AbstractMetastore):
969
974
  # TODO abstract this method and add registered = True based on kwargs
970
975
  query = self._datasets_insert().values(
971
976
  name=name,
972
- shadow=False,
973
977
  status=status,
974
978
  feature_schema=json.dumps(feature_schema or {}),
975
979
  created_at=datetime.now(timezone.utc),
@@ -992,7 +996,7 @@ class AbstractDBMetastore(AbstractMetastore):
992
996
  self,
993
997
  dataset: DatasetRecord,
994
998
  version: int,
995
- status: int = DatasetStatus.CREATED,
999
+ status: int,
996
1000
  sources: str = "",
997
1001
  feature_schema: Optional[dict] = None,
998
1002
  query_script: str = "",
@@ -1018,7 +1022,7 @@ class AbstractDBMetastore(AbstractMetastore):
1018
1022
  query = self._datasets_versions_insert().values(
1019
1023
  dataset_id=dataset.id,
1020
1024
  version=version,
1021
- status=status, # for now until we remove shadow datasets
1025
+ status=status,
1022
1026
  feature_schema=json.dumps(feature_schema or {}),
1023
1027
  created_at=created_at or datetime.now(timezone.utc),
1024
1028
  finished_at=finished_at,
@@ -1519,3 +1523,18 @@ class AbstractDBMetastore(AbstractMetastore):
1519
1523
  .values(status=dataset_status)
1520
1524
  )
1521
1525
  self.db.execute(query, conn=conn) # type: ignore[attr-defined]
1526
+
1527
+ def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
1528
+ """Returns dataset names and versions for the job."""
1529
+ dv = self._datasets_versions
1530
+ ds = self._datasets
1531
+
1532
+ join_condition = dv.c.dataset_id == ds.c.id
1533
+
1534
+ query = (
1535
+ self._datasets_versions_select(ds.c.name, dv.c.version)
1536
+ .select_from(dv.join(ds, join_condition))
1537
+ .where(dv.c.job_id == job_id)
1538
+ )
1539
+
1540
+ return list(self.db.execute(query))
@@ -15,6 +15,7 @@ from typing import (
15
15
  )
16
16
 
17
17
  import sqlalchemy
18
+ from packaging import version
18
19
  from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
19
20
  from sqlalchemy.dialects import sqlite
20
21
  from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
@@ -153,7 +154,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
153
154
  if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
154
155
  import sys
155
156
 
156
- db.set_trace_callback(sys.stderr.write)
157
+ db.set_trace_callback(lambda stmt: print(stmt, file=sys.stderr))
157
158
 
158
159
  load_usearch_extension(db)
159
160
 
@@ -345,45 +346,36 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
345
346
  def get_next_ids(self, uri: str, count: int) -> range:
346
347
  """Returns a range of IDs for the given URI."""
347
348
 
348
- # NOTE: we can't use RETURNING clause here because it is only available
349
- # in sqlalchemy v2, see
350
- # https://github.com/sqlalchemy/sqlalchemy/issues/6195#issuecomment-1248700677
351
- # After we upgrade to sqlalchemy v2, we can use the following code,
352
- # leaving fallback to the current implementation for older versions of SQLite,
353
- # which is still supported, for example, in Ubuntu 20.04 LTS (Focal Fossa),
354
- # where SQLite version 3.31.1 is used.
355
-
356
- # sqlite_version = version.parse(sqlite3.sqlite_version)
357
- # if sqlite_version >= version.parse("3.35.0"):
358
- # # RETURNING is supported on SQLite 3.35.0 (2021-03-12) or newer
359
- # stmt = (
360
- # sqlite.insert(self._table)
361
- # .values(uri=uri, last_id=count)
362
- # .on_conflict_do_update(
363
- # index_elements=["uri"],
364
- # set_={"last_id": self._table.c.last_id + count},
365
- # )
366
- # .returning(self._table.c.last_id)
367
- # )
368
- # last_id = self._db.execute(stmt).fetchone()[0]
369
- # else:
370
- # (fallback to the current implementation with a transaction)
371
-
372
- # Transactions ensure no concurrency conflicts
373
- with self._db.transaction() as conn:
374
- # UPSERT syntax was added to SQLite with version 3.24.0 (2018-06-04).
375
- stmt_ins = (
349
+ sqlite_version = version.parse(sqlite3.sqlite_version)
350
+ is_returning_supported = sqlite_version >= version.parse("3.35.0")
351
+ if is_returning_supported:
352
+ stmt = (
376
353
  sqlite.insert(self._table)
377
354
  .values(uri=uri, last_id=count)
378
355
  .on_conflict_do_update(
379
356
  index_elements=["uri"],
380
357
  set_={"last_id": self._table.c.last_id + count},
381
358
  )
359
+ .returning(self._table.c.last_id)
382
360
  )
383
- self._db.execute(stmt_ins, conn=conn)
361
+ last_id = self._db.execute(stmt).fetchone()[0]
362
+ else:
363
+ # Older versions of SQLite are still the default under Ubuntu LTS,
364
+ # e.g. Ubuntu 20.04 LTS (Focal Fossa) uses 3.31.1
365
+ # Transactions ensure no concurrency conflicts
366
+ with self._db.transaction() as conn:
367
+ stmt_ins = (
368
+ sqlite.insert(self._table)
369
+ .values(uri=uri, last_id=count)
370
+ .on_conflict_do_update(
371
+ index_elements=["uri"],
372
+ set_={"last_id": self._table.c.last_id + count},
373
+ )
374
+ )
375
+ self._db.execute(stmt_ins, conn=conn)
384
376
 
385
- stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
386
- last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
377
+ stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
378
+ last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
387
379
 
388
380
  return range(last_id - count + 1, last_id + 1)
389
381
 
@@ -919,9 +919,7 @@ class AbstractWarehouse(ABC, Serializable):
919
919
  def is_temp_table_name(self, name: str) -> bool:
920
920
  """Returns if the given table name refers to a temporary
921
921
  or no longer needed table."""
922
- return name.startswith(
923
- (self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX, "ds_shadow_")
924
- ) or name.endswith("_shadow")
922
+ return name.startswith((self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX))
925
923
 
926
924
  def get_temp_table_names(self) -> list[str]:
927
925
  return [
datachain/dataset.py CHANGED
@@ -267,7 +267,6 @@ class DatasetRecord:
267
267
  name: str
268
268
  description: Optional[str]
269
269
  labels: list[str]
270
- shadow: bool
271
270
  schema: dict[str, Union[SQLType, type[SQLType]]]
272
271
  feature_schema: dict
273
272
  versions: list[DatasetVersion]
@@ -296,7 +295,6 @@ class DatasetRecord:
296
295
  name: str,
297
296
  description: Optional[str],
298
297
  labels: str,
299
- shadow: int,
300
298
  status: int,
301
299
  feature_schema: Optional[str],
302
300
  created_at: datetime,
@@ -356,7 +354,6 @@ class DatasetRecord:
356
354
  name,
357
355
  description,
358
356
  labels_lst,
359
- bool(shadow),
360
357
  cls.parse_schema(schema_dct), # type: ignore[arg-type]
361
358
  json.loads(feature_schema) if feature_schema else {},
362
359
  [dataset_version],
datachain/lib/arrow.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import re
2
2
  from collections.abc import Sequence
3
3
  from tempfile import NamedTemporaryFile
4
- from typing import TYPE_CHECKING, Optional
4
+ from typing import TYPE_CHECKING, Any, Optional
5
5
 
6
+ import orjson
6
7
  import pyarrow as pa
7
8
  from pyarrow.dataset import CsvFileFormat, dataset
8
9
  from tqdm import tqdm
@@ -10,6 +11,7 @@ from tqdm import tqdm
10
11
  from datachain.lib.data_model import dict_to_data_model
11
12
  from datachain.lib.file import ArrowRow, File
12
13
  from datachain.lib.model_store import ModelStore
14
+ from datachain.lib.signal_schema import SignalSchema
13
15
  from datachain.lib.udf import Generator
14
16
 
15
17
  if TYPE_CHECKING:
@@ -20,6 +22,9 @@ if TYPE_CHECKING:
20
22
  from datachain.lib.dc import DataChain
21
23
 
22
24
 
25
+ DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
26
+
27
+
23
28
  class ArrowGenerator(Generator):
24
29
  def __init__(
25
30
  self,
@@ -61,28 +66,35 @@ class ArrowGenerator(Generator):
61
66
  path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
62
67
  )
63
68
  hf_schema = _get_hf_schema(ds.schema)
69
+ use_datachain_schema = (
70
+ bool(ds.schema.metadata)
71
+ and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
72
+ )
64
73
  index = 0
65
74
  with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
66
75
  for record_batch in ds.to_batches():
67
76
  for record in record_batch.to_pylist():
68
- vals = list(record.values())
69
- if self.output_schema:
70
- fields = self.output_schema.model_fields
71
- vals_dict = {}
72
- for i, ((field, field_info), val) in enumerate(
73
- zip(fields.items(), vals)
74
- ):
75
- anno = field_info.annotation
76
- if hf_schema:
77
- from datachain.lib.hf import convert_feature
78
-
79
- feat = list(hf_schema[0].values())[i]
80
- vals_dict[field] = convert_feature(val, feat, anno)
81
- elif ModelStore.is_pydantic(anno):
82
- vals_dict[field] = anno(**val) # type: ignore[misc]
83
- else:
84
- vals_dict[field] = val
85
- vals = [self.output_schema(**vals_dict)]
77
+ if use_datachain_schema and self.output_schema:
78
+ vals = [_nested_model_instantiate(record, self.output_schema)]
79
+ else:
80
+ vals = list(record.values())
81
+ if self.output_schema:
82
+ fields = self.output_schema.model_fields
83
+ vals_dict = {}
84
+ for i, ((field, field_info), val) in enumerate(
85
+ zip(fields.items(), vals)
86
+ ):
87
+ anno = field_info.annotation
88
+ if hf_schema:
89
+ from datachain.lib.hf import convert_feature
90
+
91
+ feat = list(hf_schema[0].values())[i]
92
+ vals_dict[field] = convert_feature(val, feat, anno)
93
+ elif ModelStore.is_pydantic(anno):
94
+ vals_dict[field] = anno(**val) # type: ignore[misc]
95
+ else:
96
+ vals_dict[field] = val
97
+ vals = [self.output_schema(**vals_dict)]
86
98
  if self.source:
87
99
  kwargs: dict = self.kwargs
88
100
  # Can't serialize CsvFileFormat; may lose formatting options.
@@ -113,6 +125,9 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
113
125
  )
114
126
  if not col_names:
115
127
  col_names = schema.names
128
+ signal_schema = _get_datachain_schema(schema)
129
+ if signal_schema:
130
+ return signal_schema.values
116
131
  columns = _convert_col_names(col_names) # type: ignore[arg-type]
117
132
  hf_schema = _get_hf_schema(schema)
118
133
  if hf_schema:
@@ -197,3 +212,33 @@ def _get_hf_schema(
197
212
  features = schema_from_arrow(schema)
198
213
  return features, get_output_schema(features)
199
214
  return None
215
+
216
+
217
+ def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
218
+ """Return a restored SignalSchema from parquet metadata, if any is found."""
219
+ if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
220
+ serialized_signal_schema = orjson.loads(
221
+ schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
222
+ )
223
+ return SignalSchema.deserialize(serialized_signal_schema)
224
+ return None
225
+
226
+
227
+ def _nested_model_instantiate(
228
+ column_values: dict[str, Any], model: type["BaseModel"], prefix: str = ""
229
+ ) -> "BaseModel":
230
+ """Instantiate the given model and all sub-models/fields based on the provided
231
+ column values."""
232
+ vals_dict = {}
233
+ for field, field_info in model.model_fields.items():
234
+ anno = field_info.annotation
235
+ cur_path = f"{prefix}.{field}" if prefix else field
236
+ if ModelStore.is_pydantic(anno):
237
+ vals_dict[field] = _nested_model_instantiate(
238
+ column_values,
239
+ anno, # type: ignore[arg-type]
240
+ prefix=cur_path,
241
+ )
242
+ elif cur_path in column_values:
243
+ vals_dict[field] = column_values[cur_path]
244
+ return model(**vals_dict)