rslearn 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +151 -55
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/METADATA +1 -1
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/RECORD +33 -32
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py
CHANGED
|
@@ -236,11 +236,9 @@ class BandSetConfig(BaseModel):
|
|
|
236
236
|
|
|
237
237
|
warnings.warn(
|
|
238
238
|
"`format = {'name': ...}` is deprecated; "
|
|
239
|
-
"use `{'class_path': '...', 'init_args': {...}}` instead."
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
logger.warning(
|
|
243
|
-
"BandSet.format uses legacy format; support will be removed after 2026-03-01."
|
|
239
|
+
"use `{'class_path': '...', 'init_args': {...}}` instead. "
|
|
240
|
+
"Support will be removed after 2026-03-01.",
|
|
241
|
+
FutureWarning,
|
|
244
242
|
)
|
|
245
243
|
|
|
246
244
|
legacy_name_to_class_path = {
|
|
@@ -294,16 +292,6 @@ class SpaceMode(StrEnum):
|
|
|
294
292
|
The duration of the sub-periods is controlled by another option in QueryConfig.
|
|
295
293
|
"""
|
|
296
294
|
|
|
297
|
-
COMPOSITE = "COMPOSITE"
|
|
298
|
-
"""Creates one composite covering the entire window.
|
|
299
|
-
|
|
300
|
-
During querying all items intersecting the window are placed in one group.
|
|
301
|
-
The compositing_method in the rasterlayer config specifies how these items are reduced
|
|
302
|
-
to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
|
|
303
|
-
"""
|
|
304
|
-
|
|
305
|
-
# TODO add PER_PERIOD_COMPOSITE
|
|
306
|
-
|
|
307
295
|
|
|
308
296
|
class TimeMode(StrEnum):
|
|
309
297
|
"""Temporal matching mode when looking up items corresponding to a window."""
|
|
@@ -353,6 +341,20 @@ class QueryConfig(BaseModel):
|
|
|
353
341
|
default=timedelta(days=30),
|
|
354
342
|
description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
|
|
355
343
|
)
|
|
344
|
+
mosaic_compositing_overlaps: int = Field(
|
|
345
|
+
default=1,
|
|
346
|
+
description="For MOSAIC and PER_PERIOD_MOSAIC modes, the number of overlapping items "
|
|
347
|
+
"wanted within each item group covering the window. Set to 1 for a single coverage "
|
|
348
|
+
"(default mosaic behavior), or higher for compositing multiple overlapping items."
|
|
349
|
+
"with mean or median compositing method.",
|
|
350
|
+
)
|
|
351
|
+
per_period_mosaic_reverse_time_order: bool = Field(
|
|
352
|
+
default=True,
|
|
353
|
+
description="For PER_PERIOD_MOSAIC mode, whether to return item groups in reverse "
|
|
354
|
+
"temporal order (most recent first). Set to False for chronological order (oldest first). "
|
|
355
|
+
"Default True is deprecated and will change to False with error if still unset or set True "
|
|
356
|
+
"after 2026-04-01.",
|
|
357
|
+
)
|
|
356
358
|
|
|
357
359
|
|
|
358
360
|
class DataSourceConfig(BaseModel):
|
|
@@ -404,11 +406,9 @@ class DataSourceConfig(BaseModel):
|
|
|
404
406
|
|
|
405
407
|
warnings.warn(
|
|
406
408
|
"`Data source configuration {'name': ...}` is deprecated; "
|
|
407
|
-
"use `{'class_path': '...', 'init_args': {...}, ...}` instead."
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
logger.warning(
|
|
411
|
-
"Data source configuration uses legacy format; support will be removed after 2026-03-01."
|
|
409
|
+
"use `{'class_path': '...', 'init_args': {...}, ...}` instead. "
|
|
410
|
+
"Support will be removed after 2026-03-01.",
|
|
411
|
+
FutureWarning,
|
|
412
412
|
)
|
|
413
413
|
|
|
414
414
|
# Split the dict into the base config that is in the pydantic model, and the
|
|
@@ -431,8 +431,9 @@ class DataSourceConfig(BaseModel):
|
|
|
431
431
|
and "max_cloud_cover" in ds_init_args
|
|
432
432
|
):
|
|
433
433
|
warnings.warn(
|
|
434
|
-
"Data source configuration specifies invalid 'max_cloud_cover' option."
|
|
435
|
-
|
|
434
|
+
"Data source configuration specifies invalid 'max_cloud_cover' option."
|
|
435
|
+
"Support for ignoring this option will be removed after 2026-03-01.",
|
|
436
|
+
FutureWarning,
|
|
436
437
|
)
|
|
437
438
|
del ds_init_args["max_cloud_cover"]
|
|
438
439
|
|
|
@@ -449,7 +450,13 @@ class LayerType(StrEnum):
|
|
|
449
450
|
|
|
450
451
|
|
|
451
452
|
class CompositingMethod(StrEnum):
|
|
452
|
-
"""Method how to select pixels for the composite from corresponding items of a window.
|
|
453
|
+
"""Method how to select pixels for the composite from corresponding items of a window.
|
|
454
|
+
|
|
455
|
+
For MEAN and MEDIAN modes, mosaic_compositing_overlaps (in the QueryConfig) should
|
|
456
|
+
be set higher than 1 so that rslearn creates item groups during prepare that cover
|
|
457
|
+
the window with multiple overlaps. At each pixel/band, the mean and median can then
|
|
458
|
+
be computed across items in each group that cover that pixel.
|
|
459
|
+
"""
|
|
453
460
|
|
|
454
461
|
FIRST_VALID = "FIRST_VALID"
|
|
455
462
|
"""Select first valid pixel in order of corresponding items (might be sorted)"""
|
|
@@ -236,8 +236,8 @@ class RasterImporter(Importer):
|
|
|
236
236
|
"windows in the rslearn dataset. When using settings like "
|
|
237
237
|
"max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
|
|
238
238
|
"the geometry’s valid bounds to be materialized from the global raster "
|
|
239
|
-
"instead of a more appropriate source. Consider
|
|
240
|
-
"
|
|
239
|
+
"instead of a more appropriate source. Consider increasing max_matches"
|
|
240
|
+
"if this behavior is unintended."
|
|
241
241
|
)
|
|
242
242
|
|
|
243
243
|
if spec.name:
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Utilities shared by data sources."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
3
5
|
from dataclasses import dataclass
|
|
4
|
-
from datetime import UTC, datetime
|
|
6
|
+
from datetime import UTC, datetime
|
|
5
7
|
from typing import TypeVar
|
|
6
8
|
|
|
7
9
|
import shapely
|
|
@@ -40,13 +42,13 @@ class PendingMosaic:
|
|
|
40
42
|
completed: bool = False
|
|
41
43
|
|
|
42
44
|
|
|
43
|
-
def
|
|
45
|
+
def _create_single_coverage_mosaics(
|
|
44
46
|
window_geometry: STGeometry,
|
|
45
47
|
items: list[ItemType],
|
|
46
48
|
item_shps: list[shapely.Geometry],
|
|
47
|
-
|
|
49
|
+
max_mosaics: int,
|
|
48
50
|
) -> list[list[ItemType]]:
|
|
49
|
-
"""
|
|
51
|
+
"""Create mosaics where each mosaic covers the window geometry once.
|
|
50
52
|
|
|
51
53
|
This attempts to piece together items into mosaics that fully cover the window
|
|
52
54
|
geometry. If there are items leftover that only partially cover the window
|
|
@@ -56,15 +58,16 @@ def mosaic_matching(
|
|
|
56
58
|
window_geometry: the geometry of the window.
|
|
57
59
|
items: list of items.
|
|
58
60
|
item_shps: the item shapes projected to the window's projection.
|
|
59
|
-
|
|
61
|
+
max_mosaics: the maximum number of mosaics to create.
|
|
60
62
|
|
|
61
63
|
Returns:
|
|
62
|
-
list of item groups, each one corresponding to a different
|
|
64
|
+
list of item groups, each one corresponding to a different single-coverage
|
|
65
|
+
mosaic.
|
|
63
66
|
"""
|
|
64
67
|
# To create mosaics, we iterate over the items in order, and add each item to
|
|
65
68
|
# the first mosaic that the new item adds coverage to.
|
|
66
69
|
|
|
67
|
-
#
|
|
70
|
+
# max_mosaics could be very high if the user just wants us to create as many
|
|
68
71
|
# mosaics as possible, so we initialize the list here as empty and just add
|
|
69
72
|
# more pending mosaics when it is necessary.
|
|
70
73
|
pending_mosaics: list[PendingMosaic] = []
|
|
@@ -108,7 +111,7 @@ def mosaic_matching(
|
|
|
108
111
|
|
|
109
112
|
# See if we can add a new mosaic based on this item. There must be room for
|
|
110
113
|
# more mosaics, but the item must also intersect the requested geometry.
|
|
111
|
-
if len(pending_mosaics) >=
|
|
114
|
+
if len(pending_mosaics) >= max_mosaics:
|
|
112
115
|
continue
|
|
113
116
|
intersect_area = item_shp.intersection(window_geometry.shp).area
|
|
114
117
|
if (
|
|
@@ -127,18 +130,148 @@ def mosaic_matching(
|
|
|
127
130
|
return [pending_mosaic.items for pending_mosaic in pending_mosaics]
|
|
128
131
|
|
|
129
132
|
|
|
130
|
-
def
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
def _consolidate_mosaics_by_overlaps(
|
|
134
|
+
mosaics: list[list[ItemType]],
|
|
135
|
+
overlaps: int,
|
|
136
|
+
max_groups: int,
|
|
137
|
+
) -> list[list[ItemType]]:
|
|
138
|
+
"""Consolidate single-coverage mosaics into groups based on desired overlaps.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
mosaics: list of single-coverage mosaics (each mosaic is a list of items).
|
|
142
|
+
overlaps: the number of overlapping coverages wanted per group.
|
|
143
|
+
max_groups: the maximum number of groups to return.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
list of item groups, where each group contains items from multiple mosaics
|
|
147
|
+
to achieve the desired number of overlapping coverages.
|
|
148
|
+
"""
|
|
149
|
+
if overlaps <= 0:
|
|
150
|
+
overlaps = 1
|
|
151
|
+
|
|
152
|
+
groups: list[list[ItemType]] = []
|
|
153
|
+
for i in range(0, len(mosaics), overlaps):
|
|
154
|
+
if len(groups) >= max_groups:
|
|
155
|
+
break
|
|
156
|
+
# Combine overlaps consecutive mosaics into one group
|
|
157
|
+
combined_items: list[ItemType] = []
|
|
158
|
+
for mosaic in mosaics[i : i + overlaps]:
|
|
159
|
+
combined_items.extend(mosaic)
|
|
160
|
+
if combined_items:
|
|
161
|
+
groups.append(combined_items)
|
|
162
|
+
|
|
163
|
+
return groups
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def match_with_space_mode_contains(
|
|
167
|
+
geometry: STGeometry,
|
|
168
|
+
items: list[ItemType],
|
|
169
|
+
item_shps: list[shapely.Geometry],
|
|
170
|
+
query_config: QueryConfig,
|
|
171
|
+
) -> list[list[ItemType]]:
|
|
172
|
+
"""Match items that fully contain the window geometry.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
geometry: the window's geometry.
|
|
176
|
+
items: list of items.
|
|
177
|
+
item_shps: the item shapes projected to the window's projection.
|
|
178
|
+
query_config: the query configuration.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
list of matched item groups, where each group contains a single item.
|
|
182
|
+
"""
|
|
183
|
+
groups: list[list[ItemType]] = []
|
|
184
|
+
for item, item_shp in zip(items, item_shps):
|
|
185
|
+
if not item_shp.contains(geometry.shp):
|
|
186
|
+
continue
|
|
187
|
+
groups.append([item])
|
|
188
|
+
if len(groups) >= query_config.max_matches:
|
|
189
|
+
break
|
|
190
|
+
return groups
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def match_with_space_mode_intersects(
|
|
194
|
+
geometry: STGeometry,
|
|
195
|
+
items: list[ItemType],
|
|
196
|
+
item_shps: list[shapely.Geometry],
|
|
197
|
+
query_config: QueryConfig,
|
|
198
|
+
) -> list[list[ItemType]]:
|
|
199
|
+
"""Match items that intersect any portion of the window geometry.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
geometry: the window's geometry.
|
|
203
|
+
items: list of items.
|
|
204
|
+
item_shps: the item shapes projected to the window's projection.
|
|
205
|
+
query_config: the query configuration.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
list of matched item groups, where each group contains a single item.
|
|
209
|
+
"""
|
|
210
|
+
groups: list[list[ItemType]] = []
|
|
211
|
+
for item, item_shp in zip(items, item_shps):
|
|
212
|
+
if not shp_intersects(item_shp, geometry.shp):
|
|
213
|
+
continue
|
|
214
|
+
groups.append([item])
|
|
215
|
+
if len(groups) >= query_config.max_matches:
|
|
216
|
+
break
|
|
217
|
+
return groups
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def match_with_space_mode_mosaic(
|
|
221
|
+
geometry: STGeometry,
|
|
222
|
+
items: list[ItemType],
|
|
223
|
+
item_shps: list[shapely.Geometry],
|
|
224
|
+
query_config: QueryConfig,
|
|
225
|
+
) -> list[list[ItemType]]:
|
|
226
|
+
"""Match items into mosaic groups that cover the window geometry.
|
|
227
|
+
|
|
228
|
+
Creates groups of items that together cover the window geometry. The number of
|
|
229
|
+
overlapping coverages in each group is controlled by mosaic_compositing_overlaps.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
geometry: the window's geometry.
|
|
233
|
+
items: list of items.
|
|
234
|
+
item_shps: the item shapes projected to the window's projection.
|
|
235
|
+
query_config: the query configuration.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
list of matched item groups, where each group forms a mosaic covering the
|
|
239
|
+
window.
|
|
240
|
+
"""
|
|
241
|
+
overlaps = query_config.mosaic_compositing_overlaps
|
|
242
|
+
|
|
243
|
+
# Calculate how many single-coverage mosaics we need to create.
|
|
244
|
+
# We need enough mosaics to consolidate into max_matches groups with the
|
|
245
|
+
# desired number of overlaps per group.
|
|
246
|
+
max_single_mosaics = query_config.max_matches * overlaps
|
|
247
|
+
|
|
248
|
+
# Create single-coverage mosaics
|
|
249
|
+
single_mosaics = _create_single_coverage_mosaics(
|
|
250
|
+
geometry, items, item_shps, max_single_mosaics
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Consolidate into groups based on overlaps
|
|
254
|
+
return _consolidate_mosaics_by_overlaps(
|
|
255
|
+
single_mosaics, overlaps, query_config.max_matches
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def match_with_space_mode_per_period_mosaic(
|
|
260
|
+
geometry: STGeometry,
|
|
261
|
+
items: list[ItemType],
|
|
262
|
+
item_shps: list[shapely.Geometry],
|
|
263
|
+
query_config: QueryConfig,
|
|
135
264
|
) -> list[list[ItemType]]:
|
|
136
265
|
"""Match items to the geometry with one mosaic per period.
|
|
137
266
|
|
|
138
267
|
We divide the time range of the geometry into shorter periods. Within each period,
|
|
139
268
|
we use the items corresponding to that period to create a mosaic. The returned item
|
|
140
|
-
groups include one group per period,
|
|
141
|
-
|
|
269
|
+
groups include one group per period, up to the provided max_matches.
|
|
270
|
+
|
|
271
|
+
By default (reverse_time_order=True), groups are returned starting from the most
|
|
272
|
+
recent periods. When reverse_time_order=False, groups are returned in chronological
|
|
273
|
+
order (oldest first). reverse_time_order should always be set False, and
|
|
274
|
+
FutureWarning will be warned if it is not.
|
|
142
275
|
|
|
143
276
|
The periods are also bounded to the window's time range, and aligned with the end
|
|
144
277
|
of that time range, i.e. the most recent window is
|
|
@@ -159,42 +292,59 @@ def per_period_mosaic_matching(
|
|
|
159
292
|
max_matches*period_duration is not equivalent to a longer window duration.
|
|
160
293
|
|
|
161
294
|
Args:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
295
|
+
geometry: the window's geometry.
|
|
296
|
+
items: list of items.
|
|
297
|
+
item_shps: the item shapes projected to the window's projection (unused here)
|
|
298
|
+
query_config: the query configuration.
|
|
166
299
|
|
|
167
300
|
Returns:
|
|
168
|
-
|
|
169
|
-
|
|
301
|
+
list of matched item groups, where each group contains items that yield a
|
|
302
|
+
per-period mosaic.
|
|
170
303
|
"""
|
|
171
|
-
if
|
|
304
|
+
if geometry.time_range is None:
|
|
172
305
|
raise ValueError(
|
|
173
306
|
"all windows must have time range for per period mosaic matching"
|
|
174
307
|
)
|
|
175
308
|
|
|
309
|
+
# Emit warning if per_period_mosaic_reverse_time_order is True (the default).
|
|
310
|
+
if query_config.per_period_mosaic_reverse_time_order:
|
|
311
|
+
warnings.warn(
|
|
312
|
+
"QueryConfig.per_period_mosaic_reverse_time_order defaults to True, which "
|
|
313
|
+
"returns item groups in reverse temporal order (most recent first) for "
|
|
314
|
+
"PER_PERIOD_MOSAIC mode. This default will change to False (chronological "
|
|
315
|
+
"order) after 2026-04-01. To silence this warning, explicitly set "
|
|
316
|
+
"per_period_mosaic_reverse_time_order=False.",
|
|
317
|
+
FutureWarning,
|
|
318
|
+
stacklevel=3,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
period_duration = query_config.period_duration
|
|
322
|
+
|
|
176
323
|
# For each period, we create an STGeometry with modified time range matching that
|
|
177
324
|
# period, and use it with match_candidate_items_to_window to get a mosaic.
|
|
178
325
|
cur_groups: list[list[ItemType]] = []
|
|
179
|
-
period_start =
|
|
326
|
+
period_start = geometry.time_range[1] - period_duration
|
|
180
327
|
while (
|
|
181
|
-
period_start >=
|
|
328
|
+
period_start >= geometry.time_range[0]
|
|
329
|
+
and len(cur_groups) < query_config.max_matches
|
|
182
330
|
):
|
|
183
331
|
period_time_range = (
|
|
184
332
|
period_start,
|
|
185
333
|
period_start + period_duration,
|
|
186
334
|
)
|
|
187
335
|
period_start -= period_duration
|
|
188
|
-
period_geom = STGeometry(
|
|
189
|
-
window_geometry.projection, window_geometry.shp, period_time_range
|
|
190
|
-
)
|
|
336
|
+
period_geom = STGeometry(geometry.projection, geometry.shp, period_time_range)
|
|
191
337
|
|
|
192
338
|
# We modify the QueryConfig here since caller should be asking for
|
|
193
339
|
# multiple mosaics, but we just want one mosaic per period.
|
|
194
340
|
period_groups = match_candidate_items_to_window(
|
|
195
341
|
period_geom,
|
|
196
|
-
|
|
197
|
-
QueryConfig(
|
|
342
|
+
items,
|
|
343
|
+
QueryConfig(
|
|
344
|
+
space_mode=SpaceMode.MOSAIC,
|
|
345
|
+
max_matches=1,
|
|
346
|
+
mosaic_compositing_overlaps=query_config.mosaic_compositing_overlaps,
|
|
347
|
+
),
|
|
198
348
|
)
|
|
199
349
|
|
|
200
350
|
# There should be zero or one group depending on whether there were
|
|
@@ -204,9 +354,29 @@ def per_period_mosaic_matching(
|
|
|
204
354
|
continue
|
|
205
355
|
cur_groups.append(period_groups[0])
|
|
206
356
|
|
|
357
|
+
# Currently the item groups are in reverse chronologic order.
|
|
358
|
+
# Reverse it to correct chronological order if requested.
|
|
359
|
+
if not query_config.per_period_mosaic_reverse_time_order:
|
|
360
|
+
cur_groups.reverse()
|
|
361
|
+
|
|
207
362
|
return cur_groups
|
|
208
363
|
|
|
209
364
|
|
|
365
|
+
# Type alias for space mode handler functions
|
|
366
|
+
SpaceModeHandler = Callable[
|
|
367
|
+
[STGeometry, list[ItemType], list[shapely.Geometry], QueryConfig],
|
|
368
|
+
list[list[ItemType]],
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
# Dict mapping SpaceMode values to their handler functions
|
|
372
|
+
space_mode_handlers: dict[SpaceMode, SpaceModeHandler] = {
|
|
373
|
+
SpaceMode.CONTAINS: match_with_space_mode_contains,
|
|
374
|
+
SpaceMode.INTERSECTS: match_with_space_mode_intersects,
|
|
375
|
+
SpaceMode.MOSAIC: match_with_space_mode_mosaic,
|
|
376
|
+
SpaceMode.PER_PERIOD_MOSAIC: match_with_space_mode_per_period_mosaic,
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
|
|
210
380
|
def match_candidate_items_to_window(
|
|
211
381
|
geometry: STGeometry, items: list[ItemType], query_config: QueryConfig
|
|
212
382
|
) -> list[list[ItemType]]:
|
|
@@ -262,43 +432,13 @@ def match_candidate_items_to_window(
|
|
|
262
432
|
item_geom = item_geom.to_projection(geometry.projection)
|
|
263
433
|
item_shps.append(item_geom.shp)
|
|
264
434
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
if not item_shp.contains(geometry.shp):
|
|
269
|
-
continue
|
|
270
|
-
groups.append([item])
|
|
271
|
-
if len(groups) >= query_config.max_matches:
|
|
272
|
-
break
|
|
273
|
-
|
|
274
|
-
elif query_config.space_mode == SpaceMode.INTERSECTS:
|
|
275
|
-
groups = []
|
|
276
|
-
for item, item_shp in zip(items, item_shps):
|
|
277
|
-
if not shp_intersects(item_shp, geometry.shp):
|
|
278
|
-
continue
|
|
279
|
-
groups.append([item])
|
|
280
|
-
if len(groups) >= query_config.max_matches:
|
|
281
|
-
break
|
|
282
|
-
|
|
283
|
-
elif query_config.space_mode == SpaceMode.MOSAIC:
|
|
284
|
-
groups = mosaic_matching(geometry, items, item_shps, query_config.max_matches)
|
|
285
|
-
|
|
286
|
-
elif query_config.space_mode == SpaceMode.PER_PERIOD_MOSAIC:
|
|
287
|
-
groups = per_period_mosaic_matching(
|
|
288
|
-
geometry, items, query_config.period_duration, query_config.max_matches
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
elif query_config.space_mode == SpaceMode.COMPOSITE:
|
|
292
|
-
group = []
|
|
293
|
-
for item, item_shp in zip(items, item_shps):
|
|
294
|
-
if not shp_intersects(item_shp, geometry.shp):
|
|
295
|
-
continue
|
|
296
|
-
group.append(item)
|
|
297
|
-
groups = [group]
|
|
298
|
-
|
|
299
|
-
else:
|
|
435
|
+
# Dispatch to the appropriate space mode handler
|
|
436
|
+
handler = space_mode_handlers.get(query_config.space_mode)
|
|
437
|
+
if handler is None:
|
|
300
438
|
raise ValueError(f"invalid space mode {query_config.space_mode}")
|
|
301
439
|
|
|
440
|
+
groups = handler(geometry, items, item_shps, query_config)
|
|
441
|
+
|
|
302
442
|
# Enforce minimum matches if set.
|
|
303
443
|
if len(groups) < query_config.min_matches:
|
|
304
444
|
logger.warning(
|
rslearn/dataset/materialize.py
CHANGED
|
@@ -236,7 +236,11 @@ def read_and_stack_raster_windows(
|
|
|
236
236
|
band_dtype: npt.DTypeLike,
|
|
237
237
|
resampling_method: Resampling = Resampling.bilinear,
|
|
238
238
|
) -> npt.NDArray[np.generic]:
|
|
239
|
-
"""Create a stack of
|
|
239
|
+
"""Create a stack of raster images, with one per item in the group.
|
|
240
|
+
|
|
241
|
+
We read the portion of each raster item corresponding to the window extent, and
|
|
242
|
+
stack the resulting images. This is used for the MEAN and MEDIAN compositing
|
|
243
|
+
methods to it can compute aggregate statistics across the stack.
|
|
240
244
|
|
|
241
245
|
Args:
|
|
242
246
|
group: Iterable of items (e.g., scene metadata objects) to read data from.
|
rslearn/models/clay/clay.py
CHANGED
|
@@ -105,7 +105,7 @@ class Clay(FeatureExtractor):
|
|
|
105
105
|
|
|
106
106
|
def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
|
|
107
107
|
"""Resize the image to the input resolution."""
|
|
108
|
-
new_hw =
|
|
108
|
+
new_hw = PATCH_SIZE if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
|
|
109
109
|
return F.interpolate(
|
|
110
110
|
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
111
111
|
)
|
|
@@ -123,7 +123,8 @@ class Clay(FeatureExtractor):
|
|
|
123
123
|
device = param.device
|
|
124
124
|
|
|
125
125
|
chips = torch.stack(
|
|
126
|
-
[inp[self.modality] for inp in context.inputs],
|
|
126
|
+
[inp[self.modality].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
127
|
+
dim=0,
|
|
127
128
|
) # (B, C, H, W)
|
|
128
129
|
if self.do_resizing:
|
|
129
130
|
chips = self._resize_image(chips, chips.shape[2])
|
|
@@ -203,7 +204,6 @@ class ClayNormalize(Transform):
|
|
|
203
204
|
mean=means,
|
|
204
205
|
std=stds,
|
|
205
206
|
selectors=[modality],
|
|
206
|
-
num_bands=len(means),
|
|
207
207
|
)
|
|
208
208
|
self.normalizers = torch.nn.ModuleDict(normalizers)
|
|
209
209
|
|
rslearn/models/detr/detr.py
CHANGED
|
@@ -468,7 +468,10 @@ class Detr(Predictor):
|
|
|
468
468
|
|
|
469
469
|
# Get image sizes.
|
|
470
470
|
image_sizes = torch.tensor(
|
|
471
|
-
[
|
|
471
|
+
[
|
|
472
|
+
[inp["image"].image.shape[2], inp["image"].image.shape[1]]
|
|
473
|
+
for inp in context.inputs
|
|
474
|
+
],
|
|
472
475
|
dtype=torch.int32,
|
|
473
476
|
device=features.device,
|
|
474
477
|
)
|
rslearn/models/dinov3.py
CHANGED
|
@@ -95,7 +95,9 @@ class OlmoEarth(FeatureExtractor):
|
|
|
95
95
|
"""
|
|
96
96
|
if use_legacy_timestamps:
|
|
97
97
|
warnings.warn(
|
|
98
|
-
"For new projects, don't use legacy timesteps."
|
|
98
|
+
"For new projects, don't use legacy timesteps. "
|
|
99
|
+
"Support will be removed after 2026-04-01.",
|
|
100
|
+
FutureWarning,
|
|
99
101
|
)
|
|
100
102
|
|
|
101
103
|
if (
|
|
@@ -124,6 +124,6 @@ class SegmentationPoolingDecoder(PoolingDecoder):
|
|
|
124
124
|
"""
|
|
125
125
|
output_probs = super().forward(intermediates, context)
|
|
126
126
|
# BC -> BCHW
|
|
127
|
-
h, w = context.inputs[0][self.image_key].shape[1:3]
|
|
127
|
+
h, w = context.inputs[0][self.image_key].image.shape[1:3]
|
|
128
128
|
feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
|
|
129
129
|
return FeatureMaps([feat_map])
|