rslearn 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/local_files.py +2 -2
  3. rslearn/data_sources/utils.py +204 -64
  4. rslearn/dataset/materialize.py +5 -1
  5. rslearn/models/clay/clay.py +3 -3
  6. rslearn/models/detr/detr.py +4 -1
  7. rslearn/models/dinov3.py +0 -1
  8. rslearn/models/olmoearth_pretrain/model.py +3 -1
  9. rslearn/models/pooling_decoder.py +1 -1
  10. rslearn/models/prithvi.py +0 -1
  11. rslearn/models/simple_time_series.py +97 -35
  12. rslearn/train/data_module.py +5 -0
  13. rslearn/train/dataset.py +151 -55
  14. rslearn/train/dataset_index.py +156 -0
  15. rslearn/train/model_context.py +16 -0
  16. rslearn/train/tasks/per_pixel_regression.py +13 -13
  17. rslearn/train/tasks/segmentation.py +26 -13
  18. rslearn/train/transforms/concatenate.py +17 -27
  19. rslearn/train/transforms/crop.py +8 -19
  20. rslearn/train/transforms/flip.py +4 -10
  21. rslearn/train/transforms/mask.py +9 -15
  22. rslearn/train/transforms/normalize.py +31 -82
  23. rslearn/train/transforms/pad.py +7 -13
  24. rslearn/train/transforms/resize.py +5 -22
  25. rslearn/train/transforms/select_bands.py +16 -36
  26. rslearn/train/transforms/sentinel1.py +4 -16
  27. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/METADATA +1 -1
  28. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/RECORD +33 -32
  29. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/WHEEL +0 -0
  30. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
  31. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
  32. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
  33. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py CHANGED
@@ -236,11 +236,9 @@ class BandSetConfig(BaseModel):
236
236
 
237
237
  warnings.warn(
238
238
  "`format = {'name': ...}` is deprecated; "
239
- "use `{'class_path': '...', 'init_args': {...}}` instead.",
240
- DeprecationWarning,
241
- )
242
- logger.warning(
243
- "BandSet.format uses legacy format; support will be removed after 2026-03-01."
239
+ "use `{'class_path': '...', 'init_args': {...}}` instead. "
240
+ "Support will be removed after 2026-03-01.",
241
+ FutureWarning,
244
242
  )
245
243
 
246
244
  legacy_name_to_class_path = {
@@ -294,16 +292,6 @@ class SpaceMode(StrEnum):
294
292
  The duration of the sub-periods is controlled by another option in QueryConfig.
295
293
  """
296
294
 
297
- COMPOSITE = "COMPOSITE"
298
- """Creates one composite covering the entire window.
299
-
300
- During querying all items intersecting the window are placed in one group.
301
- The compositing_method in the rasterlayer config specifies how these items are reduced
302
- to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
303
- """
304
-
305
- # TODO add PER_PERIOD_COMPOSITE
306
-
307
295
 
308
296
  class TimeMode(StrEnum):
309
297
  """Temporal matching mode when looking up items corresponding to a window."""
@@ -353,6 +341,20 @@ class QueryConfig(BaseModel):
353
341
  default=timedelta(days=30),
354
342
  description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
355
343
  )
344
+ mosaic_compositing_overlaps: int = Field(
345
+ default=1,
346
+ description="For MOSAIC and PER_PERIOD_MOSAIC modes, the number of overlapping items "
347
+ "wanted within each item group covering the window. Set to 1 for a single coverage "
348
+ "(default mosaic behavior), or higher for compositing multiple overlapping items."
349
+ "with mean or median compositing method.",
350
+ )
351
+ per_period_mosaic_reverse_time_order: bool = Field(
352
+ default=True,
353
+ description="For PER_PERIOD_MOSAIC mode, whether to return item groups in reverse "
354
+ "temporal order (most recent first). Set to False for chronological order (oldest first). "
355
+ "Default True is deprecated and will change to False with error if still unset or set True "
356
+ "after 2026-04-01.",
357
+ )
356
358
 
357
359
 
358
360
  class DataSourceConfig(BaseModel):
@@ -404,11 +406,9 @@ class DataSourceConfig(BaseModel):
404
406
 
405
407
  warnings.warn(
406
408
  "`Data source configuration {'name': ...}` is deprecated; "
407
- "use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
408
- DeprecationWarning,
409
- )
410
- logger.warning(
411
- "Data source configuration uses legacy format; support will be removed after 2026-03-01."
409
+ "use `{'class_path': '...', 'init_args': {...}, ...}` instead. "
410
+ "Support will be removed after 2026-03-01.",
411
+ FutureWarning,
412
412
  )
413
413
 
414
414
  # Split the dict into the base config that is in the pydantic model, and the
@@ -431,8 +431,9 @@ class DataSourceConfig(BaseModel):
431
431
  and "max_cloud_cover" in ds_init_args
432
432
  ):
433
433
  warnings.warn(
434
- "Data source configuration specifies invalid 'max_cloud_cover' option.",
435
- DeprecationWarning,
434
+ "Data source configuration specifies invalid 'max_cloud_cover' option."
435
+ "Support for ignoring this option will be removed after 2026-03-01.",
436
+ FutureWarning,
436
437
  )
437
438
  del ds_init_args["max_cloud_cover"]
438
439
 
@@ -449,7 +450,13 @@ class LayerType(StrEnum):
449
450
 
450
451
 
451
452
  class CompositingMethod(StrEnum):
452
- """Method how to select pixels for the composite from corresponding items of a window."""
453
+ """Method how to select pixels for the composite from corresponding items of a window.
454
+
455
+ For MEAN and MEDIAN modes, mosaic_compositing_overlaps (in the QueryConfig) should
456
+ be set higher than 1 so that rslearn creates item groups during prepare that cover
457
+ the window with multiple overlaps. At each pixel/band, the mean and median can then
458
+ be computed across items in each group that cover that pixel.
459
+ """
453
460
 
454
461
  FIRST_VALID = "FIRST_VALID"
455
462
  """Select first valid pixel in order of corresponding items (might be sorted)"""
@@ -236,8 +236,8 @@ class RasterImporter(Importer):
236
236
  "windows in the rslearn dataset. When using settings like "
237
237
  "max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
238
238
  "the geometry’s valid bounds to be materialized from the global raster "
239
- "instead of a more appropriate source. Consider using COMPOSITE mode, "
240
- "or increasing max_matches if this behavior is unintended."
239
+ "instead of a more appropriate source. Consider increasing max_matches"
240
+ "if this behavior is unintended."
241
241
  )
242
242
 
243
243
  if spec.name:
@@ -1,7 +1,9 @@
1
1
  """Utilities shared by data sources."""
2
2
 
3
+ import warnings
4
+ from collections.abc import Callable
3
5
  from dataclasses import dataclass
4
- from datetime import UTC, datetime, timedelta
6
+ from datetime import UTC, datetime
5
7
  from typing import TypeVar
6
8
 
7
9
  import shapely
@@ -40,13 +42,13 @@ class PendingMosaic:
40
42
  completed: bool = False
41
43
 
42
44
 
43
- def mosaic_matching(
45
+ def _create_single_coverage_mosaics(
44
46
  window_geometry: STGeometry,
45
47
  items: list[ItemType],
46
48
  item_shps: list[shapely.Geometry],
47
- max_matches: int,
49
+ max_mosaics: int,
48
50
  ) -> list[list[ItemType]]:
49
- """Spatial item matching for mosaic space mode.
51
+ """Create mosaics where each mosaic covers the window geometry once.
50
52
 
51
53
  This attempts to piece together items into mosaics that fully cover the window
52
54
  geometry. If there are items leftover that only partially cover the window
@@ -56,15 +58,16 @@ def mosaic_matching(
56
58
  window_geometry: the geometry of the window.
57
59
  items: list of items.
58
60
  item_shps: the item shapes projected to the window's projection.
59
- max_matches: the maximum number of matches (mosaics) to create.
61
+ max_mosaics: the maximum number of mosaics to create.
60
62
 
61
63
  Returns:
62
- list of item groups, each one corresponding to a different mosaic.
64
+ list of item groups, each one corresponding to a different single-coverage
65
+ mosaic.
63
66
  """
64
67
  # To create mosaics, we iterate over the items in order, and add each item to
65
68
  # the first mosaic that the new item adds coverage to.
66
69
 
67
- # max_matches could be very high if the user just wants us to create as many
70
+ # max_mosaics could be very high if the user just wants us to create as many
68
71
  # mosaics as possible, so we initialize the list here as empty and just add
69
72
  # more pending mosaics when it is necessary.
70
73
  pending_mosaics: list[PendingMosaic] = []
@@ -108,7 +111,7 @@ def mosaic_matching(
108
111
 
109
112
  # See if we can add a new mosaic based on this item. There must be room for
110
113
  # more mosaics, but the item must also intersect the requested geometry.
111
- if len(pending_mosaics) >= max_matches:
114
+ if len(pending_mosaics) >= max_mosaics:
112
115
  continue
113
116
  intersect_area = item_shp.intersection(window_geometry.shp).area
114
117
  if (
@@ -127,18 +130,148 @@ def mosaic_matching(
127
130
  return [pending_mosaic.items for pending_mosaic in pending_mosaics]
128
131
 
129
132
 
130
- def per_period_mosaic_matching(
131
- window_geometry: STGeometry,
132
- item_list: list[ItemType],
133
- period_duration: timedelta,
134
- max_matches: int,
133
+ def _consolidate_mosaics_by_overlaps(
134
+ mosaics: list[list[ItemType]],
135
+ overlaps: int,
136
+ max_groups: int,
137
+ ) -> list[list[ItemType]]:
138
+ """Consolidate single-coverage mosaics into groups based on desired overlaps.
139
+
140
+ Args:
141
+ mosaics: list of single-coverage mosaics (each mosaic is a list of items).
142
+ overlaps: the number of overlapping coverages wanted per group.
143
+ max_groups: the maximum number of groups to return.
144
+
145
+ Returns:
146
+ list of item groups, where each group contains items from multiple mosaics
147
+ to achieve the desired number of overlapping coverages.
148
+ """
149
+ if overlaps <= 0:
150
+ overlaps = 1
151
+
152
+ groups: list[list[ItemType]] = []
153
+ for i in range(0, len(mosaics), overlaps):
154
+ if len(groups) >= max_groups:
155
+ break
156
+ # Combine overlaps consecutive mosaics into one group
157
+ combined_items: list[ItemType] = []
158
+ for mosaic in mosaics[i : i + overlaps]:
159
+ combined_items.extend(mosaic)
160
+ if combined_items:
161
+ groups.append(combined_items)
162
+
163
+ return groups
164
+
165
+
166
+ def match_with_space_mode_contains(
167
+ geometry: STGeometry,
168
+ items: list[ItemType],
169
+ item_shps: list[shapely.Geometry],
170
+ query_config: QueryConfig,
171
+ ) -> list[list[ItemType]]:
172
+ """Match items that fully contain the window geometry.
173
+
174
+ Args:
175
+ geometry: the window's geometry.
176
+ items: list of items.
177
+ item_shps: the item shapes projected to the window's projection.
178
+ query_config: the query configuration.
179
+
180
+ Returns:
181
+ list of matched item groups, where each group contains a single item.
182
+ """
183
+ groups: list[list[ItemType]] = []
184
+ for item, item_shp in zip(items, item_shps):
185
+ if not item_shp.contains(geometry.shp):
186
+ continue
187
+ groups.append([item])
188
+ if len(groups) >= query_config.max_matches:
189
+ break
190
+ return groups
191
+
192
+
193
+ def match_with_space_mode_intersects(
194
+ geometry: STGeometry,
195
+ items: list[ItemType],
196
+ item_shps: list[shapely.Geometry],
197
+ query_config: QueryConfig,
198
+ ) -> list[list[ItemType]]:
199
+ """Match items that intersect any portion of the window geometry.
200
+
201
+ Args:
202
+ geometry: the window's geometry.
203
+ items: list of items.
204
+ item_shps: the item shapes projected to the window's projection.
205
+ query_config: the query configuration.
206
+
207
+ Returns:
208
+ list of matched item groups, where each group contains a single item.
209
+ """
210
+ groups: list[list[ItemType]] = []
211
+ for item, item_shp in zip(items, item_shps):
212
+ if not shp_intersects(item_shp, geometry.shp):
213
+ continue
214
+ groups.append([item])
215
+ if len(groups) >= query_config.max_matches:
216
+ break
217
+ return groups
218
+
219
+
220
+ def match_with_space_mode_mosaic(
221
+ geometry: STGeometry,
222
+ items: list[ItemType],
223
+ item_shps: list[shapely.Geometry],
224
+ query_config: QueryConfig,
225
+ ) -> list[list[ItemType]]:
226
+ """Match items into mosaic groups that cover the window geometry.
227
+
228
+ Creates groups of items that together cover the window geometry. The number of
229
+ overlapping coverages in each group is controlled by mosaic_compositing_overlaps.
230
+
231
+ Args:
232
+ geometry: the window's geometry.
233
+ items: list of items.
234
+ item_shps: the item shapes projected to the window's projection.
235
+ query_config: the query configuration.
236
+
237
+ Returns:
238
+ list of matched item groups, where each group forms a mosaic covering the
239
+ window.
240
+ """
241
+ overlaps = query_config.mosaic_compositing_overlaps
242
+
243
+ # Calculate how many single-coverage mosaics we need to create.
244
+ # We need enough mosaics to consolidate into max_matches groups with the
245
+ # desired number of overlaps per group.
246
+ max_single_mosaics = query_config.max_matches * overlaps
247
+
248
+ # Create single-coverage mosaics
249
+ single_mosaics = _create_single_coverage_mosaics(
250
+ geometry, items, item_shps, max_single_mosaics
251
+ )
252
+
253
+ # Consolidate into groups based on overlaps
254
+ return _consolidate_mosaics_by_overlaps(
255
+ single_mosaics, overlaps, query_config.max_matches
256
+ )
257
+
258
+
259
+ def match_with_space_mode_per_period_mosaic(
260
+ geometry: STGeometry,
261
+ items: list[ItemType],
262
+ item_shps: list[shapely.Geometry],
263
+ query_config: QueryConfig,
135
264
  ) -> list[list[ItemType]]:
136
265
  """Match items to the geometry with one mosaic per period.
137
266
 
138
267
  We divide the time range of the geometry into shorter periods. Within each period,
139
268
  we use the items corresponding to that period to create a mosaic. The returned item
140
- groups include one group per period, starting from the most recent periods, up to
141
- the provided max_matches.
269
+ groups include one group per period, up to the provided max_matches.
270
+
271
+ By default (reverse_time_order=True), groups are returned starting from the most
272
+ recent periods. When reverse_time_order=False, groups are returned in chronological
273
+ order (oldest first). reverse_time_order should always be set False, and
274
+ FutureWarning will be warned if it is not.
142
275
 
143
276
  The periods are also bounded to the window's time range, and aligned with the end
144
277
  of that time range, i.e. the most recent window is
@@ -159,42 +292,59 @@ def per_period_mosaic_matching(
159
292
  max_matches*period_duration is not equivalent to a longer window duration.
160
293
 
161
294
  Args:
162
- window_geometry: the window geometry to match items to.
163
- item_list: the list of items.
164
- period_duration: the duration of one period.
165
- max_matches: the number of per-period mosaics to create.
295
+ geometry: the window's geometry.
296
+ items: list of items.
297
+ item_shps: the item shapes projected to the window's projection (unused here)
298
+ query_config: the query configuration.
166
299
 
167
300
  Returns:
168
- the matched item groups, where each group contains items that yield a
169
- per-period mosaic.
301
+ list of matched item groups, where each group contains items that yield a
302
+ per-period mosaic.
170
303
  """
171
- if window_geometry.time_range is None:
304
+ if geometry.time_range is None:
172
305
  raise ValueError(
173
306
  "all windows must have time range for per period mosaic matching"
174
307
  )
175
308
 
309
+ # Emit warning if per_period_mosaic_reverse_time_order is True (the default).
310
+ if query_config.per_period_mosaic_reverse_time_order:
311
+ warnings.warn(
312
+ "QueryConfig.per_period_mosaic_reverse_time_order defaults to True, which "
313
+ "returns item groups in reverse temporal order (most recent first) for "
314
+ "PER_PERIOD_MOSAIC mode. This default will change to False (chronological "
315
+ "order) after 2026-04-01. To silence this warning, explicitly set "
316
+ "per_period_mosaic_reverse_time_order=False.",
317
+ FutureWarning,
318
+ stacklevel=3,
319
+ )
320
+
321
+ period_duration = query_config.period_duration
322
+
176
323
  # For each period, we create an STGeometry with modified time range matching that
177
324
  # period, and use it with match_candidate_items_to_window to get a mosaic.
178
325
  cur_groups: list[list[ItemType]] = []
179
- period_start = window_geometry.time_range[1] - period_duration
326
+ period_start = geometry.time_range[1] - period_duration
180
327
  while (
181
- period_start >= window_geometry.time_range[0] and len(cur_groups) < max_matches
328
+ period_start >= geometry.time_range[0]
329
+ and len(cur_groups) < query_config.max_matches
182
330
  ):
183
331
  period_time_range = (
184
332
  period_start,
185
333
  period_start + period_duration,
186
334
  )
187
335
  period_start -= period_duration
188
- period_geom = STGeometry(
189
- window_geometry.projection, window_geometry.shp, period_time_range
190
- )
336
+ period_geom = STGeometry(geometry.projection, geometry.shp, period_time_range)
191
337
 
192
338
  # We modify the QueryConfig here since caller should be asking for
193
339
  # multiple mosaics, but we just want one mosaic per period.
194
340
  period_groups = match_candidate_items_to_window(
195
341
  period_geom,
196
- item_list,
197
- QueryConfig(space_mode=SpaceMode.MOSAIC, max_matches=1),
342
+ items,
343
+ QueryConfig(
344
+ space_mode=SpaceMode.MOSAIC,
345
+ max_matches=1,
346
+ mosaic_compositing_overlaps=query_config.mosaic_compositing_overlaps,
347
+ ),
198
348
  )
199
349
 
200
350
  # There should be zero or one group depending on whether there were
@@ -204,9 +354,29 @@ def per_period_mosaic_matching(
204
354
  continue
205
355
  cur_groups.append(period_groups[0])
206
356
 
357
+ # Currently the item groups are in reverse chronologic order.
358
+ # Reverse it to correct chronological order if requested.
359
+ if not query_config.per_period_mosaic_reverse_time_order:
360
+ cur_groups.reverse()
361
+
207
362
  return cur_groups
208
363
 
209
364
 
365
+ # Type alias for space mode handler functions
366
+ SpaceModeHandler = Callable[
367
+ [STGeometry, list[ItemType], list[shapely.Geometry], QueryConfig],
368
+ list[list[ItemType]],
369
+ ]
370
+
371
+ # Dict mapping SpaceMode values to their handler functions
372
+ space_mode_handlers: dict[SpaceMode, SpaceModeHandler] = {
373
+ SpaceMode.CONTAINS: match_with_space_mode_contains,
374
+ SpaceMode.INTERSECTS: match_with_space_mode_intersects,
375
+ SpaceMode.MOSAIC: match_with_space_mode_mosaic,
376
+ SpaceMode.PER_PERIOD_MOSAIC: match_with_space_mode_per_period_mosaic,
377
+ }
378
+
379
+
210
380
  def match_candidate_items_to_window(
211
381
  geometry: STGeometry, items: list[ItemType], query_config: QueryConfig
212
382
  ) -> list[list[ItemType]]:
@@ -262,43 +432,13 @@ def match_candidate_items_to_window(
262
432
  item_geom = item_geom.to_projection(geometry.projection)
263
433
  item_shps.append(item_geom.shp)
264
434
 
265
- if query_config.space_mode == SpaceMode.CONTAINS:
266
- groups = []
267
- for item, item_shp in zip(items, item_shps):
268
- if not item_shp.contains(geometry.shp):
269
- continue
270
- groups.append([item])
271
- if len(groups) >= query_config.max_matches:
272
- break
273
-
274
- elif query_config.space_mode == SpaceMode.INTERSECTS:
275
- groups = []
276
- for item, item_shp in zip(items, item_shps):
277
- if not shp_intersects(item_shp, geometry.shp):
278
- continue
279
- groups.append([item])
280
- if len(groups) >= query_config.max_matches:
281
- break
282
-
283
- elif query_config.space_mode == SpaceMode.MOSAIC:
284
- groups = mosaic_matching(geometry, items, item_shps, query_config.max_matches)
285
-
286
- elif query_config.space_mode == SpaceMode.PER_PERIOD_MOSAIC:
287
- groups = per_period_mosaic_matching(
288
- geometry, items, query_config.period_duration, query_config.max_matches
289
- )
290
-
291
- elif query_config.space_mode == SpaceMode.COMPOSITE:
292
- group = []
293
- for item, item_shp in zip(items, item_shps):
294
- if not shp_intersects(item_shp, geometry.shp):
295
- continue
296
- group.append(item)
297
- groups = [group]
298
-
299
- else:
435
+ # Dispatch to the appropriate space mode handler
436
+ handler = space_mode_handlers.get(query_config.space_mode)
437
+ if handler is None:
300
438
  raise ValueError(f"invalid space mode {query_config.space_mode}")
301
439
 
440
+ groups = handler(geometry, items, item_shps, query_config)
441
+
302
442
  # Enforce minimum matches if set.
303
443
  if len(groups) < query_config.min_matches:
304
444
  logger.warning(
@@ -236,7 +236,11 @@ def read_and_stack_raster_windows(
236
236
  band_dtype: npt.DTypeLike,
237
237
  resampling_method: Resampling = Resampling.bilinear,
238
238
  ) -> npt.NDArray[np.generic]:
239
- """Create a stack of extent aligned raster windows.
239
+ """Create a stack of raster images, with one per item in the group.
240
+
241
+ We read the portion of each raster item corresponding to the window extent, and
242
+ stack the resulting images. This is used for the MEAN and MEDIAN compositing
243
+ methods to it can compute aggregate statistics across the stack.
240
244
 
241
245
  Args:
242
246
  group: Iterable of items (e.g., scene metadata objects) to read data from.
@@ -105,7 +105,7 @@ class Clay(FeatureExtractor):
105
105
 
106
106
  def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
107
107
  """Resize the image to the input resolution."""
108
- new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
108
+ new_hw = PATCH_SIZE if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
109
109
  return F.interpolate(
110
110
  image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
111
111
  )
@@ -123,7 +123,8 @@ class Clay(FeatureExtractor):
123
123
  device = param.device
124
124
 
125
125
  chips = torch.stack(
126
- [inp[self.modality] for inp in context.inputs], dim=0
126
+ [inp[self.modality].single_ts_to_chw_tensor() for inp in context.inputs],
127
+ dim=0,
127
128
  ) # (B, C, H, W)
128
129
  if self.do_resizing:
129
130
  chips = self._resize_image(chips, chips.shape[2])
@@ -203,7 +204,6 @@ class ClayNormalize(Transform):
203
204
  mean=means,
204
205
  std=stds,
205
206
  selectors=[modality],
206
- num_bands=len(means),
207
207
  )
208
208
  self.normalizers = torch.nn.ModuleDict(normalizers)
209
209
 
@@ -468,7 +468,10 @@ class Detr(Predictor):
468
468
 
469
469
  # Get image sizes.
470
470
  image_sizes = torch.tensor(
471
- [[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
471
+ [
472
+ [inp["image"].image.shape[2], inp["image"].image.shape[1]]
473
+ for inp in context.inputs
474
+ ],
472
475
  dtype=torch.int32,
473
476
  device=features.device,
474
477
  )
rslearn/models/dinov3.py CHANGED
@@ -159,7 +159,6 @@ class DinoV3Normalize(Transform):
159
159
  self.normalize = Normalize(
160
160
  [value * 255 for value in mean],
161
161
  [value * 255 for value in std],
162
- num_bands=3,
163
162
  )
164
163
 
165
164
  def forward(
@@ -95,7 +95,9 @@ class OlmoEarth(FeatureExtractor):
95
95
  """
96
96
  if use_legacy_timestamps:
97
97
  warnings.warn(
98
- "For new projects, don't use legacy timesteps.", DeprecationWarning
98
+ "For new projects, don't use legacy timesteps. "
99
+ "Support will be removed after 2026-04-01.",
100
+ FutureWarning,
99
101
  )
100
102
 
101
103
  if (
@@ -124,6 +124,6 @@ class SegmentationPoolingDecoder(PoolingDecoder):
124
124
  """
125
125
  output_probs = super().forward(intermediates, context)
126
126
  # BC -> BCHW
127
- h, w = context.inputs[0][self.image_key].shape[1:3]
127
+ h, w = context.inputs[0][self.image_key].image.shape[1:3]
128
128
  feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
129
129
  return FeatureMaps([feat_map])
rslearn/models/prithvi.py CHANGED
@@ -230,7 +230,6 @@ class PrithviNormalize(Transform):
230
230
  self.normalizer = Normalize(
231
231
  mean=config["mean"],
232
232
  std=config["std"],
233
- num_bands=len(config["mean"]),
234
233
  selectors=[PrithviV2.INPUT_KEY],
235
234
  )
236
235