rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- """Wrapper around ModelDataset to load all patches (crops) in a window."""
1
+ """Wrapper around ModelDataset to load all crops in a window."""
2
2
 
3
3
  import itertools
4
4
  from collections.abc import Iterable, Iterator
@@ -14,70 +14,78 @@ from rslearn.train.model_context import RasterImage, SampleMetadata
14
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
15
15
 
16
16
 
17
- def get_window_patch_options(
18
- patch_size: tuple[int, int],
17
+ def get_window_crop_options(
18
+ crop_size: tuple[int, int],
19
19
  overlap_size: tuple[int, int],
20
20
  bounds: PixelBounds,
21
21
  ) -> list[PixelBounds]:
22
- """Get the bounds of each input patch within the window bounds.
22
+ """Get the bounds of each input crop within the window bounds.
23
23
 
24
- This is used when running inference on all patches (crops) of a large window, to
25
- compute the position of each patch.
24
+ This is used when running inference on all crops of a large window, to
25
+ compute the position of each crop.
26
26
 
27
27
  Args:
28
- patch_size: the size of the patches to extract.
29
- overlap_size: the size of the overlap between patches.
30
- bounds: the window bounds to divide up into smaller patches.
28
+ crop_size: the size of the crops to extract.
29
+ overlap_size: the size of the overlap between crops.
30
+ bounds: the window bounds to divide up into smaller crops.
31
31
 
32
32
  Returns:
33
- a list of patch bounds within the overall bounds. The rightmost and
34
- bottommost patches may extend beyond the provided bounds.
33
+ a list of crop bounds within the overall bounds. The rightmost and
34
+ bottommost crops may extend beyond the provided bounds.
35
35
  """
36
- # We stride the patches by patch_size - overlap_size until the last patch.
37
- # We handle the first patch with a special case to ensure it is always used.
38
- # We handle the last patch with a special case to ensure it does not exceed the
39
- # window bounds. Instead, it may overlap the previous patch.
36
+ # We stride the crops by (crop_size - overlap_size) until the last crop.
37
+ # The first crop always starts at bounds[0]/bounds[1]. It's okay if it extends
38
+ # beyond the window bounds since pad_slice_protect pads the tensors.
39
+ # We handle the last crop with a special case to ensure it does not exceed the
40
+ # window bounds. Instead, it may overlap the previous crop.
41
+ # Here is a simple 1D example:
42
+ # - Suppose bounds is [0, 15] with crop_size=8, overlap_size=2
43
+ # - Then the first crop should be [0, 8] (from first crop special case)
44
+ # - There will only be one crop in the middle, [6, 14]
45
+ # - And the last crop will be at [7, 15]
46
+ # - Note that, if the bounds was [0, 14], we will only have the first/last crop
47
+ # special cases with no crops in the middle from the range(...).
40
48
  cols = [bounds[0]] + list(
41
49
  range(
42
- bounds[0] + patch_size[0],
43
- bounds[2] - patch_size[0],
44
- patch_size[0] - overlap_size[0],
50
+ bounds[0] + crop_size[0] - overlap_size[0],
51
+ bounds[2] - crop_size[0],
52
+ crop_size[0] - overlap_size[0],
45
53
  )
46
54
  )
47
55
  rows = [bounds[1]] + list(
48
56
  range(
49
- bounds[1] + patch_size[1],
50
- bounds[3] - patch_size[1],
51
- patch_size[1] - overlap_size[1],
57
+ bounds[1] + crop_size[1] - overlap_size[1],
58
+ bounds[3] - crop_size[1],
59
+ crop_size[1] - overlap_size[1],
52
60
  )
53
61
  )
54
- # Add last patches only if the input is larger than one patch.
55
- if bounds[2] - patch_size[0] > bounds[0]:
56
- cols.append(bounds[2] - patch_size[0])
57
- if bounds[3] - patch_size[1] > bounds[1]:
58
- rows.append(bounds[3] - patch_size[1])
62
+ # Add last crops only if the input is larger than one crop.
63
+ if bounds[2] - crop_size[0] > bounds[0]:
64
+ cols.append(bounds[2] - crop_size[0])
65
+ if bounds[3] - crop_size[1] > bounds[1]:
66
+ rows.append(bounds[3] - crop_size[1])
59
67
 
60
- patch_bounds: list[PixelBounds] = []
68
+ crop_bounds: list[PixelBounds] = []
61
69
  for col in cols:
62
70
  for row in rows:
63
- patch_bounds.append((col, row, col + patch_size[0], row + patch_size[1]))
64
- return patch_bounds
71
+ crop_bounds.append((col, row, col + crop_size[0], row + crop_size[1]))
72
+ return crop_bounds
65
73
 
66
74
 
67
75
  def pad_slice_protect(
68
76
  raw_inputs: dict[str, Any],
69
77
  passthrough_inputs: dict[str, Any],
70
- patch_size: tuple[int, int],
78
+ crop_size: tuple[int, int],
71
79
  inputs: dict[str, DataInput],
72
80
  ) -> tuple[dict[str, Any], dict[str, Any]]:
73
- """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
81
+ """Pad tensors in-place by crop size to protect slicing near right/bottom edges.
74
82
 
75
83
  The padding is scaled based on each input's resolution_factor.
76
84
 
77
85
  Args:
78
86
  raw_inputs: the raw inputs to pad.
79
87
  passthrough_inputs: the passthrough inputs to pad.
80
- patch_size: the size of the patches to extract (at window resolution).
88
+ crop_size: the size of the crops to extract (at window resolution).
81
89
  inputs: the DataInput definitions, used to get resolution_factor per input.
82
90
 
83
91
  Returns:
@@ -91,8 +99,8 @@ def pad_slice_protect(
91
99
  rf = inputs[input_name].resolution_factor
92
100
  scale = rf.numerator / rf.denominator
93
101
  # Scale the padding amount
94
- scaled_pad_x = int(patch_size[0] * scale)
95
- scaled_pad_y = int(patch_size[1] * scale)
102
+ scaled_pad_x = int(crop_size[0] * scale)
103
+ scaled_pad_y = int(crop_size[1] * scale)
96
104
  d[input_name] = torch.nn.functional.pad(
97
105
  value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
98
106
  )
@@ -123,12 +131,12 @@ def crop_tensor_or_rasterimage(
123
131
  )
124
132
 
125
133
 
126
- class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
127
- """This wraps a ModelDataset to iterate over all patches in that dataset.
134
+ class IterableAllCropsDataset(torch.utils.data.IterableDataset):
135
+ """This wraps a ModelDataset to iterate over all crops in that dataset.
128
136
 
129
- This should be used when SplitConfig.load_all_patches is enabled. The ModelDataset
130
- is configured with no patch size (load entire windows), and the dataset is wrapped
131
- in an AllPatchesDataset.
137
+ This should be used when SplitConfig.load_all_crops is enabled. The ModelDataset
138
+ is configured with no crop size (load entire windows), and the dataset is wrapped
139
+ in an AllCropsDataset.
132
140
 
133
141
  Similar to DistributedSampler, we add extra samples at each rank to ensure
134
142
  consistent number of batches across all ranks.
@@ -137,29 +145,27 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
137
145
  def __init__(
138
146
  self,
139
147
  dataset: ModelDataset,
140
- patch_size: tuple[int, int],
141
- overlap_ratio: float = 0.0,
148
+ crop_size: tuple[int, int],
149
+ overlap_pixels: int = 0,
142
150
  rank: int = 0,
143
151
  world_size: int = 1,
144
152
  ):
145
- """Create a new IterableAllPatchesDataset.
153
+ """Create a new IterableAllCropsDataset.
146
154
 
147
155
  Args:
148
156
  dataset: the ModelDataset to wrap.
149
- patch_size: the size of the patches to extract.
150
- overlap_ratio: whether to include overlap between the patches. Note that
151
- the right/bottom-most patches may still overlap since we ensure that
152
- all patches are contained in the window bounds.
157
+ crop_size: the size of the crops to extract.
158
+ overlap_pixels: the number of pixels shared between adjacent crops. Note
159
+ that the right/bottom-most crops may still overlap with other crops even
160
+ if overlap_pixels=0 since we ensure that all crops are contained in the
161
+ window bounds.
153
162
  rank: the global rank of this train worker process.
154
163
  world_size: the total number of train worker processes.
155
164
  """
156
165
  super().__init__()
157
166
  self.dataset = dataset
158
- self.patch_size = patch_size
159
- self.overlap_size = (
160
- round(self.patch_size[0] * overlap_ratio),
161
- round(self.patch_size[1] * overlap_ratio),
162
- )
167
+ self.crop_size = crop_size
168
+ self.overlap_size = (overlap_pixels, overlap_pixels)
163
169
  self.rank = rank
164
170
  self.world_size = world_size
165
171
  self.windows = self.dataset.get_dataset_examples()
@@ -173,17 +179,17 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
173
179
  """
174
180
  self.dataset.set_name(name)
175
181
 
176
- def get_window_num_patches(self, bounds: PixelBounds) -> int:
177
- """Get the number of patches for these bounds.
182
+ def get_window_num_crops(self, bounds: PixelBounds) -> int:
183
+ """Get the number of crops for these bounds.
178
184
 
179
- This corresponds to the length of the list returned by get_patch_options.
185
+ This corresponds to the length of the list returned by get_window_crop_options.
180
186
  """
181
187
  num_cols = (
182
188
  len(
183
189
  range(
184
190
  bounds[0],
185
- bounds[2] - self.patch_size[0],
186
- self.patch_size[0] - self.overlap_size[0],
191
+ bounds[2] - self.crop_size[0],
192
+ self.crop_size[0] - self.overlap_size[0],
187
193
  )
188
194
  )
189
195
  + 1
@@ -192,8 +198,8 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
192
198
  len(
193
199
  range(
194
200
  bounds[1],
195
- bounds[3] - self.patch_size[1],
196
- self.patch_size[1] - self.overlap_size[1],
201
+ bounds[3] - self.crop_size[1],
202
+ self.crop_size[1] - self.overlap_size[1],
197
203
  )
198
204
  )
199
205
  + 1
@@ -235,14 +241,14 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
235
241
  ]
236
242
 
237
243
  # Now compute the maximum number of samples across workers.
238
- max_num_patches = 0
244
+ max_num_crops = 0
239
245
  for worker_windows in windows_by_worker:
240
- worker_num_patches = 0
246
+ worker_num_crops = 0
241
247
  for window_id in worker_windows:
242
- worker_num_patches += self.get_window_num_patches(
248
+ worker_num_crops += self.get_window_num_crops(
243
249
  self.windows[window_id].bounds
244
250
  )
245
- max_num_patches = max(max_num_patches, worker_num_patches)
251
+ max_num_crops = max(max_num_crops, worker_num_crops)
246
252
 
247
253
  # Each worker needs at least one window, otherwise it won't be able to pad.
248
254
  # Unless there are zero windows total, which is fine.
@@ -252,17 +258,17 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
252
258
  # window in the end.
253
259
  # So now we raise an error instead, and require the number of workers to be
254
260
  # less than the number of windows.
255
- if len(windows_by_worker[global_worker_id]) == 0 and max_num_patches > 0:
261
+ if len(windows_by_worker[global_worker_id]) == 0 and max_num_crops > 0:
256
262
  raise ValueError(
257
263
  f"the number of workers {global_num_workers} must be <= the number of windows {len(self.windows)}"
258
264
  )
259
265
 
260
- return (windows_by_worker[global_worker_id], max_num_patches)
266
+ return (windows_by_worker[global_worker_id], max_num_crops)
261
267
 
262
268
  def __iter__(
263
269
  self,
264
270
  ) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
265
- """Iterate over all patches in each element of the underlying ModelDataset."""
271
+ """Iterate over all crops in each element of the underlying ModelDataset."""
266
272
  # Iterate over the window IDs until we have returned enough samples.
267
273
  window_ids, num_samples_needed = self._get_worker_iteration_data()
268
274
  num_samples_returned = 0
@@ -272,32 +278,32 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
272
278
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
273
279
  window_id
274
280
  )
275
- bounds = metadata.patch_bounds
281
+ bounds = metadata.crop_bounds
276
282
 
277
- # For simplicity, pad tensors by patch size to ensure that any patch bounds
283
+ # For simplicity, pad tensors by crop size to ensure that any crop bounds
278
284
  # extending outside the window bounds will not have issues when we slice
279
285
  # the tensors later. Padding is scaled per-input based on resolution_factor.
280
286
  pad_slice_protect(
281
- raw_inputs, passthrough_inputs, self.patch_size, self.inputs
287
+ raw_inputs, passthrough_inputs, self.crop_size, self.inputs
282
288
  )
283
289
 
284
- # Now iterate over the patches and extract/yield the crops.
290
+ # Now iterate over the crops and extract/yield them.
285
291
  # Note that, in case user is leveraging RslearnWriter, it is important that
286
- # the patch_idx be increasing (as we iterate) within one window.
287
- patches = get_window_patch_options(
288
- self.patch_size, self.overlap_size, bounds
292
+ # the crop_idx be increasing (as we iterate) within one window.
293
+ crops = get_window_crop_options(
294
+ self.crop_size, self.overlap_size, bounds
289
295
  )
290
- for patch_idx, patch_bounds in enumerate(patches):
296
+ for crop_idx, crop_bounds in enumerate(crops):
291
297
  cur_geom = STGeometry(
292
- metadata.projection, shapely.box(*patch_bounds), None
298
+ metadata.projection, shapely.box(*crop_bounds), None
293
299
  )
294
300
  start_offset = (
295
- patch_bounds[0] - bounds[0],
296
- patch_bounds[1] - bounds[1],
301
+ crop_bounds[0] - bounds[0],
302
+ crop_bounds[1] - bounds[1],
297
303
  )
298
304
  end_offset = (
299
- patch_bounds[2] - bounds[0],
300
- patch_bounds[3] - bounds[1],
305
+ crop_bounds[2] - bounds[0],
306
+ crop_bounds[3] - bounds[1],
301
307
  )
302
308
 
303
309
  # Define a helper function to handle each input dict.
@@ -339,9 +345,9 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
339
345
  # Adjust the metadata as well.
340
346
  cur_metadata = replace(
341
347
  metadata,
342
- patch_bounds=patch_bounds,
343
- patch_idx=patch_idx,
344
- num_patches_in_window=len(patches),
348
+ crop_bounds=crop_bounds,
349
+ crop_idx=crop_idx,
350
+ num_crops_in_window=len(crops),
345
351
  )
346
352
 
347
353
  # Now we can compute input and target dicts via the task.
@@ -369,37 +375,34 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
369
375
  return self.dataset.get_dataset_examples()
370
376
 
371
377
 
372
- class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
373
- """This wraps a ModelDataset to iterate over all patches in that dataset.
378
+ class InMemoryAllCropsDataset(torch.utils.data.Dataset):
379
+ """This wraps a ModelDataset to iterate over all crops in that dataset.
374
380
 
375
- This should be used when SplitConfig.load_all_patches is enabled.
381
+ This should be used when SplitConfig.load_all_crops is enabled.
376
382
 
377
- This is a simpler version of IterableAllPatchesDataset that caches all windows in memory.
383
+ This is a simpler version of IterableAllCropsDataset that caches all windows in memory.
378
384
  This is useful for small datasets that fit in memory.
379
385
  """
380
386
 
381
387
  def __init__(
382
388
  self,
383
389
  dataset: ModelDataset,
384
- patch_size: tuple[int, int],
385
- overlap_ratio: float = 0.0,
390
+ crop_size: tuple[int, int],
391
+ overlap_pixels: int = 0,
386
392
  ):
387
- """Create a new InMemoryAllPatchesDataset.
393
+ """Create a new InMemoryAllCropsDataset.
388
394
 
389
395
  Args:
390
396
  dataset: the ModelDataset to wrap.
391
- patch_size: the size of the patches to extract.
392
- overlap_ratio: whether to include overlap between the patches. Note that
393
- the right/bottom-most patches may still overlap since we ensure that
394
- all patches are contained in the window bounds.
397
+ crop_size: the size of the crops to extract.
398
+ overlap_pixels: the number of pixels shared between adjacent crops. Note
399
+ that the right/bottom-most crops may still overlap since we ensure that
400
+ all crops are contained in the window bounds.
395
401
  """
396
402
  super().__init__()
397
403
  self.dataset = dataset
398
- self.patch_size = patch_size
399
- self.overlap_size = (
400
- round(self.patch_size[0] * overlap_ratio),
401
- round(self.patch_size[1] * overlap_ratio),
402
- )
404
+ self.crop_size = crop_size
405
+ self.overlap_size = (overlap_pixels, overlap_pixels)
403
406
  self.windows = self.dataset.get_dataset_examples()
404
407
  self.inputs = dataset.inputs
405
408
  self.window_cache: dict[
@@ -407,23 +410,23 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
407
410
  ] = {}
408
411
 
409
412
  # Precompute the batch boundaries for each window
410
- self.patches = []
413
+ self.crops = []
411
414
  for window_id, window in enumerate(self.windows):
412
- patch_bounds = get_window_patch_options(
413
- self.patch_size, self.overlap_size, window.bounds
415
+ window_crop_bounds = get_window_crop_options(
416
+ self.crop_size, self.overlap_size, window.bounds
414
417
  )
415
- for i, patch_bound in enumerate(patch_bounds):
416
- self.patches.append((window_id, patch_bound, (i, len(patch_bounds))))
418
+ for i, crop_bound in enumerate(window_crop_bounds):
419
+ self.crops.append((window_id, crop_bound, (i, len(window_crop_bounds))))
417
420
 
418
421
  def get_raw_inputs(
419
422
  self, index: int
420
423
  ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
421
- """Get the raw inputs for a single patch. Retrieve from cache if possible.
424
+ """Get the raw inputs for a single crop. Retrieve from cache if possible.
422
425
 
423
- Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
426
+ Also crops/pads the tensors by crop size to protect slicing near right/bottom edges.
424
427
 
425
428
  Args:
426
- index: the index of the patch.
429
+ index: the index of the crop.
427
430
 
428
431
  Returns:
429
432
  a tuple of (raw_inputs, passthrough_inputs, metadata).
@@ -432,7 +435,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
432
435
  return self.window_cache[index]
433
436
 
434
437
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
435
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
438
+ pad_slice_protect(raw_inputs, passthrough_inputs, self.crop_size, self.inputs)
436
439
 
437
440
  self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
438
441
  return self.window_cache[index]
@@ -476,20 +479,20 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
476
479
  return cropped
477
480
 
478
481
  def __len__(self) -> int:
479
- """Return the total number of patches in the dataset."""
480
- return len(self.patches)
482
+ """Return the total number of crops in the dataset."""
483
+ return len(self.crops)
481
484
 
482
485
  def __getitem__(
483
486
  self, index: int
484
487
  ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
485
- """Return (input_dict, target_dict, metadata) for a single flattened patch."""
486
- (window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
488
+ """Return (input_dict, target_dict, metadata) for a single flattened crop."""
489
+ (window_id, crop_bounds, (crop_idx, num_crops)) = self.crops[index]
487
490
  raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
488
- bounds = metadata.patch_bounds
491
+ bounds = metadata.crop_bounds
489
492
 
490
- cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
491
- start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
492
- end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
493
+ cur_geom = STGeometry(metadata.projection, shapely.box(*crop_bounds), None)
494
+ start_offset = (crop_bounds[0] - bounds[0], crop_bounds[1] - bounds[1])
495
+ end_offset = (crop_bounds[2] - bounds[0], crop_bounds[3] - bounds[1])
493
496
 
494
497
  cur_raw_inputs = self._crop_input_dict(
495
498
  raw_inputs, start_offset, end_offset, cur_geom
@@ -501,9 +504,9 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
501
504
  # Adjust the metadata as well.
502
505
  cur_metadata = replace(
503
506
  metadata,
504
- patch_bounds=patch_bounds,
505
- patch_idx=patch_idx,
506
- num_patches_in_window=num_patches,
507
+ crop_bounds=crop_bounds,
508
+ crop_idx=crop_idx,
509
+ num_crops_in_window=num_crops,
507
510
  )
508
511
 
509
512
  # Now we can compute input and target dicts via the task.
@@ -15,12 +15,13 @@ from rslearn.dataset import Dataset
15
15
  from rslearn.log_utils import get_logger
16
16
  from rslearn.train.tasks import Task
17
17
 
18
- from .all_patches_dataset import (
19
- InMemoryAllPatchesDataset,
20
- IterableAllPatchesDataset,
18
+ from .all_crops_dataset import (
19
+ InMemoryAllCropsDataset,
20
+ IterableAllCropsDataset,
21
21
  )
22
22
  from .dataset import (
23
23
  DataInput,
24
+ IndexMode,
24
25
  ModelDataset,
25
26
  MultiDataset,
26
27
  RetryDataset,
@@ -68,7 +69,8 @@ class RslearnDataModule(L.LightningDataModule):
68
69
  predict_config: SplitConfig = SplitConfig(),
69
70
  name: str | None = None,
70
71
  retries: int = 0,
71
- use_in_memory_all_patches_dataset: bool = False,
72
+ use_in_memory_all_crops_dataset: bool = False,
73
+ index_mode: IndexMode = IndexMode.OFF,
72
74
  ) -> None:
73
75
  """Initialize a new RslearnDataModule.
74
76
 
@@ -90,8 +92,9 @@ class RslearnDataModule(L.LightningDataModule):
90
92
  predict_config: split config for predict split
91
93
  name: name of the dataset
92
94
  retries: number of retries to attempt for getitem calls
93
- use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
94
- instead of IterableAllPatchesDataset if load_all_patches is set to true.
95
+ use_in_memory_all_crops_dataset: whether to use InMemoryAllCropsDataset
96
+ instead of IterableAllCropsDataset if load_all_crops is set to true.
97
+ index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
95
98
  """
96
99
  super().__init__()
97
100
  self.inputs = inputs
@@ -102,7 +105,8 @@ class RslearnDataModule(L.LightningDataModule):
102
105
  self.init_workers = init_workers if init_workers > 0 else self.num_workers
103
106
  self.name = name
104
107
  self.retries = retries
105
- self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
108
+ self.use_in_memory_all_crops_dataset = use_in_memory_all_crops_dataset
109
+ self.index_mode = index_mode
106
110
  self.split_configs = {
107
111
  "train": default_config.update(train_config),
108
112
  "val": default_config.update(val_config),
@@ -111,15 +115,15 @@ class RslearnDataModule(L.LightningDataModule):
111
115
  }
112
116
 
113
117
  def setup(
114
- self, stage: str, use_in_memory_all_patches_dataset: bool | None = None
118
+ self, stage: str, use_in_memory_all_crops_dataset: bool | None = None
115
119
  ) -> None:
116
120
  """Set up datasets and samplers.
117
121
 
118
122
  Args:
119
123
  stage: Either 'fit', 'validate', 'test', or 'predict'.
120
- use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
121
- instead of IterableAllPatchesDataset if load_all_patches is set to true.
122
- If None, uses the value of self.use_in_memory_all_patches_dataset.
124
+ use_in_memory_all_crops_dataset: whether to use InMemoryAllCropsDataset
125
+ instead of IterableAllCropsDataset if load_all_crops is set to true.
126
+ If None, uses the value of self.use_in_memory_all_crops_dataset.
123
127
  """
124
128
  stage_to_splits = {
125
129
  "fit": ["train", "val"],
@@ -138,36 +142,37 @@ class RslearnDataModule(L.LightningDataModule):
138
142
  workers=self.init_workers,
139
143
  name=self.name,
140
144
  fix_patch_pick=(split != "train"),
145
+ index_mode=self.index_mode,
141
146
  )
142
147
  logger.info(f"got {len(dataset)} examples in split {split}")
143
- if split_config.get_load_all_patches():
144
- if use_in_memory_all_patches_dataset is None:
145
- use_in_memory_all_patches_dataset = (
146
- self.use_in_memory_all_patches_dataset
148
+ if split_config.get_load_all_crops():
149
+ if use_in_memory_all_crops_dataset is None:
150
+ use_in_memory_all_crops_dataset = (
151
+ self.use_in_memory_all_crops_dataset
147
152
  )
148
153
  logger.info(
149
- f"using AllPatchesDataset (in_memory={use_in_memory_all_patches_dataset})"
154
+ f"using AllCropsDataset (in_memory={use_in_memory_all_crops_dataset})"
150
155
  )
151
- patch_size = split_config.get_patch_size()
152
- if patch_size is None:
156
+ crop_size = split_config.get_crop_size()
157
+ if crop_size is None:
153
158
  raise ValueError(
154
- "patch_size is not set but must be set if load_all_patches is set"
159
+ "crop_size is not set but must be set if load_all_crops is set"
155
160
  )
156
161
 
157
- all_patches_cls = IterableAllPatchesDataset
162
+ all_crops_cls = IterableAllCropsDataset
158
163
  kwargs = dict(
159
164
  dataset=dataset,
160
- patch_size=patch_size,
161
- overlap_ratio=split_config.get_overlap_ratio(),
165
+ crop_size=crop_size,
166
+ overlap_pixels=split_config.get_overlap_pixels(),
162
167
  rank=self.trainer.global_rank if self.trainer else 0,
163
168
  world_size=self.trainer.world_size if self.trainer else 1,
164
169
  )
165
- if use_in_memory_all_patches_dataset:
170
+ if use_in_memory_all_crops_dataset:
166
171
  kwargs.pop("rank")
167
172
  kwargs.pop("world_size")
168
- all_patches_cls = InMemoryAllPatchesDataset # type: ignore
173
+ all_crops_cls = InMemoryAllCropsDataset # type: ignore
169
174
 
170
- dataset = all_patches_cls(**kwargs) # type: ignore
175
+ dataset = all_crops_cls(**kwargs) # type: ignore
171
176
 
172
177
  if self.retries > 0:
173
178
  dataset = RetryDataset(dataset, retries=self.retries)
@@ -204,7 +209,7 @@ class RslearnDataModule(L.LightningDataModule):
204
209
  # If the number of windows is 0, then we can set positive number of workers
205
210
  # since they won't yield anything anyway.
206
211
  num_workers = self.num_workers
207
- if split_config.load_all_patches and len(dataset.get_dataset_examples()) > 0:
212
+ if split_config.load_all_crops and len(dataset.get_dataset_examples()) > 0:
208
213
  num_workers = min(num_workers, len(dataset.get_dataset_examples()))
209
214
 
210
215
  kwargs: dict[str, Any] = dict(
@@ -352,7 +357,7 @@ class MultiDatasetDataModule(L.LightningDataModule):
352
357
  stage: The stage to set up ('fit', 'validate', 'test', 'predict')
353
358
  """
354
359
  for name, data_module in self.data_modules.items():
355
- data_module.setup(stage, use_in_memory_all_patches_dataset=True) # type: ignore
360
+ data_module.setup(stage, use_in_memory_all_crops_dataset=True) # type: ignore
356
361
  data_module.set_name(name)
357
362
 
358
363
  def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]: