rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/train/data_module.py
CHANGED
|
@@ -1,30 +1,47 @@
|
|
|
1
1
|
"""Default LightningDataModule for rslearn."""
|
|
2
2
|
|
|
3
|
+
import math
|
|
4
|
+
import random
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from collections.abc import Iterator
|
|
3
7
|
from typing import Any
|
|
4
8
|
|
|
5
9
|
import lightning as L
|
|
6
10
|
import torch
|
|
7
|
-
from torch.utils.data import DataLoader
|
|
11
|
+
from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
|
|
8
12
|
from upath import UPath
|
|
9
13
|
|
|
10
14
|
from rslearn.dataset import Dataset
|
|
15
|
+
from rslearn.log_utils import get_logger
|
|
11
16
|
from rslearn.train.tasks import Task
|
|
12
17
|
|
|
13
|
-
from .
|
|
18
|
+
from .all_patches_dataset import (
|
|
19
|
+
InMemoryAllPatchesDataset,
|
|
20
|
+
IterableAllPatchesDataset,
|
|
21
|
+
)
|
|
22
|
+
from .dataset import (
|
|
23
|
+
DataInput,
|
|
24
|
+
ModelDataset,
|
|
25
|
+
MultiDataset,
|
|
26
|
+
RetryDataset,
|
|
27
|
+
SplitConfig,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
logger = get_logger(__name__)
|
|
14
31
|
|
|
15
32
|
|
|
16
33
|
def collate_fn(
|
|
17
|
-
batch: list[tuple[dict[str, Any], dict[str, Any]]],
|
|
18
|
-
) -> tuple
|
|
34
|
+
batch: list[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]],
|
|
35
|
+
) -> tuple:
|
|
19
36
|
"""Collate batch of training examples.
|
|
20
37
|
|
|
21
38
|
We just make list of the inputs and another of the targets.
|
|
22
39
|
|
|
23
40
|
Args:
|
|
24
|
-
batch: list of input/target for each example
|
|
41
|
+
batch: list of input/target/metadata for each example
|
|
25
42
|
|
|
26
43
|
Returns:
|
|
27
|
-
a tuple (inputs, targets)
|
|
44
|
+
a tuple (inputs, targets, metadatas)
|
|
28
45
|
"""
|
|
29
46
|
return tuple(zip(*batch))
|
|
30
47
|
|
|
@@ -43,27 +60,38 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
43
60
|
path_options: dict[str, Any] = {},
|
|
44
61
|
batch_size: int = 1,
|
|
45
62
|
num_workers: int = 0,
|
|
63
|
+
init_workers: int = 0,
|
|
46
64
|
default_config: SplitConfig = SplitConfig(),
|
|
47
65
|
train_config: SplitConfig = SplitConfig(),
|
|
48
66
|
val_config: SplitConfig = SplitConfig(),
|
|
49
67
|
test_config: SplitConfig = SplitConfig(),
|
|
50
68
|
predict_config: SplitConfig = SplitConfig(),
|
|
51
|
-
|
|
69
|
+
name: str | None = None,
|
|
70
|
+
retries: int = 0,
|
|
71
|
+
use_in_memory_all_patches_dataset: bool = False,
|
|
72
|
+
) -> None:
|
|
52
73
|
"""Initialize a new RslearnDataModule.
|
|
53
74
|
|
|
54
75
|
Args:
|
|
55
76
|
inputs: what to read from the underlying dataset
|
|
56
77
|
task: the task to train on
|
|
57
|
-
path: the dataset path
|
|
78
|
+
path: the dataset path
|
|
58
79
|
path_options: additional options for path to pass to fsspec.
|
|
59
80
|
batch_size: the batch size
|
|
60
81
|
num_workers: number of data loader worker processes, or 0 to use main
|
|
61
82
|
process only
|
|
83
|
+
init_workers: number of workers used to initialize the dataset, e.g. for
|
|
84
|
+
loading the list of windows. Defaults to 0 which uses num_workers for
|
|
85
|
+
this setting
|
|
62
86
|
default_config: default split configuration
|
|
63
87
|
train_config: split config for train split
|
|
64
88
|
val_config: split config for val split
|
|
65
89
|
test_config: split config for test split
|
|
66
90
|
predict_config: split config for predict split
|
|
91
|
+
name: name of the dataset
|
|
92
|
+
retries: number of retries to attempt for getitem calls
|
|
93
|
+
use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
|
|
94
|
+
instead of IterableAllPatchesDataset if load_all_patches is set to true.
|
|
67
95
|
"""
|
|
68
96
|
super().__init__()
|
|
69
97
|
self.inputs = inputs
|
|
@@ -71,7 +99,10 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
71
99
|
self.path = UPath(path, **path_options)
|
|
72
100
|
self.batch_size = batch_size
|
|
73
101
|
self.num_workers = num_workers
|
|
74
|
-
|
|
102
|
+
self.init_workers = init_workers if init_workers > 0 else self.num_workers
|
|
103
|
+
self.name = name
|
|
104
|
+
self.retries = retries
|
|
105
|
+
self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
|
|
75
106
|
self.split_configs = {
|
|
76
107
|
"train": default_config.update(train_config),
|
|
77
108
|
"val": default_config.update(val_config),
|
|
@@ -79,11 +110,16 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
79
110
|
"predict": default_config.update(predict_config),
|
|
80
111
|
}
|
|
81
112
|
|
|
82
|
-
def setup(
|
|
113
|
+
def setup(
|
|
114
|
+
self, stage: str, use_in_memory_all_patches_dataset: bool | None = None
|
|
115
|
+
) -> None:
|
|
83
116
|
"""Set up datasets and samplers.
|
|
84
117
|
|
|
85
118
|
Args:
|
|
86
119
|
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
|
120
|
+
use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
|
|
121
|
+
instead of IterableAllPatchesDataset if load_all_patches is set to true.
|
|
122
|
+
If None, uses the value of self.use_in_memory_all_patches_dataset.
|
|
87
123
|
"""
|
|
88
124
|
stage_to_splits = {
|
|
89
125
|
"fit": ["train", "val"],
|
|
@@ -93,31 +129,112 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
93
129
|
}
|
|
94
130
|
self.datasets = {}
|
|
95
131
|
for split in stage_to_splits[stage]:
|
|
132
|
+
split_config = self.split_configs[split]
|
|
96
133
|
dataset = ModelDataset(
|
|
97
134
|
dataset=Dataset(path=self.path),
|
|
98
135
|
split_config=self.split_configs[split],
|
|
99
136
|
inputs=self.inputs,
|
|
100
137
|
task=self.task,
|
|
101
|
-
workers=self.
|
|
138
|
+
workers=self.init_workers,
|
|
139
|
+
name=self.name,
|
|
140
|
+
fix_patch_pick=(split != "train"),
|
|
102
141
|
)
|
|
103
|
-
|
|
142
|
+
logger.info(f"got {len(dataset)} examples in split {split}")
|
|
143
|
+
if split_config.get_load_all_patches():
|
|
144
|
+
if use_in_memory_all_patches_dataset is None:
|
|
145
|
+
use_in_memory_all_patches_dataset = (
|
|
146
|
+
self.use_in_memory_all_patches_dataset
|
|
147
|
+
)
|
|
148
|
+
logger.info(
|
|
149
|
+
f"using AllPatchesDataset (in_memory={use_in_memory_all_patches_dataset})"
|
|
150
|
+
)
|
|
151
|
+
patch_size = split_config.get_patch_size()
|
|
152
|
+
if patch_size is None:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"patch_size is not set but must be set if load_all_patches is set"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
all_patches_cls = IterableAllPatchesDataset
|
|
158
|
+
kwargs = dict(
|
|
159
|
+
dataset=dataset,
|
|
160
|
+
patch_size=patch_size,
|
|
161
|
+
overlap_ratio=split_config.get_overlap_ratio(),
|
|
162
|
+
rank=self.trainer.global_rank if self.trainer else 0,
|
|
163
|
+
world_size=self.trainer.world_size if self.trainer else 1,
|
|
164
|
+
)
|
|
165
|
+
if use_in_memory_all_patches_dataset:
|
|
166
|
+
kwargs.pop("rank")
|
|
167
|
+
kwargs.pop("world_size")
|
|
168
|
+
all_patches_cls = InMemoryAllPatchesDataset # type: ignore
|
|
169
|
+
|
|
170
|
+
dataset = all_patches_cls(**kwargs) # type: ignore
|
|
171
|
+
|
|
172
|
+
if self.retries > 0:
|
|
173
|
+
dataset = RetryDataset(dataset, retries=self.retries)
|
|
104
174
|
self.datasets[split] = dataset
|
|
105
|
-
print(f"got {len(self.datasets[split])} examples in split {split}")
|
|
106
175
|
|
|
107
|
-
def
|
|
176
|
+
def set_name(self, name: str) -> None:
|
|
177
|
+
"""Set the name of the dataset.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
name: the name of the dataset
|
|
181
|
+
"""
|
|
182
|
+
self.name = name
|
|
183
|
+
for dataset in self.datasets.values():
|
|
184
|
+
dataset.set_name(name)
|
|
185
|
+
|
|
186
|
+
def _get_dataloader(
|
|
187
|
+
self,
|
|
188
|
+
split: str,
|
|
189
|
+
) -> DataLoader[dict[str, torch.Tensor]]:
|
|
190
|
+
"""Get a dataloader for the given split.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
split: the split to get a dataloader for
|
|
194
|
+
"""
|
|
108
195
|
dataset = self.datasets[split]
|
|
109
|
-
|
|
196
|
+
split_config = self.split_configs[split]
|
|
197
|
+
|
|
198
|
+
# Enable persistent workers unless we are using main process.
|
|
199
|
+
persistent_workers = self.num_workers > 0
|
|
200
|
+
|
|
201
|
+
# If using all patches, limit number of workers to the number of windows.
|
|
202
|
+
# Otherwise it has to distribute the same window to different workers which can
|
|
203
|
+
# cause issues for RslearnWriter.
|
|
204
|
+
# If the number of windows is 0, then we can set positive number of workers
|
|
205
|
+
# since they won't yield anything anyway.
|
|
206
|
+
num_workers = self.num_workers
|
|
207
|
+
if split_config.load_all_patches and len(dataset.get_dataset_examples()) > 0:
|
|
208
|
+
num_workers = min(num_workers, len(dataset.get_dataset_examples()))
|
|
209
|
+
|
|
210
|
+
kwargs: dict[str, Any] = dict(
|
|
110
211
|
dataset=dataset,
|
|
111
212
|
batch_size=self.batch_size,
|
|
112
|
-
num_workers=
|
|
213
|
+
num_workers=num_workers,
|
|
113
214
|
collate_fn=collate_fn,
|
|
114
|
-
persistent_workers=
|
|
215
|
+
persistent_workers=persistent_workers,
|
|
115
216
|
)
|
|
116
|
-
|
|
217
|
+
should_shuffle = split == "train"
|
|
218
|
+
|
|
219
|
+
sampler_factory = split_config.sampler
|
|
117
220
|
if sampler_factory:
|
|
118
221
|
kwargs["sampler"] = sampler_factory.get_sampler(dataset)
|
|
119
|
-
elif
|
|
120
|
-
|
|
222
|
+
elif (
|
|
223
|
+
self.trainer is not None
|
|
224
|
+
and self.trainer.world_size is not None
|
|
225
|
+
and self.trainer.world_size > 1
|
|
226
|
+
and not isinstance(dataset, IterableDataset)
|
|
227
|
+
):
|
|
228
|
+
# Use distributed sampler in case ddp is enabled.
|
|
229
|
+
kwargs["sampler"] = DistributedSampler(
|
|
230
|
+
dataset,
|
|
231
|
+
num_replicas=self.trainer.world_size,
|
|
232
|
+
rank=self.trainer.global_rank,
|
|
233
|
+
shuffle=should_shuffle,
|
|
234
|
+
)
|
|
235
|
+
else:
|
|
236
|
+
kwargs["shuffle"] = should_shuffle
|
|
237
|
+
|
|
121
238
|
return DataLoader(**kwargs)
|
|
122
239
|
|
|
123
240
|
def train_dataloader(self) -> DataLoader[dict[str, torch.Tensor]]:
|
|
@@ -167,3 +284,310 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
167
284
|
dataset or sampler, or if the dataset or sampler has length 0.
|
|
168
285
|
"""
|
|
169
286
|
return self._get_dataloader("predict")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class MultiDatasetDataModule(L.LightningDataModule):
|
|
290
|
+
"""Data module that manages multiple RslearnDataModule instances.
|
|
291
|
+
|
|
292
|
+
This module creates and manages multiple RslearnDataModule instances, each handling
|
|
293
|
+
a different dataset. It provides a unified interface for training on multiple datasets
|
|
294
|
+
with different modalities and labels.
|
|
295
|
+
|
|
296
|
+
Each dataset can have different:
|
|
297
|
+
- Input modalities (e.g., Sentinel-2 vs Landsat)
|
|
298
|
+
- Label schemas (e.g., different classification classes)
|
|
299
|
+
- Task types (e.g., classification vs detection)
|
|
300
|
+
- Transforms and preprocessing
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def __init__(
|
|
304
|
+
self,
|
|
305
|
+
data_modules: dict[str, RslearnDataModule],
|
|
306
|
+
num_workers: int = 32,
|
|
307
|
+
sample_mode: str = "random_cycle",
|
|
308
|
+
batch_sizes: int | dict[str, int] | None = None,
|
|
309
|
+
refill_batches: bool = False,
|
|
310
|
+
per_dataset_patch_limit: int | None = None,
|
|
311
|
+
steps_per_dataset: int | None = None,
|
|
312
|
+
disabled_datasets: list[str] | None = None,
|
|
313
|
+
) -> None:
|
|
314
|
+
"""Initialize a new MultiDatasetDataModule.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
data_modules: dict mapping dataset names to RslearnDataModule objects
|
|
318
|
+
num_workers: the maximum number of workers to use for the dataloader
|
|
319
|
+
sample_mode: the mode to sample from the datasets ("random", "cycle", "random_cycle", "reptile")
|
|
320
|
+
batch_sizes: the batch size for all datasets, or a dict mapping dataset
|
|
321
|
+
names to batch sizes, or None to use the batch size of the largest
|
|
322
|
+
dataset (default: None)
|
|
323
|
+
refill_batches: whether to refill empty dataset iterators
|
|
324
|
+
once they run out each epoch (default: False)
|
|
325
|
+
per_dataset_patch_limit: the maximum number of patches to sample from each dataset
|
|
326
|
+
per epoch during training. Does not affect validation (default: None = no limit)
|
|
327
|
+
steps_per_dataset: the number of steps to sample from each dataset in a row (requires that
|
|
328
|
+
sample_mode is "reptile")
|
|
329
|
+
disabled_datasets: list of datasets to disable (default: None = no disabled datasets)
|
|
330
|
+
"""
|
|
331
|
+
super().__init__()
|
|
332
|
+
self.data_modules = data_modules
|
|
333
|
+
self.num_workers = num_workers
|
|
334
|
+
self.sample_mode = sample_mode
|
|
335
|
+
self.batch_sizes = batch_sizes
|
|
336
|
+
self.refill_batches = refill_batches
|
|
337
|
+
self.per_dataset_patch_limit = per_dataset_patch_limit
|
|
338
|
+
self.steps_per_dataset = steps_per_dataset
|
|
339
|
+
self.disabled_datasets = disabled_datasets or []
|
|
340
|
+
|
|
341
|
+
for dataset in self.disabled_datasets:
|
|
342
|
+
if dataset in self.data_modules:
|
|
343
|
+
del self.data_modules[dataset]
|
|
344
|
+
logger.info(f"Skipping disabled dataset {dataset}")
|
|
345
|
+
else:
|
|
346
|
+
logger.info(f"Could not find dataset {dataset} to skip")
|
|
347
|
+
|
|
348
|
+
def setup(self, stage: str | None = None) -> None:
|
|
349
|
+
"""Set up the datasets for the given stage. Also assign dataset-specific names.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
stage: The stage to set up ('fit', 'validate', 'test', 'predict')
|
|
353
|
+
"""
|
|
354
|
+
for name, data_module in self.data_modules.items():
|
|
355
|
+
data_module.setup(stage, use_in_memory_all_patches_dataset=True) # type: ignore
|
|
356
|
+
data_module.set_name(name)
|
|
357
|
+
|
|
358
|
+
def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]:
|
|
359
|
+
datasets = {name: dm.datasets[split] for name, dm in self.data_modules.items()}
|
|
360
|
+
if isinstance(self.batch_sizes, dict):
|
|
361
|
+
batch_sizes = self.batch_sizes
|
|
362
|
+
else:
|
|
363
|
+
batch_size: int | None = self.batch_sizes
|
|
364
|
+
if batch_size is None:
|
|
365
|
+
batch_size = max(
|
|
366
|
+
self.data_modules.values(), key=lambda dm: dm.batch_size
|
|
367
|
+
).batch_size
|
|
368
|
+
batch_sizes = {name: batch_size for name in self.data_modules.keys()}
|
|
369
|
+
|
|
370
|
+
logger.info(f"{split} is using batch_sizes {batch_sizes}")
|
|
371
|
+
logger.info(f"{split} is using sample_mode {self.sample_mode}")
|
|
372
|
+
if self.per_dataset_patch_limit:
|
|
373
|
+
logger.info(
|
|
374
|
+
f"{split} is using per_dataset_patch_limit {self.per_dataset_patch_limit}"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
dataset = MultiDataset(datasets)
|
|
378
|
+
return DataLoader(
|
|
379
|
+
dataset=dataset,
|
|
380
|
+
pin_memory=True,
|
|
381
|
+
num_workers=self.num_workers,
|
|
382
|
+
persistent_workers=True,
|
|
383
|
+
collate_fn=collate_fn,
|
|
384
|
+
batch_sampler=DistributedPerDatasetBatchSampler(
|
|
385
|
+
multi_dataset=dataset,
|
|
386
|
+
batch_sizes=batch_sizes,
|
|
387
|
+
shuffle=(split == "train"),
|
|
388
|
+
num_replicas=self.trainer.world_size, # type: ignore
|
|
389
|
+
rank=self.trainer.global_rank, # type: ignore
|
|
390
|
+
sample_mode=self.sample_mode,
|
|
391
|
+
refill_batches=self.refill_batches,
|
|
392
|
+
per_dataset_patch_limit=(
|
|
393
|
+
self.per_dataset_patch_limit if split == "train" else None
|
|
394
|
+
),
|
|
395
|
+
steps_per_dataset=self.steps_per_dataset,
|
|
396
|
+
),
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def train_dataloader(self) -> DataLoader:
|
|
400
|
+
"""Get the training dataloader."""
|
|
401
|
+
return self._get_dataloader("train")
|
|
402
|
+
|
|
403
|
+
def val_dataloader(self) -> DataLoader:
|
|
404
|
+
"""Get the validation dataloader."""
|
|
405
|
+
return self._get_dataloader("val")
|
|
406
|
+
|
|
407
|
+
def test_dataloader(self) -> DataLoader:
|
|
408
|
+
"""Get the test dataloader."""
|
|
409
|
+
return self._get_dataloader("test")
|
|
410
|
+
|
|
411
|
+
def predict_dataloader(self) -> DataLoader:
|
|
412
|
+
"""Get the predict dataloader."""
|
|
413
|
+
return self._get_dataloader("predict")
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
class DistributedPerDatasetBatchSampler(torch.utils.data.Sampler[list[int]]):
|
|
417
|
+
"""Distributed batch sampler yielding batches from one sub-dataset per batch.
|
|
418
|
+
|
|
419
|
+
Wraps torch DistributedSampler to first split indices across ranks,
|
|
420
|
+
then does "one-subdataset-per-batch" sampling in each process.
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
def __init__(
|
|
424
|
+
self,
|
|
425
|
+
multi_dataset: MultiDataset,
|
|
426
|
+
batch_sizes: dict[str, int],
|
|
427
|
+
shuffle: bool = True,
|
|
428
|
+
num_replicas: int | None = None,
|
|
429
|
+
rank: int | None = None,
|
|
430
|
+
sample_mode: str = "random_cycle",
|
|
431
|
+
refill_batches: bool = False,
|
|
432
|
+
steps_per_dataset: int | None = None,
|
|
433
|
+
per_dataset_patch_limit: int | None = None,
|
|
434
|
+
) -> None:
|
|
435
|
+
"""Initialize a new DistributedPerDatasetBatchSampler.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
multi_dataset: the MultiDataset to sample from
|
|
439
|
+
batch_sizes: the batch size for each dataset
|
|
440
|
+
shuffle: whether to shuffle the indices
|
|
441
|
+
num_replicas: the number of replicas
|
|
442
|
+
rank: the rank
|
|
443
|
+
sample_mode: the mode to sample from the datasets ("random", "cycle", "random_cycle", "reptile")
|
|
444
|
+
refill_batches: whether to refill empty dataset iterators
|
|
445
|
+
once they run out each epoch
|
|
446
|
+
steps_per_dataset: the number of steps to sample from each dataset
|
|
447
|
+
per_dataset_patch_limit: the maximum number of patches to sample from each dataset
|
|
448
|
+
per epoch during training. Does not affect validation (default: None = no limit)
|
|
449
|
+
steps_per_dataset: the number of steps to sample from each dataset in a row (requires that
|
|
450
|
+
sample_mode is "reptile")
|
|
451
|
+
"""
|
|
452
|
+
self.multi_dataset = multi_dataset
|
|
453
|
+
self.batch_sizes = batch_sizes
|
|
454
|
+
self.sample_mode = sample_mode
|
|
455
|
+
self.refill_batches = refill_batches
|
|
456
|
+
self.per_dataset_patch_limit = per_dataset_patch_limit
|
|
457
|
+
self.steps_per_dataset: int = steps_per_dataset # type: ignore
|
|
458
|
+
self.epoch = 0
|
|
459
|
+
|
|
460
|
+
if sample_mode == "reptile":
|
|
461
|
+
assert steps_per_dataset is not None, (
|
|
462
|
+
"steps_per_dataset must be provided when sample_mode is 'reptile'"
|
|
463
|
+
)
|
|
464
|
+
assert sample_mode in (
|
|
465
|
+
"random",
|
|
466
|
+
"cycle",
|
|
467
|
+
"random_cycle",
|
|
468
|
+
"reptile",
|
|
469
|
+
), f"Invalid sample_mode: {sample_mode}"
|
|
470
|
+
|
|
471
|
+
# For now, we just track the total number of batches if refill_batches is True,
|
|
472
|
+
# so we must the datasets come out balanced during each epoch
|
|
473
|
+
if refill_batches and self.sample_mode not in (
|
|
474
|
+
"cycle",
|
|
475
|
+
"random_cycle",
|
|
476
|
+
"reptile",
|
|
477
|
+
):
|
|
478
|
+
raise ValueError("refill_batches is only supported with round_robin")
|
|
479
|
+
|
|
480
|
+
# Using one DistributedSampler per dataset guarantees equal splitting
|
|
481
|
+
# across all datasets across all ranks
|
|
482
|
+
self.dist_samplers = {
|
|
483
|
+
name: DistributedSampler(
|
|
484
|
+
dataset,
|
|
485
|
+
num_replicas=num_replicas,
|
|
486
|
+
rank=rank,
|
|
487
|
+
shuffle=shuffle,
|
|
488
|
+
drop_last=False,
|
|
489
|
+
)
|
|
490
|
+
for name, dataset in multi_dataset.datasets.items()
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
for k, v in self.dist_samplers.items():
|
|
494
|
+
logger.info(f"Dataset {k} has {len(v)} samples")
|
|
495
|
+
|
|
496
|
+
def set_epoch(self, epoch: int) -> None:
|
|
497
|
+
"""Set the epoch for the distributed sampler.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
epoch: the epoch to set
|
|
501
|
+
"""
|
|
502
|
+
self.epoch = epoch
|
|
503
|
+
for dist_sampler in self.dist_samplers.values():
|
|
504
|
+
dist_sampler.set_epoch(epoch)
|
|
505
|
+
|
|
506
|
+
def __iter__(self) -> Iterator[list[int]]:
|
|
507
|
+
"""Iterate over the batches."""
|
|
508
|
+
# Get the per-rank, per-epoch list of properly offset multi-dataset indices
|
|
509
|
+
partitioned: dict[str, list[int]] = {}
|
|
510
|
+
refill: dict[str, list[int]] = defaultdict(list)
|
|
511
|
+
for name, sampler in self.dist_samplers.items():
|
|
512
|
+
offset = self.multi_dataset.buckets[name].start
|
|
513
|
+
partitioned[name] = [idx + offset for idx in sampler]
|
|
514
|
+
if self.per_dataset_patch_limit:
|
|
515
|
+
partitioned[name] = partitioned[name][: self.per_dataset_patch_limit]
|
|
516
|
+
|
|
517
|
+
# Seed is shared aross all ranks but shuffled per epoch
|
|
518
|
+
rng = random.Random(self.epoch)
|
|
519
|
+
picks = list(partitioned.keys())
|
|
520
|
+
last_picked = -1
|
|
521
|
+
dataset_counter = 0
|
|
522
|
+
|
|
523
|
+
# Random mode samples uniformly across all datasets regardless of size
|
|
524
|
+
for n in range(len(self)):
|
|
525
|
+
available = [name for name, idxs in partitioned.items() if idxs]
|
|
526
|
+
if not self.refill_batches:
|
|
527
|
+
# For cycle, only pick from available datasets, but if
|
|
528
|
+
# we are refilling batches then all datasets are available
|
|
529
|
+
picks = [name for name in picks if name in available]
|
|
530
|
+
if not available:
|
|
531
|
+
logger.warning(f"Found no available batch on step {n} of {len(self)}")
|
|
532
|
+
break
|
|
533
|
+
|
|
534
|
+
if self.sample_mode == "cycle":
|
|
535
|
+
last_picked = (last_picked + 1) % len(picks)
|
|
536
|
+
name = picks[last_picked]
|
|
537
|
+
|
|
538
|
+
elif self.sample_mode == "reptile":
|
|
539
|
+
# Sample n times from the same dataset before moving onto the next,
|
|
540
|
+
# but still ensure we sample from all the datasets before repeating
|
|
541
|
+
# This is so that we can use the refill_batches feature
|
|
542
|
+
if dataset_counter == 0:
|
|
543
|
+
name = rng.choice(picks)
|
|
544
|
+
last_picked = picks.index(name)
|
|
545
|
+
else:
|
|
546
|
+
name = picks[last_picked]
|
|
547
|
+
dataset_counter += 1
|
|
548
|
+
if dataset_counter >= self.steps_per_dataset:
|
|
549
|
+
dataset_counter = 0
|
|
550
|
+
picks.remove(name)
|
|
551
|
+
if not picks:
|
|
552
|
+
picks = list(partitioned.keys())
|
|
553
|
+
|
|
554
|
+
elif self.sample_mode == "random_cycle":
|
|
555
|
+
name = rng.choice(picks)
|
|
556
|
+
picks.remove(name)
|
|
557
|
+
if not picks:
|
|
558
|
+
picks = list(partitioned.keys())
|
|
559
|
+
|
|
560
|
+
else:
|
|
561
|
+
name = rng.choice(available)
|
|
562
|
+
|
|
563
|
+
idxs = partitioned[name]
|
|
564
|
+
batch, partitioned[name] = (
|
|
565
|
+
idxs[: self.batch_sizes[name]],
|
|
566
|
+
idxs[self.batch_sizes[name] :],
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# If we are refilling batches, we just keep adding the indexes
|
|
570
|
+
if self.refill_batches:
|
|
571
|
+
refill[name].extend(batch)
|
|
572
|
+
if len(partitioned[name]) == 0:
|
|
573
|
+
# Shuffle batches again once we have to replenish them
|
|
574
|
+
partitioned[name], refill[name] = refill[name], []
|
|
575
|
+
rng.shuffle(partitioned[name])
|
|
576
|
+
|
|
577
|
+
yield batch
|
|
578
|
+
|
|
579
|
+
def __len__(self) -> int:
|
|
580
|
+
"""Return the number of batches."""
|
|
581
|
+
|
|
582
|
+
def len_iter() -> Iterator[int]:
|
|
583
|
+
"""Iterate over the number of batches for each dataset."""
|
|
584
|
+
for name, sampler in self.dist_samplers.items():
|
|
585
|
+
length = len(sampler)
|
|
586
|
+
if self.per_dataset_patch_limit:
|
|
587
|
+
length = min(length, self.per_dataset_patch_limit)
|
|
588
|
+
yield math.ceil(length / self.batch_sizes[name])
|
|
589
|
+
|
|
590
|
+
if self.refill_batches:
|
|
591
|
+
return max(len_iter()) * len(self.dist_samplers)
|
|
592
|
+
else:
|
|
593
|
+
return sum(len_iter())
|