rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
1
+ """The SelectBands transform."""
2
+
3
+ from typing import Any
4
+
5
+ from rslearn.train.model_context import RasterImage
6
+
7
+ from .transform import Transform, read_selector, write_selector
8
+
9
+
10
+ class SelectBands(Transform):
11
+ """Select a subset of bands from an image."""
12
+
13
+ def __init__(
14
+ self,
15
+ band_indices: list[int],
16
+ input_selector: str = "image",
17
+ output_selector: str = "image",
18
+ num_bands_per_timestep: int | None = None,
19
+ ):
20
+ """Initialize a new Concatenate.
21
+
22
+ Args:
23
+ band_indices: the bands to select.
24
+ input_selector: the selector to read the input image.
25
+ output_selector: the output selector under which to save the output image.
26
+ num_bands_per_timestep: the number of bands per image, to distinguish
27
+ between stacked images in an image time series. If set, then the
28
+ band_indices are selected for each image in the time series.
29
+ """
30
+ super().__init__()
31
+ self.input_selector = input_selector
32
+ self.output_selector = output_selector
33
+ self.band_indices = band_indices
34
+ self.num_bands_per_timestep = num_bands_per_timestep
35
+
36
+ def forward(
37
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
38
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
39
+ """Apply concatenation over the inputs and targets.
40
+
41
+ Args:
42
+ input_dict: the input
43
+ target_dict: the target
44
+
45
+ Returns:
46
+ normalized (input_dicts, target_dicts) tuple
47
+ """
48
+ image = read_selector(input_dict, target_dict, self.input_selector)
49
+ num_bands_per_timestep = (
50
+ self.num_bands_per_timestep
51
+ if self.num_bands_per_timestep is not None
52
+ else image.shape[0]
53
+ )
54
+ if isinstance(image, RasterImage):
55
+ assert num_bands_per_timestep == image.shape[0], (
56
+ "Expect a seperate dimension for timesteps in RasterImages."
57
+ )
58
+
59
+ if image.shape[0] % num_bands_per_timestep != 0:
60
+ raise ValueError(
61
+ f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
62
+ )
63
+
64
+ # Copy the band indices for each timestep in the input.
65
+ wanted_bands: list[int] = []
66
+ for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
67
+ wanted_bands.extend(
68
+ [(start_channel_idx + band_idx) for band_idx in self.band_indices]
69
+ )
70
+
71
+ if isinstance(image, RasterImage):
72
+ image.image = image.image[wanted_bands]
73
+ else:
74
+ image = image[wanted_bands]
75
+ write_selector(input_dict, target_dict, self.output_selector, image)
76
+ return input_dict, target_dict
@@ -0,0 +1,75 @@
1
+ """Transforms related to Sentinel-1 data."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import RasterImage
8
+
9
+ from .transform import Transform
10
+
11
+
12
+ class Sentinel1ToDecibels(Transform):
13
+ """Convert Sentinel-1 data from raw intensity to or from decibels."""
14
+
15
+ def __init__(
16
+ self,
17
+ selectors: list[str] = ["image"],
18
+ from_decibels: bool = False,
19
+ epsilon: float = 1e-6,
20
+ ):
21
+ """Initialize a new Sentinel1ToDecibels.
22
+
23
+ Args:
24
+ selectors: the input selectors to apply the transform on.
25
+ from_decibels: convert from decibels to intensities instead of intensity to
26
+ decibels.
27
+ epsilon: when converting to decibels, clip the intensities to this minimum
28
+ value to avoid log issues. This is mostly to avoid pixels that have no
29
+ data with no data value being 0.
30
+ """
31
+ super().__init__()
32
+ self.selectors = selectors
33
+ self.from_decibels = from_decibels
34
+ self.epsilon = epsilon
35
+
36
+ def apply_image(
37
+ self, image: torch.Tensor | RasterImage
38
+ ) -> torch.Tensor | RasterImage:
39
+ """Normalize the specified image.
40
+
41
+ Args:
42
+ image: the image to transform.
43
+ """
44
+ if isinstance(image, torch.Tensor):
45
+ image_to_process = image
46
+ else:
47
+ image_to_process = image.image
48
+ if self.from_decibels:
49
+ # Decibels to linear scale.
50
+ image_to_process = torch.pow(10.0, image_to_process / 10.0)
51
+ else:
52
+ # Linear scale to decibels.
53
+ image_to_process = 10 * torch.log10(
54
+ torch.clamp(image_to_process, min=self.epsilon)
55
+ )
56
+ if isinstance(image, torch.Tensor):
57
+ return image_to_process
58
+ else:
59
+ image.image = image_to_process
60
+ return image
61
+
62
+ def forward(
63
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
64
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
65
+ """Apply normalization over the inputs and targets.
66
+
67
+ Args:
68
+ input_dict: the input
69
+ target_dict: the target
70
+
71
+ Returns:
72
+ normalized (input_dicts, target_dicts) tuple
73
+ """
74
+ self.apply_fn(self.apply_image, input_dict, target_dict, self.selectors)
75
+ return input_dict, target_dict
@@ -6,96 +6,115 @@ from typing import Any
6
6
  import torch
7
7
 
8
8
 
9
- class Transform(torch.nn.Module):
10
- """An rslearn transform.
11
-
12
- Provides helper functions for subclasses to select input and target keys and to
13
- transform them.
14
- """
9
+ def get_dict_and_subselector(
10
+ input_dict: dict[str, Any], target_dict: dict[str, Any], selector: str
11
+ ) -> tuple[dict[str, Any], str]:
12
+ """Determine whether to use input or target dict, and the sub-selector.
15
13
 
16
- def get_dict_and_subselector(
17
- self, input_dict: dict[str, Any], target_dict: dict[str, Any], selector: str
18
- ) -> tuple[dict[str, Any], str]:
19
- """Determine whether to use input or target dict, and the sub-selector.
14
+ For example, if the selector is "input/x", then we use input dict and the
15
+ sub-selector is "x".
20
16
 
21
- For example, if the selector is "input/x", then we use input dict and the
22
- sub-selector is "x".
17
+ If neither input/ nor target/ prefixes are present, then we assume it is for
18
+ input dict.
23
19
 
24
- If neither input/ nor target/ prefixes are present, then we assume it is for
25
- input dict.
20
+ Args:
21
+ input_dict: the input dict
22
+ target_dict: the target dict
23
+ selector: the full selector configured by the user
26
24
 
27
- Args:
28
- input_dict: the input dict
29
- target_dict: the target dict
30
- selector: the full selector configured by the user
25
+ Returns:
26
+ a tuple (referenced dict, sub-selector string)
27
+ """
28
+ input_prefix = "input/"
29
+ target_prefix = "target/"
31
30
 
32
- Returns:
33
- a tuple (referenced dict, sub-selector string)
34
- """
35
- input_prefix = "input/"
36
- target_prefix = "target/"
31
+ if selector.startswith(input_prefix):
32
+ d = input_dict
33
+ selector = selector[len(input_prefix) :]
34
+ elif selector.startswith(target_prefix):
35
+ d = target_dict
36
+ selector = selector[len(target_prefix) :]
37
+ else:
38
+ d = input_dict
37
39
 
38
- if selector.startswith(input_prefix):
39
- d = input_dict
40
- selector = selector[len(input_prefix) :]
41
- elif selector.startswith(target_prefix):
42
- d = target_dict
43
- selector = selector[len(target_prefix) :]
44
- else:
45
- d = input_dict
40
+ return d, selector
46
41
 
47
- return d, selector
48
42
 
49
- def read_selector(
50
- self, input_dict: dict[str, Any], target_dict: dict[str, Any], selector: str
51
- ) -> Any:
52
- """Read the item referenced by the selector.
43
+ def read_selector(
44
+ input_dict: dict[str, Any], target_dict: dict[str, Any], selector: str
45
+ ) -> Any:
46
+ """Read the item referenced by the selector.
53
47
 
54
- Args:
55
- input_dict: the input dict
56
- target_dict: the target dict
57
- selector: the selector specifying the item to read
48
+ Args:
49
+ input_dict: the input dict
50
+ target_dict: the target dict
51
+ selector: the selector specifying the item to read
58
52
 
59
- Returns:
60
- the item specified by the selector
61
- """
62
- d, selector = self.get_dict_and_subselector(input_dict, target_dict, selector)
63
- parts = selector.split("/")
64
- cur = d
65
- for part in parts:
66
- cur = cur[part]
67
- return cur
68
-
69
- def write_selector(
70
- self,
71
- input_dict: dict[str, Any],
72
- target_dict: dict[str, Any],
73
- selector: str,
74
- v: Any,
75
- ):
76
- """Write the item to the specified selector.
77
-
78
- Args:
79
- input_dict: the input dict
80
- target_dict: the target dict
81
- selector: the selector specifying the item to write
82
- v: the value to write
83
- """
84
- d, selector = self.get_dict_and_subselector(input_dict, target_dict, selector)
53
+ Returns:
54
+ the item specified by the selector
55
+ """
56
+ d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
57
+ parts = selector.split("/") if selector else []
58
+ cur = d
59
+ for part in parts:
60
+ cur = cur[part]
61
+ return cur
62
+
63
+
64
+ def write_selector(
65
+ input_dict: dict[str, Any],
66
+ target_dict: dict[str, Any],
67
+ selector: str,
68
+ v: Any,
69
+ ) -> None:
70
+ """Write the item to the specified selector.
71
+
72
+ Args:
73
+ input_dict: the input dict
74
+ target_dict: the target dict
75
+ selector: the selector specifying the item to write
76
+ v: the value to write
77
+ """
78
+ d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
79
+ if selector:
85
80
  parts = selector.split("/")
86
81
  cur = d
87
82
  for part in parts[:-1]:
88
83
  cur = cur[part]
89
84
  cur[parts[-1]] = v
85
+ else:
86
+ # If the selector references the input or target dictionary directly, then we
87
+ # have a special case where instead of overwriting with v, we replace the keys
88
+ # with those in v. v must be a dictionary here, not a tensor, since otherwise
89
+ # it wouldn't match the type of the input or target dictionary.
90
+ if not isinstance(v, dict):
91
+ raise ValueError(
92
+ "when directly specifying the input or target dict, expected the value to be a dict"
93
+ )
94
+ if d == v:
95
+ # This may happen if the writer did not make a copy of the dictionary. In
96
+ # this case the code below would not update d correctly since it would also
97
+ # clear v.
98
+ return
99
+ d.clear()
100
+ d.update(v)
101
+
102
+
103
+ class Transform(torch.nn.Module):
104
+ """An rslearn transform.
105
+
106
+ Provides helper functions for subclasses to select input and target keys and to
107
+ transform them.
108
+ """
90
109
 
91
110
  def apply_fn(
92
111
  self,
93
- fn: Callable[[Any], Any],
112
+ fn: Callable,
94
113
  input_dict: dict[str, Any],
95
114
  target_dict: dict[str, Any],
96
115
  selectors: list[str],
97
116
  **kwargs: dict[str, Any],
98
- ):
117
+ ) -> None:
99
118
  """Apply the specified function on the selectors in input/target dicts.
100
119
 
101
120
  Args:
@@ -106,9 +125,9 @@ class Transform(torch.nn.Module):
106
125
  kwargs: additional arguments to pass to the function
107
126
  """
108
127
  for selector in selectors:
109
- v = self.read_selector(input_dict, target_dict, selector)
128
+ v = read_selector(input_dict, target_dict, selector)
110
129
  v = fn(v, **kwargs)
111
- self.write_selector(input_dict, target_dict, selector, v)
130
+ write_selector(input_dict, target_dict, selector, v)
112
131
 
113
132
 
114
133
  class Identity(Transform):
rslearn/utils/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """rslearn utilities."""
2
2
 
3
- import logging
4
- import os
3
+ from rslearn.log_utils import get_logger
5
4
 
6
5
  from .feature import Feature
7
6
  from .geometry import (
@@ -14,10 +13,8 @@ from .geometry import (
14
13
  from .get_utm_ups_crs import get_utm_ups_crs
15
14
  from .grid_index import GridIndex
16
15
  from .time import daterange
17
- from .utils import open_atomic
18
16
 
19
- logger = logging.getLogger(__name__)
20
- logger.setLevel(os.environ.get("RSLEARN_LOGLEVEL", "INFO").upper())
17
+ logger = get_logger(__name__)
21
18
 
22
19
  __all__ = (
23
20
  "Feature",
@@ -29,6 +26,5 @@ __all__ = (
29
26
  "get_utm_ups_crs",
30
27
  "is_same_resolution",
31
28
  "logger",
32
- "open_atomic",
33
29
  "shp_intersects",
34
30
  )
rslearn/utils/array.py CHANGED
@@ -1,17 +1,19 @@
1
1
  """Array util functions."""
2
2
 
3
- from typing import Any
3
+ from typing import TYPE_CHECKING, Any
4
4
 
5
5
  import numpy.typing as npt
6
- import torch
6
+
7
+ if TYPE_CHECKING:
8
+ import torch
7
9
 
8
10
 
9
11
  def copy_spatial_array(
10
- src: torch.Tensor | npt.NDArray[Any],
11
- dst: torch.Tensor | npt.NDArray[Any],
12
+ src: "torch.Tensor | npt.NDArray[Any]",
13
+ dst: "torch.Tensor | npt.NDArray[Any]",
12
14
  src_offset: tuple[int, int],
13
15
  dst_offset: tuple[int, int],
14
- ):
16
+ ) -> None:
15
17
  """Copy image content from a source array onto a destination array.
16
18
 
17
19
  The source and destination might be in the same coordinate system. Only the portion
@@ -59,4 +61,4 @@ def copy_spatial_array(
59
61
  src_col_offset : src_col_offset + col_overlap,
60
62
  ]
61
63
  else:
62
- assert False
64
+ raise ValueError(f"Unsupported src shape: {src.shape}")
rslearn/utils/feature.py CHANGED
@@ -11,7 +11,7 @@ from .geometry import Projection, STGeometry
11
11
  class Feature:
12
12
  """A GeoJSON-like feature that contains one vector geometry."""
13
13
 
14
- def __init__(self, geometry: STGeometry, properties: dict[str, Any] | None = {}):
14
+ def __init__(self, geometry: STGeometry, properties: dict[str, Any] = {}):
15
15
  """Initialize a new Feature.
16
16
 
17
17
  Args:
@@ -41,7 +41,7 @@ class Feature:
41
41
  return Feature(self.geometry.to_projection(projection), self.properties)
42
42
 
43
43
  @staticmethod
44
- def from_geojson(projection: Projection, d: dict[str, Any]):
44
+ def from_geojson(projection: Projection, d: dict[str, Any]) -> "Feature":
45
45
  """Construct a Feature from a GeoJSON encoding.
46
46
 
47
47
  Args:
rslearn/utils/fsspec.py CHANGED
@@ -6,9 +6,15 @@ from collections.abc import Generator
6
6
  from contextlib import contextmanager
7
7
  from typing import Any
8
8
 
9
+ import rasterio
10
+ import rasterio.io
9
11
  from fsspec.implementations.local import LocalFileSystem
10
12
  from upath import UPath
11
13
 
14
+ from rslearn.log_utils import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
12
18
 
13
19
  @contextmanager
14
20
  def get_upath_local(
@@ -65,7 +71,7 @@ def join_upath(path: UPath, suffix: str) -> UPath:
65
71
 
66
72
 
67
73
  @contextmanager
68
- def open_atomic(path: UPath, *args: list[Any], **kwargs: dict[str, Any]):
74
+ def open_atomic(path: UPath, *args: Any, **kwargs: Any) -> Generator[Any, None, None]:
69
75
  """Open a path for atomic writing.
70
76
 
71
77
  If it is local filesystem, we will write to a temporary file, and rename it to the
@@ -79,11 +85,94 @@ def open_atomic(path: UPath, *args: list[Any], **kwargs: dict[str, Any]):
79
85
  **kwargs: any valid keyword arguments for :code:`open`
80
86
  """
81
87
  if isinstance(path.fs, LocalFileSystem):
88
+ logger.debug("open_atomic: writing atomically to local file at %s", path)
82
89
  tmppath = path.path + ".tmp." + str(os.getpid())
83
90
  with open(tmppath, *args, **kwargs) as file:
84
91
  yield file
85
92
  os.rename(tmppath, path.path)
86
93
 
87
94
  else:
95
+ logger.debug("open_atomic: writing to remote file at %s", path)
88
96
  with path.open(*args, **kwargs) as file:
89
97
  yield file
98
+
99
+
100
+ @contextmanager
101
+ def open_rasterio_upath_reader(
102
+ path: UPath, **kwargs: Any
103
+ ) -> Generator[rasterio.io.DatasetReader, None, None]:
104
+ """Open a raster for reading.
105
+
106
+ If the UPath is local, then we open with rasterio directly, since this is much
107
+ faster. Otherwise, we open the file stream first and then use rasterio with file
108
+ stream.
109
+
110
+ Args:
111
+ path: the path to read.
112
+ **kwargs: additional keyword arguments for :code:`rasterio.open`
113
+ """
114
+ if isinstance(path.fs, LocalFileSystem):
115
+ logger.debug("reading from local rasterio dataset at %s", path)
116
+ with rasterio.open(path.path, **kwargs) as raster:
117
+ yield raster
118
+
119
+ else:
120
+ logger.debug("reading from remote rasterio dataset at %s", path)
121
+ with path.open("rb") as f:
122
+ with rasterio.open(f, **kwargs) as raster:
123
+ yield raster
124
+
125
+
126
+ @contextmanager
127
+ def open_rasterio_upath_writer(
128
+ path: UPath, **kwargs: Any
129
+ ) -> Generator[rasterio.io.DatasetWriter, None, None]:
130
+ """Open a raster for writing.
131
+
132
+ If the UPath is local, then we open with rasterio directly, since this is much
133
+ faster. We also write atomically by writing to temporary file and then renaming it,
134
+ to avoid concurrency issues. Otherwise, we open the file stream first and then use
135
+ rasterio with file stream (and assume that it is object storage so the write will
136
+ be atomic).
137
+
138
+ Args:
139
+ path: the path to write.
140
+ **kwargs: additional keyword arguments for :code:`rasterio.open`
141
+ """
142
+ if isinstance(path.fs, LocalFileSystem):
143
+ logger.debug(
144
+ "open_rasterio_upath_writer: writing atomically to local rasterio dataset at %s",
145
+ path,
146
+ )
147
+ tmppath = path.path + ".tmp." + str(os.getpid())
148
+ with rasterio.open(tmppath, "w", **kwargs) as raster:
149
+ yield raster
150
+ os.rename(tmppath, path.path)
151
+
152
+ else:
153
+ logger.debug(
154
+ "open_rasterio_upath_writer: writing to remote rasterio dataset at %s", path
155
+ )
156
+ with path.open("wb") as f:
157
+ with rasterio.open(f, "w", **kwargs) as raster:
158
+ yield raster
159
+
160
+
161
+ def get_relative_suffix(base_dir: UPath, fname: UPath) -> str:
162
+ """Get the suffix of fname relative to base_dir.
163
+
164
+ Args:
165
+ base_dir: the base directory.
166
+ fname: a filename within the base directory.
167
+
168
+ Returns:
169
+ the suffix on base_dir that would yield the given filename.
170
+ """
171
+ if not fname.path.startswith(base_dir.path):
172
+ raise ValueError(
173
+ f"filename {fname.path} must start with base directory {base_dir.path}"
174
+ )
175
+ suffix = fname.path[len(base_dir.path) :]
176
+ if suffix.startswith("/"):
177
+ suffix = suffix[1:]
178
+ return suffix