rslearn 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +186 -49
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/detection.py +1 -18
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +27 -32
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/colors.py +20 -0
- rslearn/vis/__init__.py +1 -0
- rslearn/vis/normalization.py +127 -0
- rslearn/vis/render_raster_label.py +96 -0
- rslearn/vis/render_sensor_image.py +27 -0
- rslearn/vis/render_vector_label.py +439 -0
- rslearn/vis/utils.py +99 -0
- rslearn/vis/vis_server.py +574 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/METADATA +14 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/RECORD +42 -33
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/WHEEL +1 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""SimpleTimeSeries encoder."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -25,13 +26,14 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
28
|
encoder: FeatureExtractor,
|
|
28
|
-
|
|
29
|
+
num_timesteps_per_forward_pass: int = 1,
|
|
29
30
|
op: str = "max",
|
|
30
31
|
groups: list[list[int]] | None = None,
|
|
31
32
|
num_layers: int | None = None,
|
|
32
33
|
image_key: str = "image",
|
|
33
34
|
backbone_channels: list[tuple[int, int]] | None = None,
|
|
34
|
-
image_keys: dict[str, int] | None = None,
|
|
35
|
+
image_keys: list[str] | dict[str, int] | None = None,
|
|
36
|
+
image_channels: int | None = None,
|
|
35
37
|
) -> None:
|
|
36
38
|
"""Create a new SimpleTimeSeries.
|
|
37
39
|
|
|
@@ -39,9 +41,11 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
39
41
|
encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
|
|
40
42
|
function that returns the output channels, or backbone_channels must be set.
|
|
41
43
|
It must output a FeatureMaps.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
num_timesteps_per_forward_pass: how many timesteps to pass to the encoder
|
|
45
|
+
in each forward pass. Defaults to 1 (one timestep per forward pass).
|
|
46
|
+
Set to a higher value to batch multiple timesteps together, e.g. for
|
|
47
|
+
pre/post change detection where you want 4 pre and 4 post images
|
|
48
|
+
processed together.
|
|
45
49
|
op: one of max, mean, convrnn, conv3d, or conv1d
|
|
46
50
|
groups: sets of images for which to combine features. Within each set,
|
|
47
51
|
features are combined using the specified operation; then, across sets,
|
|
@@ -51,28 +55,53 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
51
55
|
combined before features and the combined after features. groups is a
|
|
52
56
|
list of sets, and each set is a list of image indices.
|
|
53
57
|
num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
|
|
54
|
-
image_key: the key to access the images.
|
|
58
|
+
image_key: the key to access the images (used when image_keys is not set).
|
|
55
59
|
backbone_channels: manually specify the backbone channels. Can be set if
|
|
56
60
|
the encoder does not provide get_backbone_channels function.
|
|
57
|
-
image_keys:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
+
image_keys: list of keys in input dict to process as multimodal inputs.
|
|
62
|
+
All keys use the same num_timesteps_per_forward_pass. If not set,
|
|
63
|
+
only the single image_key is used. Passing a dict[str, int] is
|
|
64
|
+
deprecated and will be removed on 2026-04-01.
|
|
65
|
+
image_channels: Deprecated, use num_timesteps_per_forward_pass instead.
|
|
66
|
+
Will be removed on 2026-04-01.
|
|
61
67
|
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"
|
|
68
|
+
# Handle deprecated image_channels parameter
|
|
69
|
+
if image_channels is not None:
|
|
70
|
+
warnings.warn(
|
|
71
|
+
"image_channels is deprecated and will be removed on 2026-04-01. "
|
|
72
|
+
"Use num_timesteps_per_forward_pass instead. The new parameter directly "
|
|
73
|
+
"specifies the number of timesteps per forward pass rather than requiring "
|
|
74
|
+
"image_channels // actual_channels.",
|
|
75
|
+
FutureWarning,
|
|
76
|
+
stacklevel=2,
|
|
67
77
|
)
|
|
68
78
|
|
|
79
|
+
# Handle deprecated dict form of image_keys
|
|
80
|
+
deprecated_image_keys_dict: dict[str, int] | None = None
|
|
81
|
+
if isinstance(image_keys, dict):
|
|
82
|
+
warnings.warn(
|
|
83
|
+
"Passing image_keys as a dict is deprecated and will be removed on "
|
|
84
|
+
"2026-04-01. Use image_keys as a list[str] and set "
|
|
85
|
+
"num_timesteps_per_forward_pass instead.",
|
|
86
|
+
FutureWarning,
|
|
87
|
+
stacklevel=2,
|
|
88
|
+
)
|
|
89
|
+
deprecated_image_keys_dict = image_keys
|
|
90
|
+
image_keys = None # Will use deprecated path in forward
|
|
91
|
+
|
|
69
92
|
super().__init__()
|
|
70
93
|
self.encoder = encoder
|
|
71
|
-
self.
|
|
94
|
+
self.num_timesteps_per_forward_pass = num_timesteps_per_forward_pass
|
|
95
|
+
# Store deprecated parameters for runtime conversion
|
|
96
|
+
self._deprecated_image_channels = image_channels
|
|
97
|
+
self._deprecated_image_keys_dict = deprecated_image_keys_dict
|
|
72
98
|
self.op = op
|
|
73
99
|
self.groups = groups
|
|
74
|
-
|
|
75
|
-
|
|
100
|
+
# Normalize image_key to image_keys list form
|
|
101
|
+
if image_keys is not None:
|
|
102
|
+
self.image_keys = image_keys
|
|
103
|
+
else:
|
|
104
|
+
self.image_keys = [image_key]
|
|
76
105
|
|
|
77
106
|
if backbone_channels is not None:
|
|
78
107
|
out_channels = backbone_channels
|
|
@@ -163,24 +192,25 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
163
192
|
return out_channels
|
|
164
193
|
|
|
165
194
|
def _get_batched_images(
|
|
166
|
-
self, input_dicts: list[dict[str, Any]], image_key: str,
|
|
195
|
+
self, input_dicts: list[dict[str, Any]], image_key: str, num_timesteps: int
|
|
167
196
|
) -> list[RasterImage]:
|
|
168
197
|
"""Collect and reshape images across input dicts.
|
|
169
198
|
|
|
170
199
|
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
171
200
|
the forward pass of a per-image (unitemporal) model.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
input_dicts: list of input dictionaries containing RasterImage objects.
|
|
204
|
+
image_key: the key to access the RasterImage in each input dict.
|
|
205
|
+
num_timesteps: how many timesteps to batch together per forward pass.
|
|
172
206
|
"""
|
|
173
207
|
images = torch.stack(
|
|
174
208
|
[input_dict[image_key].image for input_dict in input_dicts], dim=0
|
|
175
209
|
) # B, C, T, H, W
|
|
176
210
|
timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
|
|
177
|
-
#
|
|
178
|
-
#
|
|
179
|
-
#
|
|
180
|
-
# want to pass 2 timesteps to the model.
|
|
181
|
-
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
|
-
# this for now to not break things.
|
|
183
|
-
num_timesteps = image_channels // images.shape[1]
|
|
211
|
+
# num_timesteps specifies how many timesteps to batch together per forward pass.
|
|
212
|
+
# For example, if the input has 8 timesteps and num_timesteps=4, we do 2
|
|
213
|
+
# forward passes, each with 4 timesteps batched together.
|
|
184
214
|
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
215
|
images = rearrange(
|
|
186
216
|
images,
|
|
@@ -222,10 +252,22 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
222
252
|
n_batch = len(context.inputs)
|
|
223
253
|
n_images: int | None = None
|
|
224
254
|
|
|
225
|
-
if self.
|
|
226
|
-
|
|
255
|
+
if self._deprecated_image_keys_dict is not None:
|
|
256
|
+
# Deprecated dict form: each key has its own channels_per_timestep.
|
|
257
|
+
# The channels_per_timestep could be used to group multiple timesteps,
|
|
258
|
+
# together, so we need to divide by the actual image channel count to get
|
|
259
|
+
# the number of timesteps to be grouped.
|
|
260
|
+
for (
|
|
261
|
+
image_key,
|
|
262
|
+
channels_per_timestep,
|
|
263
|
+
) in self._deprecated_image_keys_dict.items():
|
|
264
|
+
# For deprecated image_keys dict, the value is channels per timestep,
|
|
265
|
+
# so we need to compute num_timesteps from the actual image channels
|
|
266
|
+
sample_image = context.inputs[0][image_key].image
|
|
267
|
+
actual_channels = sample_image.shape[0] # C in CTHW
|
|
268
|
+
num_timesteps = channels_per_timestep // actual_channels
|
|
227
269
|
batched_images = self._get_batched_images(
|
|
228
|
-
context.inputs, image_key,
|
|
270
|
+
context.inputs, image_key, num_timesteps
|
|
229
271
|
)
|
|
230
272
|
|
|
231
273
|
if batched_inputs is None:
|
|
@@ -240,12 +282,32 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
240
282
|
batched_inputs[i][image_key] = image
|
|
241
283
|
|
|
242
284
|
else:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
285
|
+
# Determine num_timesteps - either from deprecated image_channels or
|
|
286
|
+
# directly from num_timesteps_per_forward_pass
|
|
287
|
+
if self._deprecated_image_channels is not None:
|
|
288
|
+
# Backwards compatibility: compute num_timesteps from image_channels
|
|
289
|
+
# (which should be a multiple of the actual per-timestep channels).
|
|
290
|
+
sample_image = context.inputs[0][self.image_keys[0]].image
|
|
291
|
+
actual_channels = sample_image.shape[0] # C in CTHW
|
|
292
|
+
num_timesteps = self._deprecated_image_channels // actual_channels
|
|
293
|
+
else:
|
|
294
|
+
num_timesteps = self.num_timesteps_per_forward_pass
|
|
295
|
+
|
|
296
|
+
for image_key in self.image_keys:
|
|
297
|
+
batched_images = self._get_batched_images(
|
|
298
|
+
context.inputs, image_key, num_timesteps
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if batched_inputs is None:
|
|
302
|
+
batched_inputs = [{} for _ in batched_images]
|
|
303
|
+
n_images = len(batched_images) // n_batch
|
|
304
|
+
elif n_images != len(batched_images) // n_batch:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
"expected all modalities to have the same number of timesteps"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
for i, image in enumerate(batched_images):
|
|
310
|
+
batched_inputs[i][image_key] = image
|
|
249
311
|
|
|
250
312
|
assert n_images is not None
|
|
251
313
|
# Now we can apply the underlying FeatureExtractor.
|
rslearn/train/data_module.py
CHANGED
|
@@ -21,6 +21,7 @@ from .all_patches_dataset import (
|
|
|
21
21
|
)
|
|
22
22
|
from .dataset import (
|
|
23
23
|
DataInput,
|
|
24
|
+
IndexMode,
|
|
24
25
|
ModelDataset,
|
|
25
26
|
MultiDataset,
|
|
26
27
|
RetryDataset,
|
|
@@ -69,6 +70,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
69
70
|
name: str | None = None,
|
|
70
71
|
retries: int = 0,
|
|
71
72
|
use_in_memory_all_patches_dataset: bool = False,
|
|
73
|
+
index_mode: IndexMode = IndexMode.OFF,
|
|
72
74
|
) -> None:
|
|
73
75
|
"""Initialize a new RslearnDataModule.
|
|
74
76
|
|
|
@@ -92,6 +94,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
92
94
|
retries: number of retries to attempt for getitem calls
|
|
93
95
|
use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
|
|
94
96
|
instead of IterableAllPatchesDataset if load_all_patches is set to true.
|
|
97
|
+
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
95
98
|
"""
|
|
96
99
|
super().__init__()
|
|
97
100
|
self.inputs = inputs
|
|
@@ -103,6 +106,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
103
106
|
self.name = name
|
|
104
107
|
self.retries = retries
|
|
105
108
|
self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
|
|
109
|
+
self.index_mode = index_mode
|
|
106
110
|
self.split_configs = {
|
|
107
111
|
"train": default_config.update(train_config),
|
|
108
112
|
"val": default_config.update(val_config),
|
|
@@ -138,6 +142,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
138
142
|
workers=self.init_workers,
|
|
139
143
|
name=self.name,
|
|
140
144
|
fix_patch_pick=(split != "train"),
|
|
145
|
+
index_mode=self.index_mode,
|
|
141
146
|
)
|
|
142
147
|
logger.info(f"got {len(dataset)} examples in split {split}")
|
|
143
148
|
if split_config.get_load_all_patches():
|
rslearn/train/dataset.py
CHANGED
|
@@ -9,6 +9,7 @@ import tempfile
|
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
11
|
from datetime import datetime
|
|
12
|
+
from enum import StrEnum
|
|
12
13
|
from typing import Any
|
|
13
14
|
|
|
14
15
|
import torch
|
|
@@ -29,6 +30,7 @@ from rslearn.dataset.window import (
|
|
|
29
30
|
get_layer_and_group_from_dir_name,
|
|
30
31
|
)
|
|
31
32
|
from rslearn.log_utils import get_logger
|
|
33
|
+
from rslearn.train.dataset_index import DatasetIndex
|
|
32
34
|
from rslearn.train.model_context import RasterImage
|
|
33
35
|
from rslearn.utils.feature import Feature
|
|
34
36
|
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
@@ -41,6 +43,19 @@ from .transforms import Sequential
|
|
|
41
43
|
logger = get_logger(__name__)
|
|
42
44
|
|
|
43
45
|
|
|
46
|
+
class IndexMode(StrEnum):
|
|
47
|
+
"""Controls dataset index caching behavior."""
|
|
48
|
+
|
|
49
|
+
OFF = "off"
|
|
50
|
+
"""No caching - always load windows from dataset."""
|
|
51
|
+
|
|
52
|
+
USE = "use"
|
|
53
|
+
"""Use cached index if available, create if not."""
|
|
54
|
+
|
|
55
|
+
REFRESH = "refresh"
|
|
56
|
+
"""Ignore existing cache and rebuild."""
|
|
57
|
+
|
|
58
|
+
|
|
44
59
|
def get_torch_dtype(dtype: DType) -> torch.dtype:
|
|
45
60
|
"""Convert rslearn DType to torch dtype."""
|
|
46
61
|
if dtype == DType.INT32:
|
|
@@ -445,6 +460,7 @@ class SplitConfig:
|
|
|
445
460
|
overlap_ratio: float | None = None,
|
|
446
461
|
load_all_patches: bool | None = None,
|
|
447
462
|
skip_targets: bool | None = None,
|
|
463
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
448
464
|
) -> None:
|
|
449
465
|
"""Initialize a new SplitConfig.
|
|
450
466
|
|
|
@@ -467,6 +483,10 @@ class SplitConfig:
|
|
|
467
483
|
for each window, read all patches as separate sequential items in the
|
|
468
484
|
dataset.
|
|
469
485
|
skip_targets: whether to skip targets when loading inputs
|
|
486
|
+
output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
|
|
487
|
+
If set, windows that already
|
|
488
|
+
have this layer completed will be skipped (useful for resuming
|
|
489
|
+
partial inference runs).
|
|
470
490
|
"""
|
|
471
491
|
self.groups = groups
|
|
472
492
|
self.names = names
|
|
@@ -477,6 +497,9 @@ class SplitConfig:
|
|
|
477
497
|
self.sampler = sampler
|
|
478
498
|
self.patch_size = patch_size
|
|
479
499
|
self.skip_targets = skip_targets
|
|
500
|
+
self.output_layer_name_skip_inference_if_exists = (
|
|
501
|
+
output_layer_name_skip_inference_if_exists
|
|
502
|
+
)
|
|
480
503
|
|
|
481
504
|
# Note that load_all_patches are handled by the RslearnDataModule rather than
|
|
482
505
|
# the ModelDataset.
|
|
@@ -504,6 +527,7 @@ class SplitConfig:
|
|
|
504
527
|
overlap_ratio=self.overlap_ratio,
|
|
505
528
|
load_all_patches=self.load_all_patches,
|
|
506
529
|
skip_targets=self.skip_targets,
|
|
530
|
+
output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
|
|
507
531
|
)
|
|
508
532
|
if other.groups:
|
|
509
533
|
result.groups = other.groups
|
|
@@ -527,6 +551,10 @@ class SplitConfig:
|
|
|
527
551
|
result.load_all_patches = other.load_all_patches
|
|
528
552
|
if other.skip_targets is not None:
|
|
529
553
|
result.skip_targets = other.skip_targets
|
|
554
|
+
if other.output_layer_name_skip_inference_if_exists is not None:
|
|
555
|
+
result.output_layer_name_skip_inference_if_exists = (
|
|
556
|
+
other.output_layer_name_skip_inference_if_exists
|
|
557
|
+
)
|
|
530
558
|
return result
|
|
531
559
|
|
|
532
560
|
def get_patch_size(self) -> tuple[int, int] | None:
|
|
@@ -549,16 +577,26 @@ class SplitConfig:
|
|
|
549
577
|
"""Returns whether skip_targets is enabled (default False)."""
|
|
550
578
|
return True if self.skip_targets is True else False
|
|
551
579
|
|
|
580
|
+
def get_output_layer_name_skip_inference_if_exists(self) -> str | None:
|
|
581
|
+
"""Returns output layer to use for resume checks (default None)."""
|
|
582
|
+
return self.output_layer_name_skip_inference_if_exists
|
|
583
|
+
|
|
552
584
|
|
|
553
|
-
def check_window(
|
|
585
|
+
def check_window(
|
|
586
|
+
inputs: dict[str, DataInput],
|
|
587
|
+
window: Window,
|
|
588
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
589
|
+
) -> Window | None:
|
|
554
590
|
"""Verify that the window has the required layers based on the specified inputs.
|
|
555
591
|
|
|
556
592
|
Args:
|
|
557
593
|
inputs: the inputs to the dataset.
|
|
558
594
|
window: the window to check.
|
|
595
|
+
output_layer_name_skip_inference_if_exists: optional name of the output layer to check for existence.
|
|
559
596
|
|
|
560
597
|
Returns:
|
|
561
|
-
the window if it has all the required inputs
|
|
598
|
+
the window if it has all the required inputs and does not need to be skipped
|
|
599
|
+
due to an existing output layer; or None otherwise
|
|
562
600
|
"""
|
|
563
601
|
|
|
564
602
|
# Make sure window has all the needed layers.
|
|
@@ -588,6 +626,16 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
|
|
|
588
626
|
)
|
|
589
627
|
return None
|
|
590
628
|
|
|
629
|
+
# Optionally skip windows that already have the specified output layer completed.
|
|
630
|
+
if output_layer_name_skip_inference_if_exists is not None:
|
|
631
|
+
if window.is_layer_completed(output_layer_name_skip_inference_if_exists):
|
|
632
|
+
logger.debug(
|
|
633
|
+
"Skipping window %s since output layer '%s' already exists",
|
|
634
|
+
window.name,
|
|
635
|
+
output_layer_name_skip_inference_if_exists,
|
|
636
|
+
)
|
|
637
|
+
return None
|
|
638
|
+
|
|
591
639
|
return window
|
|
592
640
|
|
|
593
641
|
|
|
@@ -603,6 +651,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
603
651
|
workers: int,
|
|
604
652
|
name: str | None = None,
|
|
605
653
|
fix_patch_pick: bool = False,
|
|
654
|
+
index_mode: IndexMode = IndexMode.OFF,
|
|
606
655
|
) -> None:
|
|
607
656
|
"""Instantiate a new ModelDataset.
|
|
608
657
|
|
|
@@ -612,9 +661,10 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
612
661
|
inputs: data to read from the dataset for training
|
|
613
662
|
task: the task to train on
|
|
614
663
|
workers: number of workers to use for initializing the dataset
|
|
615
|
-
name: name of the dataset
|
|
664
|
+
name: name of the dataset
|
|
616
665
|
fix_patch_pick: if True, fix the patch pick to be the same every time
|
|
617
666
|
for a given window. Useful for testing (default: False)
|
|
667
|
+
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
618
668
|
"""
|
|
619
669
|
self.dataset = dataset
|
|
620
670
|
self.split_config = split_config
|
|
@@ -635,58 +685,14 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
635
685
|
else:
|
|
636
686
|
self.patch_size = split_config.get_patch_size()
|
|
637
687
|
|
|
638
|
-
windows = self._get_initial_windows(split_config, workers)
|
|
639
|
-
|
|
640
688
|
# If targets are not needed, remove them from the inputs.
|
|
641
689
|
if split_config.get_skip_targets():
|
|
642
690
|
for k in list(self.inputs.keys()):
|
|
643
691
|
if self.inputs[k].is_target:
|
|
644
692
|
del self.inputs[k]
|
|
645
693
|
|
|
646
|
-
#
|
|
647
|
-
|
|
648
|
-
new_windows = []
|
|
649
|
-
if workers == 0:
|
|
650
|
-
for window in windows:
|
|
651
|
-
if check_window(self.inputs, window) is None:
|
|
652
|
-
continue
|
|
653
|
-
new_windows.append(window)
|
|
654
|
-
else:
|
|
655
|
-
p = multiprocessing.Pool(workers)
|
|
656
|
-
outputs = star_imap_unordered(
|
|
657
|
-
p,
|
|
658
|
-
check_window,
|
|
659
|
-
[
|
|
660
|
-
dict(
|
|
661
|
-
inputs=self.inputs,
|
|
662
|
-
window=window,
|
|
663
|
-
)
|
|
664
|
-
for window in windows
|
|
665
|
-
],
|
|
666
|
-
)
|
|
667
|
-
for window in tqdm.tqdm(
|
|
668
|
-
outputs, total=len(windows), desc="Checking available layers in windows"
|
|
669
|
-
):
|
|
670
|
-
if window is None:
|
|
671
|
-
continue
|
|
672
|
-
new_windows.append(window)
|
|
673
|
-
p.close()
|
|
674
|
-
windows = new_windows
|
|
675
|
-
|
|
676
|
-
# Sort the windows to ensure that the dataset is consistent across GPUs.
|
|
677
|
-
# Inconsistent ordering can lead to a subset of windows being processed during
|
|
678
|
-
# "model test" / "model predict" when using multiple GPUs.
|
|
679
|
-
# We use a hash so that functionality like num_samples limit gets a random
|
|
680
|
-
# subset of windows (with respect to the hash function choice).
|
|
681
|
-
windows.sort(
|
|
682
|
-
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
683
|
-
)
|
|
684
|
-
|
|
685
|
-
# Limit windows to num_samples if requested.
|
|
686
|
-
if split_config.num_samples:
|
|
687
|
-
# The windows are sorted by hash of window name so this distribution should
|
|
688
|
-
# be representative of the population.
|
|
689
|
-
windows = windows[0 : split_config.num_samples]
|
|
694
|
+
# Load windows (from index if available, otherwise from dataset)
|
|
695
|
+
windows = self._load_windows(split_config, workers, index_mode)
|
|
690
696
|
|
|
691
697
|
# Write dataset_examples to a file so that we can load it lazily in the worker
|
|
692
698
|
# processes. Otherwise it takes a long time to transmit it when spawning each
|
|
@@ -755,6 +761,137 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
755
761
|
|
|
756
762
|
return windows
|
|
757
763
|
|
|
764
|
+
def _load_windows(
|
|
765
|
+
self,
|
|
766
|
+
split_config: SplitConfig,
|
|
767
|
+
workers: int,
|
|
768
|
+
index_mode: IndexMode,
|
|
769
|
+
) -> list[Window]:
|
|
770
|
+
"""Load windows, using index if available.
|
|
771
|
+
|
|
772
|
+
This method handles:
|
|
773
|
+
1. Loading from index if index_mode is USE and index exists
|
|
774
|
+
2. Otherwise, loading from dataset, filtering, sorting, limiting
|
|
775
|
+
3. Saving to index if index_mode is USE or REFRESH
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
split_config: the split configuration.
|
|
779
|
+
workers: number of worker processes.
|
|
780
|
+
index_mode: controls caching behavior.
|
|
781
|
+
|
|
782
|
+
Returns:
|
|
783
|
+
list of processed windows ready for training.
|
|
784
|
+
"""
|
|
785
|
+
# Try to load from index
|
|
786
|
+
index: DatasetIndex | None = None
|
|
787
|
+
|
|
788
|
+
if index_mode != IndexMode.OFF:
|
|
789
|
+
logger.info(f"Checking index for dataset {self.dataset.path}")
|
|
790
|
+
index = DatasetIndex(
|
|
791
|
+
storage=self.dataset.storage,
|
|
792
|
+
dataset_path=self.dataset.path,
|
|
793
|
+
groups=split_config.groups,
|
|
794
|
+
names=split_config.names,
|
|
795
|
+
tags=split_config.tags,
|
|
796
|
+
num_samples=split_config.num_samples,
|
|
797
|
+
skip_targets=split_config.get_skip_targets(),
|
|
798
|
+
inputs=self.inputs,
|
|
799
|
+
)
|
|
800
|
+
refresh = index_mode == IndexMode.REFRESH
|
|
801
|
+
indexed_windows = index.load_windows(refresh)
|
|
802
|
+
|
|
803
|
+
if indexed_windows is not None:
|
|
804
|
+
logger.info(f"Loaded {len(indexed_windows)} windows from index")
|
|
805
|
+
return indexed_windows
|
|
806
|
+
|
|
807
|
+
# No index available, load and process windows from dataset
|
|
808
|
+
logger.debug("Loading windows from dataset...")
|
|
809
|
+
windows = self._get_initial_windows(split_config, workers)
|
|
810
|
+
windows = self._filter_windows_by_layers(windows, workers)
|
|
811
|
+
windows = self._sort_and_limit_windows(windows, split_config)
|
|
812
|
+
|
|
813
|
+
# Save to index if enabled
|
|
814
|
+
if index is not None:
|
|
815
|
+
index.save_windows(windows)
|
|
816
|
+
|
|
817
|
+
return windows
|
|
818
|
+
|
|
819
|
+
def _filter_windows_by_layers(
|
|
820
|
+
self, windows: list[Window], workers: int
|
|
821
|
+
) -> list[Window]:
|
|
822
|
+
"""Filter windows to only include those with required layers.
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
windows: list of windows to filter.
|
|
826
|
+
workers: number of worker processes for parallel filtering.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
list of windows that have all required input layers.
|
|
830
|
+
"""
|
|
831
|
+
output_layer_skip = (
|
|
832
|
+
self.split_config.get_output_layer_name_skip_inference_if_exists()
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
if workers == 0:
|
|
836
|
+
return [
|
|
837
|
+
w
|
|
838
|
+
for w in windows
|
|
839
|
+
if check_window(
|
|
840
|
+
self.inputs,
|
|
841
|
+
w,
|
|
842
|
+
output_layer_name_skip_inference_if_exists=output_layer_skip,
|
|
843
|
+
)
|
|
844
|
+
is not None
|
|
845
|
+
]
|
|
846
|
+
|
|
847
|
+
p = multiprocessing.Pool(workers)
|
|
848
|
+
outputs = star_imap_unordered(
|
|
849
|
+
p,
|
|
850
|
+
check_window,
|
|
851
|
+
[
|
|
852
|
+
dict(
|
|
853
|
+
inputs=self.inputs,
|
|
854
|
+
window=window,
|
|
855
|
+
output_layer_name_skip_inference_if_exists=output_layer_skip,
|
|
856
|
+
)
|
|
857
|
+
for window in windows
|
|
858
|
+
],
|
|
859
|
+
)
|
|
860
|
+
filtered = []
|
|
861
|
+
for window in tqdm.tqdm(
|
|
862
|
+
outputs,
|
|
863
|
+
total=len(windows),
|
|
864
|
+
desc="Checking available layers in windows",
|
|
865
|
+
):
|
|
866
|
+
if window is not None:
|
|
867
|
+
filtered.append(window)
|
|
868
|
+
p.close()
|
|
869
|
+
return filtered
|
|
870
|
+
|
|
871
|
+
def _sort_and_limit_windows(
|
|
872
|
+
self, windows: list[Window], split_config: SplitConfig
|
|
873
|
+
) -> list[Window]:
|
|
874
|
+
"""Sort windows by hash and apply num_samples limit.
|
|
875
|
+
|
|
876
|
+
Sorting ensures consistent ordering across GPUs. Using hash gives a
|
|
877
|
+
pseudo-random but deterministic order for sampling.
|
|
878
|
+
|
|
879
|
+
Args:
|
|
880
|
+
windows: list of windows to sort and limit.
|
|
881
|
+
split_config: the split configuration with num_samples.
|
|
882
|
+
|
|
883
|
+
Returns:
|
|
884
|
+
sorted and optionally limited list of windows.
|
|
885
|
+
"""
|
|
886
|
+
windows.sort(
|
|
887
|
+
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
if split_config.num_samples:
|
|
891
|
+
windows = windows[: split_config.num_samples]
|
|
892
|
+
|
|
893
|
+
return windows
|
|
894
|
+
|
|
758
895
|
def _serialize_item(self, example: Window) -> dict[str, Any]:
|
|
759
896
|
return example.get_metadata()
|
|
760
897
|
|