datachain 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/cli.py +3 -2
- datachain/data_storage/metastore.py +28 -9
- datachain/data_storage/sqlite.py +24 -32
- datachain/data_storage/warehouse.py +1 -3
- datachain/dataset.py +0 -3
- datachain/lib/arrow.py +64 -19
- datachain/lib/dc.py +310 -123
- datachain/lib/listing.py +5 -3
- datachain/lib/pytorch.py +5 -1
- datachain/lib/udf.py +100 -78
- datachain/lib/udf_signature.py +8 -6
- datachain/query/dataset.py +7 -7
- datachain/query/dispatch.py +2 -2
- datachain/query/session.py +42 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/METADATA +1 -1
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/RECORD +21 -22
- datachain/query/udf.py +0 -126
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/LICENSE +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/WHEEL +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -988,6 +988,14 @@ class Catalog:
|
|
|
988
988
|
schema = {
|
|
989
989
|
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
|
|
990
990
|
}
|
|
991
|
+
|
|
992
|
+
job_id = job_id or os.getenv("DATACHAIN_JOB_ID")
|
|
993
|
+
if not job_id:
|
|
994
|
+
from datachain.query.session import Session
|
|
995
|
+
|
|
996
|
+
session = Session.get(catalog=self)
|
|
997
|
+
job_id = session.job_id
|
|
998
|
+
|
|
991
999
|
dataset = self.metastore.create_dataset_version(
|
|
992
1000
|
dataset,
|
|
993
1001
|
version,
|
datachain/cli.py
CHANGED
|
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
|
12
12
|
|
|
13
13
|
import shtab
|
|
14
14
|
|
|
15
|
-
from datachain import utils
|
|
15
|
+
from datachain import Session, utils
|
|
16
16
|
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
|
|
17
17
|
from datachain.lib.dc import DataChain
|
|
18
18
|
from datachain.telemetry import telemetry
|
|
@@ -770,7 +770,8 @@ def show(
|
|
|
770
770
|
show_records(records, collapse_columns=not no_collapse)
|
|
771
771
|
if schema and dataset_version.feature_schema:
|
|
772
772
|
print("\nSchema:")
|
|
773
|
-
|
|
773
|
+
session = Session.get(catalog=catalog)
|
|
774
|
+
dc = DataChain.from_dataset(name=name, version=version, session=session)
|
|
774
775
|
dc.print_schema()
|
|
775
776
|
|
|
776
777
|
|
|
@@ -15,7 +15,6 @@ from uuid import uuid4
|
|
|
15
15
|
from sqlalchemy import (
|
|
16
16
|
JSON,
|
|
17
17
|
BigInteger,
|
|
18
|
-
Boolean,
|
|
19
18
|
Column,
|
|
20
19
|
DateTime,
|
|
21
20
|
ForeignKey,
|
|
@@ -51,7 +50,6 @@ if TYPE_CHECKING:
|
|
|
51
50
|
from datachain.data_storage import AbstractIDGenerator, schema
|
|
52
51
|
from datachain.data_storage.db_engine import DatabaseEngine
|
|
53
52
|
|
|
54
|
-
|
|
55
53
|
logger = logging.getLogger("datachain")
|
|
56
54
|
|
|
57
55
|
|
|
@@ -228,7 +226,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
228
226
|
self,
|
|
229
227
|
dataset: DatasetRecord,
|
|
230
228
|
version: int,
|
|
231
|
-
status: int
|
|
229
|
+
status: int,
|
|
232
230
|
sources: str = "",
|
|
233
231
|
feature_schema: Optional[dict] = None,
|
|
234
232
|
query_script: str = "",
|
|
@@ -385,6 +383,11 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
385
383
|
) -> None:
|
|
386
384
|
"""Set the status of the given job and dataset."""
|
|
387
385
|
|
|
386
|
+
@abstractmethod
|
|
387
|
+
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
388
|
+
"""Returns dataset names and versions for the job."""
|
|
389
|
+
raise NotImplementedError
|
|
390
|
+
|
|
388
391
|
|
|
389
392
|
class AbstractDBMetastore(AbstractMetastore):
|
|
390
393
|
"""
|
|
@@ -448,7 +451,6 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
448
451
|
Column("name", Text, nullable=False),
|
|
449
452
|
Column("description", Text),
|
|
450
453
|
Column("labels", JSON, nullable=True),
|
|
451
|
-
Column("shadow", Boolean, nullable=False),
|
|
452
454
|
Column("status", Integer, nullable=False),
|
|
453
455
|
Column("feature_schema", JSON, nullable=True),
|
|
454
456
|
Column("created_at", DateTime(timezone=True)),
|
|
@@ -481,8 +483,11 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
481
483
|
nullable=False,
|
|
482
484
|
),
|
|
483
485
|
Column("version", Integer, nullable=False),
|
|
484
|
-
|
|
485
|
-
|
|
486
|
+
Column(
|
|
487
|
+
"status",
|
|
488
|
+
Integer,
|
|
489
|
+
nullable=False,
|
|
490
|
+
),
|
|
486
491
|
Column("feature_schema", JSON, nullable=True),
|
|
487
492
|
Column("created_at", DateTime(timezone=True)),
|
|
488
493
|
Column("finished_at", DateTime(timezone=True)),
|
|
@@ -969,7 +974,6 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
969
974
|
# TODO abstract this method and add registered = True based on kwargs
|
|
970
975
|
query = self._datasets_insert().values(
|
|
971
976
|
name=name,
|
|
972
|
-
shadow=False,
|
|
973
977
|
status=status,
|
|
974
978
|
feature_schema=json.dumps(feature_schema or {}),
|
|
975
979
|
created_at=datetime.now(timezone.utc),
|
|
@@ -992,7 +996,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
992
996
|
self,
|
|
993
997
|
dataset: DatasetRecord,
|
|
994
998
|
version: int,
|
|
995
|
-
status: int
|
|
999
|
+
status: int,
|
|
996
1000
|
sources: str = "",
|
|
997
1001
|
feature_schema: Optional[dict] = None,
|
|
998
1002
|
query_script: str = "",
|
|
@@ -1018,7 +1022,7 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1018
1022
|
query = self._datasets_versions_insert().values(
|
|
1019
1023
|
dataset_id=dataset.id,
|
|
1020
1024
|
version=version,
|
|
1021
|
-
status=status,
|
|
1025
|
+
status=status,
|
|
1022
1026
|
feature_schema=json.dumps(feature_schema or {}),
|
|
1023
1027
|
created_at=created_at or datetime.now(timezone.utc),
|
|
1024
1028
|
finished_at=finished_at,
|
|
@@ -1519,3 +1523,18 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1519
1523
|
.values(status=dataset_status)
|
|
1520
1524
|
)
|
|
1521
1525
|
self.db.execute(query, conn=conn) # type: ignore[attr-defined]
|
|
1526
|
+
|
|
1527
|
+
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
|
|
1528
|
+
"""Returns dataset names and versions for the job."""
|
|
1529
|
+
dv = self._datasets_versions
|
|
1530
|
+
ds = self._datasets
|
|
1531
|
+
|
|
1532
|
+
join_condition = dv.c.dataset_id == ds.c.id
|
|
1533
|
+
|
|
1534
|
+
query = (
|
|
1535
|
+
self._datasets_versions_select(ds.c.name, dv.c.version)
|
|
1536
|
+
.select_from(dv.join(ds, join_condition))
|
|
1537
|
+
.where(dv.c.job_id == job_id)
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
return list(self.db.execute(query))
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -15,6 +15,7 @@ from typing import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
import sqlalchemy
|
|
18
|
+
from packaging import version
|
|
18
19
|
from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
|
|
19
20
|
from sqlalchemy.dialects import sqlite
|
|
20
21
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
@@ -153,7 +154,7 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
153
154
|
if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
|
|
154
155
|
import sys
|
|
155
156
|
|
|
156
|
-
db.set_trace_callback(sys.stderr
|
|
157
|
+
db.set_trace_callback(lambda stmt: print(stmt, file=sys.stderr))
|
|
157
158
|
|
|
158
159
|
load_usearch_extension(db)
|
|
159
160
|
|
|
@@ -345,45 +346,36 @@ class SQLiteIDGenerator(AbstractDBIDGenerator):
|
|
|
345
346
|
def get_next_ids(self, uri: str, count: int) -> range:
|
|
346
347
|
"""Returns a range of IDs for the given URI."""
|
|
347
348
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
# leaving fallback to the current implementation for older versions of SQLite,
|
|
353
|
-
# which is still supported, for example, in Ubuntu 20.04 LTS (Focal Fossa),
|
|
354
|
-
# where SQLite version 3.31.1 is used.
|
|
355
|
-
|
|
356
|
-
# sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
357
|
-
# if sqlite_version >= version.parse("3.35.0"):
|
|
358
|
-
# # RETURNING is supported on SQLite 3.35.0 (2021-03-12) or newer
|
|
359
|
-
# stmt = (
|
|
360
|
-
# sqlite.insert(self._table)
|
|
361
|
-
# .values(uri=uri, last_id=count)
|
|
362
|
-
# .on_conflict_do_update(
|
|
363
|
-
# index_elements=["uri"],
|
|
364
|
-
# set_={"last_id": self._table.c.last_id + count},
|
|
365
|
-
# )
|
|
366
|
-
# .returning(self._table.c.last_id)
|
|
367
|
-
# )
|
|
368
|
-
# last_id = self._db.execute(stmt).fetchone()[0]
|
|
369
|
-
# else:
|
|
370
|
-
# (fallback to the current implementation with a transaction)
|
|
371
|
-
|
|
372
|
-
# Transactions ensure no concurrency conflicts
|
|
373
|
-
with self._db.transaction() as conn:
|
|
374
|
-
# UPSERT syntax was added to SQLite with version 3.24.0 (2018-06-04).
|
|
375
|
-
stmt_ins = (
|
|
349
|
+
sqlite_version = version.parse(sqlite3.sqlite_version)
|
|
350
|
+
is_returning_supported = sqlite_version >= version.parse("3.35.0")
|
|
351
|
+
if is_returning_supported:
|
|
352
|
+
stmt = (
|
|
376
353
|
sqlite.insert(self._table)
|
|
377
354
|
.values(uri=uri, last_id=count)
|
|
378
355
|
.on_conflict_do_update(
|
|
379
356
|
index_elements=["uri"],
|
|
380
357
|
set_={"last_id": self._table.c.last_id + count},
|
|
381
358
|
)
|
|
359
|
+
.returning(self._table.c.last_id)
|
|
382
360
|
)
|
|
383
|
-
self._db.execute(
|
|
361
|
+
last_id = self._db.execute(stmt).fetchone()[0]
|
|
362
|
+
else:
|
|
363
|
+
# Older versions of SQLite are still the default under Ubuntu LTS,
|
|
364
|
+
# e.g. Ubuntu 20.04 LTS (Focal Fossa) uses 3.31.1
|
|
365
|
+
# Transactions ensure no concurrency conflicts
|
|
366
|
+
with self._db.transaction() as conn:
|
|
367
|
+
stmt_ins = (
|
|
368
|
+
sqlite.insert(self._table)
|
|
369
|
+
.values(uri=uri, last_id=count)
|
|
370
|
+
.on_conflict_do_update(
|
|
371
|
+
index_elements=["uri"],
|
|
372
|
+
set_={"last_id": self._table.c.last_id + count},
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
self._db.execute(stmt_ins, conn=conn)
|
|
384
376
|
|
|
385
|
-
|
|
386
|
-
|
|
377
|
+
stmt_sel = select(self._table.c.last_id).where(self._table.c.uri == uri)
|
|
378
|
+
last_id = self._db.execute(stmt_sel, conn=conn).fetchone()[0]
|
|
387
379
|
|
|
388
380
|
return range(last_id - count + 1, last_id + 1)
|
|
389
381
|
|
|
@@ -919,9 +919,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
919
919
|
def is_temp_table_name(self, name: str) -> bool:
|
|
920
920
|
"""Returns if the given table name refers to a temporary
|
|
921
921
|
or no longer needed table."""
|
|
922
|
-
return name.startswith(
|
|
923
|
-
(self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX, "ds_shadow_")
|
|
924
|
-
) or name.endswith("_shadow")
|
|
922
|
+
return name.startswith((self.TMP_TABLE_NAME_PREFIX, self.UDF_TABLE_NAME_PREFIX))
|
|
925
923
|
|
|
926
924
|
def get_temp_table_names(self) -> list[str]:
|
|
927
925
|
return [
|
datachain/dataset.py
CHANGED
|
@@ -267,7 +267,6 @@ class DatasetRecord:
|
|
|
267
267
|
name: str
|
|
268
268
|
description: Optional[str]
|
|
269
269
|
labels: list[str]
|
|
270
|
-
shadow: bool
|
|
271
270
|
schema: dict[str, Union[SQLType, type[SQLType]]]
|
|
272
271
|
feature_schema: dict
|
|
273
272
|
versions: list[DatasetVersion]
|
|
@@ -296,7 +295,6 @@ class DatasetRecord:
|
|
|
296
295
|
name: str,
|
|
297
296
|
description: Optional[str],
|
|
298
297
|
labels: str,
|
|
299
|
-
shadow: int,
|
|
300
298
|
status: int,
|
|
301
299
|
feature_schema: Optional[str],
|
|
302
300
|
created_at: datetime,
|
|
@@ -356,7 +354,6 @@ class DatasetRecord:
|
|
|
356
354
|
name,
|
|
357
355
|
description,
|
|
358
356
|
labels_lst,
|
|
359
|
-
bool(shadow),
|
|
360
357
|
cls.parse_schema(schema_dct), # type: ignore[arg-type]
|
|
361
358
|
json.loads(feature_schema) if feature_schema else {},
|
|
362
359
|
[dataset_version],
|
datachain/lib/arrow.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
from tempfile import NamedTemporaryFile
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
5
5
|
|
|
6
|
+
import orjson
|
|
6
7
|
import pyarrow as pa
|
|
7
8
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
9
|
from tqdm import tqdm
|
|
@@ -10,6 +11,7 @@ from tqdm import tqdm
|
|
|
10
11
|
from datachain.lib.data_model import dict_to_data_model
|
|
11
12
|
from datachain.lib.file import ArrowRow, File
|
|
12
13
|
from datachain.lib.model_store import ModelStore
|
|
14
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
13
15
|
from datachain.lib.udf import Generator
|
|
14
16
|
|
|
15
17
|
if TYPE_CHECKING:
|
|
@@ -20,6 +22,9 @@ if TYPE_CHECKING:
|
|
|
20
22
|
from datachain.lib.dc import DataChain
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
class ArrowGenerator(Generator):
|
|
24
29
|
def __init__(
|
|
25
30
|
self,
|
|
@@ -61,28 +66,35 @@ class ArrowGenerator(Generator):
|
|
|
61
66
|
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
|
|
62
67
|
)
|
|
63
68
|
hf_schema = _get_hf_schema(ds.schema)
|
|
69
|
+
use_datachain_schema = (
|
|
70
|
+
bool(ds.schema.metadata)
|
|
71
|
+
and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
|
|
72
|
+
)
|
|
64
73
|
index = 0
|
|
65
74
|
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
|
|
66
75
|
for record_batch in ds.to_batches():
|
|
67
76
|
for record in record_batch.to_pylist():
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
77
|
+
if use_datachain_schema and self.output_schema:
|
|
78
|
+
vals = [_nested_model_instantiate(record, self.output_schema)]
|
|
79
|
+
else:
|
|
80
|
+
vals = list(record.values())
|
|
81
|
+
if self.output_schema:
|
|
82
|
+
fields = self.output_schema.model_fields
|
|
83
|
+
vals_dict = {}
|
|
84
|
+
for i, ((field, field_info), val) in enumerate(
|
|
85
|
+
zip(fields.items(), vals)
|
|
86
|
+
):
|
|
87
|
+
anno = field_info.annotation
|
|
88
|
+
if hf_schema:
|
|
89
|
+
from datachain.lib.hf import convert_feature
|
|
90
|
+
|
|
91
|
+
feat = list(hf_schema[0].values())[i]
|
|
92
|
+
vals_dict[field] = convert_feature(val, feat, anno)
|
|
93
|
+
elif ModelStore.is_pydantic(anno):
|
|
94
|
+
vals_dict[field] = anno(**val) # type: ignore[misc]
|
|
95
|
+
else:
|
|
96
|
+
vals_dict[field] = val
|
|
97
|
+
vals = [self.output_schema(**vals_dict)]
|
|
86
98
|
if self.source:
|
|
87
99
|
kwargs: dict = self.kwargs
|
|
88
100
|
# Can't serialize CsvFileFormat; may lose formatting options.
|
|
@@ -113,6 +125,9 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
113
125
|
)
|
|
114
126
|
if not col_names:
|
|
115
127
|
col_names = schema.names
|
|
128
|
+
signal_schema = _get_datachain_schema(schema)
|
|
129
|
+
if signal_schema:
|
|
130
|
+
return signal_schema.values
|
|
116
131
|
columns = _convert_col_names(col_names) # type: ignore[arg-type]
|
|
117
132
|
hf_schema = _get_hf_schema(schema)
|
|
118
133
|
if hf_schema:
|
|
@@ -197,3 +212,33 @@ def _get_hf_schema(
|
|
|
197
212
|
features = schema_from_arrow(schema)
|
|
198
213
|
return features, get_output_schema(features)
|
|
199
214
|
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _get_datachain_schema(schema: "pa.Schema") -> Optional[SignalSchema]:
|
|
218
|
+
"""Return a restored SignalSchema from parquet metadata, if any is found."""
|
|
219
|
+
if schema.metadata and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in schema.metadata:
|
|
220
|
+
serialized_signal_schema = orjson.loads(
|
|
221
|
+
schema.metadata[DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY]
|
|
222
|
+
)
|
|
223
|
+
return SignalSchema.deserialize(serialized_signal_schema)
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _nested_model_instantiate(
|
|
228
|
+
column_values: dict[str, Any], model: type["BaseModel"], prefix: str = ""
|
|
229
|
+
) -> "BaseModel":
|
|
230
|
+
"""Instantiate the given model and all sub-models/fields based on the provided
|
|
231
|
+
column values."""
|
|
232
|
+
vals_dict = {}
|
|
233
|
+
for field, field_info in model.model_fields.items():
|
|
234
|
+
anno = field_info.annotation
|
|
235
|
+
cur_path = f"{prefix}.{field}" if prefix else field
|
|
236
|
+
if ModelStore.is_pydantic(anno):
|
|
237
|
+
vals_dict[field] = _nested_model_instantiate(
|
|
238
|
+
column_values,
|
|
239
|
+
anno, # type: ignore[arg-type]
|
|
240
|
+
prefix=cur_path,
|
|
241
|
+
)
|
|
242
|
+
elif cur_path in column_values:
|
|
243
|
+
vals_dict[field] = column_values[cur_path]
|
|
244
|
+
return model(**vals_dict)
|