rslearn 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,358 @@
1
+ """Data source for Google Satellite Embedding V1 dataset on AWS Open Data."""
2
+
3
+ import os
4
+ import tempfile
5
+ from collections.abc import Callable
6
+ from datetime import UTC, datetime
7
+ from typing import Any
8
+
9
+ import boto3
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+ import pandas as pd
13
+ import rasterio
14
+ import rasterio.vrt
15
+ import shapely
16
+ import shapely.wkt
17
+ from botocore import UNSIGNED
18
+ from botocore.config import Config
19
+ from rasterio.enums import Resampling
20
+ from upath import UPath
21
+
22
+ import rslearn.data_sources.utils
23
+ from rslearn.const import WGS84_PROJECTION
24
+ from rslearn.data_sources.data_source import (
25
+ DataSourceContext,
26
+ Item,
27
+ QueryConfig,
28
+ )
29
+ from rslearn.data_sources.direct_materialize_data_source import (
30
+ DirectMaterializeDataSource,
31
+ )
32
+ from rslearn.utils.fsspec import join_upath
33
+ from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
34
+ from rslearn.utils.grid_index import GridIndex
35
+
36
+ # Band names for the 64 embedding channels
37
+ BANDS = [f"A{idx:02d}" for idx in range(64)]
38
+
39
+ # S3 bucket configuration
40
+ BUCKET_NAME = "us-west-2.opendata.source.coop"
41
+ BUCKET_PREFIX = "tge-labs/aef/v1/annual"
42
+ INDEX_KEY = f"{BUCKET_PREFIX}/aef_index.csv"
43
+ HTTP_URL_BASE = f"https://s3.us-west-2.amazonaws.com/{BUCKET_NAME}"
44
+
45
+ # Grid index cell size for spatial queries
46
+ GRID_SIZE = 1.0
47
+
48
+
49
+ class GoogleSatelliteEmbeddingV1Item(Item):
50
+ """An item in the GoogleSatelliteEmbeddingV1 data source."""
51
+
52
+ def __init__(
53
+ self,
54
+ name: str,
55
+ geometry: STGeometry,
56
+ s3_path: str,
57
+ ) -> None:
58
+ """Creates a new GoogleSatelliteEmbeddingV1Item.
59
+
60
+ Args:
61
+ name: unique name of the item (the filename without extension)
62
+ geometry: the spatial and temporal extent of the item
63
+ s3_path: full S3 path to the TIFF file
64
+ """
65
+ super().__init__(name, geometry)
66
+ self.s3_path = s3_path
67
+
68
+ def serialize(self) -> dict:
69
+ """Serializes the item to a JSON-encodable dictionary."""
70
+ d = super().serialize()
71
+ d["s3_path"] = self.s3_path
72
+ return d
73
+
74
+ @staticmethod
75
+ def deserialize(d: dict) -> "GoogleSatelliteEmbeddingV1Item":
76
+ """Deserializes an item from a JSON-decoded dictionary."""
77
+ item = super(
78
+ GoogleSatelliteEmbeddingV1Item, GoogleSatelliteEmbeddingV1Item
79
+ ).deserialize(d)
80
+ return GoogleSatelliteEmbeddingV1Item(
81
+ name=item.name,
82
+ geometry=item.geometry,
83
+ s3_path=d["s3_path"],
84
+ )
85
+
86
+
87
+ class GoogleSatelliteEmbeddingV1(
88
+ DirectMaterializeDataSource[GoogleSatelliteEmbeddingV1Item]
89
+ ):
90
+ """Data source for Google Satellite Embedding V1 on AWS Open Data.
91
+
92
+ It consists of annual satellite embeddings at 10m resolution with 64 bands
93
+ (A00-A63). The data is stored as Cloud-Optimized GeoTIFFs organized by year and UTM
94
+ zone. Each file covers 8192x8192 pixels.
95
+
96
+ Available years: 2018-2024.
97
+
98
+ See https://registry.opendata.aws/aef-source/ for details.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ metadata_cache_dir: str,
104
+ apply_dequantization: bool = True,
105
+ context: DataSourceContext = DataSourceContext(),
106
+ ) -> None:
107
+ """Initialize a new GoogleSatelliteEmbeddingV1 instance.
108
+
109
+ Args:
110
+ metadata_cache_dir: directory to cache the index file.
111
+ apply_dequantization: whether to apply de-quantization to convert
112
+ int8 values to float32. The raw data is quantized int8; the
113
+ de-quantization maps values to [-1, 1] using the formula:
114
+ ((values / 127.5) ** 2) * sign(values). The raw data has nodata value
115
+ -128 while with dequantization the nodata value is -1.0. See
116
+ https://source.coop/tge-labs/aef for details.
117
+ context: the data source context.
118
+ """
119
+ # We have a single asset containing all 64 bands. Here "image" is an arbitrary
120
+ # name, since DirectMaterializeDataSource requires an asset name.
121
+ super().__init__(asset_bands={"image": BANDS})
122
+
123
+ self.apply_dequantization = apply_dequantization
124
+
125
+ # Set up cache directory
126
+ if context.ds_path is not None:
127
+ self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
128
+ else:
129
+ self.metadata_cache_dir = UPath(metadata_cache_dir)
130
+ self.metadata_cache_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ # S3 client with anonymous access (only used for downloading index)
133
+ self.s3_client = boto3.client(
134
+ "s3",
135
+ config=Config(signature_version=UNSIGNED),
136
+ region_name="us-west-2",
137
+ )
138
+
139
+ # Lazy-loaded grid index
140
+ self._grid_index: GridIndex | None = None
141
+ self._items_by_name: dict[str, GoogleSatelliteEmbeddingV1Item] | None = None
142
+
143
+ def _read_index_csv(self) -> pd.DataFrame:
144
+ """Read the index CSV, downloading from S3 if not cached.
145
+
146
+ Returns:
147
+ DataFrame with WKT, path, and year columns.
148
+ """
149
+ cache_file = self.metadata_cache_dir / "aef_index.csv"
150
+ if not cache_file.exists():
151
+ response = self.s3_client.get_object(Bucket=BUCKET_NAME, Key=INDEX_KEY)
152
+ content = response["Body"].read()
153
+ with cache_file.open("wb") as f:
154
+ f.write(content)
155
+
156
+ return pd.read_csv(
157
+ cache_file,
158
+ header=None,
159
+ usecols=[0, 2, 3],
160
+ names=["WKT", "path", "year"],
161
+ )
162
+
163
+ def _load_index(
164
+ self,
165
+ ) -> tuple[GridIndex, dict[str, GoogleSatelliteEmbeddingV1Item]]:
166
+ """Load the index file and build spatial index.
167
+
168
+ Returns:
169
+ Tuple of (grid_index, items_by_name dict).
170
+ """
171
+ if self._grid_index is not None and self._items_by_name is not None:
172
+ return self._grid_index, self._items_by_name
173
+
174
+ df = self._read_index_csv()
175
+
176
+ grid_index = GridIndex(GRID_SIZE)
177
+ items_by_name: dict[str, GoogleSatelliteEmbeddingV1Item] = {}
178
+
179
+ for _, row in df.iterrows():
180
+ shp = shapely.wkt.loads(row["WKT"])
181
+
182
+ year = int(row["year"])
183
+ time_range = (
184
+ datetime(year, 1, 1, tzinfo=UTC),
185
+ datetime(year, 12, 31, 23, 59, 59, tzinfo=UTC),
186
+ )
187
+
188
+ s3_path = row["path"]
189
+ name = s3_path.split("/")[-1].replace(".tiff", "")
190
+
191
+ geometry = STGeometry(WGS84_PROJECTION, shp, time_range)
192
+ item = GoogleSatelliteEmbeddingV1Item(
193
+ name=name,
194
+ geometry=geometry,
195
+ s3_path=s3_path,
196
+ )
197
+
198
+ grid_index.insert(shp.bounds, item)
199
+ items_by_name[name] = item
200
+
201
+ self._grid_index = grid_index
202
+ self._items_by_name = items_by_name
203
+ return grid_index, items_by_name
204
+
205
+ # --- DataSource implementation ---
206
+
207
+ def get_items(
208
+ self, geometries: list[STGeometry], query_config: QueryConfig
209
+ ) -> list[list[list[GoogleSatelliteEmbeddingV1Item]]]:
210
+ """Get a list of items in the data source intersecting the given geometries."""
211
+ grid_index, _ = self._load_index()
212
+
213
+ wgs84_geometries = [
214
+ geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
215
+ ]
216
+
217
+ groups = []
218
+ for geometry, wgs84_geometry in zip(geometries, wgs84_geometries):
219
+ cur_items = []
220
+ for item in grid_index.query(wgs84_geometry.shp.bounds):
221
+ if not wgs84_geometry.shp.intersects(item.geometry.shp):
222
+ continue
223
+ # Check time range if specified
224
+ if wgs84_geometry.time_range is not None:
225
+ item_start, item_end = item.geometry.time_range
226
+ query_start, query_end = wgs84_geometry.time_range
227
+ if item_end < query_start or item_start > query_end:
228
+ continue
229
+ cur_items.append(item)
230
+
231
+ cur_items.sort(key=lambda item: item.geometry.time_range[0])
232
+
233
+ cur_groups: list[list[GoogleSatelliteEmbeddingV1Item]] = (
234
+ rslearn.data_sources.utils.match_candidate_items_to_window(
235
+ geometry, cur_items, query_config
236
+ )
237
+ )
238
+ groups.append(cur_groups)
239
+
240
+ return groups
241
+
242
+ def get_item_by_name(self, name: str) -> GoogleSatelliteEmbeddingV1Item:
243
+ """Gets an item by name."""
244
+ _, items_by_name = self._load_index()
245
+ if name not in items_by_name:
246
+ raise ValueError(f"item {name} not found")
247
+ return items_by_name[name]
248
+
249
+ def deserialize_item(self, serialized_item: dict) -> GoogleSatelliteEmbeddingV1Item:
250
+ """Deserializes an item from JSON-decoded data."""
251
+ return GoogleSatelliteEmbeddingV1Item.deserialize(serialized_item)
252
+
253
+ def ingest(
254
+ self,
255
+ tile_store: Any,
256
+ items: list[GoogleSatelliteEmbeddingV1Item],
257
+ geometries: list[list[STGeometry]],
258
+ ) -> None:
259
+ """Ingest items into the given tile store.
260
+
261
+ Note: Each file is 2-3GB so this can be slow. Direct materialization via
262
+ read_raster or materialize is recommended for most use cases.
263
+
264
+ Args:
265
+ tile_store: the tile store to ingest into
266
+ items: the items to ingest
267
+ geometries: a list of geometries needed for each item
268
+ """
269
+ for item in items:
270
+ if tile_store.is_raster_ready(item.name, BANDS):
271
+ continue
272
+
273
+ # Download the TIFF file directly to disk
274
+ key = item.s3_path.replace(f"s3://{BUCKET_NAME}/", "")
275
+ with tempfile.TemporaryDirectory() as tmp_dir:
276
+ local_path = os.path.join(tmp_dir, f"{item.name}.tiff")
277
+ self.s3_client.download_file(BUCKET_NAME, key, local_path)
278
+ tile_store.write_raster_file(item.name, BANDS, UPath(local_path))
279
+
280
+ # --- DirectMaterializeDataSource implementation ---
281
+
282
+ def get_asset_url(self, item_name: str, asset_key: str) -> str:
283
+ """Get the HTTP URL to read the asset.
284
+
285
+ Returns a /vsicurl/ URL that rasterio can read directly over HTTP.
286
+ """
287
+ item = self.get_item_by_name(item_name)
288
+ # Convert s3://bucket/path to HTTP URL
289
+ key = item.s3_path.replace(f"s3://{BUCKET_NAME}/", "")
290
+ return f"/vsicurl/{HTTP_URL_BASE}/{key}"
291
+
292
+ def get_read_callback(
293
+ self, item_name: str, asset_key: str
294
+ ) -> Callable[[npt.NDArray[Any]], npt.NDArray[Any]] | None:
295
+ """Return a callback to apply de-quantization if enabled."""
296
+ if not self.apply_dequantization:
297
+ return None
298
+
299
+ def dequantize(data: npt.NDArray[Any]) -> npt.NDArray[np.float32]:
300
+ # Handle nodata (-128)
301
+ nodata_mask = data == -128
302
+ float_data = data.astype(np.float32)
303
+ # This is the dequantization formula recommended at https://source.coop/tge-labs/aef.
304
+ result = ((float_data / 127.5) ** 2) * np.sign(float_data)
305
+ # We make sure that NODATA is exactly -1.0 so user can handle it appropriately.
306
+ result[nodata_mask] = -1.0
307
+ return result
308
+
309
+ return dequantize
310
+
311
+ def read_raster(
312
+ self,
313
+ layer_name: str,
314
+ item_name: str,
315
+ bands: list[str],
316
+ projection: Projection,
317
+ bounds: PixelBounds,
318
+ resampling: Resampling = Resampling.bilinear,
319
+ ) -> npt.NDArray[Any]:
320
+ """Read raster data from the store.
321
+
322
+ Overrides base class to handle band selection (the base class reads all bands).
323
+ """
324
+ asset_url = self.get_asset_url(item_name, "image")
325
+
326
+ # Determine which band indices to read (1-indexed for rasterio)
327
+ if bands == BANDS:
328
+ band_indices = list(range(1, 65))
329
+ else:
330
+ band_indices = [BANDS.index(b) + 1 for b in bands]
331
+
332
+ # Construct the transform for the requested bounds
333
+ wanted_transform = rasterio.transform.Affine(
334
+ projection.x_resolution,
335
+ 0,
336
+ bounds[0] * projection.x_resolution,
337
+ 0,
338
+ projection.y_resolution,
339
+ bounds[1] * projection.y_resolution,
340
+ )
341
+
342
+ with rasterio.open(asset_url) as src:
343
+ with rasterio.vrt.WarpedVRT(
344
+ src,
345
+ crs=projection.crs,
346
+ transform=wanted_transform,
347
+ width=bounds[2] - bounds[0],
348
+ height=bounds[3] - bounds[1],
349
+ resampling=resampling,
350
+ ) as vrt:
351
+ data = vrt.read(indexes=band_indices)
352
+
353
+ # Apply callback if dequantization is enabled
354
+ callback = self.get_read_callback(item_name, "image")
355
+ if callback is not None:
356
+ data = callback(data)
357
+
358
+ return data
@@ -11,14 +11,16 @@ import rasterio.vrt
11
11
  from rasterio.enums import Resampling
12
12
 
13
13
  from rslearn.config import LayerConfig
14
- from rslearn.data_sources.data_source import DataSource, ItemType
14
+ from rslearn.data_sources.data_source import ItemLookupDataSource, ItemType
15
15
  from rslearn.dataset import Window
16
16
  from rslearn.dataset.materialize import RasterMaterializer
17
17
  from rslearn.tile_stores import TileStore, TileStoreWithLayer
18
18
  from rslearn.utils.geometry import PixelBounds, Projection
19
19
 
20
20
 
21
- class DirectMaterializeDataSource(DataSource[ItemType], TileStore, Generic[ItemType]):
21
+ class DirectMaterializeDataSource(
22
+ ItemLookupDataSource[ItemType], TileStore, Generic[ItemType]
23
+ ):
22
24
  """Base class for data sources that support direct materialization via TileStore.
23
25
 
24
26
  This class provides common TileStore functionality for data sources that can read
@@ -27,9 +29,10 @@ class DirectMaterializeDataSource(DataSource[ItemType], TileStore, Generic[ItemT
27
29
 
28
30
  Subclasses must implement:
29
31
  - get_asset_url(): Get the URL for an asset given item name and bands
30
- - get_item_by_name(): Get an item by its name
31
32
 
32
33
  Subclasses may optionally override:
34
+ - get_item_by_name(): Inherited from ItemLookupDataSource. If also inheriting
35
+ from a class that provides it (e.g., StacDataSource), no override needed.
33
36
  - get_raster_bands(): By default, we assume that items have all assets. If
34
37
  items may have a subset of assets, override get_raster_bands to return
35
38
  the sets of bands available for that item.
@@ -77,20 +80,6 @@ class DirectMaterializeDataSource(DataSource[ItemType], TileStore, Generic[ItemT
77
80
  """
78
81
  raise NotImplementedError
79
82
 
80
- def get_item_by_name(self, name: str) -> ItemType:
81
- """Get an item by its name.
82
-
83
- Subclasses must implement this method, either directly or by inheriting from
84
- a class that provides it (e.g., StacDataSource).
85
-
86
- Args:
87
- name: the name of the item to get.
88
-
89
- Returns:
90
- the item object.
91
- """
92
- raise NotImplementedError
93
-
94
83
  # --- Optional hooks for subclasses ---
95
84
 
96
85
  def get_read_callback(
@@ -418,7 +418,8 @@ def match_candidate_items_to_window(
418
418
  )
419
419
 
420
420
  # Now apply space mode.
421
- item_shps = []
421
+ acceptable_items = []
422
+ acceptable_item_shps = []
422
423
  for item in items:
423
424
  item_geom = item.geometry
424
425
  # We need to re-project items to the geometry projection for the spatial checks
@@ -430,14 +431,20 @@ def match_candidate_items_to_window(
430
431
  item_geom = geometry
431
432
  else:
432
433
  item_geom = item_geom.to_projection(geometry.projection)
433
- item_shps.append(item_geom.shp)
434
+
435
+ if item_geom.shp.area == 0:
436
+ # Must have been an item that didn't quite match the window's spatial extent.
437
+ continue
438
+
439
+ acceptable_items.append(item)
440
+ acceptable_item_shps.append(item_geom.shp)
434
441
 
435
442
  # Dispatch to the appropriate space mode handler
436
443
  handler = space_mode_handlers.get(query_config.space_mode)
437
444
  if handler is None:
438
445
  raise ValueError(f"invalid space mode {query_config.space_mode}")
439
446
 
440
- groups = handler(geometry, items, item_shps, query_config)
447
+ groups = handler(geometry, acceptable_items, acceptable_item_shps, query_config)
441
448
 
442
449
  # Enforce minimum matches if set.
443
450
  if len(groups) < query_config.min_matches:
@@ -15,7 +15,7 @@ from rslearn.dataset.window import (
15
15
  get_window_layer_dir,
16
16
  )
17
17
  from rslearn.log_utils import get_logger
18
- from rslearn.utils.fsspec import open_atomic
18
+ from rslearn.utils.fsspec import iter_nonhidden_subdirs, open_atomic
19
19
  from rslearn.utils.mp import star_imap_unordered
20
20
 
21
21
  from .storage import WindowStorage, WindowStorageFactory
@@ -77,8 +77,8 @@ class FileWindowStorage(WindowStorage):
77
77
  window_dirs = []
78
78
  if not groups:
79
79
  groups = []
80
- for p in (self.path / "windows").iterdir():
81
- groups.append(p.name)
80
+ for group_dir in iter_nonhidden_subdirs(self.path / "windows"):
81
+ groups.append(group_dir.name)
82
82
  for group in groups:
83
83
  group_dir = self.path / "windows" / group
84
84
  if not group_dir.exists():
@@ -86,16 +86,20 @@ class FileWindowStorage(WindowStorage):
86
86
  f"Skipping group directory {group_dir} since it does not exist"
87
87
  )
88
88
  continue
89
+ if not group_dir.is_dir():
90
+ logger.warning(
91
+ f"Skipping group path {group_dir} since it is not a directory"
92
+ )
93
+ continue
89
94
  if names:
90
- cur_names = names
95
+ for window_name in names:
96
+ window_dir = group_dir / window_name
97
+ if not window_dir.is_dir():
98
+ continue
99
+ window_dirs.append(window_dir)
91
100
  else:
92
- cur_names = []
93
- for p in group_dir.iterdir():
94
- cur_names.append(p.name)
95
-
96
- for window_name in cur_names:
97
- window_dir = group_dir / window_name
98
- window_dirs.append(window_dir)
101
+ for window_dir in iter_nonhidden_subdirs(group_dir):
102
+ window_dirs.append(window_dir)
99
103
 
100
104
  if workers == 0:
101
105
  windows = [load_window(self, window_dir) for window_dir in window_dirs]
@@ -162,7 +166,7 @@ class FileWindowStorage(WindowStorage):
162
166
  return []
163
167
 
164
168
  completed_layers = []
165
- for layer_dir in layers_directory.iterdir():
169
+ for layer_dir in iter_nonhidden_subdirs(layers_directory):
166
170
  layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
167
171
  if not self.is_layer_completed(group, name, layer_name, group_idx):
168
172
  continue
@@ -0,0 +1,74 @@
1
+ """Global pooling decoder for spatial feature maps."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import FeatureMaps, FeatureVector, IntermediateComponent
10
+
11
+
12
+ class GlobalPool(IntermediateComponent):
13
+ """Apply global pooling to reduce spatial dimensions.
14
+
15
+ This component applies global average or max pooling over the spatial dimensions
16
+ of input feature maps. By default, it produces FeatureVector (BxC) suitable for
17
+ ClassificationHead or RegressionHead. When keep_spatial_dims=True, it produces
18
+ 1x1 FeatureMaps suitable for EmbeddingHead.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ mode: Literal["mean", "max"] = "mean",
24
+ keep_spatial_dims: bool = False,
25
+ ) -> None:
26
+ """Create a new GlobalPool.
27
+
28
+ Args:
29
+ mode: the pooling mode, either "mean" for global average pooling or
30
+ "max" for global max pooling. Defaults to "mean".
31
+ keep_spatial_dims: if True, returns FeatureMaps with 1x1 spatial dimensions.
32
+ If False (default), returns FeatureVector (BxC). Defaults to False.
33
+ """
34
+ super().__init__()
35
+ if mode not in ("mean", "max"):
36
+ raise ValueError(f"mode must be 'mean' or 'max', got '{mode}'")
37
+ self.mode = mode
38
+ self.keep_spatial_dims = keep_spatial_dims
39
+
40
+ def forward(
41
+ self, intermediates: Any, context: ModelContext
42
+ ) -> FeatureMaps | FeatureVector:
43
+ """Apply global pooling on the feature maps.
44
+
45
+ Args:
46
+ intermediates: output from the previous model component, which must be
47
+ a FeatureMaps.
48
+ context: the model context.
49
+
50
+ Returns:
51
+ If keep_spatial_dims=False (default): FeatureVector (BxC) suitable for
52
+ ClassificationHead or RegressionHead.
53
+ If keep_spatial_dims=True: FeatureMaps with 1x1 spatial dimensions suitable
54
+ for EmbeddingHead.
55
+ """
56
+ if not isinstance(intermediates, FeatureMaps):
57
+ raise ValueError("input to GlobalPool must be FeatureMaps")
58
+
59
+ pooled_features = []
60
+ for feat in intermediates.feature_maps:
61
+ # feat is BCHW
62
+ if self.mode == "mean":
63
+ pooled = feat.mean(dim=(2, 3), keepdim=self.keep_spatial_dims)
64
+ else:
65
+ pooled = torch.amax(feat, dim=(2, 3), keepdim=self.keep_spatial_dims)
66
+ pooled_features.append(pooled)
67
+
68
+ if self.keep_spatial_dims:
69
+ return FeatureMaps(pooled_features)
70
+ else:
71
+ if len(pooled_features) == 1:
72
+ return FeatureVector(pooled_features[0])
73
+ else:
74
+ return FeatureVector(torch.cat(pooled_features, dim=1))
@@ -15,6 +15,8 @@ from upath import UPath
15
15
  from rslearn.const import WGS84_PROJECTION
16
16
  from rslearn.utils.feature import Feature
17
17
  from rslearn.utils.fsspec import (
18
+ iter_nonhidden_files,
19
+ iter_nonhidden_subdirs,
18
20
  join_upath,
19
21
  open_atomic,
20
22
  open_rasterio_upath_reader,
@@ -129,7 +131,7 @@ class DefaultTileStore(TileStore):
129
131
  ValueError: if no file is found.
130
132
  """
131
133
  raster_dir = self._get_raster_dir(layer_name, item_name, bands)
132
- for fname in raster_dir.iterdir():
134
+ for fname in iter_nonhidden_files(raster_dir):
133
135
  # Ignore completed sentinel files, bands files, as well as temporary files created by
134
136
  # open_atomic (in case this tile store is on local filesystem).
135
137
  if fname.name == COMPLETED_FNAME:
@@ -175,7 +177,7 @@ class DefaultTileStore(TileStore):
175
177
  return []
176
178
 
177
179
  bands: list[list[str]] = []
178
- for raster_dir in item_dir.iterdir():
180
+ for raster_dir in iter_nonhidden_subdirs(item_dir):
179
181
  if not (raster_dir / BANDS_FNAME).exists():
180
182
  # This is likely a legacy directory where the bands are only encoded in
181
183
  # the directory name, so we have to rely on that.
@@ -108,10 +108,10 @@ class RslearnDataModule(L.LightningDataModule):
108
108
  self.use_in_memory_all_crops_dataset = use_in_memory_all_crops_dataset
109
109
  self.index_mode = index_mode
110
110
  self.split_configs = {
111
- "train": default_config.update(train_config),
112
- "val": default_config.update(val_config),
113
- "test": default_config.update(test_config),
114
- "predict": default_config.update(predict_config),
111
+ "train": SplitConfig.merge_and_validate([default_config, train_config]),
112
+ "val": SplitConfig.merge_and_validate([default_config, val_config]),
113
+ "test": SplitConfig.merge_and_validate([default_config, test_config]),
114
+ "predict": SplitConfig.merge_and_validate([default_config, predict_config]),
115
115
  }
116
116
 
117
117
  def setup(
@@ -141,7 +141,7 @@ class RslearnDataModule(L.LightningDataModule):
141
141
  task=self.task,
142
142
  workers=self.init_workers,
143
143
  name=self.name,
144
- fix_patch_pick=(split != "train"),
144
+ fix_crop_pick=(split != "train"),
145
145
  index_mode=self.index_mode,
146
146
  )
147
147
  logger.info(f"got {len(dataset)} examples in split {split}")
@@ -203,13 +203,16 @@ class RslearnDataModule(L.LightningDataModule):
203
203
  # Enable persistent workers unless we are using main process.
204
204
  persistent_workers = self.num_workers > 0
205
205
 
206
- # If using all patches, limit number of workers to the number of windows.
206
+ # If using all crops, limit number of workers to the number of windows.
207
207
  # Otherwise it has to distribute the same window to different workers which can
208
208
  # cause issues for RslearnWriter.
209
209
  # If the number of windows is 0, then we can set positive number of workers
210
210
  # since they won't yield anything anyway.
211
211
  num_workers = self.num_workers
212
- if split_config.load_all_crops and len(dataset.get_dataset_examples()) > 0:
212
+ if (
213
+ split_config.get_load_all_crops()
214
+ and len(dataset.get_dataset_examples()) > 0
215
+ ):
213
216
  num_workers = min(num_workers, len(dataset.get_dataset_examples()))
214
217
 
215
218
  kwargs: dict[str, Any] = dict(