rslearn 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. rslearn/arg_parser.py +59 -0
  2. rslearn/data_sources/copernicus.py +10 -8
  3. rslearn/data_sources/earthdaily.py +21 -1
  4. rslearn/data_sources/eurocrops.py +246 -0
  5. rslearn/data_sources/gcp_public_data.py +3 -3
  6. rslearn/data_sources/local_files.py +11 -0
  7. rslearn/data_sources/openstreetmap.py +2 -4
  8. rslearn/data_sources/utils.py +1 -17
  9. rslearn/main.py +10 -1
  10. rslearn/models/copernicusfm.py +216 -0
  11. rslearn/models/copernicusfm_src/__init__.py +1 -0
  12. rslearn/models/copernicusfm_src/aurora/area.py +50 -0
  13. rslearn/models/copernicusfm_src/aurora/fourier.py +134 -0
  14. rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +523 -0
  15. rslearn/models/copernicusfm_src/flexivit/patch_embed.py +260 -0
  16. rslearn/models/copernicusfm_src/flexivit/utils.py +69 -0
  17. rslearn/models/copernicusfm_src/model_vit.py +348 -0
  18. rslearn/models/copernicusfm_src/util/pos_embed.py +216 -0
  19. rslearn/models/panopticon.py +167 -0
  20. rslearn/models/presto/__init__.py +5 -0
  21. rslearn/models/presto/presto.py +247 -0
  22. rslearn/models/presto/single_file_presto.py +932 -0
  23. rslearn/models/trunk.py +0 -144
  24. rslearn/models/unet.py +15 -0
  25. rslearn/train/callbacks/adapters.py +53 -0
  26. rslearn/train/callbacks/freeze_unfreeze.py +319 -0
  27. rslearn/train/callbacks/gradients.py +54 -34
  28. rslearn/train/data_module.py +70 -41
  29. rslearn/train/dataset.py +232 -54
  30. rslearn/train/lightning_module.py +4 -0
  31. rslearn/train/prediction_writer.py +7 -0
  32. rslearn/train/scheduler.py +15 -0
  33. rslearn/train/tasks/per_pixel_regression.py +259 -0
  34. rslearn/train/tasks/regression.py +6 -4
  35. rslearn/train/tasks/segmentation.py +44 -14
  36. rslearn/train/transforms/mask.py +69 -0
  37. rslearn/utils/geometry.py +8 -8
  38. {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/METADATA +6 -3
  39. {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/RECORD +43 -27
  40. rslearn/models/moe/distributed.py +0 -262
  41. rslearn/models/moe/soft.py +0 -676
  42. {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/WHEEL +0 -0
  43. {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/entry_points.txt +0 -0
  44. {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/licenses/LICENSE +0 -0
  45. {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,216 @@
1
+ """Copernicus FM model."""
2
+
3
+ import logging
4
+ import math
5
+ from enum import Enum
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from upath import UPath
11
+
12
+ from .copernicusfm_src.model_vit import vit_base_patch16
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class CopernicusFMModality(Enum):
18
+ """Modality for Copernicus FM."""
19
+
20
+ SENTINEL2_L2A = "sentinel2_l2a"
21
+ SENTINEL1 = "sentinel1"
22
+
23
+
24
+ MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
25
+ # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml
26
+ CopernicusFMModality.SENTINEL2_L2A.value: {
27
+ "band_names": [
28
+ "B01",
29
+ "B02",
30
+ "B03",
31
+ "B04",
32
+ "B05",
33
+ "B06",
34
+ "B07",
35
+ "B08",
36
+ "B8A",
37
+ "B09",
38
+ "B10",
39
+ "B11",
40
+ "B12",
41
+ ],
42
+ "band_wavelengths": [
43
+ 440,
44
+ 490,
45
+ 560,
46
+ 665,
47
+ 705,
48
+ 740,
49
+ 783,
50
+ 842,
51
+ 860,
52
+ 940,
53
+ 1370,
54
+ 1610,
55
+ 2190,
56
+ ],
57
+ "band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180],
58
+ },
59
+ # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml
60
+ CopernicusFMModality.SENTINEL1.value: {
61
+ "band_names": ["vv", "vh"],
62
+ "band_wavelengths": [50000000, 50000000],
63
+ "band_bandwidths": [1e9, 1e9],
64
+ },
65
+ }
66
+
67
+
68
+ class CopernicusFM(torch.nn.Module):
69
+ """Wrapper for Copernicus FM to ingest Masked Helios Sample."""
70
+
71
+ image_resolution = 224
72
+ patch_size = 16
73
+ input_mode = "spectral"
74
+ # Don't need this as band order is provided
75
+ supported_modalities = [
76
+ CopernicusFMModality.SENTINEL2_L2A.value,
77
+ CopernicusFMModality.SENTINEL1.value,
78
+ ]
79
+
80
+ def __init__(
81
+ self,
82
+ band_order: dict[str, list[str]],
83
+ load_directory: str | None,
84
+ ) -> None:
85
+ """Initialize the Copernicus FM wrapper.
86
+
87
+ Args:
88
+ band_order: The band order for each modality
89
+ load_directory: The directory to load from, if None no weights are loaded
90
+ """
91
+ super().__init__()
92
+
93
+ # global_pool=True so that we initialize the fc_norm layer
94
+ self.band_order = band_order
95
+ self.model = vit_base_patch16(num_classes=10, global_pool=True)
96
+ if load_directory is not None:
97
+ check_point = torch.load(
98
+ UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth",
99
+ weights_only=True,
100
+ )
101
+ if "model" in check_point:
102
+ state_dict = check_point["model"]
103
+ else:
104
+ state_dict = check_point
105
+ self.model.load_state_dict(state_dict, strict=False)
106
+
107
+ # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same
108
+ # ordering as the Helios band orders, defined by Modality.band_order
109
+ self.modality_to_wavelength_bandwidths = {}
110
+ for modality in self.supported_modalities:
111
+ wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
112
+ wavelengths = []
113
+ bandwidths = []
114
+ modality_band_order = self.band_order.get(modality, None)
115
+ if modality_band_order is None:
116
+ logger.warning(
117
+ f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified"
118
+ )
119
+ continue
120
+ for b in modality_band_order:
121
+ cfm_idx = wavelength_bandwidths["band_names"].index(b)
122
+ wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
123
+ bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
124
+ self.modality_to_wavelength_bandwidths[modality] = {
125
+ "band_bandwidths": bandwidths,
126
+ "band_wavelengths": wavelengths,
127
+ }
128
+
129
+ def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
130
+ """Process individual modality data.
131
+
132
+ Args:
133
+ data: Input tensor of shape [B, C, H, W]
134
+
135
+ Returns:
136
+ list of tensors of shape [B, C, H, W]
137
+ """
138
+ # Get original dimensions
139
+ original_height = data.shape[2]
140
+ new_height = self.patch_size if original_height == 1 else self.image_resolution
141
+ data = F.interpolate(
142
+ data,
143
+ size=(new_height, new_height),
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+ return data
148
+
149
+ def prepare_input(
150
+ self,
151
+ inputs: dict[str, torch.Tensor],
152
+ ) -> tuple[torch.Tensor, list[int], list[int]]:
153
+ """Prepare input for the CopernicusFM model from MaskedHeliosSample."""
154
+ wavelengths: list[int] = []
155
+ bandwidths: list[int] = []
156
+ all_processed_data: list[list[torch.Tensor]] = []
157
+ for modality in inputs.keys():
158
+ if modality not in self.supported_modalities:
159
+ logger.debug(
160
+ f"Skipping modality {modality} as it is not in the supported "
161
+ f"modalities list {self.supported_modalities}"
162
+ )
163
+ continue
164
+
165
+ data = inputs[modality]
166
+
167
+ if data is None:
168
+ continue
169
+
170
+ all_processed_data.append(self._resize_data(data))
171
+ wavelengths.extend(
172
+ self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"]
173
+ )
174
+ bandwidths.extend(
175
+ self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"]
176
+ )
177
+
178
+ concatenated_processed_data = torch.cat(all_processed_data, dim=1)
179
+ return concatenated_processed_data, wavelengths, bandwidths
180
+
181
+ def forward(
182
+ self,
183
+ inputs: list[dict[str, torch.Tensor]],
184
+ ) -> torch.Tensor:
185
+ """Forward pass through CopernicusFM model."""
186
+ batch_inputs = {
187
+ key: torch.stack([inp[key] for inp in inputs], dim=0)
188
+ for key in inputs[0].keys()
189
+ }
190
+ # Prepare input
191
+ data, wavelengths, bandwidths = self.prepare_input(batch_inputs)
192
+ meta = torch.full(
193
+ (1, 4), float("nan"), device=data.device
194
+ ) # [lon, lat, delta_time, patch_token_area], assume unknown
195
+ # "The embed tensor contains the encoded image features, which can be used for downstream tasks."
196
+ _, timestep_output = self.model(
197
+ data,
198
+ meta,
199
+ wavelengths,
200
+ bandwidths,
201
+ None,
202
+ self.input_mode,
203
+ self.patch_size,
204
+ )
205
+ # no norm, following
206
+ # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py
207
+ side = math.isqrt(timestep_output.shape[1])
208
+ output_features = rearrange(
209
+ timestep_output, "b (h w) c -> b c h w ", h=side, w=side
210
+ )
211
+ return [output_features]
212
+
213
+ def get_backbone_channels(self) -> list[tuple[int, int]]:
214
+ """Returns the output channels of this model when used as a backbone."""
215
+ # TODO: load this from a constant depending on the model size
216
+ return [(self.patch_size, 768)]
@@ -0,0 +1 @@
1
+ # mypy: ignore-errors
@@ -0,0 +1,50 @@
1
+ """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
2
+
3
+ import torch
4
+
5
+ __all__ = ["area", "radius_earth"]
6
+
7
+
8
+ # float: Radius of the earth in kilometers.
9
+ radius_earth = 6378137 / 1000
10
+
11
+
12
+ def area(polygon: torch.Tensor) -> torch.Tensor:
13
+ """Compute the area of a polygon specified by latitudes and longitudes in degrees.
14
+
15
+ This function is a PyTorch port of the PyPI package `area`. In particular, it is heavily
16
+ inspired by the following file:
17
+
18
+ https://github.com/scisco/area/blob/9d9549d6ebffcbe4bffe11b71efa2d406d1c9fe9/area/__init__.py
19
+
20
+ Args:
21
+ polygon (:class:`torch.Tensor`): Polygon of the shape `(*b, n, 2)` where `b` is an optional
22
+ multidimensional batch size, `n` is the number of points of the polygon, and 2
23
+ concatenates first latitudes and then longitudes. The polygon does not have be closed.
24
+
25
+ Returns:
26
+ :class:`torch.Tensor`: Area in square kilometers.
27
+ """
28
+ # Be sure to close the loop.
29
+ polygon = torch.cat((polygon, polygon[..., -1:, :]), axis=-2)
30
+
31
+ area = torch.zeros(polygon.shape[:-2], dtype=polygon.dtype, device=polygon.device)
32
+ n = polygon.shape[-2] # Number of points of the polygon
33
+
34
+ rad = torch.deg2rad # Convert degrees to radians.
35
+
36
+ if n > 2:
37
+ for i in range(n):
38
+ i_lower = i
39
+ i_middle = (i + 1) % n
40
+ i_upper = (i + 2) % n
41
+
42
+ lon_lower = polygon[..., i_lower, 1]
43
+ lat_middle = polygon[..., i_middle, 0]
44
+ lon_upper = polygon[..., i_upper, 1]
45
+
46
+ area = area + (rad(lon_upper) - rad(lon_lower)) * torch.sin(rad(lat_middle))
47
+
48
+ area = area * radius_earth * radius_earth / 2
49
+
50
+ return torch.abs(area)
@@ -0,0 +1,134 @@
1
+ # type: ignore
2
+ """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .area import area, radius_earth
11
+
12
+ __all__ = [
13
+ "FourierExpansion",
14
+ "pos_expansion",
15
+ "scale_expansion",
16
+ "lead_time_expansion",
17
+ "levels_expansion",
18
+ "absolute_time_expansion",
19
+ ]
20
+
21
+
22
+ class FourierExpansion(nn.Module):
23
+ """A Fourier series-style expansion into a high-dimensional space.
24
+
25
+ Attributes:
26
+ lower (float): Lower wavelength.
27
+ upper (float): Upper wavelength.
28
+ assert_range (bool): Assert that the encoded tensor is within the specified wavelength
29
+ range.
30
+ """
31
+
32
+ def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None:
33
+ """Initialise.
34
+
35
+ Args:
36
+ lower (float): Lower wavelength.
37
+ upper (float): Upper wavelength.
38
+ assert_range (bool, optional): Assert that the encoded tensor is within the specified
39
+ wavelength range. Defaults to `True`.
40
+ """
41
+ super().__init__()
42
+ self.lower = lower
43
+ self.upper = upper
44
+ self.assert_range = assert_range
45
+
46
+ def forward(self, x: torch.Tensor, d: int) -> torch.Tensor:
47
+ """Perform the expansion.
48
+
49
+ Adds a dimension of length `d` to the end of the shape of `x`.
50
+
51
+ Args:
52
+ x (:class:`torch.Tensor`): Input to expand of shape `(..., n)`. All elements of `x` must
53
+ lie within `[self.lower, self.upper]` if `self.assert_range` is `True`.
54
+ d (int): Dimensionality. Must be a multiple of two.
55
+
56
+ Raises:
57
+ AssertionError: If `self.assert_range` is `True` and not all elements of `x` are not
58
+ within `[self.lower, self.upper]`.
59
+ ValueError: If `d` is not a multiple of two.
60
+
61
+ Returns:
62
+ torch.Tensor: Fourier series-style expansion of `x` of shape `(..., n, d)`.
63
+ """
64
+ # If the input is not within the configured range, the embedding might be ambiguous!
65
+ in_range = torch.logical_and(
66
+ self.lower <= x.abs(), torch.all(x.abs() <= self.upper)
67
+ )
68
+ in_range_or_zero = torch.all(
69
+ torch.logical_or(in_range, x == 0)
70
+ ) # Allow zeros to pass through.
71
+ if self.assert_range and not in_range_or_zero:
72
+ raise AssertionError(
73
+ f"The input tensor is not within the configured range"
74
+ f" `[{self.lower}, {self.upper}]`."
75
+ )
76
+
77
+ # We will use half of the dimensionality for `sin` and the other half for `cos`.
78
+ if not (d % 2 == 0):
79
+ raise ValueError("The dimensionality must be a multiple of two.")
80
+
81
+ # Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans.
82
+ x = x.double()
83
+
84
+ wavelengths = torch.logspace(
85
+ math.log10(self.lower),
86
+ math.log10(self.upper),
87
+ d // 2,
88
+ base=10,
89
+ device=x.device,
90
+ dtype=x.dtype,
91
+ )
92
+ prod = torch.einsum("...i,j->...ij", x, 2 * np.pi / wavelengths)
93
+ encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1)
94
+
95
+ return encoding.float() # Cast to `float32` to avoid incompatibilities.
96
+
97
+
98
+ # Determine a reasonable smallest value for the scale embedding by assuming a smallest delta in
99
+ # latitudes and longitudes.
100
+ _delta = 0.01 # Reasonable smallest delta in latitude and longitude
101
+ _min_patch_area: float = area(
102
+ torch.tensor(
103
+ [
104
+ # The smallest patches will be at the poles. Just use the north pole.
105
+ [90, 0],
106
+ [90, _delta],
107
+ [90 - _delta, _delta],
108
+ [90 - _delta, 0],
109
+ ],
110
+ dtype=torch.float64,
111
+ )
112
+ ).item()
113
+ _area_earth = 4 * np.pi * radius_earth * radius_earth
114
+
115
+ pos_expansion = FourierExpansion(_delta, 720)
116
+
117
+
118
+ scale_expansion = FourierExpansion(_min_patch_area, _area_earth)
119
+
120
+
121
+ lead_time_expansion = FourierExpansion(1 / 60, 24 * 7 * 3)
122
+
123
+ levels_expansion = FourierExpansion(0.01, 1e5)
124
+
125
+ absolute_time_expansion = FourierExpansion(1, 24 * 365.25, assert_range=False)
126
+
127
+ ### new for SSL4EO-S ###
128
+ # min wavelength: ultraviolet light (100 nm)
129
+ # max wavelength: radio waves (1 m)
130
+ spectrum_central_expansion = FourierExpansion(1e-7, 1)
131
+
132
+ # min bandwidth: 10nm
133
+ # max bandwidth: 1m
134
+ spectrum_width_expansion = FourierExpansion(1e-7, 1)