rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,24 @@
1
- """Classification task."""
1
+ """Regression task."""
2
2
 
3
- from typing import Any
3
+ from typing import Any, Literal
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
+ import shapely
7
8
  import torch
8
9
  import torchmetrics
9
10
  from PIL import Image, ImageDraw
10
11
  from torchmetrics import Metric, MetricCollection
11
12
 
12
- from rslearn.utils import Feature
13
+ from rslearn.models.component import FeatureVector, Predictor
14
+ from rslearn.train.model_context import (
15
+ ModelContext,
16
+ ModelOutput,
17
+ RasterImage,
18
+ SampleMetadata,
19
+ )
20
+ from rslearn.utils.feature import Feature
21
+ from rslearn.utils.geometry import STGeometry
13
22
 
14
23
  from .task import BasicTask
15
24
 
@@ -20,23 +29,29 @@ class RegressionTask(BasicTask):
20
29
  def __init__(
21
30
  self,
22
31
  property_name: str,
23
- filters: list[tuple[str, str]] | None,
32
+ filters: list[tuple[str, str]] | None = None,
24
33
  allow_invalid: bool = False,
25
34
  scale_factor: float = 1,
26
- metric_mode: str = "mse",
27
- **kwargs,
28
- ):
35
+ metric_mode: Literal["mse", "l1"] = "mse",
36
+ use_accuracy_metric: bool = False,
37
+ within_factor: float = 0.1,
38
+ **kwargs: Any,
39
+ ) -> None:
29
40
  """Initialize a new RegressionTask.
30
41
 
31
42
  Args:
32
- property_name: the property from which to extract the regression value. The
33
- value is read from the first matching feature.
43
+ property_name: the property from which to extract the ground truth
44
+ regression value. The value is read from the first matching feature.
34
45
  filters: optional list of (property_name, property_value) to only consider
35
46
  features with matching properties.
36
47
  allow_invalid: instead of throwing error when no regression label is found
37
48
  at a window, simply mark the example invalid for this task
38
- scale_factor: multiply the label value by this factor
39
- metric_mode: what metric to use, either mse or l1
49
+ scale_factor: multiply the label value by this factor for training
50
+ metric_mode: what metric to use, either "mse" (default) or "l1"
51
+ use_accuracy_metric: include metric that reports percentage of
52
+ examples where output is within a factor of the ground truth.
53
+ within_factor: the factor for accuracy metric. If it's 0.2, and ground
54
+ truth is 5.0, then values from 5.0*0.8 to 5.0*1.2 are accepted.
40
55
  kwargs: other arguments to pass to BasicTask
41
56
  """
42
57
  super().__init__(**kwargs)
@@ -45,14 +60,16 @@ class RegressionTask(BasicTask):
45
60
  self.allow_invalid = allow_invalid
46
61
  self.scale_factor = scale_factor
47
62
  self.metric_mode = metric_mode
63
+ self.use_accuracy_metric = use_accuracy_metric
64
+ self.within_factor = within_factor
48
65
 
49
66
  if not self.filters:
50
67
  self.filters = []
51
68
 
52
69
  def process_inputs(
53
70
  self,
54
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
55
- metadata: dict[str, Any],
71
+ raw_inputs: dict[str, RasterImage | list[Feature]],
72
+ metadata: SampleMetadata,
56
73
  load_targets: bool = True,
57
74
  ) -> tuple[dict[str, Any], dict[str, Any]]:
58
75
  """Processes the data into targets.
@@ -70,7 +87,10 @@ class RegressionTask(BasicTask):
70
87
  return {}, {}
71
88
 
72
89
  data = raw_inputs["targets"]
90
+ assert isinstance(data, list)
73
91
  for feat in data:
92
+ if feat.properties is None or self.filters is None:
93
+ continue
74
94
  for property_name, property_value in self.filters:
75
95
  if feat.properties.get(property_name) != property_value:
76
96
  continue
@@ -90,6 +110,35 @@ class RegressionTask(BasicTask):
90
110
  "valid": torch.tensor(0, dtype=torch.float32),
91
111
  }
92
112
 
113
+ def process_output(
114
+ self, raw_output: Any, metadata: SampleMetadata
115
+ ) -> list[Feature]:
116
+ """Processes an output into raster or vector data.
117
+
118
+ Args:
119
+ raw_output: the output from prediction head, which must be a scalar tensor.
120
+ metadata: metadata about the patch being read
121
+
122
+ Returns:
123
+ a list with a single Feature corresponding to the patch extent and with a
124
+ property containing the predicted value.
125
+ """
126
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 0:
127
+ raise ValueError("output for RegressionTask must be a scalar Tensor")
128
+
129
+ output = raw_output.item() / self.scale_factor
130
+ feature = Feature(
131
+ STGeometry(
132
+ metadata.projection,
133
+ shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
134
+ None,
135
+ ),
136
+ {
137
+ self.property_name: output,
138
+ },
139
+ )
140
+ return [feature]
141
+
93
142
  def visualize(
94
143
  self,
95
144
  input_dict: dict[str, Any],
@@ -109,6 +158,8 @@ class RegressionTask(BasicTask):
109
158
  image = super().visualize(input_dict, target_dict, output)["image"]
110
159
  image = Image.fromarray(image)
111
160
  draw = ImageDraw.Draw(image)
161
+ if target_dict is None:
162
+ raise ValueError("target_dict is required for visualization")
112
163
  target = target_dict["value"] / self.scale_factor
113
164
  output = output / self.scale_factor
114
165
  text = f"Label: {target:.2f}\nOutput: {output:.2f}"
@@ -121,27 +172,36 @@ class RegressionTask(BasicTask):
121
172
 
122
173
  def get_metrics(self) -> MetricCollection:
123
174
  """Get the metrics for this task."""
175
+ metric_dict: dict[str, Metric] = {}
176
+
124
177
  if self.metric_mode == "mse":
125
- metric = torchmetrics.MeanSquaredError()
178
+ metric_dict["mse"] = RegressionMetricWrapper(
179
+ metric=torchmetrics.MeanSquaredError(), scale_factor=self.scale_factor
180
+ )
126
181
  elif self.metric_mode == "l1":
127
- metric = torchmetrics.MeanAbsoluteError()
128
- return MetricCollection(
129
- {
130
- self.metric_mode: RegressionMetricWrapper(
131
- metric=metric, scale_factor=self.scale_factor
132
- )
133
- }
134
- )
182
+ metric_dict["l1"] = RegressionMetricWrapper(
183
+ metric=torchmetrics.MeanAbsoluteError(), scale_factor=self.scale_factor
184
+ )
135
185
 
186
+ if self.use_accuracy_metric:
187
+ metric_dict["accuracy"] = RegressionMetricWrapper(
188
+ metric=RegressionAccuracy(self.within_factor),
189
+ scale_factor=self.scale_factor,
190
+ )
136
191
 
137
- class RegressionHead(torch.nn.Module):
192
+ return MetricCollection(metric_dict)
193
+
194
+
195
+ class RegressionHead(Predictor):
138
196
  """Head for regression task."""
139
197
 
140
- def __init__(self, loss_mode: str = "mse", use_sigmoid: bool = False):
198
+ def __init__(
199
+ self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
200
+ ):
141
201
  """Initialize a new RegressionHead.
142
202
 
143
203
  Args:
144
- loss_mode: the loss function to use, either "mse" or "l1".
204
+ loss_mode: the loss function to use, either "mse" (default) or "l1".
145
205
  use_sigmoid: whether to apply a sigmoid activation on the output. This
146
206
  requires targets to be between 0-1.
147
207
  """
@@ -151,48 +211,59 @@ class RegressionHead(torch.nn.Module):
151
211
 
152
212
  def forward(
153
213
  self,
154
- logits: torch.Tensor,
155
- inputs: list[dict[str, Any]],
214
+ intermediates: Any,
215
+ context: ModelContext,
156
216
  targets: list[dict[str, Any]] | None = None,
157
- ):
217
+ ) -> ModelOutput:
158
218
  """Compute the regression outputs and loss from logits and targets.
159
219
 
160
220
  Args:
161
- logits: tensor that is (BatchSize, 1) or (BatchSize) in shape.
162
- inputs: original inputs (ignored).
163
- targets: should contain target key that stores the regression label.
221
+ intermediates: output from previous model component, which must be a
222
+ FeatureVector with channel dimension size 1 (Bx1).
223
+ context: the model context.
224
+ targets: target dicts, which each must contain a "value" key containing the
225
+ regression label, along with a "valid" key containing a flag indicating
226
+ whether each example is valid for this task.
164
227
 
165
228
  Returns:
166
- tuple of outputs and loss dict
229
+ the model outputs. The output is a B tensor so that it is split up into a
230
+ scalar for each example.
167
231
  """
168
- assert len(logits.shape) in [1, 2]
169
- if len(logits.shape) == 2:
170
- assert logits.shape[1] == 1
171
- logits = logits[:, 0]
232
+ if not isinstance(intermediates, FeatureVector):
233
+ raise ValueError("the input to RegressionHead must be a FeatureVector")
234
+ if intermediates.feature_vector.shape[1] != 1:
235
+ raise ValueError(
236
+ f"the input to RegressionHead must have channel dimension size 1, but got shape {intermediates.feature_vector.shape}"
237
+ )
238
+
239
+ logits = intermediates.feature_vector[:, 0]
172
240
 
173
241
  if self.use_sigmoid:
174
242
  outputs = torch.nn.functional.sigmoid(logits)
175
243
  else:
176
244
  outputs = logits
177
245
 
178
- loss = None
246
+ losses = {}
179
247
  if targets:
180
248
  labels = torch.stack([target["value"] for target in targets])
181
249
  mask = torch.stack([target["valid"] for target in targets])
182
250
  if self.loss_mode == "mse":
183
- loss = torch.mean(torch.square(outputs - labels) * mask)
251
+ losses["regress"] = torch.mean(torch.square(outputs - labels) * mask)
184
252
  elif self.loss_mode == "l1":
185
- loss = torch.mean(torch.abs(outputs - labels) * mask)
253
+ losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
186
254
  else:
187
- assert False
255
+ raise ValueError(f"unknown loss mode {self.loss_mode}")
188
256
 
189
- return outputs, {"regress": loss}
257
+ return ModelOutput(
258
+ outputs=outputs,
259
+ loss_dict=losses,
260
+ )
190
261
 
191
262
 
192
263
  class RegressionMetricWrapper(Metric):
193
264
  """Metric for regression task."""
194
265
 
195
- def __init__(self, metric: Metric, scale_factor: float, **kwargs):
266
+ def __init__(self, metric: Metric, scale_factor: float, **kwargs: Any) -> None:
196
267
  """Initialize a new RegressionMetricWrapper.
197
268
 
198
269
  Args:
@@ -206,14 +277,17 @@ class RegressionMetricWrapper(Metric):
206
277
  self.metric = metric
207
278
  self.scale_factor = scale_factor
208
279
 
209
- def update(self, preds: list[Any], targets: list[dict[str, Any]]) -> None:
280
+ def update(
281
+ self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
282
+ ) -> None:
210
283
  """Update metric.
211
284
 
212
285
  Args:
213
286
  preds: the predictions
214
287
  targets: the targets
215
288
  """
216
- preds = torch.stack(preds)
289
+ if not isinstance(preds, torch.Tensor):
290
+ preds = torch.stack(preds)
217
291
  labels = torch.stack([target["value"] for target in targets])
218
292
 
219
293
  # Sub-select the valid labels.
@@ -237,3 +311,46 @@ class RegressionMetricWrapper(Metric):
237
311
  def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
238
312
  """Returns a plot of the metric."""
239
313
  return self.metric.plot(*args, **kwargs)
314
+
315
+
316
+ class RegressionAccuracy(Metric):
317
+ """Percentage of examples with estimate within some factor of ground truth."""
318
+
319
+ def __init__(self, factor: float) -> None:
320
+ """Initialize a new RegressionAccuracy.
321
+
322
+ Args:
323
+ factor: the factor so if estimate is within this much of ground truth then
324
+ it is marked correct.
325
+ """
326
+ super().__init__()
327
+ self.factor = factor
328
+ self.correct = 0
329
+ self.total = 0
330
+
331
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
332
+ """Update metric.
333
+
334
+ Args:
335
+ preds: the predictions
336
+ labels: the ground truth data
337
+ """
338
+ decisions = (preds >= labels * (1 - self.factor)) & (
339
+ preds <= labels * (1 + self.factor)
340
+ )
341
+ self.correct += torch.count_nonzero(decisions)
342
+ self.total += len(decisions)
343
+
344
+ def compute(self) -> Any:
345
+ """Returns the computed metric."""
346
+ return torch.tensor(self.correct / self.total)
347
+
348
+ def reset(self) -> None:
349
+ """Reset metric."""
350
+ super().reset()
351
+ self.correct = 0
352
+ self.total = 0
353
+
354
+ def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
355
+ """Returns a plot of the metric."""
356
+ return None