rslearn 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/config/dataset.py CHANGED
@@ -8,7 +8,6 @@ from typing import Any
8
8
  import numpy as np
9
9
  import numpy.typing as npt
10
10
  import pytimeparse
11
- import torch
12
11
  from rasterio.enums import Resampling
13
12
 
14
13
  from rslearn.utils import PixelBounds, Projection
@@ -49,15 +48,6 @@ class DType(Enum):
49
48
  return np.float32
50
49
  raise ValueError(f"unable to handle numpy dtype {self}")
51
50
 
52
- def get_torch_dtype(self) -> torch.dtype:
53
- """Returns pytorch dtype object corresponding to this DType."""
54
- if self == DType.INT32:
55
- return torch.int32
56
- elif self == DType.FLOAT32:
57
- return torch.float32
58
- else:
59
- raise ValueError(f"unable to handle torch dtype {self}")
60
-
61
51
 
62
52
  RESAMPLING_METHODS = {
63
53
  "nearest": Resampling.nearest,
rslearn/dataset/manage.py CHANGED
@@ -124,12 +124,24 @@ def prepare_dataset_windows(
124
124
  )
125
125
  continue
126
126
  data_source_cfg = layer_cfg.data_source
127
+ min_matches = data_source_cfg.query_config.min_matches
127
128
 
128
129
  # Get windows that need to be prepared for this layer.
130
+ # Also track which windows are skipped vs previously rejected.
129
131
  needed_windows = []
132
+ windows_skipped = 0
133
+ windows_rejected = 0
130
134
  for window in windows:
131
135
  layer_datas = window.load_layer_datas()
132
136
  if layer_name in layer_datas and not force:
137
+ # Window already has layer data - check if it was previously rejected
138
+ layer_data = layer_datas[layer_name]
139
+ if len(layer_data.serialized_item_groups) == 0 and min_matches > 0:
140
+ # Previously rejected due to min_matches
141
+ windows_rejected += 1
142
+ else:
143
+ # Successfully prepared previously
144
+ windows_skipped += 1
133
145
  continue
134
146
  needed_windows.append(window)
135
147
  logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
@@ -141,8 +153,8 @@ def prepare_dataset_windows(
141
153
  data_source_name=data_source_cfg.name,
142
154
  duration_seconds=time.monotonic() - layer_start_time,
143
155
  windows_prepared=0,
144
- windows_skipped=len(windows),
145
- windows_rejected=0,
156
+ windows_skipped=windows_skipped,
157
+ windows_rejected=windows_rejected,
146
158
  get_items_attempts=0,
147
159
  )
148
160
  )
@@ -184,8 +196,6 @@ def prepare_dataset_windows(
184
196
  )
185
197
 
186
198
  windows_prepared = 0
187
- windows_rejected = 0
188
- min_matches = data_source_cfg.query_config.min_matches
189
199
  for window, result in zip(needed_windows, results):
190
200
  layer_datas = window.load_layer_datas()
191
201
  layer_datas[layer_name] = WindowLayerData(
@@ -202,8 +212,6 @@ def prepare_dataset_windows(
202
212
  else:
203
213
  windows_prepared += 1
204
214
 
205
- windows_skipped = len(windows) - len(needed_windows)
206
-
207
215
  layer_summaries.append(
208
216
  LayerPrepareSummary(
209
217
  layer_name=layer_name,
@@ -0,0 +1,67 @@
1
+ """LightningCLI for rslearn."""
2
+
3
+ import sys
4
+
5
+ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
6
+
7
+ from rslearn.arg_parser import RslearnArgumentParser
8
+ from rslearn.train.data_module import RslearnDataModule
9
+ from rslearn.train.lightning_module import RslearnLightningModule
10
+
11
+
12
+ class RslearnLightningCLI(LightningCLI):
13
+ """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
14
+
15
+ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
16
+ """Link data.tasks to model.tasks.
17
+
18
+ Args:
19
+ parser: the argument parser
20
+ """
21
+ # Link data.tasks to model.tasks
22
+ parser.link_arguments(
23
+ "data.init_args.task", "model.init_args.task", apply_on="instantiate"
24
+ )
25
+
26
+ def before_instantiate_classes(self) -> None:
27
+ """Called before Lightning class initialization.
28
+
29
+ Sets the dataset path for any configured RslearnPredictionWriter callbacks.
30
+ """
31
+ subcommand = self.config.subcommand
32
+ c = self.config[subcommand]
33
+
34
+ # If there is a RslearnPredictionWriter, set its path.
35
+ prediction_writer_callback = None
36
+ if "callbacks" in c.trainer:
37
+ for existing_callback in c.trainer.callbacks:
38
+ if (
39
+ existing_callback.class_path
40
+ == "rslearn.train.prediction_writer.RslearnWriter"
41
+ ):
42
+ prediction_writer_callback = existing_callback
43
+ if prediction_writer_callback:
44
+ prediction_writer_callback.init_args.path = c.data.init_args.path
45
+
46
+ # Disable the sampler replacement, since the rslearn data module will set the
47
+ # sampler as needed.
48
+ c.trainer.use_distributed_sampler = False
49
+
50
+ # For predict, make sure that return_predictions is False.
51
+ # Otherwise all the predictions would be stored in memory which can lead to
52
+ # high memory consumption.
53
+ if subcommand == "predict":
54
+ c.return_predictions = False
55
+
56
+
57
+ def model_handler() -> None:
58
+ """Handler for any rslearn model X commands."""
59
+ RslearnLightningCLI(
60
+ model_class=RslearnLightningModule,
61
+ datamodule_class=RslearnDataModule,
62
+ args=sys.argv[2:],
63
+ subclass_mode_model=True,
64
+ subclass_mode_data=True,
65
+ save_config_kwargs={"overwrite": True},
66
+ parser_class=RslearnArgumentParser,
67
+ )
rslearn/main.py CHANGED
@@ -10,11 +10,9 @@ from datetime import UTC, datetime, timedelta
10
10
  from typing import Any, TypeVar
11
11
 
12
12
  import tqdm
13
- from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
14
13
  from rasterio.crs import CRS
15
14
  from upath import UPath
16
15
 
17
- from rslearn.arg_parser import RslearnArgumentParser
18
16
  from rslearn.config import LayerConfig
19
17
  from rslearn.const import WGS84_EPSG
20
18
  from rslearn.data_sources import Item, data_source_from_config
@@ -38,8 +36,6 @@ from rslearn.dataset.manage import (
38
36
  )
39
37
  from rslearn.log_utils import get_logger
40
38
  from rslearn.tile_stores import get_tile_store_with_layer
41
- from rslearn.train.data_module import RslearnDataModule
42
- from rslearn.train.lightning_module import RslearnLightningModule
43
39
  from rslearn.utils import Projection, STGeometry
44
40
 
45
41
  logger = get_logger(__name__)
@@ -831,85 +827,35 @@ def dataset_build_index() -> None:
831
827
  index.save_index(ds_path)
832
828
 
833
829
 
834
- class RslearnLightningCLI(LightningCLI):
835
- """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
836
-
837
- def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
838
- """Link data.tasks to model.tasks.
839
-
840
- Args:
841
- parser: the argument parser
842
- """
843
- # Link data.tasks to model.tasks
844
- parser.link_arguments(
845
- "data.init_args.task", "model.init_args.task", apply_on="instantiate"
846
- )
847
-
848
- def before_instantiate_classes(self) -> None:
849
- """Called before Lightning class initialization.
850
-
851
- Sets the dataset path for any configured RslearnPredictionWriter callbacks.
852
- """
853
- subcommand = self.config.subcommand
854
- c = self.config[subcommand]
855
-
856
- # If there is a RslearnPredictionWriter, set its path.
857
- prediction_writer_callback = None
858
- if "callbacks" in c.trainer:
859
- for existing_callback in c.trainer.callbacks:
860
- if (
861
- existing_callback.class_path
862
- == "rslearn.train.prediction_writer.RslearnWriter"
863
- ):
864
- prediction_writer_callback = existing_callback
865
- if prediction_writer_callback:
866
- prediction_writer_callback.init_args.path = c.data.init_args.path
867
-
868
- # Disable the sampler replacement, since the rslearn data module will set the
869
- # sampler as needed.
870
- c.trainer.use_distributed_sampler = False
871
-
872
- # For predict, make sure that return_predictions is False.
873
- # Otherwise all the predictions would be stored in memory which can lead to
874
- # high memory consumption.
875
- if subcommand == "predict":
876
- c.return_predictions = False
877
-
878
-
879
- def model_handler() -> None:
880
- """Handler for any rslearn model X commands."""
881
- RslearnLightningCLI(
882
- model_class=RslearnLightningModule,
883
- datamodule_class=RslearnDataModule,
884
- args=sys.argv[2:],
885
- subclass_mode_model=True,
886
- subclass_mode_data=True,
887
- save_config_kwargs={"overwrite": True},
888
- parser_class=RslearnArgumentParser,
889
- )
890
-
891
-
892
830
  @register_handler("model", "fit")
893
831
  def model_fit() -> None:
894
832
  """Handler for rslearn model fit."""
833
+ from .lightning_cli import model_handler
834
+
895
835
  model_handler()
896
836
 
897
837
 
898
838
  @register_handler("model", "validate")
899
839
  def model_validate() -> None:
900
840
  """Handler for rslearn model validate."""
841
+ from .lightning_cli import model_handler
842
+
901
843
  model_handler()
902
844
 
903
845
 
904
846
  @register_handler("model", "test")
905
847
  def model_test() -> None:
906
848
  """Handler for rslearn model test."""
849
+ from .lightning_cli import model_handler
850
+
907
851
  model_handler()
908
852
 
909
853
 
910
854
  @register_handler("model", "predict")
911
855
  def model_predict() -> None:
912
856
  """Handler for rslearn model predict."""
857
+ from .lightning_cli import model_handler
858
+
913
859
  model_handler()
914
860
 
915
861
 
@@ -8,6 +8,7 @@ from importlib.resources import files
8
8
  from typing import Any
9
9
 
10
10
  import torch
11
+ import torch.nn.functional as F
11
12
  import yaml
12
13
  from einops import rearrange
13
14
  from huggingface_hub import hf_hub_download
@@ -30,6 +31,7 @@ PATCH_SIZE = 8
30
31
  CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
31
32
  CONFIG_DIR = files("rslearn.models.clay.configs")
32
33
  CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
34
+ DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
33
35
 
34
36
 
35
37
  def get_clay_checkpoint_path(
@@ -49,6 +51,7 @@ class Clay(torch.nn.Module):
49
51
  modality: str = "sentinel-2-l2a",
50
52
  checkpoint_path: str | None = None,
51
53
  metadata_path: str = CLAY_METADATA_PATH,
54
+ do_resizing: bool = False,
52
55
  ) -> None:
53
56
  """Initialize the Clay model.
54
57
 
@@ -57,6 +60,7 @@ class Clay(torch.nn.Module):
57
60
  modality: The modality to use (subset of CLAY_MODALITIES).
58
61
  checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
59
62
  metadata_path: Path to metadata.yaml.
63
+ do_resizing: Whether to resize the image to the input resolution.
60
64
  """
61
65
  super().__init__()
62
66
 
@@ -95,6 +99,14 @@ class Clay(torch.nn.Module):
95
99
 
96
100
  self.model_size = model_size
97
101
  self.modality = modality
102
+ self.do_resizing = do_resizing
103
+
104
+ def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
105
+ """Resize the image to the input resolution."""
106
+ new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
107
+ return F.interpolate(
108
+ image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
109
+ )
98
110
 
99
111
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
100
112
  """Forward pass for the Clay model.
@@ -114,7 +126,8 @@ class Clay(torch.nn.Module):
114
126
  chips = torch.stack(
115
127
  [inp[self.modality] for inp in inputs], dim=0
116
128
  ) # (B, C, H, W)
117
-
129
+ if self.do_resizing:
130
+ chips = self._resize_image(chips, chips.shape[2])
118
131
  order = self.metadata[self.modality]["band_order"]
119
132
  wavelengths = []
120
133
  for band in self.metadata[self.modality]["band_order"]:
rslearn/models/croma.py CHANGED
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from typing import Any
8
8
 
9
9
  import torch
10
+ import torch.nn.functional as F
10
11
  from einops import rearrange
11
12
  from upath import UPath
12
13
 
@@ -99,6 +100,7 @@ class Croma(torch.nn.Module):
99
100
  modality: CromaModality,
100
101
  pretrained_path: str | None = None,
101
102
  image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
103
+ do_resizing: bool = False,
102
104
  ) -> None:
103
105
  """Instantiate a new Croma instance.
104
106
 
@@ -107,12 +109,21 @@ class Croma(torch.nn.Module):
107
109
  modality: the modalities to configure the model to accept.
108
110
  pretrained_path: the local path to the pretrained weights. Otherwise it is
109
111
  downloaded and cached in temp directory.
110
- image_resolution: the width and height of the input images.
112
+ image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
113
+ do_resizing: Whether to resize the image to the input resolution.
111
114
  """
112
115
  super().__init__()
113
116
  self.size = size
114
117
  self.modality = modality
115
- self.image_resolution = image_resolution
118
+ self.do_resizing = do_resizing
119
+ if not do_resizing:
120
+ self.image_resolution = image_resolution
121
+ else:
122
+ # With single pixel input, we always resample to the patch size.
123
+ if image_resolution == 1:
124
+ self.image_resolution = PATCH_SIZE
125
+ else:
126
+ self.image_resolution = DEFAULT_IMAGE_RESOLUTION
116
127
 
117
128
  # Cache the CROMA weights to a deterministic path in temporary directory if the
118
129
  # path is not provided by the user.
@@ -137,7 +148,16 @@ class Croma(torch.nn.Module):
137
148
  pretrained_path=pretrained_path,
138
149
  size=size.value,
139
150
  modality=modality.value,
140
- image_resolution=image_resolution,
151
+ image_resolution=self.image_resolution,
152
+ )
153
+
154
+ def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
155
+ """Resize the image to the input resolution."""
156
+ return F.interpolate(
157
+ image,
158
+ size=(self.image_resolution, self.image_resolution),
159
+ mode="bilinear",
160
+ align_corners=False,
141
161
  )
142
162
 
143
163
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
@@ -151,8 +171,11 @@ class Croma(torch.nn.Module):
151
171
  sentinel2: torch.Tensor | None = None
152
172
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
153
173
  sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
174
+ sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
154
175
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
155
176
  sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
177
+ sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
178
+
156
179
  outputs = self.model(
157
180
  SAR_images=sentinel1,
158
181
  optical_images=sentinel2,
@@ -4,15 +4,14 @@ from typing import Any
4
4
 
5
5
  import satlaspretrain_models
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
 
8
9
 
9
10
  class SatlasPretrain(torch.nn.Module):
10
11
  """SatlasPretrain backbones."""
11
12
 
12
13
  def __init__(
13
- self,
14
- model_identifier: str,
15
- fpn: bool = False,
14
+ self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
16
15
  ) -> None:
17
16
  """Instantiate a new SatlasPretrain instance.
18
17
 
@@ -21,6 +20,8 @@ class SatlasPretrain(torch.nn.Module):
21
20
  https://github.com/allenai/satlaspretrain_models
22
21
  fpn: whether to include the feature pyramid network, otherwise only the
23
22
  Swin-v2-Transformer is used.
23
+ resize_to_pretrain: whether to resize inputs to the pretraining input
24
+ size (512 x 512)
24
25
  """
25
26
  super().__init__()
26
27
  weights_manager = satlaspretrain_models.Weights()
@@ -49,6 +50,19 @@ class SatlasPretrain(torch.nn.Module):
49
50
  [16, 1024],
50
51
  [32, 2048],
51
52
  ]
53
+ self.resize_to_pretrain = resize_to_pretrain
54
+
55
+ def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
56
+ """Resize to pretraining sizes if resize_to_pretrain == True."""
57
+ if self.resize_to_pretrain:
58
+ return F.interpolate(
59
+ data,
60
+ size=(512, 512),
61
+ mode="bilinear",
62
+ align_corners=False,
63
+ )
64
+ else:
65
+ return data
52
66
 
53
67
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
54
68
  """Compute feature maps from the SatlasPretrain backbone.
@@ -58,7 +72,7 @@ class SatlasPretrain(torch.nn.Module):
58
72
  process.
59
73
  """
60
74
  images = torch.stack([inp["image"] for inp in inputs], dim=0)
61
- return self.model(images)
75
+ return self.model(self.maybe_resize(images))
62
76
 
63
77
  def get_backbone_channels(self) -> list:
64
78
  """Returns the output channels of this model when used as a backbone.
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Any
5
5
 
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
  from einops import rearrange
8
9
  from terratorch.registry import BACKBONE_REGISTRY
9
10
 
@@ -18,6 +19,8 @@ class TerramindSize(str, Enum):
18
19
  LARGE = "large"
19
20
 
20
21
 
22
+ # Pretraining image size for Terramind
23
+ IMAGE_SIZE = 224
21
24
  # Default patch size for Terramind
22
25
  PATCH_SIZE = 16
23
26
 
@@ -89,12 +92,14 @@ class Terramind(torch.nn.Module):
89
92
  self,
90
93
  model_size: TerramindSize,
91
94
  modalities: list[str] = ["S2L2A"],
95
+ do_resizing: bool = False,
92
96
  ) -> None:
93
97
  """Initialize the Terramind model.
94
98
 
95
99
  Args:
96
100
  model_size: The size of the Terramind model.
97
101
  modalities: The modalities to use.
102
+ do_resizing: Whether to resize the input images to the pretraining resolution.
98
103
  """
99
104
  super().__init__()
100
105
 
@@ -116,6 +121,7 @@ class Terramind(torch.nn.Module):
116
121
 
117
122
  self.model_size = model_size
118
123
  self.modalities = modalities
124
+ self.do_resizing = do_resizing
119
125
 
120
126
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
121
127
  """Forward pass for the Terramind model.
@@ -132,6 +138,19 @@ class Terramind(torch.nn.Module):
132
138
  if modality not in inputs[0]:
133
139
  continue
134
140
  cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
141
+ if self.do_resizing and (
142
+ cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
143
+ ):
144
+ if cur.shape[2] == 1 and cur.shape[3] == 1:
145
+ new_height, new_width = PATCH_SIZE, PATCH_SIZE
146
+ else:
147
+ new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
148
+ cur = F.interpolate(
149
+ cur,
150
+ size=(new_height, new_width),
151
+ mode="bilinear",
152
+ align_corners=False,
153
+ )
135
154
  model_inputs[modality] = cur
136
155
 
137
156
  # By default, the patch embeddings are averaged over all modalities to reduce output tokens