rslearn 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/dataset/manage.py CHANGED
@@ -124,12 +124,24 @@ def prepare_dataset_windows(
124
124
  )
125
125
  continue
126
126
  data_source_cfg = layer_cfg.data_source
127
+ min_matches = data_source_cfg.query_config.min_matches
127
128
 
128
129
  # Get windows that need to be prepared for this layer.
130
+ # Also track which windows are skipped vs previously rejected.
129
131
  needed_windows = []
132
+ windows_skipped = 0
133
+ windows_rejected = 0
130
134
  for window in windows:
131
135
  layer_datas = window.load_layer_datas()
132
136
  if layer_name in layer_datas and not force:
137
+ # Window already has layer data - check if it was previously rejected
138
+ layer_data = layer_datas[layer_name]
139
+ if len(layer_data.serialized_item_groups) == 0 and min_matches > 0:
140
+ # Previously rejected due to min_matches
141
+ windows_rejected += 1
142
+ else:
143
+ # Successfully prepared previously
144
+ windows_skipped += 1
133
145
  continue
134
146
  needed_windows.append(window)
135
147
  logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
@@ -141,8 +153,8 @@ def prepare_dataset_windows(
141
153
  data_source_name=data_source_cfg.name,
142
154
  duration_seconds=time.monotonic() - layer_start_time,
143
155
  windows_prepared=0,
144
- windows_skipped=len(windows),
145
- windows_rejected=0,
156
+ windows_skipped=windows_skipped,
157
+ windows_rejected=windows_rejected,
146
158
  get_items_attempts=0,
147
159
  )
148
160
  )
@@ -184,8 +196,6 @@ def prepare_dataset_windows(
184
196
  )
185
197
 
186
198
  windows_prepared = 0
187
- windows_rejected = 0
188
- min_matches = data_source_cfg.query_config.min_matches
189
199
  for window, result in zip(needed_windows, results):
190
200
  layer_datas = window.load_layer_datas()
191
201
  layer_datas[layer_name] = WindowLayerData(
@@ -202,8 +212,6 @@ def prepare_dataset_windows(
202
212
  else:
203
213
  windows_prepared += 1
204
214
 
205
- windows_skipped = len(windows) - len(needed_windows)
206
-
207
215
  layer_summaries.append(
208
216
  LayerPrepareSummary(
209
217
  layer_name=layer_name,
@@ -8,6 +8,7 @@ from importlib.resources import files
8
8
  from typing import Any
9
9
 
10
10
  import torch
11
+ import torch.nn.functional as F
11
12
  import yaml
12
13
  from einops import rearrange
13
14
  from huggingface_hub import hf_hub_download
@@ -30,6 +31,7 @@ PATCH_SIZE = 8
30
31
  CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
31
32
  CONFIG_DIR = files("rslearn.models.clay.configs")
32
33
  CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
34
+ DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
33
35
 
34
36
 
35
37
  def get_clay_checkpoint_path(
@@ -49,6 +51,7 @@ class Clay(torch.nn.Module):
49
51
  modality: str = "sentinel-2-l2a",
50
52
  checkpoint_path: str | None = None,
51
53
  metadata_path: str = CLAY_METADATA_PATH,
54
+ do_resizing: bool = False,
52
55
  ) -> None:
53
56
  """Initialize the Clay model.
54
57
 
@@ -57,6 +60,7 @@ class Clay(torch.nn.Module):
57
60
  modality: The modality to use (subset of CLAY_MODALITIES).
58
61
  checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
59
62
  metadata_path: Path to metadata.yaml.
63
+ do_resizing: Whether to resize the image to the input resolution.
60
64
  """
61
65
  super().__init__()
62
66
 
@@ -95,6 +99,14 @@ class Clay(torch.nn.Module):
95
99
 
96
100
  self.model_size = model_size
97
101
  self.modality = modality
102
+ self.do_resizing = do_resizing
103
+
104
+ def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
105
+ """Resize the image to the input resolution."""
106
+ new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
107
+ return F.interpolate(
108
+ image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
109
+ )
98
110
 
99
111
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
100
112
  """Forward pass for the Clay model.
@@ -114,7 +126,8 @@ class Clay(torch.nn.Module):
114
126
  chips = torch.stack(
115
127
  [inp[self.modality] for inp in inputs], dim=0
116
128
  ) # (B, C, H, W)
117
-
129
+ if self.do_resizing:
130
+ chips = self._resize_image(chips, chips.shape[2])
118
131
  order = self.metadata[self.modality]["band_order"]
119
132
  wavelengths = []
120
133
  for band in self.metadata[self.modality]["band_order"]:
rslearn/models/croma.py CHANGED
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from typing import Any
8
8
 
9
9
  import torch
10
+ import torch.nn.functional as F
10
11
  from einops import rearrange
11
12
  from upath import UPath
12
13
 
@@ -99,6 +100,7 @@ class Croma(torch.nn.Module):
99
100
  modality: CromaModality,
100
101
  pretrained_path: str | None = None,
101
102
  image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
103
+ do_resizing: bool = False,
102
104
  ) -> None:
103
105
  """Instantiate a new Croma instance.
104
106
 
@@ -107,12 +109,21 @@ class Croma(torch.nn.Module):
107
109
  modality: the modalities to configure the model to accept.
108
110
  pretrained_path: the local path to the pretrained weights. Otherwise it is
109
111
  downloaded and cached in temp directory.
110
- image_resolution: the width and height of the input images.
112
+ image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
113
+ do_resizing: Whether to resize the image to the input resolution.
111
114
  """
112
115
  super().__init__()
113
116
  self.size = size
114
117
  self.modality = modality
115
- self.image_resolution = image_resolution
118
+ self.do_resizing = do_resizing
119
+ if not do_resizing:
120
+ self.image_resolution = image_resolution
121
+ else:
122
+ # With single pixel input, we always resample to the patch size.
123
+ if image_resolution == 1:
124
+ self.image_resolution = PATCH_SIZE
125
+ else:
126
+ self.image_resolution = DEFAULT_IMAGE_RESOLUTION
116
127
 
117
128
  # Cache the CROMA weights to a deterministic path in temporary directory if the
118
129
  # path is not provided by the user.
@@ -137,7 +148,16 @@ class Croma(torch.nn.Module):
137
148
  pretrained_path=pretrained_path,
138
149
  size=size.value,
139
150
  modality=modality.value,
140
- image_resolution=image_resolution,
151
+ image_resolution=self.image_resolution,
152
+ )
153
+
154
+ def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
155
+ """Resize the image to the input resolution."""
156
+ return F.interpolate(
157
+ image,
158
+ size=(self.image_resolution, self.image_resolution),
159
+ mode="bilinear",
160
+ align_corners=False,
141
161
  )
142
162
 
143
163
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
@@ -151,8 +171,11 @@ class Croma(torch.nn.Module):
151
171
  sentinel2: torch.Tensor | None = None
152
172
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
153
173
  sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
174
+ sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
154
175
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
155
176
  sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
177
+ sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
178
+
156
179
  outputs = self.model(
157
180
  SAR_images=sentinel1,
158
181
  optical_images=sentinel2,
@@ -4,15 +4,14 @@ from typing import Any
4
4
 
5
5
  import satlaspretrain_models
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
 
8
9
 
9
10
  class SatlasPretrain(torch.nn.Module):
10
11
  """SatlasPretrain backbones."""
11
12
 
12
13
  def __init__(
13
- self,
14
- model_identifier: str,
15
- fpn: bool = False,
14
+ self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
16
15
  ) -> None:
17
16
  """Instantiate a new SatlasPretrain instance.
18
17
 
@@ -21,6 +20,8 @@ class SatlasPretrain(torch.nn.Module):
21
20
  https://github.com/allenai/satlaspretrain_models
22
21
  fpn: whether to include the feature pyramid network, otherwise only the
23
22
  Swin-v2-Transformer is used.
23
+ resize_to_pretrain: whether to resize inputs to the pretraining input
24
+ size (512 x 512)
24
25
  """
25
26
  super().__init__()
26
27
  weights_manager = satlaspretrain_models.Weights()
@@ -49,6 +50,19 @@ class SatlasPretrain(torch.nn.Module):
49
50
  [16, 1024],
50
51
  [32, 2048],
51
52
  ]
53
+ self.resize_to_pretrain = resize_to_pretrain
54
+
55
+ def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
56
+ """Resize to pretraining sizes if resize_to_pretrain == True."""
57
+ if self.resize_to_pretrain:
58
+ return F.interpolate(
59
+ data,
60
+ size=(512, 512),
61
+ mode="bilinear",
62
+ align_corners=False,
63
+ )
64
+ else:
65
+ return data
52
66
 
53
67
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
54
68
  """Compute feature maps from the SatlasPretrain backbone.
@@ -58,7 +72,7 @@ class SatlasPretrain(torch.nn.Module):
58
72
  process.
59
73
  """
60
74
  images = torch.stack([inp["image"] for inp in inputs], dim=0)
61
- return self.model(images)
75
+ return self.model(self.maybe_resize(images))
62
76
 
63
77
  def get_backbone_channels(self) -> list:
64
78
  """Returns the output channels of this model when used as a backbone.
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Any
5
5
 
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
  from einops import rearrange
8
9
  from terratorch.registry import BACKBONE_REGISTRY
9
10
 
@@ -18,6 +19,8 @@ class TerramindSize(str, Enum):
18
19
  LARGE = "large"
19
20
 
20
21
 
22
+ # Pretraining image size for Terramind
23
+ IMAGE_SIZE = 224
21
24
  # Default patch size for Terramind
22
25
  PATCH_SIZE = 16
23
26
 
@@ -89,12 +92,14 @@ class Terramind(torch.nn.Module):
89
92
  self,
90
93
  model_size: TerramindSize,
91
94
  modalities: list[str] = ["S2L2A"],
95
+ do_resizing: bool = False,
92
96
  ) -> None:
93
97
  """Initialize the Terramind model.
94
98
 
95
99
  Args:
96
100
  model_size: The size of the Terramind model.
97
101
  modalities: The modalities to use.
102
+ do_resizing: Whether to resize the input images to the pretraining resolution.
98
103
  """
99
104
  super().__init__()
100
105
 
@@ -116,6 +121,7 @@ class Terramind(torch.nn.Module):
116
121
 
117
122
  self.model_size = model_size
118
123
  self.modalities = modalities
124
+ self.do_resizing = do_resizing
119
125
 
120
126
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
121
127
  """Forward pass for the Terramind model.
@@ -132,6 +138,19 @@ class Terramind(torch.nn.Module):
132
138
  if modality not in inputs[0]:
133
139
  continue
134
140
  cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
141
+ if self.do_resizing and (
142
+ cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
143
+ ):
144
+ if cur.shape[2] == 1 and cur.shape[3] == 1:
145
+ new_height, new_width = PATCH_SIZE, PATCH_SIZE
146
+ else:
147
+ new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
148
+ cur = F.interpolate(
149
+ cur,
150
+ size=(new_height, new_width),
151
+ mode="bilinear",
152
+ align_corners=False,
153
+ )
135
154
  model_inputs[modality] = cur
136
155
 
137
156
  # By default, the patch embeddings are averaged over all modalities to reduce output tokens
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.14
3
+ Version: 0.0.15
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -40,7 +40,7 @@ rslearn/dataset/add_windows.py,sha256=pwCEvwLE1jQCoqQxw6CJ-sP46ayWppFa2hGYIB6VVk
40
40
  rslearn/dataset/dataset.py,sha256=bjf9nI55j-MF0bIQWSNPjNbpfqnLK4jy-96TAcwO0MM,5214
41
41
  rslearn/dataset/handler_summaries.py,sha256=wI99RDk5erCWkzl1A7Uc4chatQ9KWIr4F_0Hxr9Co6s,2607
42
42
  rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
43
- rslearn/dataset/manage.py,sha256=owelBiBqvoIQYLhFMDK4ULzcoGBNE27JV8kl68jf3wg,18563
43
+ rslearn/dataset/manage.py,sha256=IURlbCtm9a5f4d52AXfte1yyodlf6MgjfYn3__GdIL4,19062
44
44
  rslearn/dataset/materialize.py,sha256=-z47svc_JqGhzkp8kq5Hd9fykWNqFEUCQezo887TWBw,22056
45
45
  rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
46
46
  rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
@@ -48,7 +48,7 @@ rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
48
48
  rslearn/models/anysat.py,sha256=3Oh2gWxicVdUzOjevBEZf0PuolmCy0KC5Ad7JY-0Plc,7949
49
49
  rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
50
50
  rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
51
- rslearn/models/croma.py,sha256=cOazTp3l2PNJltKrmPqD5Gy4pi3CI03-X9G4T10cX2k,9529
51
+ rslearn/models/croma.py,sha256=n7yunpT7lo8vWWaOpx4yt8jZSXjgWqfgZcZKFW5zuEQ,10591
52
52
  rslearn/models/dinov3.py,sha256=9k9kNlXCorQQwKjLGptooANd48TUBsITQ1e4fUomlM4,6337
53
53
  rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
54
54
  rslearn/models/feature_center_crop.py,sha256=24eOrvLEGGVWPw7kPHyUes5HtYNAX7GZ_NpqDGMILEY,1553
@@ -63,18 +63,18 @@ rslearn/models/prithvi.py,sha256=AIzcO5xk1ggR0MjbfhIzqPVgUKFN7odxygmgyAelfW8,401
63
63
  rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
64
64
  rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
65
65
  rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
66
- rslearn/models/satlaspretrain.py,sha256=YpjXl-uClhTZMDmyhN64Fg3AszzT-ymZgJB0fO9RyoY,2419
66
+ rslearn/models/satlaspretrain.py,sha256=b6FR_il6MnWU4UpB9OxInZSK9n0IS0PcQuLrWH4YD8g,3046
67
67
  rslearn/models/simple_time_series.py,sha256=oTg_akabYFBExJu7JCpbuM211-ZgQS4WerG2nEYrIZY,12774
68
68
  rslearn/models/singletask.py,sha256=z4vN9Yvzz0I-U4KJdVZxLJK2ZV-MIv9tzwCGcOWoUPY,1604
69
69
  rslearn/models/ssl4eo_s12.py,sha256=sOGEHcDo-rNdmEyoLu2AVEqfxRM_cv6zpfAmyn5c6tw,3553
70
70
  rslearn/models/swin.py,sha256=bMlGePXMFou4A_YSUZzjHgN9NniGXaCWdGQ31xHDKis,5511
71
71
  rslearn/models/task_embedding.py,sha256=Z6sf61BLCtvdrdnvjh8500b-KiFp3GeWbT4mOqpaCKk,9100
72
- rslearn/models/terramind.py,sha256=kipar8sMaHJJ3b8vIgL0-s4qhHcA0Vb854vmlZ9cWh4,7524
72
+ rslearn/models/terramind.py,sha256=5POVk_y29LlbVswa6ojd9gdB70iO41yB9Y2aqVY4WdQ,8327
73
73
  rslearn/models/trunk.py,sha256=H1QPQGAKsmocq3OiF66GW8MQI4LffupTDrgZR4Ta7QM,4708
74
74
  rslearn/models/unet.py,sha256=WUgLgvvlgV8l_6MIDBl6aX1HNOkb24DfnVRIyYXHCjo,6865
75
75
  rslearn/models/upsample.py,sha256=3kWbyWZIk56JJxj8en9pieitbrk3XnbIsTKlEkiDQQY,938
76
76
  rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
77
- rslearn/models/clay/clay.py,sha256=5RO5H8EM0tKjCwWMQ4xDkKkUCwKpm2K_Yw1alnhvVhU,7773
77
+ rslearn/models/clay/clay.py,sha256=29CGCOysx9duEX4Y6LUNHXck_sHjCFrlV4w8CP_hKmI,8460
78
78
  rslearn/models/clay/configs/metadata.yaml,sha256=rZTFh4Yb9htEfbQNOPl4HTbFogEhzwIRqFzG-1uT01Y,4652
79
79
  rslearn/models/detr/__init__.py,sha256=GGAnTIhyuvl34IRrJ_4gXjm_01OlM5rbQQ3c3TGfbK8,84
80
80
  rslearn/models/detr/box_ops.py,sha256=ORCF6EwMpMBB_VgQT05SjR47dCR2rN2gPhL_gsuUWJs,3236
@@ -154,10 +154,10 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
154
154
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
155
155
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
156
156
  rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
157
- rslearn-0.0.14.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
158
- rslearn-0.0.14.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
159
- rslearn-0.0.14.dist-info/METADATA,sha256=Jbm6ySbM4gkT_5o-RWbRr8APS8TYXq3Q-bWyeda-Uc8,36319
160
- rslearn-0.0.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
161
- rslearn-0.0.14.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
162
- rslearn-0.0.14.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
163
- rslearn-0.0.14.dist-info/RECORD,,
157
+ rslearn-0.0.15.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
158
+ rslearn-0.0.15.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
159
+ rslearn-0.0.15.dist-info/METADATA,sha256=HRkJjQfvxCEosmCBvLcLd9nZnXjbmfAgPIknMy_ORBo,36319
160
+ rslearn-0.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
161
+ rslearn-0.0.15.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
162
+ rslearn-0.0.15.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
163
+ rslearn-0.0.15.dist-info/RECORD,,