rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,454 @@
1
+ """LightningCLI for rslearn."""
2
+
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+
10
+ import fsspec
11
+ import jsonargparse
12
+ import wandb
13
+ from lightning.pytorch import LightningModule, Trainer
14
+ from lightning.pytorch.callbacks import Callback
15
+ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
16
+ from lightning.pytorch.utilities import rank_zero_only
17
+ from upath import UPath
18
+
19
+ from rslearn.arg_parser import RslearnArgumentParser
20
+ from rslearn.log_utils import get_logger
21
+ from rslearn.train.data_module import RslearnDataModule
22
+ from rslearn.train.lightning_module import RslearnLightningModule
23
+ from rslearn.utils.fsspec import open_atomic
24
+ from rslearn.utils.jsonargparse import init_jsonargparse
25
+
26
+ WANDB_ID_FNAME = "wandb_id"
27
+
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ def get_cached_checkpoint(checkpoint_fname: UPath) -> str:
32
+ """Get a local cached version of the specified checkpoint.
33
+
34
+ If checkpoint_fname is already local, then it is returned. Otherwise, it is saved
35
+ in a deterministic local cache directory under the system temporary directory, and
36
+ the cached filename is returned.
37
+
38
+ Note that the cache is not deleted when the program exits.
39
+
40
+ Args:
41
+ checkpoint_fname: the potentially non-local checkpoint file to load.
42
+
43
+ Returns:
44
+ a local filename containing the same checkpoint.
45
+ """
46
+ is_local = isinstance(
47
+ checkpoint_fname.fs, fsspec.implementations.local.LocalFileSystem
48
+ )
49
+ if is_local:
50
+ return checkpoint_fname.path
51
+
52
+ cache_id = hashlib.sha256(str(checkpoint_fname).encode()).hexdigest()
53
+ local_fname = os.path.join(
54
+ tempfile.gettempdir(), "rslearn_cache", "checkpoints", f"{cache_id}.ckpt"
55
+ )
56
+
57
+ if os.path.exists(local_fname):
58
+ logger.info(
59
+ "using cached checkpoint for %s at %s", str(checkpoint_fname), local_fname
60
+ )
61
+ return local_fname
62
+
63
+ logger.info("caching checkpoint %s to %s", str(checkpoint_fname), local_fname)
64
+ os.makedirs(os.path.dirname(local_fname), exist_ok=True)
65
+ with checkpoint_fname.open("rb") as src:
66
+ with open_atomic(UPath(local_fname), "wb") as dst:
67
+ shutil.copyfileobj(src, dst)
68
+
69
+ return local_fname
70
+
71
+
72
+ class SaveWandbRunIdCallback(Callback):
73
+ """Callback to save the wandb run ID to project directory in case of resume."""
74
+
75
+ def __init__(
76
+ self,
77
+ project_dir: str,
78
+ config_str: str,
79
+ ) -> None:
80
+ """Create a new SaveWandbRunIdCallback.
81
+
82
+ Args:
83
+ project_dir: the project directory.
84
+ config_str: the JSON-encoded configuration of this experiment
85
+ """
86
+ self.project_dir = project_dir
87
+ self.config_str = config_str
88
+
89
+ @rank_zero_only
90
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
91
+ """Called just before fit starts I think.
92
+
93
+ Args:
94
+ trainer: the Trainer object.
95
+ pl_module: the LightningModule object.
96
+ """
97
+ wandb_id = wandb.run.id
98
+
99
+ project_dir = UPath(self.project_dir)
100
+ project_dir.mkdir(parents=True, exist_ok=True)
101
+ with (project_dir / WANDB_ID_FNAME).open("w") as f:
102
+ f.write(wandb_id)
103
+
104
+ if self.config_str is not None and "project_name" not in wandb.config:
105
+ wandb.config.update(json.loads(self.config_str))
106
+
107
+
108
+ class RslearnLightningCLI(LightningCLI):
109
+ """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
110
+
111
+ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
112
+ """Link data.tasks to model.tasks.
113
+
114
+ Args:
115
+ parser: the argument parser
116
+ """
117
+ # Link data.tasks to model.tasks
118
+ parser.link_arguments(
119
+ "data.init_args.task", "model.init_args.task", apply_on="instantiate"
120
+ )
121
+
122
+ # Project management option to have rslearn manage checkpoints and W&B run.
123
+ parser.add_argument(
124
+ "--management_dir",
125
+ type=str | None,
126
+ help="Enable project management, and use this directory to store checkpoints and configs. If enabled, rslearn will automatically manages checkpoint directory/loading and W&B run",
127
+ default=None,
128
+ )
129
+ parser.add_argument(
130
+ "--project_name",
131
+ type=str | None,
132
+ help="The project name (used with --management_dir)",
133
+ default=None,
134
+ )
135
+ parser.add_argument(
136
+ "--run_name",
137
+ type=str | None,
138
+ help="A unique name for this experiment (used with --management_dir)",
139
+ default=None,
140
+ )
141
+ parser.add_argument(
142
+ "--run_description",
143
+ type=str,
144
+ help="Optional description of this experiment (used with --management_dir)",
145
+ default="",
146
+ )
147
+ parser.add_argument(
148
+ "--load_checkpoint_mode",
149
+ type=str,
150
+ help="Which checkpoint to load, if any (used with --management_dir). 'none' never loads any checkpoint, 'last' loads the most recent checkpoint, and 'best' loads the best checkpoint. 'auto' will use 'last' during fit and 'best' during val/test/predict.",
151
+ default="auto",
152
+ )
153
+ parser.add_argument(
154
+ "--load_checkpoint_required",
155
+ type=str,
156
+ help="Whether to fail if the expected checkpoint based on load_checkpoint_mode does not exist (used with --management_dir). 'yes' will fail while 'no' won't. 'auto' will use 'no' during fit and 'yes' during val/test/predict.",
157
+ default="auto",
158
+ )
159
+ parser.add_argument(
160
+ "--log_mode",
161
+ type=str,
162
+ help="Whether to log to W&B (used with --management_dir). 'yes' will enable logging, 'no' will disable logging, and 'auto' will use 'yes' during fit and 'no' during val/test/predict.",
163
+ default="auto",
164
+ )
165
+
166
+ def _get_checkpoint_path(
167
+ self,
168
+ project_dir: UPath,
169
+ load_checkpoint_mode: str,
170
+ load_checkpoint_required: str,
171
+ stage: str,
172
+ ) -> str | None:
173
+ """Get path to checkpoint to load from, or None to not restore checkpoint.
174
+
175
+ Args:
176
+ project_dir: the project directory determined from the project management
177
+ directory.
178
+ load_checkpoint_mode: "none" to not load any checkpoint, "last" to load the
179
+ most recent checkpoint, "best" to load the best checkpoint. "auto" to
180
+ use "last" during fit and "best" during val/test/predict.
181
+ load_checkpoint_required: "yes" to fail if no checkpoint exists, "no" to
182
+ ignore. "auto" will use "no" during fit and "yes" during
183
+ val/test/predict.
184
+ stage: the lightning stage (fit/val/test/predict).
185
+
186
+ Returns:
187
+ the path to the checkpoint for setting c.ckpt_path, or None if no
188
+ checkpoint should be restored.
189
+ """
190
+ # Resolve auto options if used.
191
+ if load_checkpoint_mode == "auto":
192
+ if stage == "fit":
193
+ load_checkpoint_mode = "last"
194
+ else:
195
+ load_checkpoint_mode = "best"
196
+ if load_checkpoint_required == "auto":
197
+ if stage == "fit":
198
+ load_checkpoint_required = "no"
199
+ else:
200
+ load_checkpoint_required = "yes"
201
+
202
+ if load_checkpoint_required == "yes" and load_checkpoint_mode == "none":
203
+ raise ValueError(
204
+ "load_checkpoint_required cannot be set when load_checkpoint_mode is none"
205
+ )
206
+
207
+ ckpt_path: str | None = None
208
+
209
+ if load_checkpoint_mode == "best":
210
+ # Checkpoints should be either:
211
+ # - last.ckpt
212
+ # - of the form "A=B-C=D-....ckpt" with one key being epoch=X
213
+ # So we want the one with the highest epoch, and only use last.ckpt if
214
+ # it's the only option.
215
+ # User should set save_top_k=1 so there's just one, otherwise we won't
216
+ # actually know which one is the best.
217
+ best_checkpoint = None
218
+ best_epochs = None
219
+
220
+ # Avoid error in case project_dir doesn't exist.
221
+ fnames = project_dir.iterdir() if project_dir.exists() else []
222
+
223
+ for option in fnames:
224
+ if not option.name.endswith(".ckpt"):
225
+ continue
226
+
227
+ # Try to see what epochs this checkpoint is at.
228
+ # If it is some other format, then set it 0 so we only use it if it's
229
+ # the only option.
230
+ # If it is last.ckpt then we set it -100 to only use it if there is not
231
+ # even another format like "best.ckpt".
232
+ extracted_epochs = 0
233
+ if option.name == "last.ckpt":
234
+ extracted_epochs = -100
235
+
236
+ parts = option.name.split(".ckpt")[0].split("-")
237
+ for part in parts:
238
+ kv_parts = part.split("=")
239
+ if len(kv_parts) != 2:
240
+ continue
241
+ if kv_parts[0] != "epoch":
242
+ continue
243
+ extracted_epochs = int(kv_parts[1])
244
+
245
+ if best_epochs is None or extracted_epochs > best_epochs:
246
+ best_checkpoint = option
247
+ best_epochs = extracted_epochs
248
+
249
+ if best_checkpoint is not None:
250
+ # Cache the checkpoint so we only need to download once in case we
251
+ # reuse it later.
252
+ # We only cache with --load_best since this is the only scenario where we
253
+ # expect to keep reusing the same checkpoint.
254
+ ckpt_path = get_cached_checkpoint(best_checkpoint)
255
+
256
+ elif load_checkpoint_mode == "last":
257
+ last_checkpoint_path = project_dir / "last.ckpt"
258
+ if last_checkpoint_path.exists():
259
+ ckpt_path = str(last_checkpoint_path)
260
+
261
+ else:
262
+ raise ValueError(f"unknown load_checkpoint_mode {load_checkpoint_mode}")
263
+
264
+ if load_checkpoint_required == "yes" and ckpt_path is None:
265
+ raise ValueError(
266
+ "load_checkpoint_required is set but no checkpoint was found"
267
+ )
268
+
269
+ return ckpt_path
270
+
271
+ def enable_project_management(self, management_dir: str) -> None:
272
+ """Enable project management in the specified directory.
273
+
274
+ Args:
275
+ management_dir: the directory to store checkpoints and W&B.
276
+ """
277
+ subcommand = self.config.subcommand
278
+ c = self.config[subcommand]
279
+
280
+ # Project name and run name are required with project management.
281
+ if not c.project_name or not c.run_name:
282
+ raise ValueError(
283
+ "project name and run name must be set when using project management"
284
+ )
285
+
286
+ # Get project directory within the project management directory.
287
+ project_dir = UPath(management_dir) / c.project_name / c.run_name
288
+
289
+ # Add the W&B logger if it isn't already set, and (re-)configure it.
290
+ should_log = False
291
+ if c.log_mode == "yes":
292
+ should_log = True
293
+ elif c.log_mode == "auto":
294
+ should_log = subcommand == "fit"
295
+ if should_log:
296
+ if not c.trainer.logger:
297
+ c.trainer.logger = jsonargparse.Namespace(
298
+ {
299
+ "class_path": "lightning.pytorch.loggers.WandbLogger",
300
+ "init_args": jsonargparse.Namespace(),
301
+ }
302
+ )
303
+ c.trainer.logger.init_args.project = c.project_name
304
+ c.trainer.logger.init_args.name = c.run_name
305
+ if c.run_description:
306
+ c.trainer.logger.init_args.notes = c.run_description
307
+
308
+ # Add callback to save config to W&B.
309
+ upload_wandb_callback = None
310
+ if "callbacks" in c.trainer and c.trainer.callbacks:
311
+ for existing_callback in c.trainer.callbacks:
312
+ if existing_callback.class_path == "SaveWandbRunIdCallback":
313
+ upload_wandb_callback = existing_callback
314
+ else:
315
+ c.trainer.callbacks = []
316
+
317
+ if not upload_wandb_callback:
318
+ config_str = json.dumps(
319
+ c.as_dict(), default=lambda _: "<not serializable>"
320
+ )
321
+ upload_wandb_callback = jsonargparse.Namespace(
322
+ {
323
+ "class_path": "SaveWandbRunIdCallback",
324
+ "init_args": jsonargparse.Namespace(
325
+ {
326
+ "project_dir": str(project_dir),
327
+ "config_str": config_str,
328
+ }
329
+ ),
330
+ }
331
+ )
332
+ c.trainer.callbacks.append(upload_wandb_callback)
333
+ elif c.trainer.logger:
334
+ logger.warning(
335
+ "Model management is enabled and logging should be off, but the model config specifies a logger. "
336
+ + "The logger should be removed from the model config, since it will not be automatically disabled."
337
+ )
338
+
339
+ if subcommand == "fit":
340
+ # Set the checkpoint directory to match the project directory.
341
+ checkpoint_callback = None
342
+ if "callbacks" in c.trainer and c.trainer.callbacks:
343
+ for existing_callback in c.trainer.callbacks:
344
+ if (
345
+ existing_callback.class_path
346
+ == "lightning.pytorch.callbacks.ModelCheckpoint"
347
+ ):
348
+ checkpoint_callback = existing_callback
349
+ else:
350
+ c.trainer.callbacks = []
351
+
352
+ if not checkpoint_callback:
353
+ checkpoint_callback = jsonargparse.Namespace(
354
+ {
355
+ "class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
356
+ "init_args": jsonargparse.Namespace(
357
+ {
358
+ "save_last": True,
359
+ "save_top_k": 1,
360
+ "monitor": "val_loss",
361
+ }
362
+ ),
363
+ }
364
+ )
365
+ c.trainer.callbacks.append(checkpoint_callback)
366
+ checkpoint_callback.init_args.dirpath = str(project_dir)
367
+
368
+ # Load existing checkpoint.
369
+ checkpoint_path = self._get_checkpoint_path(
370
+ project_dir=project_dir,
371
+ load_checkpoint_mode=c.load_checkpoint_mode,
372
+ load_checkpoint_required=c.load_checkpoint_required,
373
+ stage=subcommand,
374
+ )
375
+ if checkpoint_path is not None:
376
+ logger.info(f"found checkpoint to resume from at {checkpoint_path}")
377
+ c.ckpt_path = checkpoint_path
378
+
379
+ # If we are resuming from a checkpoint for training, we also try to resume the W&B run.
380
+ if (
381
+ subcommand == "fit"
382
+ and (project_dir / WANDB_ID_FNAME).exists()
383
+ and should_log
384
+ ):
385
+ with (project_dir / WANDB_ID_FNAME).open("r") as f:
386
+ wandb_id = f.read().strip()
387
+ c.trainer.logger.init_args.id = wandb_id
388
+
389
+ def before_instantiate_classes(self) -> None:
390
+ """Called before Lightning class initialization.
391
+
392
+ Sets the dataset path for any configured RslearnPredictionWriter callbacks.
393
+ """
394
+ if not hasattr(self.config, "subcommand"):
395
+ logger.warning(
396
+ "Config does not have subcommand attribute, assuming we are in run=False mode"
397
+ )
398
+ subcommand = None
399
+ c = self.config
400
+ else:
401
+ subcommand = self.config.subcommand
402
+ c = self.config[subcommand]
403
+
404
+ # If there is a RslearnPredictionWriter, set its path.
405
+ prediction_writer_callback = None
406
+ if "callbacks" in c.trainer and c.trainer.callbacks:
407
+ for existing_callback in c.trainer.callbacks:
408
+ if (
409
+ existing_callback.class_path
410
+ == "rslearn.train.prediction_writer.RslearnWriter"
411
+ ):
412
+ prediction_writer_callback = existing_callback
413
+ if prediction_writer_callback:
414
+ prediction_writer_callback.init_args.path = c.data.init_args.path
415
+
416
+ # Disable the sampler replacement, since the rslearn data module will set the
417
+ # sampler as needed.
418
+ c.trainer.use_distributed_sampler = False
419
+
420
+ # For predict, make sure that return_predictions is False.
421
+ # Otherwise all the predictions would be stored in memory which can lead to
422
+ # high memory consumption.
423
+ if subcommand == "predict":
424
+ c.return_predictions = False
425
+
426
+ # Default to DDP with find_unused_parameters. Likely won't get called with unified config
427
+ if subcommand == "fit":
428
+ if not c.trainer.strategy:
429
+ c.trainer.strategy = jsonargparse.Namespace(
430
+ {
431
+ "class_path": "lightning.pytorch.strategies.DDPStrategy",
432
+ "init_args": jsonargparse.Namespace(
433
+ {"find_unused_parameters": True}
434
+ ),
435
+ }
436
+ )
437
+
438
+ if c.management_dir:
439
+ self.enable_project_management(c.management_dir)
440
+
441
+
442
+ def model_handler() -> None:
443
+ """Handler for any rslearn model X commands."""
444
+ init_jsonargparse()
445
+
446
+ RslearnLightningCLI(
447
+ model_class=RslearnLightningModule,
448
+ datamodule_class=RslearnDataModule,
449
+ args=sys.argv[2:],
450
+ subclass_mode_model=True,
451
+ subclass_mode_data=True,
452
+ save_config_kwargs={"overwrite": True},
453
+ parser_class=RslearnArgumentParser,
454
+ )
rslearn/log_utils.py ADDED
@@ -0,0 +1,24 @@
1
+ """Logging utilities."""
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ LOG_FORMAT = "format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s"
8
+ # DETAILED_LOG_FORMAT = "format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s call_trace=%(pathname)s L%(lineno)-4d" # noqa
9
+
10
+
11
+ def get_logger(name: str) -> logging.Logger:
12
+ """Get a logger with a console handler."""
13
+ this_logger = logging.getLogger(name)
14
+ log_level = os.environ.get("RSLEARN_LOGLEVEL", "INFO")
15
+ if not this_logger.handlers:
16
+ console_handler = logging.StreamHandler(sys.stdout)
17
+ console_handler.setLevel(log_level)
18
+ console_formatter = logging.Formatter(LOG_FORMAT)
19
+ console_handler.setFormatter(console_formatter)
20
+ this_logger.addHandler(console_handler)
21
+
22
+ this_logger.setLevel(log_level)
23
+ this_logger.propagate = True
24
+ return this_logger