rslearn 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. {rslearn-0.0.1/rslearn.egg-info → rslearn-0.0.2}/PKG-INFO +32 -27
  2. rslearn-0.0.2/extra_requirements.txt +17 -0
  3. {rslearn-0.0.1 → rslearn-0.0.2}/pyproject.toml +1 -1
  4. rslearn-0.0.2/requirements.txt +15 -0
  5. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/config/dataset.py +22 -13
  6. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/__init__.py +8 -0
  7. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/aws_landsat.py +27 -18
  8. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/aws_open_data.py +41 -42
  9. rslearn-0.0.2/rslearn/data_sources/copernicus.py +188 -0
  10. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/data_source.py +17 -10
  11. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/gcp_public_data.py +177 -100
  12. rslearn-0.0.2/rslearn/data_sources/geotiff.py +1 -0
  13. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/google_earth_engine.py +17 -15
  14. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/local_files.py +59 -32
  15. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/openstreetmap.py +27 -23
  16. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/planet.py +10 -9
  17. rslearn-0.0.2/rslearn/data_sources/planet_basemap.py +303 -0
  18. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/raster_source.py +23 -13
  19. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/usgs_landsat.py +56 -27
  20. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/utils.py +13 -6
  21. rslearn-0.0.2/rslearn/data_sources/vector_source.py +1 -0
  22. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/data_sources/xyz_tiles.py +8 -9
  23. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/add_windows.py +1 -1
  24. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/dataset.py +16 -5
  25. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/manage.py +9 -4
  26. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/materialize.py +26 -5
  27. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/window.py +5 -0
  28. rslearn-0.0.2/rslearn/log_utils.py +24 -0
  29. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/main.py +123 -59
  30. rslearn-0.0.2/rslearn/models/clip.py +62 -0
  31. rslearn-0.0.2/rslearn/models/conv.py +56 -0
  32. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/faster_rcnn.py +2 -19
  33. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/fpn.py +1 -1
  34. rslearn-0.0.2/rslearn/models/module_wrapper.py +43 -0
  35. rslearn-0.0.2/rslearn/models/molmo.py +65 -0
  36. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/multitask.py +1 -1
  37. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/pooling_decoder.py +4 -2
  38. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/satlaspretrain.py +4 -7
  39. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/simple_time_series.py +61 -55
  40. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/ssl4eo_s12.py +9 -9
  41. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/swin.py +22 -21
  42. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/unet.py +4 -2
  43. rslearn-0.0.2/rslearn/models/upsample.py +35 -0
  44. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/tile_stores/file.py +6 -3
  45. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/tile_stores/tile_store.py +19 -7
  46. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  47. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/data_module.py +5 -4
  48. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/dataset.py +79 -36
  49. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/lightning_module.py +15 -11
  50. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/prediction_writer.py +22 -11
  51. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/classification.py +9 -8
  52. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/detection.py +94 -37
  53. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/multi_task.py +1 -1
  54. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/regression.py +8 -4
  55. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/segmentation.py +23 -19
  56. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/__init__.py +1 -1
  57. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/concatenate.py +6 -2
  58. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/crop.py +6 -2
  59. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/flip.py +5 -1
  60. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/normalize.py +9 -5
  61. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/pad.py +1 -1
  62. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/transforms/transform.py +3 -3
  63. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/__init__.py +4 -5
  64. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/array.py +2 -2
  65. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/feature.py +1 -1
  66. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/fsspec.py +70 -1
  67. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/geometry.py +155 -3
  68. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/grid_index.py +5 -5
  69. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/mp.py +4 -3
  70. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/raster_format.py +81 -73
  71. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/rtree_index.py +64 -17
  72. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/sqlite_index.py +7 -1
  73. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/utils.py +11 -3
  74. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/vector_format.py +113 -17
  75. {rslearn-0.0.1 → rslearn-0.0.2/rslearn.egg-info}/PKG-INFO +32 -27
  76. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn.egg-info/SOURCES.txt +7 -1
  77. rslearn-0.0.2/rslearn.egg-info/requires.txt +33 -0
  78. rslearn-0.0.1/extra_requirements.txt +0 -11
  79. rslearn-0.0.1/requirements.txt +0 -15
  80. rslearn-0.0.1/rslearn/data_sources/copernicus.py +0 -42
  81. rslearn-0.0.1/rslearn/data_sources/geotiff.py +0 -0
  82. rslearn-0.0.1/rslearn/data_sources/vector_source.py +0 -0
  83. rslearn-0.0.1/rslearn/utils/mgrs.py +0 -24
  84. rslearn-0.0.1/rslearn.egg-info/requires.txt +0 -28
  85. {rslearn-0.0.1 → rslearn-0.0.2}/LICENSE +0 -0
  86. {rslearn-0.0.1 → rslearn-0.0.2}/README.md +0 -0
  87. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/__init__.py +0 -0
  88. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/config/__init__.py +0 -0
  89. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/const.py +0 -0
  90. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/__init__.py +0 -0
  91. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/dataset/remap.py +0 -0
  92. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/__init__.py +0 -0
  93. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/pick_features.py +0 -0
  94. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/registry.py +0 -0
  95. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/sam2_enc.py +0 -0
  96. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/models/singletask.py +0 -0
  97. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/tile_stores/__init__.py +0 -0
  98. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/__init__.py +0 -0
  99. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/callbacks/__init__.py +0 -0
  100. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/__init__.py +0 -0
  101. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/train/tasks/task.py +0 -0
  102. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/get_utm_ups_crs.py +0 -0
  103. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/spatial_index.py +0 -0
  104. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn/utils/time.py +0 -0
  105. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn.egg-info/dependency_links.txt +0 -0
  106. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn.egg-info/entry_points.txt +0 -0
  107. {rslearn-0.0.1 → rslearn-0.0.2}/rslearn.egg-info/top_level.txt +0 -0
  108. {rslearn-0.0.1 → rslearn-0.0.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rslearn
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author-email: Favyen Bastani <favyenb@allenai.org>, Yawen Zhang <yawenz@allenai.org>, Patrick Beukema <patrickb@allenai.org>, Henry Herzog <henryh@allenai.org>, Piper Wolters <piperw@allenai.org>
6
6
  License: Apache License
@@ -208,33 +208,38 @@ License: Apache License
208
208
  Requires-Python: >=3.10
209
209
  Description-Content-Type: text/markdown
210
210
  License-File: LICENSE
211
- Requires-Dist: boto3
212
- Requires-Dist: class_registry
213
- Requires-Dist: python-dateutil
214
- Requires-Dist: pytimeparse
215
- Requires-Dist: fiona
216
- Requires-Dist: fsspec[gcs,s3]
217
- Requires-Dist: Pillow
218
- Requires-Dist: pyproj
219
- Requires-Dist: rasterio
220
- Requires-Dist: shapely
221
- Requires-Dist: tqdm
222
- Requires-Dist: torch
223
- Requires-Dist: torchvision
224
- Requires-Dist: universal_pathlib
225
- Requires-Dist: lightning[pytorch-extra]
211
+ Requires-Dist: boto3>=1.35
212
+ Requires-Dist: class_registry>=2.1
213
+ Requires-Dist: fiona>=1.10
214
+ Requires-Dist: fsspec[gcs,s3]>=2024.10
215
+ Requires-Dist: lightning[pytorch-extra]>=2.4
216
+ Requires-Dist: Pillow>=11.0
217
+ Requires-Dist: pyproj>=3.7
218
+ Requires-Dist: python-dateutil>=2.9
219
+ Requires-Dist: pytimeparse>=1.1
220
+ Requires-Dist: rasterio>=1.4
221
+ Requires-Dist: shapely>=2.0
222
+ Requires-Dist: torch>=2.5
223
+ Requires-Dist: torchvision>=0.20
224
+ Requires-Dist: tqdm>=4.66
225
+ Requires-Dist: universal_pathlib>=0.2.5
226
226
  Provides-Extra: extra
227
- Requires-Dist: earthengine-api; extra == "extra"
228
- Requires-Dist: gcsfs; extra == "extra"
229
- Requires-Dist: google-cloud-storage; extra == "extra"
230
- Requires-Dist: mgrs; extra == "extra"
231
- Requires-Dist: osmium; extra == "extra"
232
- Requires-Dist: planet; extra == "extra"
233
- Requires-Dist: pycocotools; extra == "extra"
234
- Requires-Dist: rtree; extra == "extra"
235
- Requires-Dist: satlaspretrain_models; extra == "extra"
236
- Requires-Dist: scipy; extra == "extra"
237
- Requires-Dist: wandb; extra == "extra"
227
+ Requires-Dist: accelerate>=1.0; extra == "extra"
228
+ Requires-Dist: earthengine-api>=0.1; extra == "extra"
229
+ Requires-Dist: einops>=0.8; extra == "extra"
230
+ Requires-Dist: gcsfs>=2024.10; extra == "extra"
231
+ Requires-Dist: google-cloud-bigquery>=2.18; extra == "extra"
232
+ Requires-Dist: google-cloud-storage>=2.18; extra == "extra"
233
+ Requires-Dist: interrogate>=1.7; extra == "extra"
234
+ Requires-Dist: osmium>=3.7; extra == "extra"
235
+ Requires-Dist: planet>=2.10; extra == "extra"
236
+ Requires-Dist: pycocotools>=2.0; extra == "extra"
237
+ Requires-Dist: rtree>=1.2; extra == "extra"
238
+ Requires-Dist: s3fs>=2024.10.0; extra == "extra"
239
+ Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
240
+ Requires-Dist: scipy>=1.13; extra == "extra"
241
+ Requires-Dist: transformers>=4.45; extra == "extra"
242
+ Requires-Dist: wandb>=0.17; extra == "extra"
238
243
 
239
244
  Overview
240
245
  --------
@@ -0,0 +1,17 @@
1
+ # These are requirements that are specific to a subset of data sources or models.
2
+ accelerate>=1.0
3
+ earthengine-api>=0.1
4
+ einops>=0.8
5
+ gcsfs>=2024.10
6
+ google-cloud-bigquery>=2.18
7
+ google-cloud-storage>=2.18
8
+ interrogate>=1.7
9
+ osmium>=3.7
10
+ planet>=2.10
11
+ pycocotools>=2.0
12
+ rtree>=1.2
13
+ s3fs>=2024.10.0
14
+ satlaspretrain_models>=0.3
15
+ scipy>=1.13
16
+ transformers>=4.45
17
+ wandb>=0.17
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.1"
3
+ version = "0.0.2"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  {name = "Favyen Bastani", email = "favyenb@allenai.org"},
@@ -0,0 +1,15 @@
1
+ boto3>=1.35
2
+ class_registry>=2.1
3
+ fiona>=1.10
4
+ fsspec[s3,gcs]>=2024.10
5
+ lightning[pytorch-extra]>=2.4
6
+ Pillow>=11.0
7
+ pyproj>=3.7
8
+ python-dateutil>=2.9
9
+ pytimeparse>=1.1
10
+ rasterio>=1.4
11
+ shapely>=2.0
12
+ torch>=2.5
13
+ torchvision>=0.20
14
+ tqdm>=4.66
15
+ universal_pathlib>=0.2.5
@@ -112,7 +112,7 @@ class BandSetConfig:
112
112
  self,
113
113
  config_dict: dict[str, Any],
114
114
  dtype: DType,
115
- bands: list[str] | None = None,
115
+ bands: list[str],
116
116
  format: dict[str, Any] | None = None,
117
117
  zoom_offset: int = 0,
118
118
  remap: dict[str, Any] | None = None,
@@ -130,13 +130,14 @@ class BandSetConfig:
130
130
  """
131
131
  self.config_dict = config_dict
132
132
  self.bands = bands
133
- self.format = format
134
133
  self.dtype = dtype
135
134
  self.zoom_offset = zoom_offset
136
135
  self.remap = remap
137
136
 
138
- if not self.format:
137
+ if format is None:
139
138
  self.format = {"name": "geotiff"}
139
+ else:
140
+ self.format = format
140
141
 
141
142
  def serialize(self) -> dict[str, Any]:
142
143
  """Serialize this BandSetConfig to a config dict, currently unused."""
@@ -158,11 +159,12 @@ class BandSetConfig:
158
159
  kwargs = dict(
159
160
  config_dict=config,
160
161
  dtype=DType(config["dtype"]),
162
+ bands=config["bands"],
161
163
  )
162
- for k in ["bands", "format", "zoom_offset", "remap"]:
164
+ for k in ["format", "zoom_offset", "remap"]:
163
165
  if k in config:
164
166
  kwargs[k] = config[k]
165
- return BandSetConfig(**kwargs)
167
+ return BandSetConfig(**kwargs) # type: ignore
166
168
 
167
169
  def get_final_projection_and_bounds(
168
170
  self, projection: Projection, bounds: PixelBounds | None
@@ -187,11 +189,15 @@ class BandSetConfig:
187
189
  projection.x_resolution / (2**self.zoom_offset),
188
190
  projection.y_resolution / (2**self.zoom_offset),
189
191
  )
190
- if bounds:
192
+ if bounds is not None:
191
193
  if self.zoom_offset > 0:
192
- bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
194
+ zoom_factor = 2**self.zoom_offset
195
+ bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
193
196
  else:
194
- bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
197
+ bounds = tuple(
198
+ x // (2 ** (-self.zoom_offset))
199
+ for x in bounds # type: ignore
200
+ )
195
201
  return projection, bounds
196
202
 
197
203
 
@@ -422,7 +428,7 @@ class RasterLayerConfig(LayerConfig):
422
428
  ]
423
429
  if "alias" in config:
424
430
  kwargs["alias"] = config["alias"]
425
- return RasterLayerConfig(**kwargs)
431
+ return RasterLayerConfig(**kwargs) # type: ignore
426
432
 
427
433
 
428
434
  class VectorLayerConfig(LayerConfig):
@@ -456,7 +462,7 @@ class VectorLayerConfig(LayerConfig):
456
462
  Args:
457
463
  config: the config dict for this VectorLayerConfig
458
464
  """
459
- kwargs = {"layer_type": LayerType(config["type"])}
465
+ kwargs: dict[str, Any] = {"layer_type": LayerType(config["type"])}
460
466
  if "data_source" in config:
461
467
  kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
462
468
  if "zoom_offset" in config:
@@ -465,7 +471,7 @@ class VectorLayerConfig(LayerConfig):
465
471
  kwargs["format"] = VectorFormatConfig.from_config(config["format"])
466
472
  if "alias" in config:
467
473
  kwargs["alias"] = config["alias"]
468
- return VectorLayerConfig(**kwargs)
474
+ return VectorLayerConfig(**kwargs) # type: ignore
469
475
 
470
476
  def get_final_projection_and_bounds(
471
477
  self, projection: Projection, bounds: PixelBounds | None
@@ -488,9 +494,12 @@ class VectorLayerConfig(LayerConfig):
488
494
  )
489
495
  if bounds:
490
496
  if self.zoom_offset > 0:
491
- bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
497
+ bounds = tuple(x * (2**self.zoom_offset) for x in bounds) # type: ignore
492
498
  else:
493
- bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
499
+ bounds = tuple(
500
+ x // (2 ** (-self.zoom_offset))
501
+ for x in bounds # type: ignore
502
+ )
494
503
  return projection, bounds
495
504
 
496
505
 
@@ -10,15 +10,20 @@ Each source supports operations to lookup items that match with spatiotemporal
10
10
  geometries, and ingest those items.
11
11
  """
12
12
 
13
+ import functools
13
14
  import importlib
14
15
 
15
16
  from upath import UPath
16
17
 
17
18
  from rslearn.config import LayerConfig
19
+ from rslearn.log_utils import get_logger
18
20
 
19
21
  from .data_source import DataSource, Item, ItemLookupDataSource, RetrieveItemDataSource
20
22
 
23
+ logger = get_logger(__name__)
21
24
 
25
+
26
+ @functools.cache
22
27
  def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
23
28
  """Loads a data source from config dict.
24
29
 
@@ -26,6 +31,9 @@ def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
26
31
  config: the LayerConfig containing this data source.
27
32
  ds_path: the dataset root directory.
28
33
  """
34
+ logger.debug("getting a data source for dataset at %s", ds_path)
35
+ if config.data_source is None:
36
+ raise ValueError("No data source specified")
29
37
  name = config.data_source.name
30
38
  module_name = ".".join(name.split(".")[:-1])
31
39
  class_name = name.split(".")[-1]
@@ -6,7 +6,7 @@ import shutil
6
6
  import urllib.request
7
7
  import zipfile
8
8
  from collections.abc import Generator
9
- from datetime import timedelta
9
+ from datetime import datetime, timedelta
10
10
  from typing import Any, BinaryIO
11
11
 
12
12
  import boto3
@@ -20,8 +20,7 @@ import tqdm
20
20
  from upath import UPath
21
21
 
22
22
  import rslearn.data_sources.utils
23
- import rslearn.utils.mgrs
24
- from rslearn.config import LayerConfig, RasterLayerConfig
23
+ from rslearn.config import RasterLayerConfig
25
24
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
26
25
  from rslearn.tile_stores import PrefixedTileStore, TileStore
27
26
  from rslearn.utils import STGeometry
@@ -36,7 +35,7 @@ class LandsatOliTirsItem(Item):
36
35
 
37
36
  def __init__(
38
37
  self, name: str, geometry: STGeometry, blob_path: str, cloud_cover: float
39
- ):
38
+ ) -> None:
40
39
  """Creates a new LandsatOliTirsItem.
41
40
 
42
41
  Args:
@@ -58,7 +57,7 @@ class LandsatOliTirsItem(Item):
58
57
  return d
59
58
 
60
59
  @staticmethod
61
- def deserialize(d: dict) -> Item:
60
+ def deserialize(d: dict) -> "LandsatOliTirsItem":
62
61
  """Deserializes an item from a JSON-decoded dictionary."""
63
62
  if "name" not in d:
64
63
  d["name"] = d["blob_path"].split("/")[-1].split(".tif")[0]
@@ -90,7 +89,7 @@ class LandsatOliTirs(DataSource):
90
89
 
91
90
  def __init__(
92
91
  self,
93
- config: LayerConfig,
92
+ config: RasterLayerConfig,
94
93
  metadata_cache_dir: UPath,
95
94
  max_time_delta: timedelta = timedelta(days=30),
96
95
  sort_by: str | None = None,
@@ -110,15 +109,16 @@ class LandsatOliTirs(DataSource):
110
109
  self.metadata_cache_dir = metadata_cache_dir
111
110
  self.max_time_delta = max_time_delta
112
111
  self.sort_by = sort_by
112
+ print(self.metadata_cache_dir)
113
113
 
114
114
  self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
115
-
116
115
  self.metadata_cache_dir.mkdir(parents=True, exist_ok=True)
117
116
 
118
117
  @staticmethod
119
- def from_config(config: LayerConfig, ds_path: UPath) -> "LandsatOliTirs":
118
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
120
119
  """Creates a new LandsatOliTirs instance from a configuration dictionary."""
121
- assert isinstance(config, RasterLayerConfig)
120
+ if config.data_source is None:
121
+ raise ValueError(f"data_source is required for config dict {config}")
122
122
  d = config.data_source.config_dict
123
123
  kwargs = dict(
124
124
  config=config,
@@ -181,8 +181,9 @@ class LandsatOliTirs(DataSource):
181
181
  ts = dateutil.parser.isoparse(date_str + "T" + time_str)
182
182
 
183
183
  blob_path = obj.key.split("MTL.json")[0]
184
+ time_range: tuple[datetime, datetime] = (ts, ts)
184
185
  geometry = STGeometry(
185
- WGS84_PROJECTION, shapely.Polygon(coordinates), [ts, ts]
186
+ WGS84_PROJECTION, shapely.Polygon(coordinates), time_range
186
187
  )
187
188
  items.append(
188
189
  LandsatOliTirsItem(
@@ -216,6 +217,7 @@ class LandsatOliTirs(DataSource):
216
217
  if not shp_fname.exists():
217
218
  # Download and extract zip to cache dir.
218
219
  zip_fname = self.metadata_cache_dir / f"{prefix}.zip"
220
+ print(f"Downloading {self.wrs2_url} to {zip_fname}")
219
221
  with urllib.request.urlopen(self.wrs2_url) as response:
220
222
  with zip_fname.open("wb") as f:
221
223
  shutil.copyfileobj(response, f)
@@ -259,7 +261,7 @@ class LandsatOliTirs(DataSource):
259
261
 
260
262
  def get_items(
261
263
  self, geometries: list[STGeometry], query_config: QueryConfig
262
- ) -> list[list[list[Item]]]:
264
+ ) -> list[list[list[LandsatOliTirsItem]]]:
263
265
  """Get a list of items in the data source intersecting the given geometries.
264
266
 
265
267
  Args:
@@ -305,14 +307,16 @@ class LandsatOliTirs(DataSource):
305
307
  elif self.sort_by is not None:
306
308
  raise ValueError(f"invalid sort_by setting ({self.sort_by})")
307
309
 
308
- cur_groups = rslearn.data_sources.utils.match_candidate_items_to_window(
309
- geometry, cur_items, query_config
310
+ cur_groups: list[list[LandsatOliTirsItem]] = (
311
+ rslearn.data_sources.utils.match_candidate_items_to_window(
312
+ geometry, cur_items, query_config
313
+ )
310
314
  )
311
315
  groups.append(cur_groups)
312
316
 
313
317
  return groups
314
318
 
315
- def get_item_by_name(self, name: str) -> Item:
319
+ def get_item_by_name(self, name: str) -> LandsatOliTirsItem:
316
320
  """Gets an item by name."""
317
321
  # Product name is like LC08_L1TP_046027_20230715_20230724_02_T1.
318
322
  # We want to use _read_products so we need to extract:
@@ -330,12 +334,14 @@ class LandsatOliTirs(DataSource):
330
334
  return item
331
335
  raise ValueError(f"item {name} not found")
332
336
 
333
- def deserialize_item(self, serialized_item: Any) -> Item:
337
+ def deserialize_item(self, serialized_item: Any) -> LandsatOliTirsItem:
334
338
  """Deserializes an item from JSON-decoded data."""
335
339
  assert isinstance(serialized_item, dict)
336
340
  return LandsatOliTirsItem.deserialize(serialized_item)
337
341
 
338
- def retrieve_item(self, item: Item) -> Generator[tuple[str, BinaryIO], None, None]:
342
+ def retrieve_item(
343
+ self, item: LandsatOliTirsItem
344
+ ) -> Generator[tuple[str, BinaryIO], None, None]:
339
345
  """Retrieves the rasters corresponding to an item as file streams."""
340
346
  for band in self.bands:
341
347
  buf = io.BytesIO()
@@ -351,7 +357,7 @@ class LandsatOliTirs(DataSource):
351
357
  def ingest(
352
358
  self,
353
359
  tile_store: TileStore,
354
- items: list[Item],
360
+ items: list[LandsatOliTirsItem],
355
361
  geometries: list[list[STGeometry]],
356
362
  ) -> None:
357
363
  """Ingest items into the given tile store.
@@ -368,7 +374,10 @@ class LandsatOliTirs(DataSource):
368
374
  tile_store, (item.name, "_".join(band_names))
369
375
  )
370
376
  needed_projections = get_needed_projections(
371
- cur_tile_store, band_names, self.config.band_sets, cur_geometries
377
+ cur_tile_store,
378
+ band_names,
379
+ self.config.band_sets,
380
+ cur_geometries, # type: ignore
372
381
  )
373
382
  if not needed_projections:
374
383
  continue
@@ -2,7 +2,6 @@
2
2
 
3
3
  import io
4
4
  import json
5
- import tempfile
6
5
  import xml.etree.ElementTree as ET
7
6
  from collections.abc import Callable, Generator
8
7
  from datetime import datetime, timedelta, timezone
@@ -22,19 +21,13 @@ from rasterio.crs import CRS
22
21
  from upath import UPath
23
22
 
24
23
  import rslearn.data_sources.utils
25
- import rslearn.utils.mgrs
26
- from rslearn.config import LayerConfig, RasterLayerConfig
24
+ from rslearn.config import RasterLayerConfig
27
25
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_EPSG, WGS84_PROJECTION
28
26
  from rslearn.tile_stores import PrefixedTileStore, TileStore
29
- from rslearn.utils import (
30
- GridIndex,
31
- Projection,
32
- STGeometry,
33
- daterange,
34
- )
27
+ from rslearn.utils import GridIndex, Projection, STGeometry, daterange
35
28
  from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
36
29
 
37
- from .copernicus import get_harmonize_callback
30
+ from .copernicus import get_harmonize_callback, get_sentinel2_tiles
38
31
  from .data_source import (
39
32
  DataSource,
40
33
  Item,
@@ -66,7 +59,7 @@ class NaipItem(Item):
66
59
  return d
67
60
 
68
61
  @staticmethod
69
- def deserialize(d: dict) -> Item:
62
+ def deserialize(d: dict) -> "NaipItem":
70
63
  """Deserializes an item from a JSON-decoded dictionary."""
71
64
  item = super(NaipItem, NaipItem).deserialize(d)
72
65
  return NaipItem(
@@ -89,7 +82,7 @@ class Naip(DataSource):
89
82
 
90
83
  def __init__(
91
84
  self,
92
- config: LayerConfig,
85
+ config: RasterLayerConfig,
93
86
  index_cache_dir: UPath,
94
87
  use_rtree_index: bool = False,
95
88
  states: list[str] | None = None,
@@ -113,25 +106,21 @@ class Naip(DataSource):
113
106
  self.years = years
114
107
 
115
108
  self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
116
-
109
+ self.rtree_index: Any | None = None
117
110
  if use_rtree_index:
118
111
  from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
119
112
 
120
- def build_fn(index: RtreeIndex):
113
+ def build_fn(index: RtreeIndex) -> None:
121
114
  for item in self._read_index_shapefiles(desc="Building rtree index"):
122
115
  index.insert(item.geometry.shp.bounds, json.dumps(item.serialize()))
123
116
 
124
- self.rtree_tmp_dir = tempfile.TemporaryDirectory()
125
- self.rtree_index = get_cached_rtree(
126
- self.index_cache_dir, self.rtree_tmp_dir.name, build_fn
127
- )
128
- else:
129
- self.rtree_index = None
117
+ self.rtree_index = get_cached_rtree(self.index_cache_dir, build_fn)
130
118
 
131
119
  @staticmethod
132
- def from_config(config: LayerConfig, ds_path: UPath) -> "Naip":
120
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
133
121
  """Creates a new Naip instance from a configuration dictionary."""
134
- assert isinstance(config, RasterLayerConfig)
122
+ if config.data_source is None:
123
+ raise ValueError(f"data_source is required for config dict {config}")
135
124
  d = config.data_source.config_dict
136
125
  kwargs = dict(
137
126
  config=config,
@@ -195,7 +184,9 @@ class Naip(DataSource):
195
184
  blob_path, dst, ExtraArgs={"RequestPayer": "requester"}
196
185
  )
197
186
 
198
- def _read_index_shapefiles(self, desc=None) -> Generator[NaipItem, None, None]:
187
+ def _read_index_shapefiles(
188
+ self, desc: str | None = None
189
+ ) -> Generator[NaipItem, None, None]:
199
190
  """Read the index shapefiles and yield NaipItems corresponding to each image."""
200
191
  self._download_index_shapefiles()
201
192
 
@@ -288,7 +279,7 @@ class Naip(DataSource):
288
279
 
289
280
  def get_items(
290
281
  self, geometries: list[STGeometry], query_config: QueryConfig
291
- ) -> list[list[list[Item]]]:
282
+ ) -> list[list[list[NaipItem]]]:
292
283
  """Get a list of items in the data source intersecting the given geometries.
293
284
 
294
285
  Args:
@@ -302,7 +293,7 @@ class Naip(DataSource):
302
293
  geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
303
294
  ]
304
295
 
305
- items = [[] for _ in geometries]
296
+ items: list = [[] for _ in geometries]
306
297
  if self.rtree_index:
307
298
  for idx, geometry in enumerate(wgs84_geometries):
308
299
  encoded_items = self.rtree_index.query(geometry.shp.bounds)
@@ -331,7 +322,7 @@ class Naip(DataSource):
331
322
  groups.append(cur_groups)
332
323
  return groups
333
324
 
334
- def deserialize_item(self, serialized_item: Any) -> Item:
325
+ def deserialize_item(self, serialized_item: Any) -> NaipItem:
335
326
  """Deserializes an item from JSON-decoded data."""
336
327
  assert isinstance(serialized_item, dict)
337
328
  return NaipItem.deserialize(serialized_item)
@@ -339,7 +330,7 @@ class Naip(DataSource):
339
330
  def ingest(
340
331
  self,
341
332
  tile_store: TileStore,
342
- items: list[Item],
333
+ items: list[NaipItem],
343
334
  geometries: list[list[STGeometry]],
344
335
  ) -> None:
345
336
  """Ingest items into the given tile store.
@@ -407,7 +398,7 @@ class Sentinel2Item(Item):
407
398
  return d
408
399
 
409
400
  @staticmethod
410
- def deserialize(d: dict) -> Item:
401
+ def deserialize(d: dict) -> "Sentinel2Item":
411
402
  """Deserializes an item from a JSON-decoded dictionary."""
412
403
  if "name" not in d:
413
404
  d["name"] = d["blob_path"].split("/")[-1].split(".tif")[0]
@@ -420,7 +411,10 @@ class Sentinel2Item(Item):
420
411
  )
421
412
 
422
413
 
423
- class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
414
+ # TODO: Distinguish between AWS and GCP data sources in class names.
415
+ class Sentinel2(
416
+ ItemLookupDataSource[Sentinel2Item], RetrieveItemDataSource[Sentinel2Item]
417
+ ):
424
418
  """A data source for Sentinel-2 L1C and L2A imagery on AWS.
425
419
 
426
420
  Specifically, uses the sentinel-s2-l1c and sentinel-s2-l2a S3 buckets maintained by
@@ -474,7 +468,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
474
468
 
475
469
  def __init__(
476
470
  self,
477
- config: LayerConfig,
471
+ config: RasterLayerConfig,
478
472
  modality: Sentinel2Modality,
479
473
  metadata_cache_dir: UPath,
480
474
  max_time_delta: timedelta = timedelta(days=30),
@@ -506,9 +500,10 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
506
500
  self.bucket = boto3.resource("s3").Bucket(bucket_name)
507
501
 
508
502
  @staticmethod
509
- def from_config(config: LayerConfig, ds_path: UPath) -> "Sentinel2":
503
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
510
504
  """Creates a new Sentinel2 instance from a configuration dictionary."""
511
- assert isinstance(config, RasterLayerConfig)
505
+ if config.data_source is None:
506
+ raise ValueError("Sentinel2 data source requires a data source config")
512
507
  d = config.data_source.config_dict
513
508
  kwargs = dict(
514
509
  config=config,
@@ -528,7 +523,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
528
523
  return Sentinel2(**kwargs)
529
524
 
530
525
  def _read_products(
531
- self, needed_cell_months: set[tuple[str, int, int, int]]
526
+ self, needed_cell_months: set[tuple[str, int, int]]
532
527
  ) -> Generator[Sentinel2Item, None, None]:
533
528
  """Read productInfo.json files and yield relevant Sentinel2Items.
534
529
 
@@ -603,7 +598,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
603
598
 
604
599
  def get_items(
605
600
  self, geometries: list[STGeometry], query_config: QueryConfig
606
- ) -> list[list[list[Item]]]:
601
+ ) -> list[list[list[Sentinel2Item]]]:
607
602
  """Get a list of items in the data source intersecting the given geometries.
608
603
 
609
604
  Args:
@@ -626,14 +621,14 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
626
621
  raise ValueError(
627
622
  "Sentinel2 on AWS requires geometry time ranges to be set"
628
623
  )
629
- for cell_id in rslearn.utils.mgrs.for_each_cell(wgs84_geometry.shp.bounds):
624
+ for cell_id in get_sentinel2_tiles(wgs84_geometry, self.metadata_cache_dir):
630
625
  for ts in daterange(
631
626
  wgs84_geometry.time_range[0] - self.max_time_delta,
632
627
  wgs84_geometry.time_range[1] + self.max_time_delta,
633
628
  ):
634
629
  needed_cell_months.add((cell_id, ts.year, ts.month))
635
630
 
636
- items_by_cell = {}
631
+ items_by_cell: dict[str, list[Sentinel2Item]] = {}
637
632
  for item in self._read_products(needed_cell_months):
638
633
  cell_id = "".join(item.blob_path.split("/")[1:4])
639
634
  if cell_id not in items_by_cell:
@@ -643,7 +638,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
643
638
  groups = []
644
639
  for geometry, wgs84_geometry in zip(geometries, wgs84_geometries):
645
640
  items = []
646
- for cell_id in rslearn.utils.mgrs.for_each_cell(wgs84_geometry.shp.bounds):
641
+ for cell_id in get_sentinel2_tiles(wgs84_geometry, self.metadata_cache_dir):
647
642
  for item in items_by_cell.get(cell_id, []):
648
643
  try:
649
644
  item_geom = item.geometry.to_projection(geometry.projection)
@@ -666,7 +661,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
666
661
 
667
662
  return groups
668
663
 
669
- def get_item_by_name(self, name: str) -> Item:
664
+ def get_item_by_name(self, name: str) -> Sentinel2Item:
670
665
  """Gets an item by name."""
671
666
  # Product name is like:
672
667
  # S2B_MSIL1C_20240201T230819_N0510_R015_T51CWM_20240202T012755.
@@ -685,12 +680,14 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
685
680
  return item
686
681
  raise ValueError(f"item {name} not found")
687
682
 
688
- def deserialize_item(self, serialized_item: Any) -> Item:
683
+ def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
689
684
  """Deserializes an item from JSON-decoded data."""
690
685
  assert isinstance(serialized_item, dict)
691
686
  return Sentinel2Item.deserialize(serialized_item)
692
687
 
693
- def retrieve_item(self, item: Item) -> Generator[tuple[str, BinaryIO], None, None]:
688
+ def retrieve_item(
689
+ self, item: Sentinel2Item
690
+ ) -> Generator[tuple[str, BinaryIO], None, None]:
694
691
  """Retrieves the rasters corresponding to an item as file streams."""
695
692
  for fname, _ in self.band_fnames[self.modality]:
696
693
  buf = io.BytesIO()
@@ -701,7 +698,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
701
698
  yield (fname, buf)
702
699
 
703
700
  def _get_harmonize_callback(
704
- self, item: Item
701
+ self, item: Sentinel2Item
705
702
  ) -> Callable[[npt.NDArray], npt.NDArray] | None:
706
703
  """Gets the harmonization callback for the given item.
707
704
 
@@ -715,6 +712,8 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
715
712
  return None
716
713
  # Search metadata XML for the RADIO_ADD_OFFSET tag.
717
714
  # This contains the per-band offset, but we assume all bands have the same offset.
715
+ if item.geometry.time_range is None:
716
+ raise ValueError("Sentinel2 on AWS requires geometry time ranges to be set")
718
717
  ts = item.geometry.time_range[0]
719
718
  metadata_fname = (
720
719
  f"products/{ts.year}/{ts.month}/{ts.day}/{item.name}/metadata.xml"
@@ -730,7 +729,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
730
729
  def ingest(
731
730
  self,
732
731
  tile_store: TileStore,
733
- items: list[Item],
732
+ items: list[Sentinel2Item],
734
733
  geometries: list[list[STGeometry]],
735
734
  ) -> None:
736
735
  """Ingest items into the given tile store.