rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/dataset/remap.py CHANGED
@@ -1,18 +1,42 @@
1
1
  """Classes to remap raster values."""
2
2
 
3
- from typing import Any
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
 
9
- Remappers = ClassRegistry()
9
+ _RemapperT = TypeVar("_RemapperT", bound="Remapper")
10
+
11
+
12
+ class _RemapperRegistry(dict[str, type["Remapper"]]):
13
+ """Registry for Remapper classes."""
14
+
15
+ def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
16
+ """Decorator to register a remapper class."""
17
+
18
+ def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
19
+ self[name] = cls
20
+ return cls
21
+
22
+ return decorator
23
+
24
+
25
+ Remappers = _RemapperRegistry()
10
26
  """Registry of Remapper implementations."""
11
27
 
12
28
 
13
29
  class Remapper:
14
30
  """An abstract class that remaps pixel values based on layer configuration."""
15
31
 
32
+ def __init__(self, config: dict[str, Any]) -> None:
33
+ """Initialize a Remapper.
34
+
35
+ Args:
36
+ config: the config dict for this remapper.
37
+ """
38
+ pass
39
+
16
40
  def __call__(
17
41
  self, array: npt.NDArray[Any], dtype: npt.DTypeLike
18
42
  ) -> npt.NDArray[Any]:
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
67
91
 
68
92
  def load_remapper(config: dict[str, Any]) -> Remapper:
69
93
  """Load a remapper from a configuration dictionary."""
70
- return Remappers.get(config["name"], config=config)
94
+ cls = Remappers[config["name"]]
95
+ return cls(config)
@@ -0,0 +1 @@
1
+ """Storage backends for rslearn window metadata."""
@@ -0,0 +1,202 @@
1
+ """The default file-based window storage backend."""
2
+
3
+ import json
4
+ import multiprocessing
5
+
6
+ import tqdm
7
+ from typing_extensions import override
8
+ from upath import UPath
9
+
10
+ from rslearn.dataset.window import (
11
+ LAYERS_DIRECTORY_NAME,
12
+ Window,
13
+ WindowLayerData,
14
+ get_layer_and_group_from_dir_name,
15
+ get_window_layer_dir,
16
+ )
17
+ from rslearn.log_utils import get_logger
18
+ from rslearn.utils.fsspec import open_atomic
19
+ from rslearn.utils.mp import star_imap_unordered
20
+
21
+ from .storage import WindowStorage, WindowStorageFactory
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ def load_window(storage: "FileWindowStorage", window_dir: UPath) -> Window:
27
+ """Load the window from its directory by reading metadata.json.
28
+
29
+ Args:
30
+ storage: the underlying FileWindowStorage.
31
+ window_dir: the path where the window is stored.
32
+
33
+ Returns:
34
+ the window object.
35
+ """
36
+ metadata_fname = window_dir / "metadata.json"
37
+ with metadata_fname.open() as f:
38
+ metadata = json.load(f)
39
+ return Window.from_metadata(storage, metadata)
40
+
41
+
42
+ class FileWindowStorage(WindowStorage):
43
+ """The default file-backed window storage."""
44
+
45
+ def __init__(self, path: UPath):
46
+ """Create a new FileWindowStorage.
47
+
48
+ Args:
49
+ path: the path to the dataset.
50
+ """
51
+ self.path = path
52
+
53
+ @override
54
+ def get_window_root(self, group: str, name: str) -> UPath:
55
+ return Window.get_window_root(self.path, group, name)
56
+
57
+ @override
58
+ def get_windows(
59
+ self,
60
+ groups: list[str] | None = None,
61
+ names: list[str] | None = None,
62
+ show_progress: bool = False,
63
+ workers: int = 0,
64
+ ) -> list["Window"]:
65
+ """Load the windows in the dataset.
66
+
67
+ Args:
68
+ groups: an optional list of groups to filter loading
69
+ names: an optional list of window names to filter loading
70
+ show_progress: whether to show tqdm progress bar
71
+ workers: number of parallel workers, default 0 (use main thread only to load windows)
72
+ """
73
+ # Avoid directory does not exist errors later.
74
+ if not (self.path / "windows").exists():
75
+ return []
76
+
77
+ window_dirs = []
78
+ if not groups:
79
+ groups = []
80
+ for p in (self.path / "windows").iterdir():
81
+ groups.append(p.name)
82
+ for group in groups:
83
+ group_dir = self.path / "windows" / group
84
+ if not group_dir.exists():
85
+ logger.warning(
86
+ f"Skipping group directory {group_dir} since it does not exist"
87
+ )
88
+ continue
89
+ if names:
90
+ cur_names = names
91
+ else:
92
+ cur_names = []
93
+ for p in group_dir.iterdir():
94
+ cur_names.append(p.name)
95
+
96
+ for window_name in cur_names:
97
+ window_dir = group_dir / window_name
98
+ window_dirs.append(window_dir)
99
+
100
+ if workers == 0:
101
+ windows = [load_window(self, window_dir) for window_dir in window_dirs]
102
+ else:
103
+ p = multiprocessing.Pool(workers)
104
+ outputs = star_imap_unordered(
105
+ p,
106
+ load_window,
107
+ [
108
+ dict(storage=self, window_dir=window_dir)
109
+ for window_dir in window_dirs
110
+ ],
111
+ )
112
+ if show_progress:
113
+ outputs = tqdm.tqdm(
114
+ outputs, total=len(window_dirs), desc="Loading windows"
115
+ )
116
+ windows = []
117
+ for window in outputs:
118
+ windows.append(window)
119
+ p.close()
120
+
121
+ return windows
122
+
123
+ @override
124
+ def create_or_update_window(self, window: Window) -> None:
125
+ window_path = self.get_window_root(window.group, window.name)
126
+ window_path.mkdir(parents=True, exist_ok=True)
127
+ metadata_path = window_path / "metadata.json"
128
+ logger.debug(f"Saving window metadata to {metadata_path}")
129
+ with open_atomic(metadata_path, "w") as f:
130
+ json.dump(window.get_metadata(), f)
131
+
132
+ @override
133
+ def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
134
+ window_path = self.get_window_root(group, name)
135
+ items_fname = window_path / "items.json"
136
+ if not items_fname.exists():
137
+ return {}
138
+
139
+ with items_fname.open() as f:
140
+ layer_datas = [
141
+ WindowLayerData.deserialize(layer_data) for layer_data in json.load(f)
142
+ ]
143
+
144
+ return {layer_data.layer_name: layer_data for layer_data in layer_datas}
145
+
146
+ @override
147
+ def save_layer_datas(
148
+ self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
149
+ ) -> None:
150
+ window_path = self.get_window_root(group, name)
151
+ json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
152
+ items_fname = window_path / "items.json"
153
+ logger.info(f"Saving window items to {items_fname}")
154
+ with open_atomic(items_fname, "w") as f:
155
+ json.dump(json_data, f)
156
+
157
+ @override
158
+ def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
159
+ window_path = self.get_window_root(group, name)
160
+ layers_directory = window_path / LAYERS_DIRECTORY_NAME
161
+ if not layers_directory.exists():
162
+ return []
163
+
164
+ completed_layers = []
165
+ for layer_dir in layers_directory.iterdir():
166
+ layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
167
+ if not self.is_layer_completed(group, name, layer_name, group_idx):
168
+ continue
169
+ completed_layers.append((layer_name, group_idx))
170
+
171
+ return completed_layers
172
+
173
+ @override
174
+ def is_layer_completed(
175
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
176
+ ) -> bool:
177
+ window_path = self.get_window_root(group, name)
178
+ layer_dir = get_window_layer_dir(
179
+ window_path,
180
+ layer_name,
181
+ group_idx,
182
+ )
183
+ return (layer_dir / "completed").exists()
184
+
185
+ @override
186
+ def mark_layer_completed(
187
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
188
+ ) -> None:
189
+ window_path = self.get_window_root(group, name)
190
+ layer_dir = get_window_layer_dir(window_path, layer_name, group_idx)
191
+ # We assume the directory exists because the layer should be materialized before
192
+ # being marked completed.
193
+ (layer_dir / "completed").touch()
194
+
195
+
196
+ class FileWindowStorageFactory(WindowStorageFactory):
197
+ """Factory class for FileWindowStorage."""
198
+
199
+ @override
200
+ def get_storage(self, ds_path: UPath) -> FileWindowStorage:
201
+ """Get a FileWindowStorage for the given dataset path."""
202
+ return FileWindowStorage(ds_path)
@@ -0,0 +1,140 @@
1
+ """Abstract classes for window metadata storage."""
2
+
3
+ import abc
4
+ from typing import TYPE_CHECKING
5
+
6
+ from upath import UPath
7
+
8
+ if TYPE_CHECKING:
9
+ from rslearn.dataset.window import Window, WindowLayerData
10
+
11
+
12
+ class WindowStorage(abc.ABC):
13
+ """An abstract class for the storage backend for window metadata.
14
+
15
+ This is instantiated by a WindowStorageFactory for a specific rslearn dataset.
16
+
17
+ Window metadata includes the location and time range of windows (metadata.json),
18
+ the window layer datas (items.json), and the completed (materialized) layers. It
19
+ excludes the actual materialized data. All operations involving window metadata go
20
+ through the WindowStorage, including enumerating windows, creating new windows, and
21
+ updating window layer datas during `rslearn dataset prepare` or the completed
22
+ layers during `rslearn dataset materialize`.
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ def get_window_root(self, group: str, name: str) -> UPath:
27
+ """Get the path where the window should be stored."""
28
+ raise NotImplementedError
29
+
30
+ @abc.abstractmethod
31
+ def get_windows(
32
+ self,
33
+ groups: list[str] | None = None,
34
+ names: list[str] | None = None,
35
+ ) -> list["Window"]:
36
+ """Load the windows in the dataset.
37
+
38
+ Args:
39
+ groups: an optional list of groups to filter loading
40
+ names: an optional list of window names to filter loading
41
+ """
42
+ raise NotImplementedError
43
+
44
+ @abc.abstractmethod
45
+ def create_or_update_window(self, window: "Window") -> None:
46
+ """Create or update the window.
47
+
48
+ An existing window is only updated if there is one with the same name and group.
49
+
50
+ If there is a window with the same name but a different group, the behavior is
51
+ undefined.
52
+ """
53
+ raise NotImplementedError
54
+
55
+ @abc.abstractmethod
56
+ def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
57
+ """Get the window layer datas for the specified window.
58
+
59
+ Args:
60
+ group: the window group.
61
+ name: the window name.
62
+
63
+ Returns:
64
+ a dict mapping from the layer name to the layer data for that layer, if one
65
+ was previously saved.
66
+ """
67
+ raise NotImplementedError
68
+
69
+ @abc.abstractmethod
70
+ def save_layer_datas(
71
+ self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
72
+ ) -> None:
73
+ """Set the window layer datas for the specified window."""
74
+ raise NotImplementedError
75
+
76
+ @abc.abstractmethod
77
+ def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
78
+ """List the layers available for this window that are completed.
79
+
80
+ Args:
81
+ group: the window group.
82
+ name: the window name.
83
+
84
+ Returns:
85
+ a list of (layer_name, group_idx) completed layers.
86
+ """
87
+ raise NotImplementedError
88
+
89
+ @abc.abstractmethod
90
+ def is_layer_completed(
91
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
92
+ ) -> bool:
93
+ """Check whether the specified layer is completed in the given window.
94
+
95
+ Completed means there is data in the layer and the data has been written
96
+ (materialized).
97
+
98
+ Args:
99
+ group: the window group.
100
+ name: the window name.
101
+ layer_name: the layer name.
102
+ group_idx: the index of the group within the layer.
103
+
104
+ Returns:
105
+ whether the layer is completed.
106
+ """
107
+ raise NotImplementedError
108
+
109
+ @abc.abstractmethod
110
+ def mark_layer_completed(
111
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
112
+ ) -> None:
113
+ """Mark the specified layer completed for the given window.
114
+
115
+ This must be done after the contents of the layer have been written. If a layer
116
+ has multiple groups, the caller should wait until the contents of all groups
117
+ have been written before marking them completed; this is because, when
118
+ materializing a window, we skip materialization if the first group
119
+ (group_idx=0) is marked completed.
120
+
121
+ Args:
122
+ group: the window group.
123
+ name: the window name.
124
+ layer_name: the layer name.
125
+ group_idx: the index of the group within the layer.
126
+ """
127
+ raise NotImplementedError
128
+
129
+
130
+ class WindowStorageFactory(abc.ABC):
131
+ """An abstract class for a configurable storage backend for window metadata.
132
+
133
+ The dataset config includes a StorageConfig that configures a WindowStorageFactory,
134
+ which in turn creates a WindowStorage given a dataset path.
135
+ """
136
+
137
+ @abc.abstractmethod
138
+ def get_storage(self, ds_path: UPath) -> WindowStorage:
139
+ """Get a WindowStorage for the given dataset path."""
140
+ raise NotImplementedError
rslearn/dataset/window.py CHANGED
@@ -1,14 +1,79 @@
1
1
  """rslearn windows."""
2
2
 
3
- import json
4
3
  from datetime import datetime
5
4
  from typing import Any
6
5
 
7
6
  import shapely
8
7
  from upath import UPath
9
8
 
9
+ from rslearn.dataset.storage.storage import WindowStorage
10
+ from rslearn.log_utils import get_logger
10
11
  from rslearn.utils import Projection, STGeometry
11
- from rslearn.utils.fsspec import open_atomic
12
+ from rslearn.utils.raster_format import get_bandset_dirname
13
+
14
+ logger = get_logger(__name__)
15
+
16
+ LAYERS_DIRECTORY_NAME = "layers"
17
+
18
+
19
+ def get_window_layer_dir(
20
+ window_path: UPath, layer_name: str, group_idx: int = 0
21
+ ) -> UPath:
22
+ """Get the directory containing materialized data for the specified layer.
23
+
24
+ Args:
25
+ window_path: the window directory.
26
+ layer_name: the layer name.
27
+ group_idx: the index of the group within the layer to get the directory
28
+ for (default 0).
29
+
30
+ Returns:
31
+ the path where data is or should be materialized.
32
+ """
33
+ if group_idx == 0:
34
+ folder_name = layer_name
35
+ else:
36
+ folder_name = f"{layer_name}.{group_idx}"
37
+ return window_path / LAYERS_DIRECTORY_NAME / folder_name
38
+
39
+
40
+ def get_layer_and_group_from_dir_name(layer_dir_name: str) -> tuple[str, int]:
41
+ """Get the layer name and group index from the layer directory name.
42
+
43
+ Args:
44
+ layer_dir_name: the name of the layer folder.
45
+
46
+ Returns:
47
+ a tuple (layer_name, group_idx)
48
+ """
49
+ if "." in layer_dir_name:
50
+ parts = layer_dir_name.split(".")
51
+ if len(parts) != 2:
52
+ raise ValueError(
53
+ f"expected layer directory name {layer_dir_name} to only contain one '.'"
54
+ )
55
+ return (parts[0], int(parts[1]))
56
+ else:
57
+ return (layer_dir_name, 0)
58
+
59
+
60
+ def get_window_raster_dir(
61
+ window_path: UPath, layer_name: str, bands: list[str], group_idx: int = 0
62
+ ) -> UPath:
63
+ """Get the directory where the raster is materialized.
64
+
65
+ Args:
66
+ window_path: the window directory.
67
+ layer_name: the layer name
68
+ bands: the bands in the raster. It should match a band set defined for this
69
+ layer.
70
+ group_idx: the index of the group within the layer.
71
+
72
+ Returns:
73
+ the directory containing the raster.
74
+ """
75
+ dirname = get_bandset_dirname(bands)
76
+ return get_window_layer_dir(window_path, layer_name, group_idx) / dirname
12
77
 
13
78
 
14
79
  class WindowLayerData:
@@ -69,7 +134,7 @@ class Window:
69
134
 
70
135
  def __init__(
71
136
  self,
72
- path: UPath,
137
+ storage: WindowStorage,
73
138
  group: str,
74
139
  name: str,
75
140
  projection: Projection,
@@ -83,7 +148,7 @@ class Window:
83
148
  stored in metadata.json.
84
149
 
85
150
  Args:
86
- path: the directory of this window
151
+ storage: the dataset storage for the underlying rslearn dataset.
87
152
  group: the group the window belongs to
88
153
  name: the unique name for this window
89
154
  projection: the projection of the window
@@ -91,7 +156,7 @@ class Window:
91
156
  time_range: optional time range of the window
92
157
  options: additional options (?)
93
158
  """
94
- self.path = path
159
+ self.storage = storage
95
160
  self.group = group
96
161
  self.name = name
97
162
  self.projection = projection
@@ -99,25 +164,6 @@ class Window:
99
164
  self.time_range = time_range
100
165
  self.options = options
101
166
 
102
- def save(self) -> None:
103
- """Save the window metadata to its root directory."""
104
- self.path.mkdir(parents=True, exist_ok=True)
105
- metadata = {
106
- "group": self.group,
107
- "name": self.name,
108
- "projection": self.projection.serialize(),
109
- "bounds": self.bounds,
110
- "time_range": (
111
- [self.time_range[0].isoformat(), self.time_range[1].isoformat()]
112
- if self.time_range
113
- else None
114
- ),
115
- "options": self.options,
116
- }
117
- metadata_path = self.path / "metadata.json"
118
- with open_atomic(metadata_path, "w") as f:
119
- json.dump(metadata, f)
120
-
121
167
  def get_geometry(self) -> STGeometry:
122
168
  """Computes the STGeometry corresponding to this window."""
123
169
  return STGeometry(
@@ -128,41 +174,132 @@ class Window:
128
174
 
129
175
  def load_layer_datas(self) -> dict[str, WindowLayerData]:
130
176
  """Load layer datas describing items in retrieved layers from items.json."""
131
- items_fname = self.path / "items.json"
132
- if not items_fname.exists():
133
- return {}
134
- with items_fname.open("r") as f:
135
- layer_datas = [
136
- WindowLayerData.deserialize(layer_data) for layer_data in json.load(f)
137
- ]
138
- return {layer_data.layer_name: layer_data for layer_data in layer_datas}
177
+ return self.storage.get_layer_datas(self.group, self.name)
139
178
 
140
179
  def save_layer_datas(self, layer_datas: dict[str, WindowLayerData]) -> None:
141
180
  """Save layer datas to items.json."""
142
- json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
143
- items_fname = self.path / "items.json"
144
- with open_atomic(items_fname, "w") as f:
145
- json.dump(json_data, f)
181
+ self.storage.save_layer_datas(self.group, self.name, layer_datas)
182
+
183
+ def list_completed_layers(self) -> list[tuple[str, int]]:
184
+ """List the layers available for this window that are completed.
185
+
186
+ Returns:
187
+ a list of (layer_name, group_idx) completed layers.
188
+ """
189
+ return self.storage.list_completed_layers(self.group, self.name)
190
+
191
+ def get_layer_dir(self, layer_name: str, group_idx: int = 0) -> UPath:
192
+ """Get the directory containing materialized data for the specified layer.
193
+
194
+ Args:
195
+ layer_name: the layer name.
196
+ group_idx: the index of the group within the layer to get the directory
197
+ for (default 0).
198
+
199
+ Returns:
200
+ the path where data is or should be materialized.
201
+ """
202
+ return get_window_layer_dir(
203
+ self.storage.get_window_root(self.group, self.name), layer_name, group_idx
204
+ )
205
+
206
+ def is_layer_completed(self, layer_name: str, group_idx: int = 0) -> bool:
207
+ """Check whether the specified layer is completed.
208
+
209
+ Completed means there is data in the layer and the data has been written
210
+ (materialized).
211
+
212
+ Args:
213
+ layer_name: the layer name.
214
+ group_idx: the index of the group within the layer.
215
+
216
+ Returns:
217
+ whether the layer is completed
218
+ """
219
+ return self.storage.is_layer_completed(
220
+ self.group, self.name, layer_name, group_idx
221
+ )
222
+
223
+ def mark_layer_completed(self, layer_name: str, group_idx: int = 0) -> None:
224
+ """Mark the specified layer completed.
225
+
226
+ This must be done after the contents of the layer have been written. If a layer
227
+ has multiple groups, the caller should wait until the contents of all groups
228
+ have been written before marking them completed; this is because, when
229
+ materializing a window, we skip materialization if the first group
230
+ (group_idx=0) is marked completed.
231
+
232
+ Args:
233
+ layer_name: the layer name.
234
+ group_idx: the index of the group within the layer.
235
+ """
236
+ self.storage.mark_layer_completed(self.group, self.name, layer_name, group_idx)
237
+
238
+ def get_raster_dir(
239
+ self, layer_name: str, bands: list[str], group_idx: int = 0
240
+ ) -> UPath:
241
+ """Get the directory where the raster is materialized.
242
+
243
+ Args:
244
+ layer_name: the layer name
245
+ bands: the bands in the raster. It should match a band set defined for this
246
+ layer.
247
+ group_idx: the index of the group within the layer.
248
+
249
+ Returns:
250
+ the directory containing the raster.
251
+ """
252
+ return get_window_raster_dir(
253
+ self.storage.get_window_root(self.group, self.name),
254
+ layer_name,
255
+ bands,
256
+ group_idx,
257
+ )
258
+
259
+ def get_metadata(self) -> dict[str, Any]:
260
+ """Returns the window metadata dictionary."""
261
+ return {
262
+ "group": self.group,
263
+ "name": self.name,
264
+ "projection": self.projection.serialize(),
265
+ "bounds": self.bounds,
266
+ "time_range": (
267
+ [self.time_range[0].isoformat(), self.time_range[1].isoformat()]
268
+ if self.time_range
269
+ else None
270
+ ),
271
+ "options": self.options,
272
+ }
273
+
274
+ def save(self) -> None:
275
+ """Save the window metadata to its root directory."""
276
+ self.storage.create_or_update_window(self)
146
277
 
147
278
  @staticmethod
148
- def load(path: UPath) -> "Window":
149
- """Load a Window from a UPath.
279
+ def from_metadata(storage: WindowStorage, metadata: dict[str, Any]) -> "Window":
280
+ """Create a Window from the WindowStorage and the window's metadata dictionary.
150
281
 
151
282
  Args:
152
- path: the root directory of the window
283
+ storage: the WindowStorage for the underlying dataset.
284
+ metadata: the window metadata.
153
285
 
154
286
  Returns:
155
287
  the Window
156
288
  """
157
- metadata_fname = path / "metadata.json"
158
- with metadata_fname.open("r") as f:
159
- metadata = json.load(f)
289
+ # Ensure bounds is converted from list to tuple.
290
+ bounds = (
291
+ metadata["bounds"][0],
292
+ metadata["bounds"][1],
293
+ metadata["bounds"][2],
294
+ metadata["bounds"][3],
295
+ )
296
+
160
297
  return Window(
161
- path=path,
298
+ storage=storage,
162
299
  group=metadata["group"],
163
300
  name=metadata["name"],
164
301
  projection=Projection.deserialize(metadata["projection"]),
165
- bounds=metadata["bounds"],
302
+ bounds=bounds,
166
303
  time_range=(
167
304
  (
168
305
  datetime.fromisoformat(metadata["time_range"][0]),