rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
rslearn/train/data_module.py
CHANGED
|
@@ -15,7 +15,7 @@ from .dataset import DataInput, ModelDataset, RetryDataset, SplitConfig
|
|
|
15
15
|
|
|
16
16
|
def collate_fn(
|
|
17
17
|
batch: list[tuple[dict[str, Any], dict[str, Any]]],
|
|
18
|
-
) -> tuple
|
|
18
|
+
) -> tuple:
|
|
19
19
|
"""Collate batch of training examples.
|
|
20
20
|
|
|
21
21
|
We just make list of the inputs and another of the targets.
|
|
@@ -48,7 +48,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
48
48
|
val_config: SplitConfig = SplitConfig(),
|
|
49
49
|
test_config: SplitConfig = SplitConfig(),
|
|
50
50
|
predict_config: SplitConfig = SplitConfig(),
|
|
51
|
-
):
|
|
51
|
+
) -> None:
|
|
52
52
|
"""Initialize a new RslearnDataModule.
|
|
53
53
|
|
|
54
54
|
Args:
|
|
@@ -79,7 +79,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
79
79
|
"predict": default_config.update(predict_config),
|
|
80
80
|
}
|
|
81
81
|
|
|
82
|
-
def setup(self, stage: str):
|
|
82
|
+
def setup(self, stage: str) -> None:
|
|
83
83
|
"""Set up datasets and samplers.
|
|
84
84
|
|
|
85
85
|
Args:
|
|
@@ -106,12 +106,13 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
106
106
|
|
|
107
107
|
def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]:
|
|
108
108
|
dataset = self.datasets[split]
|
|
109
|
+
persistent_workers = self.num_workers > 0
|
|
109
110
|
kwargs = dict(
|
|
110
111
|
dataset=dataset,
|
|
111
112
|
batch_size=self.batch_size,
|
|
112
113
|
num_workers=self.num_workers,
|
|
113
114
|
collate_fn=collate_fn,
|
|
114
|
-
persistent_workers=
|
|
115
|
+
persistent_workers=persistent_workers,
|
|
115
116
|
)
|
|
116
117
|
sampler_factory = self.split_configs[split].sampler
|
|
117
118
|
if sampler_factory:
|
rslearn/train/dataset.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Default Dataset for rslearn."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
3
4
|
import multiprocessing
|
|
4
5
|
import os
|
|
5
6
|
import random
|
|
@@ -47,7 +48,9 @@ class SamplerFactory:
|
|
|
47
48
|
class RandomSamplerFactory(SamplerFactory):
|
|
48
49
|
"""A sampler factory for RandomSampler."""
|
|
49
50
|
|
|
50
|
-
def __init__(
|
|
51
|
+
def __init__(
|
|
52
|
+
self, replacement: bool = False, num_samples: int | None = None
|
|
53
|
+
) -> None:
|
|
51
54
|
"""Initialize a RandomSamplerFactory.
|
|
52
55
|
|
|
53
56
|
Args:
|
|
@@ -75,7 +78,9 @@ class RandomSamplerFactory(SamplerFactory):
|
|
|
75
78
|
class WeightedRandomSamplerFactory(SamplerFactory):
|
|
76
79
|
"""A sampler factory for WeightedRandomSampler."""
|
|
77
80
|
|
|
78
|
-
def __init__(
|
|
81
|
+
def __init__(
|
|
82
|
+
self, option_key: str, num_samples: int, replacement: bool = True
|
|
83
|
+
) -> None:
|
|
79
84
|
"""Initialize a WeightedRandomSamplerFactory.
|
|
80
85
|
|
|
81
86
|
Args:
|
|
@@ -119,7 +124,7 @@ class DataInput:
|
|
|
119
124
|
passthrough: bool = False,
|
|
120
125
|
is_target: bool = False,
|
|
121
126
|
dtype: DType = DType.FLOAT32,
|
|
122
|
-
):
|
|
127
|
+
) -> None:
|
|
123
128
|
"""Initialize a new DataInput.
|
|
124
129
|
|
|
125
130
|
Args:
|
|
@@ -157,7 +162,7 @@ class SplitConfig:
|
|
|
157
162
|
overlap_ratio: float | None = None,
|
|
158
163
|
load_all_patches: bool | None = None,
|
|
159
164
|
skip_targets: bool | None = None,
|
|
160
|
-
):
|
|
165
|
+
) -> None:
|
|
161
166
|
"""Initialize a new SplitConfig.
|
|
162
167
|
|
|
163
168
|
Args:
|
|
@@ -242,7 +247,7 @@ class SplitConfig:
|
|
|
242
247
|
return True if self.skip_targets is True else False
|
|
243
248
|
|
|
244
249
|
|
|
245
|
-
def check_window(inputs: dict[str, DataInput], window: Window) ->
|
|
250
|
+
def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
|
|
246
251
|
"""Verify that the window has the required layers based on the specified inputs.
|
|
247
252
|
|
|
248
253
|
Args:
|
|
@@ -254,7 +259,7 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> bool:
|
|
|
254
259
|
"""
|
|
255
260
|
|
|
256
261
|
# Make sure window has all the needed layers.
|
|
257
|
-
def is_any_layer_available(data_input):
|
|
262
|
+
def is_any_layer_available(data_input: DataInput) -> bool:
|
|
258
263
|
for layer_name in data_input.layers:
|
|
259
264
|
completed_fname = window.path / "layers" / layer_name / "completed"
|
|
260
265
|
if completed_fname.exists():
|
|
@@ -285,7 +290,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
285
290
|
inputs: dict[str, DataInput],
|
|
286
291
|
task: Task,
|
|
287
292
|
workers: int,
|
|
288
|
-
):
|
|
293
|
+
) -> None:
|
|
289
294
|
"""Instantiate a new ModelDataset.
|
|
290
295
|
|
|
291
296
|
Args:
|
|
@@ -347,37 +352,53 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
347
352
|
|
|
348
353
|
# Eliminate windows that are missing either a requisite input layer, or missing
|
|
349
354
|
# all target layers.
|
|
350
|
-
p = multiprocessing.Pool(workers)
|
|
351
|
-
outputs = star_imap_unordered(
|
|
352
|
-
p,
|
|
353
|
-
check_window,
|
|
354
|
-
[
|
|
355
|
-
dict(
|
|
356
|
-
inputs=self.inputs,
|
|
357
|
-
window=window,
|
|
358
|
-
)
|
|
359
|
-
for window in windows
|
|
360
|
-
],
|
|
361
|
-
)
|
|
362
355
|
new_windows = []
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
356
|
+
if workers == 0:
|
|
357
|
+
for window in windows:
|
|
358
|
+
if check_window(self.inputs, window) is None:
|
|
359
|
+
continue
|
|
360
|
+
new_windows.append(window)
|
|
361
|
+
else:
|
|
362
|
+
p = multiprocessing.Pool(workers)
|
|
363
|
+
outputs = star_imap_unordered(
|
|
364
|
+
p,
|
|
365
|
+
check_window,
|
|
366
|
+
[
|
|
367
|
+
dict(
|
|
368
|
+
inputs=self.inputs,
|
|
369
|
+
window=window,
|
|
370
|
+
)
|
|
371
|
+
for window in windows
|
|
372
|
+
],
|
|
373
|
+
)
|
|
374
|
+
for window in tqdm.tqdm(
|
|
375
|
+
outputs, total=len(windows), desc="Checking available layers in windows"
|
|
376
|
+
):
|
|
377
|
+
if window is None:
|
|
378
|
+
continue
|
|
379
|
+
new_windows.append(window)
|
|
380
|
+
p.close()
|
|
370
381
|
windows = new_windows
|
|
371
382
|
|
|
383
|
+
# Sort the windows to ensure that the dataset is consistent across GPUs.
|
|
384
|
+
# Inconsistent ordering can lead to a subset of windows being processed during
|
|
385
|
+
# "model test" / "model predict" when using multiple GPUs.
|
|
386
|
+
# We use a hash so that functionality like num_samples limit gets a random
|
|
387
|
+
# subset of windows (with respect to the hash function choice).
|
|
388
|
+
windows.sort(
|
|
389
|
+
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
390
|
+
)
|
|
391
|
+
|
|
372
392
|
# Limit windows to num_samples if requested.
|
|
373
393
|
if split_config.num_samples:
|
|
374
|
-
#
|
|
394
|
+
# The windows are sorted by hash of window name so this distribution should
|
|
395
|
+
# be representative of the population.
|
|
375
396
|
windows = windows[0 : split_config.num_samples]
|
|
376
397
|
|
|
377
|
-
self.windows = windows
|
|
398
|
+
self.windows: list = windows
|
|
378
399
|
|
|
379
400
|
# If we're loading all patches, we need to include the patch details.
|
|
380
|
-
if split_config.get_load_all_patches():
|
|
401
|
+
if split_config.get_load_all_patches() and self.patch_size is not None:
|
|
381
402
|
patches = []
|
|
382
403
|
overlap_size = int(
|
|
383
404
|
self.patch_size[0] * split_config.overlap_ratio
|
|
@@ -386,6 +407,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
386
407
|
)
|
|
387
408
|
for window in self.windows:
|
|
388
409
|
cur_patches = []
|
|
410
|
+
if window is None:
|
|
411
|
+
raise ValueError("Window is None in load_all_patches")
|
|
389
412
|
for col in range(
|
|
390
413
|
window.bounds[0],
|
|
391
414
|
window.bounds[2],
|
|
@@ -412,7 +435,9 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
412
435
|
"""Returns the dataset length."""
|
|
413
436
|
return len(self.windows)
|
|
414
437
|
|
|
415
|
-
def __getitem__(
|
|
438
|
+
def __getitem__(
|
|
439
|
+
self, idx: int
|
|
440
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
|
416
441
|
"""Read one training example.
|
|
417
442
|
|
|
418
443
|
Args:
|
|
@@ -429,7 +454,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
429
454
|
window, bounds, (patch_idx, num_patches) = window
|
|
430
455
|
elif self.patch_size:
|
|
431
456
|
|
|
432
|
-
def get_patch_range(n_patch, n_window):
|
|
457
|
+
def get_patch_range(n_patch: int, n_window: int) -> list[int]:
|
|
433
458
|
if n_patch > n_window:
|
|
434
459
|
# Select arbitrary range containing the entire window.
|
|
435
460
|
# Basically arbitrarily padding the window to get to patch size.
|
|
@@ -459,7 +484,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
459
484
|
bounds = window.bounds
|
|
460
485
|
|
|
461
486
|
# Read the inputs and targets.
|
|
462
|
-
def read_input(data_input: DataInput):
|
|
487
|
+
def read_input(data_input: DataInput) -> torch.Tensor:
|
|
463
488
|
# First enumerate all options of individual layers to read.
|
|
464
489
|
layer_options = []
|
|
465
490
|
for layer_name in data_input.layers:
|
|
@@ -473,7 +498,13 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
473
498
|
# the options, as well as picking multiple for series inputs.
|
|
474
499
|
layer = random.choice(layer_options)
|
|
475
500
|
layer_dir = window.path / "layers" / layer
|
|
476
|
-
|
|
501
|
+
|
|
502
|
+
# The model config may reference a specific group within a layer, like
|
|
503
|
+
# "image.2" in a dataset that has a layer "image" with max_matches > 1.
|
|
504
|
+
# So we need to split off the period. Layer names should not contain
|
|
505
|
+
# period.
|
|
506
|
+
layer_ds_key = layer.split(".")[0]
|
|
507
|
+
layer_config = self.dataset.layers[layer_ds_key]
|
|
477
508
|
|
|
478
509
|
if data_input.data_type == "raster":
|
|
479
510
|
assert isinstance(layer_config, RasterLayerConfig)
|
|
@@ -481,6 +512,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
481
512
|
# See what different sets of bands we need to read to get all the
|
|
482
513
|
# configured bands.
|
|
483
514
|
needed_bands = data_input.bands
|
|
515
|
+
if needed_bands is None:
|
|
516
|
+
raise ValueError(f"No bands specified for {layer}")
|
|
484
517
|
needed_band_indexes = {}
|
|
485
518
|
for i, band in enumerate(needed_bands):
|
|
486
519
|
needed_band_indexes[band] = i
|
|
@@ -488,6 +521,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
488
521
|
for band_set in layer_config.band_sets:
|
|
489
522
|
needed_src_indexes = []
|
|
490
523
|
needed_dst_indexes = []
|
|
524
|
+
if band_set.bands is None:
|
|
525
|
+
continue
|
|
491
526
|
for i, band in enumerate(band_set.bands):
|
|
492
527
|
if band not in needed_band_indexes:
|
|
493
528
|
continue
|
|
@@ -514,12 +549,20 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
514
549
|
_, final_bounds = band_set.get_final_projection_and_bounds(
|
|
515
550
|
window.projection, bounds
|
|
516
551
|
)
|
|
552
|
+
if band_set.format is None:
|
|
553
|
+
raise ValueError(f"No format specified for {layer}")
|
|
517
554
|
raster_format = load_raster_format(
|
|
518
555
|
RasterFormatConfig(band_set.format["name"], band_set.format)
|
|
519
556
|
)
|
|
557
|
+
if band_set.bands is None:
|
|
558
|
+
# Raising Error as It is unclear the intended behavior here.
|
|
559
|
+
raise ValueError("No bands specified for band set")
|
|
520
560
|
cur_path = layer_dir / "_".join(band_set.bands)
|
|
561
|
+
if final_bounds is None:
|
|
562
|
+
raise ValueError("Final bounds are None")
|
|
521
563
|
src = raster_format.decode_raster(cur_path, final_bounds)
|
|
522
|
-
|
|
564
|
+
if src is None:
|
|
565
|
+
raise ValueError(f"Source is None for {data_input}")
|
|
523
566
|
# Resize to patch size if needed.
|
|
524
567
|
# This is for band sets that are stored at a lower resolution.
|
|
525
568
|
# Here we assume that it is a multiple.
|
|
@@ -594,7 +637,7 @@ class RetryDataset(torch.utils.data.Dataset):
|
|
|
594
637
|
|
|
595
638
|
def __init__(
|
|
596
639
|
self, dataset: torch.utils.data.Dataset, retries: int = 3, delay: float = 5
|
|
597
|
-
):
|
|
640
|
+
) -> None:
|
|
598
641
|
"""Create a new RetryDataset.
|
|
599
642
|
|
|
600
643
|
Args:
|
|
@@ -606,7 +649,7 @@ class RetryDataset(torch.utils.data.Dataset):
|
|
|
606
649
|
self.retries = retries
|
|
607
650
|
self.delay = delay
|
|
608
651
|
|
|
609
|
-
def __len__(self):
|
|
652
|
+
def __len__(self) -> int:
|
|
610
653
|
"""Return length of the dataset."""
|
|
611
654
|
return len(self.dataset)
|
|
612
655
|
|
|
@@ -49,7 +49,7 @@ class RestoreConfig:
|
|
|
49
49
|
"""Returns the state dict configured in this RestoreConfig."""
|
|
50
50
|
print(f"loading state dict from {self.restore_path}")
|
|
51
51
|
with self.restore_path.open("rb") as f:
|
|
52
|
-
state_dict = torch.load(f)
|
|
52
|
+
state_dict = torch.load(f, map_location="cpu")
|
|
53
53
|
for k in self.selector:
|
|
54
54
|
state_dict = state_dict[k]
|
|
55
55
|
|
|
@@ -124,6 +124,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
124
124
|
self.plateau_min_lr = plateau_min_lr
|
|
125
125
|
self.plateau_cooldown = plateau_cooldown
|
|
126
126
|
self.visualize_dir = visualize_dir
|
|
127
|
+
self.restore_config = restore_config
|
|
127
128
|
|
|
128
129
|
if print_parameters:
|
|
129
130
|
for name, param in self.named_parameters():
|
|
@@ -132,8 +133,19 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
132
133
|
if print_model:
|
|
133
134
|
print(self.model)
|
|
134
135
|
|
|
135
|
-
|
|
136
|
-
|
|
136
|
+
self.epochs = 0
|
|
137
|
+
|
|
138
|
+
metrics = self.task.get_metrics()
|
|
139
|
+
self.val_metrics = metrics.clone(prefix="val_")
|
|
140
|
+
self.test_metrics = metrics.clone(prefix="test_")
|
|
141
|
+
|
|
142
|
+
self.schedulers: dict = {}
|
|
143
|
+
|
|
144
|
+
def on_fit_start(self) -> None:
|
|
145
|
+
"""Called when the fit begins."""
|
|
146
|
+
# Only restore if doing a fresh fit.
|
|
147
|
+
if self.trainer.ckpt_path is None and self.restore_config:
|
|
148
|
+
state_dict = self.restore_config.get_state_dict()
|
|
137
149
|
missing_keys, unexpected_keys = self.model.load_state_dict(
|
|
138
150
|
state_dict, strict=False
|
|
139
151
|
)
|
|
@@ -142,14 +154,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
142
154
|
f"warning: restore yielded missing_keys={missing_keys} and unexpected_keys={unexpected_keys}"
|
|
143
155
|
)
|
|
144
156
|
|
|
145
|
-
self.epochs = 0
|
|
146
|
-
|
|
147
|
-
metrics = self.task.get_metrics()
|
|
148
|
-
self.val_metrics = metrics.clone(prefix="val_")
|
|
149
|
-
self.test_metrics = metrics.clone(prefix="test_")
|
|
150
|
-
|
|
151
|
-
self.schedulers = {}
|
|
152
|
-
|
|
153
157
|
def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
|
|
154
158
|
"""Initialize the optimizer and learning rate scheduler.
|
|
155
159
|
|
|
@@ -8,7 +8,12 @@ from lightning.pytorch import LightningModule, Trainer
|
|
|
8
8
|
from lightning.pytorch.callbacks import BasePredictionWriter
|
|
9
9
|
from upath import UPath
|
|
10
10
|
|
|
11
|
-
from rslearn.config import
|
|
11
|
+
from rslearn.config import (
|
|
12
|
+
LayerType,
|
|
13
|
+
RasterFormatConfig,
|
|
14
|
+
RasterLayerConfig,
|
|
15
|
+
VectorLayerConfig,
|
|
16
|
+
)
|
|
12
17
|
from rslearn.dataset import Dataset
|
|
13
18
|
from rslearn.utils.array import copy_spatial_array
|
|
14
19
|
from rslearn.utils.raster_format import load_raster_format
|
|
@@ -20,17 +25,14 @@ from .lightning_module import RslearnLightningModule
|
|
|
20
25
|
class PatchPredictionMerger:
|
|
21
26
|
"""Base class for merging predictions from multiple patches."""
|
|
22
27
|
|
|
23
|
-
def merge(
|
|
24
|
-
|
|
25
|
-
) -> tuple[Sequence[Any], Sequence[Any]]:
|
|
26
|
-
"""Merge the outputs and metadatas.
|
|
28
|
+
def merge(self, outputs: Sequence[Any]) -> tuple[Sequence[Any]]:
|
|
29
|
+
"""Merge the outputs.
|
|
27
30
|
|
|
28
31
|
Args:
|
|
29
32
|
outputs: the outputs to process.
|
|
30
|
-
metadatas: the metadatas to process.
|
|
31
33
|
|
|
32
34
|
Returns:
|
|
33
|
-
the merged outputs
|
|
35
|
+
the merged outputs.
|
|
34
36
|
"""
|
|
35
37
|
raise NotImplementedError
|
|
36
38
|
|
|
@@ -57,6 +59,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
57
59
|
output_layer: which layer to write the outputs under.
|
|
58
60
|
path_options: additional options for path to pass to fsspec
|
|
59
61
|
selector: keys to access the desired output in the output dict if needed.
|
|
62
|
+
e.g ["key1", "key2"] gets output["key1"]["key2"]
|
|
60
63
|
merger: merger to use to merge outputs from overlapped patches.
|
|
61
64
|
"""
|
|
62
65
|
super().__init__(write_interval="batch")
|
|
@@ -65,13 +68,16 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
65
68
|
self.path = UPath(path, **path_options)
|
|
66
69
|
self.dataset = Dataset(self.path)
|
|
67
70
|
self.layer_config = self.dataset.layers[self.output_layer]
|
|
68
|
-
|
|
71
|
+
# TODO: This is a bit of a hack to get the type checker to be happy.
|
|
72
|
+
self.format: Any
|
|
69
73
|
if self.layer_config.layer_type == LayerType.RASTER:
|
|
74
|
+
assert isinstance(self.layer_config, RasterLayerConfig)
|
|
70
75
|
band_cfg = self.layer_config.band_sets[0]
|
|
71
76
|
self.format = load_raster_format(
|
|
72
77
|
RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
|
|
73
78
|
)
|
|
74
79
|
elif self.layer_config.layer_type == LayerType.VECTOR:
|
|
80
|
+
assert isinstance(self.layer_config, VectorLayerConfig)
|
|
75
81
|
self.format = load_vector_format(self.layer_config.format)
|
|
76
82
|
else:
|
|
77
83
|
raise ValueError(f"invalid layer type {self.layer_config.layer_type}")
|
|
@@ -81,7 +87,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
81
87
|
# Map from window name to pending data to write.
|
|
82
88
|
# This is used when windows are split up into patches, so the data from all the
|
|
83
89
|
# patches of each window need to be reconstituted.
|
|
84
|
-
self.pending_outputs = {}
|
|
90
|
+
self.pending_outputs: dict[str, Any] = {}
|
|
85
91
|
|
|
86
92
|
def write_on_batch_end(
|
|
87
93
|
self,
|
|
@@ -92,7 +98,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
92
98
|
batch: Any,
|
|
93
99
|
batch_idx: int,
|
|
94
100
|
dataloader_idx: int,
|
|
95
|
-
):
|
|
101
|
+
) -> None:
|
|
96
102
|
"""Write a batch of predictions into the rslearn dataset.
|
|
97
103
|
|
|
98
104
|
Args:
|
|
@@ -112,6 +118,8 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
112
118
|
]
|
|
113
119
|
|
|
114
120
|
for output, metadata in zip(outputs, metadatas):
|
|
121
|
+
if not isinstance(output, dict):
|
|
122
|
+
raise ValueError(f"Unsupported output type {type(output)}")
|
|
115
123
|
for k in self.selector:
|
|
116
124
|
output = output[k]
|
|
117
125
|
|
|
@@ -120,7 +128,9 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
120
128
|
window_bounds = metadata["window_bounds"]
|
|
121
129
|
|
|
122
130
|
if self.layer_config.layer_type == LayerType.RASTER:
|
|
123
|
-
if window_name not in self.pending_outputs
|
|
131
|
+
if window_name not in self.pending_outputs and isinstance(
|
|
132
|
+
output, np.ndarray
|
|
133
|
+
):
|
|
124
134
|
self.pending_outputs[window_name] = np.zeros(
|
|
125
135
|
(
|
|
126
136
|
output.shape[0],
|
|
@@ -167,6 +177,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
167
177
|
)
|
|
168
178
|
|
|
169
179
|
if self.layer_config.layer_type == LayerType.RASTER:
|
|
180
|
+
assert isinstance(self.layer_config, RasterLayerConfig)
|
|
170
181
|
band_dir = layer_dir / "_".join(self.layer_config.band_sets[0].bands)
|
|
171
182
|
self.format.encode_raster(
|
|
172
183
|
band_dir, metadata["projection"], window_bounds, pending_output
|
|
@@ -26,8 +26,8 @@ class ClassificationTask(BasicTask):
|
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
28
|
property_name: str,
|
|
29
|
-
classes: list
|
|
30
|
-
filters: list[tuple[str, str]]
|
|
29
|
+
classes: list, # TODO: Should this be a list of str or int or can it be both?
|
|
30
|
+
filters: list[tuple[str, str]] = [],
|
|
31
31
|
read_class_id: bool = False,
|
|
32
32
|
allow_invalid: bool = False,
|
|
33
33
|
skip_unknown_categories: bool = False,
|
|
@@ -37,7 +37,7 @@ class ClassificationTask(BasicTask):
|
|
|
37
37
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
38
38
|
positive_class: str | None = None,
|
|
39
39
|
positive_class_threshold: float = 0.5,
|
|
40
|
-
**kwargs,
|
|
40
|
+
**kwargs: Any,
|
|
41
41
|
):
|
|
42
42
|
"""Initialize a new ClassificationTask.
|
|
43
43
|
|
|
@@ -95,9 +95,6 @@ class ClassificationTask(BasicTask):
|
|
|
95
95
|
else:
|
|
96
96
|
self.positive_class_id = self.classes.index(self.positive_class)
|
|
97
97
|
|
|
98
|
-
if not self.filters:
|
|
99
|
-
self.filters = []
|
|
100
|
-
|
|
101
98
|
def process_inputs(
|
|
102
99
|
self,
|
|
103
100
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
@@ -120,6 +117,8 @@ class ClassificationTask(BasicTask):
|
|
|
120
117
|
|
|
121
118
|
data = raw_inputs["targets"]
|
|
122
119
|
for feat in data:
|
|
120
|
+
if feat.properties is None:
|
|
121
|
+
continue
|
|
123
122
|
for property_name, property_value in self.filters:
|
|
124
123
|
if feat.properties.get(property_name) != property_value:
|
|
125
124
|
continue
|
|
@@ -178,7 +177,7 @@ class ClassificationTask(BasicTask):
|
|
|
178
177
|
class_idx = probs.argmax()
|
|
179
178
|
|
|
180
179
|
if not self.read_class_id:
|
|
181
|
-
value = self.classes[class_idx]
|
|
180
|
+
value = self.classes[class_idx] # type: ignore
|
|
182
181
|
else:
|
|
183
182
|
value = class_idx
|
|
184
183
|
|
|
@@ -192,7 +191,7 @@ class ClassificationTask(BasicTask):
|
|
|
192
191
|
self.property_name: value,
|
|
193
192
|
},
|
|
194
193
|
)
|
|
195
|
-
if self.prob_property:
|
|
194
|
+
if self.prob_property is not None and feature.properties is not None:
|
|
196
195
|
feature.properties[self.prob_property] = probs.tolist()
|
|
197
196
|
return [feature]
|
|
198
197
|
|
|
@@ -215,6 +214,8 @@ class ClassificationTask(BasicTask):
|
|
|
215
214
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
216
215
|
image = Image.fromarray(image)
|
|
217
216
|
draw = ImageDraw.Draw(image)
|
|
217
|
+
if target_dict is None:
|
|
218
|
+
raise ValueError("target_dict is required for visualization")
|
|
218
219
|
target_class = self.classes[target_dict["class"]]
|
|
219
220
|
output_class = self.classes[output.argmax()]
|
|
220
221
|
text = f"Label: {target_class}\nOutput: {output_class}"
|