rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,13 @@ from torchmetrics.classification import (
|
|
|
15
15
|
MulticlassRecall,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
+
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
+
from rslearn.train.model_context import (
|
|
20
|
+
ModelContext,
|
|
21
|
+
ModelOutput,
|
|
22
|
+
RasterImage,
|
|
23
|
+
SampleMetadata,
|
|
24
|
+
)
|
|
18
25
|
from rslearn.utils import Feature, STGeometry
|
|
19
26
|
|
|
20
27
|
from .task import BasicTask
|
|
@@ -27,7 +34,7 @@ class ClassificationTask(BasicTask):
|
|
|
27
34
|
self,
|
|
28
35
|
property_name: str,
|
|
29
36
|
classes: list[str],
|
|
30
|
-
filters: list[tuple[str, str]]
|
|
37
|
+
filters: list[tuple[str, str]] = [],
|
|
31
38
|
read_class_id: bool = False,
|
|
32
39
|
allow_invalid: bool = False,
|
|
33
40
|
skip_unknown_categories: bool = False,
|
|
@@ -37,7 +44,7 @@ class ClassificationTask(BasicTask):
|
|
|
37
44
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
38
45
|
positive_class: str | None = None,
|
|
39
46
|
positive_class_threshold: float = 0.5,
|
|
40
|
-
**kwargs,
|
|
47
|
+
**kwargs: Any,
|
|
41
48
|
):
|
|
42
49
|
"""Initialize a new ClassificationTask.
|
|
43
50
|
|
|
@@ -49,8 +56,8 @@ class ClassificationTask(BasicTask):
|
|
|
49
56
|
features with matching properties.
|
|
50
57
|
read_class_id: whether to read an integer class ID instead of the class
|
|
51
58
|
name.
|
|
52
|
-
allow_invalid: instead of throwing error when no
|
|
53
|
-
at a window, simply mark the example invalid for this task
|
|
59
|
+
allow_invalid: instead of throwing error when no classification label is
|
|
60
|
+
found at a window, simply mark the example invalid for this task
|
|
54
61
|
skip_unknown_categories: whether to skip examples with categories that are
|
|
55
62
|
not passed via classes, instead of throwing error
|
|
56
63
|
prob_property: when predicting, write probabilities in addition to class ID
|
|
@@ -95,13 +102,10 @@ class ClassificationTask(BasicTask):
|
|
|
95
102
|
else:
|
|
96
103
|
self.positive_class_id = self.classes.index(self.positive_class)
|
|
97
104
|
|
|
98
|
-
if not self.filters:
|
|
99
|
-
self.filters = []
|
|
100
|
-
|
|
101
105
|
def process_inputs(
|
|
102
106
|
self,
|
|
103
|
-
raw_inputs: dict[str,
|
|
104
|
-
metadata:
|
|
107
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
108
|
+
metadata: SampleMetadata,
|
|
105
109
|
load_targets: bool = True,
|
|
106
110
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
107
111
|
"""Processes the data into targets.
|
|
@@ -119,7 +123,10 @@ class ClassificationTask(BasicTask):
|
|
|
119
123
|
return {}, {}
|
|
120
124
|
|
|
121
125
|
data = raw_inputs["targets"]
|
|
126
|
+
assert isinstance(data, list)
|
|
122
127
|
for feat in data:
|
|
128
|
+
if feat.properties is None:
|
|
129
|
+
continue
|
|
123
130
|
for property_name, property_value in self.filters:
|
|
124
131
|
if feat.properties.get(property_name) != property_value:
|
|
125
132
|
continue
|
|
@@ -155,17 +162,25 @@ class ClassificationTask(BasicTask):
|
|
|
155
162
|
}
|
|
156
163
|
|
|
157
164
|
def process_output(
|
|
158
|
-
self, raw_output: Any, metadata:
|
|
159
|
-
) ->
|
|
165
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
166
|
+
) -> list[Feature]:
|
|
160
167
|
"""Processes an output into raster or vector data.
|
|
161
168
|
|
|
162
169
|
Args:
|
|
163
|
-
raw_output: the output from prediction head
|
|
170
|
+
raw_output: the output from prediction head, which must be a tensor
|
|
171
|
+
containing output probabilities (one dimension).
|
|
164
172
|
metadata: metadata about the patch being read
|
|
165
173
|
|
|
166
174
|
Returns:
|
|
167
|
-
|
|
175
|
+
a list with one Feature corresponding to the input patch extent with a
|
|
176
|
+
property name containing the predicted class. It will have another
|
|
177
|
+
property containing the probabilities if prob_property was set.
|
|
168
178
|
"""
|
|
179
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 1:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
"expected output for ClassificationTask to be a Tensor with one dimension"
|
|
182
|
+
)
|
|
183
|
+
|
|
169
184
|
probs = raw_output.cpu().numpy()
|
|
170
185
|
if len(self.classes) == 2 and self.positive_class_threshold != 0.5:
|
|
171
186
|
positive_class_prob = probs[self.positive_class_id]
|
|
@@ -175,24 +190,25 @@ class ClassificationTask(BasicTask):
|
|
|
175
190
|
class_idx = 1 - self.positive_class_id
|
|
176
191
|
else:
|
|
177
192
|
# For multiclass classification or when using the default threshold
|
|
178
|
-
class_idx = probs.argmax()
|
|
193
|
+
class_idx = probs.argmax().item()
|
|
179
194
|
|
|
195
|
+
value: str | int
|
|
180
196
|
if not self.read_class_id:
|
|
181
|
-
value = self.classes[class_idx]
|
|
197
|
+
value = self.classes[class_idx] # type: ignore
|
|
182
198
|
else:
|
|
183
199
|
value = class_idx
|
|
184
200
|
|
|
185
201
|
feature = Feature(
|
|
186
202
|
STGeometry(
|
|
187
|
-
metadata
|
|
188
|
-
shapely.Point(metadata[
|
|
203
|
+
metadata.projection,
|
|
204
|
+
shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
|
|
189
205
|
None,
|
|
190
206
|
),
|
|
191
207
|
{
|
|
192
208
|
self.property_name: value,
|
|
193
209
|
},
|
|
194
210
|
)
|
|
195
|
-
if self.prob_property:
|
|
211
|
+
if self.prob_property is not None and feature.properties is not None:
|
|
196
212
|
feature.properties[self.prob_property] = probs.tolist()
|
|
197
213
|
return [feature]
|
|
198
214
|
|
|
@@ -215,6 +231,8 @@ class ClassificationTask(BasicTask):
|
|
|
215
231
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
216
232
|
image = Image.fromarray(image)
|
|
217
233
|
draw = ImageDraw.Draw(image)
|
|
234
|
+
if target_dict is None:
|
|
235
|
+
raise ValueError("target_dict is required for visualization")
|
|
218
236
|
target_class = self.classes[target_dict["class"]]
|
|
219
237
|
output_class = self.classes[output.argmax()]
|
|
220
238
|
text = f"Label: {target_class}\nOutput: {output_class}"
|
|
@@ -263,28 +281,34 @@ class ClassificationTask(BasicTask):
|
|
|
263
281
|
return MetricCollection(metrics)
|
|
264
282
|
|
|
265
283
|
|
|
266
|
-
class ClassificationHead(
|
|
284
|
+
class ClassificationHead(Predictor):
|
|
267
285
|
"""Head for classification task."""
|
|
268
286
|
|
|
269
287
|
def forward(
|
|
270
288
|
self,
|
|
271
|
-
|
|
272
|
-
|
|
289
|
+
intermediates: Any,
|
|
290
|
+
context: ModelContext,
|
|
273
291
|
targets: list[dict[str, Any]] | None = None,
|
|
274
|
-
) ->
|
|
292
|
+
) -> ModelOutput:
|
|
275
293
|
"""Compute the classification outputs and loss from logits and targets.
|
|
276
294
|
|
|
277
295
|
Args:
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
296
|
+
intermediates: output from the previous model component, it should be a
|
|
297
|
+
FeatureVector with a tensor that is (BatchSize, NumClasses) in shape.
|
|
298
|
+
context: the model context.
|
|
299
|
+
targets: must contain "class" key that stores the class label, along with
|
|
300
|
+
"valid" key indicating whether the label is valid for each example.
|
|
281
301
|
|
|
282
302
|
Returns:
|
|
283
303
|
tuple of outputs and loss dict
|
|
284
304
|
"""
|
|
305
|
+
if not isinstance(intermediates, FeatureVector):
|
|
306
|
+
raise ValueError("the input to ClassificationHead must be a FeatureVector")
|
|
307
|
+
|
|
308
|
+
logits = intermediates.feature_vector
|
|
285
309
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
286
310
|
|
|
287
|
-
|
|
311
|
+
losses = {}
|
|
288
312
|
if targets:
|
|
289
313
|
class_labels = torch.stack([target["class"] for target in targets], dim=0)
|
|
290
314
|
mask = torch.stack([target["valid"] for target in targets], dim=0)
|
|
@@ -294,9 +318,12 @@ class ClassificationHead(torch.nn.Module):
|
|
|
294
318
|
)
|
|
295
319
|
* mask
|
|
296
320
|
)
|
|
297
|
-
|
|
321
|
+
losses["cls"] = torch.mean(loss)
|
|
298
322
|
|
|
299
|
-
return
|
|
323
|
+
return ModelOutput(
|
|
324
|
+
outputs=outputs,
|
|
325
|
+
loss_dict=losses,
|
|
326
|
+
)
|
|
300
327
|
|
|
301
328
|
|
|
302
329
|
class ClassificationMetric(Metric):
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -12,26 +12,27 @@ import torchmetrics.classification
|
|
|
12
12
|
import torchvision
|
|
13
13
|
from torchmetrics import Metric, MetricCollection
|
|
14
14
|
|
|
15
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
15
16
|
from rslearn.utils import Feature, STGeometry
|
|
16
17
|
|
|
17
18
|
from .task import BasicTask
|
|
18
19
|
|
|
19
20
|
DEFAULT_COLORS = [
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
21
|
+
(255, 0, 0),
|
|
22
|
+
(0, 255, 0),
|
|
23
|
+
(0, 0, 255),
|
|
24
|
+
(255, 255, 0),
|
|
25
|
+
(0, 255, 255),
|
|
26
|
+
(255, 0, 255),
|
|
27
|
+
(0, 128, 0),
|
|
28
|
+
(255, 160, 122),
|
|
29
|
+
(139, 69, 19),
|
|
30
|
+
(128, 128, 128),
|
|
31
|
+
(255, 255, 255),
|
|
32
|
+
(143, 188, 143),
|
|
33
|
+
(95, 158, 160),
|
|
34
|
+
(255, 200, 0),
|
|
35
|
+
(128, 0, 0),
|
|
35
36
|
]
|
|
36
37
|
|
|
37
38
|
|
|
@@ -53,14 +54,30 @@ class DetectionTask(BasicTask):
|
|
|
53
54
|
score_threshold: float = 0.5,
|
|
54
55
|
enable_map_metric: bool = True,
|
|
55
56
|
enable_f1_metric: bool = False,
|
|
57
|
+
enable_precision_recall: bool = False,
|
|
58
|
+
f1_metric_thresholds: list[list[float]] = [
|
|
59
|
+
[
|
|
60
|
+
0.05,
|
|
61
|
+
0.1,
|
|
62
|
+
0.2,
|
|
63
|
+
0.3,
|
|
64
|
+
0.4,
|
|
65
|
+
0.5,
|
|
66
|
+
0.6,
|
|
67
|
+
0.7,
|
|
68
|
+
0.8,
|
|
69
|
+
0.9,
|
|
70
|
+
0.95,
|
|
71
|
+
]
|
|
72
|
+
],
|
|
56
73
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
57
|
-
**kwargs,
|
|
58
|
-
):
|
|
59
|
-
"""Initialize a new
|
|
74
|
+
**kwargs: Any,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Initialize a new DetectionTask.
|
|
60
77
|
|
|
61
78
|
Args:
|
|
62
|
-
property_name: the property from which to extract the class name.
|
|
63
|
-
|
|
79
|
+
property_name: the property from which to extract the class name. Features
|
|
80
|
+
without this property name are ignored.
|
|
64
81
|
classes: a list of class names.
|
|
65
82
|
filters: optional list of (property_name, property_value) to only consider
|
|
66
83
|
features with matching properties.
|
|
@@ -70,14 +87,20 @@ class DetectionTask(BasicTask):
|
|
|
70
87
|
not passed via classes, instead of throwing error
|
|
71
88
|
skip_empty_examples: whether to skip examples with zero labels.
|
|
72
89
|
colors: optional colors for each class
|
|
73
|
-
box_size: force all boxes to be this size, centered at the
|
|
74
|
-
geometry. Required for Point geometries.
|
|
90
|
+
box_size: force all boxes to be two times this size, centered at the
|
|
91
|
+
centroid of the geometry. Required for Point geometries.
|
|
75
92
|
clip_boxes: whether to clip boxes to the image bounds.
|
|
76
93
|
exclude_by_center: before optionally clipping boxes, exclude boxes if the
|
|
77
94
|
center is outside the image bounds.
|
|
78
95
|
score_threshold: confidence threshold for visualization and prediction.
|
|
79
96
|
enable_map_metric: whether to compute mAP (default true)
|
|
80
97
|
enable_f1_metric: whether to compute F1 (default false)
|
|
98
|
+
enable_precision_recall: whether to compute precision and recall.
|
|
99
|
+
f1_metric_thresholds: list of list of thresholds to apply for F1 metric, as
|
|
100
|
+
well as for precision and recall if enabled. Each inner list is used to
|
|
101
|
+
initialize a separate F1 metric where the best F1 across the thresholds
|
|
102
|
+
within the inner list is computed. If there are multiple inner lists,
|
|
103
|
+
then multiple F1 scores will be reported.
|
|
81
104
|
f1_metric_kwargs: extra arguments to pass to F1 metric.
|
|
82
105
|
kwargs: additional arguments to pass to BasicTask
|
|
83
106
|
"""
|
|
@@ -95,6 +118,8 @@ class DetectionTask(BasicTask):
|
|
|
95
118
|
self.score_threshold = score_threshold
|
|
96
119
|
self.enable_map_metric = enable_map_metric
|
|
97
120
|
self.enable_f1_metric = enable_f1_metric
|
|
121
|
+
self.enable_precision_recall = enable_precision_recall
|
|
122
|
+
self.f1_metric_thresholds = f1_metric_thresholds
|
|
98
123
|
self.f1_metric_kwargs = f1_metric_kwargs
|
|
99
124
|
|
|
100
125
|
if not self.filters:
|
|
@@ -102,8 +127,8 @@ class DetectionTask(BasicTask):
|
|
|
102
127
|
|
|
103
128
|
def process_inputs(
|
|
104
129
|
self,
|
|
105
|
-
raw_inputs: dict[str,
|
|
106
|
-
metadata:
|
|
130
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
131
|
+
metadata: SampleMetadata,
|
|
107
132
|
load_targets: bool = True,
|
|
108
133
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
109
134
|
"""Processes the data into targets.
|
|
@@ -120,15 +145,21 @@ class DetectionTask(BasicTask):
|
|
|
120
145
|
if not load_targets:
|
|
121
146
|
return {}, {}
|
|
122
147
|
|
|
148
|
+
bounds = metadata.patch_bounds
|
|
149
|
+
|
|
123
150
|
boxes = []
|
|
124
151
|
class_labels = []
|
|
125
152
|
valid = 1
|
|
126
153
|
|
|
127
154
|
data = raw_inputs["targets"]
|
|
155
|
+
assert isinstance(data, list)
|
|
128
156
|
for feat in data:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
157
|
+
if feat.properties is None:
|
|
158
|
+
continue
|
|
159
|
+
if self.filters is not None:
|
|
160
|
+
for property_name, property_value in self.filters:
|
|
161
|
+
if feat.properties.get(property_name) != property_value:
|
|
162
|
+
continue
|
|
132
163
|
if self.property_name not in feat.properties:
|
|
133
164
|
continue
|
|
134
165
|
|
|
@@ -159,39 +190,33 @@ class DetectionTask(BasicTask):
|
|
|
159
190
|
else:
|
|
160
191
|
box = [int(val) for val in shp.bounds]
|
|
161
192
|
|
|
162
|
-
if box[0] >=
|
|
193
|
+
if box[0] >= bounds[2] or box[2] <= bounds[0]:
|
|
163
194
|
continue
|
|
164
|
-
if box[1] >=
|
|
195
|
+
if box[1] >= bounds[3] or box[3] <= bounds[1]:
|
|
165
196
|
continue
|
|
166
197
|
|
|
167
198
|
if self.exclude_by_center:
|
|
168
199
|
center_col = (box[0] + box[2]) // 2
|
|
169
200
|
center_row = (box[1] + box[3]) // 2
|
|
170
|
-
if
|
|
171
|
-
center_col <= metadata["bounds"][0]
|
|
172
|
-
or center_col >= metadata["bounds"][2]
|
|
173
|
-
):
|
|
201
|
+
if center_col <= bounds[0] or center_col >= bounds[2]:
|
|
174
202
|
continue
|
|
175
|
-
if
|
|
176
|
-
center_row <= metadata["bounds"][1]
|
|
177
|
-
or center_row >= metadata["bounds"][3]
|
|
178
|
-
):
|
|
203
|
+
if center_row <= bounds[1] or center_row >= bounds[3]:
|
|
179
204
|
continue
|
|
180
205
|
|
|
181
206
|
if self.clip_boxes:
|
|
182
207
|
box = [
|
|
183
|
-
np.clip(box[0],
|
|
184
|
-
np.clip(box[1],
|
|
185
|
-
np.clip(box[2],
|
|
186
|
-
np.clip(box[3],
|
|
208
|
+
np.clip(box[0], bounds[0], bounds[2]),
|
|
209
|
+
np.clip(box[1], bounds[1], bounds[3]),
|
|
210
|
+
np.clip(box[2], bounds[0], bounds[2]),
|
|
211
|
+
np.clip(box[3], bounds[1], bounds[3]),
|
|
187
212
|
]
|
|
188
213
|
|
|
189
214
|
# Convert to relative coordinates.
|
|
190
215
|
box = [
|
|
191
|
-
box[0] -
|
|
192
|
-
box[1] -
|
|
193
|
-
box[2] -
|
|
194
|
-
box[3] -
|
|
216
|
+
box[0] - bounds[0],
|
|
217
|
+
box[1] - bounds[1],
|
|
218
|
+
box[2] - bounds[0],
|
|
219
|
+
box[3] - bounds[1],
|
|
195
220
|
]
|
|
196
221
|
|
|
197
222
|
boxes.append(box)
|
|
@@ -211,16 +236,12 @@ class DetectionTask(BasicTask):
|
|
|
211
236
|
"valid": torch.tensor(valid, dtype=torch.int32),
|
|
212
237
|
"boxes": boxes,
|
|
213
238
|
"labels": class_labels,
|
|
214
|
-
"width": torch.tensor(
|
|
215
|
-
|
|
216
|
-
),
|
|
217
|
-
"height": torch.tensor(
|
|
218
|
-
metadata["bounds"][3] - metadata["bounds"][1], dtype=torch.float32
|
|
219
|
-
),
|
|
239
|
+
"width": torch.tensor(bounds[2] - bounds[0], dtype=torch.float32),
|
|
240
|
+
"height": torch.tensor(bounds[3] - bounds[1], dtype=torch.float32),
|
|
220
241
|
}
|
|
221
242
|
|
|
222
243
|
def process_output(
|
|
223
|
-
self, raw_output: Any, metadata:
|
|
244
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
224
245
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
225
246
|
"""Processes an output into raster or vector data.
|
|
226
247
|
|
|
@@ -240,13 +261,13 @@ class DetectionTask(BasicTask):
|
|
|
240
261
|
features = []
|
|
241
262
|
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
242
263
|
shp = shapely.box(
|
|
243
|
-
metadata[
|
|
244
|
-
metadata[
|
|
245
|
-
metadata[
|
|
246
|
-
metadata[
|
|
264
|
+
metadata.patch_bounds[0] + float(box[0]),
|
|
265
|
+
metadata.patch_bounds[1] + float(box[1]),
|
|
266
|
+
metadata.patch_bounds[0] + float(box[2]),
|
|
267
|
+
metadata.patch_bounds[1] + float(box[3]),
|
|
247
268
|
)
|
|
248
|
-
geom = STGeometry(metadata
|
|
249
|
-
properties = {
|
|
269
|
+
geom = STGeometry(metadata.projection, shp, None)
|
|
270
|
+
properties: dict[str, Any] = {
|
|
250
271
|
"score": float(score),
|
|
251
272
|
}
|
|
252
273
|
|
|
@@ -278,7 +299,9 @@ class DetectionTask(BasicTask):
|
|
|
278
299
|
"""
|
|
279
300
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
280
301
|
|
|
281
|
-
def draw_boxes(
|
|
302
|
+
def draw_boxes(
|
|
303
|
+
image: npt.NDArray[Any], d: dict[str, torch.Tensor]
|
|
304
|
+
) -> npt.NDArray[Any]:
|
|
282
305
|
boxes = d["boxes"].cpu().numpy()
|
|
283
306
|
class_ids = d["labels"].cpu().numpy()
|
|
284
307
|
if "scores" in d:
|
|
@@ -299,6 +322,8 @@ class DetectionTask(BasicTask):
|
|
|
299
322
|
|
|
300
323
|
return image
|
|
301
324
|
|
|
325
|
+
if target_dict is None:
|
|
326
|
+
raise ValueError("target_dict is required for visualization")
|
|
302
327
|
return {
|
|
303
328
|
"gt": draw_boxes(image.copy(), target_dict),
|
|
304
329
|
"pred": draw_boxes(image.copy(), output),
|
|
@@ -307,17 +332,46 @@ class DetectionTask(BasicTask):
|
|
|
307
332
|
def get_metrics(self) -> MetricCollection:
|
|
308
333
|
"""Get the metrics for this task."""
|
|
309
334
|
metrics = {}
|
|
335
|
+
|
|
310
336
|
if self.enable_map_metric:
|
|
311
337
|
metrics["mAP"] = DetectionMetric(
|
|
312
338
|
torchmetrics.detection.mean_ap.MeanAveragePrecision(),
|
|
313
339
|
output_key="map",
|
|
314
340
|
)
|
|
315
|
-
|
|
341
|
+
|
|
342
|
+
if self.enable_f1_metric or self.enable_precision_recall:
|
|
316
343
|
kwargs = dict(
|
|
317
344
|
num_classes=len(self.classes),
|
|
318
345
|
)
|
|
319
346
|
kwargs.update(self.f1_metric_kwargs)
|
|
320
|
-
|
|
347
|
+
|
|
348
|
+
for thresholds in self.f1_metric_thresholds:
|
|
349
|
+
if len(self.f1_metric_thresholds) == 1:
|
|
350
|
+
suffix = ""
|
|
351
|
+
else:
|
|
352
|
+
# Metric name can't contain "." so change to ",".
|
|
353
|
+
suffix = "_" + str(thresholds[0]).replace(".", ",")
|
|
354
|
+
|
|
355
|
+
if self.enable_f1_metric:
|
|
356
|
+
metrics["F1" + suffix] = DetectionMetric(
|
|
357
|
+
F1Metric(score_thresholds=thresholds, **kwargs) # type: ignore
|
|
358
|
+
)
|
|
359
|
+
if self.enable_precision_recall:
|
|
360
|
+
metrics["precision" + suffix] = DetectionMetric(
|
|
361
|
+
F1Metric(
|
|
362
|
+
score_thresholds=thresholds,
|
|
363
|
+
metric_mode="precision",
|
|
364
|
+
**kwargs, # type: ignore
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
metrics["recall" + suffix] = DetectionMetric(
|
|
368
|
+
F1Metric(
|
|
369
|
+
score_thresholds=thresholds,
|
|
370
|
+
metric_mode="recall",
|
|
371
|
+
**kwargs, # type: ignore
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
|
|
321
375
|
return MetricCollection(metrics)
|
|
322
376
|
|
|
323
377
|
|
|
@@ -377,22 +431,11 @@ class F1Metric(Metric):
|
|
|
377
431
|
def __init__(
|
|
378
432
|
self,
|
|
379
433
|
num_classes: int,
|
|
434
|
+
score_thresholds: list[float],
|
|
380
435
|
cmp_mode: str = "iou",
|
|
381
436
|
cmp_threshold: float = 0.5,
|
|
382
|
-
score_thresholds: list[float] = [
|
|
383
|
-
0.05,
|
|
384
|
-
0.1,
|
|
385
|
-
0.2,
|
|
386
|
-
0.3,
|
|
387
|
-
0.4,
|
|
388
|
-
0.5,
|
|
389
|
-
0.6,
|
|
390
|
-
0.7,
|
|
391
|
-
0.8,
|
|
392
|
-
0.9,
|
|
393
|
-
0.95,
|
|
394
|
-
],
|
|
395
437
|
flatten_classes: bool = False,
|
|
438
|
+
metric_mode: str = "f1",
|
|
396
439
|
):
|
|
397
440
|
"""Create a new F1Metric.
|
|
398
441
|
|
|
@@ -406,6 +449,8 @@ class F1Metric(Metric):
|
|
|
406
449
|
flatten_classes: sum true positives, false positives, and false negatives
|
|
407
450
|
across classes and report combined F1 instead of computing F1 score for
|
|
408
451
|
each class and then reporting the average.
|
|
452
|
+
metric_mode: set to "precision" or "recall" to return that instead of F1
|
|
453
|
+
(default "f1")
|
|
409
454
|
"""
|
|
410
455
|
super().__init__()
|
|
411
456
|
self.num_classes = num_classes
|
|
@@ -413,6 +458,10 @@ class F1Metric(Metric):
|
|
|
413
458
|
self.cmp_threshold = cmp_threshold
|
|
414
459
|
self.score_thresholds = score_thresholds
|
|
415
460
|
self.flatten_classes = flatten_classes
|
|
461
|
+
self.metric_mode = metric_mode
|
|
462
|
+
|
|
463
|
+
assert self.cmp_mode in ["iou", "distance"]
|
|
464
|
+
assert self.metric_mode in ["f1", "precision", "recall"]
|
|
416
465
|
|
|
417
466
|
for cls_idx in range(self.num_classes):
|
|
418
467
|
for thr_idx in range(len(self.score_thresholds)):
|
|
@@ -531,8 +580,15 @@ class F1Metric(Metric):
|
|
|
531
580
|
else:
|
|
532
581
|
f1 = 2 * precision * recall / (precision + recall)
|
|
533
582
|
|
|
534
|
-
if
|
|
535
|
-
|
|
583
|
+
if self.metric_mode == "f1":
|
|
584
|
+
score = f1
|
|
585
|
+
elif self.metric_mode == "precision":
|
|
586
|
+
score = precision
|
|
587
|
+
elif self.metric_mode == "recall":
|
|
588
|
+
score = recall
|
|
589
|
+
|
|
590
|
+
if best_score is None or score > best_score:
|
|
591
|
+
best_score = score
|
|
536
592
|
|
|
537
593
|
best_scores.append(best_score)
|
|
538
594
|
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Embedding task."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import torch
|
|
7
|
+
from torchmetrics import MetricCollection
|
|
8
|
+
|
|
9
|
+
from rslearn.models.component import FeatureMaps
|
|
10
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
11
|
+
from rslearn.utils import Feature
|
|
12
|
+
|
|
13
|
+
from .task import Task
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EmbeddingTask(Task):
|
|
17
|
+
"""A dummy task for computing embeddings.
|
|
18
|
+
|
|
19
|
+
This task does not compute any targets or loss. Instead, it is just set up for
|
|
20
|
+
inference, to save embeddings from the configured model.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def process_inputs(
|
|
24
|
+
self,
|
|
25
|
+
raw_inputs: dict[str, torch.Tensor],
|
|
26
|
+
metadata: SampleMetadata,
|
|
27
|
+
load_targets: bool = True,
|
|
28
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
29
|
+
"""Processes the data into targets.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
raw_inputs: raster or vector data to process
|
|
33
|
+
metadata: metadata about the patch being read
|
|
34
|
+
load_targets: whether to load the targets or only inputs
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
tuple (input_dict, target_dict) containing the processed inputs and targets
|
|
38
|
+
that are compatible with both metrics and loss functions
|
|
39
|
+
"""
|
|
40
|
+
return {}, {}
|
|
41
|
+
|
|
42
|
+
def process_output(
|
|
43
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
44
|
+
) -> npt.NDArray[Any] | list[Feature]:
|
|
45
|
+
"""Processes an output into raster or vector data.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
raw_output: the output from prediction head, which must be a CxHxW tensor.
|
|
49
|
+
metadata: metadata about the patch being read
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
either raster or vector data.
|
|
53
|
+
"""
|
|
54
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"output for EmbeddingTask must be a tensor with three dimensions"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Just convert the raw output to numpy array that can be saved to GeoTIFF.
|
|
60
|
+
return raw_output.cpu().numpy()
|
|
61
|
+
|
|
62
|
+
def visualize(
|
|
63
|
+
self,
|
|
64
|
+
input_dict: dict[str, Any],
|
|
65
|
+
target_dict: dict[str, Any] | None,
|
|
66
|
+
output: Any,
|
|
67
|
+
) -> dict[str, npt.NDArray[Any]]:
|
|
68
|
+
"""Visualize the outputs and targets.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
input_dict: the input dict from process_inputs
|
|
72
|
+
target_dict: the target dict from process_inputs
|
|
73
|
+
output: the prediction
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
a dictionary mapping image name to visualization image
|
|
77
|
+
"""
|
|
78
|
+
# EmbeddingTask is only set up to support `model predict`.
|
|
79
|
+
raise NotImplementedError
|
|
80
|
+
|
|
81
|
+
def get_metrics(self) -> MetricCollection:
|
|
82
|
+
"""Get the metrics for this task."""
|
|
83
|
+
return MetricCollection({})
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class EmbeddingHead:
|
|
87
|
+
"""Head for embedding task.
|
|
88
|
+
|
|
89
|
+
It just adds a dummy loss to act as a Predictor.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def forward(
|
|
93
|
+
self,
|
|
94
|
+
intermediates: Any,
|
|
95
|
+
context: ModelContext,
|
|
96
|
+
targets: list[dict[str, Any]] | None = None,
|
|
97
|
+
) -> ModelOutput:
|
|
98
|
+
"""Return the feature map along with a dummy loss.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
intermediates: output from the previous model component, which must be a
|
|
102
|
+
FeatureMaps consisting of a single feature map.
|
|
103
|
+
context: the model context.
|
|
104
|
+
targets: the targets (ignored).
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
model output with the feature map that was input to this component along
|
|
108
|
+
with a dummy loss.
|
|
109
|
+
"""
|
|
110
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
111
|
+
raise ValueError("input to EmbeddingHead must be a FeatureMaps")
|
|
112
|
+
if len(intermediates.feature_maps) != 1:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"input to EmbeddingHead must have one feature map, but got {len(intermediates.feature_maps)}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return ModelOutput(
|
|
118
|
+
outputs=intermediates.feature_maps[0],
|
|
119
|
+
loss_dict={"loss": 0},
|
|
120
|
+
)
|