rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +49 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +18 -12
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/models/terramind.py
CHANGED
|
@@ -8,8 +8,11 @@ import torch.nn.functional as F
|
|
|
8
8
|
from einops import rearrange
|
|
9
9
|
from terratorch.registry import BACKBONE_REGISTRY
|
|
10
10
|
|
|
11
|
+
from rslearn.train.model_context import ModelContext
|
|
11
12
|
from rslearn.train.transforms.transform import Transform
|
|
12
13
|
|
|
14
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
# TerraMind v1 provides two sizes: base and large
|
|
15
18
|
class TerramindSize(str, Enum):
|
|
@@ -85,7 +88,7 @@ PRETRAINED_BANDS = {
|
|
|
85
88
|
}
|
|
86
89
|
|
|
87
90
|
|
|
88
|
-
class Terramind(
|
|
91
|
+
class Terramind(FeatureExtractor):
|
|
89
92
|
"""Terramind backbones."""
|
|
90
93
|
|
|
91
94
|
def __init__(
|
|
@@ -123,21 +126,25 @@ class Terramind(torch.nn.Module):
|
|
|
123
126
|
self.modalities = modalities
|
|
124
127
|
self.do_resizing = do_resizing
|
|
125
128
|
|
|
126
|
-
def forward(self,
|
|
129
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
127
130
|
"""Forward pass for the Terramind model.
|
|
128
131
|
|
|
129
132
|
Args:
|
|
130
|
-
|
|
133
|
+
context: the model context. Input dicts must include modalities as keys
|
|
134
|
+
which are defined in the self.modalities list.
|
|
131
135
|
|
|
132
136
|
Returns:
|
|
133
|
-
|
|
137
|
+
a FeatureMaps with one feature map from the encoder, at 1/16 of the input
|
|
138
|
+
resolution.
|
|
134
139
|
"""
|
|
135
140
|
model_inputs = {}
|
|
136
141
|
for modality in self.modalities:
|
|
137
142
|
# We assume the all the inputs include the same modalities
|
|
138
|
-
if modality not in inputs[0]:
|
|
143
|
+
if modality not in context.inputs[0]:
|
|
139
144
|
continue
|
|
140
|
-
cur = torch.stack(
|
|
145
|
+
cur = torch.stack(
|
|
146
|
+
[inp[modality] for inp in context.inputs], dim=0
|
|
147
|
+
) # (B, C, H, W)
|
|
141
148
|
if self.do_resizing and (
|
|
142
149
|
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
143
150
|
):
|
|
@@ -159,7 +166,17 @@ class Terramind(torch.nn.Module):
|
|
|
159
166
|
image_features = self.model(model_inputs)[-1]
|
|
160
167
|
batch_size, num_patches, _ = image_features.shape
|
|
161
168
|
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
162
|
-
return
|
|
169
|
+
return FeatureMaps(
|
|
170
|
+
[
|
|
171
|
+
rearrange(
|
|
172
|
+
image_features,
|
|
173
|
+
"b (h w) d -> b d h w",
|
|
174
|
+
b=batch_size,
|
|
175
|
+
h=height,
|
|
176
|
+
w=width,
|
|
177
|
+
)
|
|
178
|
+
]
|
|
179
|
+
)
|
|
163
180
|
|
|
164
181
|
def get_backbone_channels(self) -> list:
|
|
165
182
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/trunk.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from rslearn.log_utils import get_logger
|
|
9
9
|
from rslearn.models.task_embedding import BaseTaskEmbedding
|
|
10
|
+
from rslearn.train.model_context import ModelOutput
|
|
10
11
|
|
|
11
12
|
logger = get_logger(__name__)
|
|
12
13
|
|
|
@@ -32,10 +33,11 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
|
|
|
32
33
|
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
33
34
|
and optionally other keys.
|
|
34
35
|
"""
|
|
36
|
+
raise NotImplementedError
|
|
35
37
|
|
|
36
38
|
@abstractmethod
|
|
37
39
|
def apply_auxiliary_losses(
|
|
38
|
-
self, trunk_out: dict[str, Any], outs:
|
|
40
|
+
self, trunk_out: dict[str, Any], outs: ModelOutput
|
|
39
41
|
) -> None:
|
|
40
42
|
"""Apply auxiliary losses in-place.
|
|
41
43
|
|
|
@@ -43,6 +45,7 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
|
|
|
43
45
|
trunk_out: The output of the trunk.
|
|
44
46
|
outs: The output of the decoders, with key "loss_dict" containing the losses.
|
|
45
47
|
"""
|
|
48
|
+
raise NotImplementedError
|
|
46
49
|
|
|
47
50
|
|
|
48
51
|
class DecoderTrunk(torch.nn.Module):
|
|
@@ -122,7 +125,7 @@ class DecoderTrunk(torch.nn.Module):
|
|
|
122
125
|
return out
|
|
123
126
|
|
|
124
127
|
def apply_auxiliary_losses(
|
|
125
|
-
self, trunk_out: dict[str, Any], outs:
|
|
128
|
+
self, trunk_out: dict[str, Any], outs: ModelOutput
|
|
126
129
|
) -> None:
|
|
127
130
|
"""Apply auxiliary losses in-place.
|
|
128
131
|
|
|
@@ -130,7 +133,7 @@ class DecoderTrunk(torch.nn.Module):
|
|
|
130
133
|
|
|
131
134
|
Args:
|
|
132
135
|
trunk_out: The output of the trunk.
|
|
133
|
-
outs: The output of the decoders
|
|
136
|
+
outs: The output of the decoders.
|
|
134
137
|
"""
|
|
135
138
|
for layer in self.layers:
|
|
136
139
|
layer.apply_auxiliary_losses(trunk_out, outs)
|
rslearn/models/unet.py
CHANGED
|
@@ -5,8 +5,15 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
from .component import (
|
|
11
|
+
FeatureMaps,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UNetDecoder(IntermediateComponent):
|
|
10
17
|
"""UNet-style decoder.
|
|
11
18
|
|
|
12
19
|
It inputs multi-scale features. Starting from last (lowest resolution) feature map,
|
|
@@ -143,23 +150,25 @@ class UNetDecoder(torch.nn.Module):
|
|
|
143
150
|
align_corners=False,
|
|
144
151
|
)
|
|
145
152
|
|
|
146
|
-
def forward(
|
|
147
|
-
self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
148
|
-
) -> torch.Tensor:
|
|
153
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
149
154
|
"""Compute output from multi-scale feature map.
|
|
150
155
|
|
|
151
156
|
Args:
|
|
152
|
-
|
|
153
|
-
|
|
157
|
+
intermediates: the output from the previous model component, which must be a FeatureMaps.
|
|
158
|
+
context: the model context.
|
|
154
159
|
|
|
155
160
|
Returns:
|
|
156
|
-
output
|
|
161
|
+
output FeatureMaps consisting of one map. The embedding size is equal to the
|
|
162
|
+
configured out_channels.
|
|
157
163
|
"""
|
|
164
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
165
|
+
raise ValueError("input to UNetDecoder must be a FeatureMaps")
|
|
166
|
+
|
|
158
167
|
# Reverse the features since we will pass them in from lowest resolution to highest.
|
|
159
|
-
in_features = list(reversed(
|
|
168
|
+
in_features = list(reversed(intermediates.feature_maps))
|
|
160
169
|
cur_features = self.layers[0](in_features[0])
|
|
161
170
|
for in_feat, layer in zip(in_features[1:], self.layers[1:]):
|
|
162
171
|
cur_features = layer(torch.cat([cur_features, in_feat], dim=1))
|
|
163
172
|
if self.original_size_to_interpolate is not None:
|
|
164
173
|
cur_features = self._resize(cur_features)
|
|
165
|
-
return cur_features
|
|
174
|
+
return FeatureMaps([cur_features])
|
rslearn/models/upsample.py
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
"""An upsampling layer."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
IntermediateComponent,
|
|
12
|
+
)
|
|
5
13
|
|
|
6
|
-
|
|
14
|
+
|
|
15
|
+
class Upsample(IntermediateComponent):
|
|
7
16
|
"""Upsamples each input feature map by the same factor."""
|
|
8
17
|
|
|
9
18
|
def __init__(
|
|
@@ -20,16 +29,20 @@ class Upsample(torch.nn.Module):
|
|
|
20
29
|
super().__init__()
|
|
21
30
|
self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
|
|
22
31
|
|
|
23
|
-
def forward(
|
|
24
|
-
|
|
25
|
-
) -> list[torch.Tensor]:
|
|
26
|
-
"""Upsample each feature map.
|
|
32
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
33
|
+
"""Upsample each feature map by scale_factor.
|
|
27
34
|
|
|
28
35
|
Args:
|
|
29
|
-
|
|
30
|
-
|
|
36
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
37
|
+
context: the model context.
|
|
31
38
|
|
|
32
39
|
Returns:
|
|
33
|
-
upsampled feature maps
|
|
40
|
+
upsampled feature maps.
|
|
34
41
|
"""
|
|
35
|
-
|
|
42
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
43
|
+
raise ValueError("input to Upsample must be a FeatureMaps")
|
|
44
|
+
|
|
45
|
+
upsampled_feat_maps = [
|
|
46
|
+
self.layer(feat_map) for feat_map in intermediates.feature_maps
|
|
47
|
+
]
|
|
48
|
+
return FeatureMaps(upsampled_feat_maps)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
from collections.abc import Iterable, Iterator
|
|
5
|
+
from dataclasses import replace
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import shapely
|
|
@@ -9,6 +10,7 @@ import torch
|
|
|
9
10
|
|
|
10
11
|
from rslearn.dataset import Window
|
|
11
12
|
from rslearn.train.dataset import ModelDataset
|
|
13
|
+
from rslearn.train.model_context import SampleMetadata
|
|
12
14
|
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
13
15
|
|
|
14
16
|
|
|
@@ -218,7 +220,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
218
220
|
|
|
219
221
|
def __iter__(
|
|
220
222
|
self,
|
|
221
|
-
) -> Iterator[tuple[dict[str, Any], dict[str, Any],
|
|
223
|
+
) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
|
|
222
224
|
"""Iterate over all patches in each element of the underlying ModelDataset."""
|
|
223
225
|
# Iterate over the window IDs until we have returned enough samples.
|
|
224
226
|
window_ids, num_samples_needed = self._get_worker_iteration_data()
|
|
@@ -229,7 +231,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
229
231
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
|
|
230
232
|
window_id
|
|
231
233
|
)
|
|
232
|
-
bounds = metadata
|
|
234
|
+
bounds = metadata.patch_bounds
|
|
233
235
|
|
|
234
236
|
# For simplicity, pad tensors by patch size to ensure that any patch bounds
|
|
235
237
|
# extending outside the window bounds will not have issues when we slice
|
|
@@ -244,7 +246,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
244
246
|
)
|
|
245
247
|
for patch_idx, patch_bounds in enumerate(patches):
|
|
246
248
|
cur_geom = STGeometry(
|
|
247
|
-
metadata
|
|
249
|
+
metadata.projection, shapely.box(*patch_bounds), None
|
|
248
250
|
)
|
|
249
251
|
start_offset = (
|
|
250
252
|
patch_bounds[0] - bounds[0],
|
|
@@ -282,10 +284,12 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
282
284
|
cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
|
|
283
285
|
|
|
284
286
|
# Adjust the metadata as well.
|
|
285
|
-
cur_metadata =
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
287
|
+
cur_metadata = replace(
|
|
288
|
+
metadata,
|
|
289
|
+
patch_bounds=patch_bounds,
|
|
290
|
+
patch_idx=patch_idx,
|
|
291
|
+
num_patches_in_window=len(patches),
|
|
292
|
+
)
|
|
289
293
|
|
|
290
294
|
# Now we can compute input and target dicts via the task.
|
|
291
295
|
input_dict, target_dict = self.dataset.task.process_inputs(
|
|
@@ -297,7 +301,6 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
297
301
|
input_dict, target_dict = self.dataset.transforms(
|
|
298
302
|
input_dict, target_dict
|
|
299
303
|
)
|
|
300
|
-
input_dict["dataset_source"] = self.dataset.name
|
|
301
304
|
|
|
302
305
|
if num_samples_returned < num_samples_needed:
|
|
303
306
|
yield input_dict, target_dict, cur_metadata
|
|
@@ -346,7 +349,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
346
349
|
)
|
|
347
350
|
self.windows = self.dataset.get_dataset_examples()
|
|
348
351
|
self.window_cache: dict[
|
|
349
|
-
int, tuple[dict[str, Any], dict[str, Any],
|
|
352
|
+
int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
|
|
350
353
|
] = {}
|
|
351
354
|
|
|
352
355
|
# Precompute the batch boundaries for each window
|
|
@@ -360,7 +363,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
360
363
|
|
|
361
364
|
def get_raw_inputs(
|
|
362
365
|
self, index: int
|
|
363
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
366
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
364
367
|
"""Get the raw inputs for a single patch. Retrieve from cache if possible.
|
|
365
368
|
|
|
366
369
|
Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
|
|
@@ -410,13 +413,13 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
410
413
|
|
|
411
414
|
def __getitem__(
|
|
412
415
|
self, index: int
|
|
413
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
416
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
414
417
|
"""Return (input_dict, target_dict, metadata) for a single flattened patch."""
|
|
415
418
|
(window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
|
|
416
419
|
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
|
|
417
|
-
bounds = metadata
|
|
420
|
+
bounds = metadata.patch_bounds
|
|
418
421
|
|
|
419
|
-
cur_geom = STGeometry(metadata
|
|
422
|
+
cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
|
|
420
423
|
start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
|
|
421
424
|
end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
|
|
422
425
|
|
|
@@ -428,10 +431,12 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
428
431
|
)
|
|
429
432
|
|
|
430
433
|
# Adjust the metadata as well.
|
|
431
|
-
cur_metadata =
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
434
|
+
cur_metadata = replace(
|
|
435
|
+
metadata,
|
|
436
|
+
patch_bounds=patch_bounds,
|
|
437
|
+
patch_idx=patch_idx,
|
|
438
|
+
num_patches_in_window=num_patches,
|
|
439
|
+
)
|
|
435
440
|
|
|
436
441
|
# Now we can compute input and target dicts via the task.
|
|
437
442
|
input_dict, target_dict = self.dataset.task.process_inputs(
|
|
@@ -441,7 +446,6 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
441
446
|
)
|
|
442
447
|
input_dict.update(cur_passthrough_inputs)
|
|
443
448
|
input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
|
|
444
|
-
input_dict["dataset_source"] = self.dataset.name
|
|
445
449
|
|
|
446
450
|
return input_dict, target_dict, cur_metadata
|
|
447
451
|
|
rslearn/train/dataset.py
CHANGED
|
@@ -20,13 +20,15 @@ from rslearn.config import (
|
|
|
20
20
|
LayerConfig,
|
|
21
21
|
)
|
|
22
22
|
from rslearn.dataset.dataset import Dataset
|
|
23
|
+
from rslearn.dataset.storage.file import FileWindowStorage
|
|
23
24
|
from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
|
|
24
25
|
from rslearn.log_utils import get_logger
|
|
25
|
-
from rslearn.train.tasks import Task
|
|
26
26
|
from rslearn.utils.feature import Feature
|
|
27
27
|
from rslearn.utils.geometry import PixelBounds
|
|
28
28
|
from rslearn.utils.mp import star_imap_unordered
|
|
29
29
|
|
|
30
|
+
from .model_context import SampleMetadata
|
|
31
|
+
from .tasks import Task
|
|
30
32
|
from .transforms import Sequential
|
|
31
33
|
|
|
32
34
|
logger = get_logger(__name__)
|
|
@@ -575,37 +577,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
575
577
|
else:
|
|
576
578
|
self.patch_size = split_config.get_patch_size()
|
|
577
579
|
|
|
578
|
-
|
|
579
|
-
windows = self.dataset.load_windows(
|
|
580
|
-
groups=split_config.groups,
|
|
581
|
-
names=split_config.names,
|
|
582
|
-
show_progress=True,
|
|
583
|
-
workers=workers,
|
|
584
|
-
)
|
|
585
|
-
elif split_config.groups:
|
|
586
|
-
windows = self.dataset.load_windows(
|
|
587
|
-
groups=split_config.groups, show_progress=True, workers=workers
|
|
588
|
-
)
|
|
589
|
-
else:
|
|
590
|
-
windows = self.dataset.load_windows(show_progress=True, workers=workers)
|
|
591
|
-
|
|
592
|
-
if split_config.tags:
|
|
593
|
-
# Filter the window.options.
|
|
594
|
-
new_windows = []
|
|
595
|
-
num_removed: dict[str, int] = {}
|
|
596
|
-
for window in windows:
|
|
597
|
-
for k, v in split_config.tags.items():
|
|
598
|
-
if k not in window.options or (v and window.options[k] != v):
|
|
599
|
-
num_removed[k] = num_removed.get(k, 0) + 1
|
|
600
|
-
break
|
|
601
|
-
else:
|
|
602
|
-
new_windows.append(window)
|
|
603
|
-
logger.info(
|
|
604
|
-
f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
|
|
605
|
-
)
|
|
606
|
-
for k, v in num_removed.items():
|
|
607
|
-
logger.info(f"Removed {v} windows due to tag {k}")
|
|
608
|
-
windows = new_windows
|
|
580
|
+
windows = self._get_initial_windows(split_config, workers)
|
|
609
581
|
|
|
610
582
|
# If targets are not needed, remove them from the inputs.
|
|
611
583
|
if split_config.get_skip_targets():
|
|
@@ -615,17 +587,11 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
615
587
|
|
|
616
588
|
# Eliminate windows that are missing either a requisite input layer, or missing
|
|
617
589
|
# all target layers.
|
|
618
|
-
# We use only main thread if the index is set, since that can take a long time
|
|
619
|
-
# to send to the worker threads, it may get serialized for each window.
|
|
620
590
|
new_windows = []
|
|
621
|
-
if workers == 0
|
|
591
|
+
if workers == 0:
|
|
622
592
|
for window in windows:
|
|
623
593
|
if check_window(self.inputs, window) is None:
|
|
624
594
|
continue
|
|
625
|
-
# The index may be set, but now that this check is done, from here on
|
|
626
|
-
# we no longer need it. We set it None so that we don't end up passing
|
|
627
|
-
# it later to the dataloader workers.
|
|
628
|
-
window.index = None
|
|
629
595
|
new_windows.append(window)
|
|
630
596
|
else:
|
|
631
597
|
p = multiprocessing.Pool(workers)
|
|
@@ -681,12 +647,62 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
681
647
|
with open(self.dataset_examples_fname, "w") as f:
|
|
682
648
|
json.dump([self._serialize_item(example) for example in windows], f)
|
|
683
649
|
|
|
650
|
+
def _get_initial_windows(
|
|
651
|
+
self, split_config: SplitConfig, workers: int
|
|
652
|
+
) -> list[Window]:
|
|
653
|
+
"""Get the initial windows before input layer filtering.
|
|
654
|
+
|
|
655
|
+
The windows are filtered based on configured window names, groups, and tags.
|
|
656
|
+
|
|
657
|
+
This is a helper for the init function.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
split_config: the split configuration.
|
|
661
|
+
workers: number of worker processes.
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
list of windows from the dataset after applying the aforementioned filters.
|
|
665
|
+
"""
|
|
666
|
+
# Load windows from dataset.
|
|
667
|
+
# If the window storage is FileWindowStorage, we pass the workers/show_progress arguments.
|
|
668
|
+
kwargs: dict[str, Any] = {}
|
|
669
|
+
if isinstance(self.dataset.storage, FileWindowStorage):
|
|
670
|
+
kwargs["workers"] = workers
|
|
671
|
+
kwargs["show_progress"] = True
|
|
672
|
+
# We also add the name/group filters to the kwargs.
|
|
673
|
+
if split_config.names:
|
|
674
|
+
kwargs["names"] = split_config.names
|
|
675
|
+
if split_config.groups:
|
|
676
|
+
kwargs["groups"] = split_config.groups
|
|
677
|
+
|
|
678
|
+
windows = self.dataset.load_windows(**kwargs)
|
|
679
|
+
|
|
680
|
+
# Filter by tags (if provided) using the window.options.
|
|
681
|
+
if split_config.tags:
|
|
682
|
+
new_windows = []
|
|
683
|
+
num_removed: dict[str, int] = {}
|
|
684
|
+
for window in windows:
|
|
685
|
+
for k, v in split_config.tags.items():
|
|
686
|
+
if k not in window.options or (v and window.options[k] != v):
|
|
687
|
+
num_removed[k] = num_removed.get(k, 0) + 1
|
|
688
|
+
break
|
|
689
|
+
else:
|
|
690
|
+
new_windows.append(window)
|
|
691
|
+
logger.info(
|
|
692
|
+
f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
|
|
693
|
+
)
|
|
694
|
+
for k, v in num_removed.items():
|
|
695
|
+
logger.info(f"Removed {v} windows due to tag {k}")
|
|
696
|
+
windows = new_windows
|
|
697
|
+
|
|
698
|
+
return windows
|
|
699
|
+
|
|
684
700
|
def _serialize_item(self, example: Window) -> dict[str, Any]:
|
|
685
701
|
return example.get_metadata()
|
|
686
702
|
|
|
687
703
|
def _deserialize_item(self, d: dict[str, Any]) -> Window:
|
|
688
704
|
return Window.from_metadata(
|
|
689
|
-
|
|
705
|
+
self.dataset.storage,
|
|
690
706
|
d,
|
|
691
707
|
)
|
|
692
708
|
|
|
@@ -713,7 +729,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
713
729
|
|
|
714
730
|
def get_raw_inputs(
|
|
715
731
|
self, idx: int
|
|
716
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
732
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
717
733
|
"""Get the raw inputs and base metadata for this example.
|
|
718
734
|
|
|
719
735
|
This is the raster or vector data before being processed by the Task. So it
|
|
@@ -775,21 +791,23 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
775
791
|
if data_input.passthrough:
|
|
776
792
|
passthrough_inputs[name] = raw_inputs[name]
|
|
777
793
|
|
|
778
|
-
metadata =
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
794
|
+
metadata = SampleMetadata(
|
|
795
|
+
window_group=window.group,
|
|
796
|
+
window_name=window.name,
|
|
797
|
+
window_bounds=window.bounds,
|
|
798
|
+
patch_bounds=bounds,
|
|
799
|
+
patch_idx=0,
|
|
800
|
+
num_patches_in_window=1,
|
|
801
|
+
time_range=window.time_range,
|
|
802
|
+
projection=window.projection,
|
|
803
|
+
dataset_source=self.name,
|
|
804
|
+
)
|
|
787
805
|
|
|
788
806
|
return raw_inputs, passthrough_inputs, metadata
|
|
789
807
|
|
|
790
808
|
def __getitem__(
|
|
791
809
|
self, idx: int
|
|
792
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
810
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
793
811
|
"""Read one training example.
|
|
794
812
|
|
|
795
813
|
Args:
|
|
@@ -801,8 +819,6 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
801
819
|
logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
|
|
802
820
|
|
|
803
821
|
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(idx)
|
|
804
|
-
metadata["patch_idx"] = 0
|
|
805
|
-
metadata["num_patches"] = 1
|
|
806
822
|
|
|
807
823
|
input_dict, target_dict = self.task.process_inputs(
|
|
808
824
|
raw_inputs,
|
|
@@ -811,7 +827,6 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
811
827
|
)
|
|
812
828
|
input_dict.update(passthrough_inputs)
|
|
813
829
|
input_dict, target_dict = self.transforms(input_dict, target_dict)
|
|
814
|
-
input_dict["dataset_source"] = self.name
|
|
815
830
|
|
|
816
831
|
logger.debug("__getitem__ finish pid=%d item_idx=%d", os.getpid(), idx)
|
|
817
832
|
|