rslearn 0.0.28__py3-none-any.whl → 0.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/aws_google_satellite_embedding_v1.py +358 -0
- rslearn/data_sources/direct_materialize_data_source.py +6 -17
- rslearn/data_sources/utils.py +10 -3
- rslearn/models/global_pool.py +74 -0
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/METADATA +1 -1
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/RECORD +11 -9
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/WHEEL +0 -0
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.28.dist-info → rslearn-0.0.29.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""Data source for Google Satellite Embedding V1 dataset on AWS Open Data."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from datetime import UTC, datetime
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import boto3
|
|
10
|
+
import numpy as np
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import rasterio
|
|
14
|
+
import rasterio.vrt
|
|
15
|
+
import shapely
|
|
16
|
+
import shapely.wkt
|
|
17
|
+
from botocore import UNSIGNED
|
|
18
|
+
from botocore.config import Config
|
|
19
|
+
from rasterio.enums import Resampling
|
|
20
|
+
from upath import UPath
|
|
21
|
+
|
|
22
|
+
import rslearn.data_sources.utils
|
|
23
|
+
from rslearn.const import WGS84_PROJECTION
|
|
24
|
+
from rslearn.data_sources.data_source import (
|
|
25
|
+
DataSourceContext,
|
|
26
|
+
Item,
|
|
27
|
+
QueryConfig,
|
|
28
|
+
)
|
|
29
|
+
from rslearn.data_sources.direct_materialize_data_source import (
|
|
30
|
+
DirectMaterializeDataSource,
|
|
31
|
+
)
|
|
32
|
+
from rslearn.utils.fsspec import join_upath
|
|
33
|
+
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
34
|
+
from rslearn.utils.grid_index import GridIndex
|
|
35
|
+
|
|
36
|
+
# Band names for the 64 embedding channels
|
|
37
|
+
BANDS = [f"A{idx:02d}" for idx in range(64)]
|
|
38
|
+
|
|
39
|
+
# S3 bucket configuration
|
|
40
|
+
BUCKET_NAME = "us-west-2.opendata.source.coop"
|
|
41
|
+
BUCKET_PREFIX = "tge-labs/aef/v1/annual"
|
|
42
|
+
INDEX_KEY = f"{BUCKET_PREFIX}/aef_index.csv"
|
|
43
|
+
HTTP_URL_BASE = f"https://s3.us-west-2.amazonaws.com/{BUCKET_NAME}"
|
|
44
|
+
|
|
45
|
+
# Grid index cell size for spatial queries
|
|
46
|
+
GRID_SIZE = 1.0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GoogleSatelliteEmbeddingV1Item(Item):
|
|
50
|
+
"""An item in the GoogleSatelliteEmbeddingV1 data source."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
name: str,
|
|
55
|
+
geometry: STGeometry,
|
|
56
|
+
s3_path: str,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""Creates a new GoogleSatelliteEmbeddingV1Item.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
name: unique name of the item (the filename without extension)
|
|
62
|
+
geometry: the spatial and temporal extent of the item
|
|
63
|
+
s3_path: full S3 path to the TIFF file
|
|
64
|
+
"""
|
|
65
|
+
super().__init__(name, geometry)
|
|
66
|
+
self.s3_path = s3_path
|
|
67
|
+
|
|
68
|
+
def serialize(self) -> dict:
|
|
69
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
70
|
+
d = super().serialize()
|
|
71
|
+
d["s3_path"] = self.s3_path
|
|
72
|
+
return d
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def deserialize(d: dict) -> "GoogleSatelliteEmbeddingV1Item":
|
|
76
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
77
|
+
item = super(
|
|
78
|
+
GoogleSatelliteEmbeddingV1Item, GoogleSatelliteEmbeddingV1Item
|
|
79
|
+
).deserialize(d)
|
|
80
|
+
return GoogleSatelliteEmbeddingV1Item(
|
|
81
|
+
name=item.name,
|
|
82
|
+
geometry=item.geometry,
|
|
83
|
+
s3_path=d["s3_path"],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class GoogleSatelliteEmbeddingV1(
|
|
88
|
+
DirectMaterializeDataSource[GoogleSatelliteEmbeddingV1Item]
|
|
89
|
+
):
|
|
90
|
+
"""Data source for Google Satellite Embedding V1 on AWS Open Data.
|
|
91
|
+
|
|
92
|
+
It consists of annual satellite embeddings at 10m resolution with 64 bands
|
|
93
|
+
(A00-A63). The data is stored as Cloud-Optimized GeoTIFFs organized by year and UTM
|
|
94
|
+
zone. Each file covers 8192x8192 pixels.
|
|
95
|
+
|
|
96
|
+
Available years: 2018-2024.
|
|
97
|
+
|
|
98
|
+
See https://registry.opendata.aws/aef-source/ for details.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
metadata_cache_dir: str,
|
|
104
|
+
apply_dequantization: bool = True,
|
|
105
|
+
context: DataSourceContext = DataSourceContext(),
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Initialize a new GoogleSatelliteEmbeddingV1 instance.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
metadata_cache_dir: directory to cache the index file.
|
|
111
|
+
apply_dequantization: whether to apply de-quantization to convert
|
|
112
|
+
int8 values to float32. The raw data is quantized int8; the
|
|
113
|
+
de-quantization maps values to [-1, 1] using the formula:
|
|
114
|
+
((values / 127.5) ** 2) * sign(values). The raw data has nodata value
|
|
115
|
+
-128 while with dequantization the nodata value is -1.0. See
|
|
116
|
+
https://source.coop/tge-labs/aef for details.
|
|
117
|
+
context: the data source context.
|
|
118
|
+
"""
|
|
119
|
+
# We have a single asset containing all 64 bands. Here "image" is an arbitrary
|
|
120
|
+
# name, since DirectMaterializeDataSource requires an asset name.
|
|
121
|
+
super().__init__(asset_bands={"image": BANDS})
|
|
122
|
+
|
|
123
|
+
self.apply_dequantization = apply_dequantization
|
|
124
|
+
|
|
125
|
+
# Set up cache directory
|
|
126
|
+
if context.ds_path is not None:
|
|
127
|
+
self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
|
|
128
|
+
else:
|
|
129
|
+
self.metadata_cache_dir = UPath(metadata_cache_dir)
|
|
130
|
+
self.metadata_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
131
|
+
|
|
132
|
+
# S3 client with anonymous access (only used for downloading index)
|
|
133
|
+
self.s3_client = boto3.client(
|
|
134
|
+
"s3",
|
|
135
|
+
config=Config(signature_version=UNSIGNED),
|
|
136
|
+
region_name="us-west-2",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Lazy-loaded grid index
|
|
140
|
+
self._grid_index: GridIndex | None = None
|
|
141
|
+
self._items_by_name: dict[str, GoogleSatelliteEmbeddingV1Item] | None = None
|
|
142
|
+
|
|
143
|
+
def _read_index_csv(self) -> pd.DataFrame:
|
|
144
|
+
"""Read the index CSV, downloading from S3 if not cached.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
DataFrame with WKT, path, and year columns.
|
|
148
|
+
"""
|
|
149
|
+
cache_file = self.metadata_cache_dir / "aef_index.csv"
|
|
150
|
+
if not cache_file.exists():
|
|
151
|
+
response = self.s3_client.get_object(Bucket=BUCKET_NAME, Key=INDEX_KEY)
|
|
152
|
+
content = response["Body"].read()
|
|
153
|
+
with cache_file.open("wb") as f:
|
|
154
|
+
f.write(content)
|
|
155
|
+
|
|
156
|
+
return pd.read_csv(
|
|
157
|
+
cache_file,
|
|
158
|
+
header=None,
|
|
159
|
+
usecols=[0, 2, 3],
|
|
160
|
+
names=["WKT", "path", "year"],
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def _load_index(
|
|
164
|
+
self,
|
|
165
|
+
) -> tuple[GridIndex, dict[str, GoogleSatelliteEmbeddingV1Item]]:
|
|
166
|
+
"""Load the index file and build spatial index.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Tuple of (grid_index, items_by_name dict).
|
|
170
|
+
"""
|
|
171
|
+
if self._grid_index is not None and self._items_by_name is not None:
|
|
172
|
+
return self._grid_index, self._items_by_name
|
|
173
|
+
|
|
174
|
+
df = self._read_index_csv()
|
|
175
|
+
|
|
176
|
+
grid_index = GridIndex(GRID_SIZE)
|
|
177
|
+
items_by_name: dict[str, GoogleSatelliteEmbeddingV1Item] = {}
|
|
178
|
+
|
|
179
|
+
for _, row in df.iterrows():
|
|
180
|
+
shp = shapely.wkt.loads(row["WKT"])
|
|
181
|
+
|
|
182
|
+
year = int(row["year"])
|
|
183
|
+
time_range = (
|
|
184
|
+
datetime(year, 1, 1, tzinfo=UTC),
|
|
185
|
+
datetime(year, 12, 31, 23, 59, 59, tzinfo=UTC),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
s3_path = row["path"]
|
|
189
|
+
name = s3_path.split("/")[-1].replace(".tiff", "")
|
|
190
|
+
|
|
191
|
+
geometry = STGeometry(WGS84_PROJECTION, shp, time_range)
|
|
192
|
+
item = GoogleSatelliteEmbeddingV1Item(
|
|
193
|
+
name=name,
|
|
194
|
+
geometry=geometry,
|
|
195
|
+
s3_path=s3_path,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
grid_index.insert(shp.bounds, item)
|
|
199
|
+
items_by_name[name] = item
|
|
200
|
+
|
|
201
|
+
self._grid_index = grid_index
|
|
202
|
+
self._items_by_name = items_by_name
|
|
203
|
+
return grid_index, items_by_name
|
|
204
|
+
|
|
205
|
+
# --- DataSource implementation ---
|
|
206
|
+
|
|
207
|
+
def get_items(
|
|
208
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
209
|
+
) -> list[list[list[GoogleSatelliteEmbeddingV1Item]]]:
|
|
210
|
+
"""Get a list of items in the data source intersecting the given geometries."""
|
|
211
|
+
grid_index, _ = self._load_index()
|
|
212
|
+
|
|
213
|
+
wgs84_geometries = [
|
|
214
|
+
geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
|
|
215
|
+
]
|
|
216
|
+
|
|
217
|
+
groups = []
|
|
218
|
+
for geometry, wgs84_geometry in zip(geometries, wgs84_geometries):
|
|
219
|
+
cur_items = []
|
|
220
|
+
for item in grid_index.query(wgs84_geometry.shp.bounds):
|
|
221
|
+
if not wgs84_geometry.shp.intersects(item.geometry.shp):
|
|
222
|
+
continue
|
|
223
|
+
# Check time range if specified
|
|
224
|
+
if wgs84_geometry.time_range is not None:
|
|
225
|
+
item_start, item_end = item.geometry.time_range
|
|
226
|
+
query_start, query_end = wgs84_geometry.time_range
|
|
227
|
+
if item_end < query_start or item_start > query_end:
|
|
228
|
+
continue
|
|
229
|
+
cur_items.append(item)
|
|
230
|
+
|
|
231
|
+
cur_items.sort(key=lambda item: item.geometry.time_range[0])
|
|
232
|
+
|
|
233
|
+
cur_groups: list[list[GoogleSatelliteEmbeddingV1Item]] = (
|
|
234
|
+
rslearn.data_sources.utils.match_candidate_items_to_window(
|
|
235
|
+
geometry, cur_items, query_config
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
groups.append(cur_groups)
|
|
239
|
+
|
|
240
|
+
return groups
|
|
241
|
+
|
|
242
|
+
def get_item_by_name(self, name: str) -> GoogleSatelliteEmbeddingV1Item:
|
|
243
|
+
"""Gets an item by name."""
|
|
244
|
+
_, items_by_name = self._load_index()
|
|
245
|
+
if name not in items_by_name:
|
|
246
|
+
raise ValueError(f"item {name} not found")
|
|
247
|
+
return items_by_name[name]
|
|
248
|
+
|
|
249
|
+
def deserialize_item(self, serialized_item: dict) -> GoogleSatelliteEmbeddingV1Item:
|
|
250
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
251
|
+
return GoogleSatelliteEmbeddingV1Item.deserialize(serialized_item)
|
|
252
|
+
|
|
253
|
+
def ingest(
|
|
254
|
+
self,
|
|
255
|
+
tile_store: Any,
|
|
256
|
+
items: list[GoogleSatelliteEmbeddingV1Item],
|
|
257
|
+
geometries: list[list[STGeometry]],
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Ingest items into the given tile store.
|
|
260
|
+
|
|
261
|
+
Note: Each file is 2-3GB so this can be slow. Direct materialization via
|
|
262
|
+
read_raster or materialize is recommended for most use cases.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
tile_store: the tile store to ingest into
|
|
266
|
+
items: the items to ingest
|
|
267
|
+
geometries: a list of geometries needed for each item
|
|
268
|
+
"""
|
|
269
|
+
for item in items:
|
|
270
|
+
if tile_store.is_raster_ready(item.name, BANDS):
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
# Download the TIFF file directly to disk
|
|
274
|
+
key = item.s3_path.replace(f"s3://{BUCKET_NAME}/", "")
|
|
275
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
276
|
+
local_path = os.path.join(tmp_dir, f"{item.name}.tiff")
|
|
277
|
+
self.s3_client.download_file(BUCKET_NAME, key, local_path)
|
|
278
|
+
tile_store.write_raster_file(item.name, BANDS, UPath(local_path))
|
|
279
|
+
|
|
280
|
+
# --- DirectMaterializeDataSource implementation ---
|
|
281
|
+
|
|
282
|
+
def get_asset_url(self, item_name: str, asset_key: str) -> str:
|
|
283
|
+
"""Get the HTTP URL to read the asset.
|
|
284
|
+
|
|
285
|
+
Returns a /vsicurl/ URL that rasterio can read directly over HTTP.
|
|
286
|
+
"""
|
|
287
|
+
item = self.get_item_by_name(item_name)
|
|
288
|
+
# Convert s3://bucket/path to HTTP URL
|
|
289
|
+
key = item.s3_path.replace(f"s3://{BUCKET_NAME}/", "")
|
|
290
|
+
return f"/vsicurl/{HTTP_URL_BASE}/{key}"
|
|
291
|
+
|
|
292
|
+
def get_read_callback(
|
|
293
|
+
self, item_name: str, asset_key: str
|
|
294
|
+
) -> Callable[[npt.NDArray[Any]], npt.NDArray[Any]] | None:
|
|
295
|
+
"""Return a callback to apply de-quantization if enabled."""
|
|
296
|
+
if not self.apply_dequantization:
|
|
297
|
+
return None
|
|
298
|
+
|
|
299
|
+
def dequantize(data: npt.NDArray[Any]) -> npt.NDArray[np.float32]:
|
|
300
|
+
# Handle nodata (-128)
|
|
301
|
+
nodata_mask = data == -128
|
|
302
|
+
float_data = data.astype(np.float32)
|
|
303
|
+
# This is the dequantization formula recommended at https://source.coop/tge-labs/aef.
|
|
304
|
+
result = ((float_data / 127.5) ** 2) * np.sign(float_data)
|
|
305
|
+
# We make sure that NODATA is exactly -1.0 so user can handle it appropriately.
|
|
306
|
+
result[nodata_mask] = -1.0
|
|
307
|
+
return result
|
|
308
|
+
|
|
309
|
+
return dequantize
|
|
310
|
+
|
|
311
|
+
def read_raster(
|
|
312
|
+
self,
|
|
313
|
+
layer_name: str,
|
|
314
|
+
item_name: str,
|
|
315
|
+
bands: list[str],
|
|
316
|
+
projection: Projection,
|
|
317
|
+
bounds: PixelBounds,
|
|
318
|
+
resampling: Resampling = Resampling.bilinear,
|
|
319
|
+
) -> npt.NDArray[Any]:
|
|
320
|
+
"""Read raster data from the store.
|
|
321
|
+
|
|
322
|
+
Overrides base class to handle band selection (the base class reads all bands).
|
|
323
|
+
"""
|
|
324
|
+
asset_url = self.get_asset_url(item_name, "image")
|
|
325
|
+
|
|
326
|
+
# Determine which band indices to read (1-indexed for rasterio)
|
|
327
|
+
if bands == BANDS:
|
|
328
|
+
band_indices = list(range(1, 65))
|
|
329
|
+
else:
|
|
330
|
+
band_indices = [BANDS.index(b) + 1 for b in bands]
|
|
331
|
+
|
|
332
|
+
# Construct the transform for the requested bounds
|
|
333
|
+
wanted_transform = rasterio.transform.Affine(
|
|
334
|
+
projection.x_resolution,
|
|
335
|
+
0,
|
|
336
|
+
bounds[0] * projection.x_resolution,
|
|
337
|
+
0,
|
|
338
|
+
projection.y_resolution,
|
|
339
|
+
bounds[1] * projection.y_resolution,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
with rasterio.open(asset_url) as src:
|
|
343
|
+
with rasterio.vrt.WarpedVRT(
|
|
344
|
+
src,
|
|
345
|
+
crs=projection.crs,
|
|
346
|
+
transform=wanted_transform,
|
|
347
|
+
width=bounds[2] - bounds[0],
|
|
348
|
+
height=bounds[3] - bounds[1],
|
|
349
|
+
resampling=resampling,
|
|
350
|
+
) as vrt:
|
|
351
|
+
data = vrt.read(indexes=band_indices)
|
|
352
|
+
|
|
353
|
+
# Apply callback if dequantization is enabled
|
|
354
|
+
callback = self.get_read_callback(item_name, "image")
|
|
355
|
+
if callback is not None:
|
|
356
|
+
data = callback(data)
|
|
357
|
+
|
|
358
|
+
return data
|
|
@@ -11,14 +11,16 @@ import rasterio.vrt
|
|
|
11
11
|
from rasterio.enums import Resampling
|
|
12
12
|
|
|
13
13
|
from rslearn.config import LayerConfig
|
|
14
|
-
from rslearn.data_sources.data_source import
|
|
14
|
+
from rslearn.data_sources.data_source import ItemLookupDataSource, ItemType
|
|
15
15
|
from rslearn.dataset import Window
|
|
16
16
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
17
17
|
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
18
18
|
from rslearn.utils.geometry import PixelBounds, Projection
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class DirectMaterializeDataSource(
|
|
21
|
+
class DirectMaterializeDataSource(
|
|
22
|
+
ItemLookupDataSource[ItemType], TileStore, Generic[ItemType]
|
|
23
|
+
):
|
|
22
24
|
"""Base class for data sources that support direct materialization via TileStore.
|
|
23
25
|
|
|
24
26
|
This class provides common TileStore functionality for data sources that can read
|
|
@@ -27,9 +29,10 @@ class DirectMaterializeDataSource(DataSource[ItemType], TileStore, Generic[ItemT
|
|
|
27
29
|
|
|
28
30
|
Subclasses must implement:
|
|
29
31
|
- get_asset_url(): Get the URL for an asset given item name and bands
|
|
30
|
-
- get_item_by_name(): Get an item by its name
|
|
31
32
|
|
|
32
33
|
Subclasses may optionally override:
|
|
34
|
+
- get_item_by_name(): Inherited from ItemLookupDataSource. If also inheriting
|
|
35
|
+
from a class that provides it (e.g., StacDataSource), no override needed.
|
|
33
36
|
- get_raster_bands(): By default, we assume that items have all assets. If
|
|
34
37
|
items may have a subset of assets, override get_raster_bands to return
|
|
35
38
|
the sets of bands available for that item.
|
|
@@ -77,20 +80,6 @@ class DirectMaterializeDataSource(DataSource[ItemType], TileStore, Generic[ItemT
|
|
|
77
80
|
"""
|
|
78
81
|
raise NotImplementedError
|
|
79
82
|
|
|
80
|
-
def get_item_by_name(self, name: str) -> ItemType:
|
|
81
|
-
"""Get an item by its name.
|
|
82
|
-
|
|
83
|
-
Subclasses must implement this method, either directly or by inheriting from
|
|
84
|
-
a class that provides it (e.g., StacDataSource).
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
name: the name of the item to get.
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
the item object.
|
|
91
|
-
"""
|
|
92
|
-
raise NotImplementedError
|
|
93
|
-
|
|
94
83
|
# --- Optional hooks for subclasses ---
|
|
95
84
|
|
|
96
85
|
def get_read_callback(
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -418,7 +418,8 @@ def match_candidate_items_to_window(
|
|
|
418
418
|
)
|
|
419
419
|
|
|
420
420
|
# Now apply space mode.
|
|
421
|
-
|
|
421
|
+
acceptable_items = []
|
|
422
|
+
acceptable_item_shps = []
|
|
422
423
|
for item in items:
|
|
423
424
|
item_geom = item.geometry
|
|
424
425
|
# We need to re-project items to the geometry projection for the spatial checks
|
|
@@ -430,14 +431,20 @@ def match_candidate_items_to_window(
|
|
|
430
431
|
item_geom = geometry
|
|
431
432
|
else:
|
|
432
433
|
item_geom = item_geom.to_projection(geometry.projection)
|
|
433
|
-
|
|
434
|
+
|
|
435
|
+
if item_geom.shp.area == 0:
|
|
436
|
+
# Must have been an item that didn't quite match the window's spatial extent.
|
|
437
|
+
continue
|
|
438
|
+
|
|
439
|
+
acceptable_items.append(item)
|
|
440
|
+
acceptable_item_shps.append(item_geom.shp)
|
|
434
441
|
|
|
435
442
|
# Dispatch to the appropriate space mode handler
|
|
436
443
|
handler = space_mode_handlers.get(query_config.space_mode)
|
|
437
444
|
if handler is None:
|
|
438
445
|
raise ValueError(f"invalid space mode {query_config.space_mode}")
|
|
439
446
|
|
|
440
|
-
groups = handler(geometry,
|
|
447
|
+
groups = handler(geometry, acceptable_items, acceptable_item_shps, query_config)
|
|
441
448
|
|
|
442
449
|
# Enforce minimum matches if set.
|
|
443
450
|
if len(groups) < query_config.min_matches:
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Global pooling decoder for spatial feature maps."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureMaps, FeatureVector, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GlobalPool(IntermediateComponent):
|
|
13
|
+
"""Apply global pooling to reduce spatial dimensions.
|
|
14
|
+
|
|
15
|
+
This component applies global average or max pooling over the spatial dimensions
|
|
16
|
+
of input feature maps. By default, it produces FeatureVector (BxC) suitable for
|
|
17
|
+
ClassificationHead or RegressionHead. When keep_spatial_dims=True, it produces
|
|
18
|
+
1x1 FeatureMaps suitable for EmbeddingHead.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
mode: Literal["mean", "max"] = "mean",
|
|
24
|
+
keep_spatial_dims: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Create a new GlobalPool.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
mode: the pooling mode, either "mean" for global average pooling or
|
|
30
|
+
"max" for global max pooling. Defaults to "mean".
|
|
31
|
+
keep_spatial_dims: if True, returns FeatureMaps with 1x1 spatial dimensions.
|
|
32
|
+
If False (default), returns FeatureVector (BxC). Defaults to False.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
if mode not in ("mean", "max"):
|
|
36
|
+
raise ValueError(f"mode must be 'mean' or 'max', got '{mode}'")
|
|
37
|
+
self.mode = mode
|
|
38
|
+
self.keep_spatial_dims = keep_spatial_dims
|
|
39
|
+
|
|
40
|
+
def forward(
|
|
41
|
+
self, intermediates: Any, context: ModelContext
|
|
42
|
+
) -> FeatureMaps | FeatureVector:
|
|
43
|
+
"""Apply global pooling on the feature maps.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
intermediates: output from the previous model component, which must be
|
|
47
|
+
a FeatureMaps.
|
|
48
|
+
context: the model context.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
If keep_spatial_dims=False (default): FeatureVector (BxC) suitable for
|
|
52
|
+
ClassificationHead or RegressionHead.
|
|
53
|
+
If keep_spatial_dims=True: FeatureMaps with 1x1 spatial dimensions suitable
|
|
54
|
+
for EmbeddingHead.
|
|
55
|
+
"""
|
|
56
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
57
|
+
raise ValueError("input to GlobalPool must be FeatureMaps")
|
|
58
|
+
|
|
59
|
+
pooled_features = []
|
|
60
|
+
for feat in intermediates.feature_maps:
|
|
61
|
+
# feat is BCHW
|
|
62
|
+
if self.mode == "mean":
|
|
63
|
+
pooled = feat.mean(dim=(2, 3), keepdim=self.keep_spatial_dims)
|
|
64
|
+
else:
|
|
65
|
+
pooled = torch.amax(feat, dim=(2, 3), keepdim=self.keep_spatial_dims)
|
|
66
|
+
pooled_features.append(pooled)
|
|
67
|
+
|
|
68
|
+
if self.keep_spatial_dims:
|
|
69
|
+
return FeatureMaps(pooled_features)
|
|
70
|
+
else:
|
|
71
|
+
if len(pooled_features) == 1:
|
|
72
|
+
return FeatureVector(pooled_features[0])
|
|
73
|
+
else:
|
|
74
|
+
return FeatureVector(torch.cat(pooled_features, dim=1))
|
|
@@ -9,6 +9,7 @@ rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,79
|
|
|
9
9
|
rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
|
|
10
10
|
rslearn/config/dataset.py,sha256=abxIUFDAYmCd4pzGnkPnW_pYyws1yhXFWJ5HnVU4WHo,23942
|
|
11
11
|
rslearn/data_sources/__init__.py,sha256=FZQckYwsnnLokMeYmi0ktUyQd9bAHyLN1_-Xc3qYLag,767
|
|
12
|
+
rslearn/data_sources/aws_google_satellite_embedding_v1.py,sha256=ga5G8uDXdMj2pw8qGC2PD11PBg8Spuf3b4QEwsVJaBY,12805
|
|
12
13
|
rslearn/data_sources/aws_landsat.py,sha256=bJmwBbUV4vjKBNp1MHt4sHhnIjhMis_jOI3FpksQc6w,16435
|
|
13
14
|
rslearn/data_sources/aws_open_data.py,sha256=fum34DqqDiuiiYBfZtGFrNNOLylE9o3o7Cyb2e0Eo0g,29101
|
|
14
15
|
rslearn/data_sources/aws_sentinel1.py,sha256=LfJLDhsd_6h_JinD8PbiiAyxajkIdvAc--5BJryUKlo,4674
|
|
@@ -16,7 +17,7 @@ rslearn/data_sources/aws_sentinel2_element84.py,sha256=qeCuiSlvhWChSY3AwYsKT6nZU
|
|
|
16
17
|
rslearn/data_sources/climate_data_store.py,sha256=mqLJfYubD6m9VwxpLunoIv_MNFN6Ue1hBBVj552e8uQ,18289
|
|
17
18
|
rslearn/data_sources/copernicus.py,sha256=ushAgYGxU2MzPcUNnEvEfPgO0RCC9Rbjzi189xq0jgc,35001
|
|
18
19
|
rslearn/data_sources/data_source.py,sha256=xojlCoAnGTCHKbEx98JkW0oYzAKBbgGMNc0kicEjHWk,4863
|
|
19
|
-
rslearn/data_sources/direct_materialize_data_source.py,sha256=
|
|
20
|
+
rslearn/data_sources/direct_materialize_data_source.py,sha256=UnFuCSJED9-YSFp12-MosV8bMFj6AqCb75a9ADu_Cxw,11030
|
|
20
21
|
rslearn/data_sources/earthdaily.py,sha256=qUtHUG1oV5IlCWXVovUcYxQhqdNDKWaEe-BKnooWX88,14623
|
|
21
22
|
rslearn/data_sources/earthdatahub.py,sha256=KRf1VnxPI9jsT0utEkeYvsCwu7LXo9t-RvMi8gXehag,15889
|
|
22
23
|
rslearn/data_sources/eurocrops.py,sha256=dJ4d0xvt-rID_HuAchyucFJBuAQL-Kk1h_qm6GOH-mE,8641
|
|
@@ -32,7 +33,7 @@ rslearn/data_sources/soilgrids.py,sha256=qbnnCIOa6tlN8wxmNCzAj60-pghKEbRxa7lVIgM
|
|
|
32
33
|
rslearn/data_sources/stac.py,sha256=Gj8TZ5pifVzWPCuzgphrle2ekQ02OET54rj-02sR2nw,10705
|
|
33
34
|
rslearn/data_sources/usda_cdl.py,sha256=3GhcgTB50T7GA44nB9WwItqDJliELquw_YbiAVxh6kc,6808
|
|
34
35
|
rslearn/data_sources/usgs_landsat.py,sha256=IsQOhWY8nwmgixJu1uMSR4CqsC3igcP3TArdBXkETd8,10178
|
|
35
|
-
rslearn/data_sources/utils.py,sha256=
|
|
36
|
+
rslearn/data_sources/utils.py,sha256=EAVFCYzjFvuHWd7E2ghTub9f-bbDhq83p3x9IJDjgvk,16843
|
|
36
37
|
rslearn/data_sources/vector_source.py,sha256=NCa7CxIrGKe9yRT0NyyFKFQboDGDZ1h7663PV9OfMOM,44
|
|
37
38
|
rslearn/data_sources/worldcereal.py,sha256=OWZA0pvQQiKvuA5AVAc0lw8JStMEeF4DYOh0n2vdg6I,21521
|
|
38
39
|
rslearn/data_sources/worldcover.py,sha256=ahyrGoXMAGWsIUDHSrqPywiK7ycwUD3E3BruNMxpo90,6057
|
|
@@ -61,6 +62,7 @@ rslearn/models/dinov3.py,sha256=Q9X7VTwzjllLSEvc235C9BY_jMnIoSybsiOkeA58uHo,6472
|
|
|
61
62
|
rslearn/models/faster_rcnn.py,sha256=yOipLPmVHbadvYCR9xfCYgmkU9Mot6fgDK-kKicVTlo,8685
|
|
62
63
|
rslearn/models/feature_center_crop.py,sha256=_Mu3E4iJLBug9I4ZIBIpB_VJo-xGterHmhtIFGaHR34,1808
|
|
63
64
|
rslearn/models/fpn.py,sha256=qm7nKMgsZrCoAdz8ASmNKU2nvZ6USm5CedMfy_w_gwE,2079
|
|
65
|
+
rslearn/models/global_pool.py,sha256=Bl48AVJ7g70hPmVLJbK1y_JN9_FTyANc_7tr6YOHANY,2782
|
|
64
66
|
rslearn/models/module_wrapper.py,sha256=73JspaglnNabUGZB2EiCYF_dZ3-Kicg_OpoTfUWHONk,2271
|
|
65
67
|
rslearn/models/molmo.py,sha256=lXnevwTCNyc1XcnJUB5_pK1G2AJGYMvQYU21mZFf5u0,2246
|
|
66
68
|
rslearn/models/multitask.py,sha256=bpFxvtFowRyT-tvRSdY7AKbEx_i1y7sToEzZgTMcF4s,16264
|
|
@@ -176,10 +178,10 @@ rslearn/vis/render_sensor_image.py,sha256=D0ynK6ABPV046970lIKwF98klpSCtrsUvZTwtZ
|
|
|
176
178
|
rslearn/vis/render_vector_label.py,sha256=ncwgRKCYCJCK1-wTpjgksOiDDebku37LpAyq6wsg4jg,14939
|
|
177
179
|
rslearn/vis/utils.py,sha256=Zop3dEmyaXUYhPiGdYzrTO8BRXWscP2dEZy2myQUnNk,2765
|
|
178
180
|
rslearn/vis/vis_server.py,sha256=kIGnhTy-yfu5lBOVCoo8VVG259i974JPszudCePbzfI,20157
|
|
179
|
-
rslearn-0.0.
|
|
180
|
-
rslearn-0.0.
|
|
181
|
-
rslearn-0.0.
|
|
182
|
-
rslearn-0.0.
|
|
183
|
-
rslearn-0.0.
|
|
184
|
-
rslearn-0.0.
|
|
185
|
-
rslearn-0.0.
|
|
181
|
+
rslearn-0.0.29.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
|
|
182
|
+
rslearn-0.0.29.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
|
|
183
|
+
rslearn-0.0.29.dist-info/METADATA,sha256=PMjB15sAZg5VA7qHkMrhG_2hSULc0Wopxr7G3op20Hg,38714
|
|
184
|
+
rslearn-0.0.29.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
185
|
+
rslearn-0.0.29.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
|
|
186
|
+
rslearn-0.0.29.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
|
|
187
|
+
rslearn-0.0.29.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|