datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
|
@@ -56,16 +56,16 @@ class YoloPose(DataModel):
|
|
|
56
56
|
if not summary:
|
|
57
57
|
return YoloPose(box=BBox(), pose=Pose3D())
|
|
58
58
|
name = summary[0].get("name", "")
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
59
|
+
if summary[0].get("box"):
|
|
60
|
+
assert isinstance(summary[0]["box"], dict)
|
|
61
|
+
box = BBox.from_dict(summary[0]["box"], title=name)
|
|
62
|
+
else:
|
|
63
|
+
box = BBox()
|
|
64
|
+
if summary[0].get("keypoints"):
|
|
65
|
+
assert isinstance(summary[0]["keypoints"], dict)
|
|
66
|
+
pose = Pose3D.from_dict(summary[0]["keypoints"])
|
|
67
|
+
else:
|
|
68
|
+
pose = Pose3D()
|
|
69
69
|
return YoloPose(
|
|
70
70
|
cls=summary[0]["class"],
|
|
71
71
|
name=name,
|
|
@@ -102,8 +102,12 @@ class YoloPoses(DataModel):
|
|
|
102
102
|
cls.append(s["class"])
|
|
103
103
|
names.append(name)
|
|
104
104
|
confidence.append(s["confidence"])
|
|
105
|
-
|
|
106
|
-
|
|
105
|
+
if s.get("box"):
|
|
106
|
+
assert isinstance(s["box"], dict)
|
|
107
|
+
box.append(BBox.from_dict(s["box"], title=name))
|
|
108
|
+
if s.get("keypoints"):
|
|
109
|
+
assert isinstance(s["keypoints"], dict)
|
|
110
|
+
pose.append(Pose3D.from_dict(s["keypoints"]))
|
|
107
111
|
return YoloPoses(
|
|
108
112
|
cls=cls,
|
|
109
113
|
name=names,
|
|
@@ -34,16 +34,16 @@ class YoloSegment(DataModel):
|
|
|
34
34
|
if not summary:
|
|
35
35
|
return YoloSegment(box=BBox(), segment=Segment())
|
|
36
36
|
name = summary[0].get("name", "")
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
37
|
+
if summary[0].get("box"):
|
|
38
|
+
assert isinstance(summary[0]["box"], dict)
|
|
39
|
+
box = BBox.from_dict(summary[0]["box"], title=name)
|
|
40
|
+
else:
|
|
41
|
+
box = BBox()
|
|
42
|
+
if summary[0].get("segments"):
|
|
43
|
+
assert isinstance(summary[0]["segments"], dict)
|
|
44
|
+
segment = Segment.from_dict(summary[0]["segments"], title=name)
|
|
45
|
+
else:
|
|
46
|
+
segment = Segment()
|
|
47
47
|
return YoloSegment(
|
|
48
48
|
cls=summary[0]["class"],
|
|
49
49
|
name=summary[0]["name"],
|
|
@@ -80,8 +80,12 @@ class YoloSegments(DataModel):
|
|
|
80
80
|
cls.append(s["class"])
|
|
81
81
|
names.append(name)
|
|
82
82
|
confidence.append(s["confidence"])
|
|
83
|
-
|
|
84
|
-
|
|
83
|
+
if s.get("box"):
|
|
84
|
+
assert isinstance(s["box"], dict)
|
|
85
|
+
box.append(BBox.from_dict(s["box"], title=name))
|
|
86
|
+
if s.get("segments"):
|
|
87
|
+
assert isinstance(s["segments"], dict)
|
|
88
|
+
segment.append(Segment.from_dict(s["segments"], title=name))
|
|
85
89
|
return YoloSegments(
|
|
86
90
|
cls=cls,
|
|
87
91
|
name=names,
|
datachain/namespace.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from dataclasses import dataclass, fields
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
|
|
6
|
+
from datachain.error import InvalidNamespaceNameError
|
|
7
|
+
|
|
8
|
+
N = TypeVar("N", bound="Namespace")
|
|
9
|
+
NAMESPACE_NAME_RESERVED_CHARS = [".", "@"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_name(name: str) -> tuple[str, str | None]:
|
|
13
|
+
"""
|
|
14
|
+
Parses namespace name into namespace and optional project name.
|
|
15
|
+
If both namespace and project are defined in name, they need to be split by dot
|
|
16
|
+
e.g dev.my-project
|
|
17
|
+
Valid inputs:
|
|
18
|
+
- dev.my-project
|
|
19
|
+
- dev
|
|
20
|
+
"""
|
|
21
|
+
parts = name.split(".")
|
|
22
|
+
if len(parts) == 1:
|
|
23
|
+
return name, None
|
|
24
|
+
if len(parts) == 2:
|
|
25
|
+
return parts[0], parts[1]
|
|
26
|
+
raise InvalidNamespaceNameError(
|
|
27
|
+
f"Invalid namespace format: {name}. Expected 'namespace' or 'ns1.ns2'."
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class Namespace:
|
|
33
|
+
id: int
|
|
34
|
+
uuid: str
|
|
35
|
+
name: str
|
|
36
|
+
descr: str | None
|
|
37
|
+
created_at: datetime
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def validate_name(name: str) -> None:
|
|
41
|
+
"""Throws exception if name is invalid, otherwise returns None"""
|
|
42
|
+
if not name:
|
|
43
|
+
raise InvalidNamespaceNameError("Namespace name cannot be empty")
|
|
44
|
+
|
|
45
|
+
for c in NAMESPACE_NAME_RESERVED_CHARS:
|
|
46
|
+
if c in name:
|
|
47
|
+
raise InvalidNamespaceNameError(
|
|
48
|
+
f"Character {c} is reserved and not allowed in namespace name"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if name in [Namespace.default(), Namespace.system()]:
|
|
52
|
+
raise InvalidNamespaceNameError(
|
|
53
|
+
f"Namespace name {name} is reserved and cannot be used."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def default() -> str:
|
|
58
|
+
"""Name of default namespace"""
|
|
59
|
+
return "local"
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def system() -> str:
|
|
63
|
+
"""Name of the system namespace"""
|
|
64
|
+
return "system"
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def is_system(self):
|
|
68
|
+
return self.name == Namespace.system()
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def parse(
|
|
72
|
+
cls: builtins.type[N],
|
|
73
|
+
id: int,
|
|
74
|
+
uuid: str,
|
|
75
|
+
name: str,
|
|
76
|
+
descr: str | None,
|
|
77
|
+
created_at: datetime,
|
|
78
|
+
) -> "Namespace":
|
|
79
|
+
return cls(id, uuid, name, descr, created_at)
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_dict(cls, d: dict[str, Any]) -> "Namespace":
|
|
83
|
+
kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
|
|
84
|
+
return cls(**kwargs)
|
datachain/node.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
|
|
5
5
|
import attrs
|
|
6
6
|
|
|
@@ -53,11 +53,11 @@ class Node:
|
|
|
53
53
|
sys__rand: int = 0
|
|
54
54
|
path: str = ""
|
|
55
55
|
etag: str = ""
|
|
56
|
-
version:
|
|
56
|
+
version: str | None = None
|
|
57
57
|
is_latest: bool = True
|
|
58
|
-
last_modified:
|
|
58
|
+
last_modified: datetime | None = None
|
|
59
59
|
size: int = 0
|
|
60
|
-
location:
|
|
60
|
+
location: str | None = None
|
|
61
61
|
source: StorageURI = StorageURI("") # noqa: RUF009
|
|
62
62
|
dir_type: int = DirType.FILE
|
|
63
63
|
|
|
@@ -90,7 +90,7 @@ class Node:
|
|
|
90
90
|
return self.path + "/"
|
|
91
91
|
return self.path
|
|
92
92
|
|
|
93
|
-
def to_file(self, source:
|
|
93
|
+
def to_file(self, source: StorageURI | None = None) -> File:
|
|
94
94
|
if source is None:
|
|
95
95
|
source = self.source
|
|
96
96
|
return File(
|
|
@@ -189,7 +189,7 @@ class NodeWithPath:
|
|
|
189
189
|
TIME_FMT = "%Y-%m-%d %H:%M"
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
def long_line_str(name: str, timestamp:
|
|
192
|
+
def long_line_str(name: str, timestamp: datetime | None) -> str:
|
|
193
193
|
if timestamp is None:
|
|
194
194
|
time = "-"
|
|
195
195
|
else:
|
datachain/nodes_thread_pool.py
CHANGED
datachain/plugins.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Plugin loader for DataChain callables.
|
|
2
|
+
|
|
3
|
+
Discovers and invokes entry points in the group "datachain.callables" once
|
|
4
|
+
per process. This enables external packages (e.g., Studio) to register
|
|
5
|
+
their callables with the serializer registry without explicit imports.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from importlib import metadata as importlib_metadata
|
|
9
|
+
|
|
10
|
+
_plugins_loaded = False
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def ensure_plugins_loaded() -> None:
|
|
14
|
+
global _plugins_loaded # noqa: PLW0603
|
|
15
|
+
if _plugins_loaded:
|
|
16
|
+
return
|
|
17
|
+
|
|
18
|
+
# Compatible across importlib.metadata versions
|
|
19
|
+
eps_obj = importlib_metadata.entry_points()
|
|
20
|
+
for ep in eps_obj.select(group="datachain.callables"):
|
|
21
|
+
func = ep.load()
|
|
22
|
+
func()
|
|
23
|
+
|
|
24
|
+
_plugins_loaded = True
|
datachain/project.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from dataclasses import dataclass, fields
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
|
|
6
|
+
from datachain.error import InvalidProjectNameError
|
|
7
|
+
from datachain.namespace import Namespace
|
|
8
|
+
|
|
9
|
+
P = TypeVar("P", bound="Project")
|
|
10
|
+
PROJECT_NAME_RESERVED_CHARS = [".", "@"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class Project:
|
|
15
|
+
id: int
|
|
16
|
+
uuid: str
|
|
17
|
+
name: str
|
|
18
|
+
descr: str | None
|
|
19
|
+
created_at: datetime
|
|
20
|
+
namespace: Namespace
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def validate_name(name: str) -> None:
|
|
24
|
+
"""Throws exception if name is invalid, otherwise returns None"""
|
|
25
|
+
if not name:
|
|
26
|
+
raise InvalidProjectNameError("Project name cannot be empty")
|
|
27
|
+
|
|
28
|
+
for c in PROJECT_NAME_RESERVED_CHARS:
|
|
29
|
+
if c in name:
|
|
30
|
+
raise InvalidProjectNameError(
|
|
31
|
+
f"Character {c} is reserved and not allowed in project name."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if name in [Project.default(), Project.listing()]:
|
|
35
|
+
raise InvalidProjectNameError(
|
|
36
|
+
f"Project name {name} is reserved and cannot be used."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def default() -> str:
|
|
41
|
+
"""Name of default project"""
|
|
42
|
+
return "local"
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def listing() -> str:
|
|
46
|
+
"""Name of listing project where all listing datasets will be saved"""
|
|
47
|
+
return "listing"
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def parse(
|
|
51
|
+
cls: builtins.type[P],
|
|
52
|
+
namespace_id: int,
|
|
53
|
+
namespace_uuid: str,
|
|
54
|
+
namespace_name: str,
|
|
55
|
+
namespace_descr: str | None,
|
|
56
|
+
namespace_created_at: datetime,
|
|
57
|
+
project_id: int,
|
|
58
|
+
uuid: str,
|
|
59
|
+
name: str,
|
|
60
|
+
descr: str | None,
|
|
61
|
+
created_at: datetime,
|
|
62
|
+
project_namespace_id: int,
|
|
63
|
+
) -> "Project":
|
|
64
|
+
namespace = Namespace.parse(
|
|
65
|
+
namespace_id,
|
|
66
|
+
namespace_uuid,
|
|
67
|
+
namespace_name,
|
|
68
|
+
namespace_descr,
|
|
69
|
+
namespace_created_at,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return cls(project_id, uuid, name, descr, created_at, namespace)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def from_dict(cls, d: dict[str, Any]) -> "Project":
|
|
76
|
+
namespace = Namespace.from_dict(d.pop("namespace"))
|
|
77
|
+
kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
|
|
78
|
+
return cls(**kwargs, namespace=namespace)
|
datachain/query/batch.py
CHANGED
|
@@ -1,24 +1,14 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import math
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Generator, Sequence
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
7
|
-
|
|
8
|
-
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
9
|
-
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
10
|
-
from datachain.query.utils import get_query_column, get_query_id_column
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from sqlalchemy import Select
|
|
4
|
+
from collections.abc import Callable, Generator, Sequence
|
|
14
5
|
|
|
6
|
+
import sqlalchemy as sa
|
|
15
7
|
|
|
16
|
-
|
|
17
|
-
class RowsOutputBatch:
|
|
18
|
-
rows: Sequence[Sequence]
|
|
19
|
-
|
|
8
|
+
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
20
9
|
|
|
21
|
-
|
|
10
|
+
RowsOutputBatch = Sequence[Sequence]
|
|
11
|
+
RowsOutput = Sequence | RowsOutputBatch
|
|
22
12
|
|
|
23
13
|
|
|
24
14
|
class BatchingStrategy(ABC):
|
|
@@ -30,8 +20,8 @@ class BatchingStrategy(ABC):
|
|
|
30
20
|
def __call__(
|
|
31
21
|
self,
|
|
32
22
|
execute: Callable,
|
|
33
|
-
query:
|
|
34
|
-
|
|
23
|
+
query: sa.Select,
|
|
24
|
+
id_col: sa.ColumnElement | None = None,
|
|
35
25
|
) -> Generator[RowsOutput, None, None]:
|
|
36
26
|
"""Apply the provided parameters to the UDF."""
|
|
37
27
|
|
|
@@ -47,12 +37,16 @@ class NoBatching(BatchingStrategy):
|
|
|
47
37
|
def __call__(
|
|
48
38
|
self,
|
|
49
39
|
execute: Callable,
|
|
50
|
-
query:
|
|
51
|
-
|
|
40
|
+
query: sa.Select,
|
|
41
|
+
id_col: sa.ColumnElement | None = None,
|
|
52
42
|
) -> Generator[Sequence, None, None]:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
43
|
+
ids_only = False
|
|
44
|
+
if id_col is not None:
|
|
45
|
+
query = query.with_only_columns(id_col)
|
|
46
|
+
ids_only = True
|
|
47
|
+
|
|
48
|
+
rows = execute(query)
|
|
49
|
+
yield from (r[0] for r in rows) if ids_only else rows
|
|
56
50
|
|
|
57
51
|
|
|
58
52
|
class Batch(BatchingStrategy):
|
|
@@ -69,27 +63,31 @@ class Batch(BatchingStrategy):
|
|
|
69
63
|
def __call__(
|
|
70
64
|
self,
|
|
71
65
|
execute: Callable,
|
|
72
|
-
query:
|
|
73
|
-
|
|
74
|
-
) -> Generator[
|
|
75
|
-
|
|
76
|
-
|
|
66
|
+
query: sa.Select,
|
|
67
|
+
id_col: sa.ColumnElement | None = None,
|
|
68
|
+
) -> Generator[RowsOutput, None, None]:
|
|
69
|
+
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
70
|
+
|
|
71
|
+
ids_only = False
|
|
72
|
+
if id_col is not None:
|
|
73
|
+
query = query.with_only_columns(id_col)
|
|
74
|
+
ids_only = True
|
|
77
75
|
|
|
78
76
|
# choose page size that is a multiple of the batch size
|
|
79
77
|
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
|
|
80
78
|
|
|
81
79
|
# select rows in batches
|
|
82
|
-
results
|
|
80
|
+
results = []
|
|
83
81
|
|
|
84
82
|
with contextlib.closing(execute(query, page_size=page_size)) as rows:
|
|
85
83
|
for row in rows:
|
|
86
84
|
results.append(row)
|
|
87
85
|
if len(results) >= self.count:
|
|
88
86
|
batch, results = results[: self.count], results[self.count :]
|
|
89
|
-
yield
|
|
87
|
+
yield [r[0] for r in batch] if ids_only else batch
|
|
90
88
|
|
|
91
89
|
if len(results) > 0:
|
|
92
|
-
yield
|
|
90
|
+
yield [r[0] for r in results] if ids_only else results
|
|
93
91
|
|
|
94
92
|
|
|
95
93
|
class Partition(BatchingStrategy):
|
|
@@ -104,18 +102,19 @@ class Partition(BatchingStrategy):
|
|
|
104
102
|
def __call__(
|
|
105
103
|
self,
|
|
106
104
|
execute: Callable,
|
|
107
|
-
query:
|
|
108
|
-
|
|
109
|
-
) -> Generator[
|
|
110
|
-
|
|
111
|
-
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
|
|
105
|
+
query: sa.Select,
|
|
106
|
+
id_col: sa.ColumnElement | None = None,
|
|
107
|
+
) -> Generator[RowsOutput, None, None]:
|
|
108
|
+
if (partition_col := query.selected_columns.get(PARTITION_COLUMN_ID)) is None:
|
|
112
109
|
raise RuntimeError("partition column not found in query")
|
|
113
110
|
|
|
114
|
-
|
|
111
|
+
ids_only = False
|
|
112
|
+
if id_col is not None:
|
|
115
113
|
query = query.with_only_columns(id_col, partition_col)
|
|
114
|
+
ids_only = True
|
|
116
115
|
|
|
117
|
-
current_partition:
|
|
118
|
-
batch: list
|
|
116
|
+
current_partition: int | None = None
|
|
117
|
+
batch: list = []
|
|
119
118
|
|
|
120
119
|
query_fields = [str(c.name) for c in query.selected_columns]
|
|
121
120
|
id_column_idx = query_fields.index("sys__id")
|
|
@@ -132,9 +131,9 @@ class Partition(BatchingStrategy):
|
|
|
132
131
|
if current_partition != partition:
|
|
133
132
|
current_partition = partition
|
|
134
133
|
if len(batch) > 0:
|
|
135
|
-
yield
|
|
134
|
+
yield batch
|
|
136
135
|
batch = []
|
|
137
|
-
batch.append(
|
|
136
|
+
batch.append(row[id_column_idx] if ids_only else row)
|
|
138
137
|
|
|
139
138
|
if len(batch) > 0:
|
|
140
|
-
yield
|
|
139
|
+
yield batch
|