rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,306 @@
1
+ """CROMA models."""
2
+
3
+ import shutil
4
+ import tempfile
5
+ import urllib.request
6
+ from enum import Enum
7
+ from typing import Any
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from upath import UPath
13
+
14
+ from rslearn.log_utils import get_logger
15
+ from rslearn.train.model_context import ModelContext
16
+ from rslearn.train.transforms.transform import Transform
17
+ from rslearn.utils.fsspec import open_atomic
18
+
19
+ from .component import FeatureExtractor, FeatureMaps
20
+ from .use_croma import PretrainedCROMA
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ class CromaSize(str, Enum):
26
+ """CROMA model size."""
27
+
28
+ BASE = "base"
29
+ LARGE = "large"
30
+
31
+
32
+ class CromaModality(str, Enum):
33
+ """CROMA model configured input modalities."""
34
+
35
+ BOTH = "both"
36
+ SENTINEL1 = "SAR"
37
+ SENTINEL2 = "optical"
38
+
39
+
40
+ PATCH_SIZE = 8
41
+ DEFAULT_IMAGE_RESOLUTION = 120
42
+ PRETRAINED_URLS: dict[CromaSize, str] = {
43
+ CromaSize.BASE: "https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_base.pt",
44
+ CromaSize.LARGE: "https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt",
45
+ }
46
+ MEAN_AND_STD_BY_BAND: dict[tuple[str, str], tuple[float, float]] = {
47
+ ("sentinel1", "vv"): (0.15, 0.82),
48
+ ("sentinel1", "vh"): (0.03, 0.15),
49
+ ("sentinel2", "B01"): (1116, 1956),
50
+ ("sentinel2", "B02"): (1189, 1859),
51
+ ("sentinel2", "B03"): (1408, 1728),
52
+ ("sentinel2", "B04"): (1513, 1741),
53
+ ("sentinel2", "B05"): (1891, 1755),
54
+ ("sentinel2", "B06"): (2484, 1622),
55
+ ("sentinel2", "B07"): (2723, 1622),
56
+ ("sentinel2", "B08"): (2755, 1612),
57
+ ("sentinel2", "B8A"): (2886, 1611),
58
+ ("sentinel2", "B09"): (3270, 2651),
59
+ ("sentinel2", "B11"): (2563, 1442),
60
+ ("sentinel2", "B12"): (1914, 1329),
61
+ }
62
+ MODALITY_BANDS = {
63
+ "sentinel1": ["vv", "vh"],
64
+ "sentinel2": [
65
+ "B01",
66
+ "B02",
67
+ "B03",
68
+ "B04",
69
+ "B05",
70
+ "B06",
71
+ "B07",
72
+ "B08",
73
+ "B8A",
74
+ "B09",
75
+ "B11",
76
+ "B12",
77
+ ],
78
+ }
79
+
80
+
81
+ class Croma(FeatureExtractor):
82
+ """CROMA backbones.
83
+
84
+ There are two model sizes, base and large.
85
+
86
+ The model can be applied with just Sentinel-1, just Sentinel-2, or both. The input
87
+ must be defined a priori by passing the corresponding CromaModality. Sentinel-1
88
+ images should be passed under the "sentinel1" key while Sentinel-2 images should be
89
+ passed under the "sentinel2" key. Only a single timestep can be provided.
90
+
91
+ The band order for Sentinel-1 is: vv, vh.
92
+
93
+ The band order for Sentinel-2 is: B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09,
94
+ B11, B12. It is trained on L1C images with B10 removed.
95
+
96
+ See https://github.com/antofuller/CROMA for more details.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ size: CromaSize,
102
+ modality: CromaModality,
103
+ pretrained_path: str | None = None,
104
+ image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
105
+ do_resizing: bool = False,
106
+ ) -> None:
107
+ """Instantiate a new Croma instance.
108
+
109
+ Args:
110
+ size: the model size, either base or large.
111
+ modality: the modalities to configure the model to accept.
112
+ pretrained_path: the local path to the pretrained weights. Otherwise it is
113
+ downloaded and cached in temp directory.
114
+ image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
115
+ do_resizing: Whether to resize the image to the input resolution.
116
+ """
117
+ super().__init__()
118
+ self.size = size
119
+ self.modality = modality
120
+ self.do_resizing = do_resizing
121
+ if not do_resizing:
122
+ self.image_resolution = image_resolution
123
+ else:
124
+ # With single pixel input, we always resample to the patch size.
125
+ if image_resolution == 1:
126
+ self.image_resolution = PATCH_SIZE
127
+ else:
128
+ self.image_resolution = DEFAULT_IMAGE_RESOLUTION
129
+
130
+ # Cache the CROMA weights to a deterministic path in temporary directory if the
131
+ # path is not provided by the user.
132
+ if pretrained_path is None:
133
+ pretrained_url = PRETRAINED_URLS[self.size]
134
+ local_fname = UPath(
135
+ tempfile.gettempdir(), "rslearn_cache", "croma", f"{self.size.value}.pt"
136
+ )
137
+ if not local_fname.exists():
138
+ logger.info(
139
+ "caching CROMA weights from %s to %s", pretrained_url, local_fname
140
+ )
141
+ local_fname.parent.mkdir(parents=True, exist_ok=True)
142
+ with urllib.request.urlopen(pretrained_url) as response:
143
+ with open_atomic(local_fname, "wb") as f:
144
+ shutil.copyfileobj(response, f)
145
+ else:
146
+ logger.info("using cached CROMA weights at %s", local_fname)
147
+ pretrained_path = local_fname.path
148
+
149
+ self.model = PretrainedCROMA(
150
+ pretrained_path=pretrained_path,
151
+ size=size.value,
152
+ modality=modality.value,
153
+ image_resolution=self.image_resolution,
154
+ )
155
+
156
+ def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
157
+ """Resize the image to the input resolution."""
158
+ return F.interpolate(
159
+ image,
160
+ size=(self.image_resolution, self.image_resolution),
161
+ mode="bilinear",
162
+ align_corners=False,
163
+ )
164
+
165
+ def forward(self, context: ModelContext) -> FeatureMaps:
166
+ """Compute feature maps from the Croma backbone.
167
+
168
+ Args:
169
+ context: the model context. Input dicts must include either/both of
170
+ "sentinel2" or "sentinel1" keys depending on the configured modality.
171
+
172
+ Returns:
173
+ a FeatureMaps with one feature map at 1/8 the input resolution.
174
+ """
175
+ sentinel1: torch.Tensor | None = None
176
+ sentinel2: torch.Tensor | None = None
177
+ if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
178
+ sentinel1 = torch.stack(
179
+ [inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
180
+ dim=0,
181
+ )
182
+ sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
183
+ if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
184
+ sentinel2 = torch.stack(
185
+ [inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
186
+ dim=0,
187
+ )
188
+ sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
189
+
190
+ outputs = self.model(
191
+ SAR_images=sentinel1,
192
+ optical_images=sentinel2,
193
+ )
194
+
195
+ # Pick which encoding to use.
196
+ # If modality is both, then there are three options, we could concatenate the
197
+ # SAR and optical encodings but for now we just use the joint encodings.
198
+ if self.modality == CromaModality.BOTH:
199
+ features = outputs["joint_encodings"]
200
+ elif self.modality == CromaModality.SENTINEL1:
201
+ features = outputs["SAR_encodings"]
202
+ elif self.modality == CromaModality.SENTINEL2:
203
+ features = outputs["optical_encodings"]
204
+
205
+ # Rearrange from patch embeddings to 2D feature map.
206
+ num_patches_per_dim = self.image_resolution // PATCH_SIZE
207
+ features = rearrange(
208
+ features,
209
+ "b (h w) d -> b d h w",
210
+ h=num_patches_per_dim,
211
+ w=num_patches_per_dim,
212
+ )
213
+
214
+ return FeatureMaps([features])
215
+
216
+ def get_backbone_channels(self) -> list:
217
+ """Returns the output channels of this model when used as a backbone.
218
+
219
+ The output channels is a list of (downsample_factor, depth) that corresponds
220
+ to the feature maps that the backbone returns. For example, an element [2, 32]
221
+ indicates that the corresponding feature map is 1/2 the input resolution and
222
+ has 32 channels.
223
+
224
+ Returns:
225
+ the output channels of the backbone as a list of (downsample_factor, depth)
226
+ tuples.
227
+ """
228
+ if self.size == CromaSize.BASE:
229
+ depth = 768
230
+ elif self.size == CromaSize.LARGE:
231
+ depth = 1024
232
+ else:
233
+ raise ValueError(f"unknown CromaSize {self.size}")
234
+ return [(PATCH_SIZE, depth)]
235
+
236
+
237
+ class CromaNormalize(Transform):
238
+ """Normalize inputs using CROMA normalization.
239
+
240
+ It will apply normalization to the "sentinel1" and "sentinel2" input keys (if set).
241
+ """
242
+
243
+ def __init__(self) -> None:
244
+ """Initialize a new CromaNormalize."""
245
+ super().__init__()
246
+
247
+ def apply_image(self, image: torch.Tensor, modality: str) -> torch.Tensor:
248
+ """Normalize the specified image with CROMA normalization.
249
+
250
+ CROMA normalized based on batch statistics, but we may apply the model with
251
+ small batches, so we instead use preset statistics corresponding to the dataset
252
+ distribution.
253
+
254
+ The normalized value is based on clipping to [mean-2*std, mean+2*std] and then
255
+ linear rescaling to [0, 1].
256
+
257
+ Args:
258
+ image: the image to transform.
259
+ modality: the modality of the image.
260
+ mean: the mean to use for the normalization.
261
+ std: the standard deviation to use for the normalization.
262
+ """
263
+ image = image.float()
264
+
265
+ # Number of channels must be a multiple of the expected number of bands for
266
+ # this modality. It can be a multiple since we accept stacked time series.
267
+ band_names = MODALITY_BANDS[modality]
268
+ if image.shape[0] % len(band_names) != 0:
269
+ raise ValueError(
270
+ f"image has {image.shape[0]} channels for modality {modality} which is not a multiple of expected number of bands {len(band_names)}"
271
+ )
272
+
273
+ normalized_bands = []
274
+ for band_idx in range(image.shape[0]):
275
+ band_name = band_names[band_idx % len(band_names)]
276
+ mean, std = MEAN_AND_STD_BY_BAND[(modality, band_name)]
277
+
278
+ orig = image[band_idx, :, :]
279
+ min_value = mean - 2 * std
280
+ max_value = mean + 2 * std
281
+
282
+ normalized = (orig - min_value) / (max_value - min_value)
283
+ normalized = torch.clip(normalized, 0, 1)
284
+ normalized_bands.append(normalized)
285
+
286
+ return torch.stack(normalized_bands, dim=0)
287
+
288
+ def forward(
289
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
290
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
291
+ """Apply normalization over the inputs and targets.
292
+
293
+ Args:
294
+ input_dict: the input
295
+ target_dict: the target
296
+
297
+ Returns:
298
+ normalized (input_dicts, target_dicts) tuple
299
+ """
300
+ for modality in MODALITY_BANDS.keys():
301
+ if modality not in input_dict:
302
+ continue
303
+ input_dict[modality].image = self.apply_image(
304
+ input_dict[modality].image, modality
305
+ )
306
+ return input_dict, target_dict
@@ -0,0 +1,5 @@
1
+ """DETR object detection model code."""
2
+
3
+ from .detr import Detr
4
+
5
+ __all__ = ["Detr"]
@@ -0,0 +1,103 @@
1
+ """Utilities for bounding box manipulation and GIoU.
2
+
3
+ This is copied from https://github.com/facebookresearch/detr/.
4
+ The original code is:
5
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6
+ """
7
+
8
+ import torch
9
+ from torchvision.ops.boxes import box_area
10
+
11
+
12
+ def box_cxcywh_to_xyxy(x: torch.Tensor) -> torch.Tensor:
13
+ """Convert boxes from cxcywh format to xyxy format."""
14
+ x_c, y_c, w, h = x.unbind(-1)
15
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
16
+ return torch.stack(b, dim=-1)
17
+
18
+
19
+ def box_xyxy_to_cxcywh(x: torch.Tensor) -> torch.Tensor:
20
+ """Convert boxes from xyxy format to cxcywh format."""
21
+ x0, y0, x1, y1 = x.unbind(-1)
22
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
23
+ return torch.stack(b, dim=-1)
24
+
25
+
26
+ # modified from torchvision to also return the union
27
+ def box_iou(
28
+ boxes1: torch.Tensor, boxes2: torch.Tensor
29
+ ) -> tuple[torch.Tensor, torch.Tensor]:
30
+ """Compute the intersection-over-union score between the two lists of boxes.
31
+
32
+ The boxes should be in xyxy format.
33
+
34
+ Args:
35
+ boxes1: the first list of boxes (Nx4).
36
+ boxes2: the second list of boxes (Mx4).
37
+
38
+ Returns:
39
+ the intersection-over-union score.
40
+ """
41
+ area1 = box_area(boxes1)
42
+ area2 = box_area(boxes2)
43
+
44
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
45
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
46
+
47
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
48
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
49
+
50
+ union = area1[:, None] + area2 - inter
51
+
52
+ iou = inter / union
53
+ return iou, union
54
+
55
+
56
+ def generalized_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
57
+ """Generalized IoU from https://giou.stanford.edu/.
58
+
59
+ The boxes should be in [x0, y0, x1, y1] format
60
+
61
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
62
+ and M = len(boxes2)
63
+ """
64
+ # degenerate boxes gives inf / nan results
65
+ # so do an early check
66
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
67
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
68
+ iou, union = box_iou(boxes1, boxes2)
69
+
70
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
71
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
72
+
73
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
74
+ area = wh[:, :, 0] * wh[:, :, 1]
75
+
76
+ return iou - (area - union) / area
77
+
78
+
79
+ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
80
+ """Compute the bounding boxes around the provided masks.
81
+
82
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
83
+
84
+ Returns a [N, 4] tensors, with the boxes in xyxy format
85
+ """
86
+ if masks.numel() == 0:
87
+ return torch.zeros((0, 4), device=masks.device)
88
+
89
+ h, w = masks.shape[-2:]
90
+
91
+ y = torch.arange(0, h, dtype=torch.float)
92
+ x = torch.arange(0, w, dtype=torch.float)
93
+ y, x = torch.meshgrid(y, x)
94
+
95
+ x_mask = masks * x.unsqueeze(0)
96
+ x_max = x_mask.flatten(1).max(-1)[0]
97
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
98
+
99
+ y_mask = masks * y.unsqueeze(0)
100
+ y_max = y_mask.flatten(1).max(-1)[0]
101
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
102
+
103
+ return torch.stack([x_min, y_min, x_max, y_max], 1)