rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/models/concatenate_features.py +6 -1
  30. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  31. rslearn/train/data_module.py +27 -27
  32. rslearn/train/dataset.py +109 -62
  33. rslearn/train/lightning_module.py +1 -1
  34. rslearn/train/model_context.py +3 -3
  35. rslearn/train/prediction_writer.py +69 -41
  36. rslearn/train/tasks/classification.py +1 -1
  37. rslearn/train/tasks/detection.py +5 -5
  38. rslearn/train/tasks/regression.py +1 -1
  39. rslearn/utils/__init__.py +2 -0
  40. rslearn/utils/geometry.py +21 -0
  41. rslearn/utils/m2m_api.py +251 -0
  42. rslearn/utils/retry_session.py +43 -0
  43. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  44. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
  45. rslearn/data_sources/earthdata_srtm.py +0 -282
  46. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- """Wrapper around ModelDataset to load all patches (crops) in a window."""
1
+ """Wrapper around ModelDataset to load all crops in a window."""
2
2
 
3
3
  import itertools
4
4
  from collections.abc import Iterable, Iterator
@@ -14,70 +14,78 @@ from rslearn.train.model_context import RasterImage, SampleMetadata
14
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
15
15
 
16
16
 
17
- def get_window_patch_options(
18
- patch_size: tuple[int, int],
17
+ def get_window_crop_options(
18
+ crop_size: tuple[int, int],
19
19
  overlap_size: tuple[int, int],
20
20
  bounds: PixelBounds,
21
21
  ) -> list[PixelBounds]:
22
- """Get the bounds of each input patch within the window bounds.
22
+ """Get the bounds of each input crop within the window bounds.
23
23
 
24
- This is used when running inference on all patches (crops) of a large window, to
25
- compute the position of each patch.
24
+ This is used when running inference on all crops of a large window, to
25
+ compute the position of each crop.
26
26
 
27
27
  Args:
28
- patch_size: the size of the patches to extract.
29
- overlap_size: the size of the overlap between patches.
30
- bounds: the window bounds to divide up into smaller patches.
28
+ crop_size: the size of the crops to extract.
29
+ overlap_size: the size of the overlap between crops.
30
+ bounds: the window bounds to divide up into smaller crops.
31
31
 
32
32
  Returns:
33
- a list of patch bounds within the overall bounds. The rightmost and
34
- bottommost patches may extend beyond the provided bounds.
33
+ a list of crop bounds within the overall bounds. The rightmost and
34
+ bottommost crops may extend beyond the provided bounds.
35
35
  """
36
- # We stride the patches by patch_size - overlap_size until the last patch.
37
- # We handle the first patch with a special case to ensure it is always used.
38
- # We handle the last patch with a special case to ensure it does not exceed the
39
- # window bounds. Instead, it may overlap the previous patch.
36
+ # We stride the crops by (crop_size - overlap_size) until the last crop.
37
+ # The first crop always starts at bounds[0]/bounds[1]. It's okay if it extends
38
+ # beyond the window bounds since pad_slice_protect pads the tensors.
39
+ # We handle the last crop with a special case to ensure it does not exceed the
40
+ # window bounds. Instead, it may overlap the previous crop.
41
+ # Here is a simple 1D example:
42
+ # - Suppose bounds is [0, 15] with crop_size=8, overlap_size=2
43
+ # - Then the first crop should be [0, 8] (from first crop special case)
44
+ # - There will only be one crop in the middle, [6, 14]
45
+ # - And the last crop will be at [7, 15]
46
+ # - Note that, if the bounds was [0, 14], we will only have the first/last crop
47
+ # special cases with no crops in the middle from the range(...).
40
48
  cols = [bounds[0]] + list(
41
49
  range(
42
- bounds[0] + patch_size[0],
43
- bounds[2] - patch_size[0],
44
- patch_size[0] - overlap_size[0],
50
+ bounds[0] + crop_size[0] - overlap_size[0],
51
+ bounds[2] - crop_size[0],
52
+ crop_size[0] - overlap_size[0],
45
53
  )
46
54
  )
47
55
  rows = [bounds[1]] + list(
48
56
  range(
49
- bounds[1] + patch_size[1],
50
- bounds[3] - patch_size[1],
51
- patch_size[1] - overlap_size[1],
57
+ bounds[1] + crop_size[1] - overlap_size[1],
58
+ bounds[3] - crop_size[1],
59
+ crop_size[1] - overlap_size[1],
52
60
  )
53
61
  )
54
- # Add last patches only if the input is larger than one patch.
55
- if bounds[2] - patch_size[0] > bounds[0]:
56
- cols.append(bounds[2] - patch_size[0])
57
- if bounds[3] - patch_size[1] > bounds[1]:
58
- rows.append(bounds[3] - patch_size[1])
62
+ # Add last crops only if the input is larger than one crop.
63
+ if bounds[2] - crop_size[0] > bounds[0]:
64
+ cols.append(bounds[2] - crop_size[0])
65
+ if bounds[3] - crop_size[1] > bounds[1]:
66
+ rows.append(bounds[3] - crop_size[1])
59
67
 
60
- patch_bounds: list[PixelBounds] = []
68
+ crop_bounds: list[PixelBounds] = []
61
69
  for col in cols:
62
70
  for row in rows:
63
- patch_bounds.append((col, row, col + patch_size[0], row + patch_size[1]))
64
- return patch_bounds
71
+ crop_bounds.append((col, row, col + crop_size[0], row + crop_size[1]))
72
+ return crop_bounds
65
73
 
66
74
 
67
75
  def pad_slice_protect(
68
76
  raw_inputs: dict[str, Any],
69
77
  passthrough_inputs: dict[str, Any],
70
- patch_size: tuple[int, int],
78
+ crop_size: tuple[int, int],
71
79
  inputs: dict[str, DataInput],
72
80
  ) -> tuple[dict[str, Any], dict[str, Any]]:
73
- """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
81
+ """Pad tensors in-place by crop size to protect slicing near right/bottom edges.
74
82
 
75
83
  The padding is scaled based on each input's resolution_factor.
76
84
 
77
85
  Args:
78
86
  raw_inputs: the raw inputs to pad.
79
87
  passthrough_inputs: the passthrough inputs to pad.
80
- patch_size: the size of the patches to extract (at window resolution).
88
+ crop_size: the size of the crops to extract (at window resolution).
81
89
  inputs: the DataInput definitions, used to get resolution_factor per input.
82
90
 
83
91
  Returns:
@@ -91,8 +99,8 @@ def pad_slice_protect(
91
99
  rf = inputs[input_name].resolution_factor
92
100
  scale = rf.numerator / rf.denominator
93
101
  # Scale the padding amount
94
- scaled_pad_x = int(patch_size[0] * scale)
95
- scaled_pad_y = int(patch_size[1] * scale)
102
+ scaled_pad_x = int(crop_size[0] * scale)
103
+ scaled_pad_y = int(crop_size[1] * scale)
96
104
  d[input_name] = torch.nn.functional.pad(
97
105
  value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
98
106
  )
@@ -123,12 +131,12 @@ def crop_tensor_or_rasterimage(
123
131
  )
124
132
 
125
133
 
126
- class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
127
- """This wraps a ModelDataset to iterate over all patches in that dataset.
134
+ class IterableAllCropsDataset(torch.utils.data.IterableDataset):
135
+ """This wraps a ModelDataset to iterate over all crops in that dataset.
128
136
 
129
- This should be used when SplitConfig.load_all_patches is enabled. The ModelDataset
130
- is configured with no patch size (load entire windows), and the dataset is wrapped
131
- in an AllPatchesDataset.
137
+ This should be used when SplitConfig.load_all_crops is enabled. The ModelDataset
138
+ is configured with no crop size (load entire windows), and the dataset is wrapped
139
+ in an AllCropsDataset.
132
140
 
133
141
  Similar to DistributedSampler, we add extra samples at each rank to ensure
134
142
  consistent number of batches across all ranks.
@@ -137,29 +145,27 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
137
145
  def __init__(
138
146
  self,
139
147
  dataset: ModelDataset,
140
- patch_size: tuple[int, int],
141
- overlap_ratio: float = 0.0,
148
+ crop_size: tuple[int, int],
149
+ overlap_pixels: int = 0,
142
150
  rank: int = 0,
143
151
  world_size: int = 1,
144
152
  ):
145
- """Create a new IterableAllPatchesDataset.
153
+ """Create a new IterableAllCropsDataset.
146
154
 
147
155
  Args:
148
156
  dataset: the ModelDataset to wrap.
149
- patch_size: the size of the patches to extract.
150
- overlap_ratio: whether to include overlap between the patches. Note that
151
- the right/bottom-most patches may still overlap since we ensure that
152
- all patches are contained in the window bounds.
157
+ crop_size: the size of the crops to extract.
158
+ overlap_pixels: the number of pixels shared between adjacent crops. Note
159
+ that the right/bottom-most crops may still overlap with other crops even
160
+ if overlap_pixels=0 since we ensure that all crops are contained in the
161
+ window bounds.
153
162
  rank: the global rank of this train worker process.
154
163
  world_size: the total number of train worker processes.
155
164
  """
156
165
  super().__init__()
157
166
  self.dataset = dataset
158
- self.patch_size = patch_size
159
- self.overlap_size = (
160
- round(self.patch_size[0] * overlap_ratio),
161
- round(self.patch_size[1] * overlap_ratio),
162
- )
167
+ self.crop_size = crop_size
168
+ self.overlap_size = (overlap_pixels, overlap_pixels)
163
169
  self.rank = rank
164
170
  self.world_size = world_size
165
171
  self.windows = self.dataset.get_dataset_examples()
@@ -173,17 +179,17 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
173
179
  """
174
180
  self.dataset.set_name(name)
175
181
 
176
- def get_window_num_patches(self, bounds: PixelBounds) -> int:
177
- """Get the number of patches for these bounds.
182
+ def get_window_num_crops(self, bounds: PixelBounds) -> int:
183
+ """Get the number of crops for these bounds.
178
184
 
179
- This corresponds to the length of the list returned by get_patch_options.
185
+ This corresponds to the length of the list returned by get_window_crop_options.
180
186
  """
181
187
  num_cols = (
182
188
  len(
183
189
  range(
184
190
  bounds[0],
185
- bounds[2] - self.patch_size[0],
186
- self.patch_size[0] - self.overlap_size[0],
191
+ bounds[2] - self.crop_size[0],
192
+ self.crop_size[0] - self.overlap_size[0],
187
193
  )
188
194
  )
189
195
  + 1
@@ -192,8 +198,8 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
192
198
  len(
193
199
  range(
194
200
  bounds[1],
195
- bounds[3] - self.patch_size[1],
196
- self.patch_size[1] - self.overlap_size[1],
201
+ bounds[3] - self.crop_size[1],
202
+ self.crop_size[1] - self.overlap_size[1],
197
203
  )
198
204
  )
199
205
  + 1
@@ -235,14 +241,14 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
235
241
  ]
236
242
 
237
243
  # Now compute the maximum number of samples across workers.
238
- max_num_patches = 0
244
+ max_num_crops = 0
239
245
  for worker_windows in windows_by_worker:
240
- worker_num_patches = 0
246
+ worker_num_crops = 0
241
247
  for window_id in worker_windows:
242
- worker_num_patches += self.get_window_num_patches(
248
+ worker_num_crops += self.get_window_num_crops(
243
249
  self.windows[window_id].bounds
244
250
  )
245
- max_num_patches = max(max_num_patches, worker_num_patches)
251
+ max_num_crops = max(max_num_crops, worker_num_crops)
246
252
 
247
253
  # Each worker needs at least one window, otherwise it won't be able to pad.
248
254
  # Unless there are zero windows total, which is fine.
@@ -252,17 +258,17 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
252
258
  # window in the end.
253
259
  # So now we raise an error instead, and require the number of workers to be
254
260
  # less than the number of windows.
255
- if len(windows_by_worker[global_worker_id]) == 0 and max_num_patches > 0:
261
+ if len(windows_by_worker[global_worker_id]) == 0 and max_num_crops > 0:
256
262
  raise ValueError(
257
263
  f"the number of workers {global_num_workers} must be <= the number of windows {len(self.windows)}"
258
264
  )
259
265
 
260
- return (windows_by_worker[global_worker_id], max_num_patches)
266
+ return (windows_by_worker[global_worker_id], max_num_crops)
261
267
 
262
268
  def __iter__(
263
269
  self,
264
270
  ) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
265
- """Iterate over all patches in each element of the underlying ModelDataset."""
271
+ """Iterate over all crops in each element of the underlying ModelDataset."""
266
272
  # Iterate over the window IDs until we have returned enough samples.
267
273
  window_ids, num_samples_needed = self._get_worker_iteration_data()
268
274
  num_samples_returned = 0
@@ -272,32 +278,32 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
272
278
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
273
279
  window_id
274
280
  )
275
- bounds = metadata.patch_bounds
281
+ bounds = metadata.crop_bounds
276
282
 
277
- # For simplicity, pad tensors by patch size to ensure that any patch bounds
283
+ # For simplicity, pad tensors by crop size to ensure that any crop bounds
278
284
  # extending outside the window bounds will not have issues when we slice
279
285
  # the tensors later. Padding is scaled per-input based on resolution_factor.
280
286
  pad_slice_protect(
281
- raw_inputs, passthrough_inputs, self.patch_size, self.inputs
287
+ raw_inputs, passthrough_inputs, self.crop_size, self.inputs
282
288
  )
283
289
 
284
- # Now iterate over the patches and extract/yield the crops.
290
+ # Now iterate over the crops and extract/yield them.
285
291
  # Note that, in case user is leveraging RslearnWriter, it is important that
286
- # the patch_idx be increasing (as we iterate) within one window.
287
- patches = get_window_patch_options(
288
- self.patch_size, self.overlap_size, bounds
292
+ # the crop_idx be increasing (as we iterate) within one window.
293
+ crops = get_window_crop_options(
294
+ self.crop_size, self.overlap_size, bounds
289
295
  )
290
- for patch_idx, patch_bounds in enumerate(patches):
296
+ for crop_idx, crop_bounds in enumerate(crops):
291
297
  cur_geom = STGeometry(
292
- metadata.projection, shapely.box(*patch_bounds), None
298
+ metadata.projection, shapely.box(*crop_bounds), None
293
299
  )
294
300
  start_offset = (
295
- patch_bounds[0] - bounds[0],
296
- patch_bounds[1] - bounds[1],
301
+ crop_bounds[0] - bounds[0],
302
+ crop_bounds[1] - bounds[1],
297
303
  )
298
304
  end_offset = (
299
- patch_bounds[2] - bounds[0],
300
- patch_bounds[3] - bounds[1],
305
+ crop_bounds[2] - bounds[0],
306
+ crop_bounds[3] - bounds[1],
301
307
  )
302
308
 
303
309
  # Define a helper function to handle each input dict.
@@ -339,9 +345,9 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
339
345
  # Adjust the metadata as well.
340
346
  cur_metadata = replace(
341
347
  metadata,
342
- patch_bounds=patch_bounds,
343
- patch_idx=patch_idx,
344
- num_patches_in_window=len(patches),
348
+ crop_bounds=crop_bounds,
349
+ crop_idx=crop_idx,
350
+ num_crops_in_window=len(crops),
345
351
  )
346
352
 
347
353
  # Now we can compute input and target dicts via the task.
@@ -369,37 +375,34 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
369
375
  return self.dataset.get_dataset_examples()
370
376
 
371
377
 
372
- class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
373
- """This wraps a ModelDataset to iterate over all patches in that dataset.
378
+ class InMemoryAllCropsDataset(torch.utils.data.Dataset):
379
+ """This wraps a ModelDataset to iterate over all crops in that dataset.
374
380
 
375
- This should be used when SplitConfig.load_all_patches is enabled.
381
+ This should be used when SplitConfig.load_all_crops is enabled.
376
382
 
377
- This is a simpler version of IterableAllPatchesDataset that caches all windows in memory.
383
+ This is a simpler version of IterableAllCropsDataset that caches all windows in memory.
378
384
  This is useful for small datasets that fit in memory.
379
385
  """
380
386
 
381
387
  def __init__(
382
388
  self,
383
389
  dataset: ModelDataset,
384
- patch_size: tuple[int, int],
385
- overlap_ratio: float = 0.0,
390
+ crop_size: tuple[int, int],
391
+ overlap_pixels: int = 0,
386
392
  ):
387
- """Create a new InMemoryAllPatchesDataset.
393
+ """Create a new InMemoryAllCropsDataset.
388
394
 
389
395
  Args:
390
396
  dataset: the ModelDataset to wrap.
391
- patch_size: the size of the patches to extract.
392
- overlap_ratio: whether to include overlap between the patches. Note that
393
- the right/bottom-most patches may still overlap since we ensure that
394
- all patches are contained in the window bounds.
397
+ crop_size: the size of the crops to extract.
398
+ overlap_pixels: the number of pixels shared between adjacent crops. Note
399
+ that the right/bottom-most crops may still overlap since we ensure that
400
+ all crops are contained in the window bounds.
395
401
  """
396
402
  super().__init__()
397
403
  self.dataset = dataset
398
- self.patch_size = patch_size
399
- self.overlap_size = (
400
- round(self.patch_size[0] * overlap_ratio),
401
- round(self.patch_size[1] * overlap_ratio),
402
- )
404
+ self.crop_size = crop_size
405
+ self.overlap_size = (overlap_pixels, overlap_pixels)
403
406
  self.windows = self.dataset.get_dataset_examples()
404
407
  self.inputs = dataset.inputs
405
408
  self.window_cache: dict[
@@ -407,23 +410,23 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
407
410
  ] = {}
408
411
 
409
412
  # Precompute the batch boundaries for each window
410
- self.patches = []
413
+ self.crops = []
411
414
  for window_id, window in enumerate(self.windows):
412
- patch_bounds = get_window_patch_options(
413
- self.patch_size, self.overlap_size, window.bounds
415
+ window_crop_bounds = get_window_crop_options(
416
+ self.crop_size, self.overlap_size, window.bounds
414
417
  )
415
- for i, patch_bound in enumerate(patch_bounds):
416
- self.patches.append((window_id, patch_bound, (i, len(patch_bounds))))
418
+ for i, crop_bound in enumerate(window_crop_bounds):
419
+ self.crops.append((window_id, crop_bound, (i, len(window_crop_bounds))))
417
420
 
418
421
  def get_raw_inputs(
419
422
  self, index: int
420
423
  ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
421
- """Get the raw inputs for a single patch. Retrieve from cache if possible.
424
+ """Get the raw inputs for a single crop. Retrieve from cache if possible.
422
425
 
423
- Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
426
+ Also crops/pads the tensors by crop size to protect slicing near right/bottom edges.
424
427
 
425
428
  Args:
426
- index: the index of the patch.
429
+ index: the index of the crop.
427
430
 
428
431
  Returns:
429
432
  a tuple of (raw_inputs, passthrough_inputs, metadata).
@@ -432,7 +435,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
432
435
  return self.window_cache[index]
433
436
 
434
437
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
435
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
438
+ pad_slice_protect(raw_inputs, passthrough_inputs, self.crop_size, self.inputs)
436
439
 
437
440
  self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
438
441
  return self.window_cache[index]
@@ -476,20 +479,20 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
476
479
  return cropped
477
480
 
478
481
  def __len__(self) -> int:
479
- """Return the total number of patches in the dataset."""
480
- return len(self.patches)
482
+ """Return the total number of crops in the dataset."""
483
+ return len(self.crops)
481
484
 
482
485
  def __getitem__(
483
486
  self, index: int
484
487
  ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
485
- """Return (input_dict, target_dict, metadata) for a single flattened patch."""
486
- (window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
488
+ """Return (input_dict, target_dict, metadata) for a single flattened crop."""
489
+ (window_id, crop_bounds, (crop_idx, num_crops)) = self.crops[index]
487
490
  raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
488
- bounds = metadata.patch_bounds
491
+ bounds = metadata.crop_bounds
489
492
 
490
- cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
491
- start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
492
- end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
493
+ cur_geom = STGeometry(metadata.projection, shapely.box(*crop_bounds), None)
494
+ start_offset = (crop_bounds[0] - bounds[0], crop_bounds[1] - bounds[1])
495
+ end_offset = (crop_bounds[2] - bounds[0], crop_bounds[3] - bounds[1])
493
496
 
494
497
  cur_raw_inputs = self._crop_input_dict(
495
498
  raw_inputs, start_offset, end_offset, cur_geom
@@ -501,9 +504,9 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
501
504
  # Adjust the metadata as well.
502
505
  cur_metadata = replace(
503
506
  metadata,
504
- patch_bounds=patch_bounds,
505
- patch_idx=patch_idx,
506
- num_patches_in_window=num_patches,
507
+ crop_bounds=crop_bounds,
508
+ crop_idx=crop_idx,
509
+ num_crops_in_window=num_crops,
507
510
  )
508
511
 
509
512
  # Now we can compute input and target dicts via the task.
@@ -15,9 +15,9 @@ from rslearn.dataset import Dataset
15
15
  from rslearn.log_utils import get_logger
16
16
  from rslearn.train.tasks import Task
17
17
 
18
- from .all_patches_dataset import (
19
- InMemoryAllPatchesDataset,
20
- IterableAllPatchesDataset,
18
+ from .all_crops_dataset import (
19
+ InMemoryAllCropsDataset,
20
+ IterableAllCropsDataset,
21
21
  )
22
22
  from .dataset import (
23
23
  DataInput,
@@ -69,7 +69,7 @@ class RslearnDataModule(L.LightningDataModule):
69
69
  predict_config: SplitConfig = SplitConfig(),
70
70
  name: str | None = None,
71
71
  retries: int = 0,
72
- use_in_memory_all_patches_dataset: bool = False,
72
+ use_in_memory_all_crops_dataset: bool = False,
73
73
  index_mode: IndexMode = IndexMode.OFF,
74
74
  ) -> None:
75
75
  """Initialize a new RslearnDataModule.
@@ -92,8 +92,8 @@ class RslearnDataModule(L.LightningDataModule):
92
92
  predict_config: split config for predict split
93
93
  name: name of the dataset
94
94
  retries: number of retries to attempt for getitem calls
95
- use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
96
- instead of IterableAllPatchesDataset if load_all_patches is set to true.
95
+ use_in_memory_all_crops_dataset: whether to use InMemoryAllCropsDataset
96
+ instead of IterableAllCropsDataset if load_all_crops is set to true.
97
97
  index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
98
98
  """
99
99
  super().__init__()
@@ -105,7 +105,7 @@ class RslearnDataModule(L.LightningDataModule):
105
105
  self.init_workers = init_workers if init_workers > 0 else self.num_workers
106
106
  self.name = name
107
107
  self.retries = retries
108
- self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
108
+ self.use_in_memory_all_crops_dataset = use_in_memory_all_crops_dataset
109
109
  self.index_mode = index_mode
110
110
  self.split_configs = {
111
111
  "train": default_config.update(train_config),
@@ -115,15 +115,15 @@ class RslearnDataModule(L.LightningDataModule):
115
115
  }
116
116
 
117
117
  def setup(
118
- self, stage: str, use_in_memory_all_patches_dataset: bool | None = None
118
+ self, stage: str, use_in_memory_all_crops_dataset: bool | None = None
119
119
  ) -> None:
120
120
  """Set up datasets and samplers.
121
121
 
122
122
  Args:
123
123
  stage: Either 'fit', 'validate', 'test', or 'predict'.
124
- use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
125
- instead of IterableAllPatchesDataset if load_all_patches is set to true.
126
- If None, uses the value of self.use_in_memory_all_patches_dataset.
124
+ use_in_memory_all_crops_dataset: whether to use InMemoryAllCropsDataset
125
+ instead of IterableAllCropsDataset if load_all_crops is set to true.
126
+ If None, uses the value of self.use_in_memory_all_crops_dataset.
127
127
  """
128
128
  stage_to_splits = {
129
129
  "fit": ["train", "val"],
@@ -145,34 +145,34 @@ class RslearnDataModule(L.LightningDataModule):
145
145
  index_mode=self.index_mode,
146
146
  )
147
147
  logger.info(f"got {len(dataset)} examples in split {split}")
148
- if split_config.get_load_all_patches():
149
- if use_in_memory_all_patches_dataset is None:
150
- use_in_memory_all_patches_dataset = (
151
- self.use_in_memory_all_patches_dataset
148
+ if split_config.get_load_all_crops():
149
+ if use_in_memory_all_crops_dataset is None:
150
+ use_in_memory_all_crops_dataset = (
151
+ self.use_in_memory_all_crops_dataset
152
152
  )
153
153
  logger.info(
154
- f"using AllPatchesDataset (in_memory={use_in_memory_all_patches_dataset})"
154
+ f"using AllCropsDataset (in_memory={use_in_memory_all_crops_dataset})"
155
155
  )
156
- patch_size = split_config.get_patch_size()
157
- if patch_size is None:
156
+ crop_size = split_config.get_crop_size()
157
+ if crop_size is None:
158
158
  raise ValueError(
159
- "patch_size is not set but must be set if load_all_patches is set"
159
+ "crop_size is not set but must be set if load_all_crops is set"
160
160
  )
161
161
 
162
- all_patches_cls = IterableAllPatchesDataset
162
+ all_crops_cls = IterableAllCropsDataset
163
163
  kwargs = dict(
164
164
  dataset=dataset,
165
- patch_size=patch_size,
166
- overlap_ratio=split_config.get_overlap_ratio(),
165
+ crop_size=crop_size,
166
+ overlap_pixels=split_config.get_overlap_pixels(),
167
167
  rank=self.trainer.global_rank if self.trainer else 0,
168
168
  world_size=self.trainer.world_size if self.trainer else 1,
169
169
  )
170
- if use_in_memory_all_patches_dataset:
170
+ if use_in_memory_all_crops_dataset:
171
171
  kwargs.pop("rank")
172
172
  kwargs.pop("world_size")
173
- all_patches_cls = InMemoryAllPatchesDataset # type: ignore
173
+ all_crops_cls = InMemoryAllCropsDataset # type: ignore
174
174
 
175
- dataset = all_patches_cls(**kwargs) # type: ignore
175
+ dataset = all_crops_cls(**kwargs) # type: ignore
176
176
 
177
177
  if self.retries > 0:
178
178
  dataset = RetryDataset(dataset, retries=self.retries)
@@ -209,7 +209,7 @@ class RslearnDataModule(L.LightningDataModule):
209
209
  # If the number of windows is 0, then we can set positive number of workers
210
210
  # since they won't yield anything anyway.
211
211
  num_workers = self.num_workers
212
- if split_config.load_all_patches and len(dataset.get_dataset_examples()) > 0:
212
+ if split_config.load_all_crops and len(dataset.get_dataset_examples()) > 0:
213
213
  num_workers = min(num_workers, len(dataset.get_dataset_examples()))
214
214
 
215
215
  kwargs: dict[str, Any] = dict(
@@ -357,7 +357,7 @@ class MultiDatasetDataModule(L.LightningDataModule):
357
357
  stage: The stage to set up ('fit', 'validate', 'test', 'predict')
358
358
  """
359
359
  for name, data_module in self.data_modules.items():
360
- data_module.setup(stage, use_in_memory_all_patches_dataset=True) # type: ignore
360
+ data_module.setup(stage, use_in_memory_all_crops_dataset=True) # type: ignore
361
361
  data_module.set_name(name)
362
362
 
363
363
  def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]: