rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Abstract classes for window metadata storage."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from upath import UPath
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from rslearn.dataset.window import Window, WindowLayerData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WindowStorage(abc.ABC):
|
|
13
|
+
"""An abstract class for the storage backend for window metadata.
|
|
14
|
+
|
|
15
|
+
This is instantiated by a WindowStorageFactory for a specific rslearn dataset.
|
|
16
|
+
|
|
17
|
+
Window metadata includes the location and time range of windows (metadata.json),
|
|
18
|
+
the window layer datas (items.json), and the completed (materialized) layers. It
|
|
19
|
+
excludes the actual materialized data. All operations involving window metadata go
|
|
20
|
+
through the WindowStorage, including enumerating windows, creating new windows, and
|
|
21
|
+
updating window layer datas during `rslearn dataset prepare` or the completed
|
|
22
|
+
layers during `rslearn dataset materialize`.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def get_window_root(self, group: str, name: str) -> UPath:
|
|
27
|
+
"""Get the path where the window should be stored."""
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def get_windows(
|
|
32
|
+
self,
|
|
33
|
+
groups: list[str] | None = None,
|
|
34
|
+
names: list[str] | None = None,
|
|
35
|
+
) -> list["Window"]:
|
|
36
|
+
"""Load the windows in the dataset.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
groups: an optional list of groups to filter loading
|
|
40
|
+
names: an optional list of window names to filter loading
|
|
41
|
+
"""
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def create_or_update_window(self, window: "Window") -> None:
|
|
46
|
+
"""Create or update the window.
|
|
47
|
+
|
|
48
|
+
An existing window is only updated if there is one with the same name and group.
|
|
49
|
+
|
|
50
|
+
If there is a window with the same name but a different group, the behavior is
|
|
51
|
+
undefined.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@abc.abstractmethod
|
|
56
|
+
def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
|
|
57
|
+
"""Get the window layer datas for the specified window.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
group: the window group.
|
|
61
|
+
name: the window name.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
a dict mapping from the layer name to the layer data for that layer, if one
|
|
65
|
+
was previously saved.
|
|
66
|
+
"""
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def save_layer_datas(
|
|
71
|
+
self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Set the window layer datas for the specified window."""
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
|
|
78
|
+
"""List the layers available for this window that are completed.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
group: the window group.
|
|
82
|
+
name: the window name.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
a list of (layer_name, group_idx) completed layers.
|
|
86
|
+
"""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
@abc.abstractmethod
|
|
90
|
+
def is_layer_completed(
|
|
91
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
92
|
+
) -> bool:
|
|
93
|
+
"""Check whether the specified layer is completed in the given window.
|
|
94
|
+
|
|
95
|
+
Completed means there is data in the layer and the data has been written
|
|
96
|
+
(materialized).
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
group: the window group.
|
|
100
|
+
name: the window name.
|
|
101
|
+
layer_name: the layer name.
|
|
102
|
+
group_idx: the index of the group within the layer.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
whether the layer is completed.
|
|
106
|
+
"""
|
|
107
|
+
raise NotImplementedError
|
|
108
|
+
|
|
109
|
+
@abc.abstractmethod
|
|
110
|
+
def mark_layer_completed(
|
|
111
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
112
|
+
) -> None:
|
|
113
|
+
"""Mark the specified layer completed for the given window.
|
|
114
|
+
|
|
115
|
+
This must be done after the contents of the layer have been written. If a layer
|
|
116
|
+
has multiple groups, the caller should wait until the contents of all groups
|
|
117
|
+
have been written before marking them completed; this is because, when
|
|
118
|
+
materializing a window, we skip materialization if the first group
|
|
119
|
+
(group_idx=0) is marked completed.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
group: the window group.
|
|
123
|
+
name: the window name.
|
|
124
|
+
layer_name: the layer name.
|
|
125
|
+
group_idx: the index of the group within the layer.
|
|
126
|
+
"""
|
|
127
|
+
raise NotImplementedError
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class WindowStorageFactory(abc.ABC):
|
|
131
|
+
"""An abstract class for a configurable storage backend for window metadata.
|
|
132
|
+
|
|
133
|
+
The dataset config includes a StorageConfig that configures a WindowStorageFactory,
|
|
134
|
+
which in turn creates a WindowStorage given a dataset path.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
@abc.abstractmethod
|
|
138
|
+
def get_storage(self, ds_path: UPath) -> WindowStorage:
|
|
139
|
+
"""Get a WindowStorage for the given dataset path."""
|
|
140
|
+
raise NotImplementedError
|
rslearn/dataset/window.py
CHANGED
|
@@ -1,20 +1,16 @@
|
|
|
1
1
|
"""rslearn windows."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
from datetime import datetime
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any
|
|
6
5
|
|
|
7
6
|
import shapely
|
|
8
7
|
from upath import UPath
|
|
9
8
|
|
|
9
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
10
10
|
from rslearn.log_utils import get_logger
|
|
11
11
|
from rslearn.utils import Projection, STGeometry
|
|
12
|
-
from rslearn.utils.fsspec import open_atomic
|
|
13
12
|
from rslearn.utils.raster_format import get_bandset_dirname
|
|
14
13
|
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from .index import DatasetIndex
|
|
17
|
-
|
|
18
14
|
logger = get_logger(__name__)
|
|
19
15
|
|
|
20
16
|
LAYERS_DIRECTORY_NAME = "layers"
|
|
@@ -138,14 +134,13 @@ class Window:
|
|
|
138
134
|
|
|
139
135
|
def __init__(
|
|
140
136
|
self,
|
|
141
|
-
|
|
137
|
+
storage: WindowStorage,
|
|
142
138
|
group: str,
|
|
143
139
|
name: str,
|
|
144
140
|
projection: Projection,
|
|
145
141
|
bounds: tuple[int, int, int, int],
|
|
146
142
|
time_range: tuple[datetime, datetime] | None,
|
|
147
143
|
options: dict[str, Any] = {},
|
|
148
|
-
index: "DatasetIndex | None" = None,
|
|
149
144
|
) -> None:
|
|
150
145
|
"""Creates a new Window instance.
|
|
151
146
|
|
|
@@ -153,23 +148,21 @@ class Window:
|
|
|
153
148
|
stored in metadata.json.
|
|
154
149
|
|
|
155
150
|
Args:
|
|
156
|
-
|
|
151
|
+
storage: the dataset storage for the underlying rslearn dataset.
|
|
157
152
|
group: the group the window belongs to
|
|
158
153
|
name: the unique name for this window
|
|
159
154
|
projection: the projection of the window
|
|
160
155
|
bounds: the bounds of the window in pixel coordinates
|
|
161
156
|
time_range: optional time range of the window
|
|
162
157
|
options: additional options (?)
|
|
163
|
-
index: DatasetIndex if it is available
|
|
164
158
|
"""
|
|
165
|
-
self.
|
|
159
|
+
self.storage = storage
|
|
166
160
|
self.group = group
|
|
167
161
|
self.name = name
|
|
168
162
|
self.projection = projection
|
|
169
163
|
self.bounds = bounds
|
|
170
164
|
self.time_range = time_range
|
|
171
165
|
self.options = options
|
|
172
|
-
self.index = index
|
|
173
166
|
|
|
174
167
|
def get_geometry(self) -> STGeometry:
|
|
175
168
|
"""Computes the STGeometry corresponding to this window."""
|
|
@@ -181,29 +174,11 @@ class Window:
|
|
|
181
174
|
|
|
182
175
|
def load_layer_datas(self) -> dict[str, WindowLayerData]:
|
|
183
176
|
"""Load layer datas describing items in retrieved layers from items.json."""
|
|
184
|
-
|
|
185
|
-
if self.index is not None:
|
|
186
|
-
layer_datas = self.index.layer_datas.get(self.name, [])
|
|
187
|
-
|
|
188
|
-
else:
|
|
189
|
-
items_fname = self.path / "items.json"
|
|
190
|
-
if not items_fname.exists():
|
|
191
|
-
return {}
|
|
192
|
-
with items_fname.open("r") as f:
|
|
193
|
-
layer_datas = [
|
|
194
|
-
WindowLayerData.deserialize(layer_data)
|
|
195
|
-
for layer_data in json.load(f)
|
|
196
|
-
]
|
|
197
|
-
|
|
198
|
-
return {layer_data.layer_name: layer_data for layer_data in layer_datas}
|
|
177
|
+
return self.storage.get_layer_datas(self.group, self.name)
|
|
199
178
|
|
|
200
179
|
def save_layer_datas(self, layer_datas: dict[str, WindowLayerData]) -> None:
|
|
201
180
|
"""Save layer datas to items.json."""
|
|
202
|
-
|
|
203
|
-
items_fname = self.path / "items.json"
|
|
204
|
-
logger.info(f"Saving window items to {items_fname}")
|
|
205
|
-
with open_atomic(items_fname, "w") as f:
|
|
206
|
-
json.dump(json_data, f)
|
|
181
|
+
self.storage.save_layer_datas(self.group, self.name, layer_datas)
|
|
207
182
|
|
|
208
183
|
def list_completed_layers(self) -> list[tuple[str, int]]:
|
|
209
184
|
"""List the layers available for this window that are completed.
|
|
@@ -211,18 +186,7 @@ class Window:
|
|
|
211
186
|
Returns:
|
|
212
187
|
a list of (layer_name, group_idx) completed layers.
|
|
213
188
|
"""
|
|
214
|
-
|
|
215
|
-
if not layers_directory.exists():
|
|
216
|
-
return []
|
|
217
|
-
|
|
218
|
-
completed_layers = []
|
|
219
|
-
for layer_dir in layers_directory.iterdir():
|
|
220
|
-
layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
|
|
221
|
-
if not self.is_layer_completed(layer_name, group_idx):
|
|
222
|
-
continue
|
|
223
|
-
completed_layers.append((layer_name, group_idx))
|
|
224
|
-
|
|
225
|
-
return completed_layers
|
|
189
|
+
return self.storage.list_completed_layers(self.group, self.name)
|
|
226
190
|
|
|
227
191
|
def get_layer_dir(self, layer_name: str, group_idx: int = 0) -> UPath:
|
|
228
192
|
"""Get the directory containing materialized data for the specified layer.
|
|
@@ -235,7 +199,9 @@ class Window:
|
|
|
235
199
|
Returns:
|
|
236
200
|
the path where data is or should be materialized.
|
|
237
201
|
"""
|
|
238
|
-
return get_window_layer_dir(
|
|
202
|
+
return get_window_layer_dir(
|
|
203
|
+
self.storage.get_window_root(self.group, self.name), layer_name, group_idx
|
|
204
|
+
)
|
|
239
205
|
|
|
240
206
|
def is_layer_completed(self, layer_name: str, group_idx: int = 0) -> bool:
|
|
241
207
|
"""Check whether the specified layer is completed.
|
|
@@ -250,14 +216,9 @@ class Window:
|
|
|
250
216
|
Returns:
|
|
251
217
|
whether the layer is completed
|
|
252
218
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
self.name, []
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
layer_dir = self.get_layer_dir(layer_name, group_idx)
|
|
260
|
-
return (layer_dir / "completed").exists()
|
|
219
|
+
return self.storage.is_layer_completed(
|
|
220
|
+
self.group, self.name, layer_name, group_idx
|
|
221
|
+
)
|
|
261
222
|
|
|
262
223
|
def mark_layer_completed(self, layer_name: str, group_idx: int = 0) -> None:
|
|
263
224
|
"""Mark the specified layer completed.
|
|
@@ -272,8 +233,7 @@ class Window:
|
|
|
272
233
|
layer_name: the layer name.
|
|
273
234
|
group_idx: the index of the group within the layer.
|
|
274
235
|
"""
|
|
275
|
-
|
|
276
|
-
(layer_dir / "completed").touch()
|
|
236
|
+
self.storage.mark_layer_completed(self.group, self.name, layer_name, group_idx)
|
|
277
237
|
|
|
278
238
|
def get_raster_dir(
|
|
279
239
|
self, layer_name: str, bands: list[str], group_idx: int = 0
|
|
@@ -289,7 +249,12 @@ class Window:
|
|
|
289
249
|
Returns:
|
|
290
250
|
the directory containing the raster.
|
|
291
251
|
"""
|
|
292
|
-
return get_window_raster_dir(
|
|
252
|
+
return get_window_raster_dir(
|
|
253
|
+
self.storage.get_window_root(self.group, self.name),
|
|
254
|
+
layer_name,
|
|
255
|
+
bands,
|
|
256
|
+
group_idx,
|
|
257
|
+
)
|
|
293
258
|
|
|
294
259
|
def get_metadata(self) -> dict[str, Any]:
|
|
295
260
|
"""Returns the window metadata dictionary."""
|
|
@@ -308,18 +273,14 @@ class Window:
|
|
|
308
273
|
|
|
309
274
|
def save(self) -> None:
|
|
310
275
|
"""Save the window metadata to its root directory."""
|
|
311
|
-
self.
|
|
312
|
-
metadata_path = self.path / "metadata.json"
|
|
313
|
-
logger.debug(f"Saving window metadata to {metadata_path}")
|
|
314
|
-
with open_atomic(metadata_path, "w") as f:
|
|
315
|
-
json.dump(self.get_metadata(), f)
|
|
276
|
+
self.storage.create_or_update_window(self)
|
|
316
277
|
|
|
317
278
|
@staticmethod
|
|
318
|
-
def from_metadata(
|
|
319
|
-
"""Create a Window from
|
|
279
|
+
def from_metadata(storage: WindowStorage, metadata: dict[str, Any]) -> "Window":
|
|
280
|
+
"""Create a Window from the WindowStorage and the window's metadata dictionary.
|
|
320
281
|
|
|
321
282
|
Args:
|
|
322
|
-
|
|
283
|
+
storage: the WindowStorage for the underlying dataset.
|
|
323
284
|
metadata: the window metadata.
|
|
324
285
|
|
|
325
286
|
Returns:
|
|
@@ -334,7 +295,7 @@ class Window:
|
|
|
334
295
|
)
|
|
335
296
|
|
|
336
297
|
return Window(
|
|
337
|
-
|
|
298
|
+
storage=storage,
|
|
338
299
|
group=metadata["group"],
|
|
339
300
|
name=metadata["name"],
|
|
340
301
|
projection=Projection.deserialize(metadata["projection"]),
|
|
@@ -350,21 +311,6 @@ class Window:
|
|
|
350
311
|
options=metadata["options"],
|
|
351
312
|
)
|
|
352
313
|
|
|
353
|
-
@staticmethod
|
|
354
|
-
def load(path: UPath) -> "Window":
|
|
355
|
-
"""Load a Window from a UPath.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
path: the root directory of the window
|
|
359
|
-
|
|
360
|
-
Returns:
|
|
361
|
-
the Window
|
|
362
|
-
"""
|
|
363
|
-
metadata_fname = path / "metadata.json"
|
|
364
|
-
with metadata_fname.open("r") as f:
|
|
365
|
-
metadata = json.load(f)
|
|
366
|
-
return Window.from_metadata(path, metadata)
|
|
367
|
-
|
|
368
314
|
@staticmethod
|
|
369
315
|
def get_window_root(ds_path: UPath, group: str, name: str) -> UPath:
|
|
370
316
|
"""Gets the root directory of a window.
|
rslearn/lightning_cli.py
CHANGED
|
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
|
|
|
21
21
|
from rslearn.train.data_module import RslearnDataModule
|
|
22
22
|
from rslearn.train.lightning_module import RslearnLightningModule
|
|
23
23
|
from rslearn.utils.fsspec import open_atomic
|
|
24
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
24
25
|
|
|
25
26
|
WANDB_ID_FNAME = "wandb_id"
|
|
26
27
|
|
|
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
390
391
|
|
|
391
392
|
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
392
393
|
"""
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
if not hasattr(self.config, "subcommand"):
|
|
395
|
+
logger.warning(
|
|
396
|
+
"Config does not have subcommand attribute, assuming we are in run=False mode"
|
|
397
|
+
)
|
|
398
|
+
subcommand = None
|
|
399
|
+
c = self.config
|
|
400
|
+
else:
|
|
401
|
+
subcommand = self.config.subcommand
|
|
402
|
+
c = self.config[subcommand]
|
|
395
403
|
|
|
396
404
|
# If there is a RslearnPredictionWriter, set its path.
|
|
397
405
|
prediction_writer_callback = None
|
|
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
415
423
|
if subcommand == "predict":
|
|
416
424
|
c.return_predictions = False
|
|
417
425
|
|
|
418
|
-
#
|
|
426
|
+
# Default to DDP with find_unused_parameters. Likely won't get called with unified config
|
|
419
427
|
if subcommand == "fit":
|
|
420
|
-
c.trainer.strategy
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
+
if not c.trainer.strategy:
|
|
429
|
+
c.trainer.strategy = jsonargparse.Namespace(
|
|
430
|
+
{
|
|
431
|
+
"class_path": "lightning.pytorch.strategies.DDPStrategy",
|
|
432
|
+
"init_args": jsonargparse.Namespace(
|
|
433
|
+
{"find_unused_parameters": True}
|
|
434
|
+
),
|
|
435
|
+
}
|
|
436
|
+
)
|
|
428
437
|
|
|
429
438
|
if c.management_dir:
|
|
430
439
|
self.enable_project_management(c.management_dir)
|
|
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
432
441
|
|
|
433
442
|
def model_handler() -> None:
|
|
434
443
|
"""Handler for any rslearn model X commands."""
|
|
444
|
+
init_jsonargparse()
|
|
445
|
+
|
|
435
446
|
RslearnLightningCLI(
|
|
436
447
|
model_class=RslearnLightningModule,
|
|
437
448
|
datamodule_class=RslearnDataModule,
|
rslearn/main.py
CHANGED
|
@@ -27,13 +27,13 @@ from rslearn.dataset.handler_summaries import (
|
|
|
27
27
|
PrepareDatasetWindowsSummary,
|
|
28
28
|
UnknownIngestCounts,
|
|
29
29
|
)
|
|
30
|
-
from rslearn.dataset.index import DatasetIndex
|
|
31
30
|
from rslearn.dataset.manage import (
|
|
32
31
|
AttemptsCounter,
|
|
33
32
|
materialize_dataset_windows,
|
|
34
33
|
prepare_dataset_windows,
|
|
35
34
|
retry,
|
|
36
35
|
)
|
|
36
|
+
from rslearn.dataset.storage.file import FileWindowStorage
|
|
37
37
|
from rslearn.log_utils import get_logger
|
|
38
38
|
from rslearn.tile_stores import get_tile_store_with_layer
|
|
39
39
|
from rslearn.utils import Projection, STGeometry
|
|
@@ -315,7 +315,8 @@ def apply_on_windows(
|
|
|
315
315
|
load_workers: optional different number of workers to use for loading the
|
|
316
316
|
windows. If set, workers controls the number of workers to process the
|
|
317
317
|
jobs, while load_workers controls the number of workers to use for reading
|
|
318
|
-
windows from the rslearn dataset.
|
|
318
|
+
windows from the rslearn dataset. Workers is only passed if the window
|
|
319
|
+
storage is FileWindowStorage.
|
|
319
320
|
batch_size: if workers > 0, the maximum number of windows to pass to the
|
|
320
321
|
function.
|
|
321
322
|
jobs_per_process: optional, terminate processes after they have handled this
|
|
@@ -336,11 +337,14 @@ def apply_on_windows(
|
|
|
336
337
|
else:
|
|
337
338
|
groups = group
|
|
338
339
|
|
|
339
|
-
if
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
340
|
+
# Load the windows. We pass workers and show_progress if it is FileWindowStorage.
|
|
341
|
+
kwargs: dict[str, Any] = {}
|
|
342
|
+
if isinstance(dataset.storage, FileWindowStorage):
|
|
343
|
+
if load_workers is None:
|
|
344
|
+
load_workers = workers
|
|
345
|
+
kwargs["workers"] = load_workers
|
|
346
|
+
kwargs["show_progress"] = True
|
|
347
|
+
windows = dataset.load_windows(groups=groups, names=names, **kwargs)
|
|
344
348
|
logger.info(f"found {len(windows)} windows")
|
|
345
349
|
|
|
346
350
|
if hasattr(f, "get_jobs"):
|
|
@@ -376,7 +380,7 @@ def apply_on_windows(
|
|
|
376
380
|
|
|
377
381
|
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
|
|
378
382
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
379
|
-
dataset = Dataset(UPath(args.root), args.disabled_layers)
|
|
383
|
+
dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
|
|
380
384
|
apply_on_windows(
|
|
381
385
|
f=f,
|
|
382
386
|
dataset=dataset,
|
|
@@ -798,35 +802,6 @@ def dataset_materialize() -> None:
|
|
|
798
802
|
apply_on_windows_args(fn, args)
|
|
799
803
|
|
|
800
804
|
|
|
801
|
-
@register_handler("dataset", "build_index")
|
|
802
|
-
def dataset_build_index() -> None:
|
|
803
|
-
"""Handler for the rslearn dataset build_index command."""
|
|
804
|
-
parser = argparse.ArgumentParser(
|
|
805
|
-
prog="rslearn dataset build_index",
|
|
806
|
-
description=("rslearn dataset build_index: " + "create a dataset index file"),
|
|
807
|
-
)
|
|
808
|
-
parser.add_argument(
|
|
809
|
-
"--root",
|
|
810
|
-
type=str,
|
|
811
|
-
required=True,
|
|
812
|
-
help="Dataset path",
|
|
813
|
-
)
|
|
814
|
-
parser.add_argument(
|
|
815
|
-
"--workers",
|
|
816
|
-
type=int,
|
|
817
|
-
default=16,
|
|
818
|
-
help="Number of workers",
|
|
819
|
-
)
|
|
820
|
-
args = parser.parse_args(args=sys.argv[3:])
|
|
821
|
-
ds_path = UPath(args.root)
|
|
822
|
-
dataset = Dataset(ds_path)
|
|
823
|
-
index = DatasetIndex.build_index(
|
|
824
|
-
dataset=dataset,
|
|
825
|
-
workers=args.workers,
|
|
826
|
-
)
|
|
827
|
-
index.save_index(ds_path)
|
|
828
|
-
|
|
829
|
-
|
|
830
805
|
@register_handler("model", "fit")
|
|
831
806
|
def model_fit() -> None:
|
|
832
807
|
"""Handler for rslearn model fit."""
|
rslearn/models/anysat.py
CHANGED
|
@@ -4,11 +4,13 @@ This code loads the AnySat model from torch hub. See
|
|
|
4
4
|
https://github.com/gastruc/AnySat for applicable license and copyright information.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
7
|
import torch
|
|
10
8
|
from einops import rearrange
|
|
11
9
|
|
|
10
|
+
from rslearn.train.model_context import ModelContext
|
|
11
|
+
|
|
12
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
13
|
+
|
|
12
14
|
# AnySat github: https://github.com/gastruc/AnySat
|
|
13
15
|
# Modalities and expected resolutions (meters)
|
|
14
16
|
MODALITY_RESOLUTIONS: dict[str, float] = {
|
|
@@ -44,7 +46,7 @@ MODALITY_BANDS: dict[str, list[str]] = {
|
|
|
44
46
|
TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
class AnySat(
|
|
49
|
+
class AnySat(FeatureExtractor):
|
|
48
50
|
"""AnySat backbone (outputs one feature map)."""
|
|
49
51
|
|
|
50
52
|
def __init__(
|
|
@@ -117,17 +119,17 @@ class AnySat(torch.nn.Module):
|
|
|
117
119
|
)
|
|
118
120
|
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
119
121
|
|
|
120
|
-
def forward(self,
|
|
122
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
121
123
|
"""Forward pass for the AnySat model.
|
|
122
124
|
|
|
123
125
|
Args:
|
|
124
|
-
|
|
126
|
+
context: the model context. Input dicts must include modalities as keys
|
|
127
|
+
which are defined in the self.modalities list
|
|
125
128
|
|
|
126
129
|
Returns:
|
|
127
|
-
|
|
130
|
+
a FeatureMaps with one feature map at the configured patch size.
|
|
128
131
|
"""
|
|
129
|
-
|
|
130
|
-
raise ValueError("empty inputs")
|
|
132
|
+
inputs = context.inputs
|
|
131
133
|
|
|
132
134
|
batch: dict[str, torch.Tensor] = {}
|
|
133
135
|
spatial_extent: tuple[float, float] | None = None
|
|
@@ -192,7 +194,7 @@ class AnySat(torch.nn.Module):
|
|
|
192
194
|
kwargs["output_modality"] = self.output_modality
|
|
193
195
|
|
|
194
196
|
features = self.model(batch, **kwargs)
|
|
195
|
-
return [rearrange(features, "b h w d -> b d h w")]
|
|
197
|
+
return FeatureMaps([rearrange(features, "b h w d -> b d h w")])
|
|
196
198
|
|
|
197
199
|
def get_backbone_channels(self) -> list:
|
|
198
200
|
"""Returns the output channels of this model when used as a backbone.
|