rslearn 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +186 -49
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/detection.py +1 -18
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +27 -32
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/colors.py +20 -0
- rslearn/vis/__init__.py +1 -0
- rslearn/vis/normalization.py +127 -0
- rslearn/vis/render_raster_label.py +96 -0
- rslearn/vis/render_sensor_image.py +27 -0
- rslearn/vis/render_vector_label.py +439 -0
- rslearn/vis/utils.py +99 -0
- rslearn/vis/vis_server.py +574 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/METADATA +14 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/RECORD +42 -33
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/WHEEL +1 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Dataset index for caching window lists to speed up ModelDataset initialization."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.dataset.window import Window
|
|
11
|
+
from rslearn.log_utils import get_logger
|
|
12
|
+
from rslearn.utils.fsspec import open_atomic
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
# Increment this when the index format changes to force rebuild
|
|
20
|
+
INDEX_VERSION = 1
|
|
21
|
+
|
|
22
|
+
# Directory name for storing index files
|
|
23
|
+
INDEX_DIR_NAME = ".rslearn_dataset_index"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DatasetIndex:
|
|
27
|
+
"""Manages indexed window lists for faster ModelDataset initialization.
|
|
28
|
+
|
|
29
|
+
Note: The index does NOT automatically detect when windows are added or removed
|
|
30
|
+
from the dataset. Use refresh=True after modifying dataset windows.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
storage: "WindowStorage",
|
|
36
|
+
dataset_path: UPath,
|
|
37
|
+
groups: list[str] | None,
|
|
38
|
+
names: list[str] | None,
|
|
39
|
+
tags: dict[str, Any] | None,
|
|
40
|
+
num_samples: int | None,
|
|
41
|
+
skip_targets: bool,
|
|
42
|
+
inputs: dict[str, Any],
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Initialize DatasetIndex with specific configuration.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
storage: WindowStorage for deserializing windows.
|
|
48
|
+
dataset_path: Path to the dataset directory.
|
|
49
|
+
groups: list of window groups to include.
|
|
50
|
+
names: list of window names to include.
|
|
51
|
+
tags: tags to filter windows by.
|
|
52
|
+
num_samples: limit on number of samples.
|
|
53
|
+
skip_targets: whether targets are skipped.
|
|
54
|
+
inputs: dict mapping input names to DataInput objects.
|
|
55
|
+
"""
|
|
56
|
+
self.storage = storage
|
|
57
|
+
self.dataset_path = dataset_path
|
|
58
|
+
self.index_dir = dataset_path / INDEX_DIR_NAME
|
|
59
|
+
|
|
60
|
+
# Compute index key from configuration
|
|
61
|
+
inputs_data = {}
|
|
62
|
+
for name, inp in inputs.items():
|
|
63
|
+
inputs_data[name] = {
|
|
64
|
+
"layers": inp.layers,
|
|
65
|
+
"required": inp.required,
|
|
66
|
+
"load_all_layers": inp.load_all_layers,
|
|
67
|
+
"is_target": inp.is_target,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
key_data = {
|
|
71
|
+
"groups": groups,
|
|
72
|
+
"names": names,
|
|
73
|
+
"tags": tags,
|
|
74
|
+
"num_samples": num_samples,
|
|
75
|
+
"skip_targets": skip_targets,
|
|
76
|
+
"inputs": inputs_data,
|
|
77
|
+
}
|
|
78
|
+
self.index_key = hashlib.sha256(
|
|
79
|
+
json.dumps(key_data, sort_keys=True).encode()
|
|
80
|
+
).hexdigest()
|
|
81
|
+
|
|
82
|
+
def _get_config_hash(self) -> str:
|
|
83
|
+
"""Get hash of config.json for quick validation.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
A 16-character hex string hash of the config, or empty string if no config.
|
|
87
|
+
"""
|
|
88
|
+
config_path = self.dataset_path / "config.json"
|
|
89
|
+
if config_path.exists():
|
|
90
|
+
with config_path.open() as f:
|
|
91
|
+
return hashlib.sha256(f.read().encode()).hexdigest()[:16]
|
|
92
|
+
return ""
|
|
93
|
+
|
|
94
|
+
def load_windows(self, refresh: bool = False) -> list[Window] | None:
|
|
95
|
+
"""Load indexed window list if valid, else return None.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
refresh: If True, ignore existing index and return None.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
List of Window objects if index is valid, None otherwise.
|
|
102
|
+
"""
|
|
103
|
+
if refresh:
|
|
104
|
+
logger.info("refresh=True, rebuilding index")
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
index_file = self.index_dir / f"{self.index_key}.json"
|
|
108
|
+
if not index_file.exists():
|
|
109
|
+
logger.info(f"No index found at {index_file}, will build")
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
with index_file.open() as f:
|
|
114
|
+
index_data = json.load(f)
|
|
115
|
+
except (OSError, json.JSONDecodeError):
|
|
116
|
+
logger.warning(f"Corrupted index file at {index_file}, will rebuild")
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
# Check index version
|
|
120
|
+
if index_data.get("version") != INDEX_VERSION:
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Index version mismatch (got {index_data.get('version')}, "
|
|
123
|
+
f"expected {INDEX_VERSION}), will rebuild"
|
|
124
|
+
)
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
# Quick validation: check config hash
|
|
128
|
+
if index_data.get("config_hash") != self._get_config_hash():
|
|
129
|
+
logger.info("Config hash mismatch, index invalidated")
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
# Deserialize windows
|
|
133
|
+
return [Window.from_metadata(self.storage, w) for w in index_data["windows"]]
|
|
134
|
+
|
|
135
|
+
def save_windows(self, windows: list[Window]) -> None:
|
|
136
|
+
"""Save processed windows to index with atomic write.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
windows: List of Window objects to index.
|
|
140
|
+
"""
|
|
141
|
+
self.index_dir.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
index_file = self.index_dir / f"{self.index_key}.json"
|
|
143
|
+
|
|
144
|
+
# Serialize windows
|
|
145
|
+
serialized_windows = [w.get_metadata() for w in windows]
|
|
146
|
+
|
|
147
|
+
index_data = {
|
|
148
|
+
"version": INDEX_VERSION,
|
|
149
|
+
"config_hash": self._get_config_hash(),
|
|
150
|
+
"created_at": datetime.now().isoformat(),
|
|
151
|
+
"num_windows": len(windows),
|
|
152
|
+
"windows": serialized_windows,
|
|
153
|
+
}
|
|
154
|
+
with open_atomic(index_file, "w") as f:
|
|
155
|
+
json.dump(index_data, f)
|
|
156
|
+
logger.info(f"Saved {len(windows)} windows to index at {index_file}")
|
rslearn/train/model_context.py
CHANGED
|
@@ -43,6 +43,22 @@ class RasterImage:
|
|
|
43
43
|
raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
|
|
44
44
|
return self.image[:, 0]
|
|
45
45
|
|
|
46
|
+
def get_hw_tensor(self) -> torch.Tensor:
|
|
47
|
+
"""Get a 2D HW tensor from a single-channel, single-timestep RasterImage.
|
|
48
|
+
|
|
49
|
+
This function checks that C=1 and T=1, then returns the HW tensor.
|
|
50
|
+
Useful for per-pixel labels like segmentation masks.
|
|
51
|
+
"""
|
|
52
|
+
if self.image.shape[0] != 1:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Expected single channel (C=1), got {self.image.shape[0]}"
|
|
55
|
+
)
|
|
56
|
+
if self.image.shape[1] != 1:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Expected single timestep (T=1), got {self.image.shape[1]}"
|
|
59
|
+
)
|
|
60
|
+
return self.image[0, 0]
|
|
61
|
+
|
|
46
62
|
|
|
47
63
|
@dataclass
|
|
48
64
|
class SampleMetadata:
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -14,27 +14,10 @@ from torchmetrics import Metric, MetricCollection
|
|
|
14
14
|
|
|
15
15
|
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
16
16
|
from rslearn.utils import Feature, STGeometry
|
|
17
|
+
from rslearn.utils.colors import DEFAULT_COLORS
|
|
17
18
|
|
|
18
19
|
from .task import BasicTask
|
|
19
20
|
|
|
20
|
-
DEFAULT_COLORS = [
|
|
21
|
-
(255, 0, 0),
|
|
22
|
-
(0, 255, 0),
|
|
23
|
-
(0, 0, 255),
|
|
24
|
-
(255, 255, 0),
|
|
25
|
-
(0, 255, 255),
|
|
26
|
-
(255, 0, 255),
|
|
27
|
-
(0, 128, 0),
|
|
28
|
-
(255, 160, 122),
|
|
29
|
-
(139, 69, 19),
|
|
30
|
-
(128, 128, 128),
|
|
31
|
-
(255, 255, 255),
|
|
32
|
-
(143, 188, 143),
|
|
33
|
-
(95, 158, 160),
|
|
34
|
-
(255, 200, 0),
|
|
35
|
-
(128, 0, 0),
|
|
36
|
-
]
|
|
37
|
-
|
|
38
21
|
|
|
39
22
|
class DetectionTask(BasicTask):
|
|
40
23
|
"""A point or bounding box detection task."""
|
|
@@ -66,20 +66,18 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
66
66
|
return {}, {}
|
|
67
67
|
|
|
68
68
|
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
69
|
-
|
|
70
|
-
assert raw_inputs["targets"].image.shape[1] == 1
|
|
71
|
-
labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
|
|
69
|
+
labels = raw_inputs["targets"].get_hw_tensor().float() * self.scale_factor
|
|
72
70
|
|
|
73
71
|
if self.nodata_value is not None:
|
|
74
|
-
valid = (
|
|
75
|
-
raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
|
|
76
|
-
).float()
|
|
72
|
+
valid = (raw_inputs["targets"].get_hw_tensor() != self.nodata_value).float()
|
|
77
73
|
else:
|
|
78
74
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
79
75
|
|
|
76
|
+
# Wrap in RasterImage with CTHW format (C=1, T=1) so values and valid can be
|
|
77
|
+
# used in image transforms.
|
|
80
78
|
return {}, {
|
|
81
|
-
"values": labels,
|
|
82
|
-
"valid": valid,
|
|
79
|
+
"values": RasterImage(labels[None, None, :, :], timestamps=None),
|
|
80
|
+
"valid": RasterImage(valid[None, None, :, :], timestamps=None),
|
|
83
81
|
}
|
|
84
82
|
|
|
85
83
|
def process_output(
|
|
@@ -121,7 +119,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
121
119
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
122
120
|
if target_dict is None:
|
|
123
121
|
raise ValueError("target_dict is required for visualization")
|
|
124
|
-
gt_values = target_dict["
|
|
122
|
+
gt_values = target_dict["values"].get_hw_tensor().cpu().numpy()
|
|
125
123
|
pred_values = output.cpu().numpy()[0, :, :]
|
|
126
124
|
gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
|
|
127
125
|
pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
|
|
@@ -210,8 +208,10 @@ class PerPixelRegressionHead(Predictor):
|
|
|
210
208
|
|
|
211
209
|
losses = {}
|
|
212
210
|
if targets:
|
|
213
|
-
labels = torch.stack(
|
|
214
|
-
|
|
211
|
+
labels = torch.stack(
|
|
212
|
+
[target["values"].get_hw_tensor() for target in targets]
|
|
213
|
+
)
|
|
214
|
+
mask = torch.stack([target["valid"].get_hw_tensor() for target in targets])
|
|
215
215
|
|
|
216
216
|
if self.loss_mode == "mse":
|
|
217
217
|
scores = torch.square(outputs - labels)
|
|
@@ -262,14 +262,14 @@ class PerPixelRegressionMetricWrapper(Metric):
|
|
|
262
262
|
"""
|
|
263
263
|
if not isinstance(preds, torch.Tensor):
|
|
264
264
|
preds = torch.stack(preds)
|
|
265
|
-
labels = torch.stack([target["values"] for target in targets])
|
|
265
|
+
labels = torch.stack([target["values"].get_hw_tensor() for target in targets])
|
|
266
266
|
|
|
267
267
|
# Sub-select the valid labels.
|
|
268
268
|
# We flatten the prediction and label images at valid pixels.
|
|
269
269
|
if len(preds.shape) == 4:
|
|
270
270
|
assert preds.shape[1] == 1
|
|
271
271
|
preds = preds[:, 0, :, :]
|
|
272
|
-
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
272
|
+
mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
|
|
273
273
|
preds = preds[mask]
|
|
274
274
|
labels = labels[mask]
|
|
275
275
|
if len(preds) == 0:
|
|
@@ -17,28 +17,10 @@ from rslearn.train.model_context import (
|
|
|
17
17
|
SampleMetadata,
|
|
18
18
|
)
|
|
19
19
|
from rslearn.utils import Feature
|
|
20
|
+
from rslearn.utils.colors import DEFAULT_COLORS
|
|
20
21
|
|
|
21
22
|
from .task import BasicTask
|
|
22
23
|
|
|
23
|
-
# TODO: This is duplicated code fix it
|
|
24
|
-
DEFAULT_COLORS = [
|
|
25
|
-
(255, 0, 0),
|
|
26
|
-
(0, 255, 0),
|
|
27
|
-
(0, 0, 255),
|
|
28
|
-
(255, 255, 0),
|
|
29
|
-
(0, 255, 255),
|
|
30
|
-
(255, 0, 255),
|
|
31
|
-
(0, 128, 0),
|
|
32
|
-
(255, 160, 122),
|
|
33
|
-
(139, 69, 19),
|
|
34
|
-
(128, 128, 128),
|
|
35
|
-
(255, 255, 255),
|
|
36
|
-
(143, 188, 143),
|
|
37
|
-
(95, 158, 160),
|
|
38
|
-
(255, 200, 0),
|
|
39
|
-
(128, 0, 0),
|
|
40
|
-
]
|
|
41
|
-
|
|
42
24
|
|
|
43
25
|
class SegmentationTask(BasicTask):
|
|
44
26
|
"""A segmentation (per-pixel classification) task."""
|
|
@@ -146,9 +128,7 @@ class SegmentationTask(BasicTask):
|
|
|
146
128
|
return {}, {}
|
|
147
129
|
|
|
148
130
|
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
149
|
-
|
|
150
|
-
assert raw_inputs["targets"].image.shape[1] == 1
|
|
151
|
-
labels = raw_inputs["targets"].image[0, 0, :, :].long()
|
|
131
|
+
labels = raw_inputs["targets"].get_hw_tensor().long()
|
|
152
132
|
|
|
153
133
|
if self.class_id_mapping is not None:
|
|
154
134
|
new_labels = labels.clone()
|
|
@@ -164,9 +144,11 @@ class SegmentationTask(BasicTask):
|
|
|
164
144
|
else:
|
|
165
145
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
166
146
|
|
|
147
|
+
# Wrap in RasterImage with CTHW format (C=1, T=1) so classes and valid can be
|
|
148
|
+
# used in image transforms.
|
|
167
149
|
return {}, {
|
|
168
|
-
"classes": labels,
|
|
169
|
-
"valid": valid,
|
|
150
|
+
"classes": RasterImage(labels[None, None, :, :], timestamps=None),
|
|
151
|
+
"valid": RasterImage(valid[None, None, :, :], timestamps=None),
|
|
170
152
|
}
|
|
171
153
|
|
|
172
154
|
def process_output(
|
|
@@ -224,7 +206,7 @@ class SegmentationTask(BasicTask):
|
|
|
224
206
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
225
207
|
if target_dict is None:
|
|
226
208
|
raise ValueError("target_dict is required for visualization")
|
|
227
|
-
gt_classes = target_dict["classes"].cpu().numpy()
|
|
209
|
+
gt_classes = target_dict["classes"].get_hw_tensor().cpu().numpy()
|
|
228
210
|
pred_classes = output.cpu().numpy().argmax(axis=0)
|
|
229
211
|
gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
|
|
230
212
|
pred_vis = np.zeros(
|
|
@@ -309,12 +291,19 @@ class SegmentationTask(BasicTask):
|
|
|
309
291
|
class SegmentationHead(Predictor):
|
|
310
292
|
"""Head for segmentation task."""
|
|
311
293
|
|
|
312
|
-
def __init__(
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
weights: list[float] | None = None,
|
|
297
|
+
dice_loss: bool = False,
|
|
298
|
+
temperature: float = 1.0,
|
|
299
|
+
):
|
|
313
300
|
"""Initialize a new SegmentationTask.
|
|
314
301
|
|
|
315
302
|
Args:
|
|
316
303
|
weights: weights for cross entropy loss (Tensor of size C)
|
|
317
304
|
dice_loss: weather to add dice loss to cross entropy
|
|
305
|
+
temperature: temperature scaling for softmax, does not affect the loss,
|
|
306
|
+
only the predictor outputs
|
|
318
307
|
"""
|
|
319
308
|
super().__init__()
|
|
320
309
|
if weights is not None:
|
|
@@ -322,6 +311,7 @@ class SegmentationHead(Predictor):
|
|
|
322
311
|
else:
|
|
323
312
|
self.weights = None
|
|
324
313
|
self.dice_loss = dice_loss
|
|
314
|
+
self.temperature = temperature
|
|
325
315
|
|
|
326
316
|
def forward(
|
|
327
317
|
self,
|
|
@@ -350,12 +340,16 @@ class SegmentationHead(Predictor):
|
|
|
350
340
|
)
|
|
351
341
|
|
|
352
342
|
logits = intermediates.feature_maps[0]
|
|
353
|
-
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
343
|
+
outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)
|
|
354
344
|
|
|
355
345
|
losses = {}
|
|
356
346
|
if targets:
|
|
357
|
-
labels = torch.stack(
|
|
358
|
-
|
|
347
|
+
labels = torch.stack(
|
|
348
|
+
[target["classes"].get_hw_tensor() for target in targets], dim=0
|
|
349
|
+
)
|
|
350
|
+
mask = torch.stack(
|
|
351
|
+
[target["valid"].get_hw_tensor() for target in targets], dim=0
|
|
352
|
+
)
|
|
359
353
|
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
360
354
|
logits, labels, weight=self.weights, reduction="none"
|
|
361
355
|
)
|
|
@@ -368,7 +362,8 @@ class SegmentationHead(Predictor):
|
|
|
368
362
|
# the summed mask loss be zero.
|
|
369
363
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
370
364
|
if self.dice_loss:
|
|
371
|
-
|
|
365
|
+
softmax_woT = torch.nn.functional.softmax(logits, dim=1)
|
|
366
|
+
dice_loss = DiceLoss()(softmax_woT, labels, mask)
|
|
372
367
|
losses["dice"] = dice_loss
|
|
373
368
|
|
|
374
369
|
return ModelOutput(
|
|
@@ -419,12 +414,12 @@ class SegmentationMetric(Metric):
|
|
|
419
414
|
"""
|
|
420
415
|
if not isinstance(preds, torch.Tensor):
|
|
421
416
|
preds = torch.stack(preds)
|
|
422
|
-
labels = torch.stack([target["classes"] for target in targets])
|
|
417
|
+
labels = torch.stack([target["classes"].get_hw_tensor() for target in targets])
|
|
423
418
|
|
|
424
419
|
# Sub-select the valid labels.
|
|
425
420
|
# We flatten the prediction and label images at valid pixels.
|
|
426
421
|
# Prediction is changed from BCHW to BHWC so we can select the valid BHW mask.
|
|
427
|
-
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
422
|
+
mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
|
|
428
423
|
preds = preds.permute(0, 2, 3, 1)[mask]
|
|
429
424
|
labels = labels[mask]
|
|
430
425
|
if len(preds) == 0:
|
|
@@ -54,36 +54,26 @@ class Concatenate(Transform):
|
|
|
54
54
|
target_dict: the target
|
|
55
55
|
|
|
56
56
|
Returns:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Otherwise it will be a torch.Tensor.
|
|
57
|
+
(input_dicts, target_dicts) where the entry corresponding to
|
|
58
|
+
output_selector contains the concatenated RasterImage.
|
|
60
59
|
"""
|
|
61
|
-
|
|
62
|
-
return_raster_image: bool = False
|
|
60
|
+
tensors: list[torch.Tensor] = []
|
|
63
61
|
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
62
|
+
|
|
64
63
|
for selector, wanted_bands in self.selections.items():
|
|
65
64
|
image = read_selector(input_dict, target_dict, selector)
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
# number of timestamps
|
|
80
|
-
timestamps = image.timestamps
|
|
81
|
-
if return_raster_image:
|
|
82
|
-
result = RasterImage(
|
|
83
|
-
torch.concatenate(images, dim=self.concatenate_dim),
|
|
84
|
-
timestamps=timestamps,
|
|
85
|
-
)
|
|
86
|
-
else:
|
|
87
|
-
result = torch.concatenate(images, dim=self.concatenate_dim)
|
|
65
|
+
if wanted_bands:
|
|
66
|
+
tensors.append(image.image[wanted_bands, :, :])
|
|
67
|
+
else:
|
|
68
|
+
tensors.append(image.image)
|
|
69
|
+
if timestamps is None and image.timestamps is not None:
|
|
70
|
+
# assume all concatenated modalities have the same
|
|
71
|
+
# number of timestamps
|
|
72
|
+
timestamps = image.timestamps
|
|
73
|
+
|
|
74
|
+
result = RasterImage(
|
|
75
|
+
torch.concatenate(tensors, dim=self.concatenate_dim),
|
|
76
|
+
timestamps=timestamps,
|
|
77
|
+
)
|
|
88
78
|
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
89
79
|
return input_dict, target_dict
|
rslearn/train/transforms/crop.py
CHANGED
|
@@ -71,9 +71,7 @@ class Crop(Transform):
|
|
|
71
71
|
"remove_from_top": remove_from_top,
|
|
72
72
|
}
|
|
73
73
|
|
|
74
|
-
def apply_image(
|
|
75
|
-
self, image: RasterImage | torch.Tensor, state: dict[str, Any]
|
|
76
|
-
) -> RasterImage | torch.Tensor:
|
|
74
|
+
def apply_image(self, image: RasterImage, state: dict[str, Any]) -> RasterImage:
|
|
77
75
|
"""Apply the sampled state on the specified image.
|
|
78
76
|
|
|
79
77
|
Args:
|
|
@@ -84,22 +82,13 @@ class Crop(Transform):
|
|
|
84
82
|
crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
|
|
85
83
|
remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
|
|
86
84
|
remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
|
|
87
|
-
|
|
88
|
-
image.image
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
95
|
-
else:
|
|
96
|
-
image = torchvision.transforms.functional.crop(
|
|
97
|
-
image,
|
|
98
|
-
top=remove_from_top,
|
|
99
|
-
left=remove_from_left,
|
|
100
|
-
height=crop_size,
|
|
101
|
-
width=crop_size,
|
|
102
|
-
)
|
|
85
|
+
image.image = torchvision.transforms.functional.crop(
|
|
86
|
+
image.image,
|
|
87
|
+
top=remove_from_top,
|
|
88
|
+
left=remove_from_left,
|
|
89
|
+
height=crop_size,
|
|
90
|
+
width=crop_size,
|
|
91
|
+
)
|
|
103
92
|
return image
|
|
104
93
|
|
|
105
94
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -57,16 +57,10 @@ class Flip(Transform):
|
|
|
57
57
|
image: the image to transform.
|
|
58
58
|
state: the sampled state.
|
|
59
59
|
"""
|
|
60
|
-
if
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
image.image = torch.flip(image.image, dims=[-2])
|
|
65
|
-
elif isinstance(image, torch.Tensor):
|
|
66
|
-
if state["horizontal"]:
|
|
67
|
-
image = torch.flip(image, dims=[-1])
|
|
68
|
-
if state["vertical"]:
|
|
69
|
-
image = torch.flip(image, dims=[-2])
|
|
60
|
+
if state["horizontal"]:
|
|
61
|
+
image.image = torch.flip(image.image, dims=[-1])
|
|
62
|
+
if state["vertical"]:
|
|
63
|
+
image.image = torch.flip(image.image, dims=[-2])
|
|
70
64
|
return image
|
|
71
65
|
|
|
72
66
|
def apply_boxes(
|
rslearn/train/transforms/mask.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Mask transform."""
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
3
|
from rslearn.train.model_context import RasterImage
|
|
6
4
|
from rslearn.train.transforms.transform import Transform, read_selector
|
|
7
5
|
|
|
@@ -32,9 +30,7 @@ class Mask(Transform):
|
|
|
32
30
|
self.mask_selector = mask_selector
|
|
33
31
|
self.mask_value = mask_value
|
|
34
32
|
|
|
35
|
-
def apply_image(
|
|
36
|
-
self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
|
|
37
|
-
) -> torch.Tensor | RasterImage:
|
|
33
|
+
def apply_image(self, image: RasterImage, mask: RasterImage) -> RasterImage:
|
|
38
34
|
"""Apply the mask on the image.
|
|
39
35
|
|
|
40
36
|
Args:
|
|
@@ -44,21 +40,19 @@ class Mask(Transform):
|
|
|
44
40
|
Returns:
|
|
45
41
|
masked image
|
|
46
42
|
"""
|
|
47
|
-
#
|
|
48
|
-
|
|
49
|
-
mask = mask.image
|
|
43
|
+
# Extract the mask tensor (CTHW format)
|
|
44
|
+
mask_tensor = mask.image
|
|
50
45
|
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
# Tile the mask to have same number of bands (C dimension) as the image.
|
|
47
|
+
if image.shape[0] != mask_tensor.shape[0]:
|
|
48
|
+
if mask_tensor.shape[0] != 1:
|
|
53
49
|
raise ValueError(
|
|
54
50
|
"expected mask to either have same bands as image, or one band"
|
|
55
51
|
)
|
|
56
|
-
|
|
52
|
+
# Repeat along C dimension, keep T, H, W the same
|
|
53
|
+
mask_tensor = mask_tensor.repeat(image.shape[0], 1, 1, 1)
|
|
57
54
|
|
|
58
|
-
|
|
59
|
-
image[mask == 0] = self.mask_value
|
|
60
|
-
else:
|
|
61
|
-
image.image[mask == 0] = self.mask_value
|
|
55
|
+
image.image[mask_tensor == 0] = self.mask_value
|
|
62
56
|
return image
|
|
63
57
|
|
|
64
58
|
def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
|