rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""The SelectBands transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from rslearn.train.model_context import RasterImage
|
|
6
|
+
|
|
7
|
+
from .transform import Transform, read_selector, write_selector
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SelectBands(Transform):
|
|
11
|
+
"""Select a subset of bands from an image."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
band_indices: list[int],
|
|
16
|
+
input_selector: str = "image",
|
|
17
|
+
output_selector: str = "image",
|
|
18
|
+
num_bands_per_timestep: int | None = None,
|
|
19
|
+
):
|
|
20
|
+
"""Initialize a new Concatenate.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
band_indices: the bands to select.
|
|
24
|
+
input_selector: the selector to read the input image.
|
|
25
|
+
output_selector: the output selector under which to save the output image.
|
|
26
|
+
num_bands_per_timestep: the number of bands per image, to distinguish
|
|
27
|
+
between stacked images in an image time series. If set, then the
|
|
28
|
+
band_indices are selected for each image in the time series.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.input_selector = input_selector
|
|
32
|
+
self.output_selector = output_selector
|
|
33
|
+
self.band_indices = band_indices
|
|
34
|
+
self.num_bands_per_timestep = num_bands_per_timestep
|
|
35
|
+
|
|
36
|
+
def forward(
|
|
37
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
38
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
39
|
+
"""Apply concatenation over the inputs and targets.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
input_dict: the input
|
|
43
|
+
target_dict: the target
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
normalized (input_dicts, target_dicts) tuple
|
|
47
|
+
"""
|
|
48
|
+
image = read_selector(input_dict, target_dict, self.input_selector)
|
|
49
|
+
num_bands_per_timestep = (
|
|
50
|
+
self.num_bands_per_timestep
|
|
51
|
+
if self.num_bands_per_timestep is not None
|
|
52
|
+
else image.shape[0]
|
|
53
|
+
)
|
|
54
|
+
if isinstance(image, RasterImage):
|
|
55
|
+
assert num_bands_per_timestep == image.shape[0], (
|
|
56
|
+
"Expect a seperate dimension for timesteps in RasterImages."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if image.shape[0] % num_bands_per_timestep != 0:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Copy the band indices for each timestep in the input.
|
|
65
|
+
wanted_bands: list[int] = []
|
|
66
|
+
for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
|
|
67
|
+
wanted_bands.extend(
|
|
68
|
+
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if isinstance(image, RasterImage):
|
|
72
|
+
image.image = image.image[wanted_bands]
|
|
73
|
+
else:
|
|
74
|
+
image = image[wanted_bands]
|
|
75
|
+
write_selector(input_dict, target_dict, self.output_selector, image)
|
|
76
|
+
return input_dict, target_dict
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Transforms related to Sentinel-1 data."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
9
|
+
from .transform import Transform
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Sentinel1ToDecibels(Transform):
|
|
13
|
+
"""Convert Sentinel-1 data from raw intensity to or from decibels."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
selectors: list[str] = ["image"],
|
|
18
|
+
from_decibels: bool = False,
|
|
19
|
+
epsilon: float = 1e-6,
|
|
20
|
+
):
|
|
21
|
+
"""Initialize a new Sentinel1ToDecibels.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
selectors: the input selectors to apply the transform on.
|
|
25
|
+
from_decibels: convert from decibels to intensities instead of intensity to
|
|
26
|
+
decibels.
|
|
27
|
+
epsilon: when converting to decibels, clip the intensities to this minimum
|
|
28
|
+
value to avoid log issues. This is mostly to avoid pixels that have no
|
|
29
|
+
data with no data value being 0.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.selectors = selectors
|
|
33
|
+
self.from_decibels = from_decibels
|
|
34
|
+
self.epsilon = epsilon
|
|
35
|
+
|
|
36
|
+
def apply_image(
|
|
37
|
+
self, image: torch.Tensor | RasterImage
|
|
38
|
+
) -> torch.Tensor | RasterImage:
|
|
39
|
+
"""Normalize the specified image.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
image: the image to transform.
|
|
43
|
+
"""
|
|
44
|
+
if isinstance(image, torch.Tensor):
|
|
45
|
+
image_to_process = image
|
|
46
|
+
else:
|
|
47
|
+
image_to_process = image.image
|
|
48
|
+
if self.from_decibels:
|
|
49
|
+
# Decibels to linear scale.
|
|
50
|
+
image_to_process = torch.pow(10.0, image_to_process / 10.0)
|
|
51
|
+
else:
|
|
52
|
+
# Linear scale to decibels.
|
|
53
|
+
image_to_process = 10 * torch.log10(
|
|
54
|
+
torch.clamp(image_to_process, min=self.epsilon)
|
|
55
|
+
)
|
|
56
|
+
if isinstance(image, torch.Tensor):
|
|
57
|
+
return image_to_process
|
|
58
|
+
else:
|
|
59
|
+
image.image = image_to_process
|
|
60
|
+
return image
|
|
61
|
+
|
|
62
|
+
def forward(
|
|
63
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
64
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
65
|
+
"""Apply normalization over the inputs and targets.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
input_dict: the input
|
|
69
|
+
target_dict: the target
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
normalized (input_dicts, target_dicts) tuple
|
|
73
|
+
"""
|
|
74
|
+
self.apply_fn(self.apply_image, input_dict, target_dict, self.selectors)
|
|
75
|
+
return input_dict, target_dict
|
|
@@ -6,96 +6,115 @@ from typing import Any
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
transform them.
|
|
14
|
-
"""
|
|
9
|
+
def get_dict_and_subselector(
|
|
10
|
+
input_dict: dict[str, Any], target_dict: dict[str, Any], selector: str
|
|
11
|
+
) -> tuple[dict[str, Any], str]:
|
|
12
|
+
"""Determine whether to use input or target dict, and the sub-selector.
|
|
15
13
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
) -> tuple[dict[str, Any], str]:
|
|
19
|
-
"""Determine whether to use input or target dict, and the sub-selector.
|
|
14
|
+
For example, if the selector is "input/x", then we use input dict and the
|
|
15
|
+
sub-selector is "x".
|
|
20
16
|
|
|
21
|
-
|
|
22
|
-
|
|
17
|
+
If neither input/ nor target/ prefixes are present, then we assume it is for
|
|
18
|
+
input dict.
|
|
23
19
|
|
|
24
|
-
|
|
25
|
-
input dict
|
|
20
|
+
Args:
|
|
21
|
+
input_dict: the input dict
|
|
22
|
+
target_dict: the target dict
|
|
23
|
+
selector: the full selector configured by the user
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
25
|
+
Returns:
|
|
26
|
+
a tuple (referenced dict, sub-selector string)
|
|
27
|
+
"""
|
|
28
|
+
input_prefix = "input/"
|
|
29
|
+
target_prefix = "target/"
|
|
31
30
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
31
|
+
if selector.startswith(input_prefix):
|
|
32
|
+
d = input_dict
|
|
33
|
+
selector = selector[len(input_prefix) :]
|
|
34
|
+
elif selector.startswith(target_prefix):
|
|
35
|
+
d = target_dict
|
|
36
|
+
selector = selector[len(target_prefix) :]
|
|
37
|
+
else:
|
|
38
|
+
d = input_dict
|
|
37
39
|
|
|
38
|
-
|
|
39
|
-
d = input_dict
|
|
40
|
-
selector = selector[len(input_prefix) :]
|
|
41
|
-
elif selector.startswith(target_prefix):
|
|
42
|
-
d = target_dict
|
|
43
|
-
selector = selector[len(target_prefix) :]
|
|
44
|
-
else:
|
|
45
|
-
d = input_dict
|
|
40
|
+
return d, selector
|
|
46
41
|
|
|
47
|
-
return d, selector
|
|
48
42
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
43
|
+
def read_selector(
|
|
44
|
+
input_dict: dict[str, Any], target_dict: dict[str, Any], selector: str
|
|
45
|
+
) -> Any:
|
|
46
|
+
"""Read the item referenced by the selector.
|
|
53
47
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
48
|
+
Args:
|
|
49
|
+
input_dict: the input dict
|
|
50
|
+
target_dict: the target dict
|
|
51
|
+
selector: the selector specifying the item to read
|
|
58
52
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
53
|
+
Returns:
|
|
54
|
+
the item specified by the selector
|
|
55
|
+
"""
|
|
56
|
+
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
57
|
+
parts = selector.split("/") if selector else []
|
|
58
|
+
cur = d
|
|
59
|
+
for part in parts:
|
|
60
|
+
cur = cur[part]
|
|
61
|
+
return cur
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def write_selector(
|
|
65
|
+
input_dict: dict[str, Any],
|
|
66
|
+
target_dict: dict[str, Any],
|
|
67
|
+
selector: str,
|
|
68
|
+
v: Any,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Write the item to the specified selector.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
input_dict: the input dict
|
|
74
|
+
target_dict: the target dict
|
|
75
|
+
selector: the selector specifying the item to write
|
|
76
|
+
v: the value to write
|
|
77
|
+
"""
|
|
78
|
+
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
79
|
+
if selector:
|
|
85
80
|
parts = selector.split("/")
|
|
86
81
|
cur = d
|
|
87
82
|
for part in parts[:-1]:
|
|
88
83
|
cur = cur[part]
|
|
89
84
|
cur[parts[-1]] = v
|
|
85
|
+
else:
|
|
86
|
+
# If the selector references the input or target dictionary directly, then we
|
|
87
|
+
# have a special case where instead of overwriting with v, we replace the keys
|
|
88
|
+
# with those in v. v must be a dictionary here, not a tensor, since otherwise
|
|
89
|
+
# it wouldn't match the type of the input or target dictionary.
|
|
90
|
+
if not isinstance(v, dict):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"when directly specifying the input or target dict, expected the value to be a dict"
|
|
93
|
+
)
|
|
94
|
+
if d == v:
|
|
95
|
+
# This may happen if the writer did not make a copy of the dictionary. In
|
|
96
|
+
# this case the code below would not update d correctly since it would also
|
|
97
|
+
# clear v.
|
|
98
|
+
return
|
|
99
|
+
d.clear()
|
|
100
|
+
d.update(v)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class Transform(torch.nn.Module):
|
|
104
|
+
"""An rslearn transform.
|
|
105
|
+
|
|
106
|
+
Provides helper functions for subclasses to select input and target keys and to
|
|
107
|
+
transform them.
|
|
108
|
+
"""
|
|
90
109
|
|
|
91
110
|
def apply_fn(
|
|
92
111
|
self,
|
|
93
|
-
fn: Callable
|
|
112
|
+
fn: Callable,
|
|
94
113
|
input_dict: dict[str, Any],
|
|
95
114
|
target_dict: dict[str, Any],
|
|
96
115
|
selectors: list[str],
|
|
97
116
|
**kwargs: dict[str, Any],
|
|
98
|
-
):
|
|
117
|
+
) -> None:
|
|
99
118
|
"""Apply the specified function on the selectors in input/target dicts.
|
|
100
119
|
|
|
101
120
|
Args:
|
|
@@ -106,9 +125,9 @@ class Transform(torch.nn.Module):
|
|
|
106
125
|
kwargs: additional arguments to pass to the function
|
|
107
126
|
"""
|
|
108
127
|
for selector in selectors:
|
|
109
|
-
v =
|
|
128
|
+
v = read_selector(input_dict, target_dict, selector)
|
|
110
129
|
v = fn(v, **kwargs)
|
|
111
|
-
|
|
130
|
+
write_selector(input_dict, target_dict, selector, v)
|
|
112
131
|
|
|
113
132
|
|
|
114
133
|
class Identity(Transform):
|
rslearn/utils/__init__.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""rslearn utilities."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import os
|
|
3
|
+
from rslearn.log_utils import get_logger
|
|
5
4
|
|
|
6
5
|
from .feature import Feature
|
|
7
6
|
from .geometry import (
|
|
@@ -14,10 +13,8 @@ from .geometry import (
|
|
|
14
13
|
from .get_utm_ups_crs import get_utm_ups_crs
|
|
15
14
|
from .grid_index import GridIndex
|
|
16
15
|
from .time import daterange
|
|
17
|
-
from .utils import open_atomic
|
|
18
16
|
|
|
19
|
-
logger =
|
|
20
|
-
logger.setLevel(os.environ.get("RSLEARN_LOGLEVEL", "INFO").upper())
|
|
17
|
+
logger = get_logger(__name__)
|
|
21
18
|
|
|
22
19
|
__all__ = (
|
|
23
20
|
"Feature",
|
|
@@ -29,6 +26,5 @@ __all__ = (
|
|
|
29
26
|
"get_utm_ups_crs",
|
|
30
27
|
"is_same_resolution",
|
|
31
28
|
"logger",
|
|
32
|
-
"open_atomic",
|
|
33
29
|
"shp_intersects",
|
|
34
30
|
)
|
rslearn/utils/array.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
"""Array util functions."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
import torch
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
def copy_spatial_array(
|
|
10
|
-
src: torch.Tensor | npt.NDArray[Any],
|
|
11
|
-
dst: torch.Tensor | npt.NDArray[Any],
|
|
12
|
+
src: "torch.Tensor | npt.NDArray[Any]",
|
|
13
|
+
dst: "torch.Tensor | npt.NDArray[Any]",
|
|
12
14
|
src_offset: tuple[int, int],
|
|
13
15
|
dst_offset: tuple[int, int],
|
|
14
|
-
):
|
|
16
|
+
) -> None:
|
|
15
17
|
"""Copy image content from a source array onto a destination array.
|
|
16
18
|
|
|
17
19
|
The source and destination might be in the same coordinate system. Only the portion
|
|
@@ -59,4 +61,4 @@ def copy_spatial_array(
|
|
|
59
61
|
src_col_offset : src_col_offset + col_overlap,
|
|
60
62
|
]
|
|
61
63
|
else:
|
|
62
|
-
|
|
64
|
+
raise ValueError(f"Unsupported src shape: {src.shape}")
|
rslearn/utils/feature.py
CHANGED
|
@@ -11,7 +11,7 @@ from .geometry import Projection, STGeometry
|
|
|
11
11
|
class Feature:
|
|
12
12
|
"""A GeoJSON-like feature that contains one vector geometry."""
|
|
13
13
|
|
|
14
|
-
def __init__(self, geometry: STGeometry, properties: dict[str, Any]
|
|
14
|
+
def __init__(self, geometry: STGeometry, properties: dict[str, Any] = {}):
|
|
15
15
|
"""Initialize a new Feature.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
@@ -41,7 +41,7 @@ class Feature:
|
|
|
41
41
|
return Feature(self.geometry.to_projection(projection), self.properties)
|
|
42
42
|
|
|
43
43
|
@staticmethod
|
|
44
|
-
def from_geojson(projection: Projection, d: dict[str, Any]):
|
|
44
|
+
def from_geojson(projection: Projection, d: dict[str, Any]) -> "Feature":
|
|
45
45
|
"""Construct a Feature from a GeoJSON encoding.
|
|
46
46
|
|
|
47
47
|
Args:
|
rslearn/utils/fsspec.py
CHANGED
|
@@ -6,9 +6,15 @@ from collections.abc import Generator
|
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
+
import rasterio
|
|
10
|
+
import rasterio.io
|
|
9
11
|
from fsspec.implementations.local import LocalFileSystem
|
|
10
12
|
from upath import UPath
|
|
11
13
|
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
12
18
|
|
|
13
19
|
@contextmanager
|
|
14
20
|
def get_upath_local(
|
|
@@ -65,7 +71,7 @@ def join_upath(path: UPath, suffix: str) -> UPath:
|
|
|
65
71
|
|
|
66
72
|
|
|
67
73
|
@contextmanager
|
|
68
|
-
def open_atomic(path: UPath, *args:
|
|
74
|
+
def open_atomic(path: UPath, *args: Any, **kwargs: Any) -> Generator[Any, None, None]:
|
|
69
75
|
"""Open a path for atomic writing.
|
|
70
76
|
|
|
71
77
|
If it is local filesystem, we will write to a temporary file, and rename it to the
|
|
@@ -79,11 +85,94 @@ def open_atomic(path: UPath, *args: list[Any], **kwargs: dict[str, Any]):
|
|
|
79
85
|
**kwargs: any valid keyword arguments for :code:`open`
|
|
80
86
|
"""
|
|
81
87
|
if isinstance(path.fs, LocalFileSystem):
|
|
88
|
+
logger.debug("open_atomic: writing atomically to local file at %s", path)
|
|
82
89
|
tmppath = path.path + ".tmp." + str(os.getpid())
|
|
83
90
|
with open(tmppath, *args, **kwargs) as file:
|
|
84
91
|
yield file
|
|
85
92
|
os.rename(tmppath, path.path)
|
|
86
93
|
|
|
87
94
|
else:
|
|
95
|
+
logger.debug("open_atomic: writing to remote file at %s", path)
|
|
88
96
|
with path.open(*args, **kwargs) as file:
|
|
89
97
|
yield file
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@contextmanager
|
|
101
|
+
def open_rasterio_upath_reader(
|
|
102
|
+
path: UPath, **kwargs: Any
|
|
103
|
+
) -> Generator[rasterio.io.DatasetReader, None, None]:
|
|
104
|
+
"""Open a raster for reading.
|
|
105
|
+
|
|
106
|
+
If the UPath is local, then we open with rasterio directly, since this is much
|
|
107
|
+
faster. Otherwise, we open the file stream first and then use rasterio with file
|
|
108
|
+
stream.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
path: the path to read.
|
|
112
|
+
**kwargs: additional keyword arguments for :code:`rasterio.open`
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(path.fs, LocalFileSystem):
|
|
115
|
+
logger.debug("reading from local rasterio dataset at %s", path)
|
|
116
|
+
with rasterio.open(path.path, **kwargs) as raster:
|
|
117
|
+
yield raster
|
|
118
|
+
|
|
119
|
+
else:
|
|
120
|
+
logger.debug("reading from remote rasterio dataset at %s", path)
|
|
121
|
+
with path.open("rb") as f:
|
|
122
|
+
with rasterio.open(f, **kwargs) as raster:
|
|
123
|
+
yield raster
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@contextmanager
|
|
127
|
+
def open_rasterio_upath_writer(
|
|
128
|
+
path: UPath, **kwargs: Any
|
|
129
|
+
) -> Generator[rasterio.io.DatasetWriter, None, None]:
|
|
130
|
+
"""Open a raster for writing.
|
|
131
|
+
|
|
132
|
+
If the UPath is local, then we open with rasterio directly, since this is much
|
|
133
|
+
faster. We also write atomically by writing to temporary file and then renaming it,
|
|
134
|
+
to avoid concurrency issues. Otherwise, we open the file stream first and then use
|
|
135
|
+
rasterio with file stream (and assume that it is object storage so the write will
|
|
136
|
+
be atomic).
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
path: the path to write.
|
|
140
|
+
**kwargs: additional keyword arguments for :code:`rasterio.open`
|
|
141
|
+
"""
|
|
142
|
+
if isinstance(path.fs, LocalFileSystem):
|
|
143
|
+
logger.debug(
|
|
144
|
+
"open_rasterio_upath_writer: writing atomically to local rasterio dataset at %s",
|
|
145
|
+
path,
|
|
146
|
+
)
|
|
147
|
+
tmppath = path.path + ".tmp." + str(os.getpid())
|
|
148
|
+
with rasterio.open(tmppath, "w", **kwargs) as raster:
|
|
149
|
+
yield raster
|
|
150
|
+
os.rename(tmppath, path.path)
|
|
151
|
+
|
|
152
|
+
else:
|
|
153
|
+
logger.debug(
|
|
154
|
+
"open_rasterio_upath_writer: writing to remote rasterio dataset at %s", path
|
|
155
|
+
)
|
|
156
|
+
with path.open("wb") as f:
|
|
157
|
+
with rasterio.open(f, "w", **kwargs) as raster:
|
|
158
|
+
yield raster
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_relative_suffix(base_dir: UPath, fname: UPath) -> str:
|
|
162
|
+
"""Get the suffix of fname relative to base_dir.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
base_dir: the base directory.
|
|
166
|
+
fname: a filename within the base directory.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
the suffix on base_dir that would yield the given filename.
|
|
170
|
+
"""
|
|
171
|
+
if not fname.path.startswith(base_dir.path):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"filename {fname.path} must start with base directory {base_dir.path}"
|
|
174
|
+
)
|
|
175
|
+
suffix = fname.path[len(base_dir.path) :]
|
|
176
|
+
if suffix.startswith("/"):
|
|
177
|
+
suffix = suffix[1:]
|
|
178
|
+
return suffix
|