rslearn 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,12 +2,12 @@
2
2
 
3
3
  import functools
4
4
  import json
5
+ from collections.abc import Callable
5
6
  from typing import Any, Generic, TypeVar
6
7
 
7
8
  import fiona
8
9
  import shapely
9
10
  import shapely.geometry
10
- from class_registry import ClassRegistry
11
11
  from rasterio.crs import CRS
12
12
  from upath import UPath
13
13
 
@@ -23,7 +23,24 @@ from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
23
23
  from .data_source import DataSource, Item, QueryConfig
24
24
 
25
25
  logger = get_logger("__name__")
26
- Importers = ClassRegistry()
26
+ _ImporterT = TypeVar("_ImporterT", bound="Importer")
27
+
28
+
29
+ class _ImporterRegistry(dict[str, type["Importer"]]):
30
+ """Registry for Importer classes."""
31
+
32
+ def register(self, name: str) -> Callable[[type[_ImporterT]], type[_ImporterT]]:
33
+ """Decorator to register an importer class."""
34
+
35
+ def decorator(cls: type[_ImporterT]) -> type[_ImporterT]:
36
+ self[name] = cls
37
+ return cls
38
+
39
+ return decorator
40
+
41
+
42
+ Importers = _ImporterRegistry()
43
+
27
44
 
28
45
  ItemType = TypeVar("ItemType", bound=Item)
29
46
  LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
@@ -425,7 +442,7 @@ class LocalFiles(DataSource):
425
442
  """
426
443
  self.config = config
427
444
 
428
- self.importer = Importers[config.layer_type.value]
445
+ self.importer = Importers[config.layer_type.value]()
429
446
  self.src_dir = src_dir
430
447
 
431
448
  @staticmethod
@@ -83,6 +83,10 @@ class PlanetaryComputer(DataSource, TileStore):
83
83
 
84
84
  STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"
85
85
 
86
+ # Default threshold for recreating the STAC client to prevent memory leaks
87
+ # from the pystac Catalog's resolved objects cache growing unbounded
88
+ DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
89
+
86
90
  def __init__(
87
91
  self,
88
92
  collection_name: str,
@@ -93,6 +97,7 @@ class PlanetaryComputer(DataSource, TileStore):
93
97
  timeout: timedelta = timedelta(seconds=10),
94
98
  skip_items_missing_assets: bool = False,
95
99
  cache_dir: UPath | None = None,
100
+ max_items_per_client: int | None = None,
96
101
  ):
97
102
  """Initialize a new PlanetaryComputer instance.
98
103
 
@@ -109,6 +114,9 @@ class PlanetaryComputer(DataSource, TileStore):
109
114
  cache_dir: optional directory to cache items by name, including asset URLs.
110
115
  If not set, there will be no cache and instead STAC requests will be
111
116
  needed each time.
117
+ max_items_per_client: number of STAC items to process before recreating
118
+ the client to prevent memory leaks from the resolved objects cache.
119
+ Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
112
120
  """
113
121
  self.collection_name = collection_name
114
122
  self.asset_bands = asset_bands
@@ -118,12 +126,15 @@ class PlanetaryComputer(DataSource, TileStore):
118
126
  self.timeout = timeout
119
127
  self.skip_items_missing_assets = skip_items_missing_assets
120
128
  self.cache_dir = cache_dir
129
+ self.max_items_per_client = (
130
+ max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
131
+ )
121
132
 
122
133
  if self.cache_dir is not None:
123
134
  self.cache_dir.mkdir(parents=True, exist_ok=True)
124
135
 
125
136
  self.client: pystac_client.Client | None = None
126
- self.collection: pystac_client.CollectionClient | None = None
137
+ self._client_item_count = 0
127
138
 
128
139
  @staticmethod
129
140
  def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
@@ -142,7 +153,12 @@ class PlanetaryComputer(DataSource, TileStore):
142
153
  if "cache_dir" in d:
143
154
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
144
155
 
145
- simple_optionals = ["query", "sort_by", "sort_ascending"]
156
+ simple_optionals = [
157
+ "query",
158
+ "sort_by",
159
+ "sort_ascending",
160
+ "max_items_per_client",
161
+ ]
146
162
  for k in simple_optionals:
147
163
  if k in d:
148
164
  kwargs[k] = d[k]
@@ -151,20 +167,40 @@ class PlanetaryComputer(DataSource, TileStore):
151
167
 
152
168
  def _load_client(
153
169
  self,
154
- ) -> tuple[pystac_client.Client, pystac_client.CollectionClient]:
170
+ ) -> pystac_client.Client:
155
171
  """Lazily load pystac client.
156
172
 
157
173
  We don't load it when creating the data source because it takes time and caller
158
174
  may not be calling get_items. Additionally, loading it during the get_items
159
175
  call enables leveraging the retry loop functionality in
160
176
  prepare_dataset_windows.
161
- """
162
- if self.client is not None:
163
- return self.client, self.collection
164
177
 
178
+ Note: We periodically recreate the client to prevent memory leaks from the
179
+ pystac Catalog's resolved objects cache, which grows unbounded as STAC items
180
+ are deserialized and cached. The cache cannot be cleared or disabled.
181
+ """
182
+ if self.client is None:
183
+ logger.info("Creating initial STAC client")
184
+ self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
185
+ return self.client
186
+
187
+ if self._client_item_count < self.max_items_per_client:
188
+ return self.client
189
+
190
+ # Recreate client to clear the resolved objects cache
191
+ current_client = self.client
192
+ logger.debug(
193
+ "Recreating STAC client after processing %d items (threshold: %d)",
194
+ self._client_item_count,
195
+ self.max_items_per_client,
196
+ )
197
+ client_root = current_client.get_root()
198
+ client_root.clear_links()
199
+ client_root.clear_items()
200
+ client_root.clear_children()
201
+ self._client_item_count = 0
165
202
  self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
166
- self.collection = self.client.get_collection(self.collection_name)
167
- return self.client, self.collection
203
+ return self.client
168
204
 
169
205
  def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
170
206
  shp = shapely.geometry.shape(stac_item.geometry)
@@ -210,10 +246,26 @@ class PlanetaryComputer(DataSource, TileStore):
210
246
 
211
247
  # No cache or not in cache, so we need to make the STAC request.
212
248
  logger.debug("Getting STAC item {name}")
213
- _, collection = self._load_client()
214
- stac_item = collection.get_item(name)
249
+ client = self._load_client()
250
+
251
+ search_result = client.search(ids=[name], collections=[self.collection_name])
252
+ stac_items = list(search_result.items())
253
+
254
+ if not stac_items:
255
+ raise ValueError(
256
+ f"Item {name} not found in collection {self.collection_name}"
257
+ )
258
+ if len(stac_items) > 1:
259
+ raise ValueError(
260
+ f"Multiple items found for ID {name} in collection {self.collection_name}"
261
+ )
262
+
263
+ stac_item = stac_items[0]
215
264
  item = self._stac_item_to_item(stac_item)
216
265
 
266
+ # Track items processed for client recreation threshold (after deserialization)
267
+ self._client_item_count += 1
268
+
217
269
  # Finally we cache it if cache_dir is set.
218
270
  if cache_fname is not None:
219
271
  with cache_fname.open("w") as f:
@@ -233,7 +285,7 @@ class PlanetaryComputer(DataSource, TileStore):
233
285
  Returns:
234
286
  List of groups of items that should be retrieved for each geometry.
235
287
  """
236
- client, _ = self._load_client()
288
+ client = self._load_client()
237
289
 
238
290
  groups = []
239
291
  for geometry in geometries:
@@ -247,7 +299,9 @@ class PlanetaryComputer(DataSource, TileStore):
247
299
  datetime=wgs84_geometry.time_range,
248
300
  query=self.query,
249
301
  )
250
- stac_items = [item for item in result.item_collection()]
302
+ stac_items = [item for item in result.items()]
303
+ # Track items processed for client recreation threshold (after deserialization)
304
+ self._client_item_count += len(stac_items)
251
305
  logger.debug("STAC search yielded %d items", len(stac_items))
252
306
 
253
307
  if self.skip_items_missing_assets:
@@ -580,7 +634,13 @@ class Sentinel2(PlanetaryComputer):
580
634
  if "cache_dir" in d:
581
635
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
582
636
 
583
- simple_optionals = ["harmonize", "query", "sort_by", "sort_ascending"]
637
+ simple_optionals = [
638
+ "harmonize",
639
+ "query",
640
+ "sort_by",
641
+ "sort_ascending",
642
+ "max_items_per_client",
643
+ ]
584
644
  for k in simple_optionals:
585
645
  if k in d:
586
646
  kwargs[k] = d[k]
@@ -756,7 +816,12 @@ class Sentinel1(PlanetaryComputer):
756
816
  if "cache_dir" in d:
757
817
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
758
818
 
759
- simple_optionals = ["query", "sort_by", "sort_ascending"]
819
+ simple_optionals = [
820
+ "query",
821
+ "sort_by",
822
+ "sort_ascending",
823
+ "max_items_per_client",
824
+ ]
760
825
  for k in simple_optionals:
761
826
  if k in d:
762
827
  kwargs[k] = d[k]
rslearn/dataset/manage.py CHANGED
@@ -396,9 +396,9 @@ def materialize_window(
396
396
  )
397
397
 
398
398
  if dataset.materializer_name:
399
- materializer = Materializers[dataset.materializer_name]
399
+ materializer = Materializers[dataset.materializer_name]()
400
400
  else:
401
- materializer = Materializers[layer_cfg.layer_type.value]
401
+ materializer = Materializers[layer_cfg.layer_type.value]()
402
402
 
403
403
  retry(
404
404
  fn=lambda: materializer.materialize(
@@ -1,10 +1,10 @@
1
1
  """Classes to implement dataset materialization."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from typing import Any, Generic, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
  from rasterio.enums import Resampling
9
9
 
10
10
  from rslearn.config import (
@@ -25,7 +25,26 @@ from rslearn.utils.vector_format import load_vector_format
25
25
  from .remap import Remapper, load_remapper
26
26
  from .window import Window
27
27
 
28
- Materializers = ClassRegistry()
28
+ _MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
29
+
30
+
31
+ class _MaterializerRegistry(dict[str, type["Materializer"]]):
32
+ """Registry for Materializer classes."""
33
+
34
+ def register(
35
+ self, name: str
36
+ ) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
37
+ """Decorator to register a materializer class."""
38
+
39
+ def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
40
+ self[name] = cls
41
+ return cls
42
+
43
+ return decorator
44
+
45
+
46
+ Materializers = _MaterializerRegistry()
47
+
29
48
 
30
49
  LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
31
50
 
rslearn/dataset/remap.py CHANGED
@@ -1,18 +1,42 @@
1
1
  """Classes to remap raster values."""
2
2
 
3
- from typing import Any
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
 
9
- Remappers = ClassRegistry()
9
+ _RemapperT = TypeVar("_RemapperT", bound="Remapper")
10
+
11
+
12
+ class _RemapperRegistry(dict[str, type["Remapper"]]):
13
+ """Registry for Remapper classes."""
14
+
15
+ def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
16
+ """Decorator to register a remapper class."""
17
+
18
+ def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
19
+ self[name] = cls
20
+ return cls
21
+
22
+ return decorator
23
+
24
+
25
+ Remappers = _RemapperRegistry()
10
26
  """Registry of Remapper implementations."""
11
27
 
12
28
 
13
29
  class Remapper:
14
30
  """An abstract class that remaps pixel values based on layer configuration."""
15
31
 
32
+ def __init__(self, config: dict[str, Any]) -> None:
33
+ """Initialize a Remapper.
34
+
35
+ Args:
36
+ config: the config dict for this remapper.
37
+ """
38
+ pass
39
+
16
40
  def __call__(
17
41
  self, array: npt.NDArray[Any], dtype: npt.DTypeLike
18
42
  ) -> npt.NDArray[Any]:
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
67
91
 
68
92
  def load_remapper(config: dict[str, Any]) -> Remapper:
69
93
  """Load a remapper from a configuration dictionary."""
70
- return Remappers.get(config["name"], config=config)
94
+ cls = Remappers[config["name"]]
95
+ return cls(config)
rslearn/models/dinov3.py CHANGED
@@ -7,8 +7,8 @@ from typing import Any
7
7
  import torch
8
8
  import torchvision
9
9
  from einops import rearrange
10
- from torchvision.transforms import v2
11
10
 
11
+ from rslearn.train.transforms.normalize import Normalize
12
12
  from rslearn.train.transforms.transform import Transform
13
13
 
14
14
 
@@ -139,15 +139,17 @@ class DinoV3Normalize(Transform):
139
139
  super().__init__()
140
140
  self.satellite = satellite
141
141
  if satellite:
142
- self.normalize = v2.Normalize(
143
- mean=(0.430, 0.411, 0.296),
144
- std=(0.213, 0.156, 0.143),
145
- )
142
+ mean = [0.430, 0.411, 0.296]
143
+ std = [0.213, 0.156, 0.143]
146
144
  else:
147
- self.normalize = v2.Normalize(
148
- mean=(0.485, 0.456, 0.406),
149
- std=(0.229, 0.224, 0.225),
150
- )
145
+ mean = [0.485, 0.456, 0.406]
146
+ std = [0.229, 0.224, 0.225]
147
+
148
+ self.normalize = Normalize(
149
+ [value * 255 for value in mean],
150
+ [value * 255 for value in std],
151
+ num_bands=3,
152
+ )
151
153
 
152
154
  def forward(
153
155
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -161,5 +163,4 @@ class DinoV3Normalize(Transform):
161
163
  Returns:
162
164
  normalized (input_dicts, target_dicts) tuple
163
165
  """
164
- input_dict["image"] = self.normalize(input_dict["image"] / 255.0)
165
- return input_dict, target_dict
166
+ return self.normalize(input_dict, target_dict)
@@ -2,6 +2,7 @@
2
2
 
3
3
  import math
4
4
  import tempfile
5
+ from contextlib import nullcontext
5
6
  from enum import StrEnum
6
7
  from typing import Any, cast
7
8
 
@@ -63,6 +64,11 @@ pretrained_weights: dict[GalileoSize, str] = {
63
64
 
64
65
  DEFAULT_NORMALIZER = Normalizer()
65
66
 
67
+ AUTOCAST_DTYPE_MAP = {
68
+ "bfloat16": torch.bfloat16,
69
+ "float32": torch.float32,
70
+ }
71
+
66
72
 
67
73
  class GalileoModel(nn.Module):
68
74
  """Galileo backbones."""
@@ -85,6 +91,7 @@ class GalileoModel(nn.Module):
85
91
  size: GalileoSize,
86
92
  patch_size: int = 4,
87
93
  pretrained_path: str | UPath | None = None,
94
+ autocast_dtype: str | None = "bfloat16",
88
95
  ) -> None:
89
96
  """Initialize the Galileo model.
90
97
 
@@ -93,6 +100,7 @@ class GalileoModel(nn.Module):
93
100
  patch_size: The patch size to use.
94
101
  pretrained_path: the local path to the pretrained weights. Otherwise it is
95
102
  downloaded and cached in temp directory.
103
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
96
104
  """
97
105
  super().__init__()
98
106
  if pretrained_path is None:
@@ -128,8 +136,14 @@ class GalileoModel(nn.Module):
128
136
  idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
129
137
  ]
130
138
 
139
+ self.size = size
131
140
  self.patch_size = patch_size
132
141
 
142
+ if autocast_dtype is not None:
143
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
144
+ else:
145
+ self.autocast_dtype = None
146
+
133
147
  @staticmethod
134
148
  def to_cartesian(
135
149
  lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
@@ -484,18 +498,31 @@ class GalileoModel(nn.Module):
484
498
  patch_size = h
485
499
  else:
486
500
  patch_size = self.patch_size
487
- outputs = self.model(
488
- s_t_x=galileo_input.s_t_x,
489
- s_t_m=galileo_input.s_t_m,
490
- sp_x=galileo_input.sp_x,
491
- sp_m=galileo_input.sp_m,
492
- t_x=galileo_input.t_x,
493
- t_m=galileo_input.t_m,
494
- st_x=galileo_input.st_x,
495
- st_m=galileo_input.st_m,
496
- months=galileo_input.months,
497
- patch_size=patch_size,
498
- )
501
+
502
+ # Decide context based on self.autocast_dtype.
503
+ device = galileo_input.s_t_x.device
504
+ if self.autocast_dtype is None:
505
+ context = nullcontext()
506
+ else:
507
+ assert device is not None
508
+ context = torch.amp.autocast(
509
+ device_type=device.type, dtype=self.autocast_dtype
510
+ )
511
+
512
+ with context:
513
+ outputs = self.model(
514
+ s_t_x=galileo_input.s_t_x,
515
+ s_t_m=galileo_input.s_t_m,
516
+ sp_x=galileo_input.sp_x,
517
+ sp_m=galileo_input.sp_m,
518
+ t_x=galileo_input.t_x,
519
+ t_m=galileo_input.t_m,
520
+ st_x=galileo_input.st_x,
521
+ st_m=galileo_input.st_m,
522
+ months=galileo_input.months,
523
+ patch_size=patch_size,
524
+ )
525
+
499
526
  if h == patch_size:
500
527
  # only one spatial patch, so we can just take an average
501
528
  # of all the tokens to output b c_g 1 1
@@ -515,3 +542,22 @@ class GalileoModel(nn.Module):
515
542
  "b h w c_g d -> b c_g d h w",
516
543
  ).mean(dim=1)
517
544
  ]
545
+
546
+ def get_backbone_channels(self) -> list:
547
+ """Returns the output channels of this model when used as a backbone.
548
+
549
+ The output channels is a list of (patch_size, depth) that corresponds
550
+ to the feature maps that the backbone returns.
551
+
552
+ Returns:
553
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
554
+ """
555
+ if self.size == GalileoSize.BASE:
556
+ depth = 768
557
+ elif self.model_size == GalileoSize.TINY:
558
+ depth = 192
559
+ elif self.model_size == GalileoSize.NANO:
560
+ depth = 128
561
+ else:
562
+ raise ValueError(f"Invalid model size: {self.size}")
563
+ return [(self.patch_size, depth)]
@@ -1469,7 +1469,13 @@ class Encoder(GalileoBase):
1469
1469
  # we take the inverse of the mask because a value
1470
1470
  # of True indicates the value *should* take part in
1471
1471
  # attention
1472
- x = blk(x=x, y=None, attn_mask=~new_m.bool())
1472
+ temp_mask = ~new_m.bool()
1473
+ if temp_mask.all():
1474
+ # if all the tokens are used in attention we can pass a None mask
1475
+ # to the attention block
1476
+ temp_mask = None
1477
+
1478
+ x = blk(x=x, y=None, attn_mask=temp_mask)
1473
1479
 
1474
1480
  if exit_ids_seq is not None:
1475
1481
  assert exited_tokens is not None
@@ -248,3 +248,14 @@ class Presto(nn.Module):
248
248
  output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
249
249
 
250
250
  return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
251
+
252
+ def get_backbone_channels(self) -> list:
253
+ """Returns the output channels of this model when used as a backbone.
254
+
255
+ The output channels is a list of (patch_size, depth) that corresponds
256
+ to the feature maps that the backbone returns.
257
+
258
+ Returns:
259
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
260
+ """
261
+ return [(1, 128)]
rslearn/models/prithvi.py CHANGED
@@ -173,6 +173,17 @@ class PrithviV2(nn.Module):
173
173
  features, num_timesteps
174
174
  )
175
175
 
176
+ def get_backbone_channels(self) -> list:
177
+ """Returns the output channels of this model when used as a backbone.
178
+
179
+ The output channels is a list of (patch_size, depth) that corresponds
180
+ to the feature maps that the backbone returns.
181
+
182
+ Returns:
183
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
184
+ """
185
+ return [(1, 1024)]
186
+
176
187
 
177
188
  class PrithviNormalize(Transform):
178
189
  """Normalize inputs using Prithvi normalization.
@@ -1,5 +1,22 @@
1
1
  """Model registry."""
2
2
 
3
- from class_registry import ClassRegistry
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
- Models = ClassRegistry()
6
+ _ModelT = TypeVar("_ModelT")
7
+
8
+
9
+ class _ModelRegistry(dict[str, type[Any]]):
10
+ """Registry for Model classes."""
11
+
12
+ def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
13
+ """Decorator to register a model class."""
14
+
15
+ def decorator(cls: type[_ModelT]) -> type[_ModelT]:
16
+ self[name] = cls
17
+ return cls
18
+
19
+ return decorator
20
+
21
+
22
+ Models = _ModelRegistry()
@@ -130,10 +130,12 @@ class DefaultTileStore(TileStore):
130
130
  """
131
131
  raster_dir = self._get_raster_dir(layer_name, item_name, bands)
132
132
  for fname in raster_dir.iterdir():
133
- # Ignore completed sentinel files as well as temporary files created by
133
+ # Ignore completed sentinel files, bands files, as well as temporary files created by
134
134
  # open_atomic (in case this tile store is on local filesystem).
135
135
  if fname.name == COMPLETED_FNAME:
136
136
  continue
137
+ if fname.name == BANDS_FNAME:
138
+ continue
137
139
  if ".tmp." in fname.name:
138
140
  continue
139
141
  return fname
@@ -54,7 +54,7 @@ def read_selector(
54
54
  the item specified by the selector
55
55
  """
56
56
  d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
57
- parts = selector.split("/")
57
+ parts = selector.split("/") if selector else []
58
58
  cur = d
59
59
  for part in parts:
60
60
  cur = cur[part]
@@ -76,11 +76,28 @@ def write_selector(
76
76
  v: the value to write
77
77
  """
78
78
  d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
79
- parts = selector.split("/")
80
- cur = d
81
- for part in parts[:-1]:
82
- cur = cur[part]
83
- cur[parts[-1]] = v
79
+ if selector:
80
+ parts = selector.split("/")
81
+ cur = d
82
+ for part in parts[:-1]:
83
+ cur = cur[part]
84
+ cur[parts[-1]] = v
85
+ else:
86
+ # If the selector references the input or target dictionary directly, then we
87
+ # have a special case where instead of overwriting with v, we replace the keys
88
+ # with those in v. v must be a dictionary here, not a tensor, since otherwise
89
+ # it wouldn't match the type of the input or target dictionary.
90
+ if not isinstance(v, dict):
91
+ raise ValueError(
92
+ "when directly specifying the input or target dict, expected the value to be a dict"
93
+ )
94
+ if d == v:
95
+ # This may happen if the writer did not make a copy of the dictionary. In
96
+ # this case the code below would not update d correctly since it would also
97
+ # clear v.
98
+ return
99
+ d.clear()
100
+ d.update(v)
84
101
 
85
102
 
86
103
  class Transform(torch.nn.Module):
@@ -2,13 +2,13 @@
2
2
 
3
3
  import hashlib
4
4
  import json
5
- from typing import Any, BinaryIO
5
+ from collections.abc import Callable
6
+ from typing import Any, BinaryIO, TypeVar
6
7
 
7
8
  import affine
8
9
  import numpy as np
9
10
  import numpy.typing as npt
10
11
  import rasterio
11
- from class_registry import ClassRegistry
12
12
  from PIL import Image
13
13
  from rasterio.crs import CRS
14
14
  from rasterio.enums import Resampling
@@ -21,7 +21,27 @@ from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath
21
21
 
22
22
  from .geometry import PixelBounds, Projection
23
23
 
24
- RasterFormats = ClassRegistry()
24
+ _RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
25
+
26
+
27
+ class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
28
+ """Registry for RasterFormat classes."""
29
+
30
+ def register(
31
+ self, name: str
32
+ ) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
33
+ """Decorator to register a raster format class."""
34
+
35
+ def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
36
+ self[name] = cls
37
+ return cls
38
+
39
+ return decorator
40
+
41
+
42
+ RasterFormats = _RasterFormatRegistry()
43
+
44
+
25
45
  logger = get_logger(__name__)
26
46
 
27
47
 
@@ -147,6 +167,19 @@ class RasterFormat:
147
167
  """
148
168
  raise NotImplementedError
149
169
 
170
+ @staticmethod
171
+ def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
172
+ """Create a RasterFormat from a config dict.
173
+
174
+ Args:
175
+ name: the name of this format
176
+ config: the config dict
177
+
178
+ Returns:
179
+ the RasterFormat instance
180
+ """
181
+ raise NotImplementedError
182
+
150
183
 
151
184
  @RasterFormats.register("image_tile")
152
185
  class ImageTileRasterFormat(RasterFormat):
@@ -716,5 +749,5 @@ def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
716
749
  Returns:
717
750
  the loaded RasterFormat implementation
718
751
  """
719
- cls = RasterFormats.get_class(config.name)
752
+ cls = RasterFormats[config.name]
720
753
  return cls.from_config(config.name, config.config_dict)
@@ -1,11 +1,11 @@
1
1
  """Classes for writing vector data to a UPath."""
2
2
 
3
3
  import json
4
+ from collections.abc import Callable
4
5
  from enum import Enum
5
- from typing import Any
6
+ from typing import Any, TypeVar
6
7
 
7
8
  import shapely
8
- from class_registry import ClassRegistry
9
9
  from rasterio.crs import CRS
10
10
  from upath import UPath
11
11
 
@@ -18,7 +18,25 @@ from .feature import Feature
18
18
  from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
19
19
 
20
20
  logger = get_logger(__name__)
21
- VectorFormats = ClassRegistry()
21
+ _VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
22
+
23
+
24
+ class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
25
+ """Registry for VectorFormat classes."""
26
+
27
+ def register(
28
+ self, name: str
29
+ ) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
30
+ """Decorator to register a vector format class."""
31
+
32
+ def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
33
+ self[name] = cls
34
+ return cls
35
+
36
+ return decorator
37
+
38
+
39
+ VectorFormats = _VectorFormatRegistry()
22
40
 
23
41
 
24
42
  class VectorFormat:
@@ -53,6 +71,19 @@ class VectorFormat:
53
71
  """
54
72
  raise NotImplementedError
55
73
 
74
+ @staticmethod
75
+ def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
76
+ """Create a VectorFormat from a config dict.
77
+
78
+ Args:
79
+ name: the name of this format
80
+ config: the config dict
81
+
82
+ Returns:
83
+ the VectorFormat instance
84
+ """
85
+ raise NotImplementedError
86
+
56
87
 
57
88
  @VectorFormats.register("tile")
58
89
  class TileVectorFormat(VectorFormat):
@@ -410,5 +441,5 @@ def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
410
441
  Returns:
411
442
  the loaded VectorFormat implementation
412
443
  """
413
- cls = VectorFormats.get_class(config.name)
444
+ cls = VectorFormats[config.name]
414
445
  return cls.from_config(config.name, config.config_dict)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -212,7 +212,6 @@ Requires-Python: >=3.11
212
212
  Description-Content-Type: text/markdown
213
213
  License-File: LICENSE
214
214
  Requires-Dist: boto3>=1.39
215
- Requires-Dist: class_registry>=2.1
216
215
  Requires-Dist: fiona>=1.10
217
216
  Requires-Dist: fsspec>=2025.9.0
218
217
  Requires-Dist: jsonargparse>=4.35.0
@@ -284,6 +283,7 @@ Quick links:
284
283
  - [Examples](docs/Examples.md) contains more examples, including customizing different
285
284
  stages of rslearn with additional code.
286
285
  - [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
286
+ - [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
287
287
 
288
288
 
289
289
  Setup
@@ -20,11 +20,11 @@ rslearn/data_sources/eurocrops.py,sha256=bH_ul45RvNLLvVKU9JjZH1XMenAmgn7Lakq0pXj
20
20
  rslearn/data_sources/gcp_public_data.py,sha256=kr9stYo7ZCvz8s4E3wmoY-yAGZoLa_9RCwjS-Q5k9dM,36128
21
21
  rslearn/data_sources/geotiff.py,sha256=sFUp919chaX4j6lQytNp__xnMLlDI3Ac3rfB6F8sgZ0,45
22
22
  rslearn/data_sources/google_earth_engine.py,sha256=hpkt74ly2lEwjRrDp8FBmGvB3MEw_mQ38Av4rQOR3_w,24246
23
- rslearn/data_sources/local_files.py,sha256=NnmXWbOoKHxZFzwCiWsybWfPrFoPTpoZbjoK57xhoqc,19049
23
+ rslearn/data_sources/local_files.py,sha256=d08m6IzrUN_80VfvgpHahMJrv-n6_CI6EIocp6kyDRs,19490
24
24
  rslearn/data_sources/openstreetmap.py,sha256=qUSMFiIA_laJkO3meBXf9TmSI7OBD-o3i4JxqllUv3Q,19232
25
25
  rslearn/data_sources/planet.py,sha256=F2JoLaQ5Cb3k1cTm0hwSWTL2TPfbaAUMXZ8q4Dy7UlA,10109
26
26
  rslearn/data_sources/planet_basemap.py,sha256=wuWM9dHSJMdINfyWb78Zk9i-KvJHTrf9J0Q2gyEyiiA,10450
27
- rslearn/data_sources/planetary_computer.py,sha256=Wchr-OmAffuVteUW6VRofIqFpE-cJqI6GMTU9v1i8pw,28089
27
+ rslearn/data_sources/planetary_computer.py,sha256=uHNYxvnMkmo8zbqIiDRdnkz8LQ7TSs6K39Y1AXjboDI,30392
28
28
  rslearn/data_sources/raster_source.py,sha256=b8wo55GhVLxXwx1WYLzeRAlzD_ZkE_P9tnvUOdnsfQE,689
29
29
  rslearn/data_sources/usda_cdl.py,sha256=2_V11AhPRgLEGd4U5Pmx3UvE2HWBPbsFXhUIQVRVFeE,7138
30
30
  rslearn/data_sources/usgs_landsat.py,sha256=31GmOUfmxwTE6MTiVI4psb-ciVmunuA8cfvqDuvTHPE,19312
@@ -39,9 +39,9 @@ rslearn/dataset/add_windows.py,sha256=pwCEvwLE1jQCoqQxw6CJ-sP46ayWppFa2hGYIB6VVk
39
39
  rslearn/dataset/dataset.py,sha256=bjf9nI55j-MF0bIQWSNPjNbpfqnLK4jy-96TAcwO0MM,5214
40
40
  rslearn/dataset/handler_summaries.py,sha256=wGnbBpjLWTxVn3UT7j7nPoHlYsaWb9_MVJ5DhU0qWXY,2581
41
41
  rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
42
- rslearn/dataset/manage.py,sha256=E2PF29MWuYu8XMcX8PYG4itCdgxaUV768R5NCHmJ-e8,18058
43
- rslearn/dataset/materialize.py,sha256=emsxgkBKWC-ybZDtudsYOsxx7yxgmdKy5YrysgqzNac,21556
44
- rslearn/dataset/remap.py,sha256=HF5Kn5z6dEPbCpSxaBmYVDSzCuBjG2w8csPsk0RRlv4,2088
42
+ rslearn/dataset/manage.py,sha256=mkdBHo1RFGxMx8f9zBT_VmRO_6y8Qb2KfWPPziKWYkg,18062
43
+ rslearn/dataset/materialize.py,sha256=-z47svc_JqGhzkp8kq5Hd9fykWNqFEUCQezo887TWBw,22056
44
+ rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
45
45
  rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
46
46
  rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
47
47
  rslearn/models/anysat.py,sha256=3BnaiS1sYB4SnV6qRjHksiz_r9vUuZeGPUO2XUziFA0,7810
@@ -49,7 +49,7 @@ rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
49
49
  rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
50
50
  rslearn/models/copernicusfm.py,sha256=3AiORuUre9sZYwydbrDgShwKtxeTLmExp7WQmJtBylg,7842
51
51
  rslearn/models/croma.py,sha256=cOazTp3l2PNJltKrmPqD5Gy4pi3CI03-X9G4T10cX2k,9529
52
- rslearn/models/dinov3.py,sha256=KgR-vbmb3-6Qn4AC7ybGL6Ri3_f4jf2XOaKxAR9uMnU,6150
52
+ rslearn/models/dinov3.py,sha256=GKk5qXZPCEporATJdjaSWsDTfWDlAGRWBplFUJN5nRM,6146
53
53
  rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
54
54
  rslearn/models/fpn.py,sha256=s3cz29I14FaSuvBvLOcwCrqVsaRBxG5GjLlqap4WgPc,1603
55
55
  rslearn/models/module_wrapper.py,sha256=H2zb-8Au4t31kawW_4JEKHsaXFjpYDawb31ZEauKcxU,2728
@@ -58,8 +58,8 @@ rslearn/models/multitask.py,sha256=j2Kiwj_dUiUp_CIUr25bS8HiyeoFlr1PGqjTfpgIGLc,1
58
58
  rslearn/models/panopticon.py,sha256=woNEs53wVc5D-NxbSDEPRZ_mYe8vllnuldmADjvhfDQ,5806
59
59
  rslearn/models/pick_features.py,sha256=y8e4tJFhyG7ZuVSElWhQ5-Aer4ZKJCEH9wLGJU7WqGI,1551
60
60
  rslearn/models/pooling_decoder.py,sha256=jZfEQCfthfa21C9sEjgFHUcfhHMVlvG7_nDMw_1FLaE,2727
61
- rslearn/models/prithvi.py,sha256=bGY2Tf_V9tlJGDtDSCMkoFMJdFXUCW6cyj4r8FER8Pw,39557
62
- rslearn/models/registry.py,sha256=0j6jY-rmB_wKckHEUeCYWXtjviDOEr5c7WhIMxxDwBk,90
61
+ rslearn/models/prithvi.py,sha256=SVM3ypJlVTkXQ69pPhB4UeJr87VnmADTCuyV365dbkU,39961
62
+ rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
63
63
  rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
64
64
  rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
65
65
  rslearn/models/satlaspretrain.py,sha256=YpjXl-uClhTZMDmyhN64Fg3AszzT-ymZgJB0fO9RyoY,2419
@@ -91,8 +91,8 @@ rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJ
91
91
  rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
92
92
  rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
93
93
  rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
94
- rslearn/models/galileo/galileo.py,sha256=b8LMyn5lN-WHmCNTJUy-ykjL0s5_-jg9RTO1CwJVra4,19533
95
- rslearn/models/galileo/single_file_galileo.py,sha256=hmtBGdlFfX9JceZLguntSlXtuHElOafxwu1yIrhrEJY,55920
94
+ rslearn/models/galileo/galileo.py,sha256=jUHA64YvVC3Fz5fevc_9dFJfZaINODRDrhSGLIiOZcw,21115
95
+ rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
96
96
  rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
97
97
  rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
98
98
  rslearn/models/panopticon_data/sensors/goes.yaml,sha256=o00aoWCYqam0aB1rPmXq1MKe8hsKak_qyBG7BPL27Sc,152
@@ -106,10 +106,10 @@ rslearn/models/panopticon_data/sensors/sentinel2.yaml,sha256=qYJ92x-GHO0ZdCrTtCj
106
106
  rslearn/models/panopticon_data/sensors/superdove.yaml,sha256=QpIRyopdV4hAez_EIsDwhGFT4VtTk7UgzQveyc8t8fc,795
107
107
  rslearn/models/panopticon_data/sensors/wv23.yaml,sha256=SWYSlkka6UViKAz6YI8aqwQ-Ayo-S5kmNa9rO3iGW6o,1172
108
108
  rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRygze0,64
109
- rslearn/models/presto/presto.py,sha256=Vr4DBIP5Twnop5pZ21BhChbOwWP7KI-GMdg2XILkj5c,8774
109
+ rslearn/models/presto/presto.py,sha256=8mZnc0jk_r_JikybHQNyyHg6t7JNPmoPmgoivyNf-U8,9177
110
110
  rslearn/models/presto/single_file_presto.py,sha256=Kbwp8V7pO8HHM2vlCPpjekQiFiDryW8zQkWmt1g05BY,30381
111
111
  rslearn/tile_stores/__init__.py,sha256=o_tWVKu6UwFzZbO9jn_3cmIDqc_Q3qDd6tA9If0T_Qk,2050
112
- rslearn/tile_stores/default.py,sha256=P78LQm9OXNjd6Igkh-YqUFDmHSBN1Uk092xwqkuOPSs,14973
112
+ rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
113
113
  rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
114
114
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
115
115
  rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
@@ -140,7 +140,7 @@ rslearn/train/transforms/normalize.py,sha256=uyv2hE5hw5B2kCRHa4JIx0tfowm-C7bgumw
140
140
  rslearn/train/transforms/pad.py,sha256=EDswS9KYRSloM3DQlbCz6S0WYqFQJvI433qMqTtqrZw,4686
141
141
  rslearn/train/transforms/select_bands.py,sha256=uDfD9G8Z4VTt88QZsjj1FB20QEmzSefhKf7uDXYn77M,2441
142
142
  rslearn/train/transforms/sentinel1.py,sha256=FrLaYZs2AjqWQCun8DTFtgo1l0xLxqaFKtDNIehtpDg,1913
143
- rslearn/train/transforms/transform.py,sha256=8Q-dPrmDr0tJ9ZOwjWqWK8kbnKi4uLxEnS9Nwf6BVJk,3594
143
+ rslearn/train/transforms/transform.py,sha256=n1Qzqix2dVvej-Q7iPzHeOQbqH79IBlvqPoymxhNVpE,4446
144
144
  rslearn/utils/__init__.py,sha256=GNvdTUmXakiEMnLdje7k1fe5aC7SFVqP757kbpN6Fzw,558
145
145
  rslearn/utils/array.py,sha256=JwZi7o0uj-dftREzJmqrRVR2joIwBikm3Er9KeHVIZU,2402
146
146
  rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
@@ -150,15 +150,15 @@ rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF
150
150
  rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
151
151
  rslearn/utils/jsonargparse.py,sha256=JcTKQoZ6jgwag-kSeTIEVBO9AsRj0X1oEJBsoaCazH4,658
152
152
  rslearn/utils/mp.py,sha256=XYmVckI5TOQuCKc49NJyirDJyFgvb4AI-gGypG2j680,1399
153
- rslearn/utils/raster_format.py,sha256=3hjBR_0Gl1M1b0Vjx4UsQBgAVziCh34VAVWeW32k3Ds,25167
153
+ rslearn/utils/raster_format.py,sha256=dBTSa8l6Ms9Ndbx9Krgqm9z4RU7j2hwLBkw2w-KibU4,26009
154
154
  rslearn/utils/rtree_index.py,sha256=j0Zwrq3pXuAJ-hKpiRFQ7VNtvO3fZYk-Em2uBPAqfx4,6460
155
155
  rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfsU,762
156
156
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
157
157
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
158
- rslearn/utils/vector_format.py,sha256=C0-qlSXFFnhUU0S7eiOn4-4d5-pCkwU2jZ-q0Ldikk8,15299
159
- rslearn-0.0.8.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
160
- rslearn-0.0.8.dist-info/METADATA,sha256=hdHA3W7Hb4bQPPZG6SscWr1AOOVhALL0F08tEMt7neQ,36206
161
- rslearn-0.0.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
- rslearn-0.0.8.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
163
- rslearn-0.0.8.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
164
- rslearn-0.0.8.dist-info/RECORD,,
158
+ rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
159
+ rslearn-0.0.9.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
160
+ rslearn-0.0.9.dist-info/METADATA,sha256=6BV8wt9tuo94FkoKjR3RcF3AbKNbU3IodkJtK4tASkE,36248
161
+ rslearn-0.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
+ rslearn-0.0.9.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
163
+ rslearn-0.0.9.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
164
+ rslearn-0.0.9.dist-info/RECORD,,