rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import numpy.typing as npt
6
6
  import torch
7
7
  from torchmetrics import Metric, MetricCollection
8
8
 
9
+ from rslearn.train.model_context import SampleMetadata
9
10
  from rslearn.utils import Feature
10
11
 
11
12
  from .task import Task
@@ -30,7 +31,7 @@ class MultiTask(Task):
30
31
  def process_inputs(
31
32
  self,
32
33
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
33
- metadata: dict[str, Any],
34
+ metadata: SampleMetadata,
34
35
  load_targets: bool = True,
35
36
  ) -> tuple[dict[str, Any], dict[str, Any]]:
36
37
  """Processes the data into targets.
@@ -46,14 +47,12 @@ class MultiTask(Task):
46
47
  """
47
48
  input_dict = {}
48
49
  target_dict = {}
49
- if metadata["dataset_source"] is None:
50
+ if metadata.dataset_source is None:
50
51
  # No multi-dataset, so always compute across all tasks
51
52
  task_iter = list(self.tasks.items())
52
53
  else:
53
54
  # Multi-dataset, so only compute for the task in this dataset
54
- task_iter = [
55
- (metadata["dataset_source"], self.tasks[metadata["dataset_source"]])
56
- ]
55
+ task_iter = [(metadata.dataset_source, self.tasks[metadata.dataset_source])]
57
56
 
58
57
  for task_name, task in task_iter:
59
58
  cur_raw_inputs = {}
@@ -71,12 +70,13 @@ class MultiTask(Task):
71
70
  return input_dict, target_dict
72
71
 
73
72
  def process_output(
74
- self, raw_output: Any, metadata: dict[str, Any]
73
+ self, raw_output: Any, metadata: SampleMetadata
75
74
  ) -> dict[str, Any]:
76
75
  """Processes an output into raster or vector data.
77
76
 
78
77
  Args:
79
- raw_output: the output from prediction head.
78
+ raw_output: the output from prediction head. It must be a dict mapping from
79
+ task name to per-task output for this sample.
80
80
  metadata: metadata about the patch being read
81
81
 
82
82
  Returns:
@@ -8,6 +8,8 @@ import torch
8
8
  import torchmetrics
9
9
  from torchmetrics import Metric, MetricCollection
10
10
 
11
+ from rslearn.models.component import FeatureMaps, Predictor
12
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
11
13
  from rslearn.utils.feature import Feature
12
14
 
13
15
  from .task import BasicTask
@@ -41,7 +43,7 @@ class PerPixelRegressionTask(BasicTask):
41
43
  def process_inputs(
42
44
  self,
43
45
  raw_inputs: dict[str, torch.Tensor],
44
- metadata: dict[str, Any],
46
+ metadata: SampleMetadata,
45
47
  load_targets: bool = True,
46
48
  ) -> tuple[dict[str, Any], dict[str, Any]]:
47
49
  """Processes the data into targets.
@@ -72,20 +74,23 @@ class PerPixelRegressionTask(BasicTask):
72
74
  }
73
75
 
74
76
  def process_output(
75
- self, raw_output: Any, metadata: dict[str, Any]
77
+ self, raw_output: Any, metadata: SampleMetadata
76
78
  ) -> npt.NDArray[Any] | list[Feature]:
77
79
  """Processes an output into raster or vector data.
78
80
 
79
81
  Args:
80
- raw_output: the output from prediction head.
82
+ raw_output: the output from prediction head, which must be an HW tensor.
81
83
  metadata: metadata about the patch being read
82
84
 
83
85
  Returns:
84
86
  either raster or vector data.
85
87
  """
86
- # Input could be CHW (with single channel) or just HW.
87
- if len(raw_output.shape) == 2:
88
- raw_output = raw_output[None, :, :]
88
+ if not isinstance(raw_output, torch.Tensor):
89
+ raise ValueError("output for PerPixelRegressionTask must be a tensor")
90
+ if len(raw_output.shape) != 2:
91
+ raise ValueError(
92
+ f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
93
+ )
89
94
  return (raw_output / self.scale_factor).cpu().numpy()
90
95
 
91
96
  def visualize(
@@ -133,7 +138,7 @@ class PerPixelRegressionTask(BasicTask):
133
138
  return MetricCollection(metric_dict)
134
139
 
135
140
 
136
- class PerPixelRegressionHead(torch.nn.Module):
141
+ class PerPixelRegressionHead(Predictor):
137
142
  """Head for per-pixel regression task."""
138
143
 
139
144
  def __init__(
@@ -156,24 +161,38 @@ class PerPixelRegressionHead(torch.nn.Module):
156
161
 
157
162
  def forward(
158
163
  self,
159
- logits: torch.Tensor,
160
- inputs: list[dict[str, Any]],
164
+ intermediates: Any,
165
+ context: ModelContext,
161
166
  targets: list[dict[str, Any]] | None = None,
162
- ) -> tuple[torch.Tensor, dict[str, Any]]:
167
+ ) -> ModelOutput:
163
168
  """Compute the regression outputs and loss from logits and targets.
164
169
 
165
170
  Args:
166
- logits: BxHxW or BxCxHxW tensor.
167
- inputs: original inputs (ignored).
168
- targets: should contain target key that stores the regression labels.
171
+ intermediates: output from previous component, which must be a FeatureMaps
172
+ with one feature map corresponding to the logits. The channel dimension
173
+ size must be 1.
174
+ context: the model context.
175
+ targets: must contain values key that stores the regression labels, and
176
+ valid key containing mask image indicating where the labels are valid.
169
177
 
170
178
  Returns:
171
- tuple of outputs and loss dict
179
+ tuple of outputs and loss dict. The output is a BHW tensor so that the
180
+ per-sample output is an HW tensor.
172
181
  """
173
- assert len(logits.shape) in [3, 4]
174
- if len(logits.shape) == 4:
175
- assert logits.shape[1] == 1
176
- logits = logits[:, 0, :, :]
182
+ if not isinstance(intermediates, FeatureMaps):
183
+ raise ValueError(
184
+ "the input to PerPixelRegressionHead must be a FeatureMaps"
185
+ )
186
+ if len(intermediates.feature_maps) != 1:
187
+ raise ValueError(
188
+ "the input to PerPixelRegressionHead must have one feature map"
189
+ )
190
+ if intermediates.feature_maps[0].shape[1] != 1:
191
+ raise ValueError(
192
+ f"the input to PerPixelRegressionHead must have channel dimension size 1, but got {intermediates.feature_maps[0].shape}"
193
+ )
194
+
195
+ logits = intermediates.feature_maps[0][:, 0, :, :]
177
196
 
178
197
  if self.use_sigmoid:
179
198
  outputs = torch.nn.functional.sigmoid(logits)
@@ -200,7 +219,10 @@ class PerPixelRegressionHead(torch.nn.Module):
200
219
  else:
201
220
  losses["regress"] = (scores * mask).sum() / mask_total
202
221
 
203
- return outputs, losses
222
+ return ModelOutput(
223
+ outputs=outputs,
224
+ loss_dict=losses,
225
+ )
204
226
 
205
227
 
206
228
  class PerPixelRegressionMetricWrapper(Metric):
@@ -10,6 +10,8 @@ import torchmetrics
10
10
  from PIL import Image, ImageDraw
11
11
  from torchmetrics import Metric, MetricCollection
12
12
 
13
+ from rslearn.models.component import FeatureVector, Predictor
14
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
13
15
  from rslearn.utils.feature import Feature
14
16
  from rslearn.utils.geometry import STGeometry
15
17
 
@@ -62,7 +64,7 @@ class RegressionTask(BasicTask):
62
64
  def process_inputs(
63
65
  self,
64
66
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
65
- metadata: dict[str, Any],
67
+ metadata: SampleMetadata,
66
68
  load_targets: bool = True,
67
69
  ) -> tuple[dict[str, Any], dict[str, Any]]:
68
70
  """Processes the data into targets.
@@ -103,22 +105,26 @@ class RegressionTask(BasicTask):
103
105
  }
104
106
 
105
107
  def process_output(
106
- self, raw_output: Any, metadata: dict[str, Any]
107
- ) -> npt.NDArray[Any] | list[Feature]:
108
+ self, raw_output: Any, metadata: SampleMetadata
109
+ ) -> list[Feature]:
108
110
  """Processes an output into raster or vector data.
109
111
 
110
112
  Args:
111
- raw_output: the output from prediction head.
113
+ raw_output: the output from prediction head, which must be a scalar tensor.
112
114
  metadata: metadata about the patch being read
113
115
 
114
116
  Returns:
115
- either raster or vector data.
117
+ a list with a single Feature corresponding to the patch extent and with a
118
+ property containing the predicted value.
116
119
  """
120
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 0:
121
+ raise ValueError("output for RegressionTask must be a scalar Tensor")
122
+
117
123
  output = raw_output.item() / self.scale_factor
118
124
  feature = Feature(
119
125
  STGeometry(
120
- metadata["projection"],
121
- shapely.Point(metadata["bounds"][0], metadata["bounds"][1]),
126
+ metadata.projection,
127
+ shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
122
128
  None,
123
129
  ),
124
130
  {
@@ -180,7 +186,7 @@ class RegressionTask(BasicTask):
180
186
  return MetricCollection(metric_dict)
181
187
 
182
188
 
183
- class RegressionHead(torch.nn.Module):
189
+ class RegressionHead(Predictor):
184
190
  """Head for regression task."""
185
191
 
186
192
  def __init__(
@@ -199,24 +205,32 @@ class RegressionHead(torch.nn.Module):
199
205
 
200
206
  def forward(
201
207
  self,
202
- logits: torch.Tensor,
203
- inputs: list[dict[str, Any]],
208
+ intermediates: Any,
209
+ context: ModelContext,
204
210
  targets: list[dict[str, Any]] | None = None,
205
- ) -> tuple[torch.Tensor, dict[str, Any]]:
211
+ ) -> ModelOutput:
206
212
  """Compute the regression outputs and loss from logits and targets.
207
213
 
208
214
  Args:
209
- logits: tensor that is (BatchSize, 1) or (BatchSize) in shape.
210
- inputs: original inputs (ignored).
211
- targets: should contain target key that stores the regression label.
215
+ intermediates: output from previous model component, which must be a
216
+ FeatureVector with channel dimension size 1 (Bx1).
217
+ context: the model context.
218
+ targets: target dicts, which each must contain a "value" key containing the
219
+ regression label, along with a "valid" key containing a flag indicating
220
+ whether each example is valid for this task.
212
221
 
213
222
  Returns:
214
- tuple of outputs and loss dict
223
+ the model outputs. The output is a B tensor so that it is split up into a
224
+ scalar for each example.
215
225
  """
216
- assert len(logits.shape) in [1, 2]
217
- if len(logits.shape) == 2:
218
- assert logits.shape[1] == 1
219
- logits = logits[:, 0]
226
+ if not isinstance(intermediates, FeatureVector):
227
+ raise ValueError("the input to RegressionHead must be a FeatureVector")
228
+ if intermediates.feature_vector.shape[1] != 1:
229
+ raise ValueError(
230
+ f"the input to RegressionHead must have channel dimension size 1, but got shape {intermediates.feature_vector.shape}"
231
+ )
232
+
233
+ logits = intermediates.feature_vector[:, 0]
220
234
 
221
235
  if self.use_sigmoid:
222
236
  outputs = torch.nn.functional.sigmoid(logits)
@@ -232,9 +246,12 @@ class RegressionHead(torch.nn.Module):
232
246
  elif self.loss_mode == "l1":
233
247
  losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
234
248
  else:
235
- assert False
249
+ raise ValueError(f"unknown loss mode {self.loss_mode}")
236
250
 
237
- return outputs, losses
251
+ return ModelOutput(
252
+ outputs=outputs,
253
+ loss_dict=losses,
254
+ )
238
255
 
239
256
 
240
257
  class RegressionMetricWrapper(Metric):
@@ -8,7 +8,8 @@ import torch
8
8
  import torchmetrics.classification
9
9
  from torchmetrics import Metric, MetricCollection
10
10
 
11
- from rslearn.utils import Feature
11
+ from rslearn.models.component import FeatureMaps, Predictor
12
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
12
13
 
13
14
  from .task import BasicTask
14
15
 
@@ -108,7 +109,7 @@ class SegmentationTask(BasicTask):
108
109
  def process_inputs(
109
110
  self,
110
111
  raw_inputs: dict[str, torch.Tensor],
111
- metadata: dict[str, Any],
112
+ metadata: SampleMetadata,
112
113
  load_targets: bool = True,
113
114
  ) -> tuple[dict[str, Any], dict[str, Any]]:
114
115
  """Processes the data into targets.
@@ -148,17 +149,20 @@ class SegmentationTask(BasicTask):
148
149
  }
149
150
 
150
151
  def process_output(
151
- self, raw_output: Any, metadata: dict[str, Any]
152
- ) -> npt.NDArray[Any] | list[Feature]:
152
+ self, raw_output: Any, metadata: SampleMetadata
153
+ ) -> npt.NDArray[Any]:
153
154
  """Processes an output into raster or vector data.
154
155
 
155
156
  Args:
156
- raw_output: the output from prediction head.
157
+ raw_output: the output from prediction head, which must be a CHW tensor.
157
158
  metadata: metadata about the patch being read
158
159
 
159
160
  Returns:
160
- either raster or vector data.
161
+ CHW numpy array with one channel, containing the predicted class IDs.
161
162
  """
163
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
164
+ raise ValueError("the output for SegmentationTask must be a CHW tensor")
165
+
162
166
  if self.prob_scales is not None:
163
167
  raw_output = (
164
168
  raw_output
@@ -166,7 +170,7 @@ class SegmentationTask(BasicTask):
166
170
  self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
167
171
  )[:, None, None]
168
172
  )
169
- classes = raw_output.argmax(dim=0).cpu().numpy().astype(np.uint8)
173
+ classes = raw_output.argmax(dim=0).cpu().numpy()
170
174
  return classes[None, :, :]
171
175
 
172
176
  def visualize(
@@ -258,25 +262,36 @@ class SegmentationTask(BasicTask):
258
262
  return MetricCollection(metrics)
259
263
 
260
264
 
261
- class SegmentationHead(torch.nn.Module):
265
+ class SegmentationHead(Predictor):
262
266
  """Head for segmentation task."""
263
267
 
264
268
  def forward(
265
269
  self,
266
- logits: torch.Tensor,
267
- inputs: list[dict[str, Any]],
270
+ intermediates: Any,
271
+ context: ModelContext,
268
272
  targets: list[dict[str, Any]] | None = None,
269
- ) -> tuple[torch.Tensor, dict[str, Any]]:
273
+ ) -> ModelOutput:
270
274
  """Compute the segmentation outputs from logits and targets.
271
275
 
272
276
  Args:
273
- logits: tensor that is (BatchSize, NumClasses, Height, Width) in shape.
274
- inputs: original inputs (ignored).
275
- targets: should contain classes key that stores the per-pixel class labels.
277
+ intermediates: a FeatureMaps with a single feature map containing the
278
+ segmentation logits.
279
+ context: the model context
280
+ targets: list of target dicts, where each target dict must contain a key
281
+ "classes" containing the per-pixel class labels, along with "valid"
282
+ containing a mask indicating where the example is valid.
276
283
 
277
284
  Returns:
278
285
  tuple of outputs and loss dict
279
286
  """
287
+ if not isinstance(intermediates, FeatureMaps):
288
+ raise ValueError("input to SegmentationHead must be a FeatureMaps")
289
+ if len(intermediates.feature_maps) != 1:
290
+ raise ValueError(
291
+ f"input to SegmentationHead must have one feature map, but got {len(intermediates.feature_maps)}"
292
+ )
293
+
294
+ logits = intermediates.feature_maps[0]
280
295
  outputs = torch.nn.functional.softmax(logits, dim=1)
281
296
 
282
297
  losses = {}
@@ -295,7 +310,10 @@ class SegmentationHead(torch.nn.Module):
295
310
  # the summed mask loss be zero.
296
311
  losses["cls"] = torch.sum(per_pixel_loss * mask)
297
312
 
298
- return outputs, losses
313
+ return ModelOutput(
314
+ outputs=outputs,
315
+ loss_dict=losses,
316
+ )
299
317
 
300
318
 
301
319
  class SegmentationMetric(Metric):
@@ -7,6 +7,7 @@ import numpy.typing as npt
7
7
  import torch
8
8
  from torchmetrics import MetricCollection
9
9
 
10
+ from rslearn.train.model_context import SampleMetadata
10
11
  from rslearn.utils import Feature
11
12
 
12
13
 
@@ -21,7 +22,7 @@ class Task:
21
22
  def process_inputs(
22
23
  self,
23
24
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
24
- metadata: dict[str, Any],
25
+ metadata: SampleMetadata,
25
26
  load_targets: bool = True,
26
27
  ) -> tuple[dict[str, Any], dict[str, Any]]:
27
28
  """Processes the data into targets.
@@ -38,7 +39,7 @@ class Task:
38
39
  raise NotImplementedError
39
40
 
40
41
  def process_output(
41
- self, raw_output: Any, metadata: dict[str, Any]
42
+ self, raw_output: Any, metadata: SampleMetadata
42
43
  ) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
43
44
  """Processes an output into raster or vector data.
44
45
 
@@ -0,0 +1,74 @@
1
+ """Resize transform."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchvision.transforms import InterpolationMode
8
+
9
+ from .transform import Transform
10
+
11
+ INTERPOLATION_MODES = {
12
+ "nearest": InterpolationMode.NEAREST,
13
+ "nearest_exact": InterpolationMode.NEAREST_EXACT,
14
+ "bilinear": InterpolationMode.BILINEAR,
15
+ "bicubic": InterpolationMode.BICUBIC,
16
+ }
17
+
18
+
19
+ class Resize(Transform):
20
+ """Resizes inputs to a target size."""
21
+
22
+ def __init__(
23
+ self,
24
+ target_size: tuple[int, int],
25
+ selectors: list[str] = [],
26
+ interpolation: str = "nearest",
27
+ ):
28
+ """Initialize a resize transform.
29
+
30
+ Args:
31
+ target_size: the (height, width) to resize to.
32
+ selectors: items to transform.
33
+ interpolation: the interpolation mode to use for resizing.
34
+ Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
35
+ """
36
+ super().__init__()
37
+ self.target_size = target_size
38
+ self.selectors = selectors
39
+ self.interpolation = INTERPOLATION_MODES[interpolation]
40
+
41
+ def apply_resize(self, image: torch.Tensor) -> torch.Tensor:
42
+ """Apply resizing on the specified image.
43
+
44
+ If the image is 2D, it is unsqueezed to 3D and then squeezed
45
+ back after resizing.
46
+
47
+ Args:
48
+ image: the image to transform.
49
+ """
50
+ if image.dim() == 2:
51
+ image = image.unsqueeze(0) # (H, W) -> (1, H, W)
52
+ result = torchvision.transforms.functional.resize(
53
+ image, self.target_size, self.interpolation
54
+ )
55
+ return result.squeeze(0) # (1, H, W) -> (H, W)
56
+
57
+ return torchvision.transforms.functional.resize(
58
+ image, self.target_size, self.interpolation
59
+ )
60
+
61
+ def forward(
62
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
63
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
64
+ """Apply transform over the inputs and targets.
65
+
66
+ Args:
67
+ input_dict: the input
68
+ target_dict: the target
69
+
70
+ Returns:
71
+ transformed (input_dicts, target_dicts) tuple
72
+ """
73
+ self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
74
+ return input_dict, target_dict
rslearn/utils/geometry.py CHANGED
@@ -116,6 +116,79 @@ class Projection:
116
116
  WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
117
 
118
118
 
119
+ class ResolutionFactor:
120
+ """Multiplier for the resolution in a Projection.
121
+
122
+ The multiplier is either an integer x, or the inverse of an integer (1/x).
123
+
124
+ Factors greater than 1 increase the projection_units/pixel resolution, increasing
125
+ the resolution (more pixels per projection unit). Factors less than 1 make it coarser
126
+ (less pixels).
127
+ """
128
+
129
+ def __init__(self, numerator: int = 1, denominator: int = 1):
130
+ """Create a new ResolutionFactor.
131
+
132
+ Args:
133
+ numerator: the numerator of the fraction.
134
+ denominator: the denominator of the fraction. If set, numerator must be 1.
135
+ """
136
+ if numerator != 1 and denominator != 1:
137
+ raise ValueError("one of numerator or denominator must be 1")
138
+ if not isinstance(numerator, int) or not isinstance(denominator, int):
139
+ raise ValueError("numerator and denominator must be integers")
140
+ if numerator < 1 or denominator < 1:
141
+ raise ValueError("numerator and denominator must be >= 1")
142
+ self.numerator = numerator
143
+ self.denominator = denominator
144
+
145
+ def multiply_projection(self, projection: Projection) -> Projection:
146
+ """Multiply the projection by this factor."""
147
+ if self.denominator > 1:
148
+ return Projection(
149
+ projection.crs,
150
+ projection.x_resolution * self.denominator,
151
+ projection.y_resolution * self.denominator,
152
+ )
153
+ else:
154
+ return Projection(
155
+ projection.crs,
156
+ projection.x_resolution // self.numerator,
157
+ projection.y_resolution // self.numerator,
158
+ )
159
+
160
+ def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
161
+ """Multiply the bounds by this factor.
162
+
163
+ When coarsening, the width and height of the given bounds must be a multiple of
164
+ the denominator.
165
+ """
166
+ if self.denominator > 1:
167
+ # Verify the width and height are multiples of the denominator.
168
+ # Otherwise the new width and height is not an integer.
169
+ width = bounds[2] - bounds[0]
170
+ height = bounds[3] - bounds[1]
171
+ if width % self.denominator != 0 or height % self.denominator != 0:
172
+ raise ValueError(
173
+ f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
174
+ )
175
+ # TODO: an offset could be introduced by bounds not being a multiple
176
+ # of the denominator -> will need to decide how to handle that.
177
+ return (
178
+ bounds[0] // self.denominator,
179
+ bounds[1] // self.denominator,
180
+ bounds[2] // self.denominator,
181
+ bounds[3] // self.denominator,
182
+ )
183
+ else:
184
+ return (
185
+ bounds[0] * self.numerator,
186
+ bounds[1] * self.numerator,
187
+ bounds[2] * self.numerator,
188
+ bounds[3] * self.numerator,
189
+ )
190
+
191
+
119
192
  class STGeometry:
120
193
  """A spatiotemporal geometry.
121
194
 
@@ -8,6 +8,7 @@ from rasterio.crs import CRS
8
8
  from upath import UPath
9
9
 
10
10
  from rslearn.config.dataset import LayerConfig
11
+ from rslearn.utils.geometry import ResolutionFactor
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from rslearn.data_sources.data_source import DataSourceContext
@@ -91,6 +92,68 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
91
92
  )
92
93
 
93
94
 
95
+ def resolution_factor_serializer(v: ResolutionFactor) -> str:
96
+ """Serialize ResolutionFactor for jsonargparse.
97
+
98
+ Args:
99
+ v: the ResolutionFactor object.
100
+
101
+ Returns:
102
+ the ResolutionFactor encoded to string
103
+ """
104
+ if hasattr(v, "init_args"):
105
+ init_args = v.init_args
106
+ return f"{init_args.numerator}/{init_args.denominator}"
107
+
108
+ return f"{v.numerator}/{v.denominator}"
109
+
110
+
111
+ def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
112
+ """Deserialize ResolutionFactor for jsonargparse.
113
+
114
+ Args:
115
+ v: the encoded ResolutionFactor.
116
+
117
+ Returns:
118
+ the decoded ResolutionFactor object
119
+ """
120
+ # Handle already-instantiated ResolutionFactor
121
+ if isinstance(v, ResolutionFactor):
122
+ return v
123
+
124
+ # Handle Namespace from class_path syntax (used during config save/validation)
125
+ if hasattr(v, "init_args"):
126
+ init_args = v.init_args
127
+ return ResolutionFactor(
128
+ numerator=init_args.numerator,
129
+ denominator=init_args.denominator,
130
+ )
131
+
132
+ # Handle dict from class_path syntax in YAML config
133
+ if isinstance(v, dict) and "init_args" in v:
134
+ init_args = v["init_args"]
135
+ return ResolutionFactor(
136
+ numerator=init_args.get("numerator", 1),
137
+ denominator=init_args.get("denominator", 1),
138
+ )
139
+
140
+ if isinstance(v, int):
141
+ return ResolutionFactor(numerator=v)
142
+ elif isinstance(v, str):
143
+ parts = v.split("/")
144
+ if len(parts) == 1:
145
+ return ResolutionFactor(numerator=int(parts[0]))
146
+ elif len(parts) == 2:
147
+ return ResolutionFactor(
148
+ numerator=int(parts[0]),
149
+ denominator=int(parts[1]),
150
+ )
151
+ else:
152
+ raise ValueError("expected resolution factor to be of the form x or 1/x")
153
+ else:
154
+ raise ValueError("expected resolution factor to be str or int")
155
+
156
+
94
157
  def init_jsonargparse() -> None:
95
158
  """Initialize custom jsonargparse serializers."""
96
159
  global INITIALIZED
@@ -100,6 +163,9 @@ def init_jsonargparse() -> None:
100
163
  jsonargparse.typing.register_type(
101
164
  datetime, datetime_serializer, datetime_deserializer
102
165
  )
166
+ jsonargparse.typing.register_type(
167
+ ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
168
+ )
103
169
 
104
170
  from rslearn.data_sources.data_source import DataSourceContext
105
171
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.17
3
+ Version: 0.0.19
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License