rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """Default LightningModule for rslearn."""
2
2
 
3
+ import json
3
4
  import os
4
5
  from typing import Any
5
6
 
@@ -7,12 +8,17 @@ import lightning as L
7
8
  import torch
8
9
  from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
9
10
  from PIL import Image
10
- from torch.optim import AdamW
11
- from torch.optim.lr_scheduler import ReduceLROnPlateau
12
11
  from upath import UPath
13
12
 
13
+ from rslearn.log_utils import get_logger
14
+
15
+ from .model_context import ModelContext, ModelOutput
16
+ from .optimizer import AdamW, OptimizerFactory
17
+ from .scheduler import PlateauScheduler, SchedulerFactory
14
18
  from .tasks import Task
15
19
 
20
+ logger = get_logger(__name__)
21
+
16
22
 
17
23
  class RestoreConfig:
18
24
  """Configuration for restoring model parameters.
@@ -36,7 +42,7 @@ class RestoreConfig:
36
42
  restore_path_options: additional options for the restore_path to pass to
37
43
  fsspec.
38
44
  selector: path in the torch dict containing the model parameters.
39
- ignore_prefixes: prefixes to restore.
45
+ ignore_prefixes: prefixes to ignore from the state dict.
40
46
  remap_prefixes: list of (old_prefix, new_prefix) to rename parameters
41
47
  starting with old_prefix to start with new_prefix instead.
42
48
  """
@@ -47,9 +53,9 @@ class RestoreConfig:
47
53
 
48
54
  def get_state_dict(self) -> dict[str, Any]:
49
55
  """Returns the state dict configured in this RestoreConfig."""
50
- print(f"loading state dict from {self.restore_path}")
56
+ logger.info(f"loading state dict from {self.restore_path}")
51
57
  with self.restore_path.open("rb") as f:
52
- state_dict = torch.load(f)
58
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
53
59
  for k in self.selector:
54
60
  state_dict = state_dict[k]
55
61
 
@@ -82,48 +88,71 @@ class RslearnLightningModule(L.LightningModule):
82
88
  self,
83
89
  model: torch.nn.Module,
84
90
  task: Task,
91
+ optimizer: OptimizerFactory | None = None,
92
+ scheduler: SchedulerFactory | None = None,
93
+ visualize_dir: str | None = None,
94
+ metrics_file: str | None = None,
95
+ restore_config: RestoreConfig | None = None,
96
+ print_parameters: bool = False,
97
+ print_model: bool = False,
98
+ # Deprecated options.
85
99
  lr: float = 1e-3,
86
100
  plateau: bool = False,
87
101
  plateau_factor: float = 0.1,
88
102
  plateau_patience: int = 10,
89
103
  plateau_min_lr: float = 0,
90
104
  plateau_cooldown: int = 0,
91
- visualize_dir: str | None = None,
92
- restore_config: RestoreConfig | None = None,
93
- print_parameters: bool = False,
94
- print_model: bool = False,
95
105
  ):
96
106
  """Initialize a new RslearnLightningModule.
97
107
 
98
108
  Args:
99
109
  model: the model
100
110
  task: the task to train on
101
- lr: the initial learning rate
102
- plateau: whether to enable plateau scheduler (default false)
103
- plateau_factor: on plateau, factor to multiply learning rate by
104
- plateau_patience: number of iterations with no improvement in val loss
105
- before reducing learning rate
106
- plateau_min_lr: minimum learning rate to reduce to
107
- plateau_cooldown: number of iterations after reducing learning rate before
108
- resetting plateau scheduler
111
+ optimizer: the optimizer factory.
112
+ scheduler: the learning rate scheduler factory.
109
113
  visualize_dir: during validation or testing, output visualizations to this
110
114
  directory
115
+ metrics_file: file to save metrics to
111
116
  restore_config: specification of configuration to restore parameters from
112
117
  a non-Lightning checkpoint.
113
118
  print_parameters: whether to print the list of model parameters after model
114
119
  initialization
115
120
  print_model: whether to print the model after model initialization
121
+ lr: deprecated.
122
+ plateau: deprecated.
123
+ plateau_factor: deprecated.
124
+ plateau_patience: deprecated.
125
+ plateau_min_lr: deprecated.
126
+ plateau_cooldown: deprecated.
116
127
  """
117
128
  super().__init__()
118
129
  self.model = model
119
130
  self.task = task
120
- self.lr = lr
121
- self.plateau = plateau
122
- self.plateau_factor = plateau_factor
123
- self.plateau_patience = plateau_patience
124
- self.plateau_min_lr = plateau_min_lr
125
- self.plateau_cooldown = plateau_cooldown
126
131
  self.visualize_dir = visualize_dir
132
+ self.metrics_file = metrics_file
133
+ self.restore_config = restore_config
134
+
135
+ self.scheduler_factory: SchedulerFactory | None = None
136
+ if scheduler:
137
+ self.scheduler_factory = scheduler
138
+ elif plateau:
139
+ logger.warning(
140
+ "The plateau argument to RslearnLightningModule is deprecated and will be removed in a future version"
141
+ )
142
+ self.scheduler_factory = PlateauScheduler(
143
+ factor=plateau_factor,
144
+ patience=plateau_patience,
145
+ min_lr=plateau_min_lr,
146
+ cooldown=plateau_cooldown,
147
+ )
148
+
149
+ if optimizer:
150
+ self.optimizer_factory = optimizer
151
+ else:
152
+ logger.warning(
153
+ "Defaulting the optimizer to AdamW since an OptimizerFactory was not provided. In a future version, the optimizer will be a required argument."
154
+ )
155
+ self.optimizer_factory = AdamW(lr=lr)
127
156
 
128
157
  if print_parameters:
129
158
  for name, param in self.named_parameters():
@@ -132,23 +161,26 @@ class RslearnLightningModule(L.LightningModule):
132
161
  if print_model:
133
162
  print(self.model)
134
163
 
135
- if restore_config:
136
- state_dict = restore_config.get_state_dict()
137
- missing_keys, unexpected_keys = self.model.load_state_dict(
138
- state_dict, strict=False
139
- )
140
- if missing_keys or unexpected_keys:
141
- print(
142
- f"warning: restore yielded missing_keys={missing_keys} and unexpected_keys={unexpected_keys}"
143
- )
144
-
145
164
  self.epochs = 0
146
165
 
147
166
  metrics = self.task.get_metrics()
148
167
  self.val_metrics = metrics.clone(prefix="val_")
149
168
  self.test_metrics = metrics.clone(prefix="test_")
150
169
 
151
- self.schedulers = {}
170
+ self.schedulers: dict = {}
171
+
172
+ def on_fit_start(self) -> None:
173
+ """Called when the fit begins."""
174
+ # Only restore if doing a fresh fit.
175
+ if self.trainer.ckpt_path is None and self.restore_config:
176
+ state_dict = self.restore_config.get_state_dict()
177
+ missing_keys, unexpected_keys = self.model.load_state_dict(
178
+ state_dict, strict=False
179
+ )
180
+ if missing_keys or unexpected_keys:
181
+ logger.warning(
182
+ f"restore yielded missing_keys={missing_keys} and unexpected_keys={unexpected_keys}"
183
+ )
152
184
 
153
185
  def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
154
186
  """Initialize the optimizer and learning rate scheduler.
@@ -156,27 +188,37 @@ class RslearnLightningModule(L.LightningModule):
156
188
  Returns:
157
189
  Optimizer and learning rate scheduler.
158
190
  """
159
- params = [p for p in self.parameters() if p.requires_grad]
160
- optimizer = AdamW(params, lr=self.lr)
191
+ optimizer = self.optimizer_factory.build(self)
161
192
  d = dict(
162
193
  optimizer=optimizer,
163
194
  )
164
- if self.plateau:
165
- scheduler = ReduceLROnPlateau(
166
- optimizer,
167
- factor=self.plateau_factor,
168
- patience=self.plateau_patience,
169
- min_lr=self.plateau_min_lr,
170
- cooldown=self.plateau_cooldown,
171
- )
195
+ if self.scheduler_factory is not None:
196
+ scheduler = self.scheduler_factory.build(optimizer)
172
197
  d["lr_scheduler"] = {
173
198
  "scheduler": scheduler,
174
199
  "monitor": "train_loss",
175
200
  "interval": "epoch",
176
201
  }
177
- self.schedulers["plateau"] = scheduler
202
+ self.schedulers["scheduler"] = scheduler
178
203
  return d
179
204
 
205
+ def on_train_epoch_start(self) -> None:
206
+ """If we are in a multi-dataset distributed strategy, set the epoch."""
207
+ try:
208
+ self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
209
+ except AttributeError:
210
+ # Fail silently for single-dataset case, which is okay
211
+ pass
212
+
213
+ def on_test_epoch_end(self) -> None:
214
+ """Optionally save the test metrics to a file."""
215
+ if self.metrics_file:
216
+ with open(self.metrics_file, "w") as f:
217
+ metrics = self.test_metrics.compute()
218
+ metrics_dict = {k: v.item() for k, v in metrics.items()}
219
+ json.dump(metrics_dict, f, indent=4)
220
+ logger.info(f"Saved metrics to {self.metrics_file}")
221
+
180
222
  def training_step(
181
223
  self, batch: Any, batch_idx: int, dataloader_idx: int = 0
182
224
  ) -> torch.Tensor:
@@ -190,9 +232,16 @@ class RslearnLightningModule(L.LightningModule):
190
232
  Returns:
191
233
  The loss tensor.
192
234
  """
193
- inputs, targets, _ = batch
235
+ inputs, targets, metadatas = batch
236
+ context = ModelContext(
237
+ inputs=inputs,
238
+ metadatas=metadatas,
239
+ )
194
240
  batch_size = len(inputs)
195
- _, loss_dict = self(inputs, targets)
241
+ model_outputs = self(context, targets)
242
+ self.on_train_forward(context, targets, model_outputs)
243
+
244
+ loss_dict = model_outputs.loss_dict
196
245
  train_loss = sum(loss_dict.values())
197
246
  self.log_dict(
198
247
  {"train_" + k: v for k, v in loss_dict.items()},
@@ -200,6 +249,7 @@ class RslearnLightningModule(L.LightningModule):
200
249
  prog_bar=True,
201
250
  on_step=False,
202
251
  on_epoch=True,
252
+ sync_dist=True,
203
253
  )
204
254
  self.log(
205
255
  "train_loss",
@@ -207,6 +257,7 @@ class RslearnLightningModule(L.LightningModule):
207
257
  batch_size=batch_size,
208
258
  on_step=False,
209
259
  on_epoch=True,
260
+ sync_dist=True,
210
261
  )
211
262
  return train_loss
212
263
 
@@ -220,15 +271,24 @@ class RslearnLightningModule(L.LightningModule):
220
271
  batch_idx: Integer displaying index of this batch.
221
272
  dataloader_idx: Index of the current dataloader.
222
273
  """
223
- inputs, targets, _ = batch
274
+ inputs, targets, metadatas = batch
275
+ context = ModelContext(
276
+ inputs=inputs,
277
+ metadatas=metadatas,
278
+ )
224
279
  batch_size = len(inputs)
225
- outputs, loss_dict = self(inputs, targets)
280
+ model_outputs = self(context, targets)
281
+ self.on_val_forward(context, targets, model_outputs)
282
+
283
+ loss_dict = model_outputs.loss_dict
284
+ outputs = model_outputs.outputs
226
285
  val_loss = sum(loss_dict.values())
227
286
  self.log_dict(
228
287
  {"val_" + k: v for k, v in loss_dict.items()},
229
288
  batch_size=batch_size,
230
289
  on_step=False,
231
290
  on_epoch=True,
291
+ sync_dist=True,
232
292
  )
233
293
  self.log(
234
294
  "val_loss",
@@ -237,9 +297,12 @@ class RslearnLightningModule(L.LightningModule):
237
297
  prog_bar=True,
238
298
  on_step=False,
239
299
  on_epoch=True,
300
+ sync_dist=True,
240
301
  )
241
302
  self.val_metrics.update(outputs, targets)
242
- self.log_dict(self.val_metrics, batch_size=batch_size, on_epoch=True)
303
+ self.log_dict(
304
+ self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
305
+ )
243
306
 
244
307
  def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
245
308
  """Compute the test loss and additional metrics.
@@ -250,20 +313,36 @@ class RslearnLightningModule(L.LightningModule):
250
313
  dataloader_idx: Index of the current dataloader.
251
314
  """
252
315
  inputs, targets, metadatas = batch
316
+ context = ModelContext(
317
+ inputs=inputs,
318
+ metadatas=metadatas,
319
+ )
253
320
  batch_size = len(inputs)
254
- outputs, loss_dict = self(inputs, targets)
321
+ model_outputs = self(context, targets)
322
+ self.on_test_forward(context, targets, model_outputs)
323
+
324
+ loss_dict = model_outputs.loss_dict
325
+ outputs = model_outputs.outputs
255
326
  test_loss = sum(loss_dict.values())
256
327
  self.log_dict(
257
328
  {"test_" + k: v for k, v in loss_dict.items()},
258
329
  batch_size=batch_size,
259
330
  on_step=False,
260
331
  on_epoch=True,
332
+ sync_dist=True,
261
333
  )
262
334
  self.log(
263
- "test_loss", test_loss, batch_size=batch_size, on_step=False, on_epoch=True
335
+ "test_loss",
336
+ test_loss,
337
+ batch_size=batch_size,
338
+ on_step=False,
339
+ on_epoch=True,
340
+ sync_dist=True,
264
341
  )
265
342
  self.test_metrics.update(outputs, targets)
266
- self.log_dict(self.test_metrics, batch_size=batch_size, on_epoch=True)
343
+ self.log_dict(
344
+ self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
345
+ )
267
346
 
268
347
  if self.visualize_dir:
269
348
  for idx, (inp, target, output, metadata) in enumerate(
@@ -273,13 +352,13 @@ class RslearnLightningModule(L.LightningModule):
273
352
  for image_suffix, image in images.items():
274
353
  out_fname = os.path.join(
275
354
  self.visualize_dir,
276
- f'{metadata["window_name"]}_{metadata["bounds"][0]}_{metadata["bounds"][1]}_{image_suffix}.png',
355
+ f"{metadata['window_name']}_{metadata['bounds'][0]}_{metadata['bounds'][1]}_{image_suffix}.png",
277
356
  )
278
357
  Image.fromarray(image).save(out_fname)
279
358
 
280
359
  def predict_step(
281
360
  self, batch: Any, batch_idx: int, dataloader_idx: int = 0
282
- ) -> torch.Tensor:
361
+ ) -> ModelOutput:
283
362
  """Compute the predicted class probabilities.
284
363
 
285
364
  Args:
@@ -290,18 +369,69 @@ class RslearnLightningModule(L.LightningModule):
290
369
  Returns:
291
370
  Output predicted probabilities.
292
371
  """
293
- inputs, _, _ = batch
294
- outputs, _ = self(inputs)
295
- return outputs
372
+ inputs, _, metadatas = batch
373
+ context = ModelContext(
374
+ inputs=inputs,
375
+ metadatas=metadatas,
376
+ )
377
+ model_outputs = self(context)
378
+ return model_outputs
296
379
 
297
- def forward(self, *args: Any, **kwargs: Any) -> Any:
380
+ def forward(
381
+ self, context: ModelContext, targets: list[dict[str, Any]] | None = None
382
+ ) -> ModelOutput:
298
383
  """Forward pass of the model.
299
384
 
300
385
  Args:
301
- args: Arguments to pass to model.
302
- kwargs: Keyword arguments to pass to model.
386
+ context: the model context.
387
+ targets: the target dicts.
303
388
 
304
389
  Returns:
305
390
  Output of the model.
306
391
  """
307
- return self.model(*args, **kwargs)
392
+ return self.model(context, targets)
393
+
394
+ def on_train_forward(
395
+ self,
396
+ context: ModelContext,
397
+ targets: list[dict[str, Any]],
398
+ model_outputs: ModelOutput,
399
+ ) -> None:
400
+ """Hook to run after the forward pass of the model during training.
401
+
402
+ Args:
403
+ context: The model context.
404
+ targets: The target batch.
405
+ model_outputs: The output of the model.
406
+ """
407
+ pass
408
+
409
+ def on_val_forward(
410
+ self,
411
+ context: ModelContext,
412
+ targets: list[dict[str, Any]],
413
+ model_outputs: ModelOutput,
414
+ ) -> None:
415
+ """Hook to run after the forward pass of the model during validation.
416
+
417
+ Args:
418
+ context: The model context.
419
+ targets: The target batch.
420
+ model_outputs: The output of the model.
421
+ """
422
+ pass
423
+
424
+ def on_test_forward(
425
+ self,
426
+ context: ModelContext,
427
+ targets: list[dict[str, Any]],
428
+ model_outputs: ModelOutput,
429
+ ) -> None:
430
+ """Hook to run after the forward pass of the model during testing.
431
+
432
+ Args:
433
+ context: The model context.
434
+ targets: The target batch.
435
+ model_outputs: The output of the model.
436
+ """
437
+ pass
@@ -0,0 +1,88 @@
1
+ """Data classes to provide various context to models."""
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from rslearn.utils.geometry import PixelBounds, Projection
11
+
12
+
13
+ @dataclass
14
+ class RasterImage:
15
+ """A raster image is a torch.tensor containing the images and their associated timestamps."""
16
+
17
+ # image is a 4D CTHW tensor
18
+ image: torch.Tensor
19
+ # if timestamps is not None, len(timestamps) must match the T dimension of the tensor
20
+ timestamps: list[tuple[datetime, datetime]] | None = None
21
+
22
+ @property
23
+ def shape(self) -> torch.Size:
24
+ """The shape of the image."""
25
+ return self.image.shape
26
+
27
+ def dim(self) -> int:
28
+ """The dim of the image."""
29
+ return self.image.dim()
30
+
31
+ @property
32
+ def dtype(self) -> torch.dtype:
33
+ """The image dtype."""
34
+ return self.image.dtype
35
+
36
+ def single_ts_to_chw_tensor(self) -> torch.Tensor:
37
+ """Single timestep models expect single timestep inputs.
38
+
39
+ This function (1) checks this raster image only has 1 timestep and
40
+ (2) returns the tensor for that (single) timestep (going from CTHW to CHW).
41
+ """
42
+ if self.image.shape[1] != 1:
43
+ raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
44
+ return self.image[:, 0]
45
+
46
+
47
+ @dataclass
48
+ class SampleMetadata:
49
+ """Metadata pertaining to an example."""
50
+
51
+ window_group: str
52
+ window_name: str
53
+ window_bounds: PixelBounds
54
+ patch_bounds: PixelBounds
55
+ patch_idx: int
56
+ num_patches_in_window: int
57
+ time_range: tuple[datetime, datetime] | None
58
+ projection: Projection
59
+
60
+ # Task name to differentiate different tasks.
61
+ dataset_source: str | None
62
+
63
+
64
+ @dataclass
65
+ class ModelContext:
66
+ """Context to pass to all model components."""
67
+
68
+ # One input dict per example in the batch.
69
+ inputs: list[dict[str, torch.Tensor | RasterImage]]
70
+ # One SampleMetadata per example in the batch.
71
+ metadatas: list[SampleMetadata]
72
+ # Arbitrary dict that components can add to.
73
+ context_dict: dict[str, Any] = field(default_factory=lambda: {})
74
+
75
+
76
+ @dataclass
77
+ class ModelOutput:
78
+ """The output from the Predictor.
79
+
80
+ Args:
81
+ outputs: output compatible with the configured Task.
82
+ loss_dict: map from loss names to scalar tensors.
83
+ metadata: arbitrary dict that can be used to store other outputs.
84
+ """
85
+
86
+ outputs: Iterable[Any]
87
+ loss_dict: dict[str, torch.Tensor]
88
+ metadata: dict[str, Any] = field(default_factory=lambda: {})
@@ -0,0 +1,31 @@
1
+ """Optimizers for rslearn."""
2
+
3
+ from dataclasses import asdict, dataclass
4
+
5
+ import lightning as L
6
+ import torch.optim
7
+ from torch.optim import Optimizer
8
+
9
+
10
+ class OptimizerFactory:
11
+ """A factory class that initializes the optimizer given the LightningModule."""
12
+
13
+ def build(self, lm: L.LightningModule) -> Optimizer:
14
+ """Build the optimizer configured by this factory class."""
15
+ raise NotImplementedError
16
+
17
+
18
+ @dataclass
19
+ class AdamW(OptimizerFactory):
20
+ """Factory for AdamW optimzier."""
21
+
22
+ lr: float = 0.001
23
+ betas: tuple[float, float] = (0.9, 0.999)
24
+ eps: float | None = None
25
+ weight_decay: float | None = None
26
+
27
+ def build(self, lm: L.LightningModule) -> Optimizer:
28
+ """Build the AdamW optimizer."""
29
+ params = [p for p in lm.parameters() if p.requires_grad]
30
+ kwargs = {k: v for k, v in asdict(self).items() if v is not None}
31
+ return torch.optim.AdamW(params, **kwargs)