rslearn 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +151 -55
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/METADATA +1 -1
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/RECORD +33 -32
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""SimpleTimeSeries encoder."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -25,13 +26,14 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
28
|
encoder: FeatureExtractor,
|
|
28
|
-
|
|
29
|
+
num_timesteps_per_forward_pass: int = 1,
|
|
29
30
|
op: str = "max",
|
|
30
31
|
groups: list[list[int]] | None = None,
|
|
31
32
|
num_layers: int | None = None,
|
|
32
33
|
image_key: str = "image",
|
|
33
34
|
backbone_channels: list[tuple[int, int]] | None = None,
|
|
34
|
-
image_keys: dict[str, int] | None = None,
|
|
35
|
+
image_keys: list[str] | dict[str, int] | None = None,
|
|
36
|
+
image_channels: int | None = None,
|
|
35
37
|
) -> None:
|
|
36
38
|
"""Create a new SimpleTimeSeries.
|
|
37
39
|
|
|
@@ -39,9 +41,11 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
39
41
|
encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
|
|
40
42
|
function that returns the output channels, or backbone_channels must be set.
|
|
41
43
|
It must output a FeatureMaps.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
num_timesteps_per_forward_pass: how many timesteps to pass to the encoder
|
|
45
|
+
in each forward pass. Defaults to 1 (one timestep per forward pass).
|
|
46
|
+
Set to a higher value to batch multiple timesteps together, e.g. for
|
|
47
|
+
pre/post change detection where you want 4 pre and 4 post images
|
|
48
|
+
processed together.
|
|
45
49
|
op: one of max, mean, convrnn, conv3d, or conv1d
|
|
46
50
|
groups: sets of images for which to combine features. Within each set,
|
|
47
51
|
features are combined using the specified operation; then, across sets,
|
|
@@ -51,28 +55,53 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
51
55
|
combined before features and the combined after features. groups is a
|
|
52
56
|
list of sets, and each set is a list of image indices.
|
|
53
57
|
num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
|
|
54
|
-
image_key: the key to access the images.
|
|
58
|
+
image_key: the key to access the images (used when image_keys is not set).
|
|
55
59
|
backbone_channels: manually specify the backbone channels. Can be set if
|
|
56
60
|
the encoder does not provide get_backbone_channels function.
|
|
57
|
-
image_keys:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
+
image_keys: list of keys in input dict to process as multimodal inputs.
|
|
62
|
+
All keys use the same num_timesteps_per_forward_pass. If not set,
|
|
63
|
+
only the single image_key is used. Passing a dict[str, int] is
|
|
64
|
+
deprecated and will be removed on 2026-04-01.
|
|
65
|
+
image_channels: Deprecated, use num_timesteps_per_forward_pass instead.
|
|
66
|
+
Will be removed on 2026-04-01.
|
|
61
67
|
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"
|
|
68
|
+
# Handle deprecated image_channels parameter
|
|
69
|
+
if image_channels is not None:
|
|
70
|
+
warnings.warn(
|
|
71
|
+
"image_channels is deprecated and will be removed on 2026-04-01. "
|
|
72
|
+
"Use num_timesteps_per_forward_pass instead. The new parameter directly "
|
|
73
|
+
"specifies the number of timesteps per forward pass rather than requiring "
|
|
74
|
+
"image_channels // actual_channels.",
|
|
75
|
+
FutureWarning,
|
|
76
|
+
stacklevel=2,
|
|
67
77
|
)
|
|
68
78
|
|
|
79
|
+
# Handle deprecated dict form of image_keys
|
|
80
|
+
deprecated_image_keys_dict: dict[str, int] | None = None
|
|
81
|
+
if isinstance(image_keys, dict):
|
|
82
|
+
warnings.warn(
|
|
83
|
+
"Passing image_keys as a dict is deprecated and will be removed on "
|
|
84
|
+
"2026-04-01. Use image_keys as a list[str] and set "
|
|
85
|
+
"num_timesteps_per_forward_pass instead.",
|
|
86
|
+
FutureWarning,
|
|
87
|
+
stacklevel=2,
|
|
88
|
+
)
|
|
89
|
+
deprecated_image_keys_dict = image_keys
|
|
90
|
+
image_keys = None # Will use deprecated path in forward
|
|
91
|
+
|
|
69
92
|
super().__init__()
|
|
70
93
|
self.encoder = encoder
|
|
71
|
-
self.
|
|
94
|
+
self.num_timesteps_per_forward_pass = num_timesteps_per_forward_pass
|
|
95
|
+
# Store deprecated parameters for runtime conversion
|
|
96
|
+
self._deprecated_image_channels = image_channels
|
|
97
|
+
self._deprecated_image_keys_dict = deprecated_image_keys_dict
|
|
72
98
|
self.op = op
|
|
73
99
|
self.groups = groups
|
|
74
|
-
|
|
75
|
-
|
|
100
|
+
# Normalize image_key to image_keys list form
|
|
101
|
+
if image_keys is not None:
|
|
102
|
+
self.image_keys = image_keys
|
|
103
|
+
else:
|
|
104
|
+
self.image_keys = [image_key]
|
|
76
105
|
|
|
77
106
|
if backbone_channels is not None:
|
|
78
107
|
out_channels = backbone_channels
|
|
@@ -163,24 +192,25 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
163
192
|
return out_channels
|
|
164
193
|
|
|
165
194
|
def _get_batched_images(
|
|
166
|
-
self, input_dicts: list[dict[str, Any]], image_key: str,
|
|
195
|
+
self, input_dicts: list[dict[str, Any]], image_key: str, num_timesteps: int
|
|
167
196
|
) -> list[RasterImage]:
|
|
168
197
|
"""Collect and reshape images across input dicts.
|
|
169
198
|
|
|
170
199
|
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
171
200
|
the forward pass of a per-image (unitemporal) model.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
input_dicts: list of input dictionaries containing RasterImage objects.
|
|
204
|
+
image_key: the key to access the RasterImage in each input dict.
|
|
205
|
+
num_timesteps: how many timesteps to batch together per forward pass.
|
|
172
206
|
"""
|
|
173
207
|
images = torch.stack(
|
|
174
208
|
[input_dict[image_key].image for input_dict in input_dicts], dim=0
|
|
175
209
|
) # B, C, T, H, W
|
|
176
210
|
timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
|
|
177
|
-
#
|
|
178
|
-
#
|
|
179
|
-
#
|
|
180
|
-
# want to pass 2 timesteps to the model.
|
|
181
|
-
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
|
-
# this for now to not break things.
|
|
183
|
-
num_timesteps = image_channels // images.shape[1]
|
|
211
|
+
# num_timesteps specifies how many timesteps to batch together per forward pass.
|
|
212
|
+
# For example, if the input has 8 timesteps and num_timesteps=4, we do 2
|
|
213
|
+
# forward passes, each with 4 timesteps batched together.
|
|
184
214
|
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
215
|
images = rearrange(
|
|
186
216
|
images,
|
|
@@ -222,10 +252,22 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
222
252
|
n_batch = len(context.inputs)
|
|
223
253
|
n_images: int | None = None
|
|
224
254
|
|
|
225
|
-
if self.
|
|
226
|
-
|
|
255
|
+
if self._deprecated_image_keys_dict is not None:
|
|
256
|
+
# Deprecated dict form: each key has its own channels_per_timestep.
|
|
257
|
+
# The channels_per_timestep could be used to group multiple timesteps,
|
|
258
|
+
# together, so we need to divide by the actual image channel count to get
|
|
259
|
+
# the number of timesteps to be grouped.
|
|
260
|
+
for (
|
|
261
|
+
image_key,
|
|
262
|
+
channels_per_timestep,
|
|
263
|
+
) in self._deprecated_image_keys_dict.items():
|
|
264
|
+
# For deprecated image_keys dict, the value is channels per timestep,
|
|
265
|
+
# so we need to compute num_timesteps from the actual image channels
|
|
266
|
+
sample_image = context.inputs[0][image_key].image
|
|
267
|
+
actual_channels = sample_image.shape[0] # C in CTHW
|
|
268
|
+
num_timesteps = channels_per_timestep // actual_channels
|
|
227
269
|
batched_images = self._get_batched_images(
|
|
228
|
-
context.inputs, image_key,
|
|
270
|
+
context.inputs, image_key, num_timesteps
|
|
229
271
|
)
|
|
230
272
|
|
|
231
273
|
if batched_inputs is None:
|
|
@@ -240,12 +282,32 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
240
282
|
batched_inputs[i][image_key] = image
|
|
241
283
|
|
|
242
284
|
else:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
285
|
+
# Determine num_timesteps - either from deprecated image_channels or
|
|
286
|
+
# directly from num_timesteps_per_forward_pass
|
|
287
|
+
if self._deprecated_image_channels is not None:
|
|
288
|
+
# Backwards compatibility: compute num_timesteps from image_channels
|
|
289
|
+
# (which should be a multiple of the actual per-timestep channels).
|
|
290
|
+
sample_image = context.inputs[0][self.image_keys[0]].image
|
|
291
|
+
actual_channels = sample_image.shape[0] # C in CTHW
|
|
292
|
+
num_timesteps = self._deprecated_image_channels // actual_channels
|
|
293
|
+
else:
|
|
294
|
+
num_timesteps = self.num_timesteps_per_forward_pass
|
|
295
|
+
|
|
296
|
+
for image_key in self.image_keys:
|
|
297
|
+
batched_images = self._get_batched_images(
|
|
298
|
+
context.inputs, image_key, num_timesteps
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if batched_inputs is None:
|
|
302
|
+
batched_inputs = [{} for _ in batched_images]
|
|
303
|
+
n_images = len(batched_images) // n_batch
|
|
304
|
+
elif n_images != len(batched_images) // n_batch:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
"expected all modalities to have the same number of timesteps"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
for i, image in enumerate(batched_images):
|
|
310
|
+
batched_inputs[i][image_key] = image
|
|
249
311
|
|
|
250
312
|
assert n_images is not None
|
|
251
313
|
# Now we can apply the underlying FeatureExtractor.
|
rslearn/train/data_module.py
CHANGED
|
@@ -21,6 +21,7 @@ from .all_patches_dataset import (
|
|
|
21
21
|
)
|
|
22
22
|
from .dataset import (
|
|
23
23
|
DataInput,
|
|
24
|
+
IndexMode,
|
|
24
25
|
ModelDataset,
|
|
25
26
|
MultiDataset,
|
|
26
27
|
RetryDataset,
|
|
@@ -69,6 +70,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
69
70
|
name: str | None = None,
|
|
70
71
|
retries: int = 0,
|
|
71
72
|
use_in_memory_all_patches_dataset: bool = False,
|
|
73
|
+
index_mode: IndexMode = IndexMode.OFF,
|
|
72
74
|
) -> None:
|
|
73
75
|
"""Initialize a new RslearnDataModule.
|
|
74
76
|
|
|
@@ -92,6 +94,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
92
94
|
retries: number of retries to attempt for getitem calls
|
|
93
95
|
use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
|
|
94
96
|
instead of IterableAllPatchesDataset if load_all_patches is set to true.
|
|
97
|
+
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
95
98
|
"""
|
|
96
99
|
super().__init__()
|
|
97
100
|
self.inputs = inputs
|
|
@@ -103,6 +106,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
103
106
|
self.name = name
|
|
104
107
|
self.retries = retries
|
|
105
108
|
self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
|
|
109
|
+
self.index_mode = index_mode
|
|
106
110
|
self.split_configs = {
|
|
107
111
|
"train": default_config.update(train_config),
|
|
108
112
|
"val": default_config.update(val_config),
|
|
@@ -138,6 +142,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
138
142
|
workers=self.init_workers,
|
|
139
143
|
name=self.name,
|
|
140
144
|
fix_patch_pick=(split != "train"),
|
|
145
|
+
index_mode=self.index_mode,
|
|
141
146
|
)
|
|
142
147
|
logger.info(f"got {len(dataset)} examples in split {split}")
|
|
143
148
|
if split_config.get_load_all_patches():
|
rslearn/train/dataset.py
CHANGED
|
@@ -9,6 +9,7 @@ import tempfile
|
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
11
|
from datetime import datetime
|
|
12
|
+
from enum import StrEnum
|
|
12
13
|
from typing import Any
|
|
13
14
|
|
|
14
15
|
import torch
|
|
@@ -29,6 +30,7 @@ from rslearn.dataset.window import (
|
|
|
29
30
|
get_layer_and_group_from_dir_name,
|
|
30
31
|
)
|
|
31
32
|
from rslearn.log_utils import get_logger
|
|
33
|
+
from rslearn.train.dataset_index import DatasetIndex
|
|
32
34
|
from rslearn.train.model_context import RasterImage
|
|
33
35
|
from rslearn.utils.feature import Feature
|
|
34
36
|
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
@@ -41,6 +43,19 @@ from .transforms import Sequential
|
|
|
41
43
|
logger = get_logger(__name__)
|
|
42
44
|
|
|
43
45
|
|
|
46
|
+
class IndexMode(StrEnum):
|
|
47
|
+
"""Controls dataset index caching behavior."""
|
|
48
|
+
|
|
49
|
+
OFF = "off"
|
|
50
|
+
"""No caching - always load windows from dataset."""
|
|
51
|
+
|
|
52
|
+
USE = "use"
|
|
53
|
+
"""Use cached index if available, create if not."""
|
|
54
|
+
|
|
55
|
+
REFRESH = "refresh"
|
|
56
|
+
"""Ignore existing cache and rebuild."""
|
|
57
|
+
|
|
58
|
+
|
|
44
59
|
def get_torch_dtype(dtype: DType) -> torch.dtype:
|
|
45
60
|
"""Convert rslearn DType to torch dtype."""
|
|
46
61
|
if dtype == DType.INT32:
|
|
@@ -636,6 +651,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
636
651
|
workers: int,
|
|
637
652
|
name: str | None = None,
|
|
638
653
|
fix_patch_pick: bool = False,
|
|
654
|
+
index_mode: IndexMode = IndexMode.OFF,
|
|
639
655
|
) -> None:
|
|
640
656
|
"""Instantiate a new ModelDataset.
|
|
641
657
|
|
|
@@ -645,9 +661,10 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
645
661
|
inputs: data to read from the dataset for training
|
|
646
662
|
task: the task to train on
|
|
647
663
|
workers: number of workers to use for initializing the dataset
|
|
648
|
-
name: name of the dataset
|
|
664
|
+
name: name of the dataset
|
|
649
665
|
fix_patch_pick: if True, fix the patch pick to be the same every time
|
|
650
666
|
for a given window. Useful for testing (default: False)
|
|
667
|
+
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
651
668
|
"""
|
|
652
669
|
self.dataset = dataset
|
|
653
670
|
self.split_config = split_config
|
|
@@ -668,66 +685,14 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
668
685
|
else:
|
|
669
686
|
self.patch_size = split_config.get_patch_size()
|
|
670
687
|
|
|
671
|
-
windows = self._get_initial_windows(split_config, workers)
|
|
672
|
-
|
|
673
688
|
# If targets are not needed, remove them from the inputs.
|
|
674
689
|
if split_config.get_skip_targets():
|
|
675
690
|
for k in list(self.inputs.keys()):
|
|
676
691
|
if self.inputs[k].is_target:
|
|
677
692
|
del self.inputs[k]
|
|
678
693
|
|
|
679
|
-
#
|
|
680
|
-
|
|
681
|
-
new_windows = []
|
|
682
|
-
if workers == 0:
|
|
683
|
-
for window in windows:
|
|
684
|
-
if (
|
|
685
|
-
check_window(
|
|
686
|
-
self.inputs,
|
|
687
|
-
window,
|
|
688
|
-
output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
|
|
689
|
-
)
|
|
690
|
-
is None
|
|
691
|
-
):
|
|
692
|
-
continue
|
|
693
|
-
new_windows.append(window)
|
|
694
|
-
else:
|
|
695
|
-
p = multiprocessing.Pool(workers)
|
|
696
|
-
outputs = star_imap_unordered(
|
|
697
|
-
p,
|
|
698
|
-
check_window,
|
|
699
|
-
[
|
|
700
|
-
dict(
|
|
701
|
-
inputs=self.inputs,
|
|
702
|
-
window=window,
|
|
703
|
-
output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
|
|
704
|
-
)
|
|
705
|
-
for window in windows
|
|
706
|
-
],
|
|
707
|
-
)
|
|
708
|
-
for window in tqdm.tqdm(
|
|
709
|
-
outputs, total=len(windows), desc="Checking available layers in windows"
|
|
710
|
-
):
|
|
711
|
-
if window is None:
|
|
712
|
-
continue
|
|
713
|
-
new_windows.append(window)
|
|
714
|
-
p.close()
|
|
715
|
-
windows = new_windows
|
|
716
|
-
|
|
717
|
-
# Sort the windows to ensure that the dataset is consistent across GPUs.
|
|
718
|
-
# Inconsistent ordering can lead to a subset of windows being processed during
|
|
719
|
-
# "model test" / "model predict" when using multiple GPUs.
|
|
720
|
-
# We use a hash so that functionality like num_samples limit gets a random
|
|
721
|
-
# subset of windows (with respect to the hash function choice).
|
|
722
|
-
windows.sort(
|
|
723
|
-
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
724
|
-
)
|
|
725
|
-
|
|
726
|
-
# Limit windows to num_samples if requested.
|
|
727
|
-
if split_config.num_samples:
|
|
728
|
-
# The windows are sorted by hash of window name so this distribution should
|
|
729
|
-
# be representative of the population.
|
|
730
|
-
windows = windows[0 : split_config.num_samples]
|
|
694
|
+
# Load windows (from index if available, otherwise from dataset)
|
|
695
|
+
windows = self._load_windows(split_config, workers, index_mode)
|
|
731
696
|
|
|
732
697
|
# Write dataset_examples to a file so that we can load it lazily in the worker
|
|
733
698
|
# processes. Otherwise it takes a long time to transmit it when spawning each
|
|
@@ -796,6 +761,137 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
796
761
|
|
|
797
762
|
return windows
|
|
798
763
|
|
|
764
|
+
def _load_windows(
|
|
765
|
+
self,
|
|
766
|
+
split_config: SplitConfig,
|
|
767
|
+
workers: int,
|
|
768
|
+
index_mode: IndexMode,
|
|
769
|
+
) -> list[Window]:
|
|
770
|
+
"""Load windows, using index if available.
|
|
771
|
+
|
|
772
|
+
This method handles:
|
|
773
|
+
1. Loading from index if index_mode is USE and index exists
|
|
774
|
+
2. Otherwise, loading from dataset, filtering, sorting, limiting
|
|
775
|
+
3. Saving to index if index_mode is USE or REFRESH
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
split_config: the split configuration.
|
|
779
|
+
workers: number of worker processes.
|
|
780
|
+
index_mode: controls caching behavior.
|
|
781
|
+
|
|
782
|
+
Returns:
|
|
783
|
+
list of processed windows ready for training.
|
|
784
|
+
"""
|
|
785
|
+
# Try to load from index
|
|
786
|
+
index: DatasetIndex | None = None
|
|
787
|
+
|
|
788
|
+
if index_mode != IndexMode.OFF:
|
|
789
|
+
logger.info(f"Checking index for dataset {self.dataset.path}")
|
|
790
|
+
index = DatasetIndex(
|
|
791
|
+
storage=self.dataset.storage,
|
|
792
|
+
dataset_path=self.dataset.path,
|
|
793
|
+
groups=split_config.groups,
|
|
794
|
+
names=split_config.names,
|
|
795
|
+
tags=split_config.tags,
|
|
796
|
+
num_samples=split_config.num_samples,
|
|
797
|
+
skip_targets=split_config.get_skip_targets(),
|
|
798
|
+
inputs=self.inputs,
|
|
799
|
+
)
|
|
800
|
+
refresh = index_mode == IndexMode.REFRESH
|
|
801
|
+
indexed_windows = index.load_windows(refresh)
|
|
802
|
+
|
|
803
|
+
if indexed_windows is not None:
|
|
804
|
+
logger.info(f"Loaded {len(indexed_windows)} windows from index")
|
|
805
|
+
return indexed_windows
|
|
806
|
+
|
|
807
|
+
# No index available, load and process windows from dataset
|
|
808
|
+
logger.debug("Loading windows from dataset...")
|
|
809
|
+
windows = self._get_initial_windows(split_config, workers)
|
|
810
|
+
windows = self._filter_windows_by_layers(windows, workers)
|
|
811
|
+
windows = self._sort_and_limit_windows(windows, split_config)
|
|
812
|
+
|
|
813
|
+
# Save to index if enabled
|
|
814
|
+
if index is not None:
|
|
815
|
+
index.save_windows(windows)
|
|
816
|
+
|
|
817
|
+
return windows
|
|
818
|
+
|
|
819
|
+
def _filter_windows_by_layers(
|
|
820
|
+
self, windows: list[Window], workers: int
|
|
821
|
+
) -> list[Window]:
|
|
822
|
+
"""Filter windows to only include those with required layers.
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
windows: list of windows to filter.
|
|
826
|
+
workers: number of worker processes for parallel filtering.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
list of windows that have all required input layers.
|
|
830
|
+
"""
|
|
831
|
+
output_layer_skip = (
|
|
832
|
+
self.split_config.get_output_layer_name_skip_inference_if_exists()
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
if workers == 0:
|
|
836
|
+
return [
|
|
837
|
+
w
|
|
838
|
+
for w in windows
|
|
839
|
+
if check_window(
|
|
840
|
+
self.inputs,
|
|
841
|
+
w,
|
|
842
|
+
output_layer_name_skip_inference_if_exists=output_layer_skip,
|
|
843
|
+
)
|
|
844
|
+
is not None
|
|
845
|
+
]
|
|
846
|
+
|
|
847
|
+
p = multiprocessing.Pool(workers)
|
|
848
|
+
outputs = star_imap_unordered(
|
|
849
|
+
p,
|
|
850
|
+
check_window,
|
|
851
|
+
[
|
|
852
|
+
dict(
|
|
853
|
+
inputs=self.inputs,
|
|
854
|
+
window=window,
|
|
855
|
+
output_layer_name_skip_inference_if_exists=output_layer_skip,
|
|
856
|
+
)
|
|
857
|
+
for window in windows
|
|
858
|
+
],
|
|
859
|
+
)
|
|
860
|
+
filtered = []
|
|
861
|
+
for window in tqdm.tqdm(
|
|
862
|
+
outputs,
|
|
863
|
+
total=len(windows),
|
|
864
|
+
desc="Checking available layers in windows",
|
|
865
|
+
):
|
|
866
|
+
if window is not None:
|
|
867
|
+
filtered.append(window)
|
|
868
|
+
p.close()
|
|
869
|
+
return filtered
|
|
870
|
+
|
|
871
|
+
def _sort_and_limit_windows(
|
|
872
|
+
self, windows: list[Window], split_config: SplitConfig
|
|
873
|
+
) -> list[Window]:
|
|
874
|
+
"""Sort windows by hash and apply num_samples limit.
|
|
875
|
+
|
|
876
|
+
Sorting ensures consistent ordering across GPUs. Using hash gives a
|
|
877
|
+
pseudo-random but deterministic order for sampling.
|
|
878
|
+
|
|
879
|
+
Args:
|
|
880
|
+
windows: list of windows to sort and limit.
|
|
881
|
+
split_config: the split configuration with num_samples.
|
|
882
|
+
|
|
883
|
+
Returns:
|
|
884
|
+
sorted and optionally limited list of windows.
|
|
885
|
+
"""
|
|
886
|
+
windows.sort(
|
|
887
|
+
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
if split_config.num_samples:
|
|
891
|
+
windows = windows[: split_config.num_samples]
|
|
892
|
+
|
|
893
|
+
return windows
|
|
894
|
+
|
|
799
895
|
def _serialize_item(self, example: Window) -> dict[str, Any]:
|
|
800
896
|
return example.get_metadata()
|
|
801
897
|
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Dataset index for caching window lists to speed up ModelDataset initialization."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.dataset.window import Window
|
|
11
|
+
from rslearn.log_utils import get_logger
|
|
12
|
+
from rslearn.utils.fsspec import open_atomic
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
# Increment this when the index format changes to force rebuild
|
|
20
|
+
INDEX_VERSION = 1
|
|
21
|
+
|
|
22
|
+
# Directory name for storing index files
|
|
23
|
+
INDEX_DIR_NAME = ".rslearn_dataset_index"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DatasetIndex:
|
|
27
|
+
"""Manages indexed window lists for faster ModelDataset initialization.
|
|
28
|
+
|
|
29
|
+
Note: The index does NOT automatically detect when windows are added or removed
|
|
30
|
+
from the dataset. Use refresh=True after modifying dataset windows.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
storage: "WindowStorage",
|
|
36
|
+
dataset_path: UPath,
|
|
37
|
+
groups: list[str] | None,
|
|
38
|
+
names: list[str] | None,
|
|
39
|
+
tags: dict[str, Any] | None,
|
|
40
|
+
num_samples: int | None,
|
|
41
|
+
skip_targets: bool,
|
|
42
|
+
inputs: dict[str, Any],
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Initialize DatasetIndex with specific configuration.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
storage: WindowStorage for deserializing windows.
|
|
48
|
+
dataset_path: Path to the dataset directory.
|
|
49
|
+
groups: list of window groups to include.
|
|
50
|
+
names: list of window names to include.
|
|
51
|
+
tags: tags to filter windows by.
|
|
52
|
+
num_samples: limit on number of samples.
|
|
53
|
+
skip_targets: whether targets are skipped.
|
|
54
|
+
inputs: dict mapping input names to DataInput objects.
|
|
55
|
+
"""
|
|
56
|
+
self.storage = storage
|
|
57
|
+
self.dataset_path = dataset_path
|
|
58
|
+
self.index_dir = dataset_path / INDEX_DIR_NAME
|
|
59
|
+
|
|
60
|
+
# Compute index key from configuration
|
|
61
|
+
inputs_data = {}
|
|
62
|
+
for name, inp in inputs.items():
|
|
63
|
+
inputs_data[name] = {
|
|
64
|
+
"layers": inp.layers,
|
|
65
|
+
"required": inp.required,
|
|
66
|
+
"load_all_layers": inp.load_all_layers,
|
|
67
|
+
"is_target": inp.is_target,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
key_data = {
|
|
71
|
+
"groups": groups,
|
|
72
|
+
"names": names,
|
|
73
|
+
"tags": tags,
|
|
74
|
+
"num_samples": num_samples,
|
|
75
|
+
"skip_targets": skip_targets,
|
|
76
|
+
"inputs": inputs_data,
|
|
77
|
+
}
|
|
78
|
+
self.index_key = hashlib.sha256(
|
|
79
|
+
json.dumps(key_data, sort_keys=True).encode()
|
|
80
|
+
).hexdigest()
|
|
81
|
+
|
|
82
|
+
def _get_config_hash(self) -> str:
|
|
83
|
+
"""Get hash of config.json for quick validation.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
A 16-character hex string hash of the config, or empty string if no config.
|
|
87
|
+
"""
|
|
88
|
+
config_path = self.dataset_path / "config.json"
|
|
89
|
+
if config_path.exists():
|
|
90
|
+
with config_path.open() as f:
|
|
91
|
+
return hashlib.sha256(f.read().encode()).hexdigest()[:16]
|
|
92
|
+
return ""
|
|
93
|
+
|
|
94
|
+
def load_windows(self, refresh: bool = False) -> list[Window] | None:
|
|
95
|
+
"""Load indexed window list if valid, else return None.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
refresh: If True, ignore existing index and return None.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
List of Window objects if index is valid, None otherwise.
|
|
102
|
+
"""
|
|
103
|
+
if refresh:
|
|
104
|
+
logger.info("refresh=True, rebuilding index")
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
index_file = self.index_dir / f"{self.index_key}.json"
|
|
108
|
+
if not index_file.exists():
|
|
109
|
+
logger.info(f"No index found at {index_file}, will build")
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
with index_file.open() as f:
|
|
114
|
+
index_data = json.load(f)
|
|
115
|
+
except (OSError, json.JSONDecodeError):
|
|
116
|
+
logger.warning(f"Corrupted index file at {index_file}, will rebuild")
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
# Check index version
|
|
120
|
+
if index_data.get("version") != INDEX_VERSION:
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Index version mismatch (got {index_data.get('version')}, "
|
|
123
|
+
f"expected {INDEX_VERSION}), will rebuild"
|
|
124
|
+
)
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
# Quick validation: check config hash
|
|
128
|
+
if index_data.get("config_hash") != self._get_config_hash():
|
|
129
|
+
logger.info("Config hash mismatch, index invalidated")
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
# Deserialize windows
|
|
133
|
+
return [Window.from_metadata(self.storage, w) for w in index_data["windows"]]
|
|
134
|
+
|
|
135
|
+
def save_windows(self, windows: list[Window]) -> None:
|
|
136
|
+
"""Save processed windows to index with atomic write.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
windows: List of Window objects to index.
|
|
140
|
+
"""
|
|
141
|
+
self.index_dir.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
index_file = self.index_dir / f"{self.index_key}.json"
|
|
143
|
+
|
|
144
|
+
# Serialize windows
|
|
145
|
+
serialized_windows = [w.get_metadata() for w in windows]
|
|
146
|
+
|
|
147
|
+
index_data = {
|
|
148
|
+
"version": INDEX_VERSION,
|
|
149
|
+
"config_hash": self._get_config_hash(),
|
|
150
|
+
"created_at": datetime.now().isoformat(),
|
|
151
|
+
"num_windows": len(windows),
|
|
152
|
+
"windows": serialized_windows,
|
|
153
|
+
}
|
|
154
|
+
with open_atomic(index_file, "w") as f:
|
|
155
|
+
json.dump(index_data, f)
|
|
156
|
+
logger.info(f"Saved {len(windows)} windows to index at {index_file}")
|
rslearn/train/model_context.py
CHANGED
|
@@ -43,6 +43,22 @@ class RasterImage:
|
|
|
43
43
|
raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
|
|
44
44
|
return self.image[:, 0]
|
|
45
45
|
|
|
46
|
+
def get_hw_tensor(self) -> torch.Tensor:
|
|
47
|
+
"""Get a 2D HW tensor from a single-channel, single-timestep RasterImage.
|
|
48
|
+
|
|
49
|
+
This function checks that C=1 and T=1, then returns the HW tensor.
|
|
50
|
+
Useful for per-pixel labels like segmentation masks.
|
|
51
|
+
"""
|
|
52
|
+
if self.image.shape[0] != 1:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Expected single channel (C=1), got {self.image.shape[0]}"
|
|
55
|
+
)
|
|
56
|
+
if self.image.shape[1] != 1:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Expected single timestep (T=1), got {self.image.shape[1]}"
|
|
59
|
+
)
|
|
60
|
+
return self.image[0, 0]
|
|
61
|
+
|
|
46
62
|
|
|
47
63
|
@dataclass
|
|
48
64
|
class SampleMetadata:
|