kaparoo-python 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kaparoo/data/sequences/base.py +17 -3
- kaparoo/data/sequences/composers.py +153 -98
- kaparoo/data/sequences/templates.py +44 -106
- kaparoo/data/sequences/utils.py +18 -28
- kaparoo/filesystem/__init__.py +2 -15
- kaparoo/filesystem/directory.py +44 -21
- kaparoo/filesystem/exceptions.py +4 -3
- kaparoo/filesystem/exclude.py +109 -0
- kaparoo/filesystem/existence.py +29 -24
- kaparoo/filesystem/hierarchy/__init__.py +33 -0
- kaparoo/filesystem/hierarchy/base.py +100 -0
- kaparoo/filesystem/hierarchy/conditions.py +470 -0
- kaparoo/filesystem/hierarchy/entry.py +395 -0
- kaparoo/filesystem/hierarchy/group.py +302 -0
- kaparoo/filesystem/hierarchy/scaffold.py +227 -0
- kaparoo/filesystem/hierarchy/traverse/__init__.py +20 -0
- kaparoo/filesystem/hierarchy/traverse/_utils.py +89 -0
- kaparoo/filesystem/hierarchy/traverse/locate.py +181 -0
- kaparoo/filesystem/hierarchy/traverse/validate.py +609 -0
- kaparoo/filesystem/hierarchy/utils.py +45 -0
- kaparoo/filesystem/search/__init__.py +2 -82
- kaparoo/filesystem/search/classes.py +70 -67
- kaparoo/filesystem/search/wrappers.py +65 -29
- kaparoo/filesystem/staged.py +140 -114
- kaparoo/filesystem/types.py +2 -0
- kaparoo/filesystem/units.py +23 -0
- kaparoo/filesystem/utils.py +60 -41
- kaparoo/{filesystem/search/filters → filters}/__init__.py +25 -5
- kaparoo/{filesystem/search/filters → filters}/base.py +36 -16
- kaparoo/filters/enumerable.py +346 -0
- kaparoo/filters/logical.py +112 -0
- kaparoo/filters/multi_pattern.py +142 -0
- kaparoo/{filesystem/search/filters → filters}/pattern.py +58 -82
- kaparoo/filters/types.py +89 -0
- kaparoo/{filesystem/search/filters → filters}/utils.py +7 -2
- kaparoo/utils/__init__.py +11 -0
- kaparoo/utils/aggregate.py +388 -108
- kaparoo/utils/checks.py +99 -0
- kaparoo/utils/optional.py +36 -56
- kaparoo/utils/timer.py +95 -101
- {kaparoo_python-0.7.0.dist-info → kaparoo_python-0.8.0.dist-info}/METADATA +52 -10
- kaparoo_python-0.8.0.dist-info/RECORD +48 -0
- {kaparoo_python-0.7.0.dist-info → kaparoo_python-0.8.0.dist-info}/WHEEL +1 -1
- kaparoo/filesystem/search/deprecated.py +0 -289
- kaparoo/filesystem/search/filters/logical.py +0 -138
- kaparoo/filesystem/search/filters/multi_pattern.py +0 -160
- kaparoo/filesystem/search/filters/types.py +0 -47
- kaparoo_python-0.7.0.dist-info/RECORD +0 -34
- {kaparoo_python-0.7.0.dist-info → kaparoo_python-0.8.0.dist-info}/licenses/LICENSE +0 -0
kaparoo/data/sequences/base.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""The `DataSequence[T, M]` abstract base: indexable items with metadata."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
__all__ = ("DataSequence",)
|
|
@@ -32,7 +34,7 @@ class DataSequence[T, M = None](Sequence[T]):
|
|
|
32
34
|
|
|
33
35
|
@abstractmethod
|
|
34
36
|
def __len__(self) -> int:
|
|
35
|
-
|
|
37
|
+
"""Return the number of items in the sequence."""
|
|
36
38
|
|
|
37
39
|
# --- item access -------------------------------------------------------
|
|
38
40
|
|
|
@@ -50,24 +52,36 @@ class DataSequence[T, M = None](Sequence[T]):
|
|
|
50
52
|
|
|
51
53
|
@abstractmethod
|
|
52
54
|
def get_item(self, index: int) -> T:
|
|
53
|
-
|
|
55
|
+
"""Fetch and return the item at `index`."""
|
|
54
56
|
|
|
55
57
|
def get_items(self, indices: Sequence[int]) -> Sequence[T]:
|
|
58
|
+
"""Fetch many items at once, in `indices` order.
|
|
59
|
+
|
|
60
|
+
Defaults to one `get_item` per index; override to use a backing
|
|
61
|
+
store's native batch read.
|
|
62
|
+
"""
|
|
56
63
|
return [self.get_item(index) for index in indices]
|
|
57
64
|
|
|
58
65
|
# --- metadata access ---------------------------------------------------
|
|
59
66
|
|
|
60
67
|
@abstractmethod
|
|
61
68
|
def get_meta(self, index: int) -> M:
|
|
62
|
-
|
|
69
|
+
"""Return the metadata for the item at `index` (`None` when `M` is `None`)."""
|
|
63
70
|
|
|
64
71
|
def get_metas(self, indices: Sequence[int]) -> Sequence[M]:
|
|
72
|
+
"""Fetch many metadata values at once, in `indices` order.
|
|
73
|
+
|
|
74
|
+
Defaults to one `get_meta` per index; override alongside
|
|
75
|
+
`get_items` when a batch read is cheaper.
|
|
76
|
+
"""
|
|
65
77
|
return [self.get_meta(index) for index in indices]
|
|
66
78
|
|
|
67
79
|
# --- combined item + metadata ------------------------------------------
|
|
68
80
|
|
|
69
81
|
def get_pair(self, index: int) -> tuple[T, M]:
|
|
82
|
+
"""Return the `(item, metadata)` pair at `index`."""
|
|
70
83
|
return self.get_item(index), self.get_meta(index)
|
|
71
84
|
|
|
72
85
|
def get_pairs(self, indices: Sequence[int]) -> Sequence[tuple[T, M]]:
|
|
86
|
+
"""Fetch many `(item, metadata)` pairs at once, in `indices` order."""
|
|
73
87
|
return [self.get_pair(index) for index in indices]
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Lazy `DataSequence` composers: slice, concat, transform, window, zip."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
__all__ = (
|
|
@@ -10,7 +12,7 @@ __all__ = (
|
|
|
10
12
|
|
|
11
13
|
from abc import abstractmethod
|
|
12
14
|
from bisect import bisect_right
|
|
13
|
-
from typing import TYPE_CHECKING, cast
|
|
15
|
+
from typing import TYPE_CHECKING, cast, override
|
|
14
16
|
|
|
15
17
|
from kaparoo.data.sequences.base import DataSequence
|
|
16
18
|
|
|
@@ -18,23 +20,35 @@ if TYPE_CHECKING:
|
|
|
18
20
|
from collections.abc import Callable, Sequence
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
"""
|
|
23
|
+
def _resolve_index(index: int, length: int) -> int:
|
|
24
|
+
"""Normalize a possibly-negative index against `length`, validating range.
|
|
25
|
+
|
|
26
|
+
Used by `ConcatSequence`, `WindowedSequence`, and `ZippedSequence`.
|
|
27
|
+
`SlicedSequence` intentionally opts out -- it indexes its `indices` tuple
|
|
28
|
+
directly, which wraps and raises the same way but with the builtin message.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
IndexError: If `index` is outside `[-length, length)`.
|
|
32
|
+
"""
|
|
33
|
+
original = index
|
|
34
|
+
if index < 0:
|
|
35
|
+
index += length
|
|
36
|
+
if not 0 <= index < length:
|
|
37
|
+
msg = f"index {original} out of range for length {length}"
|
|
38
|
+
raise IndexError(msg)
|
|
39
|
+
return index
|
|
40
|
+
|
|
23
41
|
|
|
24
|
-
|
|
25
|
-
view
|
|
26
|
-
and out-of-range indices delegate to Python's tuple semantics
|
|
27
|
-
(negative wraps, out-of-range raises `IndexError`).
|
|
42
|
+
class SlicedSequence[T, M](DataSequence[T, M]):
|
|
43
|
+
"""A view of `source` exposing only the items at `indices`, in that order.
|
|
28
44
|
|
|
29
|
-
`indices` is taken as-is
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
out-of-range entry surfaces only when that position is accessed.
|
|
45
|
+
`indices` is taken as-is -- duplicates repeat the source item, order is
|
|
46
|
+
preserved -- and is not bounds-checked against `source` until a position
|
|
47
|
+
is accessed. A negative view index wraps; out of range raises `IndexError`.
|
|
33
48
|
|
|
34
49
|
Example:
|
|
35
50
|
>>> sliced = SlicedSequence(full_dataset, [3, 7, 11])
|
|
36
51
|
>>> sliced[0] # == full_dataset[3]
|
|
37
|
-
>>> sliced[1] # == full_dataset[7]
|
|
38
52
|
"""
|
|
39
53
|
|
|
40
54
|
def __init__(
|
|
@@ -58,42 +72,54 @@ class SlicedSequence[T, M](DataSequence[T, M]):
|
|
|
58
72
|
def __len__(self) -> int:
|
|
59
73
|
return len(self._indices)
|
|
60
74
|
|
|
75
|
+
@override
|
|
61
76
|
def get_item(self, index: int) -> T:
|
|
77
|
+
"""Fetch the source item at the mapped index `indices[index]`."""
|
|
62
78
|
return self._source.get_item(self._indices[index])
|
|
63
79
|
|
|
80
|
+
@override
|
|
64
81
|
def get_meta(self, index: int) -> M:
|
|
82
|
+
"""Fetch the source metadata at the mapped index `indices[index]`."""
|
|
65
83
|
return self._source.get_meta(self._indices[index])
|
|
66
84
|
|
|
85
|
+
@override
|
|
86
|
+
def get_items(self, indices: Sequence[int]) -> Sequence[T]:
|
|
87
|
+
"""Map each view index through `indices`, then batch-fetch from `source`."""
|
|
88
|
+
return self._source.get_items([self._indices[i] for i in indices])
|
|
89
|
+
|
|
90
|
+
@override
|
|
91
|
+
def get_metas(self, indices: Sequence[int]) -> Sequence[M]:
|
|
92
|
+
"""Map each view index through `indices`, then batch-fetch metadata."""
|
|
93
|
+
return self._source.get_metas([self._indices[i] for i in indices])
|
|
94
|
+
|
|
67
95
|
|
|
68
96
|
class TransformedSequence[T_in, M_in, T_out = T_in, M_out = M_in](
|
|
69
97
|
DataSequence[T_out, M_out]
|
|
70
98
|
):
|
|
71
99
|
"""A view of `source` with `transform` applied lazily to each item.
|
|
72
100
|
|
|
73
|
-
`transform`
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
101
|
+
`transform` runs on demand in `get_item`; nothing is converted at
|
|
102
|
+
construction. `get_meta` passes `source.get_meta` through unchanged, which
|
|
103
|
+
is correct only when the metadata type is unchanged (the default
|
|
104
|
+
`M_out == M_in`). **Override `get_meta` whenever `M_out != M_in`**: the
|
|
105
|
+
passthrough's `cast` cannot catch a missing override -- generics are erased
|
|
106
|
+
at runtime -- so a forgotten one silently yields an `M_in` value mistyped
|
|
107
|
+
as `M_out`.
|
|
77
108
|
|
|
78
109
|
Type Parameters:
|
|
79
|
-
T_in:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
When `M_out != M_in`, override `get_meta` in a subclass;
|
|
84
|
-
the default passthrough is only safe when `M_out == M_in`.
|
|
110
|
+
T_in, M_in: The source's element and metadata types.
|
|
111
|
+
T_out: The transformed element type. Defaults to `T_in`.
|
|
112
|
+
M_out: The transformed metadata type. Defaults to `M_in` (the
|
|
113
|
+
passthrough case); set it and override `get_meta` otherwise.
|
|
85
114
|
|
|
86
115
|
Example:
|
|
87
116
|
>>> # Item-only transform; metadata passes through unchanged.
|
|
88
117
|
>>> normalized = TransformedSequence(image_folder, normalize)
|
|
89
118
|
|
|
90
|
-
>>> #
|
|
119
|
+
>>> # Metadata transform via subclassing:
|
|
91
120
|
>>> class Augmented(TransformedSequence[ndarray, Path, ndarray, AugMeta]):
|
|
92
121
|
... def get_meta(self, index: int) -> AugMeta:
|
|
93
|
-
... return AugMeta(
|
|
94
|
-
... path=self.source.get_meta(index),
|
|
95
|
-
... applied="normalize",
|
|
96
|
-
... )
|
|
122
|
+
... return AugMeta(self.source.get_meta(index), applied="normalize")
|
|
97
123
|
"""
|
|
98
124
|
|
|
99
125
|
def __init__(
|
|
@@ -112,21 +138,33 @@ class TransformedSequence[T_in, M_in, T_out = T_in, M_out = M_in](
|
|
|
112
138
|
def __len__(self) -> int:
|
|
113
139
|
return len(self._source)
|
|
114
140
|
|
|
141
|
+
@override
|
|
115
142
|
def get_item(self, index: int) -> T_out:
|
|
143
|
+
"""Fetch the source item at `index` and apply `transform`."""
|
|
116
144
|
return self._transform(self._source.get_item(index))
|
|
117
145
|
|
|
146
|
+
@override
|
|
147
|
+
def get_items(self, indices: Sequence[int]) -> Sequence[T_out]:
|
|
148
|
+
"""Batch-fetch from `source` and apply `transform` to each item."""
|
|
149
|
+
return [self._transform(item) for item in self._source.get_items(indices)]
|
|
150
|
+
|
|
151
|
+
@override
|
|
118
152
|
def get_meta(self, index: int) -> M_out:
|
|
119
|
-
|
|
153
|
+
"""Pass `source`'s metadata through unchanged (valid only when `M_out == M_in`)."""
|
|
154
|
+
# Passthrough -- correct only when M_out == M_in. A subclass with a
|
|
155
|
+
# different M_out MUST override this; the cast cannot catch a missing
|
|
156
|
+
# override, since generics are erased at runtime.
|
|
120
157
|
return cast("M_out", self._source.get_meta(index))
|
|
121
158
|
|
|
122
159
|
|
|
123
160
|
class ConcatSequence[T, M](DataSequence[T, M]):
|
|
124
161
|
"""The end-to-end concatenation of zero or more `sources`.
|
|
125
162
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
163
|
+
A logical index maps to the `(source, local index)` it falls in -- an
|
|
164
|
+
O(log N) lookup in the number of sources. Negative indices wrap; out of
|
|
165
|
+
range raises `IndexError`. Batch access (`get_items` / `get_metas`)
|
|
166
|
+
delegates one grouped call per source, so a source's own batch
|
|
167
|
+
optimization is used, with results kept in request order.
|
|
130
168
|
|
|
131
169
|
Example:
|
|
132
170
|
>>> combined = ConcatSequence(train_a, train_b, train_c)
|
|
@@ -148,57 +186,89 @@ class ConcatSequence[T, M](DataSequence[T, M]):
|
|
|
148
186
|
def __len__(self) -> int:
|
|
149
187
|
return self._cumulative[-1]
|
|
150
188
|
|
|
151
|
-
def
|
|
152
|
-
"""Resolve a logical index to `(source,
|
|
189
|
+
def _locate_index(self, index: int) -> tuple[int, int]:
|
|
190
|
+
"""Resolve a logical index to `(source position, local index)`.
|
|
153
191
|
|
|
154
192
|
Raises:
|
|
155
193
|
IndexError: If `index` is outside `[-len(self), len(self))`.
|
|
156
194
|
"""
|
|
157
|
-
|
|
158
|
-
original = index
|
|
159
|
-
if index < 0:
|
|
160
|
-
index += n
|
|
161
|
-
if not 0 <= index < n:
|
|
162
|
-
msg = f"index {original} out of range for length {n}"
|
|
163
|
-
raise IndexError(msg)
|
|
195
|
+
index = _resolve_index(index, self._cumulative[-1])
|
|
164
196
|
i = bisect_right(self._cumulative, index) - 1
|
|
165
|
-
return
|
|
197
|
+
return i, index - self._cumulative[i]
|
|
198
|
+
|
|
199
|
+
def _locate(self, index: int) -> tuple[DataSequence[T, M], int]:
|
|
200
|
+
"""Resolve a logical index to `(source, local_index)`."""
|
|
201
|
+
i, local = self._locate_index(index)
|
|
202
|
+
return self._sources[i], local
|
|
166
203
|
|
|
204
|
+
def _gather[R](
|
|
205
|
+
self,
|
|
206
|
+
indices: Sequence[int],
|
|
207
|
+
fetch: Callable[[DataSequence[T, M], list[int]], Sequence[R]],
|
|
208
|
+
) -> list[R]:
|
|
209
|
+
"""Batch-fetch `indices` with one grouped `fetch` per source.
|
|
210
|
+
|
|
211
|
+
The shared core of `get_items` / `get_metas`, which differ only in the
|
|
212
|
+
per-source `fetch`; results are scattered back into request order.
|
|
213
|
+
"""
|
|
214
|
+
buckets: dict[int, list[tuple[int, int]]] = {}
|
|
215
|
+
for position, index in enumerate(indices):
|
|
216
|
+
source_index, local = self._locate_index(index)
|
|
217
|
+
buckets.setdefault(source_index, []).append((position, local))
|
|
218
|
+
|
|
219
|
+
gathered: dict[int, R] = {}
|
|
220
|
+
for source_index, entries in buckets.items():
|
|
221
|
+
fetched = fetch(
|
|
222
|
+
self._sources[source_index], [local for _, local in entries]
|
|
223
|
+
)
|
|
224
|
+
for (position, _), value in zip(entries, fetched, strict=True):
|
|
225
|
+
gathered[position] = value
|
|
226
|
+
return [gathered[position] for position in range(len(indices))]
|
|
227
|
+
|
|
228
|
+
@override
|
|
167
229
|
def get_item(self, index: int) -> T:
|
|
230
|
+
"""Locate the source for `index` and fetch its local item."""
|
|
168
231
|
source, local = self._locate(index)
|
|
169
232
|
return source.get_item(local)
|
|
170
233
|
|
|
234
|
+
@override
|
|
235
|
+
def get_items(self, indices: Sequence[int]) -> Sequence[T]:
|
|
236
|
+
"""Group `indices` by source and batch-fetch items, kept in request order."""
|
|
237
|
+
return self._gather(indices, lambda source, locals_: source.get_items(locals_))
|
|
238
|
+
|
|
239
|
+
@override
|
|
171
240
|
def get_meta(self, index: int) -> M:
|
|
241
|
+
"""Locate the source for `index` and fetch its local metadata."""
|
|
172
242
|
source, local = self._locate(index)
|
|
173
243
|
return source.get_meta(local)
|
|
174
244
|
|
|
245
|
+
@override
|
|
246
|
+
def get_metas(self, indices: Sequence[int]) -> Sequence[M]:
|
|
247
|
+
"""Group `indices` by source and batch-fetch metadata, kept in request order."""
|
|
248
|
+
return self._gather(indices, lambda source, locals_: source.get_metas(locals_))
|
|
249
|
+
|
|
175
250
|
|
|
176
251
|
class WindowedSequence[T, M_in, M_out = M_in](DataSequence[tuple[T, ...], M_out]):
|
|
177
252
|
"""An abstract sliding-window view over `source`.
|
|
178
253
|
|
|
179
|
-
Each item is a tuple of `size` items from `source`, starting at
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
Subclasses use the `source`, `size`, `step`, `skip` properties and
|
|
187
|
-
should call `_normalize_index` from `get_meta` so negative and
|
|
188
|
-
out-of-range window indices behave the same way as in `get_item`.
|
|
254
|
+
Each item is a tuple of `size` items from `source`, the window starting at
|
|
255
|
+
`i * step` with intra-window stride `skip`. `get_item` is implemented;
|
|
256
|
+
**the window's metadata is intentionally left abstract** so a subclass
|
|
257
|
+
decides how per-frame metadata becomes window metadata (`M_in` -> `M_out`).
|
|
258
|
+
Subclasses should call `_normalize_index` in their `get_meta` so window
|
|
259
|
+
indices behave as in `get_item`.
|
|
189
260
|
|
|
190
261
|
Type Parameters:
|
|
191
|
-
T:
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
Determined by the subclass's `get_meta` return.
|
|
262
|
+
T: The source's element type; each item is a `tuple[T, ...]`.
|
|
263
|
+
M_in: The source's per-frame metadata type.
|
|
264
|
+
M_out: The window's metadata type a subclass produces. Defaults to
|
|
265
|
+
`M_in`.
|
|
196
266
|
|
|
197
267
|
Args:
|
|
198
268
|
source: The sequence to window over.
|
|
199
|
-
size:
|
|
200
|
-
step:
|
|
201
|
-
|
|
269
|
+
size: Items per window. Must be positive.
|
|
270
|
+
step: Advance between consecutive windows. Defaults to 1 (windows
|
|
271
|
+
overlap by `size - 1`).
|
|
202
272
|
skip: Intra-window stride. Defaults to 1 (consecutive frames).
|
|
203
273
|
|
|
204
274
|
Raises:
|
|
@@ -261,25 +331,19 @@ class WindowedSequence[T, M_in, M_out = M_in](DataSequence[tuple[T, ...], M_out]
|
|
|
261
331
|
Raises:
|
|
262
332
|
IndexError: If `index` is outside `[-len(self), len(self))`.
|
|
263
333
|
"""
|
|
264
|
-
|
|
265
|
-
original = index
|
|
266
|
-
if index < 0:
|
|
267
|
-
index += n
|
|
268
|
-
if not 0 <= index < n:
|
|
269
|
-
msg = f"index {original} out of range for length {n}"
|
|
270
|
-
raise IndexError(msg)
|
|
271
|
-
return index
|
|
334
|
+
return _resolve_index(index, self._length)
|
|
272
335
|
|
|
336
|
+
@override
|
|
273
337
|
def get_item(self, index: int) -> tuple[T, ...]:
|
|
338
|
+
"""Build the window at `index` as a tuple of `size` strided source items."""
|
|
274
339
|
index = self._normalize_index(index)
|
|
275
340
|
start = index * self._step
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
)
|
|
341
|
+
stop = start + self._size * self._skip
|
|
342
|
+
return tuple(self._source.get_items(range(start, stop, self._skip)))
|
|
279
343
|
|
|
280
344
|
@abstractmethod
|
|
281
345
|
def get_meta(self, index: int) -> M_out:
|
|
282
|
-
|
|
346
|
+
"""Return the metadata for window `index` (the `M_in` -> `M_out` policy)."""
|
|
283
347
|
|
|
284
348
|
|
|
285
349
|
class ZippedSequence[T1, T2, M1 = None, M2 = None](
|
|
@@ -291,26 +355,16 @@ class ZippedSequence[T1, T2, M1 = None, M2 = None](
|
|
|
291
355
|
`(first.get_meta(i), second.get_meta(i))` -- the "paired image + label"
|
|
292
356
|
pattern that `ConcatSequence` (end-to-end) cannot express.
|
|
293
357
|
|
|
294
|
-
With `strict=True` (the default) the
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
override `get_meta`.
|
|
358
|
+
With `strict=True` (the default) the sequences must be the same length, or
|
|
359
|
+
construction raises `ValueError`; with `strict=False` the view truncates to
|
|
360
|
+
the shorter, like the builtin `zip`. For a different combined-metadata
|
|
361
|
+
shape, subclass and override `get_meta`.
|
|
299
362
|
|
|
300
363
|
Type Parameters:
|
|
301
|
-
T1:
|
|
302
|
-
|
|
303
|
-
M1:
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
Args:
|
|
307
|
-
first: The first sequence.
|
|
308
|
-
second: The second sequence.
|
|
309
|
-
strict: When True (default), require equal lengths and raise on a
|
|
310
|
-
mismatch. When False, truncate to the shorter length.
|
|
311
|
-
|
|
312
|
-
Raises:
|
|
313
|
-
ValueError: If `strict` is True and the sequences differ in length.
|
|
364
|
+
T1, T2: Element types of the first and second sequence; items are
|
|
365
|
+
`tuple[T1, T2]`.
|
|
366
|
+
M1, M2: Their metadata types; metadata is `tuple[M1, M2]`. Each
|
|
367
|
+
defaults to `None` (a sequence without metadata).
|
|
314
368
|
|
|
315
369
|
Example:
|
|
316
370
|
>>> pairs = ZippedSequence(images, labels)
|
|
@@ -354,20 +408,17 @@ class ZippedSequence[T1, T2, M1 = None, M2 = None](
|
|
|
354
408
|
Raises:
|
|
355
409
|
IndexError: If `index` is outside `[-len(self), len(self))`.
|
|
356
410
|
"""
|
|
357
|
-
|
|
358
|
-
original = index
|
|
359
|
-
if index < 0:
|
|
360
|
-
index += n
|
|
361
|
-
if not 0 <= index < n:
|
|
362
|
-
msg = f"index {original} out of range for length {n}"
|
|
363
|
-
raise IndexError(msg)
|
|
364
|
-
return index
|
|
411
|
+
return _resolve_index(index, self._length)
|
|
365
412
|
|
|
413
|
+
@override
|
|
366
414
|
def get_item(self, index: int) -> tuple[T1, T2]:
|
|
415
|
+
"""Fetch the paired `(first[index], second[index])` item."""
|
|
367
416
|
index = self._normalize_index(index)
|
|
368
417
|
return self._first.get_item(index), self._second.get_item(index)
|
|
369
418
|
|
|
419
|
+
@override
|
|
370
420
|
def get_items(self, indices: Sequence[int]) -> Sequence[tuple[T1, T2]]:
|
|
421
|
+
"""Normalize indices, then batch-fetch and pair items from both sources."""
|
|
371
422
|
# Normalize, then bulk-delegate so each source's `get_items`
|
|
372
423
|
# optimization is used.
|
|
373
424
|
normalized = [self._normalize_index(i) for i in indices]
|
|
@@ -379,11 +430,15 @@ class ZippedSequence[T1, T2, M1 = None, M2 = None](
|
|
|
379
430
|
)
|
|
380
431
|
)
|
|
381
432
|
|
|
433
|
+
@override
|
|
382
434
|
def get_meta(self, index: int) -> tuple[M1, M2]:
|
|
435
|
+
"""Fetch the paired `(first, second)` metadata at `index`."""
|
|
383
436
|
index = self._normalize_index(index)
|
|
384
437
|
return self._first.get_meta(index), self._second.get_meta(index)
|
|
385
438
|
|
|
439
|
+
@override
|
|
386
440
|
def get_metas(self, indices: Sequence[int]) -> Sequence[tuple[M1, M2]]:
|
|
441
|
+
"""Normalize indices, then batch-fetch and pair metadata from both sources."""
|
|
387
442
|
normalized = [self._normalize_index(i) for i in indices]
|
|
388
443
|
return list(
|
|
389
444
|
zip(
|