rslearn 0.0.8__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,12 +2,12 @@
2
2
 
3
3
  import functools
4
4
  import json
5
+ from collections.abc import Callable
5
6
  from typing import Any, Generic, TypeVar
6
7
 
7
8
  import fiona
8
9
  import shapely
9
10
  import shapely.geometry
10
- from class_registry import ClassRegistry
11
11
  from rasterio.crs import CRS
12
12
  from upath import UPath
13
13
 
@@ -23,7 +23,24 @@ from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
23
23
  from .data_source import DataSource, Item, QueryConfig
24
24
 
25
25
  logger = get_logger("__name__")
26
- Importers = ClassRegistry()
26
+ _ImporterT = TypeVar("_ImporterT", bound="Importer")
27
+
28
+
29
+ class _ImporterRegistry(dict[str, type["Importer"]]):
30
+ """Registry for Importer classes."""
31
+
32
+ def register(self, name: str) -> Callable[[type[_ImporterT]], type[_ImporterT]]:
33
+ """Decorator to register an importer class."""
34
+
35
+ def decorator(cls: type[_ImporterT]) -> type[_ImporterT]:
36
+ self[name] = cls
37
+ return cls
38
+
39
+ return decorator
40
+
41
+
42
+ Importers = _ImporterRegistry()
43
+
27
44
 
28
45
  ItemType = TypeVar("ItemType", bound=Item)
29
46
  LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
@@ -425,7 +442,7 @@ class LocalFiles(DataSource):
425
442
  """
426
443
  self.config = config
427
444
 
428
- self.importer = Importers[config.layer_type.value]
445
+ self.importer = Importers[config.layer_type.value]()
429
446
  self.src_dir = src_dir
430
447
 
431
448
  @staticmethod
@@ -83,6 +83,10 @@ class PlanetaryComputer(DataSource, TileStore):
83
83
 
84
84
  STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"
85
85
 
86
+ # Default threshold for recreating the STAC client to prevent memory leaks
87
+ # from the pystac Catalog's resolved objects cache growing unbounded
88
+ DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
89
+
86
90
  def __init__(
87
91
  self,
88
92
  collection_name: str,
@@ -93,6 +97,7 @@ class PlanetaryComputer(DataSource, TileStore):
93
97
  timeout: timedelta = timedelta(seconds=10),
94
98
  skip_items_missing_assets: bool = False,
95
99
  cache_dir: UPath | None = None,
100
+ max_items_per_client: int | None = None,
96
101
  ):
97
102
  """Initialize a new PlanetaryComputer instance.
98
103
 
@@ -109,6 +114,9 @@ class PlanetaryComputer(DataSource, TileStore):
109
114
  cache_dir: optional directory to cache items by name, including asset URLs.
110
115
  If not set, there will be no cache and instead STAC requests will be
111
116
  needed each time.
117
+ max_items_per_client: number of STAC items to process before recreating
118
+ the client to prevent memory leaks from the resolved objects cache.
119
+ Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
112
120
  """
113
121
  self.collection_name = collection_name
114
122
  self.asset_bands = asset_bands
@@ -118,12 +126,15 @@ class PlanetaryComputer(DataSource, TileStore):
118
126
  self.timeout = timeout
119
127
  self.skip_items_missing_assets = skip_items_missing_assets
120
128
  self.cache_dir = cache_dir
129
+ self.max_items_per_client = (
130
+ max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
131
+ )
121
132
 
122
133
  if self.cache_dir is not None:
123
134
  self.cache_dir.mkdir(parents=True, exist_ok=True)
124
135
 
125
136
  self.client: pystac_client.Client | None = None
126
- self.collection: pystac_client.CollectionClient | None = None
137
+ self._client_item_count = 0
127
138
 
128
139
  @staticmethod
129
140
  def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
@@ -142,7 +153,12 @@ class PlanetaryComputer(DataSource, TileStore):
142
153
  if "cache_dir" in d:
143
154
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
144
155
 
145
- simple_optionals = ["query", "sort_by", "sort_ascending"]
156
+ simple_optionals = [
157
+ "query",
158
+ "sort_by",
159
+ "sort_ascending",
160
+ "max_items_per_client",
161
+ ]
146
162
  for k in simple_optionals:
147
163
  if k in d:
148
164
  kwargs[k] = d[k]
@@ -151,20 +167,40 @@ class PlanetaryComputer(DataSource, TileStore):
151
167
 
152
168
  def _load_client(
153
169
  self,
154
- ) -> tuple[pystac_client.Client, pystac_client.CollectionClient]:
170
+ ) -> pystac_client.Client:
155
171
  """Lazily load pystac client.
156
172
 
157
173
  We don't load it when creating the data source because it takes time and caller
158
174
  may not be calling get_items. Additionally, loading it during the get_items
159
175
  call enables leveraging the retry loop functionality in
160
176
  prepare_dataset_windows.
161
- """
162
- if self.client is not None:
163
- return self.client, self.collection
164
177
 
178
+ Note: We periodically recreate the client to prevent memory leaks from the
179
+ pystac Catalog's resolved objects cache, which grows unbounded as STAC items
180
+ are deserialized and cached. The cache cannot be cleared or disabled.
181
+ """
182
+ if self.client is None:
183
+ logger.info("Creating initial STAC client")
184
+ self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
185
+ return self.client
186
+
187
+ if self._client_item_count < self.max_items_per_client:
188
+ return self.client
189
+
190
+ # Recreate client to clear the resolved objects cache
191
+ current_client = self.client
192
+ logger.debug(
193
+ "Recreating STAC client after processing %d items (threshold: %d)",
194
+ self._client_item_count,
195
+ self.max_items_per_client,
196
+ )
197
+ client_root = current_client.get_root()
198
+ client_root.clear_links()
199
+ client_root.clear_items()
200
+ client_root.clear_children()
201
+ self._client_item_count = 0
165
202
  self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
166
- self.collection = self.client.get_collection(self.collection_name)
167
- return self.client, self.collection
203
+ return self.client
168
204
 
169
205
  def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
170
206
  shp = shapely.geometry.shape(stac_item.geometry)
@@ -210,10 +246,26 @@ class PlanetaryComputer(DataSource, TileStore):
210
246
 
211
247
  # No cache or not in cache, so we need to make the STAC request.
212
248
  logger.debug("Getting STAC item {name}")
213
- _, collection = self._load_client()
214
- stac_item = collection.get_item(name)
249
+ client = self._load_client()
250
+
251
+ search_result = client.search(ids=[name], collections=[self.collection_name])
252
+ stac_items = list(search_result.items())
253
+
254
+ if not stac_items:
255
+ raise ValueError(
256
+ f"Item {name} not found in collection {self.collection_name}"
257
+ )
258
+ if len(stac_items) > 1:
259
+ raise ValueError(
260
+ f"Multiple items found for ID {name} in collection {self.collection_name}"
261
+ )
262
+
263
+ stac_item = stac_items[0]
215
264
  item = self._stac_item_to_item(stac_item)
216
265
 
266
+ # Track items processed for client recreation threshold (after deserialization)
267
+ self._client_item_count += 1
268
+
217
269
  # Finally we cache it if cache_dir is set.
218
270
  if cache_fname is not None:
219
271
  with cache_fname.open("w") as f:
@@ -233,7 +285,7 @@ class PlanetaryComputer(DataSource, TileStore):
233
285
  Returns:
234
286
  List of groups of items that should be retrieved for each geometry.
235
287
  """
236
- client, _ = self._load_client()
288
+ client = self._load_client()
237
289
 
238
290
  groups = []
239
291
  for geometry in geometries:
@@ -247,7 +299,9 @@ class PlanetaryComputer(DataSource, TileStore):
247
299
  datetime=wgs84_geometry.time_range,
248
300
  query=self.query,
249
301
  )
250
- stac_items = [item for item in result.item_collection()]
302
+ stac_items = [item for item in result.items()]
303
+ # Track items processed for client recreation threshold (after deserialization)
304
+ self._client_item_count += len(stac_items)
251
305
  logger.debug("STAC search yielded %d items", len(stac_items))
252
306
 
253
307
  if self.skip_items_missing_assets:
@@ -580,7 +634,13 @@ class Sentinel2(PlanetaryComputer):
580
634
  if "cache_dir" in d:
581
635
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
582
636
 
583
- simple_optionals = ["harmonize", "query", "sort_by", "sort_ascending"]
637
+ simple_optionals = [
638
+ "harmonize",
639
+ "query",
640
+ "sort_by",
641
+ "sort_ascending",
642
+ "max_items_per_client",
643
+ ]
584
644
  for k in simple_optionals:
585
645
  if k in d:
586
646
  kwargs[k] = d[k]
@@ -756,7 +816,12 @@ class Sentinel1(PlanetaryComputer):
756
816
  if "cache_dir" in d:
757
817
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
758
818
 
759
- simple_optionals = ["query", "sort_by", "sort_ascending"]
819
+ simple_optionals = [
820
+ "query",
821
+ "sort_by",
822
+ "sort_ascending",
823
+ "max_items_per_client",
824
+ ]
760
825
  for k in simple_optionals:
761
826
  if k in d:
762
827
  kwargs[k] = d[k]
rslearn/dataset/manage.py CHANGED
@@ -396,9 +396,9 @@ def materialize_window(
396
396
  )
397
397
 
398
398
  if dataset.materializer_name:
399
- materializer = Materializers[dataset.materializer_name]
399
+ materializer = Materializers[dataset.materializer_name]()
400
400
  else:
401
- materializer = Materializers[layer_cfg.layer_type.value]
401
+ materializer = Materializers[layer_cfg.layer_type.value]()
402
402
 
403
403
  retry(
404
404
  fn=lambda: materializer.materialize(
@@ -1,10 +1,10 @@
1
1
  """Classes to implement dataset materialization."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from typing import Any, Generic, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
  from rasterio.enums import Resampling
9
9
 
10
10
  from rslearn.config import (
@@ -25,7 +25,26 @@ from rslearn.utils.vector_format import load_vector_format
25
25
  from .remap import Remapper, load_remapper
26
26
  from .window import Window
27
27
 
28
- Materializers = ClassRegistry()
28
+ _MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
29
+
30
+
31
+ class _MaterializerRegistry(dict[str, type["Materializer"]]):
32
+ """Registry for Materializer classes."""
33
+
34
+ def register(
35
+ self, name: str
36
+ ) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
37
+ """Decorator to register a materializer class."""
38
+
39
+ def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
40
+ self[name] = cls
41
+ return cls
42
+
43
+ return decorator
44
+
45
+
46
+ Materializers = _MaterializerRegistry()
47
+
29
48
 
30
49
  LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
31
50
 
rslearn/dataset/remap.py CHANGED
@@ -1,18 +1,42 @@
1
1
  """Classes to remap raster values."""
2
2
 
3
- from typing import Any
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
 
9
- Remappers = ClassRegistry()
9
+ _RemapperT = TypeVar("_RemapperT", bound="Remapper")
10
+
11
+
12
+ class _RemapperRegistry(dict[str, type["Remapper"]]):
13
+ """Registry for Remapper classes."""
14
+
15
+ def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
16
+ """Decorator to register a remapper class."""
17
+
18
+ def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
19
+ self[name] = cls
20
+ return cls
21
+
22
+ return decorator
23
+
24
+
25
+ Remappers = _RemapperRegistry()
10
26
  """Registry of Remapper implementations."""
11
27
 
12
28
 
13
29
  class Remapper:
14
30
  """An abstract class that remaps pixel values based on layer configuration."""
15
31
 
32
+ def __init__(self, config: dict[str, Any]) -> None:
33
+ """Initialize a Remapper.
34
+
35
+ Args:
36
+ config: the config dict for this remapper.
37
+ """
38
+ pass
39
+
16
40
  def __call__(
17
41
  self, array: npt.NDArray[Any], dtype: npt.DTypeLike
18
42
  ) -> npt.NDArray[Any]:
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
67
91
 
68
92
  def load_remapper(config: dict[str, Any]) -> Remapper:
69
93
  """Load a remapper from a configuration dictionary."""
70
- return Remappers.get(config["name"], config=config)
94
+ cls = Remappers[config["name"]]
95
+ return cls(config)
rslearn/models/dinov3.py CHANGED
@@ -7,8 +7,8 @@ from typing import Any
7
7
  import torch
8
8
  import torchvision
9
9
  from einops import rearrange
10
- from torchvision.transforms import v2
11
10
 
11
+ from rslearn.train.transforms.normalize import Normalize
12
12
  from rslearn.train.transforms.transform import Transform
13
13
 
14
14
 
@@ -139,15 +139,17 @@ class DinoV3Normalize(Transform):
139
139
  super().__init__()
140
140
  self.satellite = satellite
141
141
  if satellite:
142
- self.normalize = v2.Normalize(
143
- mean=(0.430, 0.411, 0.296),
144
- std=(0.213, 0.156, 0.143),
145
- )
142
+ mean = [0.430, 0.411, 0.296]
143
+ std = [0.213, 0.156, 0.143]
146
144
  else:
147
- self.normalize = v2.Normalize(
148
- mean=(0.485, 0.456, 0.406),
149
- std=(0.229, 0.224, 0.225),
150
- )
145
+ mean = [0.485, 0.456, 0.406]
146
+ std = [0.229, 0.224, 0.225]
147
+
148
+ self.normalize = Normalize(
149
+ [value * 255 for value in mean],
150
+ [value * 255 for value in std],
151
+ num_bands=3,
152
+ )
151
153
 
152
154
  def forward(
153
155
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -161,5 +163,4 @@ class DinoV3Normalize(Transform):
161
163
  Returns:
162
164
  normalized (input_dicts, target_dicts) tuple
163
165
  """
164
- input_dict["image"] = self.normalize(input_dict["image"] / 255.0)
165
- return input_dict, target_dict
166
+ return self.normalize(input_dict, target_dict)
@@ -2,6 +2,7 @@
2
2
 
3
3
  import math
4
4
  import tempfile
5
+ from contextlib import nullcontext
5
6
  from enum import StrEnum
6
7
  from typing import Any, cast
7
8
 
@@ -63,6 +64,11 @@ pretrained_weights: dict[GalileoSize, str] = {
63
64
 
64
65
  DEFAULT_NORMALIZER = Normalizer()
65
66
 
67
+ AUTOCAST_DTYPE_MAP = {
68
+ "bfloat16": torch.bfloat16,
69
+ "float32": torch.float32,
70
+ }
71
+
66
72
 
67
73
  class GalileoModel(nn.Module):
68
74
  """Galileo backbones."""
@@ -85,6 +91,7 @@ class GalileoModel(nn.Module):
85
91
  size: GalileoSize,
86
92
  patch_size: int = 4,
87
93
  pretrained_path: str | UPath | None = None,
94
+ autocast_dtype: str | None = "bfloat16",
88
95
  ) -> None:
89
96
  """Initialize the Galileo model.
90
97
 
@@ -93,6 +100,7 @@ class GalileoModel(nn.Module):
93
100
  patch_size: The patch size to use.
94
101
  pretrained_path: the local path to the pretrained weights. Otherwise it is
95
102
  downloaded and cached in temp directory.
103
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
96
104
  """
97
105
  super().__init__()
98
106
  if pretrained_path is None:
@@ -128,8 +136,14 @@ class GalileoModel(nn.Module):
128
136
  idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
129
137
  ]
130
138
 
139
+ self.size = size
131
140
  self.patch_size = patch_size
132
141
 
142
+ if autocast_dtype is not None:
143
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
144
+ else:
145
+ self.autocast_dtype = None
146
+
133
147
  @staticmethod
134
148
  def to_cartesian(
135
149
  lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
@@ -484,18 +498,31 @@ class GalileoModel(nn.Module):
484
498
  patch_size = h
485
499
  else:
486
500
  patch_size = self.patch_size
487
- outputs = self.model(
488
- s_t_x=galileo_input.s_t_x,
489
- s_t_m=galileo_input.s_t_m,
490
- sp_x=galileo_input.sp_x,
491
- sp_m=galileo_input.sp_m,
492
- t_x=galileo_input.t_x,
493
- t_m=galileo_input.t_m,
494
- st_x=galileo_input.st_x,
495
- st_m=galileo_input.st_m,
496
- months=galileo_input.months,
497
- patch_size=patch_size,
498
- )
501
+
502
+ # Decide context based on self.autocast_dtype.
503
+ device = galileo_input.s_t_x.device
504
+ if self.autocast_dtype is None:
505
+ context = nullcontext()
506
+ else:
507
+ assert device is not None
508
+ context = torch.amp.autocast(
509
+ device_type=device.type, dtype=self.autocast_dtype
510
+ )
511
+
512
+ with context:
513
+ outputs = self.model(
514
+ s_t_x=galileo_input.s_t_x,
515
+ s_t_m=galileo_input.s_t_m,
516
+ sp_x=galileo_input.sp_x,
517
+ sp_m=galileo_input.sp_m,
518
+ t_x=galileo_input.t_x,
519
+ t_m=galileo_input.t_m,
520
+ st_x=galileo_input.st_x,
521
+ st_m=galileo_input.st_m,
522
+ months=galileo_input.months,
523
+ patch_size=patch_size,
524
+ )
525
+
499
526
  if h == patch_size:
500
527
  # only one spatial patch, so we can just take an average
501
528
  # of all the tokens to output b c_g 1 1
@@ -515,3 +542,22 @@ class GalileoModel(nn.Module):
515
542
  "b h w c_g d -> b c_g d h w",
516
543
  ).mean(dim=1)
517
544
  ]
545
+
546
+ def get_backbone_channels(self) -> list:
547
+ """Returns the output channels of this model when used as a backbone.
548
+
549
+ The output channels is a list of (patch_size, depth) that corresponds
550
+ to the feature maps that the backbone returns.
551
+
552
+ Returns:
553
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
554
+ """
555
+ if self.size == GalileoSize.BASE:
556
+ depth = 768
557
+ elif self.model_size == GalileoSize.TINY:
558
+ depth = 192
559
+ elif self.model_size == GalileoSize.NANO:
560
+ depth = 128
561
+ else:
562
+ raise ValueError(f"Invalid model size: {self.size}")
563
+ return [(self.patch_size, depth)]
@@ -1469,7 +1469,13 @@ class Encoder(GalileoBase):
1469
1469
  # we take the inverse of the mask because a value
1470
1470
  # of True indicates the value *should* take part in
1471
1471
  # attention
1472
- x = blk(x=x, y=None, attn_mask=~new_m.bool())
1472
+ temp_mask = ~new_m.bool()
1473
+ if temp_mask.all():
1474
+ # if all the tokens are used in attention we can pass a None mask
1475
+ # to the attention block
1476
+ temp_mask = None
1477
+
1478
+ x = blk(x=x, y=None, attn_mask=temp_mask)
1473
1479
 
1474
1480
  if exit_ids_seq is not None:
1475
1481
  assert exited_tokens is not None
@@ -0,0 +1 @@
1
+ """OlmoEarth model architecture."""
@@ -0,0 +1,203 @@
1
+ """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
+
3
+ import json
4
+ from contextlib import nullcontext
5
+ from typing import Any
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from olmo_core.config import Config
10
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
+ from olmoearth_pretrain.data.constants import Modality
12
+ from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
13
+ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
14
+ from upath import UPath
15
+
16
+ from rslearn.log_utils import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+ MODALITY_NAMES = [
21
+ "sentinel2_l2a",
22
+ "sentinel1",
23
+ "worldcover",
24
+ "openstreetmap_raster",
25
+ "landsat",
26
+ ]
27
+
28
+ AUTOCAST_DTYPE_MAP = {
29
+ "bfloat16": torch.bfloat16,
30
+ "float16": torch.float16,
31
+ "float32": torch.float32,
32
+ }
33
+
34
+
35
+ class OlmoEarth(torch.nn.Module):
36
+ """A wrapper to support the OlmoEarth model."""
37
+
38
+ def __init__(
39
+ self,
40
+ # TODO: we should accept model ID instead of checkpoint_path once we are closer
41
+ # to being ready for release.
42
+ checkpoint_path: str,
43
+ selector: list[str | int] = [],
44
+ forward_kwargs: dict[str, Any] = {},
45
+ random_initialization: bool = False,
46
+ embedding_size: int | None = None,
47
+ patch_size: int | None = None,
48
+ autocast_dtype: str | None = "bfloat16",
49
+ ):
50
+ """Create a new OlmoEarth model.
51
+
52
+ Args:
53
+ checkpoint_path: the checkpoint directory to load. It should contain
54
+ config.json file as well as model_and_optim folder.
55
+ selector: an optional sequence of attribute names or list indices to select
56
+ the sub-module that should be applied on the input images.
57
+ forward_kwargs: additional arguments to pass to forward pass besides the
58
+ MaskedOlmoEarthSample.
59
+ random_initialization: whether to skip loading the checkpoint so the
60
+ weights are randomly initialized. In this case, the checkpoint is only
61
+ used to define the model architecture.
62
+ embedding_size: optional embedding size to report via
63
+ get_backbone_channels.
64
+ patch_size: optional patch size to report via get_backbone_channels.
65
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
66
+ """
67
+ super().__init__()
68
+ _checkpoint_path = UPath(checkpoint_path)
69
+ self.forward_kwargs = forward_kwargs
70
+ self.embedding_size = embedding_size
71
+ self.patch_size = patch_size
72
+
73
+ if autocast_dtype is not None:
74
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
75
+ else:
76
+ self.autocast_dtype = None
77
+
78
+ # Load the model config and initialize it.
79
+ # We avoid loading the train module here because it depends on running within
80
+ # olmo_core.
81
+ with (_checkpoint_path / "config.json").open() as f:
82
+ config_dict = json.load(f)
83
+ model_config = Config.from_dict(config_dict["model"])
84
+
85
+ model = model_config.build()
86
+
87
+ # Load the checkpoint.
88
+ if not random_initialization:
89
+ train_module_dir = _checkpoint_path / "model_and_optim"
90
+ if train_module_dir.exists():
91
+ load_model_and_optim_state(str(train_module_dir), model)
92
+ logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
93
+ else:
94
+ logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
95
+ else:
96
+ logger.info("skipping loading OlmoEarth encoder")
97
+
98
+ # Select just the portion of the model that we actually want to use.
99
+ for part in selector:
100
+ if isinstance(part, str):
101
+ model = getattr(model, part)
102
+ else:
103
+ model = model[part]
104
+ self.model = model
105
+
106
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
107
+ """Compute feature maps from the OlmoEarth backbone.
108
+
109
+ Inputs:
110
+ inputs: input dicts. It should include keys corresponding to the modalities
111
+ that should be passed to the OlmoEarth model.
112
+ """
113
+ kwargs = {}
114
+ present_modalities = []
115
+ device = None
116
+ # Handle the case where some modalities are multitemporal and some are not.
117
+ # We assume all multitemporal modalities have the same number of timesteps.
118
+ max_timesteps = 1
119
+ for modality in MODALITY_NAMES:
120
+ if modality not in inputs[0]:
121
+ continue
122
+ present_modalities.append(modality)
123
+ cur = torch.stack([inp[modality] for inp in inputs], dim=0)
124
+ device = cur.device
125
+ # Check if it's single or multitemporal, and reshape accordingly
126
+ num_bands = Modality.get(modality).num_bands
127
+ num_timesteps = cur.shape[1] // num_bands
128
+ max_timesteps = max(max_timesteps, num_timesteps)
129
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
130
+ kwargs[modality] = cur
131
+ # Create mask array which is BHWTS (without channels but with band sets).
132
+ num_band_sets = len(Modality.get(modality).band_sets)
133
+ mask_shape = cur.shape[0:4] + (num_band_sets,)
134
+ mask = (
135
+ torch.ones(mask_shape, dtype=torch.int32, device=device)
136
+ * MaskValue.ONLINE_ENCODER.value
137
+ )
138
+ kwargs[f"{modality}_mask"] = mask
139
+
140
+ # Timestamps is required.
141
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
142
+ # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
143
+ timestamps = torch.zeros(
144
+ (len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
145
+ )
146
+ timestamps[:, :, 0] = 1 # day
147
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
148
+ None, :
149
+ ] # month
150
+ timestamps[:, :, 2] = 2024 # year
151
+ kwargs["timestamps"] = timestamps
152
+
153
+ sample = MaskedOlmoEarthSample(**kwargs)
154
+
155
+ # Decide context based on self.autocast_dtype.
156
+ if self.autocast_dtype is None:
157
+ context = nullcontext()
158
+ else:
159
+ assert device is not None
160
+ context = torch.amp.autocast(
161
+ device_type=device.type, dtype=self.autocast_dtype
162
+ )
163
+
164
+ with context:
165
+ # Currently we assume the provided model always returns a TokensAndMasks object.
166
+ tokens_and_masks: TokensAndMasks
167
+ if isinstance(self.model, Encoder):
168
+ # Encoder has a fast_pass argument to indicate mask is not needed.
169
+ tokens_and_masks = self.model(
170
+ sample, fast_pass=True, **self.forward_kwargs
171
+ )["tokens_and_masks"]
172
+ else:
173
+ # Other models like STEncoder do not have this option supported.
174
+ tokens_and_masks = self.model(sample, **self.forward_kwargs)[
175
+ "tokens_and_masks"
176
+ ]
177
+
178
+ # Apply temporal/modality pooling so we just have one feature per patch.
179
+ features = []
180
+ for modality in present_modalities:
181
+ modality_features = getattr(tokens_and_masks, modality)
182
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
183
+ pooled = modality_features.mean(dim=[3, 4])
184
+ # We want BHWC -> BCHW.
185
+ pooled = rearrange(pooled, "b h w c -> b c h w")
186
+ features.append(pooled)
187
+ # Pool over the modalities, so we get one BCHW feature map.
188
+ pooled = torch.stack(features, dim=0).mean(dim=0)
189
+ return [pooled]
190
+
191
+ def get_backbone_channels(self) -> list:
192
+ """Returns the output channels of this model when used as a backbone.
193
+
194
+ The output channels is a list of (downsample_factor, depth) that corresponds
195
+ to the feature maps that the backbone returns. For example, an element [2, 32]
196
+ indicates that the corresponding feature map is 1/2 the input resolution and
197
+ has 32 channels.
198
+
199
+ Returns:
200
+ the output channels of the backbone as a list of (downsample_factor, depth)
201
+ tuples.
202
+ """
203
+ return [(self.patch_size, self.embedding_size)]
@@ -0,0 +1,84 @@
1
+ """Normalization transforms."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from olmoearth_pretrain.data.normalize import load_computed_config
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.train.transforms.transform import Transform
10
+
11
+ logger = get_logger(__file__)
12
+
13
+
14
+ class OlmoEarthNormalize(Transform):
15
+ """Normalize using OlmoEarth JSON config.
16
+
17
+ For Sentinel-1 data, the values should be converted to decibels before being passed
18
+ to this transform.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ band_names: dict[str, list[str]],
24
+ std_multiplier: float | None = 2,
25
+ config_fname: str | None = None,
26
+ ) -> None:
27
+ """Initialize a new OlmoEarthNormalize.
28
+
29
+ Args:
30
+ band_names: map from modality name to the list of bands in that modality in
31
+ the order they are being loaded. Note that this order must match the
32
+ expected order for the OlmoEarth model.
33
+ std_multiplier: the std multiplier matching the one used for the model
34
+ training in OlmoEarth.
35
+ config_fname: load the normalization configuration from this file, instead
36
+ of getting it from OlmoEarth.
37
+ """
38
+ super().__init__()
39
+ self.band_names = band_names
40
+ self.std_multiplier = std_multiplier
41
+
42
+ if config_fname is None:
43
+ self.norm_config = load_computed_config()
44
+ else:
45
+ logger.warning(
46
+ f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
47
+ )
48
+ with open(config_fname) as f:
49
+ self.norm_config = json.load(f)
50
+
51
+ def forward(
52
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
53
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
54
+ """Apply normalization over the inputs and targets.
55
+
56
+ Args:
57
+ input_dict: the input
58
+ target_dict: the target
59
+
60
+ Returns:
61
+ normalized (input_dicts, target_dicts) tuple
62
+ """
63
+ for modality_name, cur_band_names in self.band_names.items():
64
+ band_norms = self.norm_config[modality_name]
65
+ image = input_dict[modality_name]
66
+ # Keep a set of indices to make sure that we normalize all of them.
67
+ needed_band_indices = set(range(image.shape[0]))
68
+ num_timesteps = image.shape[0] // len(cur_band_names)
69
+
70
+ for band, norm_dict in band_norms.items():
71
+ # If multitemporal, normalize each timestep separately.
72
+ for t in range(num_timesteps):
73
+ band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
+ min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
+ max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
+ image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
77
+ needed_band_indices.remove(band_idx)
78
+
79
+ if len(needed_band_indices) > 0:
80
+ raise ValueError(
81
+ f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
82
+ )
83
+
84
+ return input_dict, target_dict
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
76
76
  features = torch.amax(features, dim=(2, 3))
77
77
  features = self.fc_layers(features)
78
78
  return self.output_layer(features)
79
+
80
+
81
+ class SegmentationPoolingDecoder(PoolingDecoder):
82
+ """Like PoolingDecoder, but copy output to all pixels.
83
+
84
+ This allows for the model to produce a global output while still being compatible
85
+ with SegmentationTask. This only makes sense for very small windows, since the
86
+ output probabilities will be the same at all pixels. The main use case is to train
87
+ for a classification-like task on small windows, but still produce a raster during
88
+ inference on large windows.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ image_key: str = "image",
96
+ **kwargs: Any,
97
+ ):
98
+ """Create a new SegmentationPoolingDecoder.
99
+
100
+ Args:
101
+ in_channels: input channels (channels in the last feature map passed to
102
+ this module)
103
+ out_channels: channels for the output flat feature vector
104
+ image_key: the key in inputs for the image from which the expected width
105
+ and height is derived.
106
+ kwargs: other arguments to pass to PoolingDecoder.
107
+ """
108
+ super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
109
+ self.image_key = image_key
110
+
111
+ def forward(
112
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
113
+ ) -> torch.Tensor:
114
+ """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
115
+
116
+ This only works when all of the pixels have the same segmentation target.
117
+ """
118
+ output_probs = super().forward(features, inputs)
119
+ # BC -> BCHW
120
+ h, w = inputs[0][self.image_key].shape[1:3]
121
+ return output_probs[:, :, None, None].repeat([1, 1, h, w])
@@ -248,3 +248,14 @@ class Presto(nn.Module):
248
248
  output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
249
249
 
250
250
  return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
251
+
252
+ def get_backbone_channels(self) -> list:
253
+ """Returns the output channels of this model when used as a backbone.
254
+
255
+ The output channels is a list of (patch_size, depth) that corresponds
256
+ to the feature maps that the backbone returns.
257
+
258
+ Returns:
259
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
260
+ """
261
+ return [(1, 128)]
rslearn/models/prithvi.py CHANGED
@@ -173,6 +173,17 @@ class PrithviV2(nn.Module):
173
173
  features, num_timesteps
174
174
  )
175
175
 
176
+ def get_backbone_channels(self) -> list:
177
+ """Returns the output channels of this model when used as a backbone.
178
+
179
+ The output channels is a list of (patch_size, depth) that corresponds
180
+ to the feature maps that the backbone returns.
181
+
182
+ Returns:
183
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
184
+ """
185
+ return [(1, 1024)]
186
+
176
187
 
177
188
  class PrithviNormalize(Transform):
178
189
  """Normalize inputs using Prithvi normalization.
@@ -1,5 +1,22 @@
1
1
  """Model registry."""
2
2
 
3
- from class_registry import ClassRegistry
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
- Models = ClassRegistry()
6
+ _ModelT = TypeVar("_ModelT")
7
+
8
+
9
+ class _ModelRegistry(dict[str, type[Any]]):
10
+ """Registry for Model classes."""
11
+
12
+ def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
13
+ """Decorator to register a model class."""
14
+
15
+ def decorator(cls: type[_ModelT]) -> type[_ModelT]:
16
+ self[name] = cls
17
+ return cls
18
+
19
+ return decorator
20
+
21
+
22
+ Models = _ModelRegistry()
@@ -130,10 +130,12 @@ class DefaultTileStore(TileStore):
130
130
  """
131
131
  raster_dir = self._get_raster_dir(layer_name, item_name, bands)
132
132
  for fname in raster_dir.iterdir():
133
- # Ignore completed sentinel files as well as temporary files created by
133
+ # Ignore completed sentinel files, bands files, as well as temporary files created by
134
134
  # open_atomic (in case this tile store is on local filesystem).
135
135
  if fname.name == COMPLETED_FNAME:
136
136
  continue
137
+ if fname.name == BANDS_FNAME:
138
+ continue
137
139
  if ".tmp." in fname.name:
138
140
  continue
139
141
  return fname
@@ -54,7 +54,7 @@ def read_selector(
54
54
  the item specified by the selector
55
55
  """
56
56
  d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
57
- parts = selector.split("/")
57
+ parts = selector.split("/") if selector else []
58
58
  cur = d
59
59
  for part in parts:
60
60
  cur = cur[part]
@@ -76,11 +76,28 @@ def write_selector(
76
76
  v: the value to write
77
77
  """
78
78
  d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
79
- parts = selector.split("/")
80
- cur = d
81
- for part in parts[:-1]:
82
- cur = cur[part]
83
- cur[parts[-1]] = v
79
+ if selector:
80
+ parts = selector.split("/")
81
+ cur = d
82
+ for part in parts[:-1]:
83
+ cur = cur[part]
84
+ cur[parts[-1]] = v
85
+ else:
86
+ # If the selector references the input or target dictionary directly, then we
87
+ # have a special case where instead of overwriting with v, we replace the keys
88
+ # with those in v. v must be a dictionary here, not a tensor, since otherwise
89
+ # it wouldn't match the type of the input or target dictionary.
90
+ if not isinstance(v, dict):
91
+ raise ValueError(
92
+ "when directly specifying the input or target dict, expected the value to be a dict"
93
+ )
94
+ if d == v:
95
+ # This may happen if the writer did not make a copy of the dictionary. In
96
+ # this case the code below would not update d correctly since it would also
97
+ # clear v.
98
+ return
99
+ d.clear()
100
+ d.update(v)
84
101
 
85
102
 
86
103
  class Transform(torch.nn.Module):
@@ -2,13 +2,13 @@
2
2
 
3
3
  import hashlib
4
4
  import json
5
- from typing import Any, BinaryIO
5
+ from collections.abc import Callable
6
+ from typing import Any, BinaryIO, TypeVar
6
7
 
7
8
  import affine
8
9
  import numpy as np
9
10
  import numpy.typing as npt
10
11
  import rasterio
11
- from class_registry import ClassRegistry
12
12
  from PIL import Image
13
13
  from rasterio.crs import CRS
14
14
  from rasterio.enums import Resampling
@@ -21,7 +21,27 @@ from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath
21
21
 
22
22
  from .geometry import PixelBounds, Projection
23
23
 
24
- RasterFormats = ClassRegistry()
24
+ _RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
25
+
26
+
27
+ class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
28
+ """Registry for RasterFormat classes."""
29
+
30
+ def register(
31
+ self, name: str
32
+ ) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
33
+ """Decorator to register a raster format class."""
34
+
35
+ def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
36
+ self[name] = cls
37
+ return cls
38
+
39
+ return decorator
40
+
41
+
42
+ RasterFormats = _RasterFormatRegistry()
43
+
44
+
25
45
  logger = get_logger(__name__)
26
46
 
27
47
 
@@ -147,6 +167,19 @@ class RasterFormat:
147
167
  """
148
168
  raise NotImplementedError
149
169
 
170
+ @staticmethod
171
+ def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
172
+ """Create a RasterFormat from a config dict.
173
+
174
+ Args:
175
+ name: the name of this format
176
+ config: the config dict
177
+
178
+ Returns:
179
+ the RasterFormat instance
180
+ """
181
+ raise NotImplementedError
182
+
150
183
 
151
184
  @RasterFormats.register("image_tile")
152
185
  class ImageTileRasterFormat(RasterFormat):
@@ -716,5 +749,5 @@ def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
716
749
  Returns:
717
750
  the loaded RasterFormat implementation
718
751
  """
719
- cls = RasterFormats.get_class(config.name)
752
+ cls = RasterFormats[config.name]
720
753
  return cls.from_config(config.name, config.config_dict)
@@ -1,11 +1,11 @@
1
1
  """Classes for writing vector data to a UPath."""
2
2
 
3
3
  import json
4
+ from collections.abc import Callable
4
5
  from enum import Enum
5
- from typing import Any
6
+ from typing import Any, TypeVar
6
7
 
7
8
  import shapely
8
- from class_registry import ClassRegistry
9
9
  from rasterio.crs import CRS
10
10
  from upath import UPath
11
11
 
@@ -18,7 +18,25 @@ from .feature import Feature
18
18
  from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
19
19
 
20
20
  logger = get_logger(__name__)
21
- VectorFormats = ClassRegistry()
21
+ _VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
22
+
23
+
24
+ class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
25
+ """Registry for VectorFormat classes."""
26
+
27
+ def register(
28
+ self, name: str
29
+ ) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
30
+ """Decorator to register a vector format class."""
31
+
32
+ def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
33
+ self[name] = cls
34
+ return cls
35
+
36
+ return decorator
37
+
38
+
39
+ VectorFormats = _VectorFormatRegistry()
22
40
 
23
41
 
24
42
  class VectorFormat:
@@ -53,6 +71,19 @@ class VectorFormat:
53
71
  """
54
72
  raise NotImplementedError
55
73
 
74
+ @staticmethod
75
+ def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
76
+ """Create a VectorFormat from a config dict.
77
+
78
+ Args:
79
+ name: the name of this format
80
+ config: the config dict
81
+
82
+ Returns:
83
+ the VectorFormat instance
84
+ """
85
+ raise NotImplementedError
86
+
56
87
 
57
88
  @VectorFormats.register("tile")
58
89
  class TileVectorFormat(VectorFormat):
@@ -410,5 +441,5 @@ def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
410
441
  Returns:
411
442
  the loaded VectorFormat implementation
412
443
  """
413
- cls = VectorFormats.get_class(config.name)
444
+ cls = VectorFormats[config.name]
414
445
  return cls.from_config(config.name, config.config_dict)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.8
3
+ Version: 0.0.11
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -212,7 +212,6 @@ Requires-Python: >=3.11
212
212
  Description-Content-Type: text/markdown
213
213
  License-File: LICENSE
214
214
  Requires-Dist: boto3>=1.39
215
- Requires-Dist: class_registry>=2.1
216
215
  Requires-Dist: fiona>=1.10
217
216
  Requires-Dist: fsspec>=2025.9.0
218
217
  Requires-Dist: jsonargparse>=4.35.0
@@ -244,6 +243,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
244
243
  Requires-Dist: pycocotools>=2.0; extra == "extra"
245
244
  Requires-Dist: pystac_client>=0.9; extra == "extra"
246
245
  Requires-Dist: rtree>=1.4; extra == "extra"
246
+ Requires-Dist: termcolor>=3.0; extra == "extra"
247
247
  Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
248
248
  Requires-Dist: scipy>=1.16; extra == "extra"
249
249
  Requires-Dist: terratorch>=1.0.2; extra == "extra"
@@ -284,6 +284,7 @@ Quick links:
284
284
  - [Examples](docs/Examples.md) contains more examples, including customizing different
285
285
  stages of rslearn with additional code.
286
286
  - [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
287
+ - [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
287
288
 
288
289
 
289
290
  Setup
@@ -20,11 +20,11 @@ rslearn/data_sources/eurocrops.py,sha256=bH_ul45RvNLLvVKU9JjZH1XMenAmgn7Lakq0pXj
20
20
  rslearn/data_sources/gcp_public_data.py,sha256=kr9stYo7ZCvz8s4E3wmoY-yAGZoLa_9RCwjS-Q5k9dM,36128
21
21
  rslearn/data_sources/geotiff.py,sha256=sFUp919chaX4j6lQytNp__xnMLlDI3Ac3rfB6F8sgZ0,45
22
22
  rslearn/data_sources/google_earth_engine.py,sha256=hpkt74ly2lEwjRrDp8FBmGvB3MEw_mQ38Av4rQOR3_w,24246
23
- rslearn/data_sources/local_files.py,sha256=NnmXWbOoKHxZFzwCiWsybWfPrFoPTpoZbjoK57xhoqc,19049
23
+ rslearn/data_sources/local_files.py,sha256=d08m6IzrUN_80VfvgpHahMJrv-n6_CI6EIocp6kyDRs,19490
24
24
  rslearn/data_sources/openstreetmap.py,sha256=qUSMFiIA_laJkO3meBXf9TmSI7OBD-o3i4JxqllUv3Q,19232
25
25
  rslearn/data_sources/planet.py,sha256=F2JoLaQ5Cb3k1cTm0hwSWTL2TPfbaAUMXZ8q4Dy7UlA,10109
26
26
  rslearn/data_sources/planet_basemap.py,sha256=wuWM9dHSJMdINfyWb78Zk9i-KvJHTrf9J0Q2gyEyiiA,10450
27
- rslearn/data_sources/planetary_computer.py,sha256=Wchr-OmAffuVteUW6VRofIqFpE-cJqI6GMTU9v1i8pw,28089
27
+ rslearn/data_sources/planetary_computer.py,sha256=uHNYxvnMkmo8zbqIiDRdnkz8LQ7TSs6K39Y1AXjboDI,30392
28
28
  rslearn/data_sources/raster_source.py,sha256=b8wo55GhVLxXwx1WYLzeRAlzD_ZkE_P9tnvUOdnsfQE,689
29
29
  rslearn/data_sources/usda_cdl.py,sha256=2_V11AhPRgLEGd4U5Pmx3UvE2HWBPbsFXhUIQVRVFeE,7138
30
30
  rslearn/data_sources/usgs_landsat.py,sha256=31GmOUfmxwTE6MTiVI4psb-ciVmunuA8cfvqDuvTHPE,19312
@@ -39,9 +39,9 @@ rslearn/dataset/add_windows.py,sha256=pwCEvwLE1jQCoqQxw6CJ-sP46ayWppFa2hGYIB6VVk
39
39
  rslearn/dataset/dataset.py,sha256=bjf9nI55j-MF0bIQWSNPjNbpfqnLK4jy-96TAcwO0MM,5214
40
40
  rslearn/dataset/handler_summaries.py,sha256=wGnbBpjLWTxVn3UT7j7nPoHlYsaWb9_MVJ5DhU0qWXY,2581
41
41
  rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
42
- rslearn/dataset/manage.py,sha256=E2PF29MWuYu8XMcX8PYG4itCdgxaUV768R5NCHmJ-e8,18058
43
- rslearn/dataset/materialize.py,sha256=emsxgkBKWC-ybZDtudsYOsxx7yxgmdKy5YrysgqzNac,21556
44
- rslearn/dataset/remap.py,sha256=HF5Kn5z6dEPbCpSxaBmYVDSzCuBjG2w8csPsk0RRlv4,2088
42
+ rslearn/dataset/manage.py,sha256=mkdBHo1RFGxMx8f9zBT_VmRO_6y8Qb2KfWPPziKWYkg,18062
43
+ rslearn/dataset/materialize.py,sha256=-z47svc_JqGhzkp8kq5Hd9fykWNqFEUCQezo887TWBw,22056
44
+ rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
45
45
  rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
46
46
  rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
47
47
  rslearn/models/anysat.py,sha256=3BnaiS1sYB4SnV6qRjHksiz_r9vUuZeGPUO2XUziFA0,7810
@@ -49,7 +49,7 @@ rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
49
49
  rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
50
50
  rslearn/models/copernicusfm.py,sha256=3AiORuUre9sZYwydbrDgShwKtxeTLmExp7WQmJtBylg,7842
51
51
  rslearn/models/croma.py,sha256=cOazTp3l2PNJltKrmPqD5Gy4pi3CI03-X9G4T10cX2k,9529
52
- rslearn/models/dinov3.py,sha256=KgR-vbmb3-6Qn4AC7ybGL6Ri3_f4jf2XOaKxAR9uMnU,6150
52
+ rslearn/models/dinov3.py,sha256=GKk5qXZPCEporATJdjaSWsDTfWDlAGRWBplFUJN5nRM,6146
53
53
  rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
54
54
  rslearn/models/fpn.py,sha256=s3cz29I14FaSuvBvLOcwCrqVsaRBxG5GjLlqap4WgPc,1603
55
55
  rslearn/models/module_wrapper.py,sha256=H2zb-8Au4t31kawW_4JEKHsaXFjpYDawb31ZEauKcxU,2728
@@ -57,9 +57,9 @@ rslearn/models/molmo.py,sha256=mVrARBhZciMzOgOOjGB5AHlPIf2iO9IBSJmdyKSl1L8,2061
57
57
  rslearn/models/multitask.py,sha256=j2Kiwj_dUiUp_CIUr25bS8HiyeoFlr1PGqjTfpgIGLc,14672
58
58
  rslearn/models/panopticon.py,sha256=woNEs53wVc5D-NxbSDEPRZ_mYe8vllnuldmADjvhfDQ,5806
59
59
  rslearn/models/pick_features.py,sha256=y8e4tJFhyG7ZuVSElWhQ5-Aer4ZKJCEH9wLGJU7WqGI,1551
60
- rslearn/models/pooling_decoder.py,sha256=jZfEQCfthfa21C9sEjgFHUcfhHMVlvG7_nDMw_1FLaE,2727
61
- rslearn/models/prithvi.py,sha256=bGY2Tf_V9tlJGDtDSCMkoFMJdFXUCW6cyj4r8FER8Pw,39557
62
- rslearn/models/registry.py,sha256=0j6jY-rmB_wKckHEUeCYWXtjviDOEr5c7WhIMxxDwBk,90
60
+ rslearn/models/pooling_decoder.py,sha256=unr2fSE_QmJHPi3dKtopqMtb1Kn-2h94LgwwAVP9vZg,4437
61
+ rslearn/models/prithvi.py,sha256=SVM3ypJlVTkXQ69pPhB4UeJr87VnmADTCuyV365dbkU,39961
62
+ rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
63
63
  rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
64
64
  rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
65
65
  rslearn/models/satlaspretrain.py,sha256=YpjXl-uClhTZMDmyhN64Fg3AszzT-ymZgJB0fO9RyoY,2419
@@ -91,8 +91,11 @@ rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJ
91
91
  rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
92
92
  rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
93
93
  rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
94
- rslearn/models/galileo/galileo.py,sha256=b8LMyn5lN-WHmCNTJUy-ykjL0s5_-jg9RTO1CwJVra4,19533
95
- rslearn/models/galileo/single_file_galileo.py,sha256=hmtBGdlFfX9JceZLguntSlXtuHElOafxwu1yIrhrEJY,55920
94
+ rslearn/models/galileo/galileo.py,sha256=jUHA64YvVC3Fz5fevc_9dFJfZaINODRDrhSGLIiOZcw,21115
95
+ rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
96
+ rslearn/models/olmoearth_pretrain/__init__.py,sha256=AjRvbjBdadCdPh-EdvySH76sVAQ8NGQaJt11Tsn1D5I,36
97
+ rslearn/models/olmoearth_pretrain/model.py,sha256=F-B1ym9UZuTPJ0OY15Jwb1TkNtr_EtAUlqI-tr_Z2uo,8352
98
+ rslearn/models/olmoearth_pretrain/norm.py,sha256=rHjFyWkpNLYMx9Ow7TsU-jGm9Sjx7FVf0p4R__ohx2c,3266
96
99
  rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
97
100
  rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
98
101
  rslearn/models/panopticon_data/sensors/goes.yaml,sha256=o00aoWCYqam0aB1rPmXq1MKe8hsKak_qyBG7BPL27Sc,152
@@ -106,10 +109,10 @@ rslearn/models/panopticon_data/sensors/sentinel2.yaml,sha256=qYJ92x-GHO0ZdCrTtCj
106
109
  rslearn/models/panopticon_data/sensors/superdove.yaml,sha256=QpIRyopdV4hAez_EIsDwhGFT4VtTk7UgzQveyc8t8fc,795
107
110
  rslearn/models/panopticon_data/sensors/wv23.yaml,sha256=SWYSlkka6UViKAz6YI8aqwQ-Ayo-S5kmNa9rO3iGW6o,1172
108
111
  rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRygze0,64
109
- rslearn/models/presto/presto.py,sha256=Vr4DBIP5Twnop5pZ21BhChbOwWP7KI-GMdg2XILkj5c,8774
112
+ rslearn/models/presto/presto.py,sha256=8mZnc0jk_r_JikybHQNyyHg6t7JNPmoPmgoivyNf-U8,9177
110
113
  rslearn/models/presto/single_file_presto.py,sha256=Kbwp8V7pO8HHM2vlCPpjekQiFiDryW8zQkWmt1g05BY,30381
111
114
  rslearn/tile_stores/__init__.py,sha256=o_tWVKu6UwFzZbO9jn_3cmIDqc_Q3qDd6tA9If0T_Qk,2050
112
- rslearn/tile_stores/default.py,sha256=P78LQm9OXNjd6Igkh-YqUFDmHSBN1Uk092xwqkuOPSs,14973
115
+ rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
113
116
  rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
114
117
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
115
118
  rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
@@ -140,7 +143,7 @@ rslearn/train/transforms/normalize.py,sha256=uyv2hE5hw5B2kCRHa4JIx0tfowm-C7bgumw
140
143
  rslearn/train/transforms/pad.py,sha256=EDswS9KYRSloM3DQlbCz6S0WYqFQJvI433qMqTtqrZw,4686
141
144
  rslearn/train/transforms/select_bands.py,sha256=uDfD9G8Z4VTt88QZsjj1FB20QEmzSefhKf7uDXYn77M,2441
142
145
  rslearn/train/transforms/sentinel1.py,sha256=FrLaYZs2AjqWQCun8DTFtgo1l0xLxqaFKtDNIehtpDg,1913
143
- rslearn/train/transforms/transform.py,sha256=8Q-dPrmDr0tJ9ZOwjWqWK8kbnKi4uLxEnS9Nwf6BVJk,3594
146
+ rslearn/train/transforms/transform.py,sha256=n1Qzqix2dVvej-Q7iPzHeOQbqH79IBlvqPoymxhNVpE,4446
144
147
  rslearn/utils/__init__.py,sha256=GNvdTUmXakiEMnLdje7k1fe5aC7SFVqP757kbpN6Fzw,558
145
148
  rslearn/utils/array.py,sha256=JwZi7o0uj-dftREzJmqrRVR2joIwBikm3Er9KeHVIZU,2402
146
149
  rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
@@ -150,15 +153,15 @@ rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF
150
153
  rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
151
154
  rslearn/utils/jsonargparse.py,sha256=JcTKQoZ6jgwag-kSeTIEVBO9AsRj0X1oEJBsoaCazH4,658
152
155
  rslearn/utils/mp.py,sha256=XYmVckI5TOQuCKc49NJyirDJyFgvb4AI-gGypG2j680,1399
153
- rslearn/utils/raster_format.py,sha256=3hjBR_0Gl1M1b0Vjx4UsQBgAVziCh34VAVWeW32k3Ds,25167
156
+ rslearn/utils/raster_format.py,sha256=dBTSa8l6Ms9Ndbx9Krgqm9z4RU7j2hwLBkw2w-KibU4,26009
154
157
  rslearn/utils/rtree_index.py,sha256=j0Zwrq3pXuAJ-hKpiRFQ7VNtvO3fZYk-Em2uBPAqfx4,6460
155
158
  rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfsU,762
156
159
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
157
160
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
158
- rslearn/utils/vector_format.py,sha256=C0-qlSXFFnhUU0S7eiOn4-4d5-pCkwU2jZ-q0Ldikk8,15299
159
- rslearn-0.0.8.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
160
- rslearn-0.0.8.dist-info/METADATA,sha256=hdHA3W7Hb4bQPPZG6SscWr1AOOVhALL0F08tEMt7neQ,36206
161
- rslearn-0.0.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
- rslearn-0.0.8.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
163
- rslearn-0.0.8.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
164
- rslearn-0.0.8.dist-info/RECORD,,
161
+ rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
162
+ rslearn-0.0.11.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
163
+ rslearn-0.0.11.dist-info/METADATA,sha256=jwB0ZZ-oLa1Y_1iuZRKCQoB4i3kOFDJ0xSeMTJP7zww,36297
164
+ rslearn-0.0.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
165
+ rslearn-0.0.11.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
166
+ rslearn-0.0.11.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
167
+ rslearn-0.0.11.dist-info/RECORD,,