rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,13 @@ from torchmetrics.classification import (
15
15
  MulticlassRecall,
16
16
  )
17
17
 
18
+ from rslearn.models.component import FeatureVector, Predictor
19
+ from rslearn.train.model_context import (
20
+ ModelContext,
21
+ ModelOutput,
22
+ RasterImage,
23
+ SampleMetadata,
24
+ )
18
25
  from rslearn.utils import Feature, STGeometry
19
26
 
20
27
  from .task import BasicTask
@@ -27,7 +34,7 @@ class ClassificationTask(BasicTask):
27
34
  self,
28
35
  property_name: str,
29
36
  classes: list[str],
30
- filters: list[tuple[str, str]] | None = None,
37
+ filters: list[tuple[str, str]] = [],
31
38
  read_class_id: bool = False,
32
39
  allow_invalid: bool = False,
33
40
  skip_unknown_categories: bool = False,
@@ -37,7 +44,7 @@ class ClassificationTask(BasicTask):
37
44
  f1_metric_kwargs: dict[str, Any] = {},
38
45
  positive_class: str | None = None,
39
46
  positive_class_threshold: float = 0.5,
40
- **kwargs,
47
+ **kwargs: Any,
41
48
  ):
42
49
  """Initialize a new ClassificationTask.
43
50
 
@@ -49,8 +56,8 @@ class ClassificationTask(BasicTask):
49
56
  features with matching properties.
50
57
  read_class_id: whether to read an integer class ID instead of the class
51
58
  name.
52
- allow_invalid: instead of throwing error when no regression label is found
53
- at a window, simply mark the example invalid for this task
59
+ allow_invalid: instead of throwing error when no classification label is
60
+ found at a window, simply mark the example invalid for this task
54
61
  skip_unknown_categories: whether to skip examples with categories that are
55
62
  not passed via classes, instead of throwing error
56
63
  prob_property: when predicting, write probabilities in addition to class ID
@@ -95,13 +102,10 @@ class ClassificationTask(BasicTask):
95
102
  else:
96
103
  self.positive_class_id = self.classes.index(self.positive_class)
97
104
 
98
- if not self.filters:
99
- self.filters = []
100
-
101
105
  def process_inputs(
102
106
  self,
103
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
104
- metadata: dict[str, Any],
107
+ raw_inputs: dict[str, RasterImage | list[Feature]],
108
+ metadata: SampleMetadata,
105
109
  load_targets: bool = True,
106
110
  ) -> tuple[dict[str, Any], dict[str, Any]]:
107
111
  """Processes the data into targets.
@@ -119,7 +123,10 @@ class ClassificationTask(BasicTask):
119
123
  return {}, {}
120
124
 
121
125
  data = raw_inputs["targets"]
126
+ assert isinstance(data, list)
122
127
  for feat in data:
128
+ if feat.properties is None:
129
+ continue
123
130
  for property_name, property_value in self.filters:
124
131
  if feat.properties.get(property_name) != property_value:
125
132
  continue
@@ -155,17 +162,25 @@ class ClassificationTask(BasicTask):
155
162
  }
156
163
 
157
164
  def process_output(
158
- self, raw_output: Any, metadata: dict[str, Any]
159
- ) -> npt.NDArray[Any] | list[Feature]:
165
+ self, raw_output: Any, metadata: SampleMetadata
166
+ ) -> list[Feature]:
160
167
  """Processes an output into raster or vector data.
161
168
 
162
169
  Args:
163
- raw_output: the output from prediction head.
170
+ raw_output: the output from prediction head, which must be a tensor
171
+ containing output probabilities (one dimension).
164
172
  metadata: metadata about the patch being read
165
173
 
166
174
  Returns:
167
- either raster or vector data.
175
+ a list with one Feature corresponding to the input patch extent with a
176
+ property name containing the predicted class. It will have another
177
+ property containing the probabilities if prob_property was set.
168
178
  """
179
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 1:
180
+ raise ValueError(
181
+ "expected output for ClassificationTask to be a Tensor with one dimension"
182
+ )
183
+
169
184
  probs = raw_output.cpu().numpy()
170
185
  if len(self.classes) == 2 and self.positive_class_threshold != 0.5:
171
186
  positive_class_prob = probs[self.positive_class_id]
@@ -175,24 +190,25 @@ class ClassificationTask(BasicTask):
175
190
  class_idx = 1 - self.positive_class_id
176
191
  else:
177
192
  # For multiclass classification or when using the default threshold
178
- class_idx = probs.argmax()
193
+ class_idx = probs.argmax().item()
179
194
 
195
+ value: str | int
180
196
  if not self.read_class_id:
181
- value = self.classes[class_idx]
197
+ value = self.classes[class_idx] # type: ignore
182
198
  else:
183
199
  value = class_idx
184
200
 
185
201
  feature = Feature(
186
202
  STGeometry(
187
- metadata["projection"],
188
- shapely.Point(metadata["bounds"][0], metadata["bounds"][1]),
203
+ metadata.projection,
204
+ shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
189
205
  None,
190
206
  ),
191
207
  {
192
208
  self.property_name: value,
193
209
  },
194
210
  )
195
- if self.prob_property:
211
+ if self.prob_property is not None and feature.properties is not None:
196
212
  feature.properties[self.prob_property] = probs.tolist()
197
213
  return [feature]
198
214
 
@@ -215,6 +231,8 @@ class ClassificationTask(BasicTask):
215
231
  image = super().visualize(input_dict, target_dict, output)["image"]
216
232
  image = Image.fromarray(image)
217
233
  draw = ImageDraw.Draw(image)
234
+ if target_dict is None:
235
+ raise ValueError("target_dict is required for visualization")
218
236
  target_class = self.classes[target_dict["class"]]
219
237
  output_class = self.classes[output.argmax()]
220
238
  text = f"Label: {target_class}\nOutput: {output_class}"
@@ -263,28 +281,34 @@ class ClassificationTask(BasicTask):
263
281
  return MetricCollection(metrics)
264
282
 
265
283
 
266
- class ClassificationHead(torch.nn.Module):
284
+ class ClassificationHead(Predictor):
267
285
  """Head for classification task."""
268
286
 
269
287
  def forward(
270
288
  self,
271
- logits: torch.Tensor,
272
- inputs: list[dict[str, Any]],
289
+ intermediates: Any,
290
+ context: ModelContext,
273
291
  targets: list[dict[str, Any]] | None = None,
274
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
292
+ ) -> ModelOutput:
275
293
  """Compute the classification outputs and loss from logits and targets.
276
294
 
277
295
  Args:
278
- logits: tensor that is (BatchSize, NumClasses) in shape.
279
- inputs: original inputs (ignored).
280
- targets: should contain class key that stores the class label.
296
+ intermediates: output from the previous model component, it should be a
297
+ FeatureVector with a tensor that is (BatchSize, NumClasses) in shape.
298
+ context: the model context.
299
+ targets: must contain "class" key that stores the class label, along with
300
+ "valid" key indicating whether the label is valid for each example.
281
301
 
282
302
  Returns:
283
303
  tuple of outputs and loss dict
284
304
  """
305
+ if not isinstance(intermediates, FeatureVector):
306
+ raise ValueError("the input to ClassificationHead must be a FeatureVector")
307
+
308
+ logits = intermediates.feature_vector
285
309
  outputs = torch.nn.functional.softmax(logits, dim=1)
286
310
 
287
- loss = None
311
+ losses = {}
288
312
  if targets:
289
313
  class_labels = torch.stack([target["class"] for target in targets], dim=0)
290
314
  mask = torch.stack([target["valid"] for target in targets], dim=0)
@@ -294,9 +318,12 @@ class ClassificationHead(torch.nn.Module):
294
318
  )
295
319
  * mask
296
320
  )
297
- loss = torch.mean(loss)
321
+ losses["cls"] = torch.mean(loss)
298
322
 
299
- return outputs, {"cls": loss}
323
+ return ModelOutput(
324
+ outputs=outputs,
325
+ loss_dict=losses,
326
+ )
300
327
 
301
328
 
302
329
  class ClassificationMetric(Metric):
@@ -12,26 +12,27 @@ import torchmetrics.classification
12
12
  import torchvision
13
13
  from torchmetrics import Metric, MetricCollection
14
14
 
15
+ from rslearn.train.model_context import RasterImage, SampleMetadata
15
16
  from rslearn.utils import Feature, STGeometry
16
17
 
17
18
  from .task import BasicTask
18
19
 
19
20
  DEFAULT_COLORS = [
20
- [255, 0, 0],
21
- [0, 255, 0],
22
- [0, 0, 255],
23
- [255, 255, 0],
24
- [0, 255, 255],
25
- [255, 0, 255],
26
- [0, 128, 0],
27
- [255, 160, 122],
28
- [139, 69, 19],
29
- [128, 128, 128],
30
- [255, 255, 255],
31
- [143, 188, 143],
32
- [95, 158, 160],
33
- [255, 200, 0],
34
- [128, 0, 0],
21
+ (255, 0, 0),
22
+ (0, 255, 0),
23
+ (0, 0, 255),
24
+ (255, 255, 0),
25
+ (0, 255, 255),
26
+ (255, 0, 255),
27
+ (0, 128, 0),
28
+ (255, 160, 122),
29
+ (139, 69, 19),
30
+ (128, 128, 128),
31
+ (255, 255, 255),
32
+ (143, 188, 143),
33
+ (95, 158, 160),
34
+ (255, 200, 0),
35
+ (128, 0, 0),
35
36
  ]
36
37
 
37
38
 
@@ -53,14 +54,30 @@ class DetectionTask(BasicTask):
53
54
  score_threshold: float = 0.5,
54
55
  enable_map_metric: bool = True,
55
56
  enable_f1_metric: bool = False,
57
+ enable_precision_recall: bool = False,
58
+ f1_metric_thresholds: list[list[float]] = [
59
+ [
60
+ 0.05,
61
+ 0.1,
62
+ 0.2,
63
+ 0.3,
64
+ 0.4,
65
+ 0.5,
66
+ 0.6,
67
+ 0.7,
68
+ 0.8,
69
+ 0.9,
70
+ 0.95,
71
+ ]
72
+ ],
56
73
  f1_metric_kwargs: dict[str, Any] = {},
57
- **kwargs,
58
- ):
59
- """Initialize a new SegmentationTask.
74
+ **kwargs: Any,
75
+ ) -> None:
76
+ """Initialize a new DetectionTask.
60
77
 
61
78
  Args:
62
- property_name: the property from which to extract the class name. The class
63
- is read from the first matching feature.
79
+ property_name: the property from which to extract the class name. Features
80
+ without this property name are ignored.
64
81
  classes: a list of class names.
65
82
  filters: optional list of (property_name, property_value) to only consider
66
83
  features with matching properties.
@@ -70,14 +87,20 @@ class DetectionTask(BasicTask):
70
87
  not passed via classes, instead of throwing error
71
88
  skip_empty_examples: whether to skip examples with zero labels.
72
89
  colors: optional colors for each class
73
- box_size: force all boxes to be this size, centered at the centroid of the
74
- geometry. Required for Point geometries.
90
+ box_size: force all boxes to be two times this size, centered at the
91
+ centroid of the geometry. Required for Point geometries.
75
92
  clip_boxes: whether to clip boxes to the image bounds.
76
93
  exclude_by_center: before optionally clipping boxes, exclude boxes if the
77
94
  center is outside the image bounds.
78
95
  score_threshold: confidence threshold for visualization and prediction.
79
96
  enable_map_metric: whether to compute mAP (default true)
80
97
  enable_f1_metric: whether to compute F1 (default false)
98
+ enable_precision_recall: whether to compute precision and recall.
99
+ f1_metric_thresholds: list of list of thresholds to apply for F1 metric, as
100
+ well as for precision and recall if enabled. Each inner list is used to
101
+ initialize a separate F1 metric where the best F1 across the thresholds
102
+ within the inner list is computed. If there are multiple inner lists,
103
+ then multiple F1 scores will be reported.
81
104
  f1_metric_kwargs: extra arguments to pass to F1 metric.
82
105
  kwargs: additional arguments to pass to BasicTask
83
106
  """
@@ -95,6 +118,8 @@ class DetectionTask(BasicTask):
95
118
  self.score_threshold = score_threshold
96
119
  self.enable_map_metric = enable_map_metric
97
120
  self.enable_f1_metric = enable_f1_metric
121
+ self.enable_precision_recall = enable_precision_recall
122
+ self.f1_metric_thresholds = f1_metric_thresholds
98
123
  self.f1_metric_kwargs = f1_metric_kwargs
99
124
 
100
125
  if not self.filters:
@@ -102,8 +127,8 @@ class DetectionTask(BasicTask):
102
127
 
103
128
  def process_inputs(
104
129
  self,
105
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
106
- metadata: dict[str, Any],
130
+ raw_inputs: dict[str, RasterImage | list[Feature]],
131
+ metadata: SampleMetadata,
107
132
  load_targets: bool = True,
108
133
  ) -> tuple[dict[str, Any], dict[str, Any]]:
109
134
  """Processes the data into targets.
@@ -120,15 +145,21 @@ class DetectionTask(BasicTask):
120
145
  if not load_targets:
121
146
  return {}, {}
122
147
 
148
+ bounds = metadata.patch_bounds
149
+
123
150
  boxes = []
124
151
  class_labels = []
125
152
  valid = 1
126
153
 
127
154
  data = raw_inputs["targets"]
155
+ assert isinstance(data, list)
128
156
  for feat in data:
129
- for property_name, property_value in self.filters:
130
- if feat.properties.get(property_name) != property_value:
131
- continue
157
+ if feat.properties is None:
158
+ continue
159
+ if self.filters is not None:
160
+ for property_name, property_value in self.filters:
161
+ if feat.properties.get(property_name) != property_value:
162
+ continue
132
163
  if self.property_name not in feat.properties:
133
164
  continue
134
165
 
@@ -159,39 +190,33 @@ class DetectionTask(BasicTask):
159
190
  else:
160
191
  box = [int(val) for val in shp.bounds]
161
192
 
162
- if box[0] >= metadata["bounds"][2] or box[2] <= metadata["bounds"][0]:
193
+ if box[0] >= bounds[2] or box[2] <= bounds[0]:
163
194
  continue
164
- if box[1] >= metadata["bounds"][3] or box[3] <= metadata["bounds"][1]:
195
+ if box[1] >= bounds[3] or box[3] <= bounds[1]:
165
196
  continue
166
197
 
167
198
  if self.exclude_by_center:
168
199
  center_col = (box[0] + box[2]) // 2
169
200
  center_row = (box[1] + box[3]) // 2
170
- if (
171
- center_col <= metadata["bounds"][0]
172
- or center_col >= metadata["bounds"][2]
173
- ):
201
+ if center_col <= bounds[0] or center_col >= bounds[2]:
174
202
  continue
175
- if (
176
- center_row <= metadata["bounds"][1]
177
- or center_row >= metadata["bounds"][3]
178
- ):
203
+ if center_row <= bounds[1] or center_row >= bounds[3]:
179
204
  continue
180
205
 
181
206
  if self.clip_boxes:
182
207
  box = [
183
- np.clip(box[0], metadata["bounds"][0], metadata["bounds"][2]),
184
- np.clip(box[1], metadata["bounds"][1], metadata["bounds"][3]),
185
- np.clip(box[2], metadata["bounds"][0], metadata["bounds"][2]),
186
- np.clip(box[3], metadata["bounds"][1], metadata["bounds"][3]),
208
+ np.clip(box[0], bounds[0], bounds[2]),
209
+ np.clip(box[1], bounds[1], bounds[3]),
210
+ np.clip(box[2], bounds[0], bounds[2]),
211
+ np.clip(box[3], bounds[1], bounds[3]),
187
212
  ]
188
213
 
189
214
  # Convert to relative coordinates.
190
215
  box = [
191
- box[0] - metadata["bounds"][0],
192
- box[1] - metadata["bounds"][1],
193
- box[2] - metadata["bounds"][0],
194
- box[3] - metadata["bounds"][1],
216
+ box[0] - bounds[0],
217
+ box[1] - bounds[1],
218
+ box[2] - bounds[0],
219
+ box[3] - bounds[1],
195
220
  ]
196
221
 
197
222
  boxes.append(box)
@@ -211,16 +236,12 @@ class DetectionTask(BasicTask):
211
236
  "valid": torch.tensor(valid, dtype=torch.int32),
212
237
  "boxes": boxes,
213
238
  "labels": class_labels,
214
- "width": torch.tensor(
215
- metadata["bounds"][2] - metadata["bounds"][0], dtype=torch.float32
216
- ),
217
- "height": torch.tensor(
218
- metadata["bounds"][3] - metadata["bounds"][1], dtype=torch.float32
219
- ),
239
+ "width": torch.tensor(bounds[2] - bounds[0], dtype=torch.float32),
240
+ "height": torch.tensor(bounds[3] - bounds[1], dtype=torch.float32),
220
241
  }
221
242
 
222
243
  def process_output(
223
- self, raw_output: Any, metadata: dict[str, Any]
244
+ self, raw_output: Any, metadata: SampleMetadata
224
245
  ) -> npt.NDArray[Any] | list[Feature]:
225
246
  """Processes an output into raster or vector data.
226
247
 
@@ -240,13 +261,13 @@ class DetectionTask(BasicTask):
240
261
  features = []
241
262
  for box, class_id, score in zip(boxes, class_ids, scores):
242
263
  shp = shapely.box(
243
- metadata["bounds"][0] + float(box[0]),
244
- metadata["bounds"][1] + float(box[1]),
245
- metadata["bounds"][0] + float(box[2]),
246
- metadata["bounds"][1] + float(box[3]),
264
+ metadata.patch_bounds[0] + float(box[0]),
265
+ metadata.patch_bounds[1] + float(box[1]),
266
+ metadata.patch_bounds[0] + float(box[2]),
267
+ metadata.patch_bounds[1] + float(box[3]),
247
268
  )
248
- geom = STGeometry(metadata["projection"], shp, None)
249
- properties = {
269
+ geom = STGeometry(metadata.projection, shp, None)
270
+ properties: dict[str, Any] = {
250
271
  "score": float(score),
251
272
  }
252
273
 
@@ -278,7 +299,9 @@ class DetectionTask(BasicTask):
278
299
  """
279
300
  image = super().visualize(input_dict, target_dict, output)["image"]
280
301
 
281
- def draw_boxes(image: npt.NDArray[Any], d: dict[str, torch.Tensor]):
302
+ def draw_boxes(
303
+ image: npt.NDArray[Any], d: dict[str, torch.Tensor]
304
+ ) -> npt.NDArray[Any]:
282
305
  boxes = d["boxes"].cpu().numpy()
283
306
  class_ids = d["labels"].cpu().numpy()
284
307
  if "scores" in d:
@@ -299,6 +322,8 @@ class DetectionTask(BasicTask):
299
322
 
300
323
  return image
301
324
 
325
+ if target_dict is None:
326
+ raise ValueError("target_dict is required for visualization")
302
327
  return {
303
328
  "gt": draw_boxes(image.copy(), target_dict),
304
329
  "pred": draw_boxes(image.copy(), output),
@@ -307,17 +332,46 @@ class DetectionTask(BasicTask):
307
332
  def get_metrics(self) -> MetricCollection:
308
333
  """Get the metrics for this task."""
309
334
  metrics = {}
335
+
310
336
  if self.enable_map_metric:
311
337
  metrics["mAP"] = DetectionMetric(
312
338
  torchmetrics.detection.mean_ap.MeanAveragePrecision(),
313
339
  output_key="map",
314
340
  )
315
- if self.enable_f1_metric:
341
+
342
+ if self.enable_f1_metric or self.enable_precision_recall:
316
343
  kwargs = dict(
317
344
  num_classes=len(self.classes),
318
345
  )
319
346
  kwargs.update(self.f1_metric_kwargs)
320
- metrics["F1"] = DetectionMetric(F1Metric(**kwargs))
347
+
348
+ for thresholds in self.f1_metric_thresholds:
349
+ if len(self.f1_metric_thresholds) == 1:
350
+ suffix = ""
351
+ else:
352
+ # Metric name can't contain "." so change to ",".
353
+ suffix = "_" + str(thresholds[0]).replace(".", ",")
354
+
355
+ if self.enable_f1_metric:
356
+ metrics["F1" + suffix] = DetectionMetric(
357
+ F1Metric(score_thresholds=thresholds, **kwargs) # type: ignore
358
+ )
359
+ if self.enable_precision_recall:
360
+ metrics["precision" + suffix] = DetectionMetric(
361
+ F1Metric(
362
+ score_thresholds=thresholds,
363
+ metric_mode="precision",
364
+ **kwargs, # type: ignore
365
+ )
366
+ )
367
+ metrics["recall" + suffix] = DetectionMetric(
368
+ F1Metric(
369
+ score_thresholds=thresholds,
370
+ metric_mode="recall",
371
+ **kwargs, # type: ignore
372
+ )
373
+ )
374
+
321
375
  return MetricCollection(metrics)
322
376
 
323
377
 
@@ -377,22 +431,11 @@ class F1Metric(Metric):
377
431
  def __init__(
378
432
  self,
379
433
  num_classes: int,
434
+ score_thresholds: list[float],
380
435
  cmp_mode: str = "iou",
381
436
  cmp_threshold: float = 0.5,
382
- score_thresholds: list[float] = [
383
- 0.05,
384
- 0.1,
385
- 0.2,
386
- 0.3,
387
- 0.4,
388
- 0.5,
389
- 0.6,
390
- 0.7,
391
- 0.8,
392
- 0.9,
393
- 0.95,
394
- ],
395
437
  flatten_classes: bool = False,
438
+ metric_mode: str = "f1",
396
439
  ):
397
440
  """Create a new F1Metric.
398
441
 
@@ -406,6 +449,8 @@ class F1Metric(Metric):
406
449
  flatten_classes: sum true positives, false positives, and false negatives
407
450
  across classes and report combined F1 instead of computing F1 score for
408
451
  each class and then reporting the average.
452
+ metric_mode: set to "precision" or "recall" to return that instead of F1
453
+ (default "f1")
409
454
  """
410
455
  super().__init__()
411
456
  self.num_classes = num_classes
@@ -413,6 +458,10 @@ class F1Metric(Metric):
413
458
  self.cmp_threshold = cmp_threshold
414
459
  self.score_thresholds = score_thresholds
415
460
  self.flatten_classes = flatten_classes
461
+ self.metric_mode = metric_mode
462
+
463
+ assert self.cmp_mode in ["iou", "distance"]
464
+ assert self.metric_mode in ["f1", "precision", "recall"]
416
465
 
417
466
  for cls_idx in range(self.num_classes):
418
467
  for thr_idx in range(len(self.score_thresholds)):
@@ -531,8 +580,15 @@ class F1Metric(Metric):
531
580
  else:
532
581
  f1 = 2 * precision * recall / (precision + recall)
533
582
 
534
- if best_score is None or f1 > best_score:
535
- best_score = f1
583
+ if self.metric_mode == "f1":
584
+ score = f1
585
+ elif self.metric_mode == "precision":
586
+ score = precision
587
+ elif self.metric_mode == "recall":
588
+ score = recall
589
+
590
+ if best_score is None or score > best_score:
591
+ best_score = score
536
592
 
537
593
  best_scores.append(best_score)
538
594
 
@@ -0,0 +1,120 @@
1
+ """Embedding task."""
2
+
3
+ from typing import Any
4
+
5
+ import numpy.typing as npt
6
+ import torch
7
+ from torchmetrics import MetricCollection
8
+
9
+ from rslearn.models.component import FeatureMaps
10
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
11
+ from rslearn.utils import Feature
12
+
13
+ from .task import Task
14
+
15
+
16
+ class EmbeddingTask(Task):
17
+ """A dummy task for computing embeddings.
18
+
19
+ This task does not compute any targets or loss. Instead, it is just set up for
20
+ inference, to save embeddings from the configured model.
21
+ """
22
+
23
+ def process_inputs(
24
+ self,
25
+ raw_inputs: dict[str, torch.Tensor],
26
+ metadata: SampleMetadata,
27
+ load_targets: bool = True,
28
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
29
+ """Processes the data into targets.
30
+
31
+ Args:
32
+ raw_inputs: raster or vector data to process
33
+ metadata: metadata about the patch being read
34
+ load_targets: whether to load the targets or only inputs
35
+
36
+ Returns:
37
+ tuple (input_dict, target_dict) containing the processed inputs and targets
38
+ that are compatible with both metrics and loss functions
39
+ """
40
+ return {}, {}
41
+
42
+ def process_output(
43
+ self, raw_output: Any, metadata: SampleMetadata
44
+ ) -> npt.NDArray[Any] | list[Feature]:
45
+ """Processes an output into raster or vector data.
46
+
47
+ Args:
48
+ raw_output: the output from prediction head, which must be a CxHxW tensor.
49
+ metadata: metadata about the patch being read
50
+
51
+ Returns:
52
+ either raster or vector data.
53
+ """
54
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
55
+ raise ValueError(
56
+ "output for EmbeddingTask must be a tensor with three dimensions"
57
+ )
58
+
59
+ # Just convert the raw output to numpy array that can be saved to GeoTIFF.
60
+ return raw_output.cpu().numpy()
61
+
62
+ def visualize(
63
+ self,
64
+ input_dict: dict[str, Any],
65
+ target_dict: dict[str, Any] | None,
66
+ output: Any,
67
+ ) -> dict[str, npt.NDArray[Any]]:
68
+ """Visualize the outputs and targets.
69
+
70
+ Args:
71
+ input_dict: the input dict from process_inputs
72
+ target_dict: the target dict from process_inputs
73
+ output: the prediction
74
+
75
+ Returns:
76
+ a dictionary mapping image name to visualization image
77
+ """
78
+ # EmbeddingTask is only set up to support `model predict`.
79
+ raise NotImplementedError
80
+
81
+ def get_metrics(self) -> MetricCollection:
82
+ """Get the metrics for this task."""
83
+ return MetricCollection({})
84
+
85
+
86
+ class EmbeddingHead:
87
+ """Head for embedding task.
88
+
89
+ It just adds a dummy loss to act as a Predictor.
90
+ """
91
+
92
+ def forward(
93
+ self,
94
+ intermediates: Any,
95
+ context: ModelContext,
96
+ targets: list[dict[str, Any]] | None = None,
97
+ ) -> ModelOutput:
98
+ """Return the feature map along with a dummy loss.
99
+
100
+ Args:
101
+ intermediates: output from the previous model component, which must be a
102
+ FeatureMaps consisting of a single feature map.
103
+ context: the model context.
104
+ targets: the targets (ignored).
105
+
106
+ Returns:
107
+ model output with the feature map that was input to this component along
108
+ with a dummy loss.
109
+ """
110
+ if not isinstance(intermediates, FeatureMaps):
111
+ raise ValueError("input to EmbeddingHead must be a FeatureMaps")
112
+ if len(intermediates.feature_maps) != 1:
113
+ raise ValueError(
114
+ f"input to EmbeddingHead must have one feature map, but got {len(intermediates.feature_maps)}"
115
+ )
116
+
117
+ return ModelOutput(
118
+ outputs=intermediates.feature_maps[0],
119
+ loss_dict={"loss": 0},
120
+ )