datachain 0.35.2__py3-none-any.whl → 0.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +45 -20
- datachain/catalog/dependency.py +164 -0
- datachain/data_storage/metastore.py +80 -0
- datachain/data_storage/schema.py +1 -2
- datachain/data_storage/sqlite.py +2 -9
- datachain/data_storage/warehouse.py +50 -33
- datachain/diff/__init__.py +2 -6
- datachain/lib/audio.py +54 -53
- datachain/lib/dc/datachain.py +13 -14
- datachain/query/dataset.py +21 -26
- datachain/query/dispatch.py +64 -42
- datachain/query/queue.py +2 -1
- {datachain-0.35.2.dist-info → datachain-0.36.1.dist-info}/METADATA +3 -2
- {datachain-0.35.2.dist-info → datachain-0.36.1.dist-info}/RECORD +18 -17
- {datachain-0.35.2.dist-info → datachain-0.36.1.dist-info}/WHEEL +0 -0
- {datachain-0.35.2.dist-info → datachain-0.36.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.35.2.dist-info → datachain-0.36.1.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.35.2.dist-info → datachain-0.36.1.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -54,6 +54,7 @@ from datachain.sql.types import DateTime, SQLType
|
|
|
54
54
|
from datachain.utils import DataChainDir
|
|
55
55
|
|
|
56
56
|
from .datasource import DataSource
|
|
57
|
+
from .dependency import build_dependency_hierarchy, populate_nested_dependencies
|
|
57
58
|
|
|
58
59
|
if TYPE_CHECKING:
|
|
59
60
|
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
@@ -1203,6 +1204,38 @@ class Catalog:
|
|
|
1203
1204
|
assert isinstance(dataset_info, dict)
|
|
1204
1205
|
return DatasetRecord.from_dict(dataset_info)
|
|
1205
1206
|
|
|
1207
|
+
def get_dataset_dependencies_by_ids(
|
|
1208
|
+
self,
|
|
1209
|
+
dataset_id: int,
|
|
1210
|
+
version_id: int,
|
|
1211
|
+
indirect: bool = True,
|
|
1212
|
+
) -> list[DatasetDependency | None]:
|
|
1213
|
+
dependency_nodes = self.metastore.get_dataset_dependency_nodes(
|
|
1214
|
+
dataset_id=dataset_id,
|
|
1215
|
+
version_id=version_id,
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
if not dependency_nodes:
|
|
1219
|
+
return []
|
|
1220
|
+
|
|
1221
|
+
dependency_map, children_map = build_dependency_hierarchy(dependency_nodes)
|
|
1222
|
+
|
|
1223
|
+
root_key = (dataset_id, version_id)
|
|
1224
|
+
if root_key not in children_map:
|
|
1225
|
+
return []
|
|
1226
|
+
|
|
1227
|
+
root_dependency_ids = children_map[root_key]
|
|
1228
|
+
root_dependencies = [dependency_map[dep_id] for dep_id in root_dependency_ids]
|
|
1229
|
+
|
|
1230
|
+
if indirect:
|
|
1231
|
+
for dependency in root_dependencies:
|
|
1232
|
+
if dependency is not None:
|
|
1233
|
+
populate_nested_dependencies(
|
|
1234
|
+
dependency, dependency_nodes, dependency_map, children_map
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
return root_dependencies
|
|
1238
|
+
|
|
1206
1239
|
def get_dataset_dependencies(
|
|
1207
1240
|
self,
|
|
1208
1241
|
name: str,
|
|
@@ -1216,29 +1249,21 @@ class Catalog:
|
|
|
1216
1249
|
namespace_name=namespace_name,
|
|
1217
1250
|
project_name=project_name,
|
|
1218
1251
|
)
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
)
|
|
1252
|
+
dataset_version = dataset.get_version(version)
|
|
1253
|
+
dataset_id = dataset.id
|
|
1254
|
+
dataset_version_id = dataset_version.id
|
|
1223
1255
|
|
|
1224
1256
|
if not indirect:
|
|
1225
|
-
return
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
# dependency has been removed
|
|
1230
|
-
continue
|
|
1231
|
-
if d.is_dataset:
|
|
1232
|
-
# only datasets can have dependencies
|
|
1233
|
-
d.dependencies = self.get_dataset_dependencies(
|
|
1234
|
-
d.name,
|
|
1235
|
-
d.version,
|
|
1236
|
-
namespace_name=d.namespace,
|
|
1237
|
-
project_name=d.project,
|
|
1238
|
-
indirect=indirect,
|
|
1239
|
-
)
|
|
1257
|
+
return self.metastore.get_direct_dataset_dependencies(
|
|
1258
|
+
dataset,
|
|
1259
|
+
version,
|
|
1260
|
+
)
|
|
1240
1261
|
|
|
1241
|
-
return
|
|
1262
|
+
return self.get_dataset_dependencies_by_ids(
|
|
1263
|
+
dataset_id,
|
|
1264
|
+
dataset_version_id,
|
|
1265
|
+
indirect,
|
|
1266
|
+
)
|
|
1242
1267
|
|
|
1243
1268
|
def ls_datasets(
|
|
1244
1269
|
self,
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
from datachain.dataset import DatasetDependency
|
|
7
|
+
|
|
8
|
+
DDN = TypeVar("DDN", bound="DatasetDependencyNode")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class DatasetDependencyNode:
|
|
13
|
+
namespace: str
|
|
14
|
+
project: str
|
|
15
|
+
id: int
|
|
16
|
+
dataset_id: int | None
|
|
17
|
+
dataset_version_id: int | None
|
|
18
|
+
dataset_name: str | None
|
|
19
|
+
dataset_version: str | None
|
|
20
|
+
created_at: datetime
|
|
21
|
+
source_dataset_id: int
|
|
22
|
+
source_dataset_version_id: int | None
|
|
23
|
+
depth: int
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def parse(
|
|
27
|
+
cls: builtins.type[DDN],
|
|
28
|
+
namespace: str,
|
|
29
|
+
project: str,
|
|
30
|
+
id: int,
|
|
31
|
+
dataset_id: int | None,
|
|
32
|
+
dataset_version_id: int | None,
|
|
33
|
+
dataset_name: str | None,
|
|
34
|
+
dataset_version: str | None,
|
|
35
|
+
created_at: datetime,
|
|
36
|
+
source_dataset_id: int,
|
|
37
|
+
source_dataset_version_id: int | None,
|
|
38
|
+
depth: int,
|
|
39
|
+
) -> "DatasetDependencyNode | None":
|
|
40
|
+
return cls(
|
|
41
|
+
namespace,
|
|
42
|
+
project,
|
|
43
|
+
id,
|
|
44
|
+
dataset_id,
|
|
45
|
+
dataset_version_id,
|
|
46
|
+
dataset_name,
|
|
47
|
+
dataset_version,
|
|
48
|
+
created_at,
|
|
49
|
+
source_dataset_id,
|
|
50
|
+
source_dataset_version_id,
|
|
51
|
+
depth,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def to_dependency(self) -> "DatasetDependency | None":
|
|
55
|
+
return DatasetDependency.parse(
|
|
56
|
+
namespace_name=self.namespace,
|
|
57
|
+
project_name=self.project,
|
|
58
|
+
id=self.id,
|
|
59
|
+
dataset_id=self.dataset_id,
|
|
60
|
+
dataset_version_id=self.dataset_version_id,
|
|
61
|
+
dataset_name=self.dataset_name,
|
|
62
|
+
dataset_version=self.dataset_version,
|
|
63
|
+
dataset_version_created_at=self.created_at,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def build_dependency_hierarchy(
|
|
68
|
+
dependency_nodes: list[DatasetDependencyNode | None],
|
|
69
|
+
) -> tuple[
|
|
70
|
+
dict[int, DatasetDependency | None], dict[tuple[int, int | None], list[int]]
|
|
71
|
+
]:
|
|
72
|
+
"""
|
|
73
|
+
Build dependency hierarchy from dependency nodes.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
dependency_nodes: List of DatasetDependencyNode objects from the database
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Tuple of (dependency_map, children_map) where:
|
|
80
|
+
- dependency_map: Maps dependency_id -> DatasetDependency
|
|
81
|
+
- children_map: Maps (source_dataset_id, source_version_id) ->
|
|
82
|
+
list of dependency_ids
|
|
83
|
+
"""
|
|
84
|
+
dependency_map: dict[int, DatasetDependency | None] = {}
|
|
85
|
+
children_map: dict[tuple[int, int | None], list[int]] = {}
|
|
86
|
+
|
|
87
|
+
for node in dependency_nodes:
|
|
88
|
+
if node is None:
|
|
89
|
+
continue
|
|
90
|
+
dependency = node.to_dependency()
|
|
91
|
+
parent_key = (node.source_dataset_id, node.source_dataset_version_id)
|
|
92
|
+
|
|
93
|
+
if dependency is not None:
|
|
94
|
+
dependency_map[dependency.id] = dependency
|
|
95
|
+
children_map.setdefault(parent_key, []).append(dependency.id)
|
|
96
|
+
else:
|
|
97
|
+
# Handle case where dependency creation failed (e.g., deleted dependency)
|
|
98
|
+
dependency_map[node.id] = None
|
|
99
|
+
children_map.setdefault(parent_key, []).append(node.id)
|
|
100
|
+
|
|
101
|
+
return dependency_map, children_map
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def populate_nested_dependencies(
|
|
105
|
+
dependency: DatasetDependency,
|
|
106
|
+
dependency_nodes: list[DatasetDependencyNode | None],
|
|
107
|
+
dependency_map: dict[int, DatasetDependency | None],
|
|
108
|
+
children_map: dict[tuple[int, int | None], list[int]],
|
|
109
|
+
) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Recursively populate nested dependencies for a given dependency.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
dependency: The dependency to populate nested dependencies for
|
|
115
|
+
dependency_nodes: All dependency nodes from the database
|
|
116
|
+
dependency_map: Maps dependency_id -> DatasetDependency
|
|
117
|
+
children_map: Maps (source_dataset_id, source_version_id) ->
|
|
118
|
+
list of dependency_ids
|
|
119
|
+
"""
|
|
120
|
+
# Find the target dataset and version for this dependency
|
|
121
|
+
target_dataset_id, target_version_id = find_target_dataset_version(
|
|
122
|
+
dependency, dependency_nodes
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if target_dataset_id is None or target_version_id is None:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
# Get children for this target
|
|
129
|
+
target_key = (target_dataset_id, target_version_id)
|
|
130
|
+
if target_key not in children_map:
|
|
131
|
+
dependency.dependencies = []
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
child_dependency_ids = children_map[target_key]
|
|
135
|
+
child_dependencies = [dependency_map[child_id] for child_id in child_dependency_ids]
|
|
136
|
+
|
|
137
|
+
dependency.dependencies = child_dependencies
|
|
138
|
+
|
|
139
|
+
# Recursively populate children
|
|
140
|
+
for child_dependency in child_dependencies:
|
|
141
|
+
if child_dependency is not None:
|
|
142
|
+
populate_nested_dependencies(
|
|
143
|
+
child_dependency, dependency_nodes, dependency_map, children_map
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def find_target_dataset_version(
|
|
148
|
+
dependency: DatasetDependency,
|
|
149
|
+
dependency_nodes: list[DatasetDependencyNode | None],
|
|
150
|
+
) -> tuple[int | None, int | None]:
|
|
151
|
+
"""
|
|
152
|
+
Find the target dataset ID and version ID for a given dependency.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
dependency: The dependency to find target for
|
|
156
|
+
dependency_nodes: All dependency nodes from the database
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tuple of (target_dataset_id, target_version_id) or (None, None) if not found
|
|
160
|
+
"""
|
|
161
|
+
for node in dependency_nodes:
|
|
162
|
+
if node is not None and node.id == dependency.id:
|
|
163
|
+
return node.dataset_id, node.dataset_version_id
|
|
164
|
+
return None, None
|
|
@@ -22,10 +22,12 @@ from sqlalchemy import (
|
|
|
22
22
|
Text,
|
|
23
23
|
UniqueConstraint,
|
|
24
24
|
desc,
|
|
25
|
+
literal,
|
|
25
26
|
select,
|
|
26
27
|
)
|
|
27
28
|
from sqlalchemy.sql import func as f
|
|
28
29
|
|
|
30
|
+
from datachain.catalog.dependency import DatasetDependencyNode
|
|
29
31
|
from datachain.checkpoint import Checkpoint
|
|
30
32
|
from datachain.data_storage import JobQueryType, JobStatus
|
|
31
33
|
from datachain.data_storage.serializer import Serializable
|
|
@@ -78,6 +80,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
78
80
|
dataset_list_class: type[DatasetListRecord] = DatasetListRecord
|
|
79
81
|
dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
|
|
80
82
|
dependency_class: type[DatasetDependency] = DatasetDependency
|
|
83
|
+
dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode
|
|
81
84
|
job_class: type[Job] = Job
|
|
82
85
|
checkpoint_class: type[Checkpoint] = Checkpoint
|
|
83
86
|
|
|
@@ -366,6 +369,12 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
366
369
|
) -> list[DatasetDependency | None]:
|
|
367
370
|
"""Gets direct dataset dependencies."""
|
|
368
371
|
|
|
372
|
+
@abstractmethod
|
|
373
|
+
def get_dataset_dependency_nodes(
|
|
374
|
+
self, dataset_id: int, version_id: int
|
|
375
|
+
) -> list[DatasetDependencyNode | None]:
|
|
376
|
+
"""Gets dataset dependency node from database."""
|
|
377
|
+
|
|
369
378
|
@abstractmethod
|
|
370
379
|
def remove_dataset_dependencies(
|
|
371
380
|
self, dataset: DatasetRecord, version: str | None = None
|
|
@@ -1483,6 +1492,77 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1483
1492
|
|
|
1484
1493
|
return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
|
|
1485
1494
|
|
|
1495
|
+
def get_dataset_dependency_nodes(
|
|
1496
|
+
self, dataset_id: int, version_id: int
|
|
1497
|
+
) -> list[DatasetDependencyNode | None]:
|
|
1498
|
+
n = self._namespaces_select().subquery()
|
|
1499
|
+
p = self._projects
|
|
1500
|
+
d = self._datasets_select().subquery()
|
|
1501
|
+
dd = self._datasets_dependencies
|
|
1502
|
+
dv = self._datasets_versions
|
|
1503
|
+
|
|
1504
|
+
# Common dependency fields for CTE
|
|
1505
|
+
dep_fields = [
|
|
1506
|
+
dd.c.id,
|
|
1507
|
+
dd.c.source_dataset_id,
|
|
1508
|
+
dd.c.source_dataset_version_id,
|
|
1509
|
+
dd.c.dataset_id,
|
|
1510
|
+
dd.c.dataset_version_id,
|
|
1511
|
+
]
|
|
1512
|
+
|
|
1513
|
+
# Base case: direct dependencies
|
|
1514
|
+
base_query = select(
|
|
1515
|
+
*dep_fields,
|
|
1516
|
+
literal(0).label("depth"),
|
|
1517
|
+
).where(
|
|
1518
|
+
(dd.c.source_dataset_id == dataset_id)
|
|
1519
|
+
& (dd.c.source_dataset_version_id == version_id)
|
|
1520
|
+
)
|
|
1521
|
+
|
|
1522
|
+
cte = base_query.cte(name="dependency_tree", recursive=True)
|
|
1523
|
+
|
|
1524
|
+
# Recursive case: dependencies of dependencies
|
|
1525
|
+
recursive_query = select(
|
|
1526
|
+
*dep_fields,
|
|
1527
|
+
(cte.c.depth + 1).label("depth"),
|
|
1528
|
+
).select_from(
|
|
1529
|
+
cte.join(
|
|
1530
|
+
dd,
|
|
1531
|
+
(cte.c.dataset_id == dd.c.source_dataset_id)
|
|
1532
|
+
& (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
|
|
1533
|
+
)
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
cte = cte.union(recursive_query)
|
|
1537
|
+
|
|
1538
|
+
# Fetch all with full details
|
|
1539
|
+
final_query = select(
|
|
1540
|
+
n.c.name,
|
|
1541
|
+
p.c.name,
|
|
1542
|
+
cte.c.id,
|
|
1543
|
+
cte.c.dataset_id,
|
|
1544
|
+
cte.c.dataset_version_id,
|
|
1545
|
+
d.c.name,
|
|
1546
|
+
dv.c.version,
|
|
1547
|
+
dv.c.created_at,
|
|
1548
|
+
cte.c.source_dataset_id,
|
|
1549
|
+
cte.c.source_dataset_version_id,
|
|
1550
|
+
cte.c.depth,
|
|
1551
|
+
).select_from(
|
|
1552
|
+
# Use outer joins to handle cases where dependent datasets have been
|
|
1553
|
+
# physically deleted. This allows us to return dependency records with
|
|
1554
|
+
# None values instead of silently omitting them, making broken
|
|
1555
|
+
# dependencies visible to callers.
|
|
1556
|
+
cte.join(d, cte.c.dataset_id == d.c.id, isouter=True)
|
|
1557
|
+
.join(dv, cte.c.dataset_version_id == dv.c.id, isouter=True)
|
|
1558
|
+
.join(p, d.c.project_id == p.c.id, isouter=True)
|
|
1559
|
+
.join(n, p.c.namespace_id == n.c.id, isouter=True)
|
|
1560
|
+
)
|
|
1561
|
+
|
|
1562
|
+
return [
|
|
1563
|
+
self.dependency_node_class.parse(*r) for r in self.db.execute(final_query)
|
|
1564
|
+
]
|
|
1565
|
+
|
|
1486
1566
|
def remove_dataset_dependencies(
|
|
1487
1567
|
self, dataset: DatasetRecord, version: str | None = None
|
|
1488
1568
|
) -> None:
|
datachain/data_storage/schema.py
CHANGED
|
@@ -11,7 +11,6 @@ from datachain.sql.types import (
|
|
|
11
11
|
JSON,
|
|
12
12
|
Boolean,
|
|
13
13
|
DateTime,
|
|
14
|
-
Int,
|
|
15
14
|
Int64,
|
|
16
15
|
SQLType,
|
|
17
16
|
String,
|
|
@@ -269,7 +268,7 @@ class DataTable:
|
|
|
269
268
|
@classmethod
|
|
270
269
|
def sys_columns(cls):
|
|
271
270
|
return [
|
|
272
|
-
sa.Column("sys__id",
|
|
271
|
+
sa.Column("sys__id", UInt64, primary_key=True),
|
|
273
272
|
sa.Column(
|
|
274
273
|
"sys__rand", UInt64, nullable=False, server_default=f.abs(f.random())
|
|
275
274
|
),
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -868,11 +868,8 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
868
868
|
if isinstance(c, BinaryExpression):
|
|
869
869
|
right_left_join = add_left_rows_filter(c)
|
|
870
870
|
|
|
871
|
-
# Use CTE instead of subquery to force SQLite to materialize the result
|
|
872
|
-
# This breaks deep nesting and prevents parser stack overflow.
|
|
873
871
|
union_cte = sqlalchemy.union(left_right_join, right_left_join).cte()
|
|
874
|
-
|
|
875
|
-
return self._regenerate_system_columns(union_cte)
|
|
872
|
+
return sqlalchemy.select(*union_cte.c).select_from(union_cte)
|
|
876
873
|
|
|
877
874
|
def _system_row_number_expr(self):
|
|
878
875
|
return func.row_number().over()
|
|
@@ -884,11 +881,7 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
884
881
|
"""
|
|
885
882
|
Create a temporary table from a query for use in a UDF.
|
|
886
883
|
"""
|
|
887
|
-
columns = [
|
|
888
|
-
sqlalchemy.Column(c.name, c.type)
|
|
889
|
-
for c in query.selected_columns
|
|
890
|
-
if c.name != "sys__id"
|
|
891
|
-
]
|
|
884
|
+
columns = [sqlalchemy.Column(c.name, c.type) for c in query.selected_columns]
|
|
892
885
|
table = self.create_udf_table(columns)
|
|
893
886
|
|
|
894
887
|
with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar:
|
|
@@ -5,7 +5,7 @@ import random
|
|
|
5
5
|
import string
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Union
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Union, cast
|
|
9
9
|
from urllib.parse import urlparse
|
|
10
10
|
|
|
11
11
|
import attrs
|
|
@@ -23,7 +23,7 @@ from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
|
|
|
23
23
|
from datachain.query.batch import RowsOutput
|
|
24
24
|
from datachain.query.schema import ColumnMeta
|
|
25
25
|
from datachain.sql.functions import path as pathfunc
|
|
26
|
-
from datachain.sql.types import
|
|
26
|
+
from datachain.sql.types import SQLType
|
|
27
27
|
from datachain.utils import sql_escape_like
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
|
|
32
32
|
_FromClauseArgument,
|
|
33
33
|
_OnClauseArgument,
|
|
34
34
|
)
|
|
35
|
+
from sqlalchemy.sql.selectable import FromClause
|
|
35
36
|
from sqlalchemy.types import TypeEngine
|
|
36
37
|
|
|
37
38
|
from datachain.data_storage import schema
|
|
@@ -248,45 +249,56 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
248
249
|
|
|
249
250
|
def _regenerate_system_columns(
|
|
250
251
|
self,
|
|
251
|
-
selectable: sa.Select
|
|
252
|
+
selectable: sa.Select,
|
|
252
253
|
keep_existing_columns: bool = False,
|
|
254
|
+
regenerate_columns: Iterable[str] | None = None,
|
|
253
255
|
) -> sa.Select:
|
|
254
256
|
"""
|
|
255
|
-
Return a SELECT that regenerates
|
|
257
|
+
Return a SELECT that regenerates system columns deterministically.
|
|
256
258
|
|
|
257
|
-
If keep_existing_columns is True, existing
|
|
258
|
-
|
|
259
|
-
"""
|
|
260
|
-
base = selectable.subquery() if hasattr(selectable, "subquery") else selectable
|
|
261
|
-
|
|
262
|
-
result_columns: dict[str, sa.ColumnElement] = {}
|
|
263
|
-
for col in base.c:
|
|
264
|
-
if col.name in result_columns:
|
|
265
|
-
raise ValueError(f"Duplicate column name {col.name} in SELECT")
|
|
266
|
-
if col.name in ("sys__id", "sys__rand"):
|
|
267
|
-
if keep_existing_columns:
|
|
268
|
-
result_columns[col.name] = col
|
|
269
|
-
else:
|
|
270
|
-
result_columns[col.name] = col
|
|
259
|
+
If keep_existing_columns is True, existing system columns will be kept as-is
|
|
260
|
+
even when they are listed in ``regenerate_columns``.
|
|
271
261
|
|
|
272
|
-
|
|
262
|
+
Args:
|
|
263
|
+
selectable: Base SELECT
|
|
264
|
+
keep_existing_columns: When True, reuse existing system columns even if
|
|
265
|
+
they are part of the regeneration set.
|
|
266
|
+
regenerate_columns: Names of system columns to regenerate. Defaults to
|
|
267
|
+
{"sys__id", "sys__rand"}. Columns not listed are left untouched.
|
|
268
|
+
"""
|
|
269
|
+
system_columns = {
|
|
273
270
|
sys_col.name: sys_col.type
|
|
274
271
|
for sys_col in self.schema.dataset_row_cls.sys_columns()
|
|
275
272
|
}
|
|
273
|
+
regenerate = set(regenerate_columns or system_columns)
|
|
274
|
+
generators = {
|
|
275
|
+
"sys__id": self._system_row_number_expr,
|
|
276
|
+
"sys__rand": self._system_random_expr,
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
base = cast("FromClause", selectable.subquery())
|
|
280
|
+
|
|
281
|
+
def build(name: str) -> sa.ColumnElement:
|
|
282
|
+
expr = generators[name]()
|
|
283
|
+
return sa.cast(expr, system_columns[name]).label(name)
|
|
284
|
+
|
|
285
|
+
columns: list[sa.ColumnElement] = []
|
|
286
|
+
present: set[str] = set()
|
|
287
|
+
changed = False
|
|
288
|
+
|
|
289
|
+
for col in base.c:
|
|
290
|
+
present.add(col.name)
|
|
291
|
+
regen = col.name in regenerate and not keep_existing_columns
|
|
292
|
+
columns.append(build(col.name) if regen else col)
|
|
293
|
+
changed |= regen
|
|
294
|
+
|
|
295
|
+
for name in regenerate - present:
|
|
296
|
+
columns.append(build(name))
|
|
297
|
+
changed = True
|
|
298
|
+
|
|
299
|
+
if not changed:
|
|
300
|
+
return selectable
|
|
276
301
|
|
|
277
|
-
# Add missing system columns if needed
|
|
278
|
-
if "sys__id" not in result_columns:
|
|
279
|
-
expr = self._system_row_number_expr()
|
|
280
|
-
expr = sa.cast(expr, system_types["sys__id"])
|
|
281
|
-
result_columns["sys__id"] = expr.label("sys__id")
|
|
282
|
-
if "sys__rand" not in result_columns:
|
|
283
|
-
expr = self._system_random_expr()
|
|
284
|
-
expr = sa.cast(expr, system_types["sys__rand"])
|
|
285
|
-
result_columns["sys__rand"] = expr.label("sys__rand")
|
|
286
|
-
|
|
287
|
-
# Wrap in subquery to materialize window functions, then wrap again in SELECT
|
|
288
|
-
# This ensures window functions are computed before INSERT...FROM SELECT
|
|
289
|
-
columns = list(result_columns.values())
|
|
290
302
|
inner = sa.select(*columns).select_from(base).subquery()
|
|
291
303
|
return sa.select(*inner.c).select_from(inner)
|
|
292
304
|
|
|
@@ -950,10 +962,15 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
950
962
|
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
|
|
951
963
|
and UDFs are run in other processes when run in parallel.
|
|
952
964
|
"""
|
|
965
|
+
columns = [
|
|
966
|
+
c
|
|
967
|
+
for c in columns
|
|
968
|
+
if c.name not in [col.name for col in self.dataset_row_cls.sys_columns()]
|
|
969
|
+
]
|
|
953
970
|
tbl = sa.Table(
|
|
954
971
|
name or self.udf_table_name(),
|
|
955
972
|
sa.MetaData(),
|
|
956
|
-
|
|
973
|
+
*self.dataset_row_cls.sys_columns(),
|
|
957
974
|
*columns,
|
|
958
975
|
)
|
|
959
976
|
self.db.create_table(tbl, if_not_exists=True)
|
datachain/diff/__init__.py
CHANGED
|
@@ -24,7 +24,7 @@ class CompareStatus(str, Enum):
|
|
|
24
24
|
SAME = "S"
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def _compare( # noqa: C901
|
|
27
|
+
def _compare( # noqa: C901
|
|
28
28
|
left: "DataChain",
|
|
29
29
|
right: "DataChain",
|
|
30
30
|
on: str | Sequence[str],
|
|
@@ -151,11 +151,7 @@ def _compare( # noqa: C901, PLR0912
|
|
|
151
151
|
if status_col:
|
|
152
152
|
cols_select.append(diff_col)
|
|
153
153
|
|
|
154
|
-
|
|
155
|
-
# TODO workaround when sys signal is not available in diff
|
|
156
|
-
dc_diff = dc_diff.settings(sys=True).select(*cols_select).settings(sys=False)
|
|
157
|
-
else:
|
|
158
|
-
dc_diff = dc_diff.select(*cols_select)
|
|
154
|
+
dc_diff = dc_diff.select(*cols_select)
|
|
159
155
|
|
|
160
156
|
# final schema is schema from the left chain with status column added if needed
|
|
161
157
|
dc_diff.signals_schema = (
|
datachain/lib/audio.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import posixpath
|
|
2
|
+
import re
|
|
2
3
|
from typing import TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
from datachain.lib.file import FileError
|
|
@@ -9,7 +10,7 @@ if TYPE_CHECKING:
|
|
|
9
10
|
from datachain.lib.file import Audio, AudioFile, File
|
|
10
11
|
|
|
11
12
|
try:
|
|
12
|
-
import
|
|
13
|
+
import soundfile as sf
|
|
13
14
|
except ImportError as exc:
|
|
14
15
|
raise ImportError(
|
|
15
16
|
"Missing dependencies for processing audio.\n"
|
|
@@ -26,18 +27,25 @@ def audio_info(file: "File | AudioFile") -> "Audio":
|
|
|
26
27
|
|
|
27
28
|
try:
|
|
28
29
|
with file.open() as f:
|
|
29
|
-
info =
|
|
30
|
+
info = sf.info(f)
|
|
31
|
+
|
|
32
|
+
sample_rate = int(info.samplerate)
|
|
33
|
+
channels = int(info.channels)
|
|
34
|
+
frames = int(info.frames)
|
|
35
|
+
duration = float(info.duration)
|
|
30
36
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
37
|
+
# soundfile provides format and subtype
|
|
38
|
+
if info.format:
|
|
39
|
+
format_name = info.format.lower()
|
|
40
|
+
else:
|
|
41
|
+
format_name = file.get_file_ext().lower()
|
|
35
42
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
43
|
+
if not format_name:
|
|
44
|
+
format_name = "unknown"
|
|
45
|
+
codec_name = info.subtype if info.subtype else ""
|
|
39
46
|
|
|
40
|
-
|
|
47
|
+
# Calculate bit rate from subtype
|
|
48
|
+
bits_per_sample = _get_bits_per_sample(info.subtype)
|
|
41
49
|
bit_rate = (
|
|
42
50
|
bits_per_sample * sample_rate * channels if bits_per_sample > 0 else -1
|
|
43
51
|
)
|
|
@@ -58,44 +66,39 @@ def audio_info(file: "File | AudioFile") -> "Audio":
|
|
|
58
66
|
)
|
|
59
67
|
|
|
60
68
|
|
|
61
|
-
def
|
|
69
|
+
def _get_bits_per_sample(subtype: str) -> int:
|
|
62
70
|
"""
|
|
63
|
-
Map
|
|
71
|
+
Map soundfile subtype to bits per sample.
|
|
64
72
|
|
|
65
73
|
Args:
|
|
66
|
-
|
|
67
|
-
file_ext: The file extension as a fallback
|
|
74
|
+
subtype: The subtype string from soundfile
|
|
68
75
|
|
|
69
76
|
Returns:
|
|
70
|
-
|
|
77
|
+
Bits per sample, or 0 if unknown
|
|
71
78
|
"""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
"
|
|
78
|
-
"
|
|
79
|
-
"
|
|
80
|
-
"
|
|
79
|
+
if not subtype:
|
|
80
|
+
return 0
|
|
81
|
+
|
|
82
|
+
# Common PCM and floating-point subtypes
|
|
83
|
+
pcm_bits = {
|
|
84
|
+
"PCM_16": 16,
|
|
85
|
+
"PCM_24": 24,
|
|
86
|
+
"PCM_32": 32,
|
|
87
|
+
"PCM_S8": 8,
|
|
88
|
+
"PCM_U8": 8,
|
|
89
|
+
"FLOAT": 32,
|
|
90
|
+
"DOUBLE": 64,
|
|
81
91
|
}
|
|
82
92
|
|
|
83
|
-
if
|
|
84
|
-
return
|
|
93
|
+
if subtype in pcm_bits:
|
|
94
|
+
return pcm_bits[subtype]
|
|
85
95
|
|
|
86
|
-
#
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
"wav": "wav",
|
|
91
|
-
"aiff": "aiff",
|
|
92
|
-
"au": "au",
|
|
93
|
-
"raw": "raw",
|
|
94
|
-
}
|
|
95
|
-
return pcm_formats.get(file_ext, "wav") # Default to wav for PCM
|
|
96
|
+
# Handle variants such as PCM_S16LE, PCM_F32LE, etc.
|
|
97
|
+
match = re.search(r"PCM_(?:[A-Z]*?)(\d+)", subtype)
|
|
98
|
+
if match:
|
|
99
|
+
return int(match.group(1))
|
|
96
100
|
|
|
97
|
-
|
|
98
|
-
return file_ext if file_ext else "unknown"
|
|
101
|
+
return 0
|
|
99
102
|
|
|
100
103
|
|
|
101
104
|
def audio_to_np(
|
|
@@ -114,27 +117,27 @@ def audio_to_np(
|
|
|
114
117
|
|
|
115
118
|
try:
|
|
116
119
|
with audio.open() as f:
|
|
117
|
-
info =
|
|
118
|
-
sample_rate = info.
|
|
120
|
+
info = sf.info(f)
|
|
121
|
+
sample_rate = info.samplerate
|
|
119
122
|
|
|
120
123
|
frame_offset = int(start * sample_rate)
|
|
121
124
|
num_frames = int(duration * sample_rate) if duration is not None else -1
|
|
122
125
|
|
|
123
126
|
# Reset file pointer to the beginning
|
|
124
|
-
# This is important to ensure we read from the correct position later
|
|
125
127
|
f.seek(0)
|
|
126
128
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
+
# Read audio data with offset and frame count
|
|
130
|
+
audio_np, sr = sf.read(
|
|
131
|
+
f,
|
|
132
|
+
start=frame_offset,
|
|
133
|
+
frames=num_frames,
|
|
134
|
+
always_2d=False,
|
|
135
|
+
dtype="float32",
|
|
129
136
|
)
|
|
130
137
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
audio_np = audio_np.T
|
|
135
|
-
else:
|
|
136
|
-
audio_np = audio_np.squeeze()
|
|
137
|
-
|
|
138
|
+
# soundfile returns shape (frames,) for mono or
|
|
139
|
+
# (frames, channels) for multi-channel
|
|
140
|
+
# We keep this format as it matches expected output
|
|
138
141
|
return audio_np, int(sr)
|
|
139
142
|
except Exception as exc:
|
|
140
143
|
raise FileError(
|
|
@@ -152,11 +155,9 @@ def audio_to_bytes(
|
|
|
152
155
|
|
|
153
156
|
If duration is None, converts from start to end of file.
|
|
154
157
|
If start is 0 and duration is None, converts entire file."""
|
|
155
|
-
y, sr = audio_to_np(audio, start, duration)
|
|
156
|
-
|
|
157
158
|
import io
|
|
158
159
|
|
|
159
|
-
|
|
160
|
+
y, sr = audio_to_np(audio, start, duration)
|
|
160
161
|
|
|
161
162
|
buffer = io.BytesIO()
|
|
162
163
|
sf.write(buffer, y, sr, format=format)
|
datachain/lib/dc/datachain.py
CHANGED
|
@@ -856,7 +856,9 @@ class DataChain:
|
|
|
856
856
|
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
857
857
|
**self._settings.to_dict(),
|
|
858
858
|
),
|
|
859
|
-
signal_schema=
|
|
859
|
+
signal_schema=SignalSchema({"sys": Sys})
|
|
860
|
+
| self.signals_schema
|
|
861
|
+
| udf_obj.output,
|
|
860
862
|
)
|
|
861
863
|
|
|
862
864
|
def gen(
|
|
@@ -894,7 +896,7 @@ class DataChain:
|
|
|
894
896
|
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
895
897
|
**self._settings.to_dict(),
|
|
896
898
|
),
|
|
897
|
-
signal_schema=udf_obj.output,
|
|
899
|
+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
|
|
898
900
|
)
|
|
899
901
|
|
|
900
902
|
@delta_disabled
|
|
@@ -1031,7 +1033,7 @@ class DataChain:
|
|
|
1031
1033
|
partition_by=processed_partition_by,
|
|
1032
1034
|
**self._settings.to_dict(),
|
|
1033
1035
|
),
|
|
1034
|
-
signal_schema=udf_obj.output,
|
|
1036
|
+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
|
|
1035
1037
|
)
|
|
1036
1038
|
|
|
1037
1039
|
def batch_map(
|
|
@@ -1097,11 +1099,7 @@ class DataChain:
|
|
|
1097
1099
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
1098
1100
|
DataModel.register(list(sign.output_schema.values.values()))
|
|
1099
1101
|
|
|
1100
|
-
|
|
1101
|
-
if self._sys:
|
|
1102
|
-
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
|
|
1103
|
-
|
|
1104
|
-
params_schema = signals_schema.slice(
|
|
1102
|
+
params_schema = self.signals_schema.slice(
|
|
1105
1103
|
sign.params, self._setup, is_batch=is_batch
|
|
1106
1104
|
)
|
|
1107
1105
|
|
|
@@ -1156,11 +1154,9 @@ class DataChain:
|
|
|
1156
1154
|
)
|
|
1157
1155
|
)
|
|
1158
1156
|
|
|
1159
|
-
def select(self, *args: str
|
|
1157
|
+
def select(self, *args: str) -> "Self":
|
|
1160
1158
|
"""Select only a specified set of signals."""
|
|
1161
1159
|
new_schema = self.signals_schema.resolve(*args)
|
|
1162
|
-
if self._sys and _sys:
|
|
1163
|
-
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
1164
1160
|
columns = new_schema.db_signals()
|
|
1165
1161
|
return self._evolve(
|
|
1166
1162
|
query=self._query.select(*columns), signal_schema=new_schema
|
|
@@ -1710,9 +1706,11 @@ class DataChain:
|
|
|
1710
1706
|
|
|
1711
1707
|
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1712
1708
|
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1709
|
+
|
|
1710
|
+
ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
|
|
1711
|
+
|
|
1712
|
+
if not full:
|
|
1713
|
+
ds.signals_schema = SignalSchema({"sys": Sys}) | ds.signals_schema
|
|
1716
1714
|
|
|
1717
1715
|
return ds
|
|
1718
1716
|
|
|
@@ -1723,6 +1721,7 @@ class DataChain:
|
|
|
1723
1721
|
Parameters:
|
|
1724
1722
|
other: chain whose rows will be added to `self`.
|
|
1725
1723
|
"""
|
|
1724
|
+
self.signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1726
1725
|
return self._evolve(query=self._query.union(other._query))
|
|
1727
1726
|
|
|
1728
1727
|
def subtract( # type: ignore[override]
|
datachain/query/dataset.py
CHANGED
|
@@ -438,9 +438,6 @@ class UDFStep(Step, ABC):
|
|
|
438
438
|
"""
|
|
439
439
|
|
|
440
440
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
441
|
-
if "sys__id" not in query.selected_columns:
|
|
442
|
-
raise RuntimeError("Query must have sys__id column to run UDF")
|
|
443
|
-
|
|
444
441
|
if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
|
|
445
442
|
return
|
|
446
443
|
|
|
@@ -634,12 +631,11 @@ class UDFStep(Step, ABC):
|
|
|
634
631
|
|
|
635
632
|
# Apply partitioning if needed.
|
|
636
633
|
if self.partition_by is not None:
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
634
|
+
_query = query = self.catalog.warehouse._regenerate_system_columns(
|
|
635
|
+
query_generator.select(),
|
|
636
|
+
keep_existing_columns=True,
|
|
637
|
+
regenerate_columns=["sys__id"],
|
|
638
|
+
)
|
|
643
639
|
partition_tbl = self.create_partitions_table(query)
|
|
644
640
|
temp_tables.append(partition_tbl.name)
|
|
645
641
|
query = query.outerjoin(
|
|
@@ -960,28 +956,23 @@ class SQLUnion(Step):
|
|
|
960
956
|
q2 = self.query2.apply_steps().select().subquery()
|
|
961
957
|
temp_tables.extend(self.query2.temp_table_names)
|
|
962
958
|
|
|
963
|
-
columns1
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
sqlalchemy.select(*columns2)
|
|
967
|
-
)
|
|
968
|
-
union_cte = union_select.cte()
|
|
969
|
-
regenerated = self.query1.catalog.warehouse._regenerate_system_columns(
|
|
970
|
-
union_cte
|
|
971
|
-
)
|
|
972
|
-
result_columns = tuple(regenerated.selected_columns)
|
|
959
|
+
columns1 = _drop_system_columns(q1.columns)
|
|
960
|
+
columns2 = _drop_system_columns(q2.columns)
|
|
961
|
+
columns1, columns2 = _order_columns(columns1, columns2)
|
|
973
962
|
|
|
974
963
|
def q(*columns):
|
|
975
|
-
|
|
976
|
-
|
|
964
|
+
selected_names = [c.name for c in columns]
|
|
965
|
+
col1 = [c for c in columns1 if c.name in selected_names]
|
|
966
|
+
col2 = [c for c in columns2 if c.name in selected_names]
|
|
967
|
+
union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
|
|
977
968
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
return
|
|
969
|
+
union_cte = union_query.cte()
|
|
970
|
+
select_cols = [union_cte.c[name] for name in selected_names]
|
|
971
|
+
return sqlalchemy.select(*select_cols)
|
|
981
972
|
|
|
982
973
|
return step_result(
|
|
983
974
|
q,
|
|
984
|
-
|
|
975
|
+
columns1,
|
|
985
976
|
dependencies=self.query1.dependencies | self.query2.dependencies,
|
|
986
977
|
)
|
|
987
978
|
|
|
@@ -1070,7 +1061,7 @@ class SQLJoin(Step):
|
|
|
1070
1061
|
q1 = self.get_query(self.query1, temp_tables)
|
|
1071
1062
|
q2 = self.get_query(self.query2, temp_tables)
|
|
1072
1063
|
|
|
1073
|
-
q1_columns = list(q1.c)
|
|
1064
|
+
q1_columns = _drop_system_columns(q1.c) if self.full else list(q1.c)
|
|
1074
1065
|
q1_column_names = {c.name for c in q1_columns}
|
|
1075
1066
|
|
|
1076
1067
|
q2_columns = []
|
|
@@ -1211,6 +1202,10 @@ def _order_columns(
|
|
|
1211
1202
|
return [[d[n] for n in column_order] for d in column_dicts]
|
|
1212
1203
|
|
|
1213
1204
|
|
|
1205
|
+
def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
|
|
1206
|
+
return [c for c in columns if not c.name.startswith("sys__")]
|
|
1207
|
+
|
|
1208
|
+
|
|
1214
1209
|
@attrs.define
|
|
1215
1210
|
class ResultIter:
|
|
1216
1211
|
_row_iter: Iterable[Any]
|
datachain/query/dispatch.py
CHANGED
|
@@ -2,12 +2,16 @@ import contextlib
|
|
|
2
2
|
from collections.abc import Iterable, Sequence
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
|
+
from queue import Empty
|
|
5
6
|
from sys import stdin
|
|
7
|
+
from time import monotonic, sleep
|
|
6
8
|
from typing import TYPE_CHECKING, Literal
|
|
7
9
|
|
|
10
|
+
import multiprocess
|
|
8
11
|
from cloudpickle import load, loads
|
|
9
12
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
10
|
-
from multiprocess import
|
|
13
|
+
from multiprocess.context import Process
|
|
14
|
+
from multiprocess.queues import Queue as MultiprocessQueue
|
|
11
15
|
|
|
12
16
|
from datachain.catalog import Catalog
|
|
13
17
|
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
@@ -25,7 +29,6 @@ from datachain.query.udf import UdfInfo
|
|
|
25
29
|
from datachain.utils import batched, flatten, safe_closing
|
|
26
30
|
|
|
27
31
|
if TYPE_CHECKING:
|
|
28
|
-
import multiprocess
|
|
29
32
|
from sqlalchemy import Select, Table
|
|
30
33
|
|
|
31
34
|
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
@@ -101,8 +104,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:
|
|
|
101
104
|
|
|
102
105
|
class UDFDispatcher:
|
|
103
106
|
_catalog: Catalog | None = None
|
|
104
|
-
task_queue:
|
|
105
|
-
done_queue:
|
|
107
|
+
task_queue: MultiprocessQueue | None = None
|
|
108
|
+
done_queue: MultiprocessQueue | None = None
|
|
106
109
|
|
|
107
110
|
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
|
|
108
111
|
self.udf_data = udf_info["udf_data"]
|
|
@@ -121,7 +124,7 @@ class UDFDispatcher:
|
|
|
121
124
|
self.buffer_size = buffer_size
|
|
122
125
|
self.task_queue = None
|
|
123
126
|
self.done_queue = None
|
|
124
|
-
self.ctx = get_context("spawn")
|
|
127
|
+
self.ctx = multiprocess.get_context("spawn")
|
|
125
128
|
|
|
126
129
|
@property
|
|
127
130
|
def catalog(self) -> "Catalog":
|
|
@@ -259,8 +262,6 @@ class UDFDispatcher:
|
|
|
259
262
|
for p in pool:
|
|
260
263
|
p.start()
|
|
261
264
|
|
|
262
|
-
# Will be set to True if all tasks complete normally
|
|
263
|
-
normal_completion = False
|
|
264
265
|
try:
|
|
265
266
|
# Will be set to True when the input is exhausted
|
|
266
267
|
input_finished = False
|
|
@@ -283,10 +284,20 @@ class UDFDispatcher:
|
|
|
283
284
|
|
|
284
285
|
# Process all tasks
|
|
285
286
|
while n_workers > 0:
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
287
|
+
while True:
|
|
288
|
+
try:
|
|
289
|
+
result = self.done_queue.get_nowait()
|
|
290
|
+
break
|
|
291
|
+
except Empty:
|
|
292
|
+
for p in pool:
|
|
293
|
+
exitcode = p.exitcode
|
|
294
|
+
if exitcode not in (None, 0):
|
|
295
|
+
message = (
|
|
296
|
+
f"Worker {p.name} exited unexpectedly with "
|
|
297
|
+
f"code {exitcode}"
|
|
298
|
+
)
|
|
299
|
+
raise RuntimeError(message) from None
|
|
300
|
+
sleep(0.01)
|
|
290
301
|
|
|
291
302
|
if bytes_downloaded := result.get("bytes_downloaded"):
|
|
292
303
|
download_cb.relative_update(bytes_downloaded)
|
|
@@ -313,39 +324,50 @@ class UDFDispatcher:
|
|
|
313
324
|
put_into_queue(self.task_queue, next(input_data))
|
|
314
325
|
except StopIteration:
|
|
315
326
|
input_finished = True
|
|
316
|
-
|
|
317
|
-
# Finished with all tasks normally
|
|
318
|
-
normal_completion = True
|
|
319
327
|
finally:
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
328
|
+
self._shutdown_workers(pool)
|
|
329
|
+
|
|
330
|
+
def _shutdown_workers(self, pool: list[Process]) -> None:
|
|
331
|
+
self._terminate_pool(pool)
|
|
332
|
+
self._drain_queue(self.done_queue)
|
|
333
|
+
self._drain_queue(self.task_queue)
|
|
334
|
+
self._close_queue(self.done_queue)
|
|
335
|
+
self._close_queue(self.task_queue)
|
|
336
|
+
|
|
337
|
+
def _terminate_pool(self, pool: list[Process]) -> None:
|
|
338
|
+
for proc in pool:
|
|
339
|
+
if proc.is_alive():
|
|
340
|
+
proc.terminate()
|
|
341
|
+
|
|
342
|
+
deadline = monotonic() + 1.0
|
|
343
|
+
for proc in pool:
|
|
344
|
+
if not proc.is_alive():
|
|
345
|
+
continue
|
|
346
|
+
remaining = deadline - monotonic()
|
|
347
|
+
if remaining > 0:
|
|
348
|
+
proc.join(remaining)
|
|
349
|
+
if proc.is_alive():
|
|
350
|
+
proc.kill()
|
|
351
|
+
proc.join(timeout=0.2)
|
|
352
|
+
|
|
353
|
+
def _drain_queue(self, queue: MultiprocessQueue) -> None:
|
|
354
|
+
while True:
|
|
355
|
+
try:
|
|
356
|
+
queue.get_nowait()
|
|
357
|
+
except Empty:
|
|
358
|
+
return
|
|
359
|
+
except (OSError, ValueError):
|
|
360
|
+
return
|
|
341
361
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
362
|
+
def _close_queue(self, queue: MultiprocessQueue) -> None:
|
|
363
|
+
with contextlib.suppress(OSError, ValueError):
|
|
364
|
+
queue.close()
|
|
365
|
+
with contextlib.suppress(RuntimeError, AssertionError, ValueError):
|
|
366
|
+
queue.join_thread()
|
|
345
367
|
|
|
346
368
|
|
|
347
369
|
class DownloadCallback(Callback):
|
|
348
|
-
def __init__(self, queue:
|
|
370
|
+
def __init__(self, queue: MultiprocessQueue) -> None:
|
|
349
371
|
self.queue = queue
|
|
350
372
|
super().__init__()
|
|
351
373
|
|
|
@@ -360,7 +382,7 @@ class ProcessedCallback(Callback):
|
|
|
360
382
|
def __init__(
|
|
361
383
|
self,
|
|
362
384
|
name: Literal["processed", "generated"],
|
|
363
|
-
queue:
|
|
385
|
+
queue: MultiprocessQueue,
|
|
364
386
|
) -> None:
|
|
365
387
|
self.name = name
|
|
366
388
|
self.queue = queue
|
|
@@ -375,8 +397,8 @@ class UDFWorker:
|
|
|
375
397
|
self,
|
|
376
398
|
catalog: "Catalog",
|
|
377
399
|
udf: "UDFAdapter",
|
|
378
|
-
task_queue:
|
|
379
|
-
done_queue:
|
|
400
|
+
task_queue: MultiprocessQueue,
|
|
401
|
+
done_queue: MultiprocessQueue,
|
|
380
402
|
query: "Select",
|
|
381
403
|
table: "Table",
|
|
382
404
|
cache: bool,
|
datachain/query/queue.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
from collections.abc import Iterable, Iterator
|
|
3
|
-
from queue import Empty, Full
|
|
3
|
+
from queue import Empty, Full
|
|
4
4
|
from struct import pack, unpack
|
|
5
5
|
from time import sleep
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import msgpack
|
|
9
|
+
from multiprocess.queues import Queue
|
|
9
10
|
|
|
10
11
|
from datachain.query.batch import RowsOutput
|
|
11
12
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.36.1
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -64,7 +64,6 @@ Requires-Dist: torch>=2.1.0; extra == "torch"
|
|
|
64
64
|
Requires-Dist: torchvision; extra == "torch"
|
|
65
65
|
Requires-Dist: transformers>=4.36.0; extra == "torch"
|
|
66
66
|
Provides-Extra: audio
|
|
67
|
-
Requires-Dist: torchaudio; extra == "audio"
|
|
68
67
|
Requires-Dist: soundfile; extra == "audio"
|
|
69
68
|
Provides-Extra: remote
|
|
70
69
|
Requires-Dist: lz4; extra == "remote"
|
|
@@ -76,6 +75,7 @@ Requires-Dist: numba>=0.60.0; extra == "hf"
|
|
|
76
75
|
Requires-Dist: datasets[vision]>=4.0.0; extra == "hf"
|
|
77
76
|
Requires-Dist: datasets[audio]>=4.0.0; (sys_platform == "linux" or sys_platform == "darwin") and extra == "hf"
|
|
78
77
|
Requires-Dist: fsspec>=2024.12.0; extra == "hf"
|
|
78
|
+
Requires-Dist: torch<2.9.0; extra == "hf"
|
|
79
79
|
Provides-Extra: video
|
|
80
80
|
Requires-Dist: ffmpeg-python; extra == "video"
|
|
81
81
|
Requires-Dist: imageio[ffmpeg,pyav]>=2.37.0; extra == "video"
|
|
@@ -117,6 +117,7 @@ Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
|
|
|
117
117
|
Requires-Dist: ultralytics; extra == "examples"
|
|
118
118
|
Requires-Dist: open_clip_torch; extra == "examples"
|
|
119
119
|
Requires-Dist: openai; extra == "examples"
|
|
120
|
+
Requires-Dist: torchaudio<2.9.0; extra == "examples"
|
|
120
121
|
Dynamic: license-file
|
|
121
122
|
|
|
122
123
|
================
|
|
@@ -24,8 +24,9 @@ datachain/studio.py,sha256=OHVAY8IcktgEHNSgYaJuBfAIln_nKBrF2j7BOM2Fxd0,15177
|
|
|
24
24
|
datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
|
|
25
25
|
datachain/utils.py,sha256=9KXA-fRH8lhK4E2JmdNOOH-74aUe-Sjb8wLiTiqXOh8,15710
|
|
26
26
|
datachain/catalog/__init__.py,sha256=9NBaywvAOaXdkyqiHjbBEiXs7JImR1OJsY9r8D5Q16g,403
|
|
27
|
-
datachain/catalog/catalog.py,sha256=
|
|
27
|
+
datachain/catalog/catalog.py,sha256=Bb5xvC-qIGdUz_-epiFT9Eq6c3e00ZtNh_qFKyI_bp0,69862
|
|
28
28
|
datachain/catalog/datasource.py,sha256=IkGMh0Ttg6Q-9DWfU_H05WUnZepbGa28HYleECi6K7I,1353
|
|
29
|
+
datachain/catalog/dependency.py,sha256=EHuu_Ox76sEhy71NXjFJiHxQVTz19KecqBcrjwFCa7M,5280
|
|
29
30
|
datachain/catalog/loader.py,sha256=VTaGPc4ASNdUdr7Elobp8qcXUOHwd0oqQcnk3LUwtF0,6244
|
|
30
31
|
datachain/cli/__init__.py,sha256=y7wfBmKiBwPJiIOhoeIOXXBWankYbjknm6OnauEPQxM,8203
|
|
31
32
|
datachain/cli/utils.py,sha256=WAeK_DSWGsYAYp58P4C9EYuAlfbUjW8PI0wh3TCfNUo,3005
|
|
@@ -53,12 +54,12 @@ datachain/client/s3.py,sha256=KS9o0jxXJRFp7Isdibz366VaWrULmpegzfYdurJpAl0,7499
|
|
|
53
54
|
datachain/data_storage/__init__.py,sha256=9Wit-oe5P46V7CJQTD0BJ5MhOa2Y9h3ddJ4VWTe-Lec,273
|
|
54
55
|
datachain/data_storage/db_engine.py,sha256=MGbrckXk5kHOfpjnhHhGpyJpAsgaBCxMmfd33hB2SWI,3756
|
|
55
56
|
datachain/data_storage/job.py,sha256=NGFhXg0C0zRFTaF6ccjXZJT4xI4_gUr1WcxTLK6WYDE,448
|
|
56
|
-
datachain/data_storage/metastore.py,sha256=
|
|
57
|
-
datachain/data_storage/schema.py,sha256=
|
|
57
|
+
datachain/data_storage/metastore.py,sha256=NLGYLErWFUNXjKbEoESFkKW222MQdMCBlpuqaYVugsE,63484
|
|
58
|
+
datachain/data_storage/schema.py,sha256=3fAgiE11TIDYCW7EbTdiOm61SErRitvsLr7YPnUlVm0,9801
|
|
58
59
|
datachain/data_storage/serializer.py,sha256=oL8i8smyAeVUyDepk8Xhf3lFOGOEHMoZjA5GdFzvfGI,3862
|
|
59
|
-
datachain/data_storage/sqlite.py,sha256=
|
|
60
|
-
datachain/data_storage/warehouse.py,sha256=
|
|
61
|
-
datachain/diff/__init__.py,sha256=
|
|
60
|
+
datachain/data_storage/sqlite.py,sha256=MgQ6bfJ7LGW91UiVHQtSkj_5HalRi1aeHCEW__5JEe8,30959
|
|
61
|
+
datachain/data_storage/warehouse.py,sha256=nuGT27visvAi7jr7ZAZF-wmFe0ZEFD8qaTheINX_7RM,35269
|
|
62
|
+
datachain/diff/__init__.py,sha256=Fo3xMnctKyA0YtvnsBXQ-P5gQeeEwed17Tn_i7vfLKs,9332
|
|
62
63
|
datachain/fs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
64
|
datachain/fs/reference.py,sha256=A8McpXF0CqbXPqanXuvpKu50YLB3a2ZXA3YAPxtBXSM,914
|
|
64
65
|
datachain/fs/utils.py,sha256=s-FkTOCGBk-b6TT3toQH51s9608pofoFjUSTc1yy7oE,825
|
|
@@ -75,7 +76,7 @@ datachain/func/string.py,sha256=kXkPHimtA__EVg_Th1yldGaLJpw4HYVhIeYtKy3DuyQ,7406
|
|
|
75
76
|
datachain/func/window.py,sha256=ImyRpc1QI8QUSPO7KdD60e_DPVo7Ja0G5kcm6BlyMcw,1584
|
|
76
77
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
78
|
datachain/lib/arrow.py,sha256=eCZtqbjAzkL4aemY74f_XkIJ_FWwXugJNjIFOwDa9w0,10815
|
|
78
|
-
datachain/lib/audio.py,sha256=
|
|
79
|
+
datachain/lib/audio.py,sha256=hHG29vqrV389im152wCjh80d0xqXGGvFnUpUwkzZejQ,7385
|
|
79
80
|
datachain/lib/clip.py,sha256=nF8-N6Uz0MbAsPJBY2iXEYa3DPLo80OOer5SRNAtcGM,6149
|
|
80
81
|
datachain/lib/data_model.py,sha256=H-bagx24-cLlC7ngSP6Dby4mB6kSxxV7KDiHxQjzwlg,3798
|
|
81
82
|
datachain/lib/dataset_info.py,sha256=Ym7yYcGpfUmPLrfdxueijCVRP2Go6KbyuLk_fmzYgDU,3273
|
|
@@ -108,7 +109,7 @@ datachain/lib/convert/values_to_tuples.py,sha256=Sxj0ojeMSpAwM_NNoXa1dMR_2L_cQ6X
|
|
|
108
109
|
datachain/lib/dc/__init__.py,sha256=UrUzmDH6YyVl8fxM5iXTSFtl5DZTUzEYm1MaazK4vdQ,900
|
|
109
110
|
datachain/lib/dc/csv.py,sha256=fIfj5-2Ix4z5D5yZueagd5WUWw86pusJ9JJKD-U3KGg,4407
|
|
110
111
|
datachain/lib/dc/database.py,sha256=Wqob3dQc9Mol_0vagzVEXzteCKS9M0E3U5130KVmQKg,14629
|
|
111
|
-
datachain/lib/dc/datachain.py,sha256=
|
|
112
|
+
datachain/lib/dc/datachain.py,sha256=cVqgemBiPVLSnfEVDLU1YH0dtowS-N-YFOAxV1k7i6U,104178
|
|
112
113
|
datachain/lib/dc/datasets.py,sha256=A4SW-b3dkQnm9Wi7ciCdlXqtrsquIeRfBQN_bJ_ulqY,15237
|
|
113
114
|
datachain/lib/dc/hf.py,sha256=FeruEO176L2qQ1Mnx0QmK4kV0GuQ4xtj717N8fGJrBI,2849
|
|
114
115
|
datachain/lib/dc/json.py,sha256=iJ6G0jwTKz8xtfh1eICShnWk_bAMWjF5bFnOXLHaTlw,2683
|
|
@@ -131,11 +132,11 @@ datachain/model/ultralytics/pose.py,sha256=pvoXrWWUSWT_UBaMwUb5MBHAY57Co2HFDPigF
|
|
|
131
132
|
datachain/model/ultralytics/segment.py,sha256=v9_xDxd5zw_I8rXsbl7yQXgEdTs2T38zyY_Y4XGN8ok,3194
|
|
132
133
|
datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
|
|
133
134
|
datachain/query/batch.py,sha256=ugTlSFqh_kxMcG6vJ5XrEzG9jBXRdb7KRAEEsFWiPew,4190
|
|
134
|
-
datachain/query/dataset.py,sha256=
|
|
135
|
-
datachain/query/dispatch.py,sha256=
|
|
135
|
+
datachain/query/dataset.py,sha256=Pu8FC11VcIj8ewXJxe0mjJpr4HBr2-gvCtMk4GQCva0,67419
|
|
136
|
+
datachain/query/dispatch.py,sha256=Tg73zB6vDnYYYAvtlS9l7BI3sI1EfRCbDjiasvNxz2s,16385
|
|
136
137
|
datachain/query/metrics.py,sha256=qOMHiYPTMtVs2zI-mUSy8OPAVwrg4oJtVF85B9tdQyM,810
|
|
137
138
|
datachain/query/params.py,sha256=JkVz6IKUIpF58JZRkUXFT8DAHX2yfaULbhVaGmHKFLc,826
|
|
138
|
-
datachain/query/queue.py,sha256=
|
|
139
|
+
datachain/query/queue.py,sha256=kCetMG6y7_ynV_jJDAXkLsf8WsVZCEk1fAuQGd7yTOo,3543
|
|
139
140
|
datachain/query/schema.py,sha256=Cn1keXjktptAbEDbHlxSzdoCu5H6h_Vzp_DtNpMSr5w,6697
|
|
140
141
|
datachain/query/session.py,sha256=lbwMDvxjZ2BS2rA9qk7MVBRzlsSrwH92yJ_waP3uvDc,6781
|
|
141
142
|
datachain/query/udf.py,sha256=SLLLNLz3QmtaM04ZVTu7K6jo58I-1j5Jf7Lb4ORv4tQ,1385
|
|
@@ -164,9 +165,9 @@ datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR
|
|
|
164
165
|
datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
|
|
165
166
|
datachain/toolkit/split.py,sha256=xQzzmvQRKsPteDKbpgOxd4r971BnFaK33mcOl0FuGeI,2883
|
|
166
167
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
167
|
-
datachain-0.
|
|
168
|
-
datachain-0.
|
|
169
|
-
datachain-0.
|
|
170
|
-
datachain-0.
|
|
171
|
-
datachain-0.
|
|
172
|
-
datachain-0.
|
|
168
|
+
datachain-0.36.1.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
169
|
+
datachain-0.36.1.dist-info/METADATA,sha256=BBaBx1Ail7RzpUlvEywlXKZtl_6Vn-KIEjm8OJdXrng,13657
|
|
170
|
+
datachain-0.36.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
171
|
+
datachain-0.36.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
172
|
+
datachain-0.36.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
173
|
+
datachain-0.36.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|