rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
rslearn/models/molmo.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Molmo model."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Molmo(torch.nn.Module):
|
|
10
|
+
"""Molmo image encoder."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
model_name: str,
|
|
15
|
+
):
|
|
16
|
+
"""Instantiate a new Molmo instance.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model_name: the model name like "allenai/Molmo-7B-D-0924".
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
24
|
+
model_name,
|
|
25
|
+
trust_remote_code=True,
|
|
26
|
+
torch_dtype="auto",
|
|
27
|
+
device_map="cpu",
|
|
28
|
+
)
|
|
29
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
30
|
+
model_name,
|
|
31
|
+
trust_remote_code=True,
|
|
32
|
+
torch_dtype="auto",
|
|
33
|
+
device_map="cpu",
|
|
34
|
+
)
|
|
35
|
+
self.encoder = model.model.vision_backbone
|
|
36
|
+
|
|
37
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
38
|
+
"""Compute outputs from the backbone.
|
|
39
|
+
|
|
40
|
+
Inputs:
|
|
41
|
+
inputs: input dicts that must include "image" key containing the image to
|
|
42
|
+
process. The images should have values 0-255.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
list of feature maps. Molmo produces features at one scale, so the list
|
|
46
|
+
contains a single Bx24x24x2048 tensor.
|
|
47
|
+
"""
|
|
48
|
+
device = inputs[0]["image"].device
|
|
49
|
+
molmo_inputs_list = []
|
|
50
|
+
# Process each one so we can isolate just the full image without any crops.
|
|
51
|
+
for inp in inputs:
|
|
52
|
+
image = inp["image"].cpu().numpy().transpose(1, 2, 0)
|
|
53
|
+
processed = self.processor.process(
|
|
54
|
+
images=[image],
|
|
55
|
+
text="",
|
|
56
|
+
)
|
|
57
|
+
molmo_inputs_list.append(processed["images"][0])
|
|
58
|
+
molmo_inputs: torch.Tensor = torch.stack(molmo_inputs_list, dim=0).unsqueeze(1)
|
|
59
|
+
|
|
60
|
+
image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
|
|
61
|
+
|
|
62
|
+
# 576x2048 -> 24x24x2048
|
|
63
|
+
return [
|
|
64
|
+
image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)
|
|
65
|
+
]
|
rslearn/models/multitask.py
CHANGED
|
@@ -44,7 +44,7 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
44
44
|
tuple (outputs, loss_dict) from the last module.
|
|
45
45
|
"""
|
|
46
46
|
features = self.encoder(inputs)
|
|
47
|
-
outputs = [{} for _ in inputs]
|
|
47
|
+
outputs: list[dict[str, Any]] = [{} for _ in inputs]
|
|
48
48
|
losses = {}
|
|
49
49
|
for name, decoder in self.decoders.items():
|
|
50
50
|
cur = features
|
|
@@ -21,7 +21,7 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
21
21
|
num_fc_layers: int = 0,
|
|
22
22
|
conv_channels: int = 128,
|
|
23
23
|
fc_channels: int = 512,
|
|
24
|
-
):
|
|
24
|
+
) -> None:
|
|
25
25
|
"""Initialize a PoolingDecoder.
|
|
26
26
|
|
|
27
27
|
Args:
|
|
@@ -57,7 +57,9 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
57
57
|
|
|
58
58
|
self.output_layer = torch.nn.Linear(prev_channels, out_channels)
|
|
59
59
|
|
|
60
|
-
def forward(
|
|
60
|
+
def forward(
|
|
61
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
62
|
+
) -> torch.Tensor:
|
|
61
63
|
"""Compute flat output vector from multi-scale feature map.
|
|
62
64
|
|
|
63
65
|
Args:
|
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -13,7 +13,7 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
13
13
|
self,
|
|
14
14
|
model_identifier: str,
|
|
15
15
|
fpn: bool = False,
|
|
16
|
-
):
|
|
16
|
+
) -> None:
|
|
17
17
|
"""Instantiate a new SatlasPretrain instance.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
@@ -25,7 +25,7 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
25
25
|
super().__init__()
|
|
26
26
|
weights_manager = satlaspretrain_models.Weights()
|
|
27
27
|
self.model = weights_manager.get_pretrained_model(
|
|
28
|
-
model_identifier=model_identifier, fpn=fpn
|
|
28
|
+
model_identifier=model_identifier, fpn=fpn, device="cpu"
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
if "SwinB" in model_identifier:
|
|
@@ -50,20 +50,17 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
50
50
|
[32, 2048],
|
|
51
51
|
]
|
|
52
52
|
|
|
53
|
-
def forward(
|
|
54
|
-
self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
|
|
55
|
-
):
|
|
53
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
56
54
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
57
55
|
|
|
58
56
|
Inputs:
|
|
59
57
|
inputs: input dicts that must include "image" key containing the image to
|
|
60
58
|
process.
|
|
61
|
-
targets: target dicts that are ignored
|
|
62
59
|
"""
|
|
63
60
|
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
64
61
|
return self.model(images)
|
|
65
62
|
|
|
66
|
-
def get_backbone_channels(self):
|
|
63
|
+
def get_backbone_channels(self) -> list:
|
|
67
64
|
"""Returns the output channels of this model when used as a backbone.
|
|
68
65
|
|
|
69
66
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -24,7 +24,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
24
24
|
op: str = "max",
|
|
25
25
|
groups: list[list[int]] | None = None,
|
|
26
26
|
num_layers: int | None = None,
|
|
27
|
-
):
|
|
27
|
+
) -> None:
|
|
28
28
|
"""Create a new SimpleTimeSeries.
|
|
29
29
|
|
|
30
30
|
Args:
|
|
@@ -55,63 +55,69 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
55
55
|
else:
|
|
56
56
|
self.num_groups = 1
|
|
57
57
|
|
|
58
|
-
if self.op
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
torch.nn.ReLU(inplace=True),
|
|
68
|
-
)
|
|
69
|
-
]
|
|
70
|
-
for _ in range(num_layers - 1):
|
|
71
|
-
cur_layer.append(
|
|
58
|
+
if self.op in ["convrnn", "conv3d", "conv1d"]:
|
|
59
|
+
if num_layers is None:
|
|
60
|
+
raise ValueError(f"num_layers must be specified for {self.op} op")
|
|
61
|
+
|
|
62
|
+
if self.op == "convrnn":
|
|
63
|
+
rnn_kernel_size = 3
|
|
64
|
+
rnn_layers = []
|
|
65
|
+
for _, count in out_channels:
|
|
66
|
+
cur_layer = [
|
|
72
67
|
torch.nn.Sequential(
|
|
73
68
|
torch.nn.Conv2d(
|
|
74
|
-
count, count, rnn_kernel_size, padding="same"
|
|
69
|
+
2 * count, count, rnn_kernel_size, padding="same"
|
|
75
70
|
),
|
|
76
71
|
torch.nn.ReLU(inplace=True),
|
|
77
72
|
)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
73
|
+
]
|
|
74
|
+
for _ in range(num_layers - 1):
|
|
75
|
+
cur_layer.append(
|
|
76
|
+
torch.nn.Sequential(
|
|
77
|
+
torch.nn.Conv2d(
|
|
78
|
+
count, count, rnn_kernel_size, padding="same"
|
|
79
|
+
),
|
|
80
|
+
torch.nn.ReLU(inplace=True),
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
cur_layer = torch.nn.Sequential(*cur_layer)
|
|
84
|
+
rnn_layers.append(cur_layer)
|
|
85
|
+
self.rnn_layers = torch.nn.ModuleList(rnn_layers)
|
|
86
|
+
|
|
87
|
+
elif self.op == "conv3d":
|
|
88
|
+
conv3d_layers = []
|
|
89
|
+
for _, count in out_channels:
|
|
90
|
+
cur_layer = [
|
|
91
|
+
torch.nn.Sequential(
|
|
92
|
+
torch.nn.Conv3d(
|
|
93
|
+
count, count, 3, padding=1, stride=(2, 1, 1)
|
|
94
|
+
),
|
|
95
|
+
torch.nn.ReLU(inplace=True),
|
|
96
|
+
)
|
|
97
|
+
for _ in range(num_layers)
|
|
98
|
+
]
|
|
99
|
+
cur_layer = torch.nn.Sequential(*cur_layer)
|
|
100
|
+
conv3d_layers.append(cur_layer)
|
|
101
|
+
self.conv3d_layers = torch.nn.ModuleList(conv3d_layers)
|
|
102
|
+
|
|
103
|
+
elif self.op == "conv1d":
|
|
104
|
+
conv1d_layers = []
|
|
105
|
+
for _, count in out_channels:
|
|
106
|
+
cur_layer = [
|
|
107
|
+
torch.nn.Sequential(
|
|
108
|
+
torch.nn.Conv1d(count, count, 3, padding=1, stride=2),
|
|
109
|
+
torch.nn.ReLU(inplace=True),
|
|
110
|
+
)
|
|
111
|
+
for _ in range(num_layers)
|
|
112
|
+
]
|
|
113
|
+
cur_layer = torch.nn.Sequential(*cur_layer)
|
|
114
|
+
conv1d_layers.append(cur_layer)
|
|
115
|
+
self.conv1d_layers = torch.nn.ModuleList(conv1d_layers)
|
|
110
116
|
|
|
111
117
|
else:
|
|
112
118
|
assert self.op in ["max", "mean"]
|
|
113
119
|
|
|
114
|
-
def get_backbone_channels(self):
|
|
120
|
+
def get_backbone_channels(self) -> list:
|
|
115
121
|
"""Returns the output channels of this model when used as a backbone.
|
|
116
122
|
|
|
117
123
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -129,14 +135,14 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
129
135
|
return out_channels
|
|
130
136
|
|
|
131
137
|
def forward(
|
|
132
|
-
self,
|
|
133
|
-
|
|
138
|
+
self,
|
|
139
|
+
inputs: list[dict[str, Any]],
|
|
140
|
+
) -> list[torch.Tensor]:
|
|
134
141
|
"""Compute outputs from the backbone.
|
|
135
142
|
|
|
136
143
|
Inputs:
|
|
137
144
|
inputs: input dicts that must include "image" key containing the image time
|
|
138
145
|
series to process (with images concatenated on the channel dimension).
|
|
139
|
-
targets: target dicts that are ignored unless
|
|
140
146
|
"""
|
|
141
147
|
# First get features of each image.
|
|
142
148
|
# To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
|
|
@@ -171,13 +177,13 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
171
177
|
for feature_idx in range(len(all_features)):
|
|
172
178
|
aggregated_features = []
|
|
173
179
|
for group in groups:
|
|
174
|
-
|
|
180
|
+
group_features_list = []
|
|
175
181
|
for image_idx in group:
|
|
176
|
-
|
|
182
|
+
group_features_list.append(
|
|
177
183
|
all_features[feature_idx][:, image_idx, :, :, :]
|
|
178
184
|
)
|
|
179
185
|
# Resulting group features are (depth, batch, C, height, width).
|
|
180
|
-
group_features = torch.stack(
|
|
186
|
+
group_features = torch.stack(group_features_list, dim=0)
|
|
181
187
|
|
|
182
188
|
if self.op == "max":
|
|
183
189
|
group_features = torch.amax(group_features, dim=0)
|
rslearn/models/ssl4eo_s12.py
CHANGED
|
@@ -14,7 +14,7 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
14
14
|
backbone_ckpt_path: str,
|
|
15
15
|
arch: str = "resnet50",
|
|
16
16
|
output_layers: list[int] = [0, 1, 2, 3],
|
|
17
|
-
):
|
|
17
|
+
) -> None:
|
|
18
18
|
"""Instantiate a new Swin instance.
|
|
19
19
|
|
|
20
20
|
Args:
|
|
@@ -51,7 +51,7 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
51
51
|
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
52
52
|
)
|
|
53
53
|
|
|
54
|
-
def get_backbone_channels(self):
|
|
54
|
+
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
55
55
|
"""Returns the output channels of this model when used as a backbone.
|
|
56
56
|
|
|
57
57
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -65,16 +65,17 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
65
65
|
"""
|
|
66
66
|
if self.arch == "resnet50":
|
|
67
67
|
all_out_channels = [
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
68
|
+
(4, 256),
|
|
69
|
+
(8, 512),
|
|
70
|
+
(16, 1024),
|
|
71
|
+
(32, 2048),
|
|
72
72
|
]
|
|
73
73
|
return [all_out_channels[idx] for idx in self.output_layers]
|
|
74
74
|
|
|
75
75
|
def forward(
|
|
76
|
-
self,
|
|
77
|
-
|
|
76
|
+
self,
|
|
77
|
+
inputs: list[dict[str, Any]],
|
|
78
|
+
) -> list[torch.Tensor]:
|
|
78
79
|
"""Compute outputs from the backbone.
|
|
79
80
|
|
|
80
81
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
@@ -84,7 +85,6 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
84
85
|
Inputs:
|
|
85
86
|
inputs: input dicts that must include "image" key containing the image to
|
|
86
87
|
process.
|
|
87
|
-
targets: target dicts that are ignored unless
|
|
88
88
|
"""
|
|
89
89
|
x = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
90
90
|
x = self.model.conv1(x)
|
rslearn/models/swin.py
CHANGED
|
@@ -28,7 +28,7 @@ class Swin(torch.nn.Module):
|
|
|
28
28
|
input_channels: int = 3,
|
|
29
29
|
output_layers: list[int] | None = None,
|
|
30
30
|
num_outputs: int = 1000,
|
|
31
|
-
):
|
|
31
|
+
) -> None:
|
|
32
32
|
"""Instantiate a new Swin instance.
|
|
33
33
|
|
|
34
34
|
Args:
|
|
@@ -89,7 +89,7 @@ class Swin(torch.nn.Module):
|
|
|
89
89
|
if num_outputs != self.model.head.out_features:
|
|
90
90
|
self.model.head = torch.nn.Linear(self.model.head.in_features, num_outputs)
|
|
91
91
|
|
|
92
|
-
def get_backbone_channels(self):
|
|
92
|
+
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
93
93
|
"""Returns the output channels of this model when used as a backbone.
|
|
94
94
|
|
|
95
95
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -105,31 +105,33 @@ class Swin(torch.nn.Module):
|
|
|
105
105
|
|
|
106
106
|
if self.arch in ["swin_b", "swin_v2_b"]:
|
|
107
107
|
all_out_channels = [
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
108
|
+
(4, 128),
|
|
109
|
+
(4, 128),
|
|
110
|
+
(4, 128),
|
|
111
|
+
(8, 256),
|
|
112
|
+
(8, 256),
|
|
113
|
+
(16, 512),
|
|
114
|
+
(16, 512),
|
|
115
|
+
(32, 1024),
|
|
116
|
+
(32, 1024),
|
|
116
117
|
]
|
|
117
118
|
elif self.arch in ["swin_s", "swin_v2_s", "swin_t", "swin_v2_t"]:
|
|
118
119
|
all_out_channels = [
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
120
|
+
(4, 96),
|
|
121
|
+
(4, 96),
|
|
122
|
+
(8, 192),
|
|
123
|
+
(8, 192),
|
|
124
|
+
(16, 384),
|
|
125
|
+
(16, 384),
|
|
126
|
+
(32, 768),
|
|
127
|
+
(32, 768),
|
|
127
128
|
]
|
|
128
129
|
return [all_out_channels[idx] for idx in self.output_layers]
|
|
129
130
|
|
|
130
131
|
def forward(
|
|
131
|
-
self,
|
|
132
|
-
|
|
132
|
+
self,
|
|
133
|
+
inputs: list[dict[str, Any]],
|
|
134
|
+
) -> list[torch.Tensor]:
|
|
133
135
|
"""Compute outputs from the backbone.
|
|
134
136
|
|
|
135
137
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
@@ -139,7 +141,6 @@ class Swin(torch.nn.Module):
|
|
|
139
141
|
Inputs:
|
|
140
142
|
inputs: input dicts that must include "image" key containing the image to
|
|
141
143
|
process.
|
|
142
|
-
targets: target dicts that are ignored unless
|
|
143
144
|
"""
|
|
144
145
|
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
145
146
|
|
rslearn/models/unet.py
CHANGED
|
@@ -19,7 +19,7 @@ class UNetDecoder(torch.nn.Module):
|
|
|
19
19
|
out_channels: int,
|
|
20
20
|
conv_layers_per_resolution: int = 1,
|
|
21
21
|
kernel_size: int = 3,
|
|
22
|
-
):
|
|
22
|
+
) -> None:
|
|
23
23
|
"""Initialize a UNetDecoder.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
@@ -110,7 +110,9 @@ class UNetDecoder(torch.nn.Module):
|
|
|
110
110
|
layers.append(torch.nn.Sequential(*cur_layers))
|
|
111
111
|
self.layers = torch.nn.ModuleList(layers)
|
|
112
112
|
|
|
113
|
-
def forward(
|
|
113
|
+
def forward(
|
|
114
|
+
self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
115
|
+
) -> torch.Tensor:
|
|
114
116
|
"""Compute output from multi-scale feature map.
|
|
115
117
|
|
|
116
118
|
Args:
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""An upsampling layer."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Upsample(torch.nn.Module):
|
|
7
|
+
"""Upsamples each input feature map by the same factor."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
scale_factor: int,
|
|
12
|
+
mode: str = "bilinear",
|
|
13
|
+
):
|
|
14
|
+
"""Initialize an Upsample.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
scale_factor: the upsampling factor, e.g. 2 to double the size.
|
|
18
|
+
mode: "nearest" or "bilinear".
|
|
19
|
+
"""
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
|
|
22
|
+
|
|
23
|
+
def forward(
|
|
24
|
+
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
25
|
+
) -> list[torch.Tensor]:
|
|
26
|
+
"""Compute flat output vector from multi-scale feature map.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
features: list of feature maps at different resolutions.
|
|
30
|
+
inputs: original inputs (ignored).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
flat feature vector
|
|
34
|
+
"""
|
|
35
|
+
return [self.layer(feat_map) for feat_map in features]
|
rslearn/tile_stores/file.py
CHANGED
|
@@ -11,7 +11,6 @@ from rslearn.utils import Feature, PixelBounds, Projection
|
|
|
11
11
|
from rslearn.utils.fsspec import open_atomic
|
|
12
12
|
from rslearn.utils.raster_format import (
|
|
13
13
|
GeotiffRasterFormat,
|
|
14
|
-
RasterFormat,
|
|
15
14
|
load_raster_format,
|
|
16
15
|
)
|
|
17
16
|
from rslearn.utils.vector_format import (
|
|
@@ -30,7 +29,7 @@ class FileTileStoreLayer(TileStoreLayer):
|
|
|
30
29
|
self,
|
|
31
30
|
path: UPath,
|
|
32
31
|
projection: Projection | None = None,
|
|
33
|
-
raster_format:
|
|
32
|
+
raster_format: GeotiffRasterFormat = GeotiffRasterFormat(),
|
|
34
33
|
vector_format: VectorFormat = GeojsonVectorFormat(),
|
|
35
34
|
):
|
|
36
35
|
"""Creates a new FileTileStoreLayer.
|
|
@@ -70,6 +69,8 @@ class FileTileStoreLayer(TileStoreLayer):
|
|
|
70
69
|
bounds: the bounds of the raster
|
|
71
70
|
array: the raster data
|
|
72
71
|
"""
|
|
72
|
+
if self.projection is None:
|
|
73
|
+
raise ValueError("need a projection to encode and write raster data")
|
|
73
74
|
self.raster_format.encode_raster(self.path, self.projection, bounds, array)
|
|
74
75
|
|
|
75
76
|
def get_raster_bounds(self) -> PixelBounds:
|
|
@@ -93,6 +94,8 @@ class FileTileStoreLayer(TileStoreLayer):
|
|
|
93
94
|
Args:
|
|
94
95
|
data: the vector data
|
|
95
96
|
"""
|
|
97
|
+
if self.projection is None:
|
|
98
|
+
raise ValueError("need a projection to encode and write vector data")
|
|
96
99
|
self.vector_format.encode_vector(self.path, self.projection, data)
|
|
97
100
|
|
|
98
101
|
def get_metadata(self) -> LayerMetadata:
|
|
@@ -125,7 +128,7 @@ class FileTileStore(TileStore):
|
|
|
125
128
|
def __init__(
|
|
126
129
|
self,
|
|
127
130
|
path: UPath,
|
|
128
|
-
raster_format:
|
|
131
|
+
raster_format: GeotiffRasterFormat = GeotiffRasterFormat(),
|
|
129
132
|
vector_format: VectorFormat = GeojsonVectorFormat(),
|
|
130
133
|
):
|
|
131
134
|
"""Initialize a new FileTileStore.
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base class for tile stores."""
|
|
2
2
|
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
3
4
|
from datetime import datetime
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
@@ -52,12 +53,13 @@ class LayerMetadata:
|
|
|
52
53
|
)
|
|
53
54
|
|
|
54
55
|
|
|
55
|
-
class TileStoreLayer:
|
|
56
|
+
class TileStoreLayer(ABC):
|
|
56
57
|
"""An abstract class for a layer in a tile store.
|
|
57
58
|
|
|
58
59
|
The layer can store one or more raster and vector datas.
|
|
59
60
|
"""
|
|
60
61
|
|
|
62
|
+
@abstractmethod
|
|
61
63
|
def read_raster(self, bounds: PixelBounds) -> npt.NDArray[Any] | None:
|
|
62
64
|
"""Read raster data from the store.
|
|
63
65
|
|
|
@@ -67,8 +69,9 @@ class TileStoreLayer:
|
|
|
67
69
|
Returns:
|
|
68
70
|
the raster data
|
|
69
71
|
"""
|
|
70
|
-
|
|
72
|
+
pass
|
|
71
73
|
|
|
74
|
+
@abstractmethod
|
|
72
75
|
def write_raster(self, bounds: PixelBounds, array: npt.NDArray[Any]) -> None:
|
|
73
76
|
"""Write raster data to the store.
|
|
74
77
|
|
|
@@ -76,8 +79,9 @@ class TileStoreLayer:
|
|
|
76
79
|
bounds: the bounds of the raster
|
|
77
80
|
array: the raster data
|
|
78
81
|
"""
|
|
79
|
-
|
|
82
|
+
pass
|
|
80
83
|
|
|
84
|
+
@abstractmethod
|
|
81
85
|
def read_vector(self, bounds: PixelBounds) -> list[Feature]:
|
|
82
86
|
"""Read vector data from the store.
|
|
83
87
|
|
|
@@ -87,20 +91,23 @@ class TileStoreLayer:
|
|
|
87
91
|
Returns:
|
|
88
92
|
the vector data
|
|
89
93
|
"""
|
|
90
|
-
|
|
94
|
+
pass
|
|
91
95
|
|
|
96
|
+
@abstractmethod
|
|
92
97
|
def write_vector(self, data: list[Feature]) -> None:
|
|
93
98
|
"""Save vector tiles to the store.
|
|
94
99
|
|
|
95
100
|
Args:
|
|
96
101
|
data: the vector data
|
|
97
102
|
"""
|
|
98
|
-
|
|
103
|
+
pass
|
|
99
104
|
|
|
105
|
+
@abstractmethod
|
|
100
106
|
def get_metadata(self) -> LayerMetadata:
|
|
101
107
|
"""Get the LayerMetadata associated with this layer."""
|
|
102
|
-
|
|
108
|
+
pass
|
|
103
109
|
|
|
110
|
+
@abstractmethod
|
|
104
111
|
def set_property(self, key: str, value: Any) -> None:
|
|
105
112
|
"""Set a property in the metadata for this layer.
|
|
106
113
|
|
|
@@ -108,7 +115,12 @@ class TileStoreLayer:
|
|
|
108
115
|
key: the property key
|
|
109
116
|
value: the property value
|
|
110
117
|
"""
|
|
111
|
-
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def get_raster_bounds(self) -> PixelBounds:
|
|
122
|
+
"""Get the bounds of the raster data in the store."""
|
|
123
|
+
pass
|
|
112
124
|
|
|
113
125
|
|
|
114
126
|
class TileStore:
|
|
@@ -14,7 +14,7 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
14
14
|
module_selector: list[str | int],
|
|
15
15
|
unfreeze_at_epoch: int | None = None,
|
|
16
16
|
unfreeze_lr_factor: float = 1,
|
|
17
|
-
):
|
|
17
|
+
) -> None:
|
|
18
18
|
"""Creates a new FreezeUnfreeze.
|
|
19
19
|
|
|
20
20
|
Args:
|
|
@@ -40,7 +40,7 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
40
40
|
target_module = getattr(target_module, k)
|
|
41
41
|
return target_module
|
|
42
42
|
|
|
43
|
-
def freeze_before_training(self, pl_module: LightningModule):
|
|
43
|
+
def freeze_before_training(self, pl_module: LightningModule) -> None:
|
|
44
44
|
"""Freeze the model at the beginning of training.
|
|
45
45
|
|
|
46
46
|
Args:
|
|
@@ -51,7 +51,7 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
51
51
|
|
|
52
52
|
def finetune_function(
|
|
53
53
|
self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
|
|
54
|
-
):
|
|
54
|
+
) -> None:
|
|
55
55
|
"""Check whether we should unfreeze the model on each epoch.
|
|
56
56
|
|
|
57
57
|
Args:
|