rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,28 @@
1
1
  """Data source for raster data in ESA Copernicus API."""
2
2
 
3
+ import functools
4
+ import io
5
+ import json
6
+ import shutil
7
+ import urllib.request
3
8
  import xml.etree.ElementTree as ET
9
+ import zipfile
4
10
  from collections.abc import Callable
5
11
 
6
12
  import numpy as np
7
13
  import numpy.typing as npt
14
+ from upath import UPath
15
+
16
+ from rslearn.const import WGS84_PROJECTION
17
+ from rslearn.log_utils import get_logger
18
+ from rslearn.utils.fsspec import open_atomic
19
+ from rslearn.utils.geometry import STGeometry, flatten_shape
20
+ from rslearn.utils.grid_index import GridIndex
21
+
22
+ SENTINEL2_TILE_URL = "https://sentiwiki.copernicus.eu/__attachments/1692737/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.zip"
23
+ SENTINEL2_KML_NAMESPACE = "{http://www.opengis.net/kml/2.2}"
24
+
25
+ logger = get_logger(__name__)
8
26
 
9
27
 
10
28
  def get_harmonize_callback(
@@ -24,6 +42,8 @@ def get_harmonize_callback(
24
42
  """
25
43
  offset = None
26
44
  for el in tree.iter("RADIO_ADD_OFFSET"):
45
+ if el.text is None:
46
+ raise ValueError(f"text is missing in {el}")
27
47
  value = int(el.text)
28
48
  if offset is None:
29
49
  offset = value
@@ -36,7 +56,133 @@ def get_harmonize_callback(
36
56
  if offset is None or offset == 0:
37
57
  return None
38
58
 
39
- def callback(array):
40
- return np.clip(array, -offset, None) + offset
59
+ def callback(array: npt.NDArray) -> npt.NDArray:
60
+ return np.clip(array, -offset, None) + offset # type: ignore
41
61
 
42
62
  return callback
63
+
64
+
65
+ def _cache_sentinel2_tile_index(cache_dir: UPath) -> None:
66
+ """Cache the tiles from SENTINEL2_TILE_URL.
67
+
68
+ This way we just need to download it once.
69
+ """
70
+ json_fname = cache_dir / "tile_index.json"
71
+
72
+ if json_fname.exists():
73
+ return
74
+
75
+ logger.info(f"caching list of Sentinel-2 tiles to {json_fname}")
76
+
77
+ # Identify the Sentinel-2 tile names and bounds using the KML file.
78
+ # First, download the zip file and extract and parse the KML.
79
+ buf = io.BytesIO()
80
+ with urllib.request.urlopen(SENTINEL2_TILE_URL) as response:
81
+ shutil.copyfileobj(response, buf)
82
+ buf.seek(0)
83
+ with zipfile.ZipFile(buf, "r") as zipf:
84
+ member_names = zipf.namelist()
85
+ if len(member_names) != 1:
86
+ raise ValueError(
87
+ "Sentinel-2 tile zip file unexpectedly contains more than one file"
88
+ )
89
+
90
+ with zipf.open(member_names[0]) as memberf:
91
+ tree = ET.parse(memberf)
92
+
93
+ # Map from the tile name to the longitude/latitude bounds.
94
+ json_data: dict[str, tuple[float, float, float, float]] = {}
95
+
96
+ # The KML is list of Placemark so iterate over those.
97
+ for placemark_node in tree.iter(SENTINEL2_KML_NAMESPACE + "Placemark"):
98
+ # The <name> node specifies the Sentinel-2 tile name.
99
+ name_node = placemark_node.find(SENTINEL2_KML_NAMESPACE + "name")
100
+ if name_node is None or name_node.text is None:
101
+ raise ValueError("Sentinel-2 KML has Placemark without valid name node")
102
+
103
+ tile_name = name_node.text
104
+
105
+ # There may be one or more <coordinates> nodes depending on whether it is a
106
+ # MultiGeometry. Here we just iterate over all of the coordinates since we are
107
+ # only interested in the bounds in WGS-84 coordinates.
108
+ lons = []
109
+ lats = []
110
+ for coord_node in placemark_node.iter(SENTINEL2_KML_NAMESPACE + "coordinates"):
111
+ # It is list of space-separated coordinates like:
112
+ # 180,-73.0597374076,0 176.8646237862,-72.9914734628,0 ...
113
+ if coord_node.text is None:
114
+ raise ValueError("Sentinel-2 KML has coordinates node missing text")
115
+
116
+ point_strs = coord_node.text.strip().split()
117
+ for point_str in point_strs:
118
+ parts = point_str.split(",")
119
+ if len(parts) != 2 and len(parts) != 3:
120
+ continue
121
+
122
+ lon = float(parts[0])
123
+ lat = float(parts[1])
124
+ lons.append(lon)
125
+ lats.append(lat)
126
+
127
+ if len(lons) == 0 or len(lats) == 0:
128
+ raise ValueError("Sentinel-2 KML has Placemark with no coordinates")
129
+
130
+ bounds = (
131
+ min(lons),
132
+ min(lats),
133
+ max(lons),
134
+ max(lats),
135
+ )
136
+ json_data[tile_name] = bounds
137
+
138
+ with open_atomic(json_fname, "w") as f:
139
+ json.dump(json_data, f)
140
+
141
+
142
+ @functools.cache
143
+ def load_sentinel2_tile_index(cache_dir: UPath) -> GridIndex:
144
+ """Load a GridIndex over Sentinel-2 tiles.
145
+
146
+ This function is cached so the GridIndex only needs to be constructed once (per
147
+ process).
148
+
149
+ Args:
150
+ cache_dir: the directory to cache the list of Sentinel-2 tiles.
151
+
152
+ Returns:
153
+ GridIndex over the tile names
154
+ """
155
+ _cache_sentinel2_tile_index(cache_dir)
156
+ json_fname = cache_dir / "tile_index.json"
157
+ with json_fname.open() as f:
158
+ json_data = json.load(f)
159
+
160
+ grid_index = GridIndex(0.5)
161
+ for tile_name, bounds in json_data.items():
162
+ grid_index.insert(bounds, tile_name)
163
+
164
+ return grid_index
165
+
166
+
167
+ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
168
+ """Get all Sentinel-2 tiles (like 01CCV) intersecting the given geometry.
169
+
170
+ Args:
171
+ geometry: the geometry to check.
172
+ cache_dir: directory to cache the tiles.
173
+
174
+ Returns:
175
+ list of Sentinel-2 tile names that intersect the geometry.
176
+ """
177
+ tile_index = load_sentinel2_tile_index(cache_dir)
178
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
179
+ # If the shape is a collection, it could be cutting across prime meridian.
180
+ # So we query each component shape separately and collect the results to avoid
181
+ # issues.
182
+ # We assume the caller has already applied split_at_prime_meridian.
183
+ results = set()
184
+ for shp in flatten_shape(wgs84_geometry.shp):
185
+ for result in tile_index.query(shp.bounds):
186
+ assert isinstance(result, str)
187
+ results.add(result)
188
+ return list(results)
@@ -1,7 +1,7 @@
1
1
  """Base classes for rslearn data sources."""
2
2
 
3
3
  from collections.abc import Generator
4
- from typing import Any, BinaryIO
4
+ from typing import Any, BinaryIO, Generic, TypeVar
5
5
 
6
6
  from rslearn.config import LayerConfig, QueryConfig
7
7
  from rslearn.dataset import Window
@@ -51,15 +51,20 @@ class Item:
51
51
  return hash(self.name)
52
52
 
53
53
 
54
- class DataSource:
54
+ ItemType = TypeVar("ItemType", bound="Item")
55
+
56
+
57
+ class DataSource(Generic[ItemType]):
55
58
  """A set of raster or vector files that can be retrieved.
56
59
 
57
60
  Data sources should support at least one of ingest and materialize.
58
61
  """
59
62
 
63
+ TIMEOUT = 1000000 # Set very high to start
64
+
60
65
  def get_items(
61
66
  self, geometries: list[STGeometry], query_config: QueryConfig
62
- ) -> list[list[list[Item]]]:
67
+ ) -> list[list[list[ItemType]]]:
63
68
  """Get a list of items in the data source intersecting the given geometries.
64
69
 
65
70
  Args:
@@ -71,14 +76,14 @@ class DataSource:
71
76
  """
72
77
  raise NotImplementedError
73
78
 
74
- def deserialize_item(self, serialized_item: Any) -> Item:
79
+ def deserialize_item(self, serialized_item: Any) -> ItemType:
75
80
  """Deserializes an item from JSON-decoded data."""
76
81
  raise NotImplementedError
77
82
 
78
83
  def ingest(
79
84
  self,
80
85
  tile_store: TileStore,
81
- items: list[Item],
86
+ items: list[ItemType],
82
87
  geometries: list[list[STGeometry]],
83
88
  ) -> None:
84
89
  """Ingest items into the given tile store.
@@ -93,7 +98,7 @@ class DataSource:
93
98
  def materialize(
94
99
  self,
95
100
  window: Window,
96
- item_groups: list[list[Item]],
101
+ item_groups: list[list[ItemType]],
97
102
  layer_name: str,
98
103
  layer_cfg: LayerConfig,
99
104
  ) -> None:
@@ -108,17 +113,19 @@ class DataSource:
108
113
  raise NotImplementedError
109
114
 
110
115
 
111
- class ItemLookupDataSource(DataSource):
116
+ class ItemLookupDataSource(DataSource[ItemType]):
112
117
  """A data source that can look up items by name."""
113
118
 
114
- def get_item_by_name(self, name: str) -> Item:
119
+ def get_item_by_name(self, name: str) -> ItemType:
115
120
  """Gets an item by name."""
116
121
  raise NotImplementedError
117
122
 
118
123
 
119
- class RetrieveItemDataSource(DataSource):
124
+ class RetrieveItemDataSource(DataSource[ItemType]):
120
125
  """A data source that can retrieve items in their raw format."""
121
126
 
122
- def retrieve_item(self, item: Item) -> Generator[tuple[str, BinaryIO], None, None]:
127
+ def retrieve_item(
128
+ self, item: ItemType
129
+ ) -> Generator[tuple[str, BinaryIO], None, None]:
123
130
  """Retrieves the rasters corresponding to an item as file streams."""
124
131
  raise NotImplementedError