rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/models/dinov3.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""DinoV3 model.
|
|
2
|
+
|
|
3
|
+
This code loads the DINOv3 model. You must obtain the model separately from Meta to use
|
|
4
|
+
it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
|
|
5
|
+
information.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from enum import StrEnum
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torchvision
|
|
14
|
+
from einops import rearrange
|
|
15
|
+
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
17
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
18
|
+
from rslearn.train.transforms.transform import Transform
|
|
19
|
+
|
|
20
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DinoV3Models(StrEnum):
|
|
24
|
+
"""Names for different DinoV3 images on torch hub."""
|
|
25
|
+
|
|
26
|
+
SMALL_WEB = "dinov3_vits16"
|
|
27
|
+
SMALL_PLUS_WEB = "dinov3_vits16plus"
|
|
28
|
+
BASE_WEB = "dinov3_vitb16"
|
|
29
|
+
LARGE_WEB = "dinov3_vitl16"
|
|
30
|
+
HUGE_PLUS_WEB = "dinov3_vith16plus"
|
|
31
|
+
FULL_7B_WEB = "dinov3_vit7b16"
|
|
32
|
+
LARGE_SATELLITE = "dinov3_vitl16_sat"
|
|
33
|
+
FULL_7B_SATELLITE = "dinov3_vit7b16_sat"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
DINOV3_PTHS: dict[str, str] = {
|
|
37
|
+
DinoV3Models.LARGE_SATELLITE: "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
|
|
38
|
+
DinoV3Models.FULL_7B_SATELLITE: "dinov3_vit7b16_pretrain_sat493m-a6675841.pth",
|
|
39
|
+
DinoV3Models.BASE_WEB: "dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth",
|
|
40
|
+
DinoV3Models.LARGE_WEB: "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
|
|
41
|
+
DinoV3Models.HUGE_PLUS_WEB: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth",
|
|
42
|
+
DinoV3Models.FULL_7B_WEB: "dinov3_vit7b16_pretrain_lvd1689m-a955f4.pth",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DinoV3(FeatureExtractor):
|
|
47
|
+
"""DinoV3 Backbones.
|
|
48
|
+
|
|
49
|
+
Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
|
|
50
|
+
See https://github.com/facebookresearch/dinov3?tab=readme-ov-file#pretrained-models
|
|
51
|
+
|
|
52
|
+
Only takes RGB as input. Expects normalized data (use the below normalizer).
|
|
53
|
+
|
|
54
|
+
Uses patch size 16. The input is resized to 256x256; when applying DinoV3 on
|
|
55
|
+
segmentation or detection tasks with inputs larger than 256x256, it may be best to
|
|
56
|
+
train and predict on 256x256 crops (using SplitConfig.patch_size argument).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
image_size: int = 256
|
|
60
|
+
patch_size: int = 16
|
|
61
|
+
output_dim: int = 1024
|
|
62
|
+
|
|
63
|
+
def _load_model(self, size: str, checkpoint_dir: str | None) -> torch.nn.Module:
|
|
64
|
+
model_name = size.replace("_sat", "")
|
|
65
|
+
if checkpoint_dir is not None:
|
|
66
|
+
weights = str(Path(checkpoint_dir) / DINOV3_PTHS[size])
|
|
67
|
+
return torch.hub.load(
|
|
68
|
+
"facebookresearch/dinov3",
|
|
69
|
+
model_name,
|
|
70
|
+
weights=weights,
|
|
71
|
+
) # nosec
|
|
72
|
+
return torch.hub.load("facebookresearch/dinov3", model_name, pretrained=False) # nosec
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
checkpoint_dir: str | None,
|
|
77
|
+
size: str = DinoV3Models.LARGE_SATELLITE,
|
|
78
|
+
use_cls_token: bool = False,
|
|
79
|
+
do_resizing: bool = True,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Instantiate a new DinoV3 instance.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
checkpoint_dir: the local path to the pretrained weight dir. If None, we load the architecture
|
|
85
|
+
only (randomly initialized).
|
|
86
|
+
size: the model size, see class for various models.
|
|
87
|
+
use_cls_token: use pooled class token (for classification), otherwise returns spatial feature map.
|
|
88
|
+
do_resizing: whether to resize inputs to 256x256. Default true.
|
|
89
|
+
"""
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.size = size
|
|
92
|
+
self.checkpoint_dir = checkpoint_dir
|
|
93
|
+
self.use_cls_token = use_cls_token
|
|
94
|
+
self.do_resizing = do_resizing
|
|
95
|
+
self.model = self._load_model(size, checkpoint_dir)
|
|
96
|
+
|
|
97
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
98
|
+
"""Forward pass for the dinov3 model.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
context: the model context. Input dicts must include "image" key.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
a FeatureMaps with one feature map.
|
|
105
|
+
"""
|
|
106
|
+
cur = torch.stack(
|
|
107
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
108
|
+
dim=0,
|
|
109
|
+
) # (B, C, H, W)
|
|
110
|
+
|
|
111
|
+
if self.do_resizing and (
|
|
112
|
+
cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
|
|
113
|
+
):
|
|
114
|
+
cur = torchvision.transforms.functional.resize(
|
|
115
|
+
cur,
|
|
116
|
+
[self.image_size, self.image_size],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if self.use_cls_token:
|
|
120
|
+
features = self.model(cur)
|
|
121
|
+
else:
|
|
122
|
+
features = self.model.forward_features(cur)["x_norm_patchtokens"]
|
|
123
|
+
batch_size, num_patches, _ = features.shape
|
|
124
|
+
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
125
|
+
features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
|
|
126
|
+
|
|
127
|
+
return FeatureMaps([features])
|
|
128
|
+
|
|
129
|
+
def get_backbone_channels(self) -> list:
|
|
130
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
131
|
+
|
|
132
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
133
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
134
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
135
|
+
has 32 channels.
|
|
136
|
+
"""
|
|
137
|
+
return [(self.patch_size, self.output_dim)]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class DinoV3Normalize(Transform):
|
|
141
|
+
"""Normalize inputs using DinoV3 normalization.
|
|
142
|
+
|
|
143
|
+
Normalize "image" key in input according to Dino statistics from pretraining. Satellite pretraining has slightly different normalizing than the base image model so set 'satellite' depending on what pretrained model you are using.
|
|
144
|
+
|
|
145
|
+
Input "image" should be RGB-like image between 0-255.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(self, satellite: bool = True):
|
|
149
|
+
"""Initialize a new DinoV3Normalize."""
|
|
150
|
+
super().__init__()
|
|
151
|
+
self.satellite = satellite
|
|
152
|
+
if satellite:
|
|
153
|
+
mean = [0.430, 0.411, 0.296]
|
|
154
|
+
std = [0.213, 0.156, 0.143]
|
|
155
|
+
else:
|
|
156
|
+
mean = [0.485, 0.456, 0.406]
|
|
157
|
+
std = [0.229, 0.224, 0.225]
|
|
158
|
+
|
|
159
|
+
self.normalize = Normalize(
|
|
160
|
+
[value * 255 for value in mean],
|
|
161
|
+
[value * 255 for value in std],
|
|
162
|
+
num_bands=3,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def forward(
|
|
166
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
167
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
168
|
+
"""Normalize the specified image with DinoV3 normalization.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
input_dict: the input dictionary.
|
|
172
|
+
target_dict: the target dictionary.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
normalized (input_dicts, target_dicts) tuple
|
|
176
|
+
"""
|
|
177
|
+
return self.normalize(input_dict, target_dict)
|
rslearn/models/faster_rcnn.py
CHANGED
|
@@ -6,14 +6,24 @@ from typing import Any
|
|
|
6
6
|
import torch
|
|
7
7
|
import torchvision
|
|
8
8
|
|
|
9
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
10
|
+
|
|
11
|
+
from .component import FeatureMaps, Predictor
|
|
12
|
+
|
|
9
13
|
|
|
10
14
|
class NoopTransform(torch.nn.Module):
|
|
11
15
|
"""A placeholder transform used with torchvision detection model."""
|
|
12
16
|
|
|
13
|
-
def __init__(self):
|
|
17
|
+
def __init__(self) -> None:
|
|
14
18
|
"""Create a new NoopTransform."""
|
|
15
19
|
super().__init__()
|
|
16
20
|
|
|
21
|
+
# We initialize a GeneralizedRCNNTransform just to use its batch_images
|
|
22
|
+
# function, which concatenates the images (padding to the dimensions of the
|
|
23
|
+
# largest image as needed) to the form needed by the Faster R-CNN head.
|
|
24
|
+
# We pass an arbitrary min_size and max_size here, but these are ignored since
|
|
25
|
+
# we call GeneralizedRCNNTransform.batch_images directly rather than calling
|
|
26
|
+
# its forward function.
|
|
17
27
|
self.transform = (
|
|
18
28
|
torchvision.models.detection.transform.GeneralizedRCNNTransform(
|
|
19
29
|
min_size=800,
|
|
@@ -39,32 +49,17 @@ class NoopTransform(torch.nn.Module):
|
|
|
39
49
|
Returns:
|
|
40
50
|
wrapped images and unmodified targets
|
|
41
51
|
"""
|
|
52
|
+
# See comment above, this just pads/concatenates the images without resizing.
|
|
42
53
|
images = self.transform.batch_images(images, size_divisible=32)
|
|
54
|
+
# Now convert to ImageList object needed by Faster R-CNN head.
|
|
43
55
|
image_sizes = [(image.shape[1], image.shape[2]) for image in images]
|
|
44
56
|
image_list = torchvision.models.detection.image_list.ImageList(
|
|
45
57
|
images, image_sizes
|
|
46
58
|
)
|
|
47
59
|
return image_list, targets
|
|
48
60
|
|
|
49
|
-
def postprocess(
|
|
50
|
-
self, detections: dict[str, torch.Tensor], image_sizes, orig_sizes
|
|
51
|
-
) -> dict[str, torch.Tensor]:
|
|
52
|
-
"""Post-process the detections to reflect original image size.
|
|
53
|
-
|
|
54
|
-
Since we didn't transform the images, we don't need to do anything here.
|
|
55
61
|
|
|
56
|
-
|
|
57
|
-
detections: the raw detections
|
|
58
|
-
image_sizes: the transformed image sizes
|
|
59
|
-
orig_sizes: the original image sizes
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
the post-processed detections (unmodified from the provided detections)
|
|
63
|
-
"""
|
|
64
|
-
return detections
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class FasterRCNN(torch.nn.Module):
|
|
62
|
+
class FasterRCNN(Predictor):
|
|
68
63
|
"""Faster R-CNN head for predicting bounding boxes.
|
|
69
64
|
|
|
70
65
|
It inputs multi-scale features, using each feature map to predict ROIs and then
|
|
@@ -80,7 +75,7 @@ class FasterRCNN(torch.nn.Module):
|
|
|
80
75
|
anchor_sizes: list[list[int]],
|
|
81
76
|
instance_segmentation: bool = False,
|
|
82
77
|
box_score_thresh: float = 0.05,
|
|
83
|
-
):
|
|
78
|
+
) -> None:
|
|
84
79
|
"""Create a new FasterRCNN.
|
|
85
80
|
|
|
86
81
|
Args:
|
|
@@ -185,20 +180,23 @@ class FasterRCNN(torch.nn.Module):
|
|
|
185
180
|
|
|
186
181
|
def forward(
|
|
187
182
|
self,
|
|
188
|
-
|
|
189
|
-
|
|
183
|
+
intermediates: Any,
|
|
184
|
+
context: ModelContext,
|
|
190
185
|
targets: list[dict[str, Any]] | None = None,
|
|
191
|
-
) ->
|
|
186
|
+
) -> ModelOutput:
|
|
192
187
|
"""Compute the detection outputs and loss from features.
|
|
193
188
|
|
|
194
189
|
Args:
|
|
195
|
-
|
|
196
|
-
|
|
190
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
191
|
+
context: the model context. Input dicts must contain image key for original image size.
|
|
197
192
|
targets: should contain class key that stores the class label.
|
|
198
193
|
|
|
199
194
|
Returns:
|
|
200
195
|
tuple of outputs and loss dict
|
|
201
196
|
"""
|
|
197
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
198
|
+
raise ValueError("input to FasterRCNN must be FeatureMaps")
|
|
199
|
+
|
|
202
200
|
# Fix target labels to be 1 size in case it's empty.
|
|
203
201
|
# For some reason this is needed.
|
|
204
202
|
if targets:
|
|
@@ -212,11 +210,12 @@ class FasterRCNN(torch.nn.Module):
|
|
|
212
210
|
),
|
|
213
211
|
)
|
|
214
212
|
|
|
215
|
-
|
|
213
|
+
# take the first (and assumed to be only) timestep
|
|
214
|
+
image_list = [inp["image"].image[:, 0] for inp in context.inputs]
|
|
216
215
|
images, targets = self.noop_transform(image_list, targets)
|
|
217
216
|
|
|
218
217
|
feature_dict = collections.OrderedDict()
|
|
219
|
-
for i, feat_map in enumerate(
|
|
218
|
+
for i, feat_map in enumerate(intermediates.feature_maps):
|
|
220
219
|
feature_dict[f"feat{i}"] = feat_map
|
|
221
220
|
|
|
222
221
|
proposals, proposal_losses = self.rpn(images, feature_dict, targets)
|
|
@@ -228,4 +227,7 @@ class FasterRCNN(torch.nn.Module):
|
|
|
228
227
|
losses.update(proposal_losses)
|
|
229
228
|
losses.update(detector_losses)
|
|
230
229
|
|
|
231
|
-
return
|
|
230
|
+
return ModelOutput(
|
|
231
|
+
outputs=detections,
|
|
232
|
+
loss_dict=losses,
|
|
233
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Apply center cropping on a feature map."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
|
+
|
|
7
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FeatureCenterCrop(IntermediateComponent):
|
|
11
|
+
"""Apply center cropping on the input feature maps."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
sizes: list[tuple[int, int]],
|
|
16
|
+
) -> None:
|
|
17
|
+
"""Create a new FeatureCenterCrop.
|
|
18
|
+
|
|
19
|
+
Only the center of each feature map will be retained and passed to the next
|
|
20
|
+
module.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
sizes: a list of (height, width) tuples, with one tuple for each input
|
|
24
|
+
feature map.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.sizes = sizes
|
|
28
|
+
|
|
29
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
30
|
+
"""Apply center cropping on the feature maps.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
intermediates: output from the previous model component, which must be a FeatureMaps.
|
|
34
|
+
context: the model context.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
center cropped feature maps.
|
|
38
|
+
"""
|
|
39
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
40
|
+
raise ValueError("input to FeatureCenterCrop must be FeatureMaps")
|
|
41
|
+
|
|
42
|
+
new_features = []
|
|
43
|
+
for i, feat in enumerate(intermediates.feature_maps):
|
|
44
|
+
height, width = self.sizes[i]
|
|
45
|
+
if feat.shape[2] < height or feat.shape[3] < width:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"feature map is smaller than the desired height and width"
|
|
48
|
+
)
|
|
49
|
+
start_h = feat.shape[2] // 2 - height // 2
|
|
50
|
+
start_w = feat.shape[3] // 2 - width // 2
|
|
51
|
+
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
52
|
+
new_features.append(feat)
|
|
53
|
+
return FeatureMaps(new_features)
|
rslearn/models/fpn.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
"""Feature pyramid network."""
|
|
2
2
|
|
|
3
3
|
import collections
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
|
-
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Fpn(IntermediateComponent):
|
|
10
14
|
"""A feature pyramid network (FPN).
|
|
11
15
|
|
|
12
16
|
The FPN inputs a multi-scale feature map. At each scale, it computes new features
|
|
@@ -32,20 +36,27 @@ class Fpn(torch.nn.Module):
|
|
|
32
36
|
in_channels_list=in_channels, out_channels=out_channels
|
|
33
37
|
)
|
|
34
38
|
|
|
35
|
-
def forward(self,
|
|
39
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
36
40
|
"""Compute outputs of the FPN.
|
|
37
41
|
|
|
38
42
|
Args:
|
|
39
|
-
|
|
43
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
44
|
+
context: the model context.
|
|
40
45
|
|
|
41
46
|
Returns:
|
|
42
|
-
new multi-scale feature maps from the FPN
|
|
47
|
+
new multi-scale feature maps from the FPN.
|
|
43
48
|
"""
|
|
44
|
-
|
|
49
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
50
|
+
raise ValueError("input to Fpn must be FeatureMaps")
|
|
51
|
+
|
|
52
|
+
feature_maps = intermediates.feature_maps
|
|
53
|
+
inp = collections.OrderedDict(
|
|
54
|
+
[(f"feat{i}", el) for i, el in enumerate(feature_maps)]
|
|
55
|
+
)
|
|
45
56
|
output = self.fpn(inp)
|
|
46
57
|
output = list(output.values())
|
|
47
58
|
|
|
48
59
|
if self.prepend:
|
|
49
|
-
return output +
|
|
60
|
+
return FeatureMaps(output + feature_maps)
|
|
50
61
|
else:
|
|
51
|
-
return output
|
|
62
|
+
return FeatureMaps(output)
|