rslearn 0.0.9__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. rslearn/models/anysat.py +5 -1
  2. rslearn/models/dinov3.py +6 -1
  3. rslearn/models/feature_center_crop.py +50 -0
  4. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  5. rslearn/models/olmoearth_pretrain/model.py +263 -0
  6. rslearn/models/olmoearth_pretrain/norm.py +84 -0
  7. rslearn/models/pooling_decoder.py +43 -0
  8. rslearn/models/prithvi.py +9 -1
  9. rslearn/train/lightning_module.py +0 -3
  10. rslearn/train/tasks/classification.py +2 -2
  11. rslearn/train/tasks/detection.py +5 -5
  12. rslearn/train/tasks/per_pixel_regression.py +5 -4
  13. rslearn/train/tasks/regression.py +5 -5
  14. rslearn/train/transforms/pad.py +3 -3
  15. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/METADATA +3 -1
  16. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/RECORD +21 -25
  17. rslearn-0.0.12.dist-info/licenses/NOTICE +115 -0
  18. rslearn/models/copernicusfm.py +0 -228
  19. rslearn/models/copernicusfm_src/__init__.py +0 -1
  20. rslearn/models/copernicusfm_src/aurora/area.py +0 -50
  21. rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
  22. rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
  23. rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
  24. rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
  25. rslearn/models/copernicusfm_src/model_vit.py +0 -348
  26. rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
  27. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/WHEEL +0 -0
  28. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/entry_points.txt +0 -0
  29. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/licenses/LICENSE +0 -0
  30. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/top_level.txt +0 -0
rslearn/models/anysat.py CHANGED
@@ -1,4 +1,8 @@
1
- """AnySat model."""
1
+ """AnySat model.
2
+
3
+ This code loads the AnySat model from torch hub. See
4
+ https://github.com/gastruc/AnySat for applicable license and copyright information.
5
+ """
2
6
 
3
7
  from typing import Any
4
8
 
rslearn/models/dinov3.py CHANGED
@@ -1,4 +1,9 @@
1
- """DinoV3 model."""
1
+ """DinoV3 model.
2
+
3
+ This code loads the DINOv3 model. You must obtain the model separately from Meta to use
4
+ it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
5
+ information.
6
+ """
2
7
 
3
8
  from enum import StrEnum
4
9
  from pathlib import Path
@@ -0,0 +1,50 @@
1
+ """Apply center cropping on a feature map."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+
8
+ class FeatureCenterCrop(torch.nn.Module):
9
+ """Apply center cropping on the input feature maps."""
10
+
11
+ def __init__(
12
+ self,
13
+ sizes: list[tuple[int, int]],
14
+ ) -> None:
15
+ """Create a new FeatureCenterCrop.
16
+
17
+ Only the center of each feature map will be retained and passed to the next
18
+ module.
19
+
20
+ Args:
21
+ sizes: a list of (height, width) tuples, with one tuple for each input
22
+ feature map.
23
+ """
24
+ super().__init__()
25
+ self.sizes = sizes
26
+
27
+ def forward(
28
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
29
+ ) -> list[torch.Tensor]:
30
+ """Apply center cropping on the feature maps.
31
+
32
+ Args:
33
+ features: list of feature maps at different resolutions.
34
+ inputs: original inputs (ignored).
35
+
36
+ Returns:
37
+ center cropped feature maps.
38
+ """
39
+ new_features = []
40
+ for i, feat in enumerate(features):
41
+ height, width = self.sizes[i]
42
+ if feat.shape[2] < height or feat.shape[3] < width:
43
+ raise ValueError(
44
+ "feature map is smaller than the desired height and width"
45
+ )
46
+ start_h = feat.shape[2] // 2 - height // 2
47
+ start_w = feat.shape[3] // 2 - width // 2
48
+ feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
49
+ new_features.append(feat)
50
+ return new_features
@@ -0,0 +1 @@
1
+ """OlmoEarth model architecture."""
@@ -0,0 +1,263 @@
1
+ """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
+
3
+ import json
4
+ from contextlib import nullcontext
5
+ from typing import Any
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from olmo_core.config import Config
10
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
+ from olmoearth_pretrain.data.constants import Modality
12
+ from olmoearth_pretrain.model_loader import (
13
+ ModelID,
14
+ load_model_from_id,
15
+ load_model_from_path,
16
+ )
17
+ from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
18
+ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
+ from upath import UPath
20
+
21
+ from rslearn.log_utils import get_logger
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ MODALITY_NAMES = [
26
+ "sentinel2_l2a",
27
+ "sentinel1",
28
+ "worldcover",
29
+ "openstreetmap_raster",
30
+ "landsat",
31
+ ]
32
+
33
+ AUTOCAST_DTYPE_MAP = {
34
+ "bfloat16": torch.bfloat16,
35
+ "float16": torch.float16,
36
+ "float32": torch.float32,
37
+ }
38
+
39
+ EMBEDDING_SIZES = {
40
+ ModelID.OLMOEARTH_V1_NANO: 128,
41
+ ModelID.OLMOEARTH_V1_TINY: 192,
42
+ ModelID.OLMOEARTH_V1_BASE: 768,
43
+ }
44
+
45
+
46
+ class OlmoEarth(torch.nn.Module):
47
+ """A wrapper to support the OlmoEarth model."""
48
+
49
+ def __init__(
50
+ self,
51
+ patch_size: int,
52
+ model_id: ModelID | None = None,
53
+ model_path: str | None = None,
54
+ checkpoint_path: str | None = None,
55
+ selector: list[str | int] = ["encoder"],
56
+ forward_kwargs: dict[str, Any] = {},
57
+ random_initialization: bool = False,
58
+ embedding_size: int | None = None,
59
+ autocast_dtype: str | None = "bfloat16",
60
+ ):
61
+ """Create a new OlmoEarth model.
62
+
63
+ Args:
64
+ patch_size: token spatial patch size to use.
65
+ model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
66
+ set.
67
+ model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
68
+ set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
69
+ checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
70
+ set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
71
+ folder.
72
+ selector: an optional sequence of attribute names or list indices to select
73
+ the sub-module that should be applied on the input images. Defaults to
74
+ ["encoder"] to select only the transformer encoder.
75
+ forward_kwargs: additional arguments to pass to forward pass besides the
76
+ MaskedOlmoEarthSample.
77
+ random_initialization: whether to skip loading the checkpoint so the
78
+ weights are randomly initialized. In this case, the checkpoint is only
79
+ used to define the model architecture.
80
+ embedding_size: optional embedding size to report via
81
+ get_backbone_channels (if model_id is not set).
82
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
83
+ """
84
+ if (
85
+ sum(
86
+ [
87
+ model_id is not None,
88
+ model_path is not None,
89
+ checkpoint_path is not None,
90
+ ]
91
+ )
92
+ != 1
93
+ ):
94
+ raise ValueError(
95
+ "exactly one of model_id, model_path, or checkpoint_path must be set"
96
+ )
97
+
98
+ super().__init__()
99
+ self.patch_size = patch_size
100
+ self.forward_kwargs = forward_kwargs
101
+ self.embedding_size = embedding_size
102
+
103
+ if autocast_dtype is not None:
104
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
105
+ else:
106
+ self.autocast_dtype = None
107
+
108
+ if model_id is not None:
109
+ # Load from Hugging Face.
110
+ model = load_model_from_id(model_id, load_weights=not random_initialization)
111
+ if self.embedding_size is None and model_id in EMBEDDING_SIZES:
112
+ self.embedding_size = EMBEDDING_SIZES[model_id]
113
+
114
+ elif model_path is not None:
115
+ # Load from path.
116
+ model = load_model_from_path(
117
+ UPath(model_path), load_weights=not random_initialization
118
+ )
119
+
120
+ else:
121
+ # Load the distributed model checkpoint by path through Olmo Core
122
+ model = self._load_model_from_checkpoint(
123
+ UPath(checkpoint_path), random_initialization
124
+ )
125
+
126
+ # Select just the portion of the model that we actually want to use.
127
+ for part in selector:
128
+ if isinstance(part, str):
129
+ model = getattr(model, part)
130
+ else:
131
+ model = model[part]
132
+ self.model = model
133
+
134
+ def _load_model_from_checkpoint(
135
+ self, checkpoint_upath: UPath, random_initialization: bool
136
+ ) -> torch.nn.Module:
137
+ """Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
138
+
139
+ The folder should contain config.json as well as the model_and_optim folder
140
+ that contains the distributed checkpoint. This is the format produced by
141
+ pre-training runs in olmoearth_pretrain.
142
+ """
143
+ # Load the model config and initialize it.
144
+ # We avoid loading the train module here because it depends on running within
145
+ # olmo_core.
146
+ with (checkpoint_upath / "config.json").open() as f:
147
+ config_dict = json.load(f)
148
+ model_config = Config.from_dict(config_dict["model"])
149
+
150
+ model = model_config.build()
151
+
152
+ # Load the checkpoint.
153
+ if not random_initialization:
154
+ train_module_dir = checkpoint_upath / "model_and_optim"
155
+ if train_module_dir.exists():
156
+ load_model_and_optim_state(str(train_module_dir), model)
157
+ logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
158
+ else:
159
+ logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
160
+
161
+ return model
162
+
163
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
164
+ """Compute feature maps from the OlmoEarth backbone.
165
+
166
+ Inputs:
167
+ inputs: input dicts. It should include keys corresponding to the modalities
168
+ that should be passed to the OlmoEarth model.
169
+ """
170
+ kwargs = {}
171
+ present_modalities = []
172
+ device = None
173
+ # Handle the case where some modalities are multitemporal and some are not.
174
+ # We assume all multitemporal modalities have the same number of timesteps.
175
+ max_timesteps = 1
176
+ for modality in MODALITY_NAMES:
177
+ if modality not in inputs[0]:
178
+ continue
179
+ present_modalities.append(modality)
180
+ cur = torch.stack([inp[modality] for inp in inputs], dim=0)
181
+ device = cur.device
182
+ # Check if it's single or multitemporal, and reshape accordingly
183
+ num_bands = Modality.get(modality).num_bands
184
+ num_timesteps = cur.shape[1] // num_bands
185
+ max_timesteps = max(max_timesteps, num_timesteps)
186
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
187
+ kwargs[modality] = cur
188
+ # Create mask array which is BHWTS (without channels but with band sets).
189
+ num_band_sets = len(Modality.get(modality).band_sets)
190
+ mask_shape = cur.shape[0:4] + (num_band_sets,)
191
+ mask = (
192
+ torch.ones(mask_shape, dtype=torch.int32, device=device)
193
+ * MaskValue.ONLINE_ENCODER.value
194
+ )
195
+ kwargs[f"{modality}_mask"] = mask
196
+
197
+ # Timestamps is required.
198
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
199
+ # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
200
+ timestamps = torch.zeros(
201
+ (len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
202
+ )
203
+ timestamps[:, :, 0] = 1 # day
204
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
205
+ None, :
206
+ ] # month
207
+ timestamps[:, :, 2] = 2024 # year
208
+ kwargs["timestamps"] = timestamps
209
+
210
+ sample = MaskedOlmoEarthSample(**kwargs)
211
+
212
+ # Decide context based on self.autocast_dtype.
213
+ if self.autocast_dtype is None:
214
+ context = nullcontext()
215
+ else:
216
+ assert device is not None
217
+ context = torch.amp.autocast(
218
+ device_type=device.type, dtype=self.autocast_dtype
219
+ )
220
+
221
+ with context:
222
+ # Currently we assume the provided model always returns a TokensAndMasks object.
223
+ tokens_and_masks: TokensAndMasks
224
+ if isinstance(self.model, Encoder):
225
+ # Encoder has a fast_pass argument to indicate mask is not needed.
226
+ tokens_and_masks = self.model(
227
+ sample,
228
+ fast_pass=True,
229
+ patch_size=self.patch_size,
230
+ **self.forward_kwargs,
231
+ )["tokens_and_masks"]
232
+ else:
233
+ # Other models like STEncoder do not have this option supported.
234
+ tokens_and_masks = self.model(
235
+ sample, patch_size=self.patch_size, **self.forward_kwargs
236
+ )["tokens_and_masks"]
237
+
238
+ # Apply temporal/modality pooling so we just have one feature per patch.
239
+ features = []
240
+ for modality in present_modalities:
241
+ modality_features = getattr(tokens_and_masks, modality)
242
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
243
+ pooled = modality_features.mean(dim=[3, 4])
244
+ # We want BHWC -> BCHW.
245
+ pooled = rearrange(pooled, "b h w c -> b c h w")
246
+ features.append(pooled)
247
+ # Pool over the modalities, so we get one BCHW feature map.
248
+ pooled = torch.stack(features, dim=0).mean(dim=0)
249
+ return [pooled]
250
+
251
+ def get_backbone_channels(self) -> list:
252
+ """Returns the output channels of this model when used as a backbone.
253
+
254
+ The output channels is a list of (downsample_factor, depth) that corresponds
255
+ to the feature maps that the backbone returns. For example, an element [2, 32]
256
+ indicates that the corresponding feature map is 1/2 the input resolution and
257
+ has 32 channels.
258
+
259
+ Returns:
260
+ the output channels of the backbone as a list of (downsample_factor, depth)
261
+ tuples.
262
+ """
263
+ return [(self.patch_size, self.embedding_size)]
@@ -0,0 +1,84 @@
1
+ """Normalization transforms."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from olmoearth_pretrain.data.normalize import load_computed_config
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.train.transforms.transform import Transform
10
+
11
+ logger = get_logger(__file__)
12
+
13
+
14
+ class OlmoEarthNormalize(Transform):
15
+ """Normalize using OlmoEarth JSON config.
16
+
17
+ For Sentinel-1 data, the values should be converted to decibels before being passed
18
+ to this transform.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ band_names: dict[str, list[str]],
24
+ std_multiplier: float | None = 2,
25
+ config_fname: str | None = None,
26
+ ) -> None:
27
+ """Initialize a new OlmoEarthNormalize.
28
+
29
+ Args:
30
+ band_names: map from modality name to the list of bands in that modality in
31
+ the order they are being loaded. Note that this order must match the
32
+ expected order for the OlmoEarth model.
33
+ std_multiplier: the std multiplier matching the one used for the model
34
+ training in OlmoEarth.
35
+ config_fname: load the normalization configuration from this file, instead
36
+ of getting it from OlmoEarth.
37
+ """
38
+ super().__init__()
39
+ self.band_names = band_names
40
+ self.std_multiplier = std_multiplier
41
+
42
+ if config_fname is None:
43
+ self.norm_config = load_computed_config()
44
+ else:
45
+ logger.warning(
46
+ f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
47
+ )
48
+ with open(config_fname) as f:
49
+ self.norm_config = json.load(f)
50
+
51
+ def forward(
52
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
53
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
54
+ """Apply normalization over the inputs and targets.
55
+
56
+ Args:
57
+ input_dict: the input
58
+ target_dict: the target
59
+
60
+ Returns:
61
+ normalized (input_dicts, target_dicts) tuple
62
+ """
63
+ for modality_name, cur_band_names in self.band_names.items():
64
+ band_norms = self.norm_config[modality_name]
65
+ image = input_dict[modality_name]
66
+ # Keep a set of indices to make sure that we normalize all of them.
67
+ needed_band_indices = set(range(image.shape[0]))
68
+ num_timesteps = image.shape[0] // len(cur_band_names)
69
+
70
+ for band, norm_dict in band_norms.items():
71
+ # If multitemporal, normalize each timestep separately.
72
+ for t in range(num_timesteps):
73
+ band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
+ min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
+ max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
+ image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
77
+ needed_band_indices.remove(band_idx)
78
+
79
+ if len(needed_band_indices) > 0:
80
+ raise ValueError(
81
+ f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
82
+ )
83
+
84
+ return input_dict, target_dict
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
76
76
  features = torch.amax(features, dim=(2, 3))
77
77
  features = self.fc_layers(features)
78
78
  return self.output_layer(features)
79
+
80
+
81
+ class SegmentationPoolingDecoder(PoolingDecoder):
82
+ """Like PoolingDecoder, but copy output to all pixels.
83
+
84
+ This allows for the model to produce a global output while still being compatible
85
+ with SegmentationTask. This only makes sense for very small windows, since the
86
+ output probabilities will be the same at all pixels. The main use case is to train
87
+ for a classification-like task on small windows, but still produce a raster during
88
+ inference on large windows.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ image_key: str = "image",
96
+ **kwargs: Any,
97
+ ):
98
+ """Create a new SegmentationPoolingDecoder.
99
+
100
+ Args:
101
+ in_channels: input channels (channels in the last feature map passed to
102
+ this module)
103
+ out_channels: channels for the output flat feature vector
104
+ image_key: the key in inputs for the image from which the expected width
105
+ and height is derived.
106
+ kwargs: other arguments to pass to PoolingDecoder.
107
+ """
108
+ super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
109
+ self.image_key = image_key
110
+
111
+ def forward(
112
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
113
+ ) -> torch.Tensor:
114
+ """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
115
+
116
+ This only works when all of the pixels have the same segmentation target.
117
+ """
118
+ output_probs = super().forward(features, inputs)
119
+ # BC -> BCHW
120
+ h, w = inputs[0][self.image_key].shape[1:3]
121
+ return output_probs[:, :, None, None].repeat([1, 1, h, w])
rslearn/models/prithvi.py CHANGED
@@ -1,4 +1,12 @@
1
- """Prithvi V2."""
1
+ """Prithvi V2.
2
+
3
+ This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
4
+
5
+ The code is released under:
6
+
7
+ MIT License
8
+ Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
9
+ """
2
10
 
3
11
  import json
4
12
  import logging
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
94
94
  restore_config: RestoreConfig | None = None,
95
95
  print_parameters: bool = False,
96
96
  print_model: bool = False,
97
- strict_loading: bool = True,
98
97
  # Deprecated options.
99
98
  lr: float = 1e-3,
100
99
  plateau: bool = False,
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
118
117
  print_parameters: whether to print the list of model parameters after model
119
118
  initialization
120
119
  print_model: whether to print the model after model initialization
121
- strict_loading: whether to strictly load the model parameters.
122
120
  lr: deprecated.
123
121
  plateau: deprecated.
124
122
  plateau_factor: deprecated.
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
132
130
  self.visualize_dir = visualize_dir
133
131
  self.metrics_file = metrics_file
134
132
  self.restore_config = restore_config
135
- self.strict_loading = strict_loading
136
133
 
137
134
  self.scheduler_factory: SchedulerFactory | None = None
138
135
  if scheduler:
@@ -49,8 +49,8 @@ class ClassificationTask(BasicTask):
49
49
  features with matching properties.
50
50
  read_class_id: whether to read an integer class ID instead of the class
51
51
  name.
52
- allow_invalid: instead of throwing error when no regression label is found
53
- at a window, simply mark the example invalid for this task
52
+ allow_invalid: instead of throwing error when no classification label is
53
+ found at a window, simply mark the example invalid for this task
54
54
  skip_unknown_categories: whether to skip examples with categories that are
55
55
  not passed via classes, instead of throwing error
56
56
  prob_property: when predicting, write probabilities in addition to class ID
@@ -72,11 +72,11 @@ class DetectionTask(BasicTask):
72
72
  f1_metric_kwargs: dict[str, Any] = {},
73
73
  **kwargs: Any,
74
74
  ) -> None:
75
- """Initialize a new SegmentationTask.
75
+ """Initialize a new DetectionTask.
76
76
 
77
77
  Args:
78
- property_name: the property from which to extract the class name. The class
79
- is read from the first matching feature.
78
+ property_name: the property from which to extract the class name. Features
79
+ without this property name are ignored.
80
80
  classes: a list of class names.
81
81
  filters: optional list of (property_name, property_value) to only consider
82
82
  features with matching properties.
@@ -86,8 +86,8 @@ class DetectionTask(BasicTask):
86
86
  not passed via classes, instead of throwing error
87
87
  skip_empty_examples: whether to skip examples with zero labels.
88
88
  colors: optional colors for each class
89
- box_size: force all boxes to be this size, centered at the centroid of the
90
- geometry. Required for Point geometries.
89
+ box_size: force all boxes to be two times this size, centered at the
90
+ centroid of the geometry. Required for Point geometries.
91
91
  clip_boxes: whether to clip boxes to the image bounds.
92
92
  exclude_by_center: before optionally clipping boxes, exclude boxes if the
93
93
  center is outside the image bounds.
@@ -26,10 +26,11 @@ class PerPixelRegressionTask(BasicTask):
26
26
  """Initialize a new PerPixelRegressionTask.
27
27
 
28
28
  Args:
29
- scale_factor: multiply the label value by this factor before using it for
29
+ scale_factor: multiply ground truth values by this factor before using it for
30
30
  training.
31
- metric_mode: what metric to use, either mse or l1
32
- nodata_value: optional value to treat as invalid
31
+ metric_mode: what metric to use, either "mse" (default) or "l1"
32
+ nodata_value: optional value to treat as invalid. The loss will be masked
33
+ at pixels where the ground truth value is equal to nodata_value.
33
34
  kwargs: other arguments to pass to BasicTask
34
35
  """
35
36
  super().__init__(**kwargs)
@@ -141,7 +142,7 @@ class PerPixelRegressionHead(torch.nn.Module):
141
142
  """Initialize a new RegressionHead.
142
143
 
143
144
  Args:
144
- loss_mode: the loss function to use, either "mse" or "l1".
145
+ loss_mode: the loss function to use, either "mse" (default) or "l1".
145
146
  use_sigmoid: whether to apply a sigmoid activation on the output. This
146
147
  requires targets to be between 0-1.
147
148
  """
@@ -33,14 +33,14 @@ class RegressionTask(BasicTask):
33
33
  """Initialize a new RegressionTask.
34
34
 
35
35
  Args:
36
- property_name: the property from which to extract the regression value. The
37
- value is read from the first matching feature.
36
+ property_name: the property from which to extract the ground truth
37
+ regression value. The value is read from the first matching feature.
38
38
  filters: optional list of (property_name, property_value) to only consider
39
39
  features with matching properties.
40
40
  allow_invalid: instead of throwing error when no regression label is found
41
41
  at a window, simply mark the example invalid for this task
42
- scale_factor: multiply the label value by this factor
43
- metric_mode: what metric to use, either mse or l1
42
+ scale_factor: multiply the label value by this factor for training
43
+ metric_mode: what metric to use, either "mse" (default) or "l1"
44
44
  use_accuracy_metric: include metric that reports percentage of
45
45
  examples where output is within a factor of the ground truth.
46
46
  within_factor: the factor for accuracy metric. If it's 0.2, and ground
@@ -189,7 +189,7 @@ class RegressionHead(torch.nn.Module):
189
189
  """Initialize a new RegressionHead.
190
190
 
191
191
  Args:
192
- loss_mode: the loss function to use, either "mse" or "l1".
192
+ loss_mode: the loss function to use, either "mse" (default) or "l1".
193
193
  use_sigmoid: whether to apply a sigmoid activation on the output. This
194
194
  requires targets to be between 0-1.
195
195
  """
@@ -25,8 +25,8 @@ class Pad(Transform):
25
25
  Args:
26
26
  size: the size to pad to, or a min/max range of pad sizes. If the image is
27
27
  larger than this size, then it is cropped instead.
28
- mode: "center" (default) to apply padding equally on all sides, or
29
- "topleft" to only apply it on the bottom and right.
28
+ mode: "topleft" (default) to only apply padding on the bottom and right
29
+ sides, or "center" to apply padding equally on all sides.
30
30
  image_selectors: image items to transform.
31
31
  box_selectors: boxes items to transform.
32
32
  """
@@ -64,7 +64,7 @@ class Pad(Transform):
64
64
  ) -> torch.Tensor:
65
65
  # Before/after must either be both non-negative or both negative.
66
66
  # >=0 indicates padding while <0 indicates cropping.
67
- assert (before < 0 and after < 0) or (before >= 0 and after >= 0)
67
+ assert (before < 0 and after <= 0) or (before >= 0 and after >= 0)
68
68
  if before > 0:
69
69
  # Padding.
70
70
  if horizontal:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.9
3
+ Version: 0.0.12
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -211,6 +211,7 @@ Project-URL: repository, https://github.com/allenai/rslearn
211
211
  Requires-Python: >=3.11
212
212
  Description-Content-Type: text/markdown
213
213
  License-File: LICENSE
214
+ License-File: NOTICE
214
215
  Requires-Dist: boto3>=1.39
215
216
  Requires-Dist: fiona>=1.10
216
217
  Requires-Dist: fsspec>=2025.9.0
@@ -243,6 +244,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
243
244
  Requires-Dist: pycocotools>=2.0; extra == "extra"
244
245
  Requires-Dist: pystac_client>=0.9; extra == "extra"
245
246
  Requires-Dist: rtree>=1.4; extra == "extra"
247
+ Requires-Dist: termcolor>=3.0; extra == "extra"
246
248
  Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
247
249
  Requires-Dist: scipy>=1.16; extra == "extra"
248
250
  Requires-Dist: terratorch>=1.0.2; extra == "extra"