rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Wrapper around ModelDataset to load all
|
|
1
|
+
"""Wrapper around ModelDataset to load all crops in a window."""
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
from collections.abc import Iterable, Iterator
|
|
@@ -14,70 +14,78 @@ from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
|
14
14
|
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
18
|
-
|
|
17
|
+
def get_window_crop_options(
|
|
18
|
+
crop_size: tuple[int, int],
|
|
19
19
|
overlap_size: tuple[int, int],
|
|
20
20
|
bounds: PixelBounds,
|
|
21
21
|
) -> list[PixelBounds]:
|
|
22
|
-
"""Get the bounds of each input
|
|
22
|
+
"""Get the bounds of each input crop within the window bounds.
|
|
23
23
|
|
|
24
|
-
This is used when running inference on all
|
|
25
|
-
compute the position of each
|
|
24
|
+
This is used when running inference on all crops of a large window, to
|
|
25
|
+
compute the position of each crop.
|
|
26
26
|
|
|
27
27
|
Args:
|
|
28
|
-
|
|
29
|
-
overlap_size: the size of the overlap between
|
|
30
|
-
bounds: the window bounds to divide up into smaller
|
|
28
|
+
crop_size: the size of the crops to extract.
|
|
29
|
+
overlap_size: the size of the overlap between crops.
|
|
30
|
+
bounds: the window bounds to divide up into smaller crops.
|
|
31
31
|
|
|
32
32
|
Returns:
|
|
33
|
-
a list of
|
|
34
|
-
bottommost
|
|
33
|
+
a list of crop bounds within the overall bounds. The rightmost and
|
|
34
|
+
bottommost crops may extend beyond the provided bounds.
|
|
35
35
|
"""
|
|
36
|
-
# We stride the
|
|
37
|
-
#
|
|
38
|
-
#
|
|
39
|
-
#
|
|
36
|
+
# We stride the crops by (crop_size - overlap_size) until the last crop.
|
|
37
|
+
# The first crop always starts at bounds[0]/bounds[1]. It's okay if it extends
|
|
38
|
+
# beyond the window bounds since pad_slice_protect pads the tensors.
|
|
39
|
+
# We handle the last crop with a special case to ensure it does not exceed the
|
|
40
|
+
# window bounds. Instead, it may overlap the previous crop.
|
|
41
|
+
# Here is a simple 1D example:
|
|
42
|
+
# - Suppose bounds is [0, 15] with crop_size=8, overlap_size=2
|
|
43
|
+
# - Then the first crop should be [0, 8] (from first crop special case)
|
|
44
|
+
# - There will only be one crop in the middle, [6, 14]
|
|
45
|
+
# - And the last crop will be at [7, 15]
|
|
46
|
+
# - Note that, if the bounds was [0, 14], we will only have the first/last crop
|
|
47
|
+
# special cases with no crops in the middle from the range(...).
|
|
40
48
|
cols = [bounds[0]] + list(
|
|
41
49
|
range(
|
|
42
|
-
bounds[0] +
|
|
43
|
-
bounds[2] -
|
|
44
|
-
|
|
50
|
+
bounds[0] + crop_size[0] - overlap_size[0],
|
|
51
|
+
bounds[2] - crop_size[0],
|
|
52
|
+
crop_size[0] - overlap_size[0],
|
|
45
53
|
)
|
|
46
54
|
)
|
|
47
55
|
rows = [bounds[1]] + list(
|
|
48
56
|
range(
|
|
49
|
-
bounds[1] +
|
|
50
|
-
bounds[3] -
|
|
51
|
-
|
|
57
|
+
bounds[1] + crop_size[1] - overlap_size[1],
|
|
58
|
+
bounds[3] - crop_size[1],
|
|
59
|
+
crop_size[1] - overlap_size[1],
|
|
52
60
|
)
|
|
53
61
|
)
|
|
54
|
-
# Add last
|
|
55
|
-
if bounds[2] -
|
|
56
|
-
cols.append(bounds[2] -
|
|
57
|
-
if bounds[3] -
|
|
58
|
-
rows.append(bounds[3] -
|
|
62
|
+
# Add last crops only if the input is larger than one crop.
|
|
63
|
+
if bounds[2] - crop_size[0] > bounds[0]:
|
|
64
|
+
cols.append(bounds[2] - crop_size[0])
|
|
65
|
+
if bounds[3] - crop_size[1] > bounds[1]:
|
|
66
|
+
rows.append(bounds[3] - crop_size[1])
|
|
59
67
|
|
|
60
|
-
|
|
68
|
+
crop_bounds: list[PixelBounds] = []
|
|
61
69
|
for col in cols:
|
|
62
70
|
for row in rows:
|
|
63
|
-
|
|
64
|
-
return
|
|
71
|
+
crop_bounds.append((col, row, col + crop_size[0], row + crop_size[1]))
|
|
72
|
+
return crop_bounds
|
|
65
73
|
|
|
66
74
|
|
|
67
75
|
def pad_slice_protect(
|
|
68
76
|
raw_inputs: dict[str, Any],
|
|
69
77
|
passthrough_inputs: dict[str, Any],
|
|
70
|
-
|
|
78
|
+
crop_size: tuple[int, int],
|
|
71
79
|
inputs: dict[str, DataInput],
|
|
72
80
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
73
|
-
"""Pad tensors in-place by
|
|
81
|
+
"""Pad tensors in-place by crop size to protect slicing near right/bottom edges.
|
|
74
82
|
|
|
75
83
|
The padding is scaled based on each input's resolution_factor.
|
|
76
84
|
|
|
77
85
|
Args:
|
|
78
86
|
raw_inputs: the raw inputs to pad.
|
|
79
87
|
passthrough_inputs: the passthrough inputs to pad.
|
|
80
|
-
|
|
88
|
+
crop_size: the size of the crops to extract (at window resolution).
|
|
81
89
|
inputs: the DataInput definitions, used to get resolution_factor per input.
|
|
82
90
|
|
|
83
91
|
Returns:
|
|
@@ -91,8 +99,8 @@ def pad_slice_protect(
|
|
|
91
99
|
rf = inputs[input_name].resolution_factor
|
|
92
100
|
scale = rf.numerator / rf.denominator
|
|
93
101
|
# Scale the padding amount
|
|
94
|
-
scaled_pad_x = int(
|
|
95
|
-
scaled_pad_y = int(
|
|
102
|
+
scaled_pad_x = int(crop_size[0] * scale)
|
|
103
|
+
scaled_pad_y = int(crop_size[1] * scale)
|
|
96
104
|
d[input_name] = torch.nn.functional.pad(
|
|
97
105
|
value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
|
|
98
106
|
)
|
|
@@ -123,12 +131,12 @@ def crop_tensor_or_rasterimage(
|
|
|
123
131
|
)
|
|
124
132
|
|
|
125
133
|
|
|
126
|
-
class
|
|
127
|
-
"""This wraps a ModelDataset to iterate over all
|
|
134
|
+
class IterableAllCropsDataset(torch.utils.data.IterableDataset):
|
|
135
|
+
"""This wraps a ModelDataset to iterate over all crops in that dataset.
|
|
128
136
|
|
|
129
|
-
This should be used when SplitConfig.
|
|
130
|
-
is configured with no
|
|
131
|
-
in an
|
|
137
|
+
This should be used when SplitConfig.load_all_crops is enabled. The ModelDataset
|
|
138
|
+
is configured with no crop size (load entire windows), and the dataset is wrapped
|
|
139
|
+
in an AllCropsDataset.
|
|
132
140
|
|
|
133
141
|
Similar to DistributedSampler, we add extra samples at each rank to ensure
|
|
134
142
|
consistent number of batches across all ranks.
|
|
@@ -137,29 +145,27 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
137
145
|
def __init__(
|
|
138
146
|
self,
|
|
139
147
|
dataset: ModelDataset,
|
|
140
|
-
|
|
141
|
-
|
|
148
|
+
crop_size: tuple[int, int],
|
|
149
|
+
overlap_pixels: int = 0,
|
|
142
150
|
rank: int = 0,
|
|
143
151
|
world_size: int = 1,
|
|
144
152
|
):
|
|
145
|
-
"""Create a new
|
|
153
|
+
"""Create a new IterableAllCropsDataset.
|
|
146
154
|
|
|
147
155
|
Args:
|
|
148
156
|
dataset: the ModelDataset to wrap.
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
the right/bottom-most
|
|
152
|
-
all
|
|
157
|
+
crop_size: the size of the crops to extract.
|
|
158
|
+
overlap_pixels: the number of pixels shared between adjacent crops. Note
|
|
159
|
+
that the right/bottom-most crops may still overlap with other crops even
|
|
160
|
+
if overlap_pixels=0 since we ensure that all crops are contained in the
|
|
161
|
+
window bounds.
|
|
153
162
|
rank: the global rank of this train worker process.
|
|
154
163
|
world_size: the total number of train worker processes.
|
|
155
164
|
"""
|
|
156
165
|
super().__init__()
|
|
157
166
|
self.dataset = dataset
|
|
158
|
-
self.
|
|
159
|
-
self.overlap_size = (
|
|
160
|
-
round(self.patch_size[0] * overlap_ratio),
|
|
161
|
-
round(self.patch_size[1] * overlap_ratio),
|
|
162
|
-
)
|
|
167
|
+
self.crop_size = crop_size
|
|
168
|
+
self.overlap_size = (overlap_pixels, overlap_pixels)
|
|
163
169
|
self.rank = rank
|
|
164
170
|
self.world_size = world_size
|
|
165
171
|
self.windows = self.dataset.get_dataset_examples()
|
|
@@ -173,17 +179,17 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
173
179
|
"""
|
|
174
180
|
self.dataset.set_name(name)
|
|
175
181
|
|
|
176
|
-
def
|
|
177
|
-
"""Get the number of
|
|
182
|
+
def get_window_num_crops(self, bounds: PixelBounds) -> int:
|
|
183
|
+
"""Get the number of crops for these bounds.
|
|
178
184
|
|
|
179
|
-
This corresponds to the length of the list returned by
|
|
185
|
+
This corresponds to the length of the list returned by get_window_crop_options.
|
|
180
186
|
"""
|
|
181
187
|
num_cols = (
|
|
182
188
|
len(
|
|
183
189
|
range(
|
|
184
190
|
bounds[0],
|
|
185
|
-
bounds[2] - self.
|
|
186
|
-
self.
|
|
191
|
+
bounds[2] - self.crop_size[0],
|
|
192
|
+
self.crop_size[0] - self.overlap_size[0],
|
|
187
193
|
)
|
|
188
194
|
)
|
|
189
195
|
+ 1
|
|
@@ -192,8 +198,8 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
192
198
|
len(
|
|
193
199
|
range(
|
|
194
200
|
bounds[1],
|
|
195
|
-
bounds[3] - self.
|
|
196
|
-
self.
|
|
201
|
+
bounds[3] - self.crop_size[1],
|
|
202
|
+
self.crop_size[1] - self.overlap_size[1],
|
|
197
203
|
)
|
|
198
204
|
)
|
|
199
205
|
+ 1
|
|
@@ -235,14 +241,14 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
235
241
|
]
|
|
236
242
|
|
|
237
243
|
# Now compute the maximum number of samples across workers.
|
|
238
|
-
|
|
244
|
+
max_num_crops = 0
|
|
239
245
|
for worker_windows in windows_by_worker:
|
|
240
|
-
|
|
246
|
+
worker_num_crops = 0
|
|
241
247
|
for window_id in worker_windows:
|
|
242
|
-
|
|
248
|
+
worker_num_crops += self.get_window_num_crops(
|
|
243
249
|
self.windows[window_id].bounds
|
|
244
250
|
)
|
|
245
|
-
|
|
251
|
+
max_num_crops = max(max_num_crops, worker_num_crops)
|
|
246
252
|
|
|
247
253
|
# Each worker needs at least one window, otherwise it won't be able to pad.
|
|
248
254
|
# Unless there are zero windows total, which is fine.
|
|
@@ -252,17 +258,17 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
252
258
|
# window in the end.
|
|
253
259
|
# So now we raise an error instead, and require the number of workers to be
|
|
254
260
|
# less than the number of windows.
|
|
255
|
-
if len(windows_by_worker[global_worker_id]) == 0 and
|
|
261
|
+
if len(windows_by_worker[global_worker_id]) == 0 and max_num_crops > 0:
|
|
256
262
|
raise ValueError(
|
|
257
263
|
f"the number of workers {global_num_workers} must be <= the number of windows {len(self.windows)}"
|
|
258
264
|
)
|
|
259
265
|
|
|
260
|
-
return (windows_by_worker[global_worker_id],
|
|
266
|
+
return (windows_by_worker[global_worker_id], max_num_crops)
|
|
261
267
|
|
|
262
268
|
def __iter__(
|
|
263
269
|
self,
|
|
264
270
|
) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
|
|
265
|
-
"""Iterate over all
|
|
271
|
+
"""Iterate over all crops in each element of the underlying ModelDataset."""
|
|
266
272
|
# Iterate over the window IDs until we have returned enough samples.
|
|
267
273
|
window_ids, num_samples_needed = self._get_worker_iteration_data()
|
|
268
274
|
num_samples_returned = 0
|
|
@@ -272,32 +278,32 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
272
278
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
|
|
273
279
|
window_id
|
|
274
280
|
)
|
|
275
|
-
bounds = metadata.
|
|
281
|
+
bounds = metadata.crop_bounds
|
|
276
282
|
|
|
277
|
-
# For simplicity, pad tensors by
|
|
283
|
+
# For simplicity, pad tensors by crop size to ensure that any crop bounds
|
|
278
284
|
# extending outside the window bounds will not have issues when we slice
|
|
279
285
|
# the tensors later. Padding is scaled per-input based on resolution_factor.
|
|
280
286
|
pad_slice_protect(
|
|
281
|
-
raw_inputs, passthrough_inputs, self.
|
|
287
|
+
raw_inputs, passthrough_inputs, self.crop_size, self.inputs
|
|
282
288
|
)
|
|
283
289
|
|
|
284
|
-
# Now iterate over the
|
|
290
|
+
# Now iterate over the crops and extract/yield them.
|
|
285
291
|
# Note that, in case user is leveraging RslearnWriter, it is important that
|
|
286
|
-
# the
|
|
287
|
-
|
|
288
|
-
self.
|
|
292
|
+
# the crop_idx be increasing (as we iterate) within one window.
|
|
293
|
+
crops = get_window_crop_options(
|
|
294
|
+
self.crop_size, self.overlap_size, bounds
|
|
289
295
|
)
|
|
290
|
-
for
|
|
296
|
+
for crop_idx, crop_bounds in enumerate(crops):
|
|
291
297
|
cur_geom = STGeometry(
|
|
292
|
-
metadata.projection, shapely.box(*
|
|
298
|
+
metadata.projection, shapely.box(*crop_bounds), None
|
|
293
299
|
)
|
|
294
300
|
start_offset = (
|
|
295
|
-
|
|
296
|
-
|
|
301
|
+
crop_bounds[0] - bounds[0],
|
|
302
|
+
crop_bounds[1] - bounds[1],
|
|
297
303
|
)
|
|
298
304
|
end_offset = (
|
|
299
|
-
|
|
300
|
-
|
|
305
|
+
crop_bounds[2] - bounds[0],
|
|
306
|
+
crop_bounds[3] - bounds[1],
|
|
301
307
|
)
|
|
302
308
|
|
|
303
309
|
# Define a helper function to handle each input dict.
|
|
@@ -339,9 +345,9 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
339
345
|
# Adjust the metadata as well.
|
|
340
346
|
cur_metadata = replace(
|
|
341
347
|
metadata,
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
348
|
+
crop_bounds=crop_bounds,
|
|
349
|
+
crop_idx=crop_idx,
|
|
350
|
+
num_crops_in_window=len(crops),
|
|
345
351
|
)
|
|
346
352
|
|
|
347
353
|
# Now we can compute input and target dicts via the task.
|
|
@@ -369,37 +375,34 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
369
375
|
return self.dataset.get_dataset_examples()
|
|
370
376
|
|
|
371
377
|
|
|
372
|
-
class
|
|
373
|
-
"""This wraps a ModelDataset to iterate over all
|
|
378
|
+
class InMemoryAllCropsDataset(torch.utils.data.Dataset):
|
|
379
|
+
"""This wraps a ModelDataset to iterate over all crops in that dataset.
|
|
374
380
|
|
|
375
|
-
This should be used when SplitConfig.
|
|
381
|
+
This should be used when SplitConfig.load_all_crops is enabled.
|
|
376
382
|
|
|
377
|
-
This is a simpler version of
|
|
383
|
+
This is a simpler version of IterableAllCropsDataset that caches all windows in memory.
|
|
378
384
|
This is useful for small datasets that fit in memory.
|
|
379
385
|
"""
|
|
380
386
|
|
|
381
387
|
def __init__(
|
|
382
388
|
self,
|
|
383
389
|
dataset: ModelDataset,
|
|
384
|
-
|
|
385
|
-
|
|
390
|
+
crop_size: tuple[int, int],
|
|
391
|
+
overlap_pixels: int = 0,
|
|
386
392
|
):
|
|
387
|
-
"""Create a new
|
|
393
|
+
"""Create a new InMemoryAllCropsDataset.
|
|
388
394
|
|
|
389
395
|
Args:
|
|
390
396
|
dataset: the ModelDataset to wrap.
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
the right/bottom-most
|
|
394
|
-
all
|
|
397
|
+
crop_size: the size of the crops to extract.
|
|
398
|
+
overlap_pixels: the number of pixels shared between adjacent crops. Note
|
|
399
|
+
that the right/bottom-most crops may still overlap since we ensure that
|
|
400
|
+
all crops are contained in the window bounds.
|
|
395
401
|
"""
|
|
396
402
|
super().__init__()
|
|
397
403
|
self.dataset = dataset
|
|
398
|
-
self.
|
|
399
|
-
self.overlap_size = (
|
|
400
|
-
round(self.patch_size[0] * overlap_ratio),
|
|
401
|
-
round(self.patch_size[1] * overlap_ratio),
|
|
402
|
-
)
|
|
404
|
+
self.crop_size = crop_size
|
|
405
|
+
self.overlap_size = (overlap_pixels, overlap_pixels)
|
|
403
406
|
self.windows = self.dataset.get_dataset_examples()
|
|
404
407
|
self.inputs = dataset.inputs
|
|
405
408
|
self.window_cache: dict[
|
|
@@ -407,23 +410,23 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
407
410
|
] = {}
|
|
408
411
|
|
|
409
412
|
# Precompute the batch boundaries for each window
|
|
410
|
-
self.
|
|
413
|
+
self.crops = []
|
|
411
414
|
for window_id, window in enumerate(self.windows):
|
|
412
|
-
|
|
413
|
-
self.
|
|
415
|
+
window_crop_bounds = get_window_crop_options(
|
|
416
|
+
self.crop_size, self.overlap_size, window.bounds
|
|
414
417
|
)
|
|
415
|
-
for i,
|
|
416
|
-
self.
|
|
418
|
+
for i, crop_bound in enumerate(window_crop_bounds):
|
|
419
|
+
self.crops.append((window_id, crop_bound, (i, len(window_crop_bounds))))
|
|
417
420
|
|
|
418
421
|
def get_raw_inputs(
|
|
419
422
|
self, index: int
|
|
420
423
|
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
421
|
-
"""Get the raw inputs for a single
|
|
424
|
+
"""Get the raw inputs for a single crop. Retrieve from cache if possible.
|
|
422
425
|
|
|
423
|
-
Also crops/pads the tensors by
|
|
426
|
+
Also crops/pads the tensors by crop size to protect slicing near right/bottom edges.
|
|
424
427
|
|
|
425
428
|
Args:
|
|
426
|
-
index: the index of the
|
|
429
|
+
index: the index of the crop.
|
|
427
430
|
|
|
428
431
|
Returns:
|
|
429
432
|
a tuple of (raw_inputs, passthrough_inputs, metadata).
|
|
@@ -432,7 +435,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
432
435
|
return self.window_cache[index]
|
|
433
436
|
|
|
434
437
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
|
|
435
|
-
pad_slice_protect(raw_inputs, passthrough_inputs, self.
|
|
438
|
+
pad_slice_protect(raw_inputs, passthrough_inputs, self.crop_size, self.inputs)
|
|
436
439
|
|
|
437
440
|
self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
|
|
438
441
|
return self.window_cache[index]
|
|
@@ -476,20 +479,20 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
476
479
|
return cropped
|
|
477
480
|
|
|
478
481
|
def __len__(self) -> int:
|
|
479
|
-
"""Return the total number of
|
|
480
|
-
return len(self.
|
|
482
|
+
"""Return the total number of crops in the dataset."""
|
|
483
|
+
return len(self.crops)
|
|
481
484
|
|
|
482
485
|
def __getitem__(
|
|
483
486
|
self, index: int
|
|
484
487
|
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
485
|
-
"""Return (input_dict, target_dict, metadata) for a single flattened
|
|
486
|
-
(window_id,
|
|
488
|
+
"""Return (input_dict, target_dict, metadata) for a single flattened crop."""
|
|
489
|
+
(window_id, crop_bounds, (crop_idx, num_crops)) = self.crops[index]
|
|
487
490
|
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
|
|
488
|
-
bounds = metadata.
|
|
491
|
+
bounds = metadata.crop_bounds
|
|
489
492
|
|
|
490
|
-
cur_geom = STGeometry(metadata.projection, shapely.box(*
|
|
491
|
-
start_offset = (
|
|
492
|
-
end_offset = (
|
|
493
|
+
cur_geom = STGeometry(metadata.projection, shapely.box(*crop_bounds), None)
|
|
494
|
+
start_offset = (crop_bounds[0] - bounds[0], crop_bounds[1] - bounds[1])
|
|
495
|
+
end_offset = (crop_bounds[2] - bounds[0], crop_bounds[3] - bounds[1])
|
|
493
496
|
|
|
494
497
|
cur_raw_inputs = self._crop_input_dict(
|
|
495
498
|
raw_inputs, start_offset, end_offset, cur_geom
|
|
@@ -501,9 +504,9 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
501
504
|
# Adjust the metadata as well.
|
|
502
505
|
cur_metadata = replace(
|
|
503
506
|
metadata,
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
+
crop_bounds=crop_bounds,
|
|
508
|
+
crop_idx=crop_idx,
|
|
509
|
+
num_crops_in_window=num_crops,
|
|
507
510
|
)
|
|
508
511
|
|
|
509
512
|
# Now we can compute input and target dicts via the task.
|
rslearn/train/data_module.py
CHANGED
|
@@ -15,12 +15,13 @@ from rslearn.dataset import Dataset
|
|
|
15
15
|
from rslearn.log_utils import get_logger
|
|
16
16
|
from rslearn.train.tasks import Task
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
|
|
20
|
-
|
|
18
|
+
from .all_crops_dataset import (
|
|
19
|
+
InMemoryAllCropsDataset,
|
|
20
|
+
IterableAllCropsDataset,
|
|
21
21
|
)
|
|
22
22
|
from .dataset import (
|
|
23
23
|
DataInput,
|
|
24
|
+
IndexMode,
|
|
24
25
|
ModelDataset,
|
|
25
26
|
MultiDataset,
|
|
26
27
|
RetryDataset,
|
|
@@ -68,7 +69,8 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
68
69
|
predict_config: SplitConfig = SplitConfig(),
|
|
69
70
|
name: str | None = None,
|
|
70
71
|
retries: int = 0,
|
|
71
|
-
|
|
72
|
+
use_in_memory_all_crops_dataset: bool = False,
|
|
73
|
+
index_mode: IndexMode = IndexMode.OFF,
|
|
72
74
|
) -> None:
|
|
73
75
|
"""Initialize a new RslearnDataModule.
|
|
74
76
|
|
|
@@ -90,8 +92,9 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
90
92
|
predict_config: split config for predict split
|
|
91
93
|
name: name of the dataset
|
|
92
94
|
retries: number of retries to attempt for getitem calls
|
|
93
|
-
|
|
94
|
-
instead of
|
|
95
|
+
use_in_memory_all_crops_dataset: whether to use InMemoryAllCropsDataset
|
|
96
|
+
instead of IterableAllCropsDataset if load_all_crops is set to true.
|
|
97
|
+
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
95
98
|
"""
|
|
96
99
|
super().__init__()
|
|
97
100
|
self.inputs = inputs
|
|
@@ -102,7 +105,8 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
102
105
|
self.init_workers = init_workers if init_workers > 0 else self.num_workers
|
|
103
106
|
self.name = name
|
|
104
107
|
self.retries = retries
|
|
105
|
-
self.
|
|
108
|
+
self.use_in_memory_all_crops_dataset = use_in_memory_all_crops_dataset
|
|
109
|
+
self.index_mode = index_mode
|
|
106
110
|
self.split_configs = {
|
|
107
111
|
"train": default_config.update(train_config),
|
|
108
112
|
"val": default_config.update(val_config),
|
|
@@ -111,15 +115,15 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
111
115
|
}
|
|
112
116
|
|
|
113
117
|
def setup(
|
|
114
|
-
self, stage: str,
|
|
118
|
+
self, stage: str, use_in_memory_all_crops_dataset: bool | None = None
|
|
115
119
|
) -> None:
|
|
116
120
|
"""Set up datasets and samplers.
|
|
117
121
|
|
|
118
122
|
Args:
|
|
119
123
|
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
|
120
|
-
|
|
121
|
-
instead of
|
|
122
|
-
If None, uses the value of self.
|
|
124
|
+
use_in_memory_all_crops_dataset: whether to use InMemoryAllCropsDataset
|
|
125
|
+
instead of IterableAllCropsDataset if load_all_crops is set to true.
|
|
126
|
+
If None, uses the value of self.use_in_memory_all_crops_dataset.
|
|
123
127
|
"""
|
|
124
128
|
stage_to_splits = {
|
|
125
129
|
"fit": ["train", "val"],
|
|
@@ -138,36 +142,37 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
138
142
|
workers=self.init_workers,
|
|
139
143
|
name=self.name,
|
|
140
144
|
fix_patch_pick=(split != "train"),
|
|
145
|
+
index_mode=self.index_mode,
|
|
141
146
|
)
|
|
142
147
|
logger.info(f"got {len(dataset)} examples in split {split}")
|
|
143
|
-
if split_config.
|
|
144
|
-
if
|
|
145
|
-
|
|
146
|
-
self.
|
|
148
|
+
if split_config.get_load_all_crops():
|
|
149
|
+
if use_in_memory_all_crops_dataset is None:
|
|
150
|
+
use_in_memory_all_crops_dataset = (
|
|
151
|
+
self.use_in_memory_all_crops_dataset
|
|
147
152
|
)
|
|
148
153
|
logger.info(
|
|
149
|
-
f"using
|
|
154
|
+
f"using AllCropsDataset (in_memory={use_in_memory_all_crops_dataset})"
|
|
150
155
|
)
|
|
151
|
-
|
|
152
|
-
if
|
|
156
|
+
crop_size = split_config.get_crop_size()
|
|
157
|
+
if crop_size is None:
|
|
153
158
|
raise ValueError(
|
|
154
|
-
"
|
|
159
|
+
"crop_size is not set but must be set if load_all_crops is set"
|
|
155
160
|
)
|
|
156
161
|
|
|
157
|
-
|
|
162
|
+
all_crops_cls = IterableAllCropsDataset
|
|
158
163
|
kwargs = dict(
|
|
159
164
|
dataset=dataset,
|
|
160
|
-
|
|
161
|
-
|
|
165
|
+
crop_size=crop_size,
|
|
166
|
+
overlap_pixels=split_config.get_overlap_pixels(),
|
|
162
167
|
rank=self.trainer.global_rank if self.trainer else 0,
|
|
163
168
|
world_size=self.trainer.world_size if self.trainer else 1,
|
|
164
169
|
)
|
|
165
|
-
if
|
|
170
|
+
if use_in_memory_all_crops_dataset:
|
|
166
171
|
kwargs.pop("rank")
|
|
167
172
|
kwargs.pop("world_size")
|
|
168
|
-
|
|
173
|
+
all_crops_cls = InMemoryAllCropsDataset # type: ignore
|
|
169
174
|
|
|
170
|
-
dataset =
|
|
175
|
+
dataset = all_crops_cls(**kwargs) # type: ignore
|
|
171
176
|
|
|
172
177
|
if self.retries > 0:
|
|
173
178
|
dataset = RetryDataset(dataset, retries=self.retries)
|
|
@@ -204,7 +209,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
204
209
|
# If the number of windows is 0, then we can set positive number of workers
|
|
205
210
|
# since they won't yield anything anyway.
|
|
206
211
|
num_workers = self.num_workers
|
|
207
|
-
if split_config.
|
|
212
|
+
if split_config.load_all_crops and len(dataset.get_dataset_examples()) > 0:
|
|
208
213
|
num_workers = min(num_workers, len(dataset.get_dataset_examples()))
|
|
209
214
|
|
|
210
215
|
kwargs: dict[str, Any] = dict(
|
|
@@ -352,7 +357,7 @@ class MultiDatasetDataModule(L.LightningDataModule):
|
|
|
352
357
|
stage: The stage to set up ('fit', 'validate', 'test', 'predict')
|
|
353
358
|
"""
|
|
354
359
|
for name, data_module in self.data_modules.items():
|
|
355
|
-
data_module.setup(stage,
|
|
360
|
+
data_module.setup(stage, use_in_memory_all_crops_dataset=True) # type: ignore
|
|
356
361
|
data_module.set_name(name)
|
|
357
362
|
|
|
358
363
|
def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]:
|