rslearn 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. rslearn/models/anysat.py +35 -33
  2. rslearn/models/clip.py +5 -2
  3. rslearn/models/croma.py +11 -3
  4. rslearn/models/dinov3.py +2 -1
  5. rslearn/models/faster_rcnn.py +2 -1
  6. rslearn/models/galileo/galileo.py +58 -31
  7. rslearn/models/module_wrapper.py +6 -1
  8. rslearn/models/molmo.py +4 -2
  9. rslearn/models/olmoearth_pretrain/model.py +93 -29
  10. rslearn/models/olmoearth_pretrain/norm.py +5 -3
  11. rslearn/models/panopticon.py +3 -1
  12. rslearn/models/presto/presto.py +45 -15
  13. rslearn/models/prithvi.py +9 -7
  14. rslearn/models/sam2_enc.py +3 -1
  15. rslearn/models/satlaspretrain.py +4 -1
  16. rslearn/models/simple_time_series.py +36 -16
  17. rslearn/models/ssl4eo_s12.py +19 -14
  18. rslearn/models/swin.py +3 -1
  19. rslearn/models/terramind.py +5 -4
  20. rslearn/train/all_patches_dataset.py +34 -14
  21. rslearn/train/dataset.py +66 -10
  22. rslearn/train/model_context.py +35 -1
  23. rslearn/train/tasks/classification.py +8 -2
  24. rslearn/train/tasks/detection.py +3 -2
  25. rslearn/train/tasks/multi_task.py +2 -3
  26. rslearn/train/tasks/per_pixel_regression.py +14 -5
  27. rslearn/train/tasks/regression.py +8 -2
  28. rslearn/train/tasks/segmentation.py +13 -4
  29. rslearn/train/tasks/task.py +2 -2
  30. rslearn/train/transforms/concatenate.py +45 -5
  31. rslearn/train/transforms/crop.py +22 -8
  32. rslearn/train/transforms/flip.py +13 -5
  33. rslearn/train/transforms/mask.py +11 -2
  34. rslearn/train/transforms/normalize.py +46 -15
  35. rslearn/train/transforms/pad.py +15 -3
  36. rslearn/train/transforms/resize.py +18 -9
  37. rslearn/train/transforms/select_bands.py +11 -2
  38. rslearn/train/transforms/sentinel1.py +18 -3
  39. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
  40. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/RECORD +45 -45
  41. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
  42. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
  43. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
  44. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
  45. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import logging
4
4
  import tempfile
5
+ from datetime import datetime
5
6
 
6
7
  import torch
7
8
  from einops import rearrange, repeat
@@ -118,21 +119,21 @@ class Presto(FeatureExtractor):
118
119
  of each timestep for that pixel
119
120
  """
120
121
  bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
121
- hs = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
122
- ws = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
122
+ ts = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
123
+ hs = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
124
+ ws = [x.shape[4] for x in [s1, s2, era5, srtm] if x is not None]
123
125
  devices = [x.device for x in [s1, s2, era5, srtm] if x is not None]
124
126
 
125
127
  assert len(set(bs)) == 1
126
128
  assert len(set(hs)) == 1
127
129
  assert len(set(ws)) == 1
128
130
  assert len(set(devices)) == 1
129
- b, h, w, device = bs[0], hs[0], ws[0], devices[0]
130
-
131
+ assert len(set(ts)) == 1
132
+ b, h, w, t, device = bs[0], hs[0], ws[0], ts[0], devices[0]
131
133
  # these values will be initialized as
132
134
  # we iterate through the data
133
135
  x: torch.Tensor | None = None
134
136
  mask: torch.Tensor | None = None
135
- t: int | None = None
136
137
 
137
138
  for band_group in [
138
139
  (s1, s1_bands),
@@ -146,14 +147,7 @@ class Presto(FeatureExtractor):
146
147
  else:
147
148
  continue
148
149
 
149
- m_t = data.shape[1] // len(input_bands)
150
- if t is None:
151
- t = m_t
152
- else:
153
- if t != m_t:
154
- raise ValueError("inconsistent values for t")
155
-
156
- data = rearrange(data, "b (t c) h w -> b t h w c", t=m_t)
150
+ data = rearrange(data, "b c t h w -> b t h w c")
157
151
  if x is None:
158
152
  x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
159
153
  if mask is None:
@@ -184,6 +178,23 @@ class Presto(FeatureExtractor):
184
178
  x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
185
179
  return x, mask, dynamic_world.long(), months.long()
186
180
 
181
+ @staticmethod
182
+ def time_ranges_to_timestamps(
183
+ time_ranges: list[tuple[datetime, datetime]],
184
+ device: torch.device,
185
+ ) -> torch.Tensor:
186
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by Presto.
187
+
188
+ Presto only uses the month associated with each timestamp, so we take the midpoint
189
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
190
+ time so that start_time == end_time == mid_time.
191
+ """
192
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
193
+ # months are indexed 0-11
194
+ return torch.tensor(
195
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
196
+ )
197
+
187
198
  def forward(self, context: ModelContext) -> FeatureMaps:
188
199
  """Compute feature maps from the Presto backbone.
189
200
 
@@ -194,17 +205,36 @@ class Presto(FeatureExtractor):
194
205
  a FeatureMaps with one feature map that is at the same resolution as the
195
206
  input (since Presto operates per-pixel).
196
207
  """
208
+ time_modalities = ["s1", "s2", "era5"]
197
209
  stacked_inputs = {}
198
210
  latlons: torch.Tensor | None = None
211
+ months: torch.Tensor | None = None
199
212
  for key in context.inputs[0].keys():
200
213
  # assume all the keys in an input are consistent
201
214
  if key in self.input_keys:
202
215
  if key == "latlon":
203
- latlons = torch.stack([inp[key] for inp in context.inputs], dim=0)
216
+ latlons = torch.stack(
217
+ [inp[key].image for inp in context.inputs], dim=0
218
+ )
204
219
  else:
205
220
  stacked_inputs[key] = torch.stack(
206
- [inp[key] for inp in context.inputs], dim=0
221
+ [inp[key].image for inp in context.inputs], dim=0
207
222
  )
223
+ if key in time_modalities:
224
+ if months is None:
225
+ if context.inputs[0][key].timestamps is not None:
226
+ months = torch.stack(
227
+ [
228
+ self.time_ranges_to_timestamps(
229
+ inp[key].timestamps, # type: ignore
230
+ device=stacked_inputs[key].device,
231
+ )
232
+ for inp in context.inputs
233
+ ],
234
+ dim=0,
235
+ )
236
+ if months is not None:
237
+ stacked_inputs["months"] = months
208
238
 
209
239
  (
210
240
  x,
rslearn/models/prithvi.py CHANGED
@@ -144,13 +144,15 @@ class PrithviV2(FeatureExtractor):
144
144
  """Process individual modality data.
145
145
 
146
146
  Args:
147
- data: Input tensor of shape [B, C, H, W]
147
+ data: Input tensor of shape [B, C, T, H, W]
148
148
 
149
149
  Returns:
150
- list of tensors of shape [B, C, H, W]
150
+ list of tensors of shape [B, C, T, H, W]
151
151
  """
152
152
  # Get original dimensions
153
- original_height = data.shape[2]
153
+ B, C, T, H, W = data.shape
154
+ data = rearrange(data, "b c t h w -> b (c t) h w")
155
+ original_height = H
154
156
  new_height = self.patch_size if original_height == 1 else self.image_resolution
155
157
  data = F.interpolate(
156
158
  data,
@@ -158,6 +160,7 @@ class PrithviV2(FeatureExtractor):
158
160
  mode="bilinear",
159
161
  align_corners=False,
160
162
  )
163
+ data = rearrange(data, "b (c t) h w -> b c t h w", c=C, t=T)
161
164
  return data
162
165
 
163
166
  def forward(self, context: ModelContext) -> FeatureMaps:
@@ -171,17 +174,16 @@ class PrithviV2(FeatureExtractor):
171
174
  a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
172
175
  feature maps across the 11 transformer blocks.
173
176
  """
174
- x = torch.stack([inp[self.INPUT_KEY] for inp in context.inputs], dim=0)
177
+ # x has shape BCTHW
178
+ x = torch.stack([inp[self.INPUT_KEY].image for inp in context.inputs], dim=0)
175
179
  x = self._resize_data(x)
176
- num_timesteps = x.shape[1] // len(self.bands)
177
- x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
178
180
  features = self.model.encoder.forward_features(x)
179
181
  # prepare_features_for_image_model was slightly modified since we already
180
182
  # know the number of timesteps and don't need to recompute it.
181
183
  # in addition we average along the time dimension (instead of concatenating)
182
184
  # to keep the embeddings reasonably sized.
183
185
  result = self.model.encoder.prepare_features_for_image_model(
184
- features, num_timesteps
186
+ features, x.shape[2]
185
187
  )
186
188
  return FeatureMaps([torch.cat(result, dim=1)])
187
189
 
@@ -95,7 +95,9 @@ class SAM2Encoder(FeatureExtractor):
95
95
  Returns:
96
96
  feature maps from the encoder.
97
97
  """
98
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
98
+ images = torch.stack(
99
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
100
+ )
99
101
  features = self.encoder(images)
100
102
  return FeatureMaps(features)
101
103
 
@@ -76,7 +76,10 @@ class SatlasPretrain(FeatureExtractor):
76
76
  Returns:
77
77
  multi-resolution feature maps computed by the model.
78
78
  """
79
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
79
+ # take the first (assumed to be only) timestep
80
+ images = torch.stack(
81
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
82
+ )
80
83
  feature_maps = self.model(self.maybe_resize(images))
81
84
  return FeatureMaps(feature_maps)
82
85
 
@@ -3,8 +3,9 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ from einops import rearrange
6
7
 
7
- from rslearn.train.model_context import ModelContext
8
+ from rslearn.train.model_context import ModelContext, RasterImage
8
9
 
9
10
  from .component import FeatureExtractor, FeatureMaps
10
11
 
@@ -163,23 +164,44 @@ class SimpleTimeSeries(FeatureExtractor):
163
164
 
164
165
  def _get_batched_images(
165
166
  self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
166
- ) -> torch.Tensor:
167
+ ) -> list[RasterImage]:
167
168
  """Collect and reshape images across input dicts.
168
169
 
169
170
  The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
170
171
  the forward pass of a per-image (unitemporal) model.
171
172
  """
172
173
  images = torch.stack(
173
- [input_dict[image_key] for input_dict in input_dicts], dim=0
174
+ [input_dict[image_key].image for input_dict in input_dicts], dim=0
175
+ ) # B, C, T, H, W
176
+ timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
177
+ # if image channels is not equal to the actual number of channels, then
178
+ # then every N images should be batched together. For example, if the
179
+ # number of input channels c == 2, and image_channels == 4, then we
180
+ # want to pass 2 timesteps to the model.
181
+ # TODO is probably to make this behaviour clearer but lets leave it like
182
+ # this for now to not break things.
183
+ num_timesteps = images.shape[1] // image_channels
184
+ batched_timesteps = images.shape[2] // num_timesteps
185
+ images = rearrange(
186
+ images,
187
+ "b c (b_t k_t) h w -> (b b_t) c k_t h w",
188
+ b_t=batched_timesteps,
189
+ k_t=num_timesteps,
174
190
  )
175
- n_batch = images.shape[0]
176
- n_images = images.shape[1] // image_channels
177
- n_height = images.shape[2]
178
- n_width = images.shape[3]
179
- batched_images = images.reshape(
180
- n_batch * n_images, image_channels, n_height, n_width
181
- )
182
- return batched_images
191
+ if timestamps[0] is None:
192
+ new_timestamps = [None] * images.shape[0]
193
+ else:
194
+ # we also need to split the timestamps
195
+ new_timestamps = []
196
+ for t in timestamps:
197
+ for i in range(batched_timesteps):
198
+ new_timestamps.append(
199
+ t[i * num_timesteps : (i + 1) * num_timesteps]
200
+ )
201
+ return [
202
+ RasterImage(image=image, timestamps=timestamps)
203
+ for image, timestamps in zip(images, new_timestamps)
204
+ ] # C, T, H, W
183
205
 
184
206
  def forward(
185
207
  self,
@@ -208,8 +230,8 @@ class SimpleTimeSeries(FeatureExtractor):
208
230
 
209
231
  if batched_inputs is None:
210
232
  batched_inputs = [{} for _ in batched_images]
211
- n_images = batched_images.shape[0] // n_batch
212
- elif n_images != batched_images.shape[0] // n_batch:
233
+ n_images = len(batched_images) // n_batch
234
+ elif n_images != len(batched_images) // n_batch:
213
235
  raise ValueError(
214
236
  "expected all modalities to have the same number of timesteps"
215
237
  )
@@ -223,10 +245,9 @@ class SimpleTimeSeries(FeatureExtractor):
223
245
  context.inputs, self.image_key, self.image_channels
224
246
  )
225
247
  batched_inputs = [{self.image_key: image} for image in batched_images]
226
- n_images = batched_images.shape[0] // n_batch
248
+ n_images = len(batched_images) // n_batch
227
249
 
228
250
  assert n_images is not None
229
-
230
251
  # Now we can apply the underlying FeatureExtractor.
231
252
  # Its output must be a FeatureMaps.
232
253
  assert batched_inputs is not None
@@ -250,7 +271,6 @@ class SimpleTimeSeries(FeatureExtractor):
250
271
  )
251
272
  for feat_map in encoder_output.feature_maps
252
273
  ]
253
-
254
274
  # Groups defaults to flattening all the feature maps.
255
275
  groups = self.groups
256
276
  if not groups:
@@ -13,7 +13,7 @@ class Ssl4eoS12(FeatureExtractor):
13
13
 
14
14
  def __init__(
15
15
  self,
16
- backbone_ckpt_path: str,
16
+ backbone_ckpt_path: str | None,
17
17
  arch: str = "resnet50",
18
18
  output_layers: list[int] = [0, 1, 2, 3],
19
19
  ) -> None:
@@ -39,19 +39,22 @@ class Ssl4eoS12(FeatureExtractor):
39
39
  else:
40
40
  raise ValueError(f"unknown SSL4EO-S12 architecture {arch}")
41
41
 
42
- state_dict = torch.load(backbone_ckpt_path, weights_only=True)
43
- state_dict = state_dict["teacher"]
44
- prefix = "module.backbone."
45
- state_dict = {
46
- k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
47
- }
48
- missing_keys, unexpected_keys = self.model.load_state_dict(
49
- state_dict, strict=False
50
- )
51
- if missing_keys or unexpected_keys:
52
- print(
53
- f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
42
+ if backbone_ckpt_path is not None:
43
+ state_dict = torch.load(backbone_ckpt_path, weights_only=True)
44
+ state_dict = state_dict["teacher"]
45
+ prefix = "module.backbone."
46
+ state_dict = {
47
+ k[len(prefix) :]: v
48
+ for k, v in state_dict.items()
49
+ if k.startswith(prefix)
50
+ }
51
+ missing_keys, unexpected_keys = self.model.load_state_dict(
52
+ state_dict, strict=False
54
53
  )
54
+ if missing_keys or unexpected_keys:
55
+ print(
56
+ f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
57
+ )
55
58
 
56
59
  def get_backbone_channels(self) -> list[tuple[int, int]]:
57
60
  """Returns the output channels of this model when used as a backbone.
@@ -91,7 +94,9 @@ class Ssl4eoS12(FeatureExtractor):
91
94
  Returns:
92
95
  feature maps computed by the pre-trained model.
93
96
  """
94
- x = torch.stack([inp["image"] for inp in context.inputs], dim=0)
97
+ x = torch.stack(
98
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
99
+ )
95
100
  x = self.model.conv1(x)
96
101
  x = self.model.bn1(x)
97
102
  x = self.model.relu(x)
rslearn/models/swin.py CHANGED
@@ -151,7 +151,9 @@ class Swin(FeatureExtractor):
151
151
  a FeatureVector if the configured output_layers is None, or a FeatureMaps
152
152
  otherwise containing one feature map per configured output layer.
153
153
  """
154
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
154
+ images = torch.stack(
155
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
156
+ )
155
157
 
156
158
  if self.output_layers:
157
159
  layer_features = []
@@ -143,7 +143,8 @@ class Terramind(FeatureExtractor):
143
143
  if modality not in context.inputs[0]:
144
144
  continue
145
145
  cur = torch.stack(
146
- [inp[modality] for inp in context.inputs], dim=0
146
+ [inp[modality].single_ts_to_chw_tensor() for inp in context.inputs],
147
+ dim=0,
147
148
  ) # (B, C, H, W)
148
149
  if self.do_resizing and (
149
150
  cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
@@ -219,7 +220,7 @@ class TerramindNormalize(Transform):
219
220
  Returns:
220
221
  The normalized image.
221
222
  """
222
- images = image.float() # (C, H, W)
223
+ images = image.float() # (C, 1, H, W)
223
224
  if images.shape[0] % len(means) != 0:
224
225
  raise ValueError(
225
226
  f"the number of image channels {images.shape[0]} is not multiple of expected number of bands {len(means)}"
@@ -247,8 +248,8 @@ class TerramindNormalize(Transform):
247
248
  band_info = PRETRAINED_BANDS[modality]
248
249
  means = [band_info[band][0] for band in band_info]
249
250
  stds = [band_info[band][1] for band in band_info]
250
- input_dict[modality] = self.apply_image(
251
- input_dict[modality],
251
+ input_dict[modality].image = self.apply_image(
252
+ input_dict[modality].image,
252
253
  means,
253
254
  stds,
254
255
  )
@@ -10,7 +10,7 @@ import torch
10
10
 
11
11
  from rslearn.dataset import Window
12
12
  from rslearn.train.dataset import DataInput, ModelDataset
13
- from rslearn.train.model_context import SampleMetadata
13
+ from rslearn.train.model_context import RasterImage, SampleMetadata
14
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
15
15
 
16
16
 
@@ -99,6 +99,30 @@ def pad_slice_protect(
99
99
  return raw_inputs, passthrough_inputs
100
100
 
101
101
 
102
+ def crop_tensor_or_rasterimage(
103
+ x: torch.Tensor | RasterImage, start: tuple[int, int], end: tuple[int, int]
104
+ ) -> torch.Tensor | RasterImage:
105
+ """Crop a tensor or a RasterImage."""
106
+ if isinstance(x, torch.Tensor):
107
+ # Crop the CHW tensor with scaled coordinates.
108
+ return x[
109
+ :,
110
+ start[1] : end[1],
111
+ start[0] : end[0],
112
+ ].clone()
113
+ else:
114
+ # Crop the CTHW tensor with scaled coordinates.
115
+ return RasterImage(
116
+ x.image[
117
+ :,
118
+ :,
119
+ start[1] : end[1],
120
+ start[0] : end[0],
121
+ ].clone(),
122
+ x.timestamps,
123
+ )
124
+
125
+
102
126
  class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
103
127
  """This wraps a ModelDataset to iterate over all patches in that dataset.
104
128
 
@@ -281,7 +305,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
281
305
  def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
282
306
  cropped = {}
283
307
  for input_name, value in d.items():
284
- if isinstance(value, torch.Tensor):
308
+ if isinstance(value, torch.Tensor | RasterImage):
285
309
  # Get resolution scale for this input
286
310
  rf = self.inputs[input_name].resolution_factor
287
311
  scale = rf.numerator / rf.denominator
@@ -294,12 +318,9 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
294
318
  int(end_offset[0] * scale),
295
319
  int(end_offset[1] * scale),
296
320
  )
297
- # Crop the CHW tensor with scaled coordinates.
298
- cropped[input_name] = value[
299
- :,
300
- scaled_start[1] : scaled_end[1],
301
- scaled_start[0] : scaled_end[0],
302
- ].clone()
321
+ cropped[input_name] = crop_tensor_or_rasterimage(
322
+ value, scaled_start, scaled_end
323
+ )
303
324
  elif isinstance(value, list):
304
325
  cropped[input_name] = [
305
326
  feat
@@ -429,7 +450,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
429
450
  """
430
451
  cropped = {}
431
452
  for input_name, value in d.items():
432
- if isinstance(value, torch.Tensor):
453
+ if isinstance(value, torch.Tensor | RasterImage):
433
454
  # Get resolution scale for this input
434
455
  rf = self.inputs[input_name].resolution_factor
435
456
  scale = rf.numerator / rf.denominator
@@ -442,11 +463,10 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
442
463
  int(end_offset[0] * scale),
443
464
  int(end_offset[1] * scale),
444
465
  )
445
- cropped[input_name] = value[
446
- :,
447
- scaled_start[1] : scaled_end[1],
448
- scaled_start[0] : scaled_end[0],
449
- ].clone()
466
+ cropped[input_name] = crop_tensor_or_rasterimage(
467
+ value, scaled_start, scaled_end
468
+ )
469
+
450
470
  elif isinstance(value, list):
451
471
  cropped[input_name] = [
452
472
  feat for feat in value if cur_geom.intersects(feat.geometry)
rslearn/train/dataset.py CHANGED
@@ -8,6 +8,7 @@ import random
8
8
  import tempfile
9
9
  import time
10
10
  import uuid
11
+ from datetime import datetime
11
12
  from typing import Any
12
13
 
13
14
  import torch
@@ -19,10 +20,16 @@ from rslearn.config import (
19
20
  DType,
20
21
  LayerConfig,
21
22
  )
23
+ from rslearn.data_sources.data_source import Item
22
24
  from rslearn.dataset.dataset import Dataset
23
25
  from rslearn.dataset.storage.file import FileWindowStorage
24
- from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
26
+ from rslearn.dataset.window import (
27
+ Window,
28
+ WindowLayerData,
29
+ get_layer_and_group_from_dir_name,
30
+ )
25
31
  from rslearn.log_utils import get_logger
32
+ from rslearn.train.model_context import RasterImage
26
33
  from rslearn.utils.feature import Feature
27
34
  from rslearn.utils.geometry import PixelBounds, ResolutionFactor
28
35
  from rslearn.utils.mp import star_imap_unordered
@@ -198,7 +205,8 @@ def read_raster_layer_for_data_input(
198
205
  group_idx: int,
199
206
  layer_config: LayerConfig,
200
207
  data_input: DataInput,
201
- ) -> torch.Tensor:
208
+ layer_data: WindowLayerData | None,
209
+ ) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
202
210
  """Read a raster layer for a DataInput.
203
211
 
204
212
  This scans the available rasters for the layer at the window to determine which
@@ -211,9 +219,11 @@ def read_raster_layer_for_data_input(
211
219
  group_idx: the item group.
212
220
  layer_config: the layer configuration.
213
221
  data_input: the DataInput that specifies the bands and dtype.
222
+ layer_data: the WindowLayerData associated with this layer and window.
214
223
 
215
224
  Returns:
216
- tensor containing raster data.
225
+ RasterImage containing raster data and the timestamp associated
226
+ with that data.
217
227
  """
218
228
  # See what different sets of bands we need to read to get all the
219
229
  # configured bands.
@@ -284,7 +294,34 @@ def read_raster_layer_for_data_input(
284
294
  src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
285
295
  )
286
296
 
287
- return image
297
+ # add the timestamp. this is a tuple defining the start and end of the time range.
298
+ time_range = None
299
+ if layer_data is not None:
300
+ item = Item.deserialize(layer_data.serialized_item_groups[group_idx][0])
301
+ if item.geometry.time_range is not None:
302
+ # we assume if one layer data has a geometry & time range, all of them do
303
+ time_ranges = [
304
+ (
305
+ datetime.fromisoformat(
306
+ Item.deserialize(
307
+ layer_data.serialized_item_groups[group_idx][idx]
308
+ ).geometry.time_range[0] # type: ignore
309
+ ),
310
+ datetime.fromisoformat(
311
+ Item.deserialize(
312
+ layer_data.serialized_item_groups[group_idx][idx]
313
+ ).geometry.time_range[1] # type: ignore
314
+ ),
315
+ )
316
+ for idx in range(len(layer_data.serialized_item_groups[group_idx]))
317
+ ]
318
+ # take the min and max
319
+ time_range = (
320
+ min([t[0] for t in time_ranges]),
321
+ max([t[1] for t in time_ranges]),
322
+ )
323
+
324
+ return image, time_range
288
325
 
289
326
 
290
327
  def read_data_input(
@@ -293,7 +330,7 @@ def read_data_input(
293
330
  bounds: PixelBounds,
294
331
  data_input: DataInput,
295
332
  rng: random.Random,
296
- ) -> torch.Tensor | list[Feature]:
333
+ ) -> RasterImage | list[Feature]:
297
334
  """Read the data specified by the DataInput from the window.
298
335
 
299
336
  Args:
@@ -335,15 +372,34 @@ def read_data_input(
335
372
  layers_to_read = [rng.choice(layer_options)]
336
373
 
337
374
  if data_input.data_type == "raster":
375
+ # load it once here
376
+ layer_datas = window.load_layer_datas()
338
377
  images: list[torch.Tensor] = []
378
+ time_ranges: list[tuple[datetime, datetime] | None] = []
339
379
  for layer_name, group_idx in layers_to_read:
340
380
  layer_config = dataset.layers[layer_name]
341
- images.append(
342
- read_raster_layer_for_data_input(
343
- window, bounds, layer_name, group_idx, layer_config, data_input
344
- )
381
+ image, time_range = read_raster_layer_for_data_input(
382
+ window,
383
+ bounds,
384
+ layer_name,
385
+ group_idx,
386
+ layer_config,
387
+ data_input,
388
+ # some layers (e.g. "label_raster") won't have associated
389
+ # layer datas
390
+ layer_datas[layer_name] if layer_name in layer_datas else None,
345
391
  )
346
- return torch.cat(images, dim=0)
392
+ if len(time_ranges) > 0:
393
+ if type(time_ranges[-1]) is not type(time_range):
394
+ raise ValueError(
395
+ f"All time ranges should be datetime tuples or None. Got {type(time_range)} amd {type(time_ranges[-1])}"
396
+ )
397
+ images.append(image)
398
+ time_ranges.append(time_range)
399
+ return RasterImage(
400
+ torch.stack(images, dim=1),
401
+ time_ranges if time_ranges[0] is not None else None, # type: ignore
402
+ )
347
403
 
348
404
  elif data_input.data_type == "vector":
349
405
  # We don't really support time series for vector data currently, we just
@@ -10,6 +10,40 @@ import torch
10
10
  from rslearn.utils.geometry import PixelBounds, Projection
11
11
 
12
12
 
13
+ @dataclass
14
+ class RasterImage:
15
+ """A raster image is a torch.tensor containing the images and their associated timestamps."""
16
+
17
+ # image is a 4D CTHW tensor
18
+ image: torch.Tensor
19
+ # if timestamps is not None, len(timestamps) must match the T dimension of the tensor
20
+ timestamps: list[tuple[datetime, datetime]] | None = None
21
+
22
+ @property
23
+ def shape(self) -> torch.Size:
24
+ """The shape of the image."""
25
+ return self.image.shape
26
+
27
+ def dim(self) -> int:
28
+ """The dim of the image."""
29
+ return self.image.dim()
30
+
31
+ @property
32
+ def dtype(self) -> torch.dtype:
33
+ """The image dtype."""
34
+ return self.image.dtype
35
+
36
+ def single_ts_to_chw_tensor(self) -> torch.Tensor:
37
+ """Single timestep models expect single timestep inputs.
38
+
39
+ This function (1) checks this raster image only has 1 timestep and
40
+ (2) returns the tensor for that (single) timestep (going from CTHW to CHW).
41
+ """
42
+ if self.image.shape[1] != 1:
43
+ raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
44
+ return self.image[:, 0]
45
+
46
+
13
47
  @dataclass
14
48
  class SampleMetadata:
15
49
  """Metadata pertaining to an example."""
@@ -32,7 +66,7 @@ class ModelContext:
32
66
  """Context to pass to all model components."""
33
67
 
34
68
  # One input dict per example in the batch.
35
- inputs: list[dict[str, torch.Tensor]]
69
+ inputs: list[dict[str, torch.Tensor | RasterImage]]
36
70
  # One SampleMetadata per example in the batch.
37
71
  metadatas: list[SampleMetadata]
38
72
  # Arbitrary dict that components can add to.