rslearn 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +420 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +14 -20
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +377 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/concatenate_features.py +93 -0
  33. rslearn/models/olmoearth_pretrain/model.py +2 -5
  34. rslearn/tile_stores/__init__.py +0 -11
  35. rslearn/train/dataset.py +4 -12
  36. rslearn/train/prediction_writer.py +16 -32
  37. rslearn/train/tasks/classification.py +2 -1
  38. rslearn/utils/fsspec.py +20 -0
  39. rslearn/utils/jsonargparse.py +79 -0
  40. rslearn/utils/raster_format.py +1 -41
  41. rslearn/utils/vector_format.py +1 -38
  42. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/METADATA +58 -25
  43. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/RECORD +48 -49
  44. rslearn/data_sources/geotiff.py +0 -1
  45. rslearn/data_sources/raster_source.py +0 -23
  46. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/top_level.txt +0 -0
rslearn/lightning_cli.py CHANGED
@@ -1,12 +1,107 @@
1
1
  """LightningCLI for rslearn."""
2
2
 
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import shutil
3
7
  import sys
8
+ import tempfile
4
9
 
10
+ import fsspec
11
+ import jsonargparse
12
+ import wandb
13
+ from lightning.pytorch import LightningModule, Trainer
14
+ from lightning.pytorch.callbacks import Callback
5
15
  from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
16
+ from lightning.pytorch.utilities import rank_zero_only
17
+ from upath import UPath
6
18
 
7
19
  from rslearn.arg_parser import RslearnArgumentParser
20
+ from rslearn.log_utils import get_logger
8
21
  from rslearn.train.data_module import RslearnDataModule
9
22
  from rslearn.train.lightning_module import RslearnLightningModule
23
+ from rslearn.utils.fsspec import open_atomic
24
+
25
+ WANDB_ID_FNAME = "wandb_id"
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ def get_cached_checkpoint(checkpoint_fname: UPath) -> str:
31
+ """Get a local cached version of the specified checkpoint.
32
+
33
+ If checkpoint_fname is already local, then it is returned. Otherwise, it is saved
34
+ in a deterministic local cache directory under the system temporary directory, and
35
+ the cached filename is returned.
36
+
37
+ Note that the cache is not deleted when the program exits.
38
+
39
+ Args:
40
+ checkpoint_fname: the potentially non-local checkpoint file to load.
41
+
42
+ Returns:
43
+ a local filename containing the same checkpoint.
44
+ """
45
+ is_local = isinstance(
46
+ checkpoint_fname.fs, fsspec.implementations.local.LocalFileSystem
47
+ )
48
+ if is_local:
49
+ return checkpoint_fname.path
50
+
51
+ cache_id = hashlib.sha256(str(checkpoint_fname).encode()).hexdigest()
52
+ local_fname = os.path.join(
53
+ tempfile.gettempdir(), "rslearn_cache", "checkpoints", f"{cache_id}.ckpt"
54
+ )
55
+
56
+ if os.path.exists(local_fname):
57
+ logger.info(
58
+ "using cached checkpoint for %s at %s", str(checkpoint_fname), local_fname
59
+ )
60
+ return local_fname
61
+
62
+ logger.info("caching checkpoint %s to %s", str(checkpoint_fname), local_fname)
63
+ os.makedirs(os.path.dirname(local_fname), exist_ok=True)
64
+ with checkpoint_fname.open("rb") as src:
65
+ with open_atomic(UPath(local_fname), "wb") as dst:
66
+ shutil.copyfileobj(src, dst)
67
+
68
+ return local_fname
69
+
70
+
71
+ class SaveWandbRunIdCallback(Callback):
72
+ """Callback to save the wandb run ID to project directory in case of resume."""
73
+
74
+ def __init__(
75
+ self,
76
+ project_dir: str,
77
+ config_str: str,
78
+ ) -> None:
79
+ """Create a new SaveWandbRunIdCallback.
80
+
81
+ Args:
82
+ project_dir: the project directory.
83
+ config_str: the JSON-encoded configuration of this experiment
84
+ """
85
+ self.project_dir = project_dir
86
+ self.config_str = config_str
87
+
88
+ @rank_zero_only
89
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
90
+ """Called just before fit starts I think.
91
+
92
+ Args:
93
+ trainer: the Trainer object.
94
+ pl_module: the LightningModule object.
95
+ """
96
+ wandb_id = wandb.run.id
97
+
98
+ project_dir = UPath(self.project_dir)
99
+ project_dir.mkdir(parents=True, exist_ok=True)
100
+ with (project_dir / WANDB_ID_FNAME).open("w") as f:
101
+ f.write(wandb_id)
102
+
103
+ if self.config_str is not None and "project_name" not in wandb.config:
104
+ wandb.config.update(json.loads(self.config_str))
10
105
 
11
106
 
12
107
  class RslearnLightningCLI(LightningCLI):
@@ -23,6 +118,273 @@ class RslearnLightningCLI(LightningCLI):
23
118
  "data.init_args.task", "model.init_args.task", apply_on="instantiate"
24
119
  )
25
120
 
121
+ # Project management option to have rslearn manage checkpoints and W&B run.
122
+ parser.add_argument(
123
+ "--management_dir",
124
+ type=str | None,
125
+ help="Enable project management, and use this directory to store checkpoints and configs. If enabled, rslearn will automatically manages checkpoint directory/loading and W&B run",
126
+ default=None,
127
+ )
128
+ parser.add_argument(
129
+ "--project_name",
130
+ type=str | None,
131
+ help="The project name (used with --management_dir)",
132
+ default=None,
133
+ )
134
+ parser.add_argument(
135
+ "--run_name",
136
+ type=str | None,
137
+ help="A unique name for this experiment (used with --management_dir)",
138
+ default=None,
139
+ )
140
+ parser.add_argument(
141
+ "--run_description",
142
+ type=str,
143
+ help="Optional description of this experiment (used with --management_dir)",
144
+ default="",
145
+ )
146
+ parser.add_argument(
147
+ "--load_checkpoint_mode",
148
+ type=str,
149
+ help="Which checkpoint to load, if any (used with --management_dir). 'none' never loads any checkpoint, 'last' loads the most recent checkpoint, and 'best' loads the best checkpoint. 'auto' will use 'last' during fit and 'best' during val/test/predict.",
150
+ default="auto",
151
+ )
152
+ parser.add_argument(
153
+ "--load_checkpoint_required",
154
+ type=str,
155
+ help="Whether to fail if the expected checkpoint based on load_checkpoint_mode does not exist (used with --management_dir). 'yes' will fail while 'no' won't. 'auto' will use 'no' during fit and 'yes' during val/test/predict.",
156
+ default="auto",
157
+ )
158
+ parser.add_argument(
159
+ "--log_mode",
160
+ type=str,
161
+ help="Whether to log to W&B (used with --management_dir). 'yes' will enable logging, 'no' will disable logging, and 'auto' will use 'yes' during fit and 'no' during val/test/predict.",
162
+ default="auto",
163
+ )
164
+
165
+ def _get_checkpoint_path(
166
+ self,
167
+ project_dir: UPath,
168
+ load_checkpoint_mode: str,
169
+ load_checkpoint_required: str,
170
+ stage: str,
171
+ ) -> str | None:
172
+ """Get path to checkpoint to load from, or None to not restore checkpoint.
173
+
174
+ Args:
175
+ project_dir: the project directory determined from the project management
176
+ directory.
177
+ load_checkpoint_mode: "none" to not load any checkpoint, "last" to load the
178
+ most recent checkpoint, "best" to load the best checkpoint. "auto" to
179
+ use "last" during fit and "best" during val/test/predict.
180
+ load_checkpoint_required: "yes" to fail if no checkpoint exists, "no" to
181
+ ignore. "auto" will use "no" during fit and "yes" during
182
+ val/test/predict.
183
+ stage: the lightning stage (fit/val/test/predict).
184
+
185
+ Returns:
186
+ the path to the checkpoint for setting c.ckpt_path, or None if no
187
+ checkpoint should be restored.
188
+ """
189
+ # Resolve auto options if used.
190
+ if load_checkpoint_mode == "auto":
191
+ if stage == "fit":
192
+ load_checkpoint_mode = "last"
193
+ else:
194
+ load_checkpoint_mode = "best"
195
+ if load_checkpoint_required == "auto":
196
+ if stage == "fit":
197
+ load_checkpoint_required = "no"
198
+ else:
199
+ load_checkpoint_required = "yes"
200
+
201
+ if load_checkpoint_required == "yes" and load_checkpoint_mode == "none":
202
+ raise ValueError(
203
+ "load_checkpoint_required cannot be set when load_checkpoint_mode is none"
204
+ )
205
+
206
+ ckpt_path: str | None = None
207
+
208
+ if load_checkpoint_mode == "best":
209
+ # Checkpoints should be either:
210
+ # - last.ckpt
211
+ # - of the form "A=B-C=D-....ckpt" with one key being epoch=X
212
+ # So we want the one with the highest epoch, and only use last.ckpt if
213
+ # it's the only option.
214
+ # User should set save_top_k=1 so there's just one, otherwise we won't
215
+ # actually know which one is the best.
216
+ best_checkpoint = None
217
+ best_epochs = None
218
+
219
+ # Avoid error in case project_dir doesn't exist.
220
+ fnames = project_dir.iterdir() if project_dir.exists() else []
221
+
222
+ for option in fnames:
223
+ if not option.name.endswith(".ckpt"):
224
+ continue
225
+
226
+ # Try to see what epochs this checkpoint is at.
227
+ # If it is some other format, then set it 0 so we only use it if it's
228
+ # the only option.
229
+ # If it is last.ckpt then we set it -100 to only use it if there is not
230
+ # even another format like "best.ckpt".
231
+ extracted_epochs = 0
232
+ if option.name == "last.ckpt":
233
+ extracted_epochs = -100
234
+
235
+ parts = option.name.split(".ckpt")[0].split("-")
236
+ for part in parts:
237
+ kv_parts = part.split("=")
238
+ if len(kv_parts) != 2:
239
+ continue
240
+ if kv_parts[0] != "epoch":
241
+ continue
242
+ extracted_epochs = int(kv_parts[1])
243
+
244
+ if best_epochs is None or extracted_epochs > best_epochs:
245
+ best_checkpoint = option
246
+ best_epochs = extracted_epochs
247
+
248
+ if best_checkpoint is not None:
249
+ # Cache the checkpoint so we only need to download once in case we
250
+ # reuse it later.
251
+ # We only cache with --load_best since this is the only scenario where we
252
+ # expect to keep reusing the same checkpoint.
253
+ ckpt_path = get_cached_checkpoint(best_checkpoint)
254
+
255
+ elif load_checkpoint_mode == "last":
256
+ last_checkpoint_path = project_dir / "last.ckpt"
257
+ if last_checkpoint_path.exists():
258
+ ckpt_path = str(last_checkpoint_path)
259
+
260
+ else:
261
+ raise ValueError(f"unknown load_checkpoint_mode {load_checkpoint_mode}")
262
+
263
+ if load_checkpoint_required == "yes" and ckpt_path is None:
264
+ raise ValueError(
265
+ "load_checkpoint_required is set but no checkpoint was found"
266
+ )
267
+
268
+ return ckpt_path
269
+
270
+ def enable_project_management(self, management_dir: str) -> None:
271
+ """Enable project management in the specified directory.
272
+
273
+ Args:
274
+ management_dir: the directory to store checkpoints and W&B.
275
+ """
276
+ subcommand = self.config.subcommand
277
+ c = self.config[subcommand]
278
+
279
+ # Project name and run name are required with project management.
280
+ if not c.project_name or not c.run_name:
281
+ raise ValueError(
282
+ "project name and run name must be set when using project management"
283
+ )
284
+
285
+ # Get project directory within the project management directory.
286
+ project_dir = UPath(management_dir) / c.project_name / c.run_name
287
+
288
+ # Add the W&B logger if it isn't already set, and (re-)configure it.
289
+ should_log = False
290
+ if c.log_mode == "yes":
291
+ should_log = True
292
+ elif c.log_mode == "auto":
293
+ should_log = subcommand == "fit"
294
+ if should_log:
295
+ if not c.trainer.logger:
296
+ c.trainer.logger = jsonargparse.Namespace(
297
+ {
298
+ "class_path": "lightning.pytorch.loggers.WandbLogger",
299
+ "init_args": jsonargparse.Namespace(),
300
+ }
301
+ )
302
+ c.trainer.logger.init_args.project = c.project_name
303
+ c.trainer.logger.init_args.name = c.run_name
304
+ if c.run_description:
305
+ c.trainer.logger.init_args.notes = c.run_description
306
+
307
+ # Add callback to save config to W&B.
308
+ upload_wandb_callback = None
309
+ if "callbacks" in c.trainer and c.trainer.callbacks:
310
+ for existing_callback in c.trainer.callbacks:
311
+ if existing_callback.class_path == "SaveWandbRunIdCallback":
312
+ upload_wandb_callback = existing_callback
313
+ else:
314
+ c.trainer.callbacks = []
315
+
316
+ if not upload_wandb_callback:
317
+ config_str = json.dumps(
318
+ c.as_dict(), default=lambda _: "<not serializable>"
319
+ )
320
+ upload_wandb_callback = jsonargparse.Namespace(
321
+ {
322
+ "class_path": "SaveWandbRunIdCallback",
323
+ "init_args": jsonargparse.Namespace(
324
+ {
325
+ "project_dir": str(project_dir),
326
+ "config_str": config_str,
327
+ }
328
+ ),
329
+ }
330
+ )
331
+ c.trainer.callbacks.append(upload_wandb_callback)
332
+ elif c.trainer.logger:
333
+ logger.warning(
334
+ "Model management is enabled and logging should be off, but the model config specifies a logger. "
335
+ + "The logger should be removed from the model config, since it will not be automatically disabled."
336
+ )
337
+
338
+ if subcommand == "fit":
339
+ # Set the checkpoint directory to match the project directory.
340
+ checkpoint_callback = None
341
+ if "callbacks" in c.trainer and c.trainer.callbacks:
342
+ for existing_callback in c.trainer.callbacks:
343
+ if (
344
+ existing_callback.class_path
345
+ == "lightning.pytorch.callbacks.ModelCheckpoint"
346
+ ):
347
+ checkpoint_callback = existing_callback
348
+ else:
349
+ c.trainer.callbacks = []
350
+
351
+ if not checkpoint_callback:
352
+ checkpoint_callback = jsonargparse.Namespace(
353
+ {
354
+ "class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
355
+ "init_args": jsonargparse.Namespace(
356
+ {
357
+ "save_last": True,
358
+ "save_top_k": 1,
359
+ "monitor": "val_loss",
360
+ }
361
+ ),
362
+ }
363
+ )
364
+ c.trainer.callbacks.append(checkpoint_callback)
365
+ checkpoint_callback.init_args.dirpath = str(project_dir)
366
+
367
+ # Load existing checkpoint.
368
+ checkpoint_path = self._get_checkpoint_path(
369
+ project_dir=project_dir,
370
+ load_checkpoint_mode=c.load_checkpoint_mode,
371
+ load_checkpoint_required=c.load_checkpoint_required,
372
+ stage=subcommand,
373
+ )
374
+ if checkpoint_path is not None:
375
+ logger.info(f"found checkpoint to resume from at {checkpoint_path}")
376
+ c.ckpt_path = checkpoint_path
377
+
378
+ # If we are resuming from a checkpoint for training, we also try to resume the W&B run.
379
+ if (
380
+ subcommand == "fit"
381
+ and (project_dir / WANDB_ID_FNAME).exists()
382
+ and should_log
383
+ ):
384
+ with (project_dir / WANDB_ID_FNAME).open("r") as f:
385
+ wandb_id = f.read().strip()
386
+ c.trainer.logger.init_args.id = wandb_id
387
+
26
388
  def before_instantiate_classes(self) -> None:
27
389
  """Called before Lightning class initialization.
28
390
 
@@ -33,7 +395,7 @@ class RslearnLightningCLI(LightningCLI):
33
395
 
34
396
  # If there is a RslearnPredictionWriter, set its path.
35
397
  prediction_writer_callback = None
36
- if "callbacks" in c.trainer:
398
+ if "callbacks" in c.trainer and c.trainer.callbacks:
37
399
  for existing_callback in c.trainer.callbacks:
38
400
  if (
39
401
  existing_callback.class_path
@@ -53,6 +415,20 @@ class RslearnLightningCLI(LightningCLI):
53
415
  if subcommand == "predict":
54
416
  c.return_predictions = False
55
417
 
418
+ # For now we use DDP strategy with find_unused_parameters=True.
419
+ if subcommand == "fit":
420
+ c.trainer.strategy = jsonargparse.Namespace(
421
+ {
422
+ "class_path": "lightning.pytorch.strategies.DDPStrategy",
423
+ "init_args": jsonargparse.Namespace(
424
+ {"find_unused_parameters": True}
425
+ ),
426
+ }
427
+ )
428
+
429
+ if c.management_dir:
430
+ self.enable_project_management(c.management_dir)
431
+
56
432
 
57
433
  def model_handler() -> None:
58
434
  """Handler for any rslearn model X commands."""
rslearn/main.py CHANGED
@@ -15,7 +15,7 @@ from upath import UPath
15
15
 
16
16
  from rslearn.config import LayerConfig
17
17
  from rslearn.const import WGS84_EPSG
18
- from rslearn.data_sources import Item, data_source_from_config
18
+ from rslearn.data_sources import Item
19
19
  from rslearn.dataset import Dataset, Window, WindowLayerData
20
20
  from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
21
21
  from rslearn.dataset.handler_summaries import (
@@ -544,7 +544,7 @@ class IngestHandler:
544
544
  tile_store, layer_name, layer_cfg
545
545
  )
546
546
  layer_cfg = self.dataset.layers[layer_name]
547
- data_source = data_source_from_config(layer_cfg, self.dataset.path)
547
+ data_source = layer_cfg.instantiate_data_source(self.dataset.path)
548
548
 
549
549
  attempts_counter = AttemptsCounter()
550
550
  ingest_counts: IngestCounts | UnknownIngestCounts
@@ -640,7 +640,7 @@ class IngestHandler:
640
640
  if not layer_cfg.data_source.ingest:
641
641
  continue
642
642
 
643
- data_source = data_source_from_config(layer_cfg, self.dataset.path)
643
+ data_source = layer_cfg.instantiate_data_source(self.dataset.path)
644
644
 
645
645
  geometries_by_item: dict = {}
646
646
  for window, layer_datas in windows_and_layer_datas:
@@ -0,0 +1,93 @@
1
+ """Concatenate feature map with features from input data."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+
8
+ class ConcatenateFeatures(torch.nn.Module):
9
+ """Concatenate feature map with additional raw data inputs."""
10
+
11
+ def __init__(
12
+ self,
13
+ key: str,
14
+ in_channels: int | None = None,
15
+ conv_channels: int = 64,
16
+ out_channels: int | None = None,
17
+ num_conv_layers: int = 1,
18
+ kernel_size: int = 3,
19
+ final_relu: bool = False,
20
+ ):
21
+ """Create a new ConcatenateFeatures.
22
+
23
+ Args:
24
+ key: the key of the input_dict to concatenate.
25
+ in_channels: number of input channels of the additional features.
26
+ conv_channels: number of channels of the convolutional layers.
27
+ out_channels: number of output channels of the additional features.
28
+ num_conv_layers: number of convolutional layers to apply to the additional features.
29
+ kernel_size: kernel size of the convolutional layers.
30
+ final_relu: whether to apply a ReLU activation to the final output, default False.
31
+ """
32
+ super().__init__()
33
+ self.key = key
34
+
35
+ if num_conv_layers > 0:
36
+ if in_channels is None or out_channels is None:
37
+ raise ValueError(
38
+ "in_channels and out_channels must be specified if num_conv_layers > 0"
39
+ )
40
+
41
+ conv_layers = []
42
+ for i in range(num_conv_layers):
43
+ conv_in = in_channels if i == 0 else conv_channels
44
+ conv_out = out_channels if i == num_conv_layers - 1 else conv_channels
45
+ conv_layers.append(
46
+ torch.nn.Conv2d(
47
+ in_channels=conv_in,
48
+ out_channels=conv_out,
49
+ kernel_size=kernel_size,
50
+ padding="same",
51
+ )
52
+ )
53
+ if i < num_conv_layers - 1 or final_relu:
54
+ conv_layers.append(torch.nn.ReLU(inplace=True))
55
+
56
+ self.conv_layers = torch.nn.Sequential(*conv_layers)
57
+
58
+ def forward(
59
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
60
+ ) -> list[torch.Tensor]:
61
+ """Concatenate the feature map with the raw data inputs.
62
+
63
+ Args:
64
+ features: list of feature maps at different resolutions.
65
+ inputs: original inputs.
66
+
67
+ Returns:
68
+ concatenated feature maps.
69
+ """
70
+ if not features:
71
+ raise ValueError("Expected at least one feature map, got none.")
72
+
73
+ add_data = torch.stack([input_data[self.key] for input_data in inputs], dim=0)
74
+ add_features = self.conv_layers(add_data)
75
+
76
+ new_features: list[torch.Tensor] = []
77
+ for feature_map in features:
78
+ # Shape of feature map: BCHW
79
+ feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
80
+
81
+ resized_add_features = add_features
82
+ # Resize additional features to match each feature map size if needed
83
+ if add_features.shape[2] != feat_h or add_features.shape[3] != feat_w:
84
+ resized_add_features = torch.nn.functional.interpolate(
85
+ add_features,
86
+ size=(feat_h, feat_w),
87
+ mode="bilinear",
88
+ align_corners=False,
89
+ )
90
+
91
+ new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
92
+
93
+ return new_features
@@ -153,11 +153,8 @@ class OlmoEarth(torch.nn.Module):
153
153
  # Load the checkpoint.
154
154
  if not random_initialization:
155
155
  train_module_dir = checkpoint_upath / "model_and_optim"
156
- if train_module_dir.exists():
157
- load_model_and_optim_state(str(train_module_dir), model)
158
- logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
159
- else:
160
- logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
156
+ load_model_and_optim_state(str(train_module_dir), model)
157
+ logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
161
158
 
162
159
  return model
163
160
 
@@ -22,17 +22,6 @@ def load_tile_store(config: dict[str, Any], ds_path: UPath) -> TileStore:
22
22
  Returns:
23
23
  the TileStore
24
24
  """
25
- if config is None:
26
- tile_store = DefaultTileStore()
27
- tile_store.set_dataset_path(ds_path)
28
- return tile_store
29
-
30
- # Backwards compatability.
31
- if "name" in config and "root_dir" in config and config["name"] == "file":
32
- tile_store = DefaultTileStore(config["root_dir"])
33
- tile_store.set_dataset_path(ds_path)
34
- return tile_store
35
-
36
25
  init_jsonargparse()
37
26
  parser = jsonargparse.ArgumentParser()
38
27
  parser.add_argument("--tile_store", type=TileStore)
rslearn/train/dataset.py CHANGED
@@ -17,9 +17,7 @@ from rasterio.warp import Resampling
17
17
  import rslearn.train.transforms.transform
18
18
  from rslearn.config import (
19
19
  DType,
20
- RasterFormatConfig,
21
- RasterLayerConfig,
22
- VectorLayerConfig,
20
+ LayerConfig,
23
21
  )
24
22
  from rslearn.dataset.dataset import Dataset
25
23
  from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
@@ -28,8 +26,6 @@ from rslearn.train.tasks import Task
28
26
  from rslearn.utils.feature import Feature
29
27
  from rslearn.utils.geometry import PixelBounds
30
28
  from rslearn.utils.mp import star_imap_unordered
31
- from rslearn.utils.raster_format import load_raster_format
32
- from rslearn.utils.vector_format import load_vector_format
33
29
 
34
30
  from .transforms import Sequential
35
31
 
@@ -185,7 +181,7 @@ def read_raster_layer_for_data_input(
185
181
  bounds: PixelBounds,
186
182
  layer_name: str,
187
183
  group_idx: int,
188
- layer_config: RasterLayerConfig,
184
+ layer_config: LayerConfig,
189
185
  data_input: DataInput,
190
186
  ) -> torch.Tensor:
191
187
  """Read a raster layer for a DataInput.
@@ -246,9 +242,7 @@ def read_raster_layer_for_data_input(
246
242
  )
247
243
  if band_set.format is None:
248
244
  raise ValueError(f"No format specified for {layer_name}")
249
- raster_format = load_raster_format(
250
- RasterFormatConfig(band_set.format["name"], band_set.format)
251
- )
245
+ raster_format = band_set.instantiate_raster_format()
252
246
  raster_dir = window.get_raster_dir(
253
247
  layer_name, band_set.bands, group_idx=group_idx
254
248
  )
@@ -349,7 +343,6 @@ def read_data_input(
349
343
  images: list[torch.Tensor] = []
350
344
  for layer_name, group_idx in layers_to_read:
351
345
  layer_config = dataset.layers[layer_name]
352
- assert isinstance(layer_config, RasterLayerConfig)
353
346
  images.append(
354
347
  read_raster_layer_for_data_input(
355
348
  window, bounds, layer_name, group_idx, layer_config, data_input
@@ -363,8 +356,7 @@ def read_data_input(
363
356
  features: list[Feature] = []
364
357
  for layer_name, group_idx in layers_to_read:
365
358
  layer_config = dataset.layers[layer_name]
366
- assert isinstance(layer_config, VectorLayerConfig)
367
- vector_format = load_vector_format(layer_config.format)
359
+ vector_format = layer_config.instantiate_vector_format()
368
360
  layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
369
361
  cur_features = vector_format.decode_vector(
370
362
  layer_dir, window.projection, window.bounds