rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ """DinoV3 model.
2
+
3
+ This code loads the DINOv3 model. You must obtain the model separately from Meta to use
4
+ it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
5
+ information.
6
+ """
7
+
8
+ from enum import StrEnum
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import torch
13
+ import torchvision
14
+ from einops import rearrange
15
+
16
+ from rslearn.train.model_context import ModelContext
17
+ from rslearn.train.transforms.normalize import Normalize
18
+ from rslearn.train.transforms.transform import Transform
19
+
20
+ from .component import FeatureExtractor, FeatureMaps
21
+
22
+
23
+ class DinoV3Models(StrEnum):
24
+ """Names for different DinoV3 images on torch hub."""
25
+
26
+ SMALL_WEB = "dinov3_vits16"
27
+ SMALL_PLUS_WEB = "dinov3_vits16plus"
28
+ BASE_WEB = "dinov3_vitb16"
29
+ LARGE_WEB = "dinov3_vitl16"
30
+ HUGE_PLUS_WEB = "dinov3_vith16plus"
31
+ FULL_7B_WEB = "dinov3_vit7b16"
32
+ LARGE_SATELLITE = "dinov3_vitl16_sat"
33
+ FULL_7B_SATELLITE = "dinov3_vit7b16_sat"
34
+
35
+
36
+ DINOV3_PTHS: dict[str, str] = {
37
+ DinoV3Models.LARGE_SATELLITE: "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
38
+ DinoV3Models.FULL_7B_SATELLITE: "dinov3_vit7b16_pretrain_sat493m-a6675841.pth",
39
+ DinoV3Models.BASE_WEB: "dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth",
40
+ DinoV3Models.LARGE_WEB: "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
41
+ DinoV3Models.HUGE_PLUS_WEB: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth",
42
+ DinoV3Models.FULL_7B_WEB: "dinov3_vit7b16_pretrain_lvd1689m-a955f4.pth",
43
+ }
44
+
45
+
46
+ class DinoV3(FeatureExtractor):
47
+ """DinoV3 Backbones.
48
+
49
+ Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
50
+ See https://github.com/facebookresearch/dinov3?tab=readme-ov-file#pretrained-models
51
+
52
+ Only takes RGB as input. Expects normalized data (use the below normalizer).
53
+
54
+ Uses patch size 16. The input is resized to 256x256; when applying DinoV3 on
55
+ segmentation or detection tasks with inputs larger than 256x256, it may be best to
56
+ train and predict on 256x256 crops (using SplitConfig.patch_size argument).
57
+ """
58
+
59
+ image_size: int = 256
60
+ patch_size: int = 16
61
+ output_dim: int = 1024
62
+
63
+ def _load_model(self, size: str, checkpoint_dir: str | None) -> torch.nn.Module:
64
+ model_name = size.replace("_sat", "")
65
+ if checkpoint_dir is not None:
66
+ weights = str(Path(checkpoint_dir) / DINOV3_PTHS[size])
67
+ return torch.hub.load(
68
+ "facebookresearch/dinov3",
69
+ model_name,
70
+ weights=weights,
71
+ ) # nosec
72
+ return torch.hub.load("facebookresearch/dinov3", model_name, pretrained=False) # nosec
73
+
74
+ def __init__(
75
+ self,
76
+ checkpoint_dir: str | None,
77
+ size: str = DinoV3Models.LARGE_SATELLITE,
78
+ use_cls_token: bool = False,
79
+ do_resizing: bool = True,
80
+ ) -> None:
81
+ """Instantiate a new DinoV3 instance.
82
+
83
+ Args:
84
+ checkpoint_dir: the local path to the pretrained weight dir. If None, we load the architecture
85
+ only (randomly initialized).
86
+ size: the model size, see class for various models.
87
+ use_cls_token: use pooled class token (for classification), otherwise returns spatial feature map.
88
+ do_resizing: whether to resize inputs to 256x256. Default true.
89
+ """
90
+ super().__init__()
91
+ self.size = size
92
+ self.checkpoint_dir = checkpoint_dir
93
+ self.use_cls_token = use_cls_token
94
+ self.do_resizing = do_resizing
95
+ self.model = self._load_model(size, checkpoint_dir)
96
+
97
+ def forward(self, context: ModelContext) -> FeatureMaps:
98
+ """Forward pass for the dinov3 model.
99
+
100
+ Args:
101
+ context: the model context. Input dicts must include "image" key.
102
+
103
+ Returns:
104
+ a FeatureMaps with one feature map.
105
+ """
106
+ cur = torch.stack(
107
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
108
+ dim=0,
109
+ ) # (B, C, H, W)
110
+
111
+ if self.do_resizing and (
112
+ cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
113
+ ):
114
+ cur = torchvision.transforms.functional.resize(
115
+ cur,
116
+ [self.image_size, self.image_size],
117
+ )
118
+
119
+ if self.use_cls_token:
120
+ features = self.model(cur)
121
+ else:
122
+ features = self.model.forward_features(cur)["x_norm_patchtokens"]
123
+ batch_size, num_patches, _ = features.shape
124
+ height, width = int(num_patches**0.5), int(num_patches**0.5)
125
+ features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
126
+
127
+ return FeatureMaps([features])
128
+
129
+ def get_backbone_channels(self) -> list:
130
+ """Returns the output channels of this model when used as a backbone.
131
+
132
+ The output channels is a list of (downsample_factor, depth) that corresponds
133
+ to the feature maps that the backbone returns. For example, an element [2, 32]
134
+ indicates that the corresponding feature map is 1/2 the input resolution and
135
+ has 32 channels.
136
+ """
137
+ return [(self.patch_size, self.output_dim)]
138
+
139
+
140
+ class DinoV3Normalize(Transform):
141
+ """Normalize inputs using DinoV3 normalization.
142
+
143
+ Normalize "image" key in input according to Dino statistics from pretraining. Satellite pretraining has slightly different normalizing than the base image model so set 'satellite' depending on what pretrained model you are using.
144
+
145
+ Input "image" should be RGB-like image between 0-255.
146
+ """
147
+
148
+ def __init__(self, satellite: bool = True):
149
+ """Initialize a new DinoV3Normalize."""
150
+ super().__init__()
151
+ self.satellite = satellite
152
+ if satellite:
153
+ mean = [0.430, 0.411, 0.296]
154
+ std = [0.213, 0.156, 0.143]
155
+ else:
156
+ mean = [0.485, 0.456, 0.406]
157
+ std = [0.229, 0.224, 0.225]
158
+
159
+ self.normalize = Normalize(
160
+ [value * 255 for value in mean],
161
+ [value * 255 for value in std],
162
+ num_bands=3,
163
+ )
164
+
165
+ def forward(
166
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
167
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
168
+ """Normalize the specified image with DinoV3 normalization.
169
+
170
+ Args:
171
+ input_dict: the input dictionary.
172
+ target_dict: the target dictionary.
173
+
174
+ Returns:
175
+ normalized (input_dicts, target_dicts) tuple
176
+ """
177
+ return self.normalize(input_dict, target_dict)
@@ -6,14 +6,24 @@ from typing import Any
6
6
  import torch
7
7
  import torchvision
8
8
 
9
+ from rslearn.train.model_context import ModelContext, ModelOutput
10
+
11
+ from .component import FeatureMaps, Predictor
12
+
9
13
 
10
14
  class NoopTransform(torch.nn.Module):
11
15
  """A placeholder transform used with torchvision detection model."""
12
16
 
13
- def __init__(self):
17
+ def __init__(self) -> None:
14
18
  """Create a new NoopTransform."""
15
19
  super().__init__()
16
20
 
21
+ # We initialize a GeneralizedRCNNTransform just to use its batch_images
22
+ # function, which concatenates the images (padding to the dimensions of the
23
+ # largest image as needed) to the form needed by the Faster R-CNN head.
24
+ # We pass an arbitrary min_size and max_size here, but these are ignored since
25
+ # we call GeneralizedRCNNTransform.batch_images directly rather than calling
26
+ # its forward function.
17
27
  self.transform = (
18
28
  torchvision.models.detection.transform.GeneralizedRCNNTransform(
19
29
  min_size=800,
@@ -39,32 +49,17 @@ class NoopTransform(torch.nn.Module):
39
49
  Returns:
40
50
  wrapped images and unmodified targets
41
51
  """
52
+ # See comment above, this just pads/concatenates the images without resizing.
42
53
  images = self.transform.batch_images(images, size_divisible=32)
54
+ # Now convert to ImageList object needed by Faster R-CNN head.
43
55
  image_sizes = [(image.shape[1], image.shape[2]) for image in images]
44
56
  image_list = torchvision.models.detection.image_list.ImageList(
45
57
  images, image_sizes
46
58
  )
47
59
  return image_list, targets
48
60
 
49
- def postprocess(
50
- self, detections: dict[str, torch.Tensor], image_sizes, orig_sizes
51
- ) -> dict[str, torch.Tensor]:
52
- """Post-process the detections to reflect original image size.
53
-
54
- Since we didn't transform the images, we don't need to do anything here.
55
61
 
56
- Args:
57
- detections: the raw detections
58
- image_sizes: the transformed image sizes
59
- orig_sizes: the original image sizes
60
-
61
- Returns:
62
- the post-processed detections (unmodified from the provided detections)
63
- """
64
- return detections
65
-
66
-
67
- class FasterRCNN(torch.nn.Module):
62
+ class FasterRCNN(Predictor):
68
63
  """Faster R-CNN head for predicting bounding boxes.
69
64
 
70
65
  It inputs multi-scale features, using each feature map to predict ROIs and then
@@ -80,7 +75,7 @@ class FasterRCNN(torch.nn.Module):
80
75
  anchor_sizes: list[list[int]],
81
76
  instance_segmentation: bool = False,
82
77
  box_score_thresh: float = 0.05,
83
- ):
78
+ ) -> None:
84
79
  """Create a new FasterRCNN.
85
80
 
86
81
  Args:
@@ -185,20 +180,23 @@ class FasterRCNN(torch.nn.Module):
185
180
 
186
181
  def forward(
187
182
  self,
188
- features: list[torch.Tensor],
189
- inputs: list[dict[str, Any]],
183
+ intermediates: Any,
184
+ context: ModelContext,
190
185
  targets: list[dict[str, Any]] | None = None,
191
- ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
186
+ ) -> ModelOutput:
192
187
  """Compute the detection outputs and loss from features.
193
188
 
194
189
  Args:
195
- features: multi-scale feature maps.
196
- inputs: original inputs, should cotnain image key for original image size.
190
+ intermediates: the output from the previous component, which must be a FeatureMaps.
191
+ context: the model context. Input dicts must contain image key for original image size.
197
192
  targets: should contain class key that stores the class label.
198
193
 
199
194
  Returns:
200
195
  tuple of outputs and loss dict
201
196
  """
197
+ if not isinstance(intermediates, FeatureMaps):
198
+ raise ValueError("input to FasterRCNN must be FeatureMaps")
199
+
202
200
  # Fix target labels to be 1 size in case it's empty.
203
201
  # For some reason this is needed.
204
202
  if targets:
@@ -212,11 +210,12 @@ class FasterRCNN(torch.nn.Module):
212
210
  ),
213
211
  )
214
212
 
215
- image_list = [inp["image"] for inp in inputs]
213
+ # take the first (and assumed to be only) timestep
214
+ image_list = [inp["image"].image[:, 0] for inp in context.inputs]
216
215
  images, targets = self.noop_transform(image_list, targets)
217
216
 
218
217
  feature_dict = collections.OrderedDict()
219
- for i, feat_map in enumerate(features):
218
+ for i, feat_map in enumerate(intermediates.feature_maps):
220
219
  feature_dict[f"feat{i}"] = feat_map
221
220
 
222
221
  proposals, proposal_losses = self.rpn(images, feature_dict, targets)
@@ -228,4 +227,7 @@ class FasterRCNN(torch.nn.Module):
228
227
  losses.update(proposal_losses)
229
228
  losses.update(detector_losses)
230
229
 
231
- return detections, losses
230
+ return ModelOutput(
231
+ outputs=detections,
232
+ loss_dict=losses,
233
+ )
@@ -0,0 +1,53 @@
1
+ """Apply center cropping on a feature map."""
2
+
3
+ from typing import Any
4
+
5
+ from rslearn.train.model_context import ModelContext
6
+
7
+ from .component import FeatureMaps, IntermediateComponent
8
+
9
+
10
+ class FeatureCenterCrop(IntermediateComponent):
11
+ """Apply center cropping on the input feature maps."""
12
+
13
+ def __init__(
14
+ self,
15
+ sizes: list[tuple[int, int]],
16
+ ) -> None:
17
+ """Create a new FeatureCenterCrop.
18
+
19
+ Only the center of each feature map will be retained and passed to the next
20
+ module.
21
+
22
+ Args:
23
+ sizes: a list of (height, width) tuples, with one tuple for each input
24
+ feature map.
25
+ """
26
+ super().__init__()
27
+ self.sizes = sizes
28
+
29
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
30
+ """Apply center cropping on the feature maps.
31
+
32
+ Args:
33
+ intermediates: output from the previous model component, which must be a FeatureMaps.
34
+ context: the model context.
35
+
36
+ Returns:
37
+ center cropped feature maps.
38
+ """
39
+ if not isinstance(intermediates, FeatureMaps):
40
+ raise ValueError("input to FeatureCenterCrop must be FeatureMaps")
41
+
42
+ new_features = []
43
+ for i, feat in enumerate(intermediates.feature_maps):
44
+ height, width = self.sizes[i]
45
+ if feat.shape[2] < height or feat.shape[3] < width:
46
+ raise ValueError(
47
+ "feature map is smaller than the desired height and width"
48
+ )
49
+ start_h = feat.shape[2] // 2 - height // 2
50
+ start_w = feat.shape[3] // 2 - width // 2
51
+ feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
52
+ new_features.append(feat)
53
+ return FeatureMaps(new_features)
rslearn/models/fpn.py CHANGED
@@ -1,12 +1,16 @@
1
1
  """Feature pyramid network."""
2
2
 
3
3
  import collections
4
+ from typing import Any
4
5
 
5
- import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import ModelContext
8
9
 
9
- class Fpn(torch.nn.Module):
10
+ from .component import FeatureMaps, IntermediateComponent
11
+
12
+
13
+ class Fpn(IntermediateComponent):
10
14
  """A feature pyramid network (FPN).
11
15
 
12
16
  The FPN inputs a multi-scale feature map. At each scale, it computes new features
@@ -32,20 +36,27 @@ class Fpn(torch.nn.Module):
32
36
  in_channels_list=in_channels, out_channels=out_channels
33
37
  )
34
38
 
35
- def forward(self, x: list[torch.Tensor]):
39
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
36
40
  """Compute outputs of the FPN.
37
41
 
38
42
  Args:
39
- x: the multi-scale feature maps
43
+ intermediates: the output from the previous component, which must be a FeatureMaps.
44
+ context: the model context.
40
45
 
41
46
  Returns:
42
- new multi-scale feature maps from the FPN
47
+ new multi-scale feature maps from the FPN.
43
48
  """
44
- inp = collections.OrderedDict([(f"feat{i}", el) for i, el in enumerate(x)])
49
+ if not isinstance(intermediates, FeatureMaps):
50
+ raise ValueError("input to Fpn must be FeatureMaps")
51
+
52
+ feature_maps = intermediates.feature_maps
53
+ inp = collections.OrderedDict(
54
+ [(f"feat{i}", el) for i, el in enumerate(feature_maps)]
55
+ )
45
56
  output = self.fpn(inp)
46
57
  output = list(output.values())
47
58
 
48
59
  if self.prepend:
49
- return output + x
60
+ return FeatureMaps(output + feature_maps)
50
61
  else:
51
- return output
62
+ return FeatureMaps(output)
@@ -0,0 +1,5 @@
1
+ """Galileo model."""
2
+
3
+ from .galileo import GalileoModel, GalileoSize
4
+
5
+ __all__ = ["GalileoModel", "GalileoSize"]