rslearn 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. rslearn/dataset/handler_summaries.py +130 -0
  2. rslearn/dataset/manage.py +157 -22
  3. rslearn/main.py +60 -8
  4. rslearn/models/anysat.py +207 -0
  5. rslearn/models/clay/clay.py +219 -0
  6. rslearn/models/clay/configs/metadata.yaml +295 -0
  7. rslearn/models/copernicusfm.py +37 -25
  8. rslearn/models/dinov3.py +165 -0
  9. rslearn/models/galileo/__init__.py +5 -0
  10. rslearn/models/galileo/galileo.py +517 -0
  11. rslearn/models/galileo/single_file_galileo.py +1672 -0
  12. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  13. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  14. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  15. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  16. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  17. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  18. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  19. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  20. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  21. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  22. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  23. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  24. rslearn/models/presto/presto.py +10 -7
  25. rslearn/models/prithvi.py +1122 -0
  26. rslearn/models/resize_features.py +45 -0
  27. rslearn/models/simple_time_series.py +65 -10
  28. rslearn/models/unet.py +17 -11
  29. rslearn/models/upsample.py +2 -2
  30. rslearn/tile_stores/default.py +31 -6
  31. rslearn/train/transforms/normalize.py +34 -5
  32. rslearn/train/transforms/select_bands.py +67 -0
  33. rslearn/train/transforms/sentinel1.py +60 -0
  34. rslearn/utils/geometry.py +61 -1
  35. rslearn/utils/raster_format.py +7 -1
  36. rslearn/utils/vector_format.py +13 -10
  37. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/METADATA +144 -15
  38. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/RECORD +42 -18
  39. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/WHEEL +0 -0
  40. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/entry_points.txt +0 -0
  41. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/licenses/LICENSE +0 -0
  42. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,12 @@
3
3
  import logging
4
4
  import math
5
5
  from enum import Enum
6
+ from pathlib import Path
6
7
 
7
8
  import torch
8
9
  import torch.nn.functional as F
9
10
  from einops import rearrange
10
- from upath import UPath
11
+ from huggingface_hub import hf_hub_download
11
12
 
12
13
  from .copernicusfm_src.model_vit import vit_base_patch16
13
14
 
@@ -64,6 +65,10 @@ MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
64
65
  },
65
66
  }
66
67
 
68
+ HF_REPO_ID = "wangyi111/Copernicus-FM"
69
+ HF_REPO_REVISION = "e1db406d517a122c8373802e1c130c5fc4789f84"
70
+ HF_FILENAME = "CopernicusFM_ViT_base_varlang_e100.pth"
71
+
67
72
 
68
73
  class CopernicusFM(torch.nn.Module):
69
74
  """Wrapper for Copernicus FM to ingest Masked Helios Sample."""
@@ -80,44 +85,51 @@ class CopernicusFM(torch.nn.Module):
80
85
  def __init__(
81
86
  self,
82
87
  band_order: dict[str, list[str]],
83
- load_directory: str | None,
88
+ cache_dir: str | Path | None = None,
84
89
  ) -> None:
85
90
  """Initialize the Copernicus FM wrapper.
86
91
 
87
92
  Args:
88
- band_order: The band order for each modality
89
- load_directory: The directory to load from, if None no weights are loaded
93
+ band_order: The band order for each modality that will be used. The bands
94
+ can be provided in any order, and any subset can be used.
95
+ cache_dir: The directory to cache the weights. If None, a default directory
96
+ managed by huggingface_hub is used. The weights are downloaded from
97
+ Hugging Face (https://huggingface.co/wangyi111/Copernicus-FM).
90
98
  """
91
99
  super().__init__()
92
100
 
101
+ # Make sure all keys in band_order are in supported_modalities.
102
+ for modality_name in band_order.keys():
103
+ if modality_name in self.supported_modalities:
104
+ continue
105
+ raise ValueError(
106
+ f"band_order contains unsupported modality {modality_name}"
107
+ )
108
+
93
109
  # global_pool=True so that we initialize the fc_norm layer
94
- self.band_order = band_order
95
110
  self.model = vit_base_patch16(num_classes=10, global_pool=True)
96
- if load_directory is not None:
97
- check_point = torch.load(
98
- UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth",
99
- weights_only=True,
100
- )
101
- if "model" in check_point:
102
- state_dict = check_point["model"]
103
- else:
104
- state_dict = check_point
105
- self.model.load_state_dict(state_dict, strict=False)
106
-
107
- # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same
108
- # ordering as the Helios band orders, defined by Modality.band_order
111
+
112
+ # Load weights, downloading if needed.
113
+ local_fname = hf_hub_download(
114
+ repo_id=HF_REPO_ID,
115
+ revision=HF_REPO_REVISION,
116
+ filename=HF_FILENAME,
117
+ local_dir=cache_dir,
118
+ ) # nosec
119
+ state_dict = torch.load(local_fname, weights_only=True)
120
+ self.model.load_state_dict(state_dict, strict=False)
121
+
122
+ # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrange it so that it has the same
123
+ # ordering as the user-provided band order.
109
124
  self.modality_to_wavelength_bandwidths = {}
110
125
  for modality in self.supported_modalities:
126
+ if modality not in band_order:
127
+ continue
128
+
111
129
  wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
112
130
  wavelengths = []
113
131
  bandwidths = []
114
- modality_band_order = self.band_order.get(modality, None)
115
- if modality_band_order is None:
116
- logger.warning(
117
- f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified"
118
- )
119
- continue
120
- for b in modality_band_order:
132
+ for b in band_order[modality]:
121
133
  cfm_idx = wavelength_bandwidths["band_names"].index(b)
122
134
  wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
123
135
  bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
@@ -0,0 +1,165 @@
1
+ """DinoV3 model."""
2
+
3
+ from enum import StrEnum
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torchvision
9
+ from einops import rearrange
10
+ from torchvision.transforms import v2
11
+
12
+ from rslearn.train.transforms.transform import Transform
13
+
14
+
15
+ class DinoV3Models(StrEnum):
16
+ """Names for different DinoV3 images on torch hub."""
17
+
18
+ SMALL_WEB = "dinov3_vits16"
19
+ SMALL_PLUS_WEB = "dinov3_vits16plus"
20
+ BASE_WEB = "dinov3_vitb16"
21
+ LARGE_WEB = "dinov3_vitl16"
22
+ HUGE_PLUS_WEB = "dinov3_vith16plus"
23
+ FULL_7B_WEB = "dinov3_vit7b16"
24
+ LARGE_SATELLITE = "dinov3_vitl16_sat"
25
+ FULL_7B_SATELLITE = "dinov3_vit7b16_sat"
26
+
27
+
28
+ DINOV3_PTHS: dict[str, str] = {
29
+ DinoV3Models.LARGE_SATELLITE: "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
30
+ DinoV3Models.FULL_7B_SATELLITE: "dinov3_vit7b16_pretrain_sat493m-a6675841.pth",
31
+ DinoV3Models.BASE_WEB: "dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth",
32
+ DinoV3Models.LARGE_WEB: "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
33
+ DinoV3Models.HUGE_PLUS_WEB: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth",
34
+ DinoV3Models.FULL_7B_WEB: "dinov3_vit7b16_pretrain_lvd1689m-a955f4.pth",
35
+ }
36
+
37
+
38
+ class DinoV3(torch.nn.Module):
39
+ """DinoV3 Backbones.
40
+
41
+ Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
42
+ See https://github.com/facebookresearch/dinov3?tab=readme-ov-file#pretrained-models
43
+
44
+ Only takes RGB as input. Expects normalized data (use the below normalizer).
45
+
46
+ Uses patch size 16. The input is resized to 256x256; when applying DinoV3 on
47
+ segmentation or detection tasks with inputs larger than 256x256, it may be best to
48
+ train and predict on 256x256 crops (using SplitConfig.patch_size argument).
49
+ """
50
+
51
+ image_size: int = 256
52
+ patch_size: int = 16
53
+ output_dim: int = 1024
54
+
55
+ def _load_model(self, size: str, checkpoint_dir: str | None) -> torch.nn.Module:
56
+ model_name = size.replace("_sat", "")
57
+ if checkpoint_dir is not None:
58
+ weights = str(Path(checkpoint_dir) / DINOV3_PTHS[size])
59
+ return torch.hub.load(
60
+ "facebookresearch/dinov3",
61
+ model_name,
62
+ weights=weights,
63
+ ) # nosec
64
+ return torch.hub.load("facebookresearch/dinov3", model_name, pretrained=False) # nosec
65
+
66
+ def __init__(
67
+ self,
68
+ checkpoint_dir: str | None,
69
+ size: str = DinoV3Models.LARGE_SATELLITE,
70
+ use_cls_token: bool = False,
71
+ do_resizing: bool = True,
72
+ ) -> None:
73
+ """Instantiate a new DinoV3 instance.
74
+
75
+ Args:
76
+ checkpoint_dir: the local path to the pretrained weight dir. If None, we load the architecture
77
+ only (randomly initialized).
78
+ size: the model size, see class for various models.
79
+ use_cls_token: use pooled class token (for classification), otherwise returns spatial feature map.
80
+ do_resizing: whether to resize inputs to 256x256. Default true.
81
+ """
82
+ super().__init__()
83
+ self.size = size
84
+ self.checkpoint_dir = checkpoint_dir
85
+ self.use_cls_token = use_cls_token
86
+ self.do_resizing = do_resizing
87
+ self.model = self._load_model(size, checkpoint_dir)
88
+
89
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
90
+ """Forward pass for the dinov3 model.
91
+
92
+ Args:
93
+ inputs: input dicts that must include "image" key.
94
+
95
+ Returns:
96
+ List[torch.Tensor]: Single-scale feature tensors from the encoder.
97
+ """
98
+ cur = torch.stack([inp["image"] for inp in inputs], dim=0) # (B, C, H, W)
99
+
100
+ if self.do_resizing and (
101
+ cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
102
+ ):
103
+ cur = torchvision.transforms.functional.resize(
104
+ cur,
105
+ [self.image_size, self.image_size],
106
+ )
107
+
108
+ if self.use_cls_token:
109
+ features = self.model(cur)
110
+ else:
111
+ features = self.model.forward_features(cur)["x_norm_patchtokens"]
112
+ batch_size, num_patches, _ = features.shape
113
+ height, width = int(num_patches**0.5), int(num_patches**0.5)
114
+ features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
115
+
116
+ return [features]
117
+
118
+ def get_backbone_channels(self) -> list:
119
+ """Returns the output channels of this model when used as a backbone.
120
+
121
+ The output channels is a list of (downsample_factor, depth) that corresponds
122
+ to the feature maps that the backbone returns. For example, an element [2, 32]
123
+ indicates that the corresponding feature map is 1/2 the input resolution and
124
+ has 32 channels.
125
+ """
126
+ return [(self.patch_size, self.output_dim)]
127
+
128
+
129
+ class DinoV3Normalize(Transform):
130
+ """Normalize inputs using DinoV3 normalization.
131
+
132
+ Normalize "image" key in input according to Dino statistics from pretraining. Satellite pretraining has slightly different normalizing than the base image model so set 'satellite' depending on what pretrained model you are using.
133
+
134
+ Input "image" should be RGB-like image between 0-255.
135
+ """
136
+
137
+ def __init__(self, satellite: bool = True):
138
+ """Initialize a new DinoV3Normalize."""
139
+ super().__init__()
140
+ self.satellite = satellite
141
+ if satellite:
142
+ self.normalize = v2.Normalize(
143
+ mean=(0.430, 0.411, 0.296),
144
+ std=(0.213, 0.156, 0.143),
145
+ )
146
+ else:
147
+ self.normalize = v2.Normalize(
148
+ mean=(0.485, 0.456, 0.406),
149
+ std=(0.229, 0.224, 0.225),
150
+ )
151
+
152
+ def forward(
153
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
154
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
155
+ """Normalize the specified image with DinoV3 normalization.
156
+
157
+ Args:
158
+ input_dict: the input dictionary.
159
+ target_dict: the target dictionary.
160
+
161
+ Returns:
162
+ normalized (input_dicts, target_dicts) tuple
163
+ """
164
+ input_dict["image"] = self.normalize(input_dict["image"] / 255.0)
165
+ return input_dict, target_dict
@@ -0,0 +1,5 @@
1
+ """Galileo model."""
2
+
3
+ from .galileo import GalileoModel, GalileoSize
4
+
5
+ __all__ = ["GalileoModel", "GalileoSize"]