datachain 0.34.6__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/catalog.py +75 -83
- datachain/catalog/loader.py +3 -3
- datachain/checkpoint.py +1 -2
- datachain/cli/__init__.py +2 -4
- datachain/cli/commands/datasets.py +13 -13
- datachain/cli/commands/ls.py +4 -4
- datachain/cli/commands/query.py +3 -3
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +1 -2
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +11 -21
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +4 -4
- datachain/client/local.py +4 -4
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +5 -5
- datachain/data_storage/metastore.py +107 -107
- datachain/data_storage/schema.py +18 -24
- datachain/data_storage/sqlite.py +21 -28
- datachain/data_storage/warehouse.py +13 -13
- datachain/dataset.py +64 -70
- datachain/delta.py +21 -18
- datachain/diff/__init__.py +13 -13
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +45 -42
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +54 -81
- datachain/job.py +8 -8
- datachain/lib/arrow.py +17 -14
- datachain/lib/audio.py +6 -6
- datachain/lib/clip.py +5 -4
- datachain/lib/convert/python_to_sql.py +4 -22
- datachain/lib/convert/values_to_tuples.py +4 -9
- datachain/lib/data_model.py +20 -19
- datachain/lib/dataset_info.py +6 -6
- datachain/lib/dc/csv.py +10 -10
- datachain/lib/dc/database.py +28 -29
- datachain/lib/dc/datachain.py +98 -97
- datachain/lib/dc/datasets.py +22 -22
- datachain/lib/dc/hf.py +4 -4
- datachain/lib/dc/json.py +9 -10
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +5 -5
- datachain/lib/dc/records.py +5 -5
- datachain/lib/dc/storage.py +12 -12
- datachain/lib/dc/storage_pattern.py +2 -2
- datachain/lib/dc/utils.py +11 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +32 -28
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +1 -2
- datachain/lib/model_store.py +3 -3
- datachain/lib/namespaces.py +4 -6
- datachain/lib/projects.py +5 -9
- datachain/lib/pytorch.py +10 -10
- datachain/lib/settings.py +23 -23
- datachain/lib/signal_schema.py +52 -44
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +25 -17
- datachain/lib/udf_signature.py +11 -11
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +30 -35
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +4 -4
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +1 -7
- datachain/project.py +4 -4
- datachain/query/batch.py +7 -8
- datachain/query/dataset.py +80 -87
- datachain/query/dispatch.py +7 -7
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/schema.py +7 -6
- datachain/query/session.py +7 -7
- datachain/query/udf.py +8 -7
- datachain/query/utils.py +3 -5
- datachain/remote/studio.py +33 -39
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +6 -9
- datachain/studio.py +30 -30
- datachain/toolkit/split.py +1 -2
- datachain/utils.py +21 -21
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/METADATA +2 -3
- datachain-0.35.0.dist-info/RECORD +173 -0
- datachain-0.34.6.dist-info/RECORD +0 -173
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/WHEEL +0 -0
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.34.6.dist-info → datachain-0.35.0.dist-info}/top_level.txt +0 -0
datachain/query/dataset.py
CHANGED
|
@@ -8,19 +8,11 @@ import string
|
|
|
8
8
|
import subprocess
|
|
9
9
|
import sys
|
|
10
10
|
from abc import ABC, abstractmethod
|
|
11
|
-
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
11
|
+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
12
12
|
from copy import copy
|
|
13
13
|
from functools import wraps
|
|
14
14
|
from types import GeneratorType
|
|
15
|
-
from typing import
|
|
16
|
-
TYPE_CHECKING,
|
|
17
|
-
Any,
|
|
18
|
-
Callable,
|
|
19
|
-
Optional,
|
|
20
|
-
Protocol,
|
|
21
|
-
TypeVar,
|
|
22
|
-
Union,
|
|
23
|
-
)
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
|
24
16
|
|
|
25
17
|
import attrs
|
|
26
18
|
import sqlalchemy
|
|
@@ -67,11 +59,12 @@ from datachain.utils import (
|
|
|
67
59
|
|
|
68
60
|
if TYPE_CHECKING:
|
|
69
61
|
from collections.abc import Mapping
|
|
62
|
+
from typing import Concatenate
|
|
70
63
|
|
|
71
64
|
from sqlalchemy.sql.elements import ClauseElement
|
|
72
65
|
from sqlalchemy.sql.schema import Table
|
|
73
66
|
from sqlalchemy.sql.selectable import GenerativeSelect
|
|
74
|
-
from typing_extensions import
|
|
67
|
+
from typing_extensions import ParamSpec, Self
|
|
75
68
|
|
|
76
69
|
from datachain.catalog import Catalog
|
|
77
70
|
from datachain.data_storage import AbstractWarehouse
|
|
@@ -83,13 +76,10 @@ if TYPE_CHECKING:
|
|
|
83
76
|
|
|
84
77
|
INSERT_BATCH_SIZE = 10000
|
|
85
78
|
|
|
86
|
-
PartitionByType =
|
|
87
|
-
str
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
Sequence[Union[str, Function, ColumnElement]],
|
|
91
|
-
]
|
|
92
|
-
JoinPredicateType = Union[str, ColumnClause, ColumnElement]
|
|
79
|
+
PartitionByType = (
|
|
80
|
+
str | Function | ColumnElement | Sequence[str | Function | ColumnElement]
|
|
81
|
+
)
|
|
82
|
+
JoinPredicateType = str | ColumnClause | ColumnElement
|
|
93
83
|
DatasetDependencyType = tuple["DatasetRecord", str]
|
|
94
84
|
|
|
95
85
|
logger = logging.getLogger("datachain")
|
|
@@ -411,14 +401,14 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
411
401
|
class UDFStep(Step, ABC):
|
|
412
402
|
udf: "UDFAdapter"
|
|
413
403
|
catalog: "Catalog"
|
|
414
|
-
partition_by:
|
|
404
|
+
partition_by: PartitionByType | None = None
|
|
415
405
|
is_generator = False
|
|
416
406
|
# Parameters from Settings
|
|
417
407
|
cache: bool = False
|
|
418
|
-
parallel:
|
|
419
|
-
workers:
|
|
420
|
-
min_task_size:
|
|
421
|
-
batch_size:
|
|
408
|
+
parallel: int | None = None
|
|
409
|
+
workers: bool | int = False
|
|
410
|
+
min_task_size: int | None = None
|
|
411
|
+
batch_size: int | None = None
|
|
422
412
|
|
|
423
413
|
def hash_inputs(self) -> str:
|
|
424
414
|
partition_by = ensure_sequence(self.partition_by or [])
|
|
@@ -624,7 +614,7 @@ class UDFStep(Step, ABC):
|
|
|
624
614
|
|
|
625
615
|
return tbl
|
|
626
616
|
|
|
627
|
-
def clone(self, partition_by:
|
|
617
|
+
def clone(self, partition_by: PartitionByType | None = None) -> "Self":
|
|
628
618
|
if partition_by is not None:
|
|
629
619
|
return self.__class__(
|
|
630
620
|
self.udf,
|
|
@@ -681,14 +671,14 @@ class UDFStep(Step, ABC):
|
|
|
681
671
|
class UDFSignal(UDFStep):
|
|
682
672
|
udf: "UDFAdapter"
|
|
683
673
|
catalog: "Catalog"
|
|
684
|
-
partition_by:
|
|
674
|
+
partition_by: PartitionByType | None = None
|
|
685
675
|
is_generator = False
|
|
686
676
|
# Parameters from Settings
|
|
687
677
|
cache: bool = False
|
|
688
|
-
parallel:
|
|
689
|
-
workers:
|
|
690
|
-
min_task_size:
|
|
691
|
-
batch_size:
|
|
678
|
+
parallel: int | None = None
|
|
679
|
+
workers: bool | int = False
|
|
680
|
+
min_task_size: int | None = None
|
|
681
|
+
batch_size: int | None = None
|
|
692
682
|
|
|
693
683
|
def create_udf_table(self, query: Select) -> "Table":
|
|
694
684
|
udf_output_columns: list[sqlalchemy.Column[Any]] = [
|
|
@@ -760,14 +750,14 @@ class RowGenerator(UDFStep):
|
|
|
760
750
|
|
|
761
751
|
udf: "UDFAdapter"
|
|
762
752
|
catalog: "Catalog"
|
|
763
|
-
partition_by:
|
|
753
|
+
partition_by: PartitionByType | None = None
|
|
764
754
|
is_generator = True
|
|
765
755
|
# Parameters from Settings
|
|
766
756
|
cache: bool = False
|
|
767
|
-
parallel:
|
|
768
|
-
workers:
|
|
769
|
-
min_task_size:
|
|
770
|
-
batch_size:
|
|
757
|
+
parallel: int | None = None
|
|
758
|
+
workers: bool | int = False
|
|
759
|
+
min_task_size: int | None = None
|
|
760
|
+
batch_size: int | None = None
|
|
771
761
|
|
|
772
762
|
def create_udf_table(self, query: Select) -> "Table":
|
|
773
763
|
warehouse = self.catalog.warehouse
|
|
@@ -814,7 +804,7 @@ class SQLClause(Step, ABC):
|
|
|
814
804
|
|
|
815
805
|
def parse_cols(
|
|
816
806
|
self,
|
|
817
|
-
cols: Sequence[
|
|
807
|
+
cols: Sequence[Function | ColumnElement],
|
|
818
808
|
) -> tuple[ColumnElement, ...]:
|
|
819
809
|
return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
|
|
820
810
|
|
|
@@ -825,7 +815,7 @@ class SQLClause(Step, ABC):
|
|
|
825
815
|
|
|
826
816
|
@frozen
|
|
827
817
|
class SQLSelect(SQLClause):
|
|
828
|
-
args: tuple[
|
|
818
|
+
args: tuple[Function | ColumnElement, ...]
|
|
829
819
|
|
|
830
820
|
def hash_inputs(self) -> str:
|
|
831
821
|
return hash_column_elements(self.args)
|
|
@@ -844,7 +834,7 @@ class SQLSelect(SQLClause):
|
|
|
844
834
|
|
|
845
835
|
@frozen
|
|
846
836
|
class SQLSelectExcept(SQLClause):
|
|
847
|
-
args: tuple[
|
|
837
|
+
args: tuple[Function | ColumnElement, ...]
|
|
848
838
|
|
|
849
839
|
def hash_inputs(self) -> str:
|
|
850
840
|
return hash_column_elements(self.args)
|
|
@@ -890,7 +880,7 @@ class SQLMutate(SQLClause):
|
|
|
890
880
|
|
|
891
881
|
@frozen
|
|
892
882
|
class SQLFilter(SQLClause):
|
|
893
|
-
expressions: tuple[
|
|
883
|
+
expressions: tuple[Function | ColumnElement, ...]
|
|
894
884
|
|
|
895
885
|
def hash_inputs(self) -> str:
|
|
896
886
|
return hash_column_elements(self.expressions)
|
|
@@ -906,7 +896,7 @@ class SQLFilter(SQLClause):
|
|
|
906
896
|
|
|
907
897
|
@frozen
|
|
908
898
|
class SQLOrderBy(SQLClause):
|
|
909
|
-
args: tuple[
|
|
899
|
+
args: tuple[Function | ColumnElement, ...]
|
|
910
900
|
|
|
911
901
|
def hash_inputs(self) -> str:
|
|
912
902
|
return hash_column_elements(self.args)
|
|
@@ -1011,7 +1001,7 @@ class SQLJoin(Step):
|
|
|
1011
1001
|
catalog: "Catalog"
|
|
1012
1002
|
query1: "DatasetQuery"
|
|
1013
1003
|
query2: "DatasetQuery"
|
|
1014
|
-
predicates:
|
|
1004
|
+
predicates: JoinPredicateType | tuple[JoinPredicateType, ...]
|
|
1015
1005
|
inner: bool
|
|
1016
1006
|
full: bool
|
|
1017
1007
|
rname: str
|
|
@@ -1150,8 +1140,8 @@ class SQLJoin(Step):
|
|
|
1150
1140
|
|
|
1151
1141
|
@frozen
|
|
1152
1142
|
class SQLGroupBy(SQLClause):
|
|
1153
|
-
cols: Sequence[
|
|
1154
|
-
group_by: Sequence[
|
|
1143
|
+
cols: Sequence[str | Function | ColumnElement]
|
|
1144
|
+
group_by: Sequence[str | Function | ColumnElement]
|
|
1155
1145
|
|
|
1156
1146
|
def hash_inputs(self) -> str:
|
|
1157
1147
|
return hashlib.sha256(
|
|
@@ -1211,6 +1201,7 @@ def _validate_columns(
|
|
|
1211
1201
|
missing_left,
|
|
1212
1202
|
],
|
|
1213
1203
|
["left", "right"],
|
|
1204
|
+
strict=False,
|
|
1214
1205
|
)
|
|
1215
1206
|
if missing_columns
|
|
1216
1207
|
]
|
|
@@ -1243,32 +1234,32 @@ class DatasetQuery:
|
|
|
1243
1234
|
def __init__(
|
|
1244
1235
|
self,
|
|
1245
1236
|
name: str,
|
|
1246
|
-
version:
|
|
1247
|
-
project_name:
|
|
1248
|
-
namespace_name:
|
|
1249
|
-
catalog:
|
|
1250
|
-
session:
|
|
1237
|
+
version: str | None = None,
|
|
1238
|
+
project_name: str | None = None,
|
|
1239
|
+
namespace_name: str | None = None,
|
|
1240
|
+
catalog: "Catalog | None" = None,
|
|
1241
|
+
session: Session | None = None,
|
|
1251
1242
|
in_memory: bool = False,
|
|
1252
1243
|
update: bool = False,
|
|
1253
1244
|
) -> None:
|
|
1254
1245
|
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
|
|
1255
1246
|
self.catalog = catalog or self.session.catalog
|
|
1256
1247
|
self.steps: list[Step] = []
|
|
1257
|
-
self._chunk_index:
|
|
1258
|
-
self._chunk_total:
|
|
1248
|
+
self._chunk_index: int | None = None
|
|
1249
|
+
self._chunk_total: int | None = None
|
|
1259
1250
|
self.temp_table_names: list[str] = []
|
|
1260
1251
|
self.dependencies: set[DatasetDependencyType] = set()
|
|
1261
1252
|
self.table = self.get_table()
|
|
1262
|
-
self.starting_step:
|
|
1263
|
-
self.name:
|
|
1264
|
-
self.version:
|
|
1265
|
-
self.feature_schema:
|
|
1266
|
-
self.column_types:
|
|
1253
|
+
self.starting_step: QueryStep | None = None
|
|
1254
|
+
self.name: str | None = None
|
|
1255
|
+
self.version: str | None = None
|
|
1256
|
+
self.feature_schema: dict | None = None
|
|
1257
|
+
self.column_types: dict[str, Any] | None = None
|
|
1267
1258
|
self.before_steps: list[Callable] = []
|
|
1268
|
-
self.listing_fn:
|
|
1259
|
+
self.listing_fn: Callable | None = None
|
|
1269
1260
|
self.update = update
|
|
1270
1261
|
|
|
1271
|
-
self.list_ds_name:
|
|
1262
|
+
self.list_ds_name: str | None = None
|
|
1272
1263
|
|
|
1273
1264
|
self.name = name
|
|
1274
1265
|
self.dialect = self.catalog.warehouse.db.dialect
|
|
@@ -1352,7 +1343,7 @@ class DatasetQuery:
|
|
|
1352
1343
|
"""
|
|
1353
1344
|
return self.name is not None and self.version is not None
|
|
1354
1345
|
|
|
1355
|
-
def c(self, column:
|
|
1346
|
+
def c(self, column: C | str) -> "ColumnClause[Any]":
|
|
1356
1347
|
col: sqlalchemy.ColumnClause = (
|
|
1357
1348
|
sqlalchemy.column(column)
|
|
1358
1349
|
if isinstance(column, str)
|
|
@@ -1447,6 +1438,7 @@ class DatasetQuery:
|
|
|
1447
1438
|
# This is needed to always use a new connection with all metastore and warehouse
|
|
1448
1439
|
# implementations, as errors may close or render unusable the existing
|
|
1449
1440
|
# connections.
|
|
1441
|
+
assert len(self.temp_table_names) == len(set(self.temp_table_names))
|
|
1450
1442
|
with self.catalog.metastore.clone(use_new_connection=True) as metastore:
|
|
1451
1443
|
metastore.cleanup_tables(self.temp_table_names)
|
|
1452
1444
|
with self.catalog.warehouse.clone(use_new_connection=True) as warehouse:
|
|
@@ -1461,7 +1453,7 @@ class DatasetQuery:
|
|
|
1461
1453
|
return list(result)
|
|
1462
1454
|
|
|
1463
1455
|
def to_db_records(self) -> list[dict[str, Any]]:
|
|
1464
|
-
return self.db_results(lambda cols, row: dict(zip(cols, row)))
|
|
1456
|
+
return self.db_results(lambda cols, row: dict(zip(cols, row, strict=False)))
|
|
1465
1457
|
|
|
1466
1458
|
@contextlib.contextmanager
|
|
1467
1459
|
def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
|
|
@@ -1500,7 +1492,7 @@ class DatasetQuery:
|
|
|
1500
1492
|
yield from rows
|
|
1501
1493
|
|
|
1502
1494
|
async def get_params(row: Sequence) -> tuple:
|
|
1503
|
-
row_dict = RowDict(zip(query_fields, row))
|
|
1495
|
+
row_dict = RowDict(zip(query_fields, row, strict=False))
|
|
1504
1496
|
return tuple( # noqa: C409
|
|
1505
1497
|
[
|
|
1506
1498
|
await p.get_value_async(
|
|
@@ -1540,6 +1532,7 @@ class DatasetQuery:
|
|
|
1540
1532
|
obj.steps = obj.steps.copy()
|
|
1541
1533
|
if new_table:
|
|
1542
1534
|
obj.table = self.get_table()
|
|
1535
|
+
obj.temp_table_names = []
|
|
1543
1536
|
return obj
|
|
1544
1537
|
|
|
1545
1538
|
@detach
|
|
@@ -1720,7 +1713,7 @@ class DatasetQuery:
|
|
|
1720
1713
|
def join(
|
|
1721
1714
|
self,
|
|
1722
1715
|
dataset_query: "DatasetQuery",
|
|
1723
|
-
predicates:
|
|
1716
|
+
predicates: JoinPredicateType | Sequence[JoinPredicateType],
|
|
1724
1717
|
inner=False,
|
|
1725
1718
|
full=False,
|
|
1726
1719
|
rname="{name}_right",
|
|
@@ -1762,17 +1755,17 @@ class DatasetQuery:
|
|
|
1762
1755
|
def add_signals(
|
|
1763
1756
|
self,
|
|
1764
1757
|
udf: "UDFAdapter",
|
|
1765
|
-
partition_by:
|
|
1758
|
+
partition_by: PartitionByType | None = None,
|
|
1766
1759
|
# Parameters from Settings
|
|
1767
1760
|
cache: bool = False,
|
|
1768
|
-
parallel:
|
|
1769
|
-
workers:
|
|
1770
|
-
min_task_size:
|
|
1771
|
-
batch_size:
|
|
1761
|
+
parallel: int | None = None,
|
|
1762
|
+
workers: bool | int = False,
|
|
1763
|
+
min_task_size: int | None = None,
|
|
1764
|
+
batch_size: int | None = None,
|
|
1772
1765
|
# Parameters are unused, kept only to match the signature of Settings.to_dict
|
|
1773
|
-
prefetch:
|
|
1774
|
-
namespace:
|
|
1775
|
-
project:
|
|
1766
|
+
prefetch: int | None = None,
|
|
1767
|
+
namespace: str | None = None,
|
|
1768
|
+
project: str | None = None,
|
|
1776
1769
|
) -> "Self":
|
|
1777
1770
|
"""
|
|
1778
1771
|
Adds one or more signals based on the results from the provided UDF.
|
|
@@ -1813,17 +1806,17 @@ class DatasetQuery:
|
|
|
1813
1806
|
def generate(
|
|
1814
1807
|
self,
|
|
1815
1808
|
udf: "UDFAdapter",
|
|
1816
|
-
partition_by:
|
|
1809
|
+
partition_by: PartitionByType | None = None,
|
|
1817
1810
|
# Parameters from Settings
|
|
1818
1811
|
cache: bool = False,
|
|
1819
|
-
parallel:
|
|
1820
|
-
workers:
|
|
1821
|
-
min_task_size:
|
|
1822
|
-
batch_size:
|
|
1812
|
+
parallel: int | None = None,
|
|
1813
|
+
workers: bool | int = False,
|
|
1814
|
+
min_task_size: int | None = None,
|
|
1815
|
+
batch_size: int | None = None,
|
|
1823
1816
|
# Parameters are unused, kept only to match the signature of Settings.to_dict:
|
|
1824
|
-
prefetch:
|
|
1825
|
-
namespace:
|
|
1826
|
-
project:
|
|
1817
|
+
prefetch: int | None = None,
|
|
1818
|
+
namespace: str | None = None,
|
|
1819
|
+
project: str | None = None,
|
|
1827
1820
|
) -> "Self":
|
|
1828
1821
|
query = self.clone()
|
|
1829
1822
|
steps = query.steps
|
|
@@ -1879,23 +1872,23 @@ class DatasetQuery:
|
|
|
1879
1872
|
|
|
1880
1873
|
def exec(self) -> "Self":
|
|
1881
1874
|
"""Execute the query."""
|
|
1875
|
+
query = self.clone()
|
|
1882
1876
|
try:
|
|
1883
|
-
query = self.clone()
|
|
1884
1877
|
query.apply_steps()
|
|
1885
1878
|
finally:
|
|
1886
|
-
|
|
1879
|
+
query.cleanup()
|
|
1887
1880
|
return query
|
|
1888
1881
|
|
|
1889
1882
|
def save(
|
|
1890
1883
|
self,
|
|
1891
|
-
name:
|
|
1892
|
-
version:
|
|
1893
|
-
project:
|
|
1894
|
-
feature_schema:
|
|
1895
|
-
dependencies:
|
|
1896
|
-
description:
|
|
1897
|
-
attrs:
|
|
1898
|
-
update_version:
|
|
1884
|
+
name: str | None = None,
|
|
1885
|
+
version: str | None = None,
|
|
1886
|
+
project: Project | None = None,
|
|
1887
|
+
feature_schema: dict | None = None,
|
|
1888
|
+
dependencies: list[DatasetDependency] | None = None,
|
|
1889
|
+
description: str | None = None,
|
|
1890
|
+
attrs: list[str] | None = None,
|
|
1891
|
+
update_version: str | None = "patch",
|
|
1899
1892
|
**kwargs,
|
|
1900
1893
|
) -> "Self":
|
|
1901
1894
|
"""Save the query as a dataset."""
|
|
@@ -1989,5 +1982,5 @@ class DatasetQuery:
|
|
|
1989
1982
|
return isinstance(self.last_step, SQLOrderBy)
|
|
1990
1983
|
|
|
1991
1984
|
@property
|
|
1992
|
-
def last_step(self) ->
|
|
1985
|
+
def last_step(self) -> Step | None:
|
|
1993
1986
|
return self.steps[-1] if self.steps else None
|
datachain/query/dispatch.py
CHANGED
|
@@ -3,9 +3,8 @@ from collections.abc import Iterable, Sequence
|
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
5
|
from sys import stdin
|
|
6
|
-
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
from typing import TYPE_CHECKING, Literal
|
|
7
7
|
|
|
8
|
-
import multiprocess
|
|
9
8
|
from cloudpickle import load, loads
|
|
10
9
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
11
10
|
from multiprocess import get_context
|
|
@@ -27,6 +26,7 @@ from datachain.query.utils import get_query_id_column
|
|
|
27
26
|
from datachain.utils import batched, flatten, safe_closing
|
|
28
27
|
|
|
29
28
|
if TYPE_CHECKING:
|
|
29
|
+
import multiprocess
|
|
30
30
|
from sqlalchemy import Select, Table
|
|
31
31
|
|
|
32
32
|
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
@@ -41,7 +41,7 @@ FAILED_STATUS = "FAILED"
|
|
|
41
41
|
NOTIFY_STATUS = "NOTIFY"
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def get_n_workers_from_arg(n_workers:
|
|
44
|
+
def get_n_workers_from_arg(n_workers: int | None = None) -> int:
|
|
45
45
|
if not n_workers:
|
|
46
46
|
return cpu_count()
|
|
47
47
|
if n_workers < 1:
|
|
@@ -86,7 +86,7 @@ def udf_entrypoint() -> int:
|
|
|
86
86
|
return 0
|
|
87
87
|
|
|
88
88
|
|
|
89
|
-
def udf_worker_entrypoint(fd:
|
|
89
|
+
def udf_worker_entrypoint(fd: int | None = None) -> int:
|
|
90
90
|
if not (udf_distributor_class := get_udf_distributor_class()):
|
|
91
91
|
raise RuntimeError(
|
|
92
92
|
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
@@ -97,9 +97,9 @@ def udf_worker_entrypoint(fd: Optional[int] = None) -> int:
|
|
|
97
97
|
|
|
98
98
|
|
|
99
99
|
class UDFDispatcher:
|
|
100
|
-
_catalog:
|
|
101
|
-
task_queue:
|
|
102
|
-
done_queue:
|
|
100
|
+
_catalog: Catalog | None = None
|
|
101
|
+
task_queue: "multiprocess.Queue | None" = None
|
|
102
|
+
done_queue: "multiprocess.Queue | None" = None
|
|
103
103
|
|
|
104
104
|
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
|
|
105
105
|
self.udf_data = udf_info["udf_data"]
|
datachain/query/metrics.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Optional, Union
|
|
3
2
|
|
|
4
|
-
metrics: dict[str,
|
|
3
|
+
metrics: dict[str, str | int | float | bool | None] = {}
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def set(key: str, value:
|
|
6
|
+
def set(key: str, value: str | int | float | bool | None) -> None: # noqa: PYI041
|
|
8
7
|
"""Set a metric value."""
|
|
9
8
|
if not isinstance(key, str):
|
|
10
9
|
raise TypeError("Key must be a string")
|
|
@@ -21,6 +20,6 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
|
|
|
21
20
|
metastore.update_job(job_id, metrics=metrics)
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
def get(key: str) ->
|
|
23
|
+
def get(key: str) -> str | int | float | bool | None:
|
|
25
24
|
"""Get a metric value."""
|
|
26
25
|
return metrics[key]
|
datachain/query/params.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
|
-
params_cache:
|
|
4
|
+
params_cache: dict[str, str] | None = None
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
def param(key: str, default:
|
|
7
|
+
def param(key: str, default: str | None = None) -> str | None:
|
|
9
8
|
"""Get query parameter."""
|
|
10
9
|
if not isinstance(key, str):
|
|
11
10
|
raise TypeError("Param key must be a string")
|
datachain/query/schema.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from fnmatch import fnmatch
|
|
4
|
-
from typing import TYPE_CHECKING, Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import attrs
|
|
7
8
|
import sqlalchemy as sa
|
|
@@ -42,7 +43,7 @@ class ColumnMeta(type):
|
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
45
|
-
inherit_cache:
|
|
46
|
+
inherit_cache: bool | None = True
|
|
46
47
|
|
|
47
48
|
def __init__(self, text, type_=None, is_literal=False, _selectable=None):
|
|
48
49
|
"""Dataset column."""
|
|
@@ -177,7 +178,7 @@ class LocalFilename(UDFParameter):
|
|
|
177
178
|
otherwise None will be returned.
|
|
178
179
|
"""
|
|
179
180
|
|
|
180
|
-
glob:
|
|
181
|
+
glob: str | None = None
|
|
181
182
|
|
|
182
183
|
def get_value(
|
|
183
184
|
self,
|
|
@@ -186,7 +187,7 @@ class LocalFilename(UDFParameter):
|
|
|
186
187
|
*,
|
|
187
188
|
cb: Callback = DEFAULT_CALLBACK,
|
|
188
189
|
**kwargs,
|
|
189
|
-
) ->
|
|
190
|
+
) -> str | None:
|
|
190
191
|
if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
|
|
191
192
|
# If the glob pattern is specified and the row filename
|
|
192
193
|
# does not match it, then return None
|
|
@@ -205,7 +206,7 @@ class LocalFilename(UDFParameter):
|
|
|
205
206
|
cache: bool = False,
|
|
206
207
|
cb: Callback = DEFAULT_CALLBACK,
|
|
207
208
|
**kwargs,
|
|
208
|
-
) ->
|
|
209
|
+
) -> str | None:
|
|
209
210
|
if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
|
|
210
211
|
# If the glob pattern is specified and the row filename
|
|
211
212
|
# does not match it, then return None
|
|
@@ -216,7 +217,7 @@ class LocalFilename(UDFParameter):
|
|
|
216
217
|
return client.cache.get_path(file)
|
|
217
218
|
|
|
218
219
|
|
|
219
|
-
UDFParamSpec =
|
|
220
|
+
UDFParamSpec = str | Column | UDFParameter
|
|
220
221
|
|
|
221
222
|
|
|
222
223
|
def normalize_param(param: UDFParamSpec) -> UDFParameter:
|
datachain/query/session.py
CHANGED
|
@@ -3,7 +3,7 @@ import gc
|
|
|
3
3
|
import logging
|
|
4
4
|
import re
|
|
5
5
|
import sys
|
|
6
|
-
from typing import TYPE_CHECKING, ClassVar
|
|
6
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
7
7
|
from uuid import uuid4
|
|
8
8
|
|
|
9
9
|
from datachain.catalog import get_catalog
|
|
@@ -39,7 +39,7 @@ class Session:
|
|
|
39
39
|
catalog (Catalog): Catalog object.
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
GLOBAL_SESSION_CTX:
|
|
42
|
+
GLOBAL_SESSION_CTX: "Session | None" = None
|
|
43
43
|
SESSION_CONTEXTS: ClassVar[list["Session"]] = []
|
|
44
44
|
ORIGINAL_EXCEPT_HOOK = None
|
|
45
45
|
|
|
@@ -51,8 +51,8 @@ class Session:
|
|
|
51
51
|
def __init__(
|
|
52
52
|
self,
|
|
53
53
|
name="",
|
|
54
|
-
catalog:
|
|
55
|
-
client_config:
|
|
54
|
+
catalog: "Catalog | None" = None,
|
|
55
|
+
client_config: dict | None = None,
|
|
56
56
|
in_memory: bool = False,
|
|
57
57
|
):
|
|
58
58
|
if re.match(r"^[0-9a-zA-Z]*$", name) is None:
|
|
@@ -126,9 +126,9 @@ class Session:
|
|
|
126
126
|
@classmethod
|
|
127
127
|
def get(
|
|
128
128
|
cls,
|
|
129
|
-
session:
|
|
130
|
-
catalog:
|
|
131
|
-
client_config:
|
|
129
|
+
session: "Session | None" = None,
|
|
130
|
+
catalog: "Catalog | None" = None,
|
|
131
|
+
client_config: dict | None = None,
|
|
132
132
|
in_memory: bool = False,
|
|
133
133
|
) -> "Session":
|
|
134
134
|
"""Creates a Session() object from a catalog.
|
datachain/query/udf.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
3
4
|
|
|
4
5
|
if TYPE_CHECKING:
|
|
5
6
|
from sqlalchemy import Select, Table
|
|
@@ -17,7 +18,7 @@ class UdfInfo(TypedDict):
|
|
|
17
18
|
query: "Select"
|
|
18
19
|
udf_fields: list[str]
|
|
19
20
|
batching: "BatchingStrategy"
|
|
20
|
-
processes:
|
|
21
|
+
processes: int | None
|
|
21
22
|
is_generator: bool
|
|
22
23
|
cache: bool
|
|
23
24
|
rows_total: int
|
|
@@ -33,14 +34,14 @@ class AbstractUDFDistributor(ABC):
|
|
|
33
34
|
query: "Select",
|
|
34
35
|
udf_data: bytes,
|
|
35
36
|
batching: "BatchingStrategy",
|
|
36
|
-
workers:
|
|
37
|
-
processes:
|
|
37
|
+
workers: bool | int,
|
|
38
|
+
processes: bool | int,
|
|
38
39
|
udf_fields: list[str],
|
|
39
40
|
rows_total: int,
|
|
40
41
|
use_cache: bool,
|
|
41
42
|
is_generator: bool = False,
|
|
42
|
-
min_task_size:
|
|
43
|
-
batch_size:
|
|
43
|
+
min_task_size: str | int | None = None,
|
|
44
|
+
batch_size: int | None = None,
|
|
44
45
|
) -> None: ...
|
|
45
46
|
|
|
46
47
|
@abstractmethod
|
|
@@ -48,4 +49,4 @@ class AbstractUDFDistributor(ABC):
|
|
|
48
49
|
|
|
49
50
|
@staticmethod
|
|
50
51
|
@abstractmethod
|
|
51
|
-
def run_udf(fd:
|
|
52
|
+
def run_udf(fd: int | None = None) -> int: ...
|
datachain/query/utils.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
|
-
|
|
3
1
|
import sqlalchemy as sa
|
|
4
2
|
|
|
5
|
-
ColT =
|
|
3
|
+
ColT = sa.ColumnClause | sa.Column | sa.ColumnElement | sa.TextClause | sa.Label
|
|
6
4
|
|
|
7
5
|
|
|
8
6
|
def column_name(col: ColT) -> str:
|
|
@@ -14,12 +12,12 @@ def column_name(col: ColT) -> str:
|
|
|
14
12
|
)
|
|
15
13
|
|
|
16
14
|
|
|
17
|
-
def get_query_column(query: sa.Select, name: str) ->
|
|
15
|
+
def get_query_column(query: sa.Select, name: str) -> ColT | None:
|
|
18
16
|
"""Returns column element from query by name or None if column not found."""
|
|
19
17
|
return next((col for col in query.inner_columns if column_name(col) == name), None)
|
|
20
18
|
|
|
21
19
|
|
|
22
|
-
def get_query_id_column(query: sa.Select) ->
|
|
20
|
+
def get_query_id_column(query: sa.Select) -> sa.ColumnElement | None:
|
|
23
21
|
"""Returns ID column element from query or None if column not found."""
|
|
24
22
|
col = get_query_column(query, "sys__id")
|
|
25
23
|
return col if col is not None and isinstance(col, sa.ColumnElement) else None
|