datachain 0.34.5__py3-none-any.whl → 0.34.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/catalog.py +75 -83
- datachain/catalog/loader.py +3 -3
- datachain/checkpoint.py +1 -2
- datachain/cli/__init__.py +2 -4
- datachain/cli/commands/datasets.py +13 -13
- datachain/cli/commands/ls.py +4 -4
- datachain/cli/commands/query.py +3 -3
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +1 -2
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +11 -21
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +4 -4
- datachain/client/local.py +4 -4
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +5 -5
- datachain/data_storage/metastore.py +107 -107
- datachain/data_storage/schema.py +18 -24
- datachain/data_storage/sqlite.py +21 -28
- datachain/data_storage/warehouse.py +13 -13
- datachain/dataset.py +64 -70
- datachain/delta.py +21 -18
- datachain/diff/__init__.py +13 -13
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +45 -42
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +54 -81
- datachain/job.py +8 -8
- datachain/lib/arrow.py +17 -14
- datachain/lib/audio.py +6 -6
- datachain/lib/clip.py +5 -4
- datachain/lib/convert/python_to_sql.py +4 -22
- datachain/lib/convert/values_to_tuples.py +4 -9
- datachain/lib/data_model.py +20 -19
- datachain/lib/dataset_info.py +6 -6
- datachain/lib/dc/csv.py +10 -10
- datachain/lib/dc/database.py +28 -29
- datachain/lib/dc/datachain.py +98 -97
- datachain/lib/dc/datasets.py +22 -22
- datachain/lib/dc/hf.py +4 -4
- datachain/lib/dc/json.py +9 -10
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +5 -5
- datachain/lib/dc/records.py +5 -5
- datachain/lib/dc/storage.py +12 -12
- datachain/lib/dc/storage_pattern.py +2 -2
- datachain/lib/dc/utils.py +11 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +26 -26
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +1 -2
- datachain/lib/model_store.py +3 -3
- datachain/lib/namespaces.py +4 -6
- datachain/lib/projects.py +5 -9
- datachain/lib/pytorch.py +10 -10
- datachain/lib/settings.py +23 -23
- datachain/lib/signal_schema.py +52 -44
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +25 -17
- datachain/lib/udf_signature.py +11 -11
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +30 -35
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +4 -4
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +1 -7
- datachain/project.py +4 -4
- datachain/query/batch.py +7 -8
- datachain/query/dataset.py +80 -87
- datachain/query/dispatch.py +7 -7
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/schema.py +7 -6
- datachain/query/session.py +7 -7
- datachain/query/udf.py +8 -7
- datachain/query/utils.py +8 -6
- datachain/remote/studio.py +33 -39
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +6 -9
- datachain/studio.py +30 -30
- datachain/toolkit/split.py +1 -2
- datachain/utils.py +21 -21
- {datachain-0.34.5.dist-info → datachain-0.34.7.dist-info}/METADATA +2 -3
- datachain-0.34.7.dist-info/RECORD +173 -0
- datachain-0.34.5.dist-info/RECORD +0 -173
- {datachain-0.34.5.dist-info → datachain-0.34.7.dist-info}/WHEEL +0 -0
- {datachain-0.34.5.dist-info → datachain-0.34.7.dist-info}/entry_points.txt +0 -0
- {datachain-0.34.5.dist-info → datachain-0.34.7.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.34.5.dist-info → datachain-0.34.7.dist-info}/top_level.txt +0 -0
datachain/lib/video.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
import shutil
|
|
3
3
|
import tempfile
|
|
4
|
-
from typing import Optional, Union
|
|
5
4
|
|
|
6
5
|
from numpy import ndarray
|
|
7
6
|
|
|
@@ -18,7 +17,7 @@ except ImportError as exc:
|
|
|
18
17
|
) from exc
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
def video_info(file:
|
|
20
|
+
def video_info(file: File | VideoFile) -> Video:
|
|
22
21
|
"""
|
|
23
22
|
Returns video file information.
|
|
24
23
|
|
|
@@ -108,7 +107,7 @@ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
|
|
|
108
107
|
def validate_frame_range(
|
|
109
108
|
video: VideoFile,
|
|
110
109
|
start: int = 0,
|
|
111
|
-
end:
|
|
110
|
+
end: int | None = None,
|
|
112
111
|
step: int = 1,
|
|
113
112
|
) -> tuple[int, int, int]:
|
|
114
113
|
"""
|
|
@@ -186,7 +185,7 @@ def save_video_fragment(
|
|
|
186
185
|
start: float,
|
|
187
186
|
end: float,
|
|
188
187
|
output: str,
|
|
189
|
-
format:
|
|
188
|
+
format: str | None = None,
|
|
190
189
|
) -> VideoFile:
|
|
191
190
|
"""
|
|
192
191
|
Saves video interval as a new video file. If output is a remote path,
|
datachain/lib/webdataset.py
CHANGED
|
@@ -1,17 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import tarfile
|
|
3
|
+
import types
|
|
3
4
|
import warnings
|
|
4
|
-
from collections.abc import Iterator, Sequence
|
|
5
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Callable,
|
|
9
|
-
ClassVar,
|
|
10
|
-
Optional,
|
|
11
|
-
Union,
|
|
12
|
-
get_args,
|
|
13
|
-
get_origin,
|
|
14
|
-
)
|
|
7
|
+
from typing import Any, ClassVar, Union, get_args, get_origin
|
|
15
8
|
|
|
16
9
|
from pydantic import Field
|
|
17
10
|
|
|
@@ -64,28 +57,28 @@ class WDSBasic(DataModel):
|
|
|
64
57
|
|
|
65
58
|
|
|
66
59
|
class WDSAllFile(WDSBasic):
|
|
67
|
-
txt:
|
|
68
|
-
text:
|
|
69
|
-
cap:
|
|
70
|
-
transcript:
|
|
71
|
-
cls:
|
|
72
|
-
cls2:
|
|
73
|
-
index:
|
|
74
|
-
inx:
|
|
75
|
-
id:
|
|
76
|
-
json:
|
|
77
|
-
jsn:
|
|
78
|
-
|
|
79
|
-
pyd:
|
|
80
|
-
pickle:
|
|
81
|
-
pth:
|
|
82
|
-
ten:
|
|
83
|
-
tb:
|
|
84
|
-
mp:
|
|
85
|
-
msg:
|
|
86
|
-
npy:
|
|
87
|
-
npz:
|
|
88
|
-
cbor:
|
|
60
|
+
txt: str | None = Field(default=None)
|
|
61
|
+
text: str | None = Field(default=None)
|
|
62
|
+
cap: str | None = Field(default=None)
|
|
63
|
+
transcript: str | None = Field(default=None)
|
|
64
|
+
cls: int | None = Field(default=None)
|
|
65
|
+
cls2: int | None = Field(default=None)
|
|
66
|
+
index: int | None = Field(default=None)
|
|
67
|
+
inx: int | None = Field(default=None)
|
|
68
|
+
id: int | None = Field(default=None)
|
|
69
|
+
json: dict | None = Field(default=None) # type: ignore[assignment]
|
|
70
|
+
jsn: dict | None = Field(default=None)
|
|
71
|
+
|
|
72
|
+
pyd: bytes | None = Field(default=None)
|
|
73
|
+
pickle: bytes | None = Field(default=None)
|
|
74
|
+
pth: bytes | None = Field(default=None)
|
|
75
|
+
ten: bytes | None = Field(default=None)
|
|
76
|
+
tb: bytes | None = Field(default=None)
|
|
77
|
+
mp: bytes | None = Field(default=None)
|
|
78
|
+
msg: bytes | None = Field(default=None)
|
|
79
|
+
npy: bytes | None = Field(default=None)
|
|
80
|
+
npz: bytes | None = Field(default=None)
|
|
81
|
+
cbor: bytes | None = Field(default=None)
|
|
89
82
|
|
|
90
83
|
|
|
91
84
|
class WDSReadableSubclass(DataModel):
|
|
@@ -189,9 +182,11 @@ class Builder:
|
|
|
189
182
|
return
|
|
190
183
|
|
|
191
184
|
anno = field.annotation
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
185
|
+
anno_origin = get_origin(anno)
|
|
186
|
+
if anno_origin in (Union, types.UnionType):
|
|
187
|
+
anno_args = get_args(anno)
|
|
188
|
+
if len(anno_args) == 2 and type(None) in anno_args:
|
|
189
|
+
return anno_args[0] if anno_args[1] is type(None) else anno_args[1]
|
|
195
190
|
|
|
196
191
|
return anno
|
|
197
192
|
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from collections.abc import Iterator
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
from pydantic import BaseModel, Field
|
|
@@ -23,18 +22,18 @@ warnings.filterwarnings(
|
|
|
23
22
|
|
|
24
23
|
class Laion(WDSReadableSubclass):
|
|
25
24
|
uid: str = Field(default="")
|
|
26
|
-
face_bboxes:
|
|
27
|
-
caption:
|
|
28
|
-
url:
|
|
29
|
-
key:
|
|
30
|
-
status:
|
|
31
|
-
error_message:
|
|
32
|
-
width:
|
|
33
|
-
height:
|
|
34
|
-
original_width:
|
|
35
|
-
original_height:
|
|
36
|
-
exif:
|
|
37
|
-
sha256:
|
|
25
|
+
face_bboxes: list[list[float]] | None = Field(default=None)
|
|
26
|
+
caption: str | None = Field(default=None)
|
|
27
|
+
url: str | None = Field(default=None)
|
|
28
|
+
key: str | None = Field(default=None)
|
|
29
|
+
status: str | None = Field(default=None)
|
|
30
|
+
error_message: str | None = Field(default=None)
|
|
31
|
+
width: int | None = Field(default=None)
|
|
32
|
+
height: int | None = Field(default=None)
|
|
33
|
+
original_width: int | None = Field(default=None)
|
|
34
|
+
original_height: int | None = Field(default=None)
|
|
35
|
+
exif: str | None = Field(default=None)
|
|
36
|
+
sha256: str | None = Field(default=None)
|
|
38
37
|
|
|
39
38
|
@staticmethod
|
|
40
39
|
def _reader(builder, item):
|
|
@@ -42,13 +41,13 @@ class Laion(WDSReadableSubclass):
|
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
class WDSLaion(WDSBasic):
|
|
45
|
-
txt:
|
|
46
|
-
json: Laion # type: ignore[assignment]
|
|
44
|
+
txt: str | None = Field(default=None)
|
|
45
|
+
json: Laion = Field(default_factory=Laion) # type: ignore[assignment]
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
class LaionMeta(BaseModel):
|
|
50
49
|
file: File
|
|
51
|
-
index:
|
|
50
|
+
index: int | None = Field(default=None)
|
|
52
51
|
b32_img: list[float] = Field(default=[])
|
|
53
52
|
b32_txt: list[float] = Field(default=[])
|
|
54
53
|
l14_img: list[float] = Field(default=[])
|
datachain/listing.py
CHANGED
|
@@ -2,7 +2,7 @@ import glob
|
|
|
2
2
|
import os
|
|
3
3
|
from collections.abc import Iterable, Iterator
|
|
4
4
|
from functools import cached_property
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
6
|
|
|
7
7
|
from sqlalchemy import Column
|
|
8
8
|
from sqlalchemy.sql import func
|
|
@@ -25,8 +25,8 @@ class Listing:
|
|
|
25
25
|
metastore: "AbstractMetastore",
|
|
26
26
|
warehouse: "AbstractWarehouse",
|
|
27
27
|
client: "Client",
|
|
28
|
-
dataset_name:
|
|
29
|
-
dataset_version:
|
|
28
|
+
dataset_name: str | None = None,
|
|
29
|
+
dataset_version: str | None = None,
|
|
30
30
|
column: str = "file",
|
|
31
31
|
):
|
|
32
32
|
self.metastore = metastore
|
|
@@ -102,7 +102,7 @@ class Listing:
|
|
|
102
102
|
def collect_nodes_to_instantiate(
|
|
103
103
|
self,
|
|
104
104
|
sources: Iterable["DataSource"],
|
|
105
|
-
copy_to_filename:
|
|
105
|
+
copy_to_filename: str | None,
|
|
106
106
|
recursive=False,
|
|
107
107
|
copy_dir_contents=False,
|
|
108
108
|
from_dataset=False,
|
datachain/model/bbox.py
CHANGED
|
@@ -198,7 +198,9 @@ class BBox(DataModel):
|
|
|
198
198
|
def pose_inside(self, pose: Union["Pose", "Pose3D"]) -> bool:
|
|
199
199
|
"""Return True if the pose is inside the bounding box."""
|
|
200
200
|
return all(
|
|
201
|
-
self.point_inside(x, y)
|
|
201
|
+
self.point_inside(x, y)
|
|
202
|
+
for x, y in zip(pose.x, pose.y, strict=False)
|
|
203
|
+
if x > 0 or y > 0
|
|
202
204
|
)
|
|
203
205
|
|
|
204
206
|
@staticmethod
|
datachain/namespace.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import builtins
|
|
2
2
|
from dataclasses import dataclass, fields
|
|
3
3
|
from datetime import datetime
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
5
|
|
|
6
6
|
from datachain.error import InvalidNamespaceNameError
|
|
7
7
|
|
|
@@ -9,7 +9,7 @@ N = TypeVar("N", bound="Namespace")
|
|
|
9
9
|
NAMESPACE_NAME_RESERVED_CHARS = [".", "@"]
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def parse_name(name: str) -> tuple[str,
|
|
12
|
+
def parse_name(name: str) -> tuple[str, str | None]:
|
|
13
13
|
"""
|
|
14
14
|
Parses namespace name into namespace and optional project name.
|
|
15
15
|
If both namespace and project are defined in name, they need to be split by dot
|
|
@@ -33,7 +33,7 @@ class Namespace:
|
|
|
33
33
|
id: int
|
|
34
34
|
uuid: str
|
|
35
35
|
name: str
|
|
36
|
-
descr:
|
|
36
|
+
descr: str | None
|
|
37
37
|
created_at: datetime
|
|
38
38
|
|
|
39
39
|
@staticmethod
|
|
@@ -73,7 +73,7 @@ class Namespace:
|
|
|
73
73
|
id: int,
|
|
74
74
|
uuid: str,
|
|
75
75
|
name: str,
|
|
76
|
-
descr:
|
|
76
|
+
descr: str | None,
|
|
77
77
|
created_at: datetime,
|
|
78
78
|
) -> "Namespace":
|
|
79
79
|
return cls(id, uuid, name, descr, created_at)
|
datachain/node.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
|
|
5
5
|
import attrs
|
|
6
6
|
|
|
@@ -53,11 +53,11 @@ class Node:
|
|
|
53
53
|
sys__rand: int = 0
|
|
54
54
|
path: str = ""
|
|
55
55
|
etag: str = ""
|
|
56
|
-
version:
|
|
56
|
+
version: str | None = None
|
|
57
57
|
is_latest: bool = True
|
|
58
|
-
last_modified:
|
|
58
|
+
last_modified: datetime | None = None
|
|
59
59
|
size: int = 0
|
|
60
|
-
location:
|
|
60
|
+
location: str | None = None
|
|
61
61
|
source: StorageURI = StorageURI("") # noqa: RUF009
|
|
62
62
|
dir_type: int = DirType.FILE
|
|
63
63
|
|
|
@@ -90,7 +90,7 @@ class Node:
|
|
|
90
90
|
return self.path + "/"
|
|
91
91
|
return self.path
|
|
92
92
|
|
|
93
|
-
def to_file(self, source:
|
|
93
|
+
def to_file(self, source: StorageURI | None = None) -> File:
|
|
94
94
|
if source is None:
|
|
95
95
|
source = self.source
|
|
96
96
|
return File(
|
|
@@ -189,7 +189,7 @@ class NodeWithPath:
|
|
|
189
189
|
TIME_FMT = "%Y-%m-%d %H:%M"
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
def long_line_str(name: str, timestamp:
|
|
192
|
+
def long_line_str(name: str, timestamp: datetime | None) -> str:
|
|
193
193
|
if timestamp is None:
|
|
194
194
|
time = "-"
|
|
195
195
|
else:
|
datachain/nodes_thread_pool.py
CHANGED
datachain/plugins.py
CHANGED
|
@@ -17,13 +17,7 @@ def ensure_plugins_loaded() -> None:
|
|
|
17
17
|
|
|
18
18
|
# Compatible across importlib.metadata versions
|
|
19
19
|
eps_obj = importlib_metadata.entry_points()
|
|
20
|
-
|
|
21
|
-
eps_list = eps_obj.select(group="datachain.callables")
|
|
22
|
-
else:
|
|
23
|
-
# Compatibility for older versions of importlib_metadata, Python 3.9
|
|
24
|
-
eps_list = eps_obj.get("datachain.callables", []) # type: ignore[attr-defined]
|
|
25
|
-
|
|
26
|
-
for ep in eps_list:
|
|
20
|
+
for ep in eps_obj.select(group="datachain.callables"):
|
|
27
21
|
func = ep.load()
|
|
28
22
|
func()
|
|
29
23
|
|
datachain/project.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import builtins
|
|
2
2
|
from dataclasses import dataclass, fields
|
|
3
3
|
from datetime import datetime
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
5
|
|
|
6
6
|
from datachain.error import InvalidProjectNameError
|
|
7
7
|
from datachain.namespace import Namespace
|
|
@@ -15,7 +15,7 @@ class Project:
|
|
|
15
15
|
id: int
|
|
16
16
|
uuid: str
|
|
17
17
|
name: str
|
|
18
|
-
descr:
|
|
18
|
+
descr: str | None
|
|
19
19
|
created_at: datetime
|
|
20
20
|
namespace: Namespace
|
|
21
21
|
|
|
@@ -52,12 +52,12 @@ class Project:
|
|
|
52
52
|
namespace_id: int,
|
|
53
53
|
namespace_uuid: str,
|
|
54
54
|
namespace_name: str,
|
|
55
|
-
namespace_descr:
|
|
55
|
+
namespace_descr: str | None,
|
|
56
56
|
namespace_created_at: datetime,
|
|
57
57
|
project_id: int,
|
|
58
58
|
uuid: str,
|
|
59
59
|
name: str,
|
|
60
|
-
descr:
|
|
60
|
+
descr: str | None,
|
|
61
61
|
created_at: datetime,
|
|
62
62
|
project_namespace_id: int,
|
|
63
63
|
) -> "Project":
|
datachain/query/batch.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import math
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Generator, Sequence
|
|
5
|
-
from typing import Callable, Optional, Union
|
|
4
|
+
from collections.abc import Callable, Generator, Sequence
|
|
6
5
|
|
|
7
6
|
import sqlalchemy as sa
|
|
8
7
|
|
|
@@ -10,7 +9,7 @@ from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
|
10
9
|
from datachain.query.utils import get_query_column
|
|
11
10
|
|
|
12
11
|
RowsOutputBatch = Sequence[Sequence]
|
|
13
|
-
RowsOutput =
|
|
12
|
+
RowsOutput = Sequence | RowsOutputBatch
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class BatchingStrategy(ABC):
|
|
@@ -23,7 +22,7 @@ class BatchingStrategy(ABC):
|
|
|
23
22
|
self,
|
|
24
23
|
execute: Callable,
|
|
25
24
|
query: sa.Select,
|
|
26
|
-
id_col:
|
|
25
|
+
id_col: sa.ColumnElement | None = None,
|
|
27
26
|
) -> Generator[RowsOutput, None, None]:
|
|
28
27
|
"""Apply the provided parameters to the UDF."""
|
|
29
28
|
|
|
@@ -40,7 +39,7 @@ class NoBatching(BatchingStrategy):
|
|
|
40
39
|
self,
|
|
41
40
|
execute: Callable,
|
|
42
41
|
query: sa.Select,
|
|
43
|
-
id_col:
|
|
42
|
+
id_col: sa.ColumnElement | None = None,
|
|
44
43
|
) -> Generator[Sequence, None, None]:
|
|
45
44
|
ids_only = False
|
|
46
45
|
if id_col is not None:
|
|
@@ -66,7 +65,7 @@ class Batch(BatchingStrategy):
|
|
|
66
65
|
self,
|
|
67
66
|
execute: Callable,
|
|
68
67
|
query: sa.Select,
|
|
69
|
-
id_col:
|
|
68
|
+
id_col: sa.ColumnElement | None = None,
|
|
70
69
|
) -> Generator[RowsOutput, None, None]:
|
|
71
70
|
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
72
71
|
|
|
@@ -105,7 +104,7 @@ class Partition(BatchingStrategy):
|
|
|
105
104
|
self,
|
|
106
105
|
execute: Callable,
|
|
107
106
|
query: sa.Select,
|
|
108
|
-
id_col:
|
|
107
|
+
id_col: sa.ColumnElement | None = None,
|
|
109
108
|
) -> Generator[RowsOutput, None, None]:
|
|
110
109
|
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
|
|
111
110
|
raise RuntimeError("partition column not found in query")
|
|
@@ -115,7 +114,7 @@ class Partition(BatchingStrategy):
|
|
|
115
114
|
query = query.with_only_columns(id_col, partition_col)
|
|
116
115
|
ids_only = True
|
|
117
116
|
|
|
118
|
-
current_partition:
|
|
117
|
+
current_partition: int | None = None
|
|
119
118
|
batch: list = []
|
|
120
119
|
|
|
121
120
|
query_fields = [str(c.name) for c in query.selected_columns]
|