rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,47 @@
1
1
  """Default LightningDataModule for rslearn."""
2
2
 
3
+ import math
4
+ import random
5
+ from collections import defaultdict
6
+ from collections.abc import Iterator
3
7
  from typing import Any
4
8
 
5
9
  import lightning as L
6
10
  import torch
7
- from torch.utils.data import DataLoader
11
+ from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
8
12
  from upath import UPath
9
13
 
10
14
  from rslearn.dataset import Dataset
15
+ from rslearn.log_utils import get_logger
11
16
  from rslearn.train.tasks import Task
12
17
 
13
- from .dataset import DataInput, ModelDataset, RetryDataset, SplitConfig
18
+ from .all_patches_dataset import (
19
+ InMemoryAllPatchesDataset,
20
+ IterableAllPatchesDataset,
21
+ )
22
+ from .dataset import (
23
+ DataInput,
24
+ ModelDataset,
25
+ MultiDataset,
26
+ RetryDataset,
27
+ SplitConfig,
28
+ )
29
+
30
+ logger = get_logger(__name__)
14
31
 
15
32
 
16
33
  def collate_fn(
17
- batch: list[tuple[dict[str, Any], dict[str, Any]]],
18
- ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
34
+ batch: list[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]],
35
+ ) -> tuple:
19
36
  """Collate batch of training examples.
20
37
 
21
38
  We just make list of the inputs and another of the targets.
22
39
 
23
40
  Args:
24
- batch: list of input/target for each example
41
+ batch: list of input/target/metadata for each example
25
42
 
26
43
  Returns:
27
- a tuple (inputs, targets)
44
+ a tuple (inputs, targets, metadatas)
28
45
  """
29
46
  return tuple(zip(*batch))
30
47
 
@@ -43,27 +60,38 @@ class RslearnDataModule(L.LightningDataModule):
43
60
  path_options: dict[str, Any] = {},
44
61
  batch_size: int = 1,
45
62
  num_workers: int = 0,
63
+ init_workers: int = 0,
46
64
  default_config: SplitConfig = SplitConfig(),
47
65
  train_config: SplitConfig = SplitConfig(),
48
66
  val_config: SplitConfig = SplitConfig(),
49
67
  test_config: SplitConfig = SplitConfig(),
50
68
  predict_config: SplitConfig = SplitConfig(),
51
- ):
69
+ name: str | None = None,
70
+ retries: int = 0,
71
+ use_in_memory_all_patches_dataset: bool = False,
72
+ ) -> None:
52
73
  """Initialize a new RslearnDataModule.
53
74
 
54
75
  Args:
55
76
  inputs: what to read from the underlying dataset
56
77
  task: the task to train on
57
- path: the dataset path.
78
+ path: the dataset path
58
79
  path_options: additional options for path to pass to fsspec.
59
80
  batch_size: the batch size
60
81
  num_workers: number of data loader worker processes, or 0 to use main
61
82
  process only
83
+ init_workers: number of workers used to initialize the dataset, e.g. for
84
+ loading the list of windows. Defaults to 0 which uses num_workers for
85
+ this setting
62
86
  default_config: default split configuration
63
87
  train_config: split config for train split
64
88
  val_config: split config for val split
65
89
  test_config: split config for test split
66
90
  predict_config: split config for predict split
91
+ name: name of the dataset
92
+ retries: number of retries to attempt for getitem calls
93
+ use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
94
+ instead of IterableAllPatchesDataset if load_all_patches is set to true.
67
95
  """
68
96
  super().__init__()
69
97
  self.inputs = inputs
@@ -71,7 +99,10 @@ class RslearnDataModule(L.LightningDataModule):
71
99
  self.path = UPath(path, **path_options)
72
100
  self.batch_size = batch_size
73
101
  self.num_workers = num_workers
74
-
102
+ self.init_workers = init_workers if init_workers > 0 else self.num_workers
103
+ self.name = name
104
+ self.retries = retries
105
+ self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
75
106
  self.split_configs = {
76
107
  "train": default_config.update(train_config),
77
108
  "val": default_config.update(val_config),
@@ -79,11 +110,16 @@ class RslearnDataModule(L.LightningDataModule):
79
110
  "predict": default_config.update(predict_config),
80
111
  }
81
112
 
82
- def setup(self, stage: str):
113
+ def setup(
114
+ self, stage: str, use_in_memory_all_patches_dataset: bool | None = None
115
+ ) -> None:
83
116
  """Set up datasets and samplers.
84
117
 
85
118
  Args:
86
119
  stage: Either 'fit', 'validate', 'test', or 'predict'.
120
+ use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
121
+ instead of IterableAllPatchesDataset if load_all_patches is set to true.
122
+ If None, uses the value of self.use_in_memory_all_patches_dataset.
87
123
  """
88
124
  stage_to_splits = {
89
125
  "fit": ["train", "val"],
@@ -93,31 +129,112 @@ class RslearnDataModule(L.LightningDataModule):
93
129
  }
94
130
  self.datasets = {}
95
131
  for split in stage_to_splits[stage]:
132
+ split_config = self.split_configs[split]
96
133
  dataset = ModelDataset(
97
134
  dataset=Dataset(path=self.path),
98
135
  split_config=self.split_configs[split],
99
136
  inputs=self.inputs,
100
137
  task=self.task,
101
- workers=self.num_workers,
138
+ workers=self.init_workers,
139
+ name=self.name,
140
+ fix_patch_pick=(split != "train"),
102
141
  )
103
- dataset = RetryDataset(dataset)
142
+ logger.info(f"got {len(dataset)} examples in split {split}")
143
+ if split_config.get_load_all_patches():
144
+ if use_in_memory_all_patches_dataset is None:
145
+ use_in_memory_all_patches_dataset = (
146
+ self.use_in_memory_all_patches_dataset
147
+ )
148
+ logger.info(
149
+ f"using AllPatchesDataset (in_memory={use_in_memory_all_patches_dataset})"
150
+ )
151
+ patch_size = split_config.get_patch_size()
152
+ if patch_size is None:
153
+ raise ValueError(
154
+ "patch_size is not set but must be set if load_all_patches is set"
155
+ )
156
+
157
+ all_patches_cls = IterableAllPatchesDataset
158
+ kwargs = dict(
159
+ dataset=dataset,
160
+ patch_size=patch_size,
161
+ overlap_ratio=split_config.get_overlap_ratio(),
162
+ rank=self.trainer.global_rank if self.trainer else 0,
163
+ world_size=self.trainer.world_size if self.trainer else 1,
164
+ )
165
+ if use_in_memory_all_patches_dataset:
166
+ kwargs.pop("rank")
167
+ kwargs.pop("world_size")
168
+ all_patches_cls = InMemoryAllPatchesDataset # type: ignore
169
+
170
+ dataset = all_patches_cls(**kwargs) # type: ignore
171
+
172
+ if self.retries > 0:
173
+ dataset = RetryDataset(dataset, retries=self.retries)
104
174
  self.datasets[split] = dataset
105
- print(f"got {len(self.datasets[split])} examples in split {split}")
106
175
 
107
- def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]:
176
+ def set_name(self, name: str) -> None:
177
+ """Set the name of the dataset.
178
+
179
+ Args:
180
+ name: the name of the dataset
181
+ """
182
+ self.name = name
183
+ for dataset in self.datasets.values():
184
+ dataset.set_name(name)
185
+
186
+ def _get_dataloader(
187
+ self,
188
+ split: str,
189
+ ) -> DataLoader[dict[str, torch.Tensor]]:
190
+ """Get a dataloader for the given split.
191
+
192
+ Args:
193
+ split: the split to get a dataloader for
194
+ """
108
195
  dataset = self.datasets[split]
109
- kwargs = dict(
196
+ split_config = self.split_configs[split]
197
+
198
+ # Enable persistent workers unless we are using main process.
199
+ persistent_workers = self.num_workers > 0
200
+
201
+ # If using all patches, limit number of workers to the number of windows.
202
+ # Otherwise it has to distribute the same window to different workers which can
203
+ # cause issues for RslearnWriter.
204
+ # If the number of windows is 0, then we can set positive number of workers
205
+ # since they won't yield anything anyway.
206
+ num_workers = self.num_workers
207
+ if split_config.load_all_patches and len(dataset.get_dataset_examples()) > 0:
208
+ num_workers = min(num_workers, len(dataset.get_dataset_examples()))
209
+
210
+ kwargs: dict[str, Any] = dict(
110
211
  dataset=dataset,
111
212
  batch_size=self.batch_size,
112
- num_workers=self.num_workers,
213
+ num_workers=num_workers,
113
214
  collate_fn=collate_fn,
114
- persistent_workers=True,
215
+ persistent_workers=persistent_workers,
115
216
  )
116
- sampler_factory = self.split_configs[split].sampler
217
+ should_shuffle = split == "train"
218
+
219
+ sampler_factory = split_config.sampler
117
220
  if sampler_factory:
118
221
  kwargs["sampler"] = sampler_factory.get_sampler(dataset)
119
- elif split == "train":
120
- kwargs["shuffle"] = True
222
+ elif (
223
+ self.trainer is not None
224
+ and self.trainer.world_size is not None
225
+ and self.trainer.world_size > 1
226
+ and not isinstance(dataset, IterableDataset)
227
+ ):
228
+ # Use distributed sampler in case ddp is enabled.
229
+ kwargs["sampler"] = DistributedSampler(
230
+ dataset,
231
+ num_replicas=self.trainer.world_size,
232
+ rank=self.trainer.global_rank,
233
+ shuffle=should_shuffle,
234
+ )
235
+ else:
236
+ kwargs["shuffle"] = should_shuffle
237
+
121
238
  return DataLoader(**kwargs)
122
239
 
123
240
  def train_dataloader(self) -> DataLoader[dict[str, torch.Tensor]]:
@@ -167,3 +284,310 @@ class RslearnDataModule(L.LightningDataModule):
167
284
  dataset or sampler, or if the dataset or sampler has length 0.
168
285
  """
169
286
  return self._get_dataloader("predict")
287
+
288
+
289
+ class MultiDatasetDataModule(L.LightningDataModule):
290
+ """Data module that manages multiple RslearnDataModule instances.
291
+
292
+ This module creates and manages multiple RslearnDataModule instances, each handling
293
+ a different dataset. It provides a unified interface for training on multiple datasets
294
+ with different modalities and labels.
295
+
296
+ Each dataset can have different:
297
+ - Input modalities (e.g., Sentinel-2 vs Landsat)
298
+ - Label schemas (e.g., different classification classes)
299
+ - Task types (e.g., classification vs detection)
300
+ - Transforms and preprocessing
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ data_modules: dict[str, RslearnDataModule],
306
+ num_workers: int = 32,
307
+ sample_mode: str = "random_cycle",
308
+ batch_sizes: int | dict[str, int] | None = None,
309
+ refill_batches: bool = False,
310
+ per_dataset_patch_limit: int | None = None,
311
+ steps_per_dataset: int | None = None,
312
+ disabled_datasets: list[str] | None = None,
313
+ ) -> None:
314
+ """Initialize a new MultiDatasetDataModule.
315
+
316
+ Args:
317
+ data_modules: dict mapping dataset names to RslearnDataModule objects
318
+ num_workers: the maximum number of workers to use for the dataloader
319
+ sample_mode: the mode to sample from the datasets ("random", "cycle", "random_cycle", "reptile")
320
+ batch_sizes: the batch size for all datasets, or a dict mapping dataset
321
+ names to batch sizes, or None to use the batch size of the largest
322
+ dataset (default: None)
323
+ refill_batches: whether to refill empty dataset iterators
324
+ once they run out each epoch (default: False)
325
+ per_dataset_patch_limit: the maximum number of patches to sample from each dataset
326
+ per epoch during training. Does not affect validation (default: None = no limit)
327
+ steps_per_dataset: the number of steps to sample from each dataset in a row (requires that
328
+ sample_mode is "reptile")
329
+ disabled_datasets: list of datasets to disable (default: None = no disabled datasets)
330
+ """
331
+ super().__init__()
332
+ self.data_modules = data_modules
333
+ self.num_workers = num_workers
334
+ self.sample_mode = sample_mode
335
+ self.batch_sizes = batch_sizes
336
+ self.refill_batches = refill_batches
337
+ self.per_dataset_patch_limit = per_dataset_patch_limit
338
+ self.steps_per_dataset = steps_per_dataset
339
+ self.disabled_datasets = disabled_datasets or []
340
+
341
+ for dataset in self.disabled_datasets:
342
+ if dataset in self.data_modules:
343
+ del self.data_modules[dataset]
344
+ logger.info(f"Skipping disabled dataset {dataset}")
345
+ else:
346
+ logger.info(f"Could not find dataset {dataset} to skip")
347
+
348
+ def setup(self, stage: str | None = None) -> None:
349
+ """Set up the datasets for the given stage. Also assign dataset-specific names.
350
+
351
+ Args:
352
+ stage: The stage to set up ('fit', 'validate', 'test', 'predict')
353
+ """
354
+ for name, data_module in self.data_modules.items():
355
+ data_module.setup(stage, use_in_memory_all_patches_dataset=True) # type: ignore
356
+ data_module.set_name(name)
357
+
358
+ def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]:
359
+ datasets = {name: dm.datasets[split] for name, dm in self.data_modules.items()}
360
+ if isinstance(self.batch_sizes, dict):
361
+ batch_sizes = self.batch_sizes
362
+ else:
363
+ batch_size: int | None = self.batch_sizes
364
+ if batch_size is None:
365
+ batch_size = max(
366
+ self.data_modules.values(), key=lambda dm: dm.batch_size
367
+ ).batch_size
368
+ batch_sizes = {name: batch_size for name in self.data_modules.keys()}
369
+
370
+ logger.info(f"{split} is using batch_sizes {batch_sizes}")
371
+ logger.info(f"{split} is using sample_mode {self.sample_mode}")
372
+ if self.per_dataset_patch_limit:
373
+ logger.info(
374
+ f"{split} is using per_dataset_patch_limit {self.per_dataset_patch_limit}"
375
+ )
376
+
377
+ dataset = MultiDataset(datasets)
378
+ return DataLoader(
379
+ dataset=dataset,
380
+ pin_memory=True,
381
+ num_workers=self.num_workers,
382
+ persistent_workers=True,
383
+ collate_fn=collate_fn,
384
+ batch_sampler=DistributedPerDatasetBatchSampler(
385
+ multi_dataset=dataset,
386
+ batch_sizes=batch_sizes,
387
+ shuffle=(split == "train"),
388
+ num_replicas=self.trainer.world_size, # type: ignore
389
+ rank=self.trainer.global_rank, # type: ignore
390
+ sample_mode=self.sample_mode,
391
+ refill_batches=self.refill_batches,
392
+ per_dataset_patch_limit=(
393
+ self.per_dataset_patch_limit if split == "train" else None
394
+ ),
395
+ steps_per_dataset=self.steps_per_dataset,
396
+ ),
397
+ )
398
+
399
+ def train_dataloader(self) -> DataLoader:
400
+ """Get the training dataloader."""
401
+ return self._get_dataloader("train")
402
+
403
+ def val_dataloader(self) -> DataLoader:
404
+ """Get the validation dataloader."""
405
+ return self._get_dataloader("val")
406
+
407
+ def test_dataloader(self) -> DataLoader:
408
+ """Get the test dataloader."""
409
+ return self._get_dataloader("test")
410
+
411
+ def predict_dataloader(self) -> DataLoader:
412
+ """Get the predict dataloader."""
413
+ return self._get_dataloader("predict")
414
+
415
+
416
+ class DistributedPerDatasetBatchSampler(torch.utils.data.Sampler[list[int]]):
417
+ """Distributed batch sampler yielding batches from one sub-dataset per batch.
418
+
419
+ Wraps torch DistributedSampler to first split indices across ranks,
420
+ then does "one-subdataset-per-batch" sampling in each process.
421
+ """
422
+
423
+ def __init__(
424
+ self,
425
+ multi_dataset: MultiDataset,
426
+ batch_sizes: dict[str, int],
427
+ shuffle: bool = True,
428
+ num_replicas: int | None = None,
429
+ rank: int | None = None,
430
+ sample_mode: str = "random_cycle",
431
+ refill_batches: bool = False,
432
+ steps_per_dataset: int | None = None,
433
+ per_dataset_patch_limit: int | None = None,
434
+ ) -> None:
435
+ """Initialize a new DistributedPerDatasetBatchSampler.
436
+
437
+ Args:
438
+ multi_dataset: the MultiDataset to sample from
439
+ batch_sizes: the batch size for each dataset
440
+ shuffle: whether to shuffle the indices
441
+ num_replicas: the number of replicas
442
+ rank: the rank
443
+ sample_mode: the mode to sample from the datasets ("random", "cycle", "random_cycle", "reptile")
444
+ refill_batches: whether to refill empty dataset iterators
445
+ once they run out each epoch
446
+ steps_per_dataset: the number of steps to sample from each dataset
447
+ per_dataset_patch_limit: the maximum number of patches to sample from each dataset
448
+ per epoch during training. Does not affect validation (default: None = no limit)
449
+ steps_per_dataset: the number of steps to sample from each dataset in a row (requires that
450
+ sample_mode is "reptile")
451
+ """
452
+ self.multi_dataset = multi_dataset
453
+ self.batch_sizes = batch_sizes
454
+ self.sample_mode = sample_mode
455
+ self.refill_batches = refill_batches
456
+ self.per_dataset_patch_limit = per_dataset_patch_limit
457
+ self.steps_per_dataset: int = steps_per_dataset # type: ignore
458
+ self.epoch = 0
459
+
460
+ if sample_mode == "reptile":
461
+ assert steps_per_dataset is not None, (
462
+ "steps_per_dataset must be provided when sample_mode is 'reptile'"
463
+ )
464
+ assert sample_mode in (
465
+ "random",
466
+ "cycle",
467
+ "random_cycle",
468
+ "reptile",
469
+ ), f"Invalid sample_mode: {sample_mode}"
470
+
471
+ # For now, we just track the total number of batches if refill_batches is True,
472
+ # so we must the datasets come out balanced during each epoch
473
+ if refill_batches and self.sample_mode not in (
474
+ "cycle",
475
+ "random_cycle",
476
+ "reptile",
477
+ ):
478
+ raise ValueError("refill_batches is only supported with round_robin")
479
+
480
+ # Using one DistributedSampler per dataset guarantees equal splitting
481
+ # across all datasets across all ranks
482
+ self.dist_samplers = {
483
+ name: DistributedSampler(
484
+ dataset,
485
+ num_replicas=num_replicas,
486
+ rank=rank,
487
+ shuffle=shuffle,
488
+ drop_last=False,
489
+ )
490
+ for name, dataset in multi_dataset.datasets.items()
491
+ }
492
+
493
+ for k, v in self.dist_samplers.items():
494
+ logger.info(f"Dataset {k} has {len(v)} samples")
495
+
496
+ def set_epoch(self, epoch: int) -> None:
497
+ """Set the epoch for the distributed sampler.
498
+
499
+ Args:
500
+ epoch: the epoch to set
501
+ """
502
+ self.epoch = epoch
503
+ for dist_sampler in self.dist_samplers.values():
504
+ dist_sampler.set_epoch(epoch)
505
+
506
+ def __iter__(self) -> Iterator[list[int]]:
507
+ """Iterate over the batches."""
508
+ # Get the per-rank, per-epoch list of properly offset multi-dataset indices
509
+ partitioned: dict[str, list[int]] = {}
510
+ refill: dict[str, list[int]] = defaultdict(list)
511
+ for name, sampler in self.dist_samplers.items():
512
+ offset = self.multi_dataset.buckets[name].start
513
+ partitioned[name] = [idx + offset for idx in sampler]
514
+ if self.per_dataset_patch_limit:
515
+ partitioned[name] = partitioned[name][: self.per_dataset_patch_limit]
516
+
517
+ # Seed is shared aross all ranks but shuffled per epoch
518
+ rng = random.Random(self.epoch)
519
+ picks = list(partitioned.keys())
520
+ last_picked = -1
521
+ dataset_counter = 0
522
+
523
+ # Random mode samples uniformly across all datasets regardless of size
524
+ for n in range(len(self)):
525
+ available = [name for name, idxs in partitioned.items() if idxs]
526
+ if not self.refill_batches:
527
+ # For cycle, only pick from available datasets, but if
528
+ # we are refilling batches then all datasets are available
529
+ picks = [name for name in picks if name in available]
530
+ if not available:
531
+ logger.warning(f"Found no available batch on step {n} of {len(self)}")
532
+ break
533
+
534
+ if self.sample_mode == "cycle":
535
+ last_picked = (last_picked + 1) % len(picks)
536
+ name = picks[last_picked]
537
+
538
+ elif self.sample_mode == "reptile":
539
+ # Sample n times from the same dataset before moving onto the next,
540
+ # but still ensure we sample from all the datasets before repeating
541
+ # This is so that we can use the refill_batches feature
542
+ if dataset_counter == 0:
543
+ name = rng.choice(picks)
544
+ last_picked = picks.index(name)
545
+ else:
546
+ name = picks[last_picked]
547
+ dataset_counter += 1
548
+ if dataset_counter >= self.steps_per_dataset:
549
+ dataset_counter = 0
550
+ picks.remove(name)
551
+ if not picks:
552
+ picks = list(partitioned.keys())
553
+
554
+ elif self.sample_mode == "random_cycle":
555
+ name = rng.choice(picks)
556
+ picks.remove(name)
557
+ if not picks:
558
+ picks = list(partitioned.keys())
559
+
560
+ else:
561
+ name = rng.choice(available)
562
+
563
+ idxs = partitioned[name]
564
+ batch, partitioned[name] = (
565
+ idxs[: self.batch_sizes[name]],
566
+ idxs[self.batch_sizes[name] :],
567
+ )
568
+
569
+ # If we are refilling batches, we just keep adding the indexes
570
+ if self.refill_batches:
571
+ refill[name].extend(batch)
572
+ if len(partitioned[name]) == 0:
573
+ # Shuffle batches again once we have to replenish them
574
+ partitioned[name], refill[name] = refill[name], []
575
+ rng.shuffle(partitioned[name])
576
+
577
+ yield batch
578
+
579
+ def __len__(self) -> int:
580
+ """Return the number of batches."""
581
+
582
+ def len_iter() -> Iterator[int]:
583
+ """Iterate over the number of batches for each dataset."""
584
+ for name, sampler in self.dist_samplers.items():
585
+ length = len(sampler)
586
+ if self.per_dataset_patch_limit:
587
+ length = min(length, self.per_dataset_patch_limit)
588
+ yield math.ceil(length / self.batch_sizes[name])
589
+
590
+ if self.refill_batches:
591
+ return max(len_iter()) * len(self.dist_samplers)
592
+ else:
593
+ return sum(len_iter())