rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py CHANGED
@@ -1,30 +1,55 @@
1
1
  """Default Dataset for rslearn."""
2
2
 
3
+ import hashlib
4
+ import json
3
5
  import multiprocessing
4
6
  import os
5
7
  import random
8
+ import tempfile
6
9
  import time
10
+ import uuid
11
+ from datetime import datetime
7
12
  from typing import Any
8
13
 
9
14
  import torch
10
15
  import tqdm
16
+ from rasterio.warp import Resampling
11
17
 
12
18
  import rslearn.train.transforms.transform
13
19
  from rslearn.config import (
14
20
  DType,
15
- RasterFormatConfig,
16
- RasterLayerConfig,
17
- VectorLayerConfig,
21
+ LayerConfig,
18
22
  )
19
- from rslearn.dataset import Dataset, Window
20
- from rslearn.train.tasks import Task
21
- from rslearn.utils import logger
23
+ from rslearn.data_sources.data_source import Item
24
+ from rslearn.dataset.dataset import Dataset
25
+ from rslearn.dataset.storage.file import FileWindowStorage
26
+ from rslearn.dataset.window import (
27
+ Window,
28
+ WindowLayerData,
29
+ get_layer_and_group_from_dir_name,
30
+ )
31
+ from rslearn.log_utils import get_logger
32
+ from rslearn.train.model_context import RasterImage
33
+ from rslearn.utils.feature import Feature
34
+ from rslearn.utils.geometry import PixelBounds, ResolutionFactor
22
35
  from rslearn.utils.mp import star_imap_unordered
23
- from rslearn.utils.raster_format import load_raster_format
24
- from rslearn.utils.vector_format import load_vector_format
25
36
 
37
+ from .model_context import SampleMetadata
38
+ from .tasks import Task
26
39
  from .transforms import Sequential
27
40
 
41
+ logger = get_logger(__name__)
42
+
43
+
44
+ def get_torch_dtype(dtype: DType) -> torch.dtype:
45
+ """Convert rslearn DType to torch dtype."""
46
+ if dtype == DType.INT32:
47
+ return torch.int32
48
+ elif dtype == DType.FLOAT32:
49
+ return torch.float32
50
+ else:
51
+ raise ValueError(f"unable to handle {dtype} as a torch dtype")
52
+
28
53
 
29
54
  class SamplerFactory:
30
55
  """Factory to produce a Sampler.
@@ -47,7 +72,9 @@ class SamplerFactory:
47
72
  class RandomSamplerFactory(SamplerFactory):
48
73
  """A sampler factory for RandomSampler."""
49
74
 
50
- def __init__(self, replacement: bool = False, num_samples: int | None = None):
75
+ def __init__(
76
+ self, replacement: bool = False, num_samples: int | None = None
77
+ ) -> None:
51
78
  """Initialize a RandomSamplerFactory.
52
79
 
53
80
  Args:
@@ -75,7 +102,9 @@ class RandomSamplerFactory(SamplerFactory):
75
102
  class WeightedRandomSamplerFactory(SamplerFactory):
76
103
  """A sampler factory for WeightedRandomSampler."""
77
104
 
78
- def __init__(self, option_key: str, num_samples: int, replacement: bool = True):
105
+ def __init__(
106
+ self, option_key: str, num_samples: int, replacement: bool = True
107
+ ) -> None:
79
108
  """Initialize a WeightedRandomSamplerFactory.
80
109
 
81
110
  Args:
@@ -97,7 +126,7 @@ class WeightedRandomSamplerFactory(SamplerFactory):
97
126
  a RandomSampler
98
127
  """
99
128
  weights = []
100
- for window in dataset.get_windows():
129
+ for window in dataset.get_dataset_examples():
101
130
  weights.append(window.options[self.option_key])
102
131
  return torch.utils.data.WeightedRandomSampler(
103
132
  weights, self.num_samples, replacement=self.replacement
@@ -108,6 +137,10 @@ class DataInput:
108
137
  """Specification of a piece of data from a window that is needed for training.
109
138
 
110
139
  The DataInput includes which layer(s) the data can be obtained from for each window.
140
+
141
+ Note that this class is not a dataclass because jsonargparse does not play well
142
+ with dataclasses without enabling specialized options which we have not validated
143
+ will work with the rest of our code.
111
144
  """
112
145
 
113
146
  def __init__(
@@ -119,6 +152,10 @@ class DataInput:
119
152
  passthrough: bool = False,
120
153
  is_target: bool = False,
121
154
  dtype: DType = DType.FLOAT32,
155
+ load_all_layers: bool = False,
156
+ load_all_item_groups: bool = False,
157
+ resolution_factor: ResolutionFactor = ResolutionFactor(),
158
+ resampling: Resampling = Resampling.nearest,
122
159
  ):
123
160
  """Initialize a new DataInput.
124
161
 
@@ -132,6 +169,21 @@ class DataInput:
132
169
  is_target: whether this DataInput represents a target for the task. Targets
133
170
  are not read during prediction phase.
134
171
  dtype: data type to load the raster as
172
+ load_all_layers: whether to load all of the layers specified in the list of
173
+ layer names. By default, we randomly pick one layer to read. When
174
+ reading multiple layers, the images are stacked on the channel
175
+ dimension. This option will also cause the dataset to only include
176
+ windows where all of the layers are materialized (by default, only
177
+ windows with none of the layers materialized would be excluded).
178
+ load_all_item_groups: whether to load all item groups in the layer(s) we
179
+ are reading from. By default, we assume the specified layer name is of
180
+ the form "{layer_name}.{group_idx}" and read that item group only. With
181
+ this option enabled, we ignore the group_idx and read all item groups.
182
+ resolution_factor: controls the resolution at which raster data is loaded for training.
183
+ By default (factor=1), data is loaded at the window resolution.
184
+ E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
185
+ the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
186
+ resampling: resampling method (default nearest neighbor).
135
187
  """
136
188
  self.data_type = data_type
137
189
  self.layers = layers
@@ -140,6 +192,241 @@ class DataInput:
140
192
  self.passthrough = passthrough
141
193
  self.is_target = is_target
142
194
  self.dtype = dtype
195
+ self.load_all_layers = load_all_layers
196
+ self.load_all_item_groups = load_all_item_groups
197
+ self.resolution_factor = resolution_factor
198
+ self.resampling = resampling
199
+
200
+
201
+ def read_raster_layer_for_data_input(
202
+ window: Window,
203
+ bounds: PixelBounds,
204
+ layer_name: str,
205
+ group_idx: int,
206
+ layer_config: LayerConfig,
207
+ data_input: DataInput,
208
+ ) -> torch.Tensor:
209
+ """Read a raster layer for a DataInput.
210
+
211
+ This scans the available rasters for the layer at the window to determine which
212
+ ones are needed to get all of the configured bands.
213
+
214
+ Args:
215
+ window: the window to read from.
216
+ bounds: the bounds to read.
217
+ layer_name: the layer.
218
+ group_idx: the item group.
219
+ layer_config: the layer configuration.
220
+ data_input: the DataInput that specifies the bands and dtype.
221
+
222
+ Returns:
223
+ Raster data as a tensor.
224
+ """
225
+ # See what different sets of bands we need to read to get all the
226
+ # configured bands.
227
+ needed_bands = data_input.bands
228
+ if needed_bands is None:
229
+ raise ValueError(f"No bands specified for {layer_name}")
230
+ needed_band_indexes = {}
231
+ for i, band in enumerate(needed_bands):
232
+ needed_band_indexes[band] = i
233
+ needed_sets_and_indexes = []
234
+ for band_set in layer_config.band_sets:
235
+ needed_src_indexes = []
236
+ needed_dst_indexes = []
237
+ if band_set.bands is None:
238
+ continue
239
+ for i, band in enumerate(band_set.bands):
240
+ if band not in needed_band_indexes:
241
+ continue
242
+ needed_src_indexes.append(i)
243
+ needed_dst_indexes.append(needed_band_indexes[band])
244
+ del needed_band_indexes[band]
245
+ if len(needed_src_indexes) == 0:
246
+ continue
247
+ needed_sets_and_indexes.append(
248
+ (band_set, needed_src_indexes, needed_dst_indexes)
249
+ )
250
+ if len(needed_band_indexes) > 0:
251
+ raise ValueError(
252
+ "could not get all the needed bands from "
253
+ + f"window {window.name} layer {layer_name} group {group_idx}"
254
+ )
255
+
256
+ # Get the projection and bounds to read under (multiply window resolution # by
257
+ # the specified resolution factor).
258
+ final_projection = data_input.resolution_factor.multiply_projection(
259
+ window.projection
260
+ )
261
+ final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
262
+
263
+ image = torch.zeros(
264
+ (
265
+ len(needed_bands),
266
+ final_bounds[3] - final_bounds[1],
267
+ final_bounds[2] - final_bounds[0],
268
+ ),
269
+ dtype=get_torch_dtype(data_input.dtype),
270
+ )
271
+
272
+ for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
273
+ if band_set.format is None:
274
+ raise ValueError(f"No format specified for {layer_name}")
275
+ raster_format = band_set.instantiate_raster_format()
276
+ raster_dir = window.get_raster_dir(
277
+ layer_name, band_set.bands, group_idx=group_idx
278
+ )
279
+
280
+ # TODO: previously we try to read based on band_set.zoom_offset when possible,
281
+ # and handle zooming in with torch.repeat (if resampling method is nearest
282
+ # neighbor). However, we have not benchmarked whether this actually improves
283
+ # data loading speed, so for simplicity, for now we let rasterio handle the
284
+ # resampling. If it really is much faster to handle it via torch, then it may
285
+ # make sense to bring back that functionality.
286
+
287
+ src = raster_format.decode_raster(
288
+ raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
289
+ )
290
+ image[dst_indexes, :, :] = torch.as_tensor(
291
+ src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
292
+ )
293
+
294
+ return image
295
+
296
+
297
+ def read_layer_time_range(
298
+ layer_data: WindowLayerData | None, group_idx: int
299
+ ) -> tuple[datetime, datetime] | None:
300
+ """Extract the combined time range from all items in a layer data group.
301
+
302
+ Returns the min start time and max end time across all items, or None if
303
+ no items have time ranges.
304
+
305
+ Raises:
306
+ ValueError: If some items have time_range and others don't.
307
+ """
308
+ if layer_data is None:
309
+ return None
310
+
311
+ serialized_items = layer_data.serialized_item_groups[group_idx]
312
+ if not serialized_items:
313
+ return None
314
+
315
+ first_item = Item.deserialize(serialized_items[0])
316
+ if first_item.geometry.time_range is None:
317
+ return None
318
+
319
+ # If the first item has a time_range, all items must have one
320
+ time_ranges: list[tuple[datetime, datetime]] = []
321
+ for serialized_item in serialized_items:
322
+ item = Item.deserialize(serialized_item)
323
+ if item.geometry.time_range is None:
324
+ raise ValueError(
325
+ f"Item '{item.name}' has no time_range, but first item does. "
326
+ "All items in a group must consistently have or lack time_range."
327
+ )
328
+ time_ranges.append(item.geometry.time_range)
329
+
330
+ return (
331
+ min(tr[0] for tr in time_ranges),
332
+ max(tr[1] for tr in time_ranges),
333
+ )
334
+
335
+
336
+ def read_data_input(
337
+ dataset: Dataset,
338
+ window: Window,
339
+ bounds: PixelBounds,
340
+ data_input: DataInput,
341
+ rng: random.Random,
342
+ ) -> RasterImage | list[Feature]:
343
+ """Read the data specified by the DataInput from the window.
344
+
345
+ Args:
346
+ dataset: the dataset, to get layer configs.
347
+ window: the window to read from.
348
+ bounds: the bounds of the patch we are reading.
349
+ data_input: the DataInput that specifies what layers to read.
350
+ rng: random number generator
351
+
352
+ Returns:
353
+ the raster or vector data.
354
+ """
355
+ # We first enumerate which layers are available.
356
+ # If load_all_item_groups is set, we need to check each item group within the
357
+ # layer.
358
+ layer_options: list[tuple[str, int]] = []
359
+ if data_input.load_all_item_groups:
360
+ wanted_layers = set(data_input.layers)
361
+ for layer_name, group_idx in window.list_completed_layers():
362
+ if layer_name not in wanted_layers:
363
+ continue
364
+ layer_options.append((layer_name, group_idx))
365
+ else:
366
+ for option in data_input.layers:
367
+ layer_name, group_idx = get_layer_and_group_from_dir_name(option)
368
+ if not window.is_layer_completed(layer_name, group_idx):
369
+ continue
370
+ layer_options.append((layer_name, group_idx))
371
+
372
+ # Now determine the layers that we should actually read.
373
+ # We randomly pick one, unless load_all_layers is set, in which case we read all of
374
+ # them.
375
+ layers_to_read: list[tuple[str, int]]
376
+ if data_input.load_all_layers:
377
+ # We assume that the user has ensured the layers are compatible, e.g. raster
378
+ # layers will need to have the same number of bands.
379
+ layers_to_read = layer_options
380
+ else:
381
+ layers_to_read = [rng.choice(layer_options)]
382
+
383
+ if data_input.data_type == "raster":
384
+ # load it once here
385
+ layer_datas = window.load_layer_datas()
386
+ images: list[torch.Tensor] = []
387
+ time_ranges: list[tuple[datetime, datetime] | None] = []
388
+ for layer_name, group_idx in layers_to_read:
389
+ layer_config = dataset.layers[layer_name]
390
+ image = read_raster_layer_for_data_input(
391
+ window,
392
+ bounds,
393
+ layer_name,
394
+ group_idx,
395
+ layer_config,
396
+ data_input,
397
+ )
398
+ # some layers (e.g. "label_raster") won't have associated layer datas
399
+ layer_data = layer_datas.get(layer_name)
400
+ time_range = read_layer_time_range(layer_data, group_idx)
401
+ if len(time_ranges) > 0:
402
+ if type(time_ranges[-1]) is not type(time_range):
403
+ raise ValueError(
404
+ f"All time ranges should be datetime tuples or None. Got {type(time_range)} amd {type(time_ranges[-1])}"
405
+ )
406
+ images.append(image)
407
+ time_ranges.append(time_range)
408
+ return RasterImage(
409
+ torch.stack(images, dim=1),
410
+ time_ranges if time_ranges[0] is not None else None, # type: ignore
411
+ )
412
+
413
+ elif data_input.data_type == "vector":
414
+ # We don't really support time series for vector data currently, we just
415
+ # concatenate the features together.
416
+ features: list[Feature] = []
417
+ for layer_name, group_idx in layers_to_read:
418
+ layer_config = dataset.layers[layer_name]
419
+ vector_format = layer_config.instantiate_vector_format()
420
+ layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
421
+ cur_features = vector_format.decode_vector(
422
+ layer_dir, window.projection, window.bounds
423
+ )
424
+ features.extend(cur_features)
425
+
426
+ return features
427
+
428
+ else:
429
+ raise ValueError(f"unknown data type {data_input.data_type}")
143
430
 
144
431
 
145
432
  class SplitConfig:
@@ -149,15 +436,16 @@ class SplitConfig:
149
436
  self,
150
437
  groups: list[str] | None = None,
151
438
  names: list[str] | None = None,
152
- tags: dict[str, str] | None = None,
439
+ tags: dict[str, Any] | None = None,
153
440
  num_samples: int | None = None,
441
+ num_patches: int | None = None,
154
442
  transforms: list[torch.nn.Module] | None = None,
155
443
  sampler: SamplerFactory | None = None,
156
444
  patch_size: int | tuple[int, int] | None = None,
157
445
  overlap_ratio: float | None = None,
158
446
  load_all_patches: bool | None = None,
159
447
  skip_targets: bool | None = None,
160
- ):
448
+ ) -> None:
161
449
  """Initialize a new SplitConfig.
162
450
 
163
451
  Args:
@@ -168,6 +456,7 @@ class SplitConfig:
168
456
  value. If value is empty, then only the existince of the key in the
169
457
  window options is checked.
170
458
  num_samples: limit this split to this many examples
459
+ num_patches: limit this split to this many patches
171
460
  transforms: transforms to apply
172
461
  sampler: SamplerFactory for this split
173
462
  patch_size: an optional square size or (width, height) tuple. If set, read
@@ -183,15 +472,19 @@ class SplitConfig:
183
472
  self.names = names
184
473
  self.tags = tags
185
474
  self.num_samples = num_samples
475
+ self.num_patches = num_patches
186
476
  self.transforms = transforms
187
477
  self.sampler = sampler
188
478
  self.patch_size = patch_size
189
- self.load_all_patches = load_all_patches
190
479
  self.skip_targets = skip_targets
480
+
481
+ # Note that load_all_patches are handled by the RslearnDataModule rather than
482
+ # the ModelDataset.
483
+ self.load_all_patches = load_all_patches
191
484
  self.overlap_ratio = overlap_ratio
192
- if self.overlap_ratio is not None:
193
- if not (0 < self.overlap_ratio < 1):
194
- raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
485
+
486
+ if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
487
+ raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
195
488
 
196
489
  def update(self, other: "SplitConfig") -> "SplitConfig":
197
490
  """Override settings in this SplitConfig with those in another.
@@ -204,6 +497,7 @@ class SplitConfig:
204
497
  names=self.names,
205
498
  tags=self.tags,
206
499
  num_samples=self.num_samples,
500
+ num_patches=self.num_patches,
207
501
  transforms=self.transforms,
208
502
  sampler=self.sampler,
209
503
  patch_size=self.patch_size,
@@ -219,6 +513,8 @@ class SplitConfig:
219
513
  result.tags = other.tags
220
514
  if other.num_samples:
221
515
  result.num_samples = other.num_samples
516
+ if other.num_patches:
517
+ result.num_patches = other.num_patches
222
518
  if other.transforms:
223
519
  result.transforms = other.transforms
224
520
  if other.sampler:
@@ -233,6 +529,18 @@ class SplitConfig:
233
529
  result.skip_targets = other.skip_targets
234
530
  return result
235
531
 
532
+ def get_patch_size(self) -> tuple[int, int] | None:
533
+ """Get patch size normalized to int tuple."""
534
+ if self.patch_size is None:
535
+ return None
536
+ if isinstance(self.patch_size, int):
537
+ return (self.patch_size, self.patch_size)
538
+ return self.patch_size
539
+
540
+ def get_overlap_ratio(self) -> float:
541
+ """Get the overlap ratio (default 0)."""
542
+ return self.overlap_ratio if self.overlap_ratio is not None else 0.0
543
+
236
544
  def get_load_all_patches(self) -> bool:
237
545
  """Returns whether loading all patches is enabled (default False)."""
238
546
  return True if self.load_all_patches is True else False
@@ -242,7 +550,7 @@ class SplitConfig:
242
550
  return True if self.skip_targets is True else False
243
551
 
244
552
 
245
- def check_window(inputs: dict[str, DataInput], window: Window) -> bool:
553
+ def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
246
554
  """Verify that the window has the required layers based on the specified inputs.
247
555
 
248
556
  Args:
@@ -254,17 +562,25 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> bool:
254
562
  """
255
563
 
256
564
  # Make sure window has all the needed layers.
257
- def is_any_layer_available(data_input):
565
+ def is_available(data_input: DataInput) -> bool:
566
+ # If load_all_layers is enabled, we should check that all the layers are
567
+ # present. Otherwise, we just need one layer.
568
+ is_any_layer_available = False
569
+ are_all_layers_available = True
258
570
  for layer_name in data_input.layers:
259
- completed_fname = window.path / "layers" / layer_name / "completed"
260
- if completed_fname.exists():
261
- return True
262
- return False
571
+ if window.is_layer_completed(layer_name):
572
+ is_any_layer_available = True
573
+ else:
574
+ are_all_layers_available = False
575
+ if data_input.load_all_layers:
576
+ return are_all_layers_available
577
+ else:
578
+ return is_any_layer_available
263
579
 
264
580
  for data_input in inputs.values():
265
581
  if not data_input.required:
266
582
  continue
267
- if not is_any_layer_available(data_input):
583
+ if not is_available(data_input):
268
584
  logger.debug(
269
585
  "Skipping window %s since check for layers %s failed",
270
586
  window.name,
@@ -285,7 +601,9 @@ class ModelDataset(torch.utils.data.Dataset):
285
601
  inputs: dict[str, DataInput],
286
602
  task: Task,
287
603
  workers: int,
288
- ):
604
+ name: str | None = None,
605
+ fix_patch_pick: bool = False,
606
+ ) -> None:
289
607
  """Instantiate a new ModelDataset.
290
608
 
291
609
  Args:
@@ -294,50 +612,30 @@ class ModelDataset(torch.utils.data.Dataset):
294
612
  inputs: data to read from the dataset for training
295
613
  task: the task to train on
296
614
  workers: number of workers to use for initializing the dataset
615
+ name: name of the dataset (default: None)
616
+ fix_patch_pick: if True, fix the patch pick to be the same every time
617
+ for a given window. Useful for testing (default: False)
297
618
  """
298
619
  self.dataset = dataset
299
620
  self.split_config = split_config
300
621
  self.inputs = inputs
301
622
  self.task = task
302
-
623
+ self.name = name
624
+ self.fix_patch_pick = fix_patch_pick
303
625
  if split_config.transforms:
304
626
  self.transforms = Sequential(*split_config.transforms)
305
627
  else:
306
628
  self.transforms = rslearn.train.transforms.transform.Identity()
307
629
 
308
- # Convert patch size to (width, height) format if needed.
309
- if not split_config.patch_size:
630
+ # Get normalized patch size from the SplitConfig.
631
+ # But if load all patches is enabled, this is handled by AllPatchesDataset, so
632
+ # here we instead load the entire windows.
633
+ if split_config.get_load_all_patches():
310
634
  self.patch_size = None
311
- elif isinstance(split_config.patch_size, int):
312
- self.patch_size = (split_config.patch_size, split_config.patch_size)
313
635
  else:
314
- self.patch_size = split_config.patch_size
636
+ self.patch_size = split_config.get_patch_size()
315
637
 
316
- if split_config.names:
317
- windows = self.dataset.load_windows(
318
- groups=split_config.groups,
319
- names=split_config.names,
320
- show_progress=True,
321
- workers=workers,
322
- )
323
- elif split_config.groups:
324
- windows = self.dataset.load_windows(
325
- groups=split_config.groups, show_progress=True, workers=workers
326
- )
327
- else:
328
- windows = self.dataset.load_windows(show_progress=True, workers=workers)
329
-
330
- if split_config.tags:
331
- # Filter the window.options.
332
- new_windows = []
333
- for window in windows:
334
- for k, v in split_config.tags.items():
335
- if k not in window.options:
336
- continue
337
- if v and window.options[k] != v:
338
- continue
339
- new_windows.append(window)
340
- windows = new_windows
638
+ windows = self._get_initial_windows(split_config, workers)
341
639
 
342
640
  # If targets are not needed, remove them from the inputs.
343
641
  if split_config.get_skip_targets():
@@ -347,98 +645,178 @@ class ModelDataset(torch.utils.data.Dataset):
347
645
 
348
646
  # Eliminate windows that are missing either a requisite input layer, or missing
349
647
  # all target layers.
350
- p = multiprocessing.Pool(workers)
351
- outputs = star_imap_unordered(
352
- p,
353
- check_window,
354
- [
355
- dict(
356
- inputs=self.inputs,
357
- window=window,
358
- )
359
- for window in windows
360
- ],
361
- )
362
648
  new_windows = []
363
- for window in tqdm.tqdm(
364
- outputs, total=len(windows), desc="Checking available layers in windows"
365
- ):
366
- if window is None:
367
- continue
368
- new_windows.append(window)
369
- p.close()
649
+ if workers == 0:
650
+ for window in windows:
651
+ if check_window(self.inputs, window) is None:
652
+ continue
653
+ new_windows.append(window)
654
+ else:
655
+ p = multiprocessing.Pool(workers)
656
+ outputs = star_imap_unordered(
657
+ p,
658
+ check_window,
659
+ [
660
+ dict(
661
+ inputs=self.inputs,
662
+ window=window,
663
+ )
664
+ for window in windows
665
+ ],
666
+ )
667
+ for window in tqdm.tqdm(
668
+ outputs, total=len(windows), desc="Checking available layers in windows"
669
+ ):
670
+ if window is None:
671
+ continue
672
+ new_windows.append(window)
673
+ p.close()
370
674
  windows = new_windows
371
675
 
676
+ # Sort the windows to ensure that the dataset is consistent across GPUs.
677
+ # Inconsistent ordering can lead to a subset of windows being processed during
678
+ # "model test" / "model predict" when using multiple GPUs.
679
+ # We use a hash so that functionality like num_samples limit gets a random
680
+ # subset of windows (with respect to the hash function choice).
681
+ windows.sort(
682
+ key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
683
+ )
684
+
372
685
  # Limit windows to num_samples if requested.
373
686
  if split_config.num_samples:
374
- # TODO: use hash of window names so this is deterministic and not arbitrarily ordered according to load_windows
687
+ # The windows are sorted by hash of window name so this distribution should
688
+ # be representative of the population.
375
689
  windows = windows[0 : split_config.num_samples]
376
690
 
377
- self.windows = windows
691
+ # Write dataset_examples to a file so that we can load it lazily in the worker
692
+ # processes. Otherwise it takes a long time to transmit it when spawning each
693
+ # process.
694
+ self.dataset_examples_fname = os.path.join(
695
+ tempfile.gettempdir(),
696
+ "rslearn_dataset_examples",
697
+ f"{os.getpid()}_{uuid.uuid4()}.json",
698
+ )
699
+ self.num_dataset_examples = len(windows)
700
+ self.dataset_examples: list[Window] | None = None
701
+ logger.info(
702
+ f"Writing {len(windows)} dataset examples to {self.dataset_examples_fname}"
703
+ )
704
+ os.makedirs(os.path.dirname(self.dataset_examples_fname), exist_ok=True)
705
+ with open(self.dataset_examples_fname, "w") as f:
706
+ json.dump([self._serialize_item(example) for example in windows], f)
378
707
 
379
- # If we're loading all patches, we need to include the patch details.
380
- if split_config.get_load_all_patches():
381
- patches = []
382
- overlap_size = int(
383
- self.patch_size[0] * split_config.overlap_ratio
384
- if split_config.overlap_ratio
385
- else 0
708
+ def _get_initial_windows(
709
+ self, split_config: SplitConfig, workers: int
710
+ ) -> list[Window]:
711
+ """Get the initial windows before input layer filtering.
712
+
713
+ The windows are filtered based on configured window names, groups, and tags.
714
+
715
+ This is a helper for the init function.
716
+
717
+ Args:
718
+ split_config: the split configuration.
719
+ workers: number of worker processes.
720
+
721
+ Returns:
722
+ list of windows from the dataset after applying the aforementioned filters.
723
+ """
724
+ # Load windows from dataset.
725
+ # If the window storage is FileWindowStorage, we pass the workers/show_progress arguments.
726
+ kwargs: dict[str, Any] = {}
727
+ if isinstance(self.dataset.storage, FileWindowStorage):
728
+ kwargs["workers"] = workers
729
+ kwargs["show_progress"] = True
730
+ # We also add the name/group filters to the kwargs.
731
+ if split_config.names:
732
+ kwargs["names"] = split_config.names
733
+ if split_config.groups:
734
+ kwargs["groups"] = split_config.groups
735
+
736
+ windows = self.dataset.load_windows(**kwargs)
737
+
738
+ # Filter by tags (if provided) using the window.options.
739
+ if split_config.tags:
740
+ new_windows = []
741
+ num_removed: dict[str, int] = {}
742
+ for window in windows:
743
+ for k, v in split_config.tags.items():
744
+ if k not in window.options or (v and window.options[k] != v):
745
+ num_removed[k] = num_removed.get(k, 0) + 1
746
+ break
747
+ else:
748
+ new_windows.append(window)
749
+ logger.info(
750
+ f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
751
+ )
752
+ for k, v in num_removed.items():
753
+ logger.info(f"Removed {v} windows due to tag {k}")
754
+ windows = new_windows
755
+
756
+ return windows
757
+
758
+ def _serialize_item(self, example: Window) -> dict[str, Any]:
759
+ return example.get_metadata()
760
+
761
+ def _deserialize_item(self, d: dict[str, Any]) -> Window:
762
+ return Window.from_metadata(
763
+ self.dataset.storage,
764
+ d,
765
+ )
766
+
767
+ def get_dataset_examples(self) -> list[Window]:
768
+ """Get a list of examples in the dataset.
769
+
770
+ If load_all_patches is False, this is a list of Windows. Otherwise, this is a
771
+ list of (window, patch_bounds, (patch_idx, # patches)) tuples.
772
+ """
773
+ if self.dataset_examples is None:
774
+ logger.debug(
775
+ f"Loading dataset examples from {self.dataset_examples_fname} in process {os.getpid()}"
386
776
  )
387
- for window in self.windows:
388
- cur_patches = []
389
- for col in range(
390
- window.bounds[0],
391
- window.bounds[2],
392
- self.patch_size[0] - overlap_size,
393
- ):
394
- for row in range(
395
- window.bounds[1],
396
- window.bounds[3],
397
- self.patch_size[1] - overlap_size,
398
- ):
399
- cur_patches.append(
400
- (
401
- col,
402
- row,
403
- col + self.patch_size[0],
404
- row + self.patch_size[1],
405
- )
406
- )
407
- for i, patch_bounds in enumerate(cur_patches):
408
- patches.append((window, patch_bounds, (i, len(cur_patches))))
409
- self.windows = patches
777
+ with open(self.dataset_examples_fname) as f:
778
+ self.dataset_examples = [
779
+ self._deserialize_item(d) for d in json.load(f)
780
+ ]
781
+ logger.debug(f"Finished loading dataset examples in process {os.getpid()}")
782
+ return self.dataset_examples
410
783
 
411
784
  def __len__(self) -> int:
412
785
  """Returns the dataset length."""
413
- return len(self.windows)
786
+ return self.num_dataset_examples
414
787
 
415
- def __getitem__(self, idx) -> tuple[dict[str, Any], dict[str, Any]]:
416
- """Read one training example.
788
+ def get_raw_inputs(
789
+ self, idx: int
790
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
791
+ """Get the raw inputs and base metadata for this example.
792
+
793
+ This is the raster or vector data before being processed by the Task. So it
794
+ should be a Tensor for raster and list[Feature] for vector.
417
795
 
418
796
  Args:
419
797
  idx: the index in the dataset.
420
798
 
421
799
  Returns:
422
- a tuple (input_dict, target_dict)
800
+ a tuple (raw_inputs, passthrough_inputs, metadata).
423
801
  """
424
- logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
425
- window = self.windows[idx]
802
+ dataset_examples = self.get_dataset_examples()
803
+ example = dataset_examples[idx]
804
+ rng = random.Random(idx if self.fix_patch_pick else None)
426
805
 
427
806
  # Select bounds to read.
428
- if self.split_config.get_load_all_patches():
429
- window, bounds, (patch_idx, num_patches) = window
430
- elif self.patch_size:
807
+ if self.patch_size:
808
+ window = example
431
809
 
432
- def get_patch_range(n_patch, n_window):
810
+ def get_patch_range(n_patch: int, n_window: int) -> list[int]:
433
811
  if n_patch > n_window:
434
812
  # Select arbitrary range containing the entire window.
435
813
  # Basically arbitrarily padding the window to get to patch size.
436
- start = random.randint(n_window - n_patch, 0)
814
+ start = rng.randint(n_window - n_patch, 0)
437
815
  return [start, start + n_patch]
438
816
 
439
817
  else:
440
818
  # Select arbitrary patch within the window.
441
- start = random.randint(0, n_window - n_patch)
819
+ start = rng.randint(0, n_window - n_patch)
442
820
  return [start, start + n_patch]
443
821
 
444
822
  window_size = (
@@ -449,128 +827,56 @@ class ModelDataset(torch.utils.data.Dataset):
449
827
  get_patch_range(self.patch_size[0], window_size[0]),
450
828
  get_patch_range(self.patch_size[1], window_size[1]),
451
829
  ]
452
- bounds = [
830
+ bounds = (
453
831
  window.bounds[0] + patch_ranges[0][0],
454
832
  window.bounds[1] + patch_ranges[1][0],
455
833
  window.bounds[0] + patch_ranges[0][1],
456
834
  window.bounds[1] + patch_ranges[1][1],
457
- ]
835
+ )
836
+
458
837
  else:
838
+ window = example
459
839
  bounds = window.bounds
460
840
 
461
- # Read the inputs and targets.
462
- def read_input(data_input: DataInput):
463
- # First enumerate all options of individual layers to read.
464
- layer_options = []
465
- for layer_name in data_input.layers:
466
- completed_fname = window.path / "layers" / layer_name / "completed"
467
- if not completed_fname.exists():
468
- continue
469
- layer_options.append(layer_name)
470
-
471
- # For now we just randomly pick one option.
472
- # In the future we need to support different configuration for how to pick
473
- # the options, as well as picking multiple for series inputs.
474
- layer = random.choice(layer_options)
475
- layer_dir = window.path / "layers" / layer
476
- layer_config = self.dataset.layers[layer]
477
-
478
- if data_input.data_type == "raster":
479
- assert isinstance(layer_config, RasterLayerConfig)
480
-
481
- # See what different sets of bands we need to read to get all the
482
- # configured bands.
483
- needed_bands = data_input.bands
484
- needed_band_indexes = {}
485
- for i, band in enumerate(needed_bands):
486
- needed_band_indexes[band] = i
487
- needed_sets_and_indexes = []
488
- for band_set in layer_config.band_sets:
489
- needed_src_indexes = []
490
- needed_dst_indexes = []
491
- for i, band in enumerate(band_set.bands):
492
- if band not in needed_band_indexes:
493
- continue
494
- needed_src_indexes.append(i)
495
- needed_dst_indexes.append(needed_band_indexes[band])
496
- del needed_band_indexes[band]
497
- if len(needed_src_indexes) == 0:
498
- continue
499
- needed_sets_and_indexes.append(
500
- (band_set, needed_src_indexes, needed_dst_indexes)
501
- )
502
- if len(needed_band_indexes) > 0:
503
- raise Exception(
504
- "could not get all the needed bands from "
505
- + f"window {window.name} layer {layer}"
506
- )
507
-
508
- image = torch.zeros(
509
- (len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
510
- dtype=data_input.dtype.get_torch_dtype(),
511
- )
512
-
513
- for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
514
- _, final_bounds = band_set.get_final_projection_and_bounds(
515
- window.projection, bounds
516
- )
517
- raster_format = load_raster_format(
518
- RasterFormatConfig(band_set.format["name"], band_set.format)
519
- )
520
- cur_path = layer_dir / "_".join(band_set.bands)
521
- src = raster_format.decode_raster(cur_path, final_bounds)
522
-
523
- # Resize to patch size if needed.
524
- # This is for band sets that are stored at a lower resolution.
525
- # Here we assume that it is a multiple.
526
- if src.shape[1:3] != image.shape[1:3]:
527
- if src.shape[1] < image.shape[1]:
528
- factor = image.shape[1] // src.shape[1]
529
- src = src.repeat(repeats=factor, axis=1).repeat(
530
- repeats=factor, axis=2
531
- )
532
- else:
533
- factor = src.shape[1] // image.shape[1]
534
- src = src[:, ::factor, ::factor]
535
-
536
- image[dst_indexes, :, :] = torch.as_tensor(
537
- src[src_indexes, :, :].astype(
538
- data_input.dtype.get_numpy_dtype()
539
- )
540
- )
541
-
542
- return image
543
-
544
- elif data_input.data_type == "vector":
545
- assert isinstance(layer_config, VectorLayerConfig)
546
- vector_format = load_vector_format(layer_config.format)
547
- features = vector_format.decode_vector(layer_dir, bounds)
548
- return features
549
-
550
- else:
551
- raise Exception(f"unknown data type {data_input.data_type}")
841
+ assert isinstance(window, Window)
552
842
 
553
843
  raw_inputs = {}
554
844
  passthrough_inputs = {}
555
845
  for name, data_input in self.inputs.items():
556
- raw_inputs[name] = read_input(data_input)
846
+ raw_inputs[name] = read_data_input(
847
+ self.dataset, window, bounds, data_input, rng
848
+ )
557
849
  if data_input.passthrough:
558
850
  passthrough_inputs[name] = raw_inputs[name]
559
851
 
560
- metadata = {
561
- "group": window.group,
562
- "window_name": window.name,
563
- "window_bounds": window.bounds,
564
- "bounds": bounds,
565
- "time_range": window.time_range,
566
- "projection": window.projection,
567
- }
568
- if self.split_config.get_load_all_patches():
569
- metadata["patch_idx"] = patch_idx
570
- metadata["num_patches"] = num_patches
571
- else:
572
- metadata["patch_idx"] = 0
573
- metadata["num_patches"] = 1
852
+ metadata = SampleMetadata(
853
+ window_group=window.group,
854
+ window_name=window.name,
855
+ window_bounds=window.bounds,
856
+ patch_bounds=bounds,
857
+ patch_idx=0,
858
+ num_patches_in_window=1,
859
+ time_range=window.time_range,
860
+ projection=window.projection,
861
+ dataset_source=self.name,
862
+ )
863
+
864
+ return raw_inputs, passthrough_inputs, metadata
865
+
866
+ def __getitem__(
867
+ self, idx: int
868
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
869
+ """Read one training example.
870
+
871
+ Args:
872
+ idx: the index in the dataset.
873
+
874
+ Returns:
875
+ a tuple (input_dict, target_dict, metadata)
876
+ """
877
+ logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
878
+
879
+ raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(idx)
574
880
 
575
881
  input_dict, target_dict = self.task.process_inputs(
576
882
  raw_inputs,
@@ -584,17 +890,21 @@ class ModelDataset(torch.utils.data.Dataset):
584
890
 
585
891
  return input_dict, target_dict, metadata
586
892
 
587
- def get_windows(self) -> list[Window]:
588
- """Returns a list of windows in this dataset."""
589
- return self.windows
893
+ def set_name(self, name: str) -> None:
894
+ """Set the name of the dataset.
895
+
896
+ Args:
897
+ name: the name to set.
898
+ """
899
+ self.name = name
590
900
 
591
901
 
592
902
  class RetryDataset(torch.utils.data.Dataset):
593
903
  """A dataset wrapper that retries getitem upon encountering error."""
594
904
 
595
905
  def __init__(
596
- self, dataset: torch.utils.data.Dataset, retries: int = 3, delay: float = 5
597
- ):
906
+ self, dataset: ModelDataset, retries: int = 3, delay: float = 5
907
+ ) -> None:
598
908
  """Create a new RetryDataset.
599
909
 
600
910
  Args:
@@ -606,7 +916,15 @@ class RetryDataset(torch.utils.data.Dataset):
606
916
  self.retries = retries
607
917
  self.delay = delay
608
918
 
609
- def __len__(self):
919
+ def set_name(self, name: str) -> None:
920
+ """Set the name of the dataset.
921
+
922
+ Args:
923
+ name: the name to set.
924
+ """
925
+ self.dataset.set_name(name)
926
+
927
+ def __len__(self) -> int:
610
928
  """Return length of the dataset."""
611
929
  return len(self.dataset)
612
930
 
@@ -632,6 +950,41 @@ class RetryDataset(torch.utils.data.Dataset):
632
950
  # One last try -- but don't catch any more errors.
633
951
  return self.dataset[idx]
634
952
 
635
- def get_windows(self) -> list[Window]:
953
+ def get_dataset_examples(self) -> list[Window]:
636
954
  """Returns a list of windows in this dataset."""
637
- return self.dataset.get_windows()
955
+ return self.dataset.get_dataset_examples()
956
+
957
+
958
+ class MultiDataset(torch.utils.data.Dataset):
959
+ """A dataset that combines multiple datasets."""
960
+
961
+ def __init__(self, datasets: dict[str, RetryDataset]) -> None:
962
+ """Create a new MultiDataset.
963
+
964
+ Args:
965
+ datasets: map of dataset name to dataset.
966
+ """
967
+ self.datasets = datasets
968
+ self.buckets = {}
969
+ curr_offset = 0
970
+ for name, ds in datasets.items():
971
+ self.buckets[name] = range(curr_offset, curr_offset + len(ds))
972
+ curr_offset += len(ds)
973
+
974
+ def __len__(self) -> int:
975
+ """Return length of the dataset."""
976
+ return sum(len(ds) for ds in self.datasets.values())
977
+
978
+ def __getitem__(self, idx: int) -> Any:
979
+ """Get item from the dataset.
980
+
981
+ Args:
982
+ idx: the item index.
983
+
984
+ Returns:
985
+ the item data.
986
+ """
987
+ for name, bucket in self.buckets.items():
988
+ if idx in bucket:
989
+ return self.datasets[name][idx - bucket.start]
990
+ raise IndexError(f"Index {idx} out of range (len={len(self)})")