rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -20,13 +20,15 @@ from rslearn.config import (
|
|
|
20
20
|
LayerConfig,
|
|
21
21
|
)
|
|
22
22
|
from rslearn.dataset.dataset import Dataset
|
|
23
|
+
from rslearn.dataset.storage.file import FileWindowStorage
|
|
23
24
|
from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
|
|
24
25
|
from rslearn.log_utils import get_logger
|
|
25
|
-
from rslearn.train.tasks import Task
|
|
26
26
|
from rslearn.utils.feature import Feature
|
|
27
|
-
from rslearn.utils.geometry import PixelBounds
|
|
27
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
28
28
|
from rslearn.utils.mp import star_imap_unordered
|
|
29
29
|
|
|
30
|
+
from .model_context import SampleMetadata
|
|
31
|
+
from .tasks import Task
|
|
30
32
|
from .transforms import Sequential
|
|
31
33
|
|
|
32
34
|
logger = get_logger(__name__)
|
|
@@ -128,6 +130,10 @@ class DataInput:
|
|
|
128
130
|
"""Specification of a piece of data from a window that is needed for training.
|
|
129
131
|
|
|
130
132
|
The DataInput includes which layer(s) the data can be obtained from for each window.
|
|
133
|
+
|
|
134
|
+
Note that this class is not a dataclass because jsonargparse does not play well
|
|
135
|
+
with dataclasses without enabling specialized options which we have not validated
|
|
136
|
+
will work with the rest of our code.
|
|
131
137
|
"""
|
|
132
138
|
|
|
133
139
|
def __init__(
|
|
@@ -141,7 +147,9 @@ class DataInput:
|
|
|
141
147
|
dtype: DType = DType.FLOAT32,
|
|
142
148
|
load_all_layers: bool = False,
|
|
143
149
|
load_all_item_groups: bool = False,
|
|
144
|
-
|
|
150
|
+
resolution_factor: ResolutionFactor = ResolutionFactor(),
|
|
151
|
+
resampling: Resampling = Resampling.nearest,
|
|
152
|
+
):
|
|
145
153
|
"""Initialize a new DataInput.
|
|
146
154
|
|
|
147
155
|
Args:
|
|
@@ -164,6 +172,11 @@ class DataInput:
|
|
|
164
172
|
are reading from. By default, we assume the specified layer name is of
|
|
165
173
|
the form "{layer_name}.{group_idx}" and read that item group only. With
|
|
166
174
|
this option enabled, we ignore the group_idx and read all item groups.
|
|
175
|
+
resolution_factor: controls the resolution at which raster data is loaded for training.
|
|
176
|
+
By default (factor=1), data is loaded at the window resolution.
|
|
177
|
+
E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
|
|
178
|
+
the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
|
|
179
|
+
resampling: resampling method (default nearest neighbor).
|
|
167
180
|
"""
|
|
168
181
|
self.data_type = data_type
|
|
169
182
|
self.layers = layers
|
|
@@ -174,6 +187,8 @@ class DataInput:
|
|
|
174
187
|
self.dtype = dtype
|
|
175
188
|
self.load_all_layers = load_all_layers
|
|
176
189
|
self.load_all_item_groups = load_all_item_groups
|
|
190
|
+
self.resolution_factor = resolution_factor
|
|
191
|
+
self.resampling = resampling
|
|
177
192
|
|
|
178
193
|
|
|
179
194
|
def read_raster_layer_for_data_input(
|
|
@@ -231,15 +246,23 @@ def read_raster_layer_for_data_input(
|
|
|
231
246
|
+ f"window {window.name} layer {layer_name} group {group_idx}"
|
|
232
247
|
)
|
|
233
248
|
|
|
249
|
+
# Get the projection and bounds to read under (multiply window resolution # by
|
|
250
|
+
# the specified resolution factor).
|
|
251
|
+
final_projection = data_input.resolution_factor.multiply_projection(
|
|
252
|
+
window.projection
|
|
253
|
+
)
|
|
254
|
+
final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
|
|
255
|
+
|
|
234
256
|
image = torch.zeros(
|
|
235
|
-
(
|
|
257
|
+
(
|
|
258
|
+
len(needed_bands),
|
|
259
|
+
final_bounds[3] - final_bounds[1],
|
|
260
|
+
final_bounds[2] - final_bounds[0],
|
|
261
|
+
),
|
|
236
262
|
dtype=get_torch_dtype(data_input.dtype),
|
|
237
263
|
)
|
|
238
264
|
|
|
239
265
|
for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
|
|
240
|
-
final_projection, final_bounds = band_set.get_final_projection_and_bounds(
|
|
241
|
-
window.projection, bounds
|
|
242
|
-
)
|
|
243
266
|
if band_set.format is None:
|
|
244
267
|
raise ValueError(f"No format specified for {layer_name}")
|
|
245
268
|
raster_format = band_set.instantiate_raster_format()
|
|
@@ -247,44 +270,16 @@ def read_raster_layer_for_data_input(
|
|
|
247
270
|
layer_name, band_set.bands, group_idx=group_idx
|
|
248
271
|
)
|
|
249
272
|
|
|
250
|
-
#
|
|
251
|
-
#
|
|
252
|
-
#
|
|
253
|
-
#
|
|
254
|
-
#
|
|
255
|
-
#
|
|
256
|
-
is_bounds_zoomable = True
|
|
257
|
-
if band_set.zoom_offset < 0:
|
|
258
|
-
zoom_factor = 2 ** (-band_set.zoom_offset)
|
|
259
|
-
is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
|
|
260
|
-
bounds[2] - bounds[0]
|
|
261
|
-
) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
|
|
262
|
-
bounds[3] - bounds[1]
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
if is_bounds_zoomable:
|
|
266
|
-
src = raster_format.decode_raster(
|
|
267
|
-
raster_dir, final_projection, final_bounds
|
|
268
|
-
)
|
|
269
|
-
|
|
270
|
-
# Resize to patch size if needed.
|
|
271
|
-
# This is for band sets that are stored at a lower resolution.
|
|
272
|
-
# Here we assume that it is a multiple.
|
|
273
|
-
if src.shape[1:3] != image.shape[1:3]:
|
|
274
|
-
if src.shape[1] < image.shape[1]:
|
|
275
|
-
factor = image.shape[1] // src.shape[1]
|
|
276
|
-
src = src.repeat(repeats=factor, axis=1).repeat(
|
|
277
|
-
repeats=factor, axis=2
|
|
278
|
-
)
|
|
279
|
-
else:
|
|
280
|
-
factor = src.shape[1] // image.shape[1]
|
|
281
|
-
src = src[:, ::factor, ::factor]
|
|
282
|
-
|
|
283
|
-
else:
|
|
284
|
-
src = raster_format.decode_raster(
|
|
285
|
-
raster_dir, window.projection, bounds, resampling=Resampling.nearest
|
|
286
|
-
)
|
|
273
|
+
# TODO: previously we try to read based on band_set.zoom_offset when possible,
|
|
274
|
+
# and handle zooming in with torch.repeat (if resampling method is nearest
|
|
275
|
+
# neighbor). However, we have not benchmarked whether this actually improves
|
|
276
|
+
# data loading speed, so for simplicity, for now we let rasterio handle the
|
|
277
|
+
# resampling. If it really is much faster to handle it via torch, then it may
|
|
278
|
+
# make sense to bring back that functionality.
|
|
287
279
|
|
|
280
|
+
src = raster_format.decode_raster(
|
|
281
|
+
raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
|
|
282
|
+
)
|
|
288
283
|
image[dst_indexes, :, :] = torch.as_tensor(
|
|
289
284
|
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
|
|
290
285
|
)
|
|
@@ -575,37 +570,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
575
570
|
else:
|
|
576
571
|
self.patch_size = split_config.get_patch_size()
|
|
577
572
|
|
|
578
|
-
|
|
579
|
-
windows = self.dataset.load_windows(
|
|
580
|
-
groups=split_config.groups,
|
|
581
|
-
names=split_config.names,
|
|
582
|
-
show_progress=True,
|
|
583
|
-
workers=workers,
|
|
584
|
-
)
|
|
585
|
-
elif split_config.groups:
|
|
586
|
-
windows = self.dataset.load_windows(
|
|
587
|
-
groups=split_config.groups, show_progress=True, workers=workers
|
|
588
|
-
)
|
|
589
|
-
else:
|
|
590
|
-
windows = self.dataset.load_windows(show_progress=True, workers=workers)
|
|
591
|
-
|
|
592
|
-
if split_config.tags:
|
|
593
|
-
# Filter the window.options.
|
|
594
|
-
new_windows = []
|
|
595
|
-
num_removed: dict[str, int] = {}
|
|
596
|
-
for window in windows:
|
|
597
|
-
for k, v in split_config.tags.items():
|
|
598
|
-
if k not in window.options or (v and window.options[k] != v):
|
|
599
|
-
num_removed[k] = num_removed.get(k, 0) + 1
|
|
600
|
-
break
|
|
601
|
-
else:
|
|
602
|
-
new_windows.append(window)
|
|
603
|
-
logger.info(
|
|
604
|
-
f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
|
|
605
|
-
)
|
|
606
|
-
for k, v in num_removed.items():
|
|
607
|
-
logger.info(f"Removed {v} windows due to tag {k}")
|
|
608
|
-
windows = new_windows
|
|
573
|
+
windows = self._get_initial_windows(split_config, workers)
|
|
609
574
|
|
|
610
575
|
# If targets are not needed, remove them from the inputs.
|
|
611
576
|
if split_config.get_skip_targets():
|
|
@@ -615,17 +580,11 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
615
580
|
|
|
616
581
|
# Eliminate windows that are missing either a requisite input layer, or missing
|
|
617
582
|
# all target layers.
|
|
618
|
-
# We use only main thread if the index is set, since that can take a long time
|
|
619
|
-
# to send to the worker threads, it may get serialized for each window.
|
|
620
583
|
new_windows = []
|
|
621
|
-
if workers == 0
|
|
584
|
+
if workers == 0:
|
|
622
585
|
for window in windows:
|
|
623
586
|
if check_window(self.inputs, window) is None:
|
|
624
587
|
continue
|
|
625
|
-
# The index may be set, but now that this check is done, from here on
|
|
626
|
-
# we no longer need it. We set it None so that we don't end up passing
|
|
627
|
-
# it later to the dataloader workers.
|
|
628
|
-
window.index = None
|
|
629
588
|
new_windows.append(window)
|
|
630
589
|
else:
|
|
631
590
|
p = multiprocessing.Pool(workers)
|
|
@@ -681,12 +640,62 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
681
640
|
with open(self.dataset_examples_fname, "w") as f:
|
|
682
641
|
json.dump([self._serialize_item(example) for example in windows], f)
|
|
683
642
|
|
|
643
|
+
def _get_initial_windows(
|
|
644
|
+
self, split_config: SplitConfig, workers: int
|
|
645
|
+
) -> list[Window]:
|
|
646
|
+
"""Get the initial windows before input layer filtering.
|
|
647
|
+
|
|
648
|
+
The windows are filtered based on configured window names, groups, and tags.
|
|
649
|
+
|
|
650
|
+
This is a helper for the init function.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
split_config: the split configuration.
|
|
654
|
+
workers: number of worker processes.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
list of windows from the dataset after applying the aforementioned filters.
|
|
658
|
+
"""
|
|
659
|
+
# Load windows from dataset.
|
|
660
|
+
# If the window storage is FileWindowStorage, we pass the workers/show_progress arguments.
|
|
661
|
+
kwargs: dict[str, Any] = {}
|
|
662
|
+
if isinstance(self.dataset.storage, FileWindowStorage):
|
|
663
|
+
kwargs["workers"] = workers
|
|
664
|
+
kwargs["show_progress"] = True
|
|
665
|
+
# We also add the name/group filters to the kwargs.
|
|
666
|
+
if split_config.names:
|
|
667
|
+
kwargs["names"] = split_config.names
|
|
668
|
+
if split_config.groups:
|
|
669
|
+
kwargs["groups"] = split_config.groups
|
|
670
|
+
|
|
671
|
+
windows = self.dataset.load_windows(**kwargs)
|
|
672
|
+
|
|
673
|
+
# Filter by tags (if provided) using the window.options.
|
|
674
|
+
if split_config.tags:
|
|
675
|
+
new_windows = []
|
|
676
|
+
num_removed: dict[str, int] = {}
|
|
677
|
+
for window in windows:
|
|
678
|
+
for k, v in split_config.tags.items():
|
|
679
|
+
if k not in window.options or (v and window.options[k] != v):
|
|
680
|
+
num_removed[k] = num_removed.get(k, 0) + 1
|
|
681
|
+
break
|
|
682
|
+
else:
|
|
683
|
+
new_windows.append(window)
|
|
684
|
+
logger.info(
|
|
685
|
+
f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
|
|
686
|
+
)
|
|
687
|
+
for k, v in num_removed.items():
|
|
688
|
+
logger.info(f"Removed {v} windows due to tag {k}")
|
|
689
|
+
windows = new_windows
|
|
690
|
+
|
|
691
|
+
return windows
|
|
692
|
+
|
|
684
693
|
def _serialize_item(self, example: Window) -> dict[str, Any]:
|
|
685
694
|
return example.get_metadata()
|
|
686
695
|
|
|
687
696
|
def _deserialize_item(self, d: dict[str, Any]) -> Window:
|
|
688
697
|
return Window.from_metadata(
|
|
689
|
-
|
|
698
|
+
self.dataset.storage,
|
|
690
699
|
d,
|
|
691
700
|
)
|
|
692
701
|
|
|
@@ -713,7 +722,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
713
722
|
|
|
714
723
|
def get_raw_inputs(
|
|
715
724
|
self, idx: int
|
|
716
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
725
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
717
726
|
"""Get the raw inputs and base metadata for this example.
|
|
718
727
|
|
|
719
728
|
This is the raster or vector data before being processed by the Task. So it
|
|
@@ -775,21 +784,23 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
775
784
|
if data_input.passthrough:
|
|
776
785
|
passthrough_inputs[name] = raw_inputs[name]
|
|
777
786
|
|
|
778
|
-
metadata =
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
+
metadata = SampleMetadata(
|
|
788
|
+
window_group=window.group,
|
|
789
|
+
window_name=window.name,
|
|
790
|
+
window_bounds=window.bounds,
|
|
791
|
+
patch_bounds=bounds,
|
|
792
|
+
patch_idx=0,
|
|
793
|
+
num_patches_in_window=1,
|
|
794
|
+
time_range=window.time_range,
|
|
795
|
+
projection=window.projection,
|
|
796
|
+
dataset_source=self.name,
|
|
797
|
+
)
|
|
787
798
|
|
|
788
799
|
return raw_inputs, passthrough_inputs, metadata
|
|
789
800
|
|
|
790
801
|
def __getitem__(
|
|
791
802
|
self, idx: int
|
|
792
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
803
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
793
804
|
"""Read one training example.
|
|
794
805
|
|
|
795
806
|
Args:
|
|
@@ -801,8 +812,6 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
801
812
|
logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
|
|
802
813
|
|
|
803
814
|
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(idx)
|
|
804
|
-
metadata["patch_idx"] = 0
|
|
805
|
-
metadata["num_patches"] = 1
|
|
806
815
|
|
|
807
816
|
input_dict, target_dict = self.task.process_inputs(
|
|
808
817
|
raw_inputs,
|
|
@@ -811,7 +820,6 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
811
820
|
)
|
|
812
821
|
input_dict.update(passthrough_inputs)
|
|
813
822
|
input_dict, target_dict = self.transforms(input_dict, target_dict)
|
|
814
|
-
input_dict["dataset_source"] = self.name
|
|
815
823
|
|
|
816
824
|
logger.debug("__getitem__ finish pid=%d item_idx=%d", os.getpid(), idx)
|
|
817
825
|
|
|
@@ -12,6 +12,7 @@ from upath import UPath
|
|
|
12
12
|
|
|
13
13
|
from rslearn.log_utils import get_logger
|
|
14
14
|
|
|
15
|
+
from .model_context import ModelContext, ModelOutput
|
|
15
16
|
from .optimizer import AdamW, OptimizerFactory
|
|
16
17
|
from .scheduler import PlateauScheduler, SchedulerFactory
|
|
17
18
|
from .tasks import Task
|
|
@@ -231,12 +232,16 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
231
232
|
Returns:
|
|
232
233
|
The loss tensor.
|
|
233
234
|
"""
|
|
234
|
-
inputs, targets,
|
|
235
|
+
inputs, targets, metadatas = batch
|
|
236
|
+
context = ModelContext(
|
|
237
|
+
inputs=inputs,
|
|
238
|
+
metadatas=metadatas,
|
|
239
|
+
)
|
|
235
240
|
batch_size = len(inputs)
|
|
236
|
-
model_outputs = self(
|
|
237
|
-
self.on_train_forward(
|
|
241
|
+
model_outputs = self(context, targets)
|
|
242
|
+
self.on_train_forward(context, targets, model_outputs)
|
|
238
243
|
|
|
239
|
-
loss_dict = model_outputs
|
|
244
|
+
loss_dict = model_outputs.loss_dict
|
|
240
245
|
train_loss = sum(loss_dict.values())
|
|
241
246
|
self.log_dict(
|
|
242
247
|
{"train_" + k: v for k, v in loss_dict.items()},
|
|
@@ -266,13 +271,17 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
266
271
|
batch_idx: Integer displaying index of this batch.
|
|
267
272
|
dataloader_idx: Index of the current dataloader.
|
|
268
273
|
"""
|
|
269
|
-
inputs, targets,
|
|
274
|
+
inputs, targets, metadatas = batch
|
|
275
|
+
context = ModelContext(
|
|
276
|
+
inputs=inputs,
|
|
277
|
+
metadatas=metadatas,
|
|
278
|
+
)
|
|
270
279
|
batch_size = len(inputs)
|
|
271
|
-
model_outputs = self(
|
|
272
|
-
self.on_val_forward(
|
|
280
|
+
model_outputs = self(context, targets)
|
|
281
|
+
self.on_val_forward(context, targets, model_outputs)
|
|
273
282
|
|
|
274
|
-
loss_dict = model_outputs
|
|
275
|
-
outputs = model_outputs
|
|
283
|
+
loss_dict = model_outputs.loss_dict
|
|
284
|
+
outputs = model_outputs.outputs
|
|
276
285
|
val_loss = sum(loss_dict.values())
|
|
277
286
|
self.log_dict(
|
|
278
287
|
{"val_" + k: v for k, v in loss_dict.items()},
|
|
@@ -304,12 +313,16 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
304
313
|
dataloader_idx: Index of the current dataloader.
|
|
305
314
|
"""
|
|
306
315
|
inputs, targets, metadatas = batch
|
|
316
|
+
context = ModelContext(
|
|
317
|
+
inputs=inputs,
|
|
318
|
+
metadatas=metadatas,
|
|
319
|
+
)
|
|
307
320
|
batch_size = len(inputs)
|
|
308
|
-
model_outputs = self(
|
|
309
|
-
self.on_test_forward(
|
|
321
|
+
model_outputs = self(context, targets)
|
|
322
|
+
self.on_test_forward(context, targets, model_outputs)
|
|
310
323
|
|
|
311
|
-
loss_dict = model_outputs
|
|
312
|
-
outputs = model_outputs
|
|
324
|
+
loss_dict = model_outputs.loss_dict
|
|
325
|
+
outputs = model_outputs.outputs
|
|
313
326
|
test_loss = sum(loss_dict.values())
|
|
314
327
|
self.log_dict(
|
|
315
328
|
{"test_" + k: v for k, v in loss_dict.items()},
|
|
@@ -345,7 +358,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
345
358
|
|
|
346
359
|
def predict_step(
|
|
347
360
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
348
|
-
) ->
|
|
361
|
+
) -> ModelOutput:
|
|
349
362
|
"""Compute the predicted class probabilities.
|
|
350
363
|
|
|
351
364
|
Args:
|
|
@@ -356,63 +369,69 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
356
369
|
Returns:
|
|
357
370
|
Output predicted probabilities.
|
|
358
371
|
"""
|
|
359
|
-
inputs, _,
|
|
360
|
-
|
|
372
|
+
inputs, _, metadatas = batch
|
|
373
|
+
context = ModelContext(
|
|
374
|
+
inputs=inputs,
|
|
375
|
+
metadatas=metadatas,
|
|
376
|
+
)
|
|
377
|
+
model_outputs = self(context)
|
|
361
378
|
return model_outputs
|
|
362
379
|
|
|
363
|
-
def forward(
|
|
380
|
+
def forward(
|
|
381
|
+
self, context: ModelContext, targets: list[dict[str, Any]] | None = None
|
|
382
|
+
) -> ModelOutput:
|
|
364
383
|
"""Forward pass of the model.
|
|
365
384
|
|
|
366
385
|
Args:
|
|
367
|
-
|
|
368
|
-
|
|
386
|
+
context: the model context.
|
|
387
|
+
targets: the target dicts.
|
|
369
388
|
|
|
370
389
|
Returns:
|
|
371
390
|
Output of the model.
|
|
372
391
|
"""
|
|
373
|
-
return self.model(
|
|
392
|
+
return self.model(context, targets)
|
|
374
393
|
|
|
375
394
|
def on_train_forward(
|
|
376
395
|
self,
|
|
377
|
-
|
|
396
|
+
context: ModelContext,
|
|
378
397
|
targets: list[dict[str, Any]],
|
|
379
|
-
model_outputs:
|
|
398
|
+
model_outputs: ModelOutput,
|
|
380
399
|
) -> None:
|
|
381
400
|
"""Hook to run after the forward pass of the model during training.
|
|
382
401
|
|
|
383
402
|
Args:
|
|
384
|
-
|
|
403
|
+
context: The model context.
|
|
385
404
|
targets: The target batch.
|
|
386
|
-
model_outputs: The output of the model
|
|
405
|
+
model_outputs: The output of the model.
|
|
387
406
|
"""
|
|
388
407
|
pass
|
|
389
408
|
|
|
390
409
|
def on_val_forward(
|
|
391
410
|
self,
|
|
392
|
-
|
|
411
|
+
context: ModelContext,
|
|
393
412
|
targets: list[dict[str, Any]],
|
|
394
|
-
model_outputs:
|
|
413
|
+
model_outputs: ModelOutput,
|
|
395
414
|
) -> None:
|
|
396
415
|
"""Hook to run after the forward pass of the model during validation.
|
|
397
416
|
|
|
398
417
|
Args:
|
|
399
|
-
|
|
418
|
+
context: The model context.
|
|
400
419
|
targets: The target batch.
|
|
401
|
-
model_outputs: The output of the model
|
|
420
|
+
model_outputs: The output of the model.
|
|
402
421
|
"""
|
|
403
422
|
pass
|
|
404
423
|
|
|
405
424
|
def on_test_forward(
|
|
406
425
|
self,
|
|
407
|
-
|
|
426
|
+
context: ModelContext,
|
|
408
427
|
targets: list[dict[str, Any]],
|
|
409
|
-
model_outputs:
|
|
428
|
+
model_outputs: ModelOutput,
|
|
410
429
|
) -> None:
|
|
411
430
|
"""Hook to run after the forward pass of the model during testing.
|
|
412
431
|
|
|
413
432
|
Args:
|
|
414
|
-
|
|
433
|
+
context: The model context.
|
|
415
434
|
targets: The target batch.
|
|
416
|
-
model_outputs: The output of the model
|
|
435
|
+
model_outputs: The output of the model.
|
|
417
436
|
"""
|
|
418
437
|
pass
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Data classes to provide various context to models."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from rslearn.utils.geometry import PixelBounds, Projection
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class SampleMetadata:
|
|
15
|
+
"""Metadata pertaining to an example."""
|
|
16
|
+
|
|
17
|
+
window_group: str
|
|
18
|
+
window_name: str
|
|
19
|
+
window_bounds: PixelBounds
|
|
20
|
+
patch_bounds: PixelBounds
|
|
21
|
+
patch_idx: int
|
|
22
|
+
num_patches_in_window: int
|
|
23
|
+
time_range: tuple[datetime, datetime] | None
|
|
24
|
+
projection: Projection
|
|
25
|
+
|
|
26
|
+
# Task name to differentiate different tasks.
|
|
27
|
+
dataset_source: str | None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModelContext:
|
|
32
|
+
"""Context to pass to all model components."""
|
|
33
|
+
|
|
34
|
+
# One input dict per example in the batch.
|
|
35
|
+
inputs: list[dict[str, torch.Tensor]]
|
|
36
|
+
# One SampleMetadata per example in the batch.
|
|
37
|
+
metadatas: list[SampleMetadata]
|
|
38
|
+
# Arbitrary dict that components can add to.
|
|
39
|
+
context_dict: dict[str, Any] = field(default_factory=lambda: {})
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class ModelOutput:
|
|
44
|
+
"""The output from the Predictor.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
outputs: output compatible with the configured Task.
|
|
48
|
+
loss_dict: map from loss names to scalar tensors.
|
|
49
|
+
metadata: arbitrary dict that can be used to store other outputs.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
outputs: Iterable[Any]
|
|
53
|
+
loss_dict: dict[str, torch.Tensor]
|
|
54
|
+
metadata: dict[str, Any] = field(default_factory=lambda: {})
|