rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,140 @@
1
+ """Abstract classes for window metadata storage."""
2
+
3
+ import abc
4
+ from typing import TYPE_CHECKING
5
+
6
+ from upath import UPath
7
+
8
+ if TYPE_CHECKING:
9
+ from rslearn.dataset.window import Window, WindowLayerData
10
+
11
+
12
+ class WindowStorage(abc.ABC):
13
+ """An abstract class for the storage backend for window metadata.
14
+
15
+ This is instantiated by a WindowStorageFactory for a specific rslearn dataset.
16
+
17
+ Window metadata includes the location and time range of windows (metadata.json),
18
+ the window layer datas (items.json), and the completed (materialized) layers. It
19
+ excludes the actual materialized data. All operations involving window metadata go
20
+ through the WindowStorage, including enumerating windows, creating new windows, and
21
+ updating window layer datas during `rslearn dataset prepare` or the completed
22
+ layers during `rslearn dataset materialize`.
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ def get_window_root(self, group: str, name: str) -> UPath:
27
+ """Get the path where the window should be stored."""
28
+ raise NotImplementedError
29
+
30
+ @abc.abstractmethod
31
+ def get_windows(
32
+ self,
33
+ groups: list[str] | None = None,
34
+ names: list[str] | None = None,
35
+ ) -> list["Window"]:
36
+ """Load the windows in the dataset.
37
+
38
+ Args:
39
+ groups: an optional list of groups to filter loading
40
+ names: an optional list of window names to filter loading
41
+ """
42
+ raise NotImplementedError
43
+
44
+ @abc.abstractmethod
45
+ def create_or_update_window(self, window: "Window") -> None:
46
+ """Create or update the window.
47
+
48
+ An existing window is only updated if there is one with the same name and group.
49
+
50
+ If there is a window with the same name but a different group, the behavior is
51
+ undefined.
52
+ """
53
+ raise NotImplementedError
54
+
55
+ @abc.abstractmethod
56
+ def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
57
+ """Get the window layer datas for the specified window.
58
+
59
+ Args:
60
+ group: the window group.
61
+ name: the window name.
62
+
63
+ Returns:
64
+ a dict mapping from the layer name to the layer data for that layer, if one
65
+ was previously saved.
66
+ """
67
+ raise NotImplementedError
68
+
69
+ @abc.abstractmethod
70
+ def save_layer_datas(
71
+ self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
72
+ ) -> None:
73
+ """Set the window layer datas for the specified window."""
74
+ raise NotImplementedError
75
+
76
+ @abc.abstractmethod
77
+ def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
78
+ """List the layers available for this window that are completed.
79
+
80
+ Args:
81
+ group: the window group.
82
+ name: the window name.
83
+
84
+ Returns:
85
+ a list of (layer_name, group_idx) completed layers.
86
+ """
87
+ raise NotImplementedError
88
+
89
+ @abc.abstractmethod
90
+ def is_layer_completed(
91
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
92
+ ) -> bool:
93
+ """Check whether the specified layer is completed in the given window.
94
+
95
+ Completed means there is data in the layer and the data has been written
96
+ (materialized).
97
+
98
+ Args:
99
+ group: the window group.
100
+ name: the window name.
101
+ layer_name: the layer name.
102
+ group_idx: the index of the group within the layer.
103
+
104
+ Returns:
105
+ whether the layer is completed.
106
+ """
107
+ raise NotImplementedError
108
+
109
+ @abc.abstractmethod
110
+ def mark_layer_completed(
111
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
112
+ ) -> None:
113
+ """Mark the specified layer completed for the given window.
114
+
115
+ This must be done after the contents of the layer have been written. If a layer
116
+ has multiple groups, the caller should wait until the contents of all groups
117
+ have been written before marking them completed; this is because, when
118
+ materializing a window, we skip materialization if the first group
119
+ (group_idx=0) is marked completed.
120
+
121
+ Args:
122
+ group: the window group.
123
+ name: the window name.
124
+ layer_name: the layer name.
125
+ group_idx: the index of the group within the layer.
126
+ """
127
+ raise NotImplementedError
128
+
129
+
130
+ class WindowStorageFactory(abc.ABC):
131
+ """An abstract class for a configurable storage backend for window metadata.
132
+
133
+ The dataset config includes a StorageConfig that configures a WindowStorageFactory,
134
+ which in turn creates a WindowStorage given a dataset path.
135
+ """
136
+
137
+ @abc.abstractmethod
138
+ def get_storage(self, ds_path: UPath) -> WindowStorage:
139
+ """Get a WindowStorage for the given dataset path."""
140
+ raise NotImplementedError
rslearn/dataset/window.py CHANGED
@@ -1,20 +1,16 @@
1
1
  """rslearn windows."""
2
2
 
3
- import json
4
3
  from datetime import datetime
5
- from typing import TYPE_CHECKING, Any
4
+ from typing import Any
6
5
 
7
6
  import shapely
8
7
  from upath import UPath
9
8
 
9
+ from rslearn.dataset.storage.storage import WindowStorage
10
10
  from rslearn.log_utils import get_logger
11
11
  from rslearn.utils import Projection, STGeometry
12
- from rslearn.utils.fsspec import open_atomic
13
12
  from rslearn.utils.raster_format import get_bandset_dirname
14
13
 
15
- if TYPE_CHECKING:
16
- from .index import DatasetIndex
17
-
18
14
  logger = get_logger(__name__)
19
15
 
20
16
  LAYERS_DIRECTORY_NAME = "layers"
@@ -138,14 +134,13 @@ class Window:
138
134
 
139
135
  def __init__(
140
136
  self,
141
- path: UPath,
137
+ storage: WindowStorage,
142
138
  group: str,
143
139
  name: str,
144
140
  projection: Projection,
145
141
  bounds: tuple[int, int, int, int],
146
142
  time_range: tuple[datetime, datetime] | None,
147
143
  options: dict[str, Any] = {},
148
- index: "DatasetIndex | None" = None,
149
144
  ) -> None:
150
145
  """Creates a new Window instance.
151
146
 
@@ -153,23 +148,21 @@ class Window:
153
148
  stored in metadata.json.
154
149
 
155
150
  Args:
156
- path: the directory of this window
151
+ storage: the dataset storage for the underlying rslearn dataset.
157
152
  group: the group the window belongs to
158
153
  name: the unique name for this window
159
154
  projection: the projection of the window
160
155
  bounds: the bounds of the window in pixel coordinates
161
156
  time_range: optional time range of the window
162
157
  options: additional options (?)
163
- index: DatasetIndex if it is available
164
158
  """
165
- self.path = path
159
+ self.storage = storage
166
160
  self.group = group
167
161
  self.name = name
168
162
  self.projection = projection
169
163
  self.bounds = bounds
170
164
  self.time_range = time_range
171
165
  self.options = options
172
- self.index = index
173
166
 
174
167
  def get_geometry(self) -> STGeometry:
175
168
  """Computes the STGeometry corresponding to this window."""
@@ -181,29 +174,11 @@ class Window:
181
174
 
182
175
  def load_layer_datas(self) -> dict[str, WindowLayerData]:
183
176
  """Load layer datas describing items in retrieved layers from items.json."""
184
- # Load from index if it is available.
185
- if self.index is not None:
186
- layer_datas = self.index.layer_datas.get(self.name, [])
187
-
188
- else:
189
- items_fname = self.path / "items.json"
190
- if not items_fname.exists():
191
- return {}
192
- with items_fname.open("r") as f:
193
- layer_datas = [
194
- WindowLayerData.deserialize(layer_data)
195
- for layer_data in json.load(f)
196
- ]
197
-
198
- return {layer_data.layer_name: layer_data for layer_data in layer_datas}
177
+ return self.storage.get_layer_datas(self.group, self.name)
199
178
 
200
179
  def save_layer_datas(self, layer_datas: dict[str, WindowLayerData]) -> None:
201
180
  """Save layer datas to items.json."""
202
- json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
203
- items_fname = self.path / "items.json"
204
- logger.info(f"Saving window items to {items_fname}")
205
- with open_atomic(items_fname, "w") as f:
206
- json.dump(json_data, f)
181
+ self.storage.save_layer_datas(self.group, self.name, layer_datas)
207
182
 
208
183
  def list_completed_layers(self) -> list[tuple[str, int]]:
209
184
  """List the layers available for this window that are completed.
@@ -211,18 +186,7 @@ class Window:
211
186
  Returns:
212
187
  a list of (layer_name, group_idx) completed layers.
213
188
  """
214
- layers_directory = self.path / LAYERS_DIRECTORY_NAME
215
- if not layers_directory.exists():
216
- return []
217
-
218
- completed_layers = []
219
- for layer_dir in layers_directory.iterdir():
220
- layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
221
- if not self.is_layer_completed(layer_name, group_idx):
222
- continue
223
- completed_layers.append((layer_name, group_idx))
224
-
225
- return completed_layers
189
+ return self.storage.list_completed_layers(self.group, self.name)
226
190
 
227
191
  def get_layer_dir(self, layer_name: str, group_idx: int = 0) -> UPath:
228
192
  """Get the directory containing materialized data for the specified layer.
@@ -235,7 +199,9 @@ class Window:
235
199
  Returns:
236
200
  the path where data is or should be materialized.
237
201
  """
238
- return get_window_layer_dir(self.path, layer_name, group_idx)
202
+ return get_window_layer_dir(
203
+ self.storage.get_window_root(self.group, self.name), layer_name, group_idx
204
+ )
239
205
 
240
206
  def is_layer_completed(self, layer_name: str, group_idx: int = 0) -> bool:
241
207
  """Check whether the specified layer is completed.
@@ -250,14 +216,9 @@ class Window:
250
216
  Returns:
251
217
  whether the layer is completed
252
218
  """
253
- # Use the index to speed up the completed check if it is available.
254
- if self.index is not None:
255
- return (layer_name, group_idx) in self.index.completed_layers.get(
256
- self.name, []
257
- )
258
-
259
- layer_dir = self.get_layer_dir(layer_name, group_idx)
260
- return (layer_dir / "completed").exists()
219
+ return self.storage.is_layer_completed(
220
+ self.group, self.name, layer_name, group_idx
221
+ )
261
222
 
262
223
  def mark_layer_completed(self, layer_name: str, group_idx: int = 0) -> None:
263
224
  """Mark the specified layer completed.
@@ -272,8 +233,7 @@ class Window:
272
233
  layer_name: the layer name.
273
234
  group_idx: the index of the group within the layer.
274
235
  """
275
- layer_dir = self.get_layer_dir(layer_name, group_idx)
276
- (layer_dir / "completed").touch()
236
+ self.storage.mark_layer_completed(self.group, self.name, layer_name, group_idx)
277
237
 
278
238
  def get_raster_dir(
279
239
  self, layer_name: str, bands: list[str], group_idx: int = 0
@@ -289,7 +249,12 @@ class Window:
289
249
  Returns:
290
250
  the directory containing the raster.
291
251
  """
292
- return get_window_raster_dir(self.path, layer_name, bands, group_idx)
252
+ return get_window_raster_dir(
253
+ self.storage.get_window_root(self.group, self.name),
254
+ layer_name,
255
+ bands,
256
+ group_idx,
257
+ )
293
258
 
294
259
  def get_metadata(self) -> dict[str, Any]:
295
260
  """Returns the window metadata dictionary."""
@@ -308,18 +273,14 @@ class Window:
308
273
 
309
274
  def save(self) -> None:
310
275
  """Save the window metadata to its root directory."""
311
- self.path.mkdir(parents=True, exist_ok=True)
312
- metadata_path = self.path / "metadata.json"
313
- logger.debug(f"Saving window metadata to {metadata_path}")
314
- with open_atomic(metadata_path, "w") as f:
315
- json.dump(self.get_metadata(), f)
276
+ self.storage.create_or_update_window(self)
316
277
 
317
278
  @staticmethod
318
- def from_metadata(path: UPath, metadata: dict[str, Any]) -> "Window":
319
- """Create a Window from its path and metadata dictionary.
279
+ def from_metadata(storage: WindowStorage, metadata: dict[str, Any]) -> "Window":
280
+ """Create a Window from the WindowStorage and the window's metadata dictionary.
320
281
 
321
282
  Args:
322
- path: the root directory of the window.
283
+ storage: the WindowStorage for the underlying dataset.
323
284
  metadata: the window metadata.
324
285
 
325
286
  Returns:
@@ -334,7 +295,7 @@ class Window:
334
295
  )
335
296
 
336
297
  return Window(
337
- path=path,
298
+ storage=storage,
338
299
  group=metadata["group"],
339
300
  name=metadata["name"],
340
301
  projection=Projection.deserialize(metadata["projection"]),
@@ -350,21 +311,6 @@ class Window:
350
311
  options=metadata["options"],
351
312
  )
352
313
 
353
- @staticmethod
354
- def load(path: UPath) -> "Window":
355
- """Load a Window from a UPath.
356
-
357
- Args:
358
- path: the root directory of the window
359
-
360
- Returns:
361
- the Window
362
- """
363
- metadata_fname = path / "metadata.json"
364
- with metadata_fname.open("r") as f:
365
- metadata = json.load(f)
366
- return Window.from_metadata(path, metadata)
367
-
368
314
  @staticmethod
369
315
  def get_window_root(ds_path: UPath, group: str, name: str) -> UPath:
370
316
  """Gets the root directory of a window.
rslearn/lightning_cli.py CHANGED
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
21
21
  from rslearn.train.data_module import RslearnDataModule
22
22
  from rslearn.train.lightning_module import RslearnLightningModule
23
23
  from rslearn.utils.fsspec import open_atomic
24
+ from rslearn.utils.jsonargparse import init_jsonargparse
24
25
 
25
26
  WANDB_ID_FNAME = "wandb_id"
26
27
 
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
390
391
 
391
392
  Sets the dataset path for any configured RslearnPredictionWriter callbacks.
392
393
  """
393
- subcommand = self.config.subcommand
394
- c = self.config[subcommand]
394
+ if not hasattr(self.config, "subcommand"):
395
+ logger.warning(
396
+ "Config does not have subcommand attribute, assuming we are in run=False mode"
397
+ )
398
+ subcommand = None
399
+ c = self.config
400
+ else:
401
+ subcommand = self.config.subcommand
402
+ c = self.config[subcommand]
395
403
 
396
404
  # If there is a RslearnPredictionWriter, set its path.
397
405
  prediction_writer_callback = None
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
415
423
  if subcommand == "predict":
416
424
  c.return_predictions = False
417
425
 
418
- # For now we use DDP strategy with find_unused_parameters=True.
426
+ # Default to DDP with find_unused_parameters. Likely won't get called with unified config
419
427
  if subcommand == "fit":
420
- c.trainer.strategy = jsonargparse.Namespace(
421
- {
422
- "class_path": "lightning.pytorch.strategies.DDPStrategy",
423
- "init_args": jsonargparse.Namespace(
424
- {"find_unused_parameters": True}
425
- ),
426
- }
427
- )
428
+ if not c.trainer.strategy:
429
+ c.trainer.strategy = jsonargparse.Namespace(
430
+ {
431
+ "class_path": "lightning.pytorch.strategies.DDPStrategy",
432
+ "init_args": jsonargparse.Namespace(
433
+ {"find_unused_parameters": True}
434
+ ),
435
+ }
436
+ )
428
437
 
429
438
  if c.management_dir:
430
439
  self.enable_project_management(c.management_dir)
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
432
441
 
433
442
  def model_handler() -> None:
434
443
  """Handler for any rslearn model X commands."""
444
+ init_jsonargparse()
445
+
435
446
  RslearnLightningCLI(
436
447
  model_class=RslearnLightningModule,
437
448
  datamodule_class=RslearnDataModule,
rslearn/main.py CHANGED
@@ -27,13 +27,13 @@ from rslearn.dataset.handler_summaries import (
27
27
  PrepareDatasetWindowsSummary,
28
28
  UnknownIngestCounts,
29
29
  )
30
- from rslearn.dataset.index import DatasetIndex
31
30
  from rslearn.dataset.manage import (
32
31
  AttemptsCounter,
33
32
  materialize_dataset_windows,
34
33
  prepare_dataset_windows,
35
34
  retry,
36
35
  )
36
+ from rslearn.dataset.storage.file import FileWindowStorage
37
37
  from rslearn.log_utils import get_logger
38
38
  from rslearn.tile_stores import get_tile_store_with_layer
39
39
  from rslearn.utils import Projection, STGeometry
@@ -315,7 +315,8 @@ def apply_on_windows(
315
315
  load_workers: optional different number of workers to use for loading the
316
316
  windows. If set, workers controls the number of workers to process the
317
317
  jobs, while load_workers controls the number of workers to use for reading
318
- windows from the rslearn dataset.
318
+ windows from the rslearn dataset. Workers is only passed if the window
319
+ storage is FileWindowStorage.
319
320
  batch_size: if workers > 0, the maximum number of windows to pass to the
320
321
  function.
321
322
  jobs_per_process: optional, terminate processes after they have handled this
@@ -336,11 +337,14 @@ def apply_on_windows(
336
337
  else:
337
338
  groups = group
338
339
 
339
- if load_workers is None:
340
- load_workers = workers
341
- windows = dataset.load_windows(
342
- groups=groups, names=names, workers=load_workers, show_progress=True
343
- )
340
+ # Load the windows. We pass workers and show_progress if it is FileWindowStorage.
341
+ kwargs: dict[str, Any] = {}
342
+ if isinstance(dataset.storage, FileWindowStorage):
343
+ if load_workers is None:
344
+ load_workers = workers
345
+ kwargs["workers"] = load_workers
346
+ kwargs["show_progress"] = True
347
+ windows = dataset.load_windows(groups=groups, names=names, **kwargs)
344
348
  logger.info(f"found {len(windows)} windows")
345
349
 
346
350
  if hasattr(f, "get_jobs"):
@@ -376,7 +380,7 @@ def apply_on_windows(
376
380
 
377
381
  def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
378
382
  """Call apply_on_windows with arguments passed via command-line interface."""
379
- dataset = Dataset(UPath(args.root), args.disabled_layers)
383
+ dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
380
384
  apply_on_windows(
381
385
  f=f,
382
386
  dataset=dataset,
@@ -798,35 +802,6 @@ def dataset_materialize() -> None:
798
802
  apply_on_windows_args(fn, args)
799
803
 
800
804
 
801
- @register_handler("dataset", "build_index")
802
- def dataset_build_index() -> None:
803
- """Handler for the rslearn dataset build_index command."""
804
- parser = argparse.ArgumentParser(
805
- prog="rslearn dataset build_index",
806
- description=("rslearn dataset build_index: " + "create a dataset index file"),
807
- )
808
- parser.add_argument(
809
- "--root",
810
- type=str,
811
- required=True,
812
- help="Dataset path",
813
- )
814
- parser.add_argument(
815
- "--workers",
816
- type=int,
817
- default=16,
818
- help="Number of workers",
819
- )
820
- args = parser.parse_args(args=sys.argv[3:])
821
- ds_path = UPath(args.root)
822
- dataset = Dataset(ds_path)
823
- index = DatasetIndex.build_index(
824
- dataset=dataset,
825
- workers=args.workers,
826
- )
827
- index.save_index(ds_path)
828
-
829
-
830
805
  @register_handler("model", "fit")
831
806
  def model_fit() -> None:
832
807
  """Handler for rslearn model fit."""
rslearn/models/anysat.py CHANGED
@@ -4,11 +4,13 @@ This code loads the AnySat model from torch hub. See
4
4
  https://github.com/gastruc/AnySat for applicable license and copyright information.
5
5
  """
6
6
 
7
- from typing import Any
8
-
9
7
  import torch
10
8
  from einops import rearrange
11
9
 
10
+ from rslearn.train.model_context import ModelContext
11
+
12
+ from .component import FeatureExtractor, FeatureMaps
13
+
12
14
  # AnySat github: https://github.com/gastruc/AnySat
13
15
  # Modalities and expected resolutions (meters)
14
16
  MODALITY_RESOLUTIONS: dict[str, float] = {
@@ -44,7 +46,7 @@ MODALITY_BANDS: dict[str, list[str]] = {
44
46
  TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
45
47
 
46
48
 
47
- class AnySat(torch.nn.Module):
49
+ class AnySat(FeatureExtractor):
48
50
  """AnySat backbone (outputs one feature map)."""
49
51
 
50
52
  def __init__(
@@ -117,17 +119,17 @@ class AnySat(torch.nn.Module):
117
119
  )
118
120
  self._embed_dim = 768 # base width, 'dense' returns 2x
119
121
 
120
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
122
+ def forward(self, context: ModelContext) -> FeatureMaps:
121
123
  """Forward pass for the AnySat model.
122
124
 
123
125
  Args:
124
- inputs: input dicts that must include modalities as keys which are defined in the self.modalities list
126
+ context: the model context. Input dicts must include modalities as keys
127
+ which are defined in the self.modalities list
125
128
 
126
129
  Returns:
127
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
130
+ a FeatureMaps with one feature map at the configured patch size.
128
131
  """
129
- if not inputs:
130
- raise ValueError("empty inputs")
132
+ inputs = context.inputs
131
133
 
132
134
  batch: dict[str, torch.Tensor] = {}
133
135
  spatial_extent: tuple[float, float] | None = None
@@ -192,7 +194,7 @@ class AnySat(torch.nn.Module):
192
194
  kwargs["output_modality"] = self.output_modality
193
195
 
194
196
  features = self.model(batch, **kwargs)
195
- return [rearrange(features, "b h w d -> b d h w")]
197
+ return FeatureMaps([rearrange(features, "b h w d -> b d h w")])
196
198
 
197
199
  def get_backbone_channels(self) -> list:
198
200
  """Returns the output channels of this model when used as a backbone.