rslearn 0.0.19__tar.gz → 0.0.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. {rslearn-0.0.19/rslearn.egg-info → rslearn-0.0.20}/PKG-INFO +1 -1
  2. {rslearn-0.0.19 → rslearn-0.0.20}/pyproject.toml +1 -1
  3. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/anysat.py +35 -33
  4. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/clip.py +5 -2
  5. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/croma.py +11 -3
  6. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/dinov3.py +2 -1
  7. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/faster_rcnn.py +2 -1
  8. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/galileo/galileo.py +58 -31
  9. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/module_wrapper.py +6 -1
  10. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/molmo.py +4 -2
  11. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/model.py +93 -29
  12. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/norm.py +5 -3
  13. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon.py +3 -1
  14. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/presto/presto.py +45 -15
  15. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/prithvi.py +9 -7
  16. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/sam2_enc.py +3 -1
  17. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/satlaspretrain.py +4 -1
  18. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/simple_time_series.py +36 -16
  19. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/ssl4eo_s12.py +19 -14
  20. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/swin.py +3 -1
  21. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/terramind.py +5 -4
  22. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/all_patches_dataset.py +34 -14
  23. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/dataset.py +66 -10
  24. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/model_context.py +35 -1
  25. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/classification.py +8 -2
  26. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/detection.py +3 -2
  27. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/multi_task.py +2 -3
  28. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/per_pixel_regression.py +14 -5
  29. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/regression.py +8 -2
  30. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/segmentation.py +13 -4
  31. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/task.py +2 -2
  32. rslearn-0.0.20/rslearn/train/transforms/concatenate.py +89 -0
  33. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/crop.py +22 -8
  34. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/flip.py +13 -5
  35. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/mask.py +11 -2
  36. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/normalize.py +46 -15
  37. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/pad.py +15 -3
  38. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/resize.py +18 -9
  39. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/select_bands.py +11 -2
  40. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/sentinel1.py +18 -3
  41. {rslearn-0.0.19 → rslearn-0.0.20/rslearn.egg-info}/PKG-INFO +1 -1
  42. rslearn-0.0.19/rslearn/train/transforms/concatenate.py +0 -49
  43. {rslearn-0.0.19 → rslearn-0.0.20}/LICENSE +0 -0
  44. {rslearn-0.0.19 → rslearn-0.0.20}/NOTICE +0 -0
  45. {rslearn-0.0.19 → rslearn-0.0.20}/README.md +0 -0
  46. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/__init__.py +0 -0
  47. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/arg_parser.py +0 -0
  48. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/config/__init__.py +0 -0
  49. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/config/dataset.py +0 -0
  50. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/const.py +0 -0
  51. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/__init__.py +0 -0
  52. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/aws_landsat.py +0 -0
  53. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/aws_open_data.py +0 -0
  54. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/aws_sentinel1.py +0 -0
  55. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/climate_data_store.py +0 -0
  56. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/copernicus.py +0 -0
  57. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/data_source.py +0 -0
  58. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/earthdaily.py +0 -0
  59. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/earthdata_srtm.py +0 -0
  60. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/eurocrops.py +0 -0
  61. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/gcp_public_data.py +0 -0
  62. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/google_earth_engine.py +0 -0
  63. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/local_files.py +0 -0
  64. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/openstreetmap.py +0 -0
  65. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/planet.py +0 -0
  66. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/planet_basemap.py +0 -0
  67. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/planetary_computer.py +0 -0
  68. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/usda_cdl.py +0 -0
  69. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/usgs_landsat.py +0 -0
  70. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/utils.py +0 -0
  71. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/vector_source.py +0 -0
  72. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/worldcereal.py +0 -0
  73. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/worldcover.py +0 -0
  74. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/worldpop.py +0 -0
  75. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/xyz_tiles.py +0 -0
  76. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/__init__.py +0 -0
  77. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/add_windows.py +0 -0
  78. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/dataset.py +0 -0
  79. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/handler_summaries.py +0 -0
  80. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/manage.py +0 -0
  81. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/materialize.py +0 -0
  82. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/remap.py +0 -0
  83. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/storage/__init__.py +0 -0
  84. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/storage/file.py +0 -0
  85. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/storage/storage.py +0 -0
  86. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/window.py +0 -0
  87. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/lightning_cli.py +0 -0
  88. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/log_utils.py +0 -0
  89. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/main.py +0 -0
  90. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/__init__.py +0 -0
  91. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/attention_pooling.py +0 -0
  92. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/clay/clay.py +0 -0
  93. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/clay/configs/metadata.yaml +0 -0
  94. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/component.py +0 -0
  95. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/concatenate_features.py +0 -0
  96. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/conv.py +0 -0
  97. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/__init__.py +0 -0
  98. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/box_ops.py +0 -0
  99. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/detr.py +0 -0
  100. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/matcher.py +0 -0
  101. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/position_encoding.py +0 -0
  102. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/transformer.py +0 -0
  103. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/util.py +0 -0
  104. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/feature_center_crop.py +0 -0
  105. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/fpn.py +0 -0
  106. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/galileo/__init__.py +0 -0
  107. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/galileo/single_file_galileo.py +0 -0
  108. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/multitask.py +0 -0
  109. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  110. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  111. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  112. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  113. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  114. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  115. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  116. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  117. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  118. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  119. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  120. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  121. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  122. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/pick_features.py +0 -0
  123. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/pooling_decoder.py +0 -0
  124. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/presto/__init__.py +0 -0
  125. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/presto/single_file_presto.py +0 -0
  126. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/resize_features.py +0 -0
  127. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/singletask.py +0 -0
  128. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/task_embedding.py +0 -0
  129. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/trunk.py +0 -0
  130. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/unet.py +0 -0
  131. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/upsample.py +0 -0
  132. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/use_croma.py +0 -0
  133. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/py.typed +0 -0
  134. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/template_params.py +0 -0
  135. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/tile_stores/__init__.py +0 -0
  136. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/tile_stores/default.py +0 -0
  137. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/tile_stores/tile_store.py +0 -0
  138. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/__init__.py +0 -0
  139. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/__init__.py +0 -0
  140. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/adapters.py +0 -0
  141. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  142. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/gradients.py +0 -0
  143. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/peft.py +0 -0
  144. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/data_module.py +0 -0
  145. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/lightning_module.py +0 -0
  146. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/optimizer.py +0 -0
  147. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/prediction_writer.py +0 -0
  148. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/scheduler.py +0 -0
  149. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/__init__.py +0 -0
  150. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/embedding.py +0 -0
  151. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/__init__.py +0 -0
  152. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/transform.py +0 -0
  153. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/__init__.py +0 -0
  154. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/array.py +0 -0
  155. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/feature.py +0 -0
  156. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/fsspec.py +0 -0
  157. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/geometry.py +0 -0
  158. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/get_utm_ups_crs.py +0 -0
  159. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/grid_index.py +0 -0
  160. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/jsonargparse.py +0 -0
  161. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/mp.py +0 -0
  162. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/raster_format.py +0 -0
  163. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/rtree_index.py +0 -0
  164. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/spatial_index.py +0 -0
  165. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/sqlite_index.py +0 -0
  166. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/time.py +0 -0
  167. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/vector_format.py +0 -0
  168. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/SOURCES.txt +0 -0
  169. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/dependency_links.txt +0 -0
  170. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/entry_points.txt +0 -0
  171. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/requires.txt +0 -0
  172. {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/top_level.txt +0 -0
  173. {rslearn-0.0.19 → rslearn-0.0.20}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.19
3
+ Version: 0.0.20
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.19"
3
+ version = "0.0.20"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -4,6 +4,8 @@ This code loads the AnySat model from torch hub. See
4
4
  https://github.com/gastruc/AnySat for applicable license and copyright information.
5
5
  """
6
6
 
7
+ from datetime import datetime
8
+
7
9
  import torch
8
10
  from einops import rearrange
9
11
 
@@ -53,7 +55,6 @@ class AnySat(FeatureExtractor):
53
55
  self,
54
56
  modalities: list[str],
55
57
  patch_size_meters: int,
56
- dates: dict[str, list[int]],
57
58
  output: str = "patch",
58
59
  output_modality: str | None = None,
59
60
  hub_repo: str = "gastruc/anysat",
@@ -85,14 +86,6 @@ class AnySat(FeatureExtractor):
85
86
  if m not in MODALITY_RESOLUTIONS:
86
87
  raise ValueError(f"Invalid modality: {m}")
87
88
 
88
- if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
89
- raise ValueError("`dates` keys must be time-series modalities only.")
90
- for m in modalities:
91
- if m in TIME_SERIES_MODALITIES and m not in dates:
92
- raise ValueError(
93
- f"Missing required dates for time-series modality '{m}'."
94
- )
95
-
96
89
  if patch_size_meters % 10 != 0:
97
90
  raise ValueError(
98
91
  "In AnySat, `patch_size` is in meters and must be a multiple of 10."
@@ -106,7 +99,6 @@ class AnySat(FeatureExtractor):
106
99
 
107
100
  self.modalities = modalities
108
101
  self.patch_size_meters = int(patch_size_meters)
109
- self.dates = dates
110
102
  self.output = output
111
103
  self.output_modality = output_modality
112
104
 
@@ -119,6 +111,20 @@ class AnySat(FeatureExtractor):
119
111
  )
120
112
  self._embed_dim = 768 # base width, 'dense' returns 2x
121
113
 
114
+ @staticmethod
115
+ def time_ranges_to_doy(
116
+ time_ranges: list[tuple[datetime, datetime]],
117
+ device: torch.device,
118
+ ) -> torch.Tensor:
119
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
120
+
121
+ AnySat uses the doy with each timestamp, so we take the midpoint
122
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
123
+ time so that start_time == end_time == mid_time.
124
+ """
125
+ doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
126
+ return torch.tensor(doys, dtype=torch.int32, device=device)
127
+
122
128
  def forward(self, context: ModelContext) -> FeatureMaps:
123
129
  """Forward pass for the AnySat model.
124
130
 
@@ -139,17 +145,29 @@ class AnySat(FeatureExtractor):
139
145
  raise ValueError(f"Modality '{modality}' not present in inputs.")
140
146
 
141
147
  cur = torch.stack(
142
- [inp[modality] for inp in inputs], dim=0
143
- ) # (B, C, H, W) or (B, T*C, H, W)
148
+ [inp[modality].image for inp in inputs], dim=0
149
+ ) # (B, C, T, H, W)
144
150
 
145
151
  if modality in TIME_SERIES_MODALITIES:
146
- num_dates = len(self.dates[modality])
147
- num_bands = cur.shape[1] // num_dates
148
- cur = rearrange(
149
- cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
150
- )
152
+ num_bands = cur.shape[1]
153
+ cur = rearrange(cur, "b c t h w -> b t c h w")
151
154
  H, W = cur.shape[-2], cur.shape[-1]
155
+
156
+ if inputs[0][modality].timestamps is None:
157
+ raise ValueError(
158
+ f"Require timestamps for time series modality {modality}"
159
+ )
160
+ timestamps = torch.stack(
161
+ [
162
+ self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
163
+ for inp in inputs
164
+ ],
165
+ dim=0,
166
+ )
167
+ batch[f"{modality}_dates"] = timestamps
152
168
  else:
169
+ # take the first (assumed only) timestep
170
+ cur = cur[:, :, 0]
153
171
  num_bands = cur.shape[1]
154
172
  H, W = cur.shape[-2], cur.shape[-1]
155
173
 
@@ -173,22 +191,6 @@ class AnySat(FeatureExtractor):
173
191
  "All modalities must share the same spatial extent (H*res, W*res)."
174
192
  )
175
193
 
176
- # Add *_dates
177
- to_add = {}
178
- for modality, x in list(batch.items()):
179
- if modality in TIME_SERIES_MODALITIES:
180
- B, T = x.shape[0], x.shape[1]
181
- d = torch.as_tensor(
182
- self.dates[modality], dtype=torch.long, device=x.device
183
- )
184
- if d.ndim != 1 or d.numel() != T:
185
- raise ValueError(
186
- f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
187
- )
188
- to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
189
-
190
- batch.update(to_add)
191
-
192
194
  kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
193
195
  if self.output == "dense":
194
196
  kwargs["output_modality"] = self.output_modality
@@ -43,9 +43,12 @@ class CLIP(FeatureExtractor):
43
43
  a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
44
44
  """
45
45
  inputs = context.inputs
46
- device = inputs[0]["image"].device
46
+ device = inputs[0]["image"].image.device
47
47
  clip_inputs = self.processor(
48
- images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
48
+ images=[
49
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
50
+ for inp in inputs
51
+ ],
49
52
  return_tensors="pt",
50
53
  padding=True,
51
54
  )
@@ -175,10 +175,16 @@ class Croma(FeatureExtractor):
175
175
  sentinel1: torch.Tensor | None = None
176
176
  sentinel2: torch.Tensor | None = None
177
177
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
178
- sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
178
+ sentinel1 = torch.stack(
179
+ [inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
180
+ dim=0,
181
+ )
179
182
  sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
180
183
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
181
- sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
184
+ sentinel2 = torch.stack(
185
+ [inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
186
+ dim=0,
187
+ )
182
188
  sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
183
189
 
184
190
  outputs = self.model(
@@ -294,5 +300,7 @@ class CromaNormalize(Transform):
294
300
  for modality in MODALITY_BANDS.keys():
295
301
  if modality not in input_dict:
296
302
  continue
297
- input_dict[modality] = self.apply_image(input_dict[modality], modality)
303
+ input_dict[modality].image = self.apply_image(
304
+ input_dict[modality].image, modality
305
+ )
298
306
  return input_dict, target_dict
@@ -104,7 +104,8 @@ class DinoV3(FeatureExtractor):
104
104
  a FeatureMaps with one feature map.
105
105
  """
106
106
  cur = torch.stack(
107
- [inp["image"] for inp in context.inputs], dim=0
107
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
108
+ dim=0,
108
109
  ) # (B, C, H, W)
109
110
 
110
111
  if self.do_resizing and (
@@ -210,7 +210,8 @@ class FasterRCNN(Predictor):
210
210
  ),
211
211
  )
212
212
 
213
- image_list = [inp["image"] for inp in context.inputs]
213
+ # take the first (and assumed to be only) timestep
214
+ image_list = [inp["image"].image[:, 0] for inp in context.inputs]
214
215
  images, targets = self.noop_transform(image_list, targets)
215
216
 
216
217
  feature_dict = collections.OrderedDict()
@@ -3,6 +3,7 @@
3
3
  import math
4
4
  import tempfile
5
5
  from contextlib import nullcontext
6
+ from datetime import datetime
6
7
  from enum import StrEnum
7
8
  from typing import cast
8
9
 
@@ -411,6 +412,23 @@ class GalileoModel(FeatureExtractor):
411
412
  months=months,
412
413
  )
413
414
 
415
+ @staticmethod
416
+ def time_ranges_to_timestamps(
417
+ time_ranges: list[tuple[datetime, datetime]],
418
+ device: torch.device,
419
+ ) -> torch.Tensor:
420
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by Galileo.
421
+
422
+ Galileo only uses the month associated with each timestamp, so we take the midpoint
423
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
424
+ time so that start_time == end_time == mid_time.
425
+ """
426
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
427
+ # months are indexed 0-11
428
+ return torch.tensor(
429
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
430
+ )
431
+
414
432
  def forward(self, context: ModelContext) -> FeatureMaps:
415
433
  """Compute feature maps from the Galileo backbone.
416
434
 
@@ -418,16 +436,16 @@ class GalileoModel(FeatureExtractor):
418
436
  context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
419
437
  (also documented below) and values are tensors of the following shapes,
420
438
  per input key:
421
- "s1": B (T * C) H W
422
- "s2": B (T * C) H W
423
- "era5": B (T * C) H W (we will average over the H, W dimensions)
424
- "tc": B (T * C) H W (we will average over the H, W dimensions)
425
- "viirs": B (T * C) H W (we will average over the H, W dimensions)
426
- "srtm": B C H W (SRTM has no temporal dimension)
427
- "dw": : B C H W (Dynamic World should be averaged over time)
428
- "wc": B C H W (WorldCereal has no temporal dimension)
429
- "landscan": B C H W (we will average over the H, W dimensions)
430
- "latlon": B C H W (we will average over the H, W dimensions)
439
+ "s1": B C T H W
440
+ "s2": B C T H W
441
+ "era5": B C T H W (we will average over the H, W dimensions)
442
+ "tc": B C T H W (we will average over the H, W dimensions)
443
+ "viirs": B C T H W (we will average over the H, W dimensions)
444
+ "srtm": B C 1 H W (SRTM has no temporal dimension)
445
+ "dw": : B C 1 H W (Dynamic World should be averaged over time)
446
+ "wc": B C 1 H W (WorldCereal has no temporal dimension)
447
+ "landscan": B C 1 H W (we will average over the H, W dimensions)
448
+ "latlon": B C 1 H W (we will average over the H, W dimensions)
431
449
 
432
450
  The output will be an embedding representing the pooled tokens. If there is
433
451
  only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
@@ -436,15 +454,35 @@ class GalileoModel(FeatureExtractor):
436
454
  If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
437
455
  take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
438
456
  """
457
+ space_time_modalities = ["s1", "s2"]
458
+ time_modalities = ["era5", "tc", "viirs"]
439
459
  stacked_inputs = {}
460
+ months: torch.Tensor | None = None
440
461
  for key in context.inputs[0].keys():
441
462
  # assume all the keys in an input are consistent
442
463
  if key in self.input_keys:
443
464
  stacked_inputs[key] = torch.stack(
444
- [inp[key] for inp in context.inputs], dim=0
465
+ [inp[key].image for inp in context.inputs], dim=0
445
466
  )
467
+ if key in space_time_modalities + time_modalities:
468
+ if months is None:
469
+ if context.inputs[0][key].timestamps is not None:
470
+ months = torch.stack(
471
+ [
472
+ self.time_ranges_to_timestamps(
473
+ inp[key].timestamps, # type: ignore
474
+ device=stacked_inputs[key].device,
475
+ )
476
+ for inp in context.inputs
477
+ ],
478
+ dim=0,
479
+ )
480
+
481
+ if months is not None:
482
+ stacked_inputs["months"] = months
483
+
446
484
  s_t_channels = []
447
- for space_time_modality in ["s1", "s2"]:
485
+ for space_time_modality in space_time_modalities:
448
486
  if space_time_modality not in stacked_inputs:
449
487
  continue
450
488
  if space_time_modality == "s1":
@@ -452,36 +490,27 @@ class GalileoModel(FeatureExtractor):
452
490
  else:
453
491
  s_t_channels += self.s_t_channels_s2
454
492
  cur = stacked_inputs[space_time_modality]
455
- # Check if it's single or multitemporal, and reshape accordingly
456
- num_bands = len(S2_BANDS) if space_time_modality == "s2" else len(S1_BANDS)
457
- num_timesteps = cur.shape[1] // num_bands
458
- cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
493
+ cur = rearrange(cur, "b c t h w -> b h w t c")
459
494
  stacked_inputs[space_time_modality] = cur
460
495
 
461
496
  for space_modality in ["srtm", "dw", "wc"]:
462
497
  if space_modality not in stacked_inputs:
463
498
  continue
499
+ # take the first (and assumed only) timestep
500
+ stacked_inputs[space_modality] = stacked_inputs[space_modality][:, :, 0]
464
501
  stacked_inputs[space_modality] = rearrange(
465
502
  stacked_inputs[space_modality], "b c h w -> b h w c"
466
503
  )
467
504
 
468
- for time_modality in ["era5", "tc", "viirs"]:
505
+ for time_modality in time_modalities:
469
506
  if time_modality not in stacked_inputs:
470
507
  continue
471
508
  cur = stacked_inputs[time_modality]
472
- # Check if it's single or multitemporal, and reshape accordingly
473
- num_bands = {
474
- "era5": len(ERA5_BANDS),
475
- "tc": len(TC_BANDS),
476
- "viirs": len(VIIRS_BANDS),
477
- }[time_modality]
478
- num_timesteps = cur.shape[1] // num_bands
479
509
  # take the average over the h, w bands since Galileo
480
510
  # treats it as a pixel-timeseries
481
511
  cur = rearrange(
482
- torch.nanmean(torch.nanmean(cur, dim=-1), dim=-1),
483
- "b (t c) -> b t c",
484
- t=num_timesteps,
512
+ torch.nanmean(cur, dim=(-1, -2)),
513
+ "b c t -> b t c",
485
514
  )
486
515
  stacked_inputs[time_modality] = cur
487
516
 
@@ -489,9 +518,8 @@ class GalileoModel(FeatureExtractor):
489
518
  if static_modality not in stacked_inputs:
490
519
  continue
491
520
  cur = stacked_inputs[static_modality]
492
- stacked_inputs[static_modality] = torch.nanmean(
493
- torch.nanmean(cur, dim=-1), dim=-1
494
- )
521
+ stacked_inputs[static_modality] = torch.nanmean(cur, dim=(2, 3, 4))
522
+
495
523
  galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
496
524
  h = galileo_input.s_t_x.shape[1]
497
525
  if h < self.patch_size:
@@ -511,7 +539,6 @@ class GalileoModel(FeatureExtractor):
511
539
  torch_context = torch.amp.autocast(
512
540
  device_type=device.type, dtype=self.autocast_dtype
513
541
  )
514
-
515
542
  with torch_context:
516
543
  outputs = self.model(
517
544
  s_t_x=galileo_input.s_t_x,
@@ -53,7 +53,12 @@ class EncoderModuleWrapper(FeatureExtractor):
53
53
  Returns:
54
54
  the output from the last wrapped module.
55
55
  """
56
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
56
+ # take the first and only timestep. Currently no intermediate
57
+ # components support multi temporal inputs, so if the input is
58
+ # multitemporal it should be wrapped in a simple time series wrapper.
59
+ images = torch.stack(
60
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
61
+ )
57
62
  cur: Any = FeatureMaps([images])
58
63
  for m in self.encoder_modules:
59
64
  cur = m(cur, context)
@@ -47,11 +47,13 @@ class Molmo(FeatureExtractor):
47
47
  a FeatureMaps. Molmo produces features at one scale, so it will contain one
48
48
  feature map that is a Bx24x24x2048 tensor.
49
49
  """
50
- device = context.inputs[0]["image"].device
50
+ device = context.inputs[0]["image"].image.device
51
51
  molmo_inputs_list = []
52
52
  # Process each one so we can isolate just the full image without any crops.
53
53
  for inp in context.inputs:
54
- image = inp["image"].cpu().numpy().transpose(1, 2, 0)
54
+ image = (
55
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
56
+ )
55
57
  processed = self.processor.process(
56
58
  images=[image],
57
59
  text="",
@@ -1,26 +1,27 @@
1
1
  """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
2
 
3
3
  import json
4
+ import warnings
4
5
  from contextlib import nullcontext
6
+ from datetime import datetime
5
7
  from typing import Any
6
8
 
7
9
  import torch
8
10
  from einops import rearrange
9
- from olmo_core.config import Config
10
- from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
+ from olmoearth_pretrain.config import Config, require_olmo_core
11
12
  from olmoearth_pretrain.data.constants import Modality
13
+ from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
12
14
  from olmoearth_pretrain.model_loader import (
13
15
  ModelID,
14
16
  load_model_from_id,
15
17
  load_model_from_path,
16
18
  )
17
19
  from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
18
- from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
20
  from upath import UPath
20
21
 
21
22
  from rslearn.log_utils import get_logger
22
23
  from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
23
- from rslearn.train.model_context import ModelContext
24
+ from rslearn.train.model_context import ModelContext, RasterImage
24
25
 
25
26
  logger = get_logger(__name__)
26
27
 
@@ -61,6 +62,7 @@ class OlmoEarth(FeatureExtractor):
61
62
  embedding_size: int | None = None,
62
63
  autocast_dtype: str | None = "bfloat16",
63
64
  token_pooling: bool = True,
65
+ use_legacy_timestamps: bool = True,
64
66
  ):
65
67
  """Create a new OlmoEarth model.
66
68
 
@@ -87,7 +89,15 @@ class OlmoEarth(FeatureExtractor):
87
89
  token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
88
90
  there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
89
91
  dimensions.
92
+ use_legacy_timestamps: In our original implementation of OlmoEarth, we applied timestamps starting
93
+ from 0 (instead of the actual timestamps of the input). The option to do this is preserved
94
+ for backwards compatability with finetuned models which were trained against this implementation.
90
95
  """
96
+ if use_legacy_timestamps:
97
+ warnings.warn(
98
+ "For new projects, don't use legacy timesteps.", DeprecationWarning
99
+ )
100
+
91
101
  if (
92
102
  sum(
93
103
  [
@@ -138,6 +148,7 @@ class OlmoEarth(FeatureExtractor):
138
148
  model = model[part]
139
149
  self.model = model
140
150
  self.token_pooling = token_pooling
151
+ self.use_legacy_timestamps = use_legacy_timestamps
141
152
 
142
153
  def _load_model_from_checkpoint(
143
154
  self, checkpoint_upath: UPath, random_initialization: bool
@@ -148,9 +159,12 @@ class OlmoEarth(FeatureExtractor):
148
159
  that contains the distributed checkpoint. This is the format produced by
149
160
  pre-training runs in olmoearth_pretrain.
150
161
  """
151
- # Load the model config and initialize it.
152
162
  # We avoid loading the train module here because it depends on running within
153
163
  # olmo_core.
164
+ # Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
165
+ require_olmo_core("_load_model_from_checkpoint")
166
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
167
+
154
168
  with (checkpoint_upath / "config.json").open() as f:
155
169
  config_dict = json.load(f)
156
170
  model_config = Config.from_dict(config_dict["model"])
@@ -165,6 +179,32 @@ class OlmoEarth(FeatureExtractor):
165
179
 
166
180
  return model
167
181
 
182
+ @staticmethod
183
+ def time_ranges_to_timestamps(
184
+ time_ranges: list[tuple[datetime, datetime]],
185
+ max_timestamps: int,
186
+ device: torch.device,
187
+ ) -> torch.Tensor:
188
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by OlmoEarth.
189
+
190
+ OlmoEarth only uses the month associated with each timestamp, so we take the midpoint
191
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
192
+ time so that start_time == end_time == mid_time.
193
+ """
194
+ timestamps = torch.zeros((max_timestamps, 3), dtype=torch.int32, device=device)
195
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
196
+ timestamps[: len(time_ranges), 0] = torch.tensor(
197
+ [d.day for d in mid_ranges], dtype=torch.int32
198
+ )
199
+ # months are indexed 0-11
200
+ timestamps[: len(time_ranges), 1] = torch.tensor(
201
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32
202
+ )
203
+ timestamps[: len(time_ranges), 2] = torch.tensor(
204
+ [d.year for d in mid_ranges], dtype=torch.int32
205
+ )
206
+ return timestamps
207
+
168
208
  def _prepare_modality_inputs(
169
209
  self, context: ModelContext
170
210
  ) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
@@ -190,43 +230,55 @@ class OlmoEarth(FeatureExtractor):
190
230
  # We'll have to fix all that.
191
231
  max_timesteps = 1
192
232
  modality_data = {}
233
+ # we will just store the longest time range
234
+ # per instance in the batch. This means it may not be
235
+ # aligned per modality
236
+ timestamps_per_instance: list[list[tuple[datetime, datetime]]] = [[]] * len(
237
+ context.inputs
238
+ )
193
239
  for modality in MODALITY_NAMES:
194
240
  if modality not in context.inputs[0]:
195
241
  continue
196
242
  present_modalities.append(modality)
197
- tensors = [inp[modality] for inp in context.inputs]
243
+ tensors = []
244
+ for idx, inp in enumerate(context.inputs):
245
+ assert isinstance(inp, RasterImage)
246
+ tensors.append(inp[modality].image)
247
+ cur_timestamps = inp[modality].timestamps
248
+ if cur_timestamps is not None and len(cur_timestamps) > len(
249
+ timestamps_per_instance[idx]
250
+ ):
251
+ timestamps_per_instance[idx] = cur_timestamps
252
+ tensors = [inp[modality].image for inp in context.inputs]
198
253
  device = tensors[0].device
199
- num_bands = Modality.get(modality).num_bands
200
- max_t = max(t.shape[0] for t in tensors) // num_bands
254
+ max_t = max(t.shape[1] for t in tensors)
201
255
  max_timesteps = max(max_timesteps, max_t)
202
256
  modality_data[modality] = (
203
257
  tensors,
204
- num_bands,
205
258
  len(Modality.get(modality).band_sets),
206
259
  )
207
260
 
208
261
  # Second pass: pad and process each modality with global max_timesteps
209
262
  for modality in present_modalities:
210
- tensors, num_bands, num_band_sets = modality_data[modality]
211
- target_ch = max_timesteps * num_bands
263
+ tensors, num_band_sets = modality_data[modality]
212
264
 
213
265
  # Pad tensors to target_ch and track original timesteps for masking
214
266
  padded = []
215
267
  original_timesteps = []
216
268
  for t in tensors:
217
- orig_t = t.shape[0] // num_bands
269
+ orig_t = t.shape[1]
218
270
  original_timesteps.append(orig_t)
219
- if t.shape[0] < target_ch:
271
+ if orig_t < max_timesteps:
220
272
  pad = torch.zeros(
221
- (target_ch - t.shape[0],) + t.shape[1:],
273
+ t.shape[:1] + (max_timesteps - orig_t,) + t.shape[2:],
222
274
  dtype=t.dtype,
223
275
  device=device,
224
276
  )
225
- t = torch.cat([t, pad], dim=0)
277
+ t = torch.cat([t, pad], dim=1)
226
278
  padded.append(t)
227
279
 
228
280
  cur = torch.stack(padded, dim=0)
229
- cur = rearrange(cur, "b (t c) h w -> b h w t c", t=max_timesteps)
281
+ cur = rearrange(cur, "b c t h w -> b h w t c")
230
282
  kwargs[modality] = cur
231
283
 
232
284
  # Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
@@ -242,19 +294,31 @@ class OlmoEarth(FeatureExtractor):
242
294
  mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
243
295
  kwargs[f"{modality}_mask"] = mask
244
296
 
245
- # Timestamps is required.
246
- # Note that only months (0 to 11) are used in OlmoEarth position encoding.
247
- # For now, we assign same timestamps to all inputs, but later we should
248
- # handle varying timestamps per input.
249
- timestamps = torch.zeros(
250
- (len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
251
- )
252
- timestamps[:, :, 0] = 1 # day
253
- timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
254
- None, :
255
- ] # month
256
- timestamps[:, :, 2] = 2024 # year
257
- kwargs["timestamps"] = timestamps
297
+ if self.use_legacy_timestamps:
298
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
299
+ timestamps = torch.zeros(
300
+ (len(context.inputs), max_timesteps, 3),
301
+ dtype=torch.int32,
302
+ device=device,
303
+ )
304
+ timestamps[:, :, 0] = 1 # day
305
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
306
+ None, :
307
+ ] # month
308
+ timestamps[:, :, 2] = 2024 # year
309
+ kwargs["timestamps"] = timestamps
310
+ else:
311
+ if max([len(t) for t in timestamps_per_instance]) == 0:
312
+ # Timestamps is required.
313
+ raise ValueError("No inputs had timestamps.")
314
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
315
+ kwargs["timestamps"] = torch.stack(
316
+ [
317
+ self.time_ranges_to_timestamps(time_range, max_timesteps, device)
318
+ for time_range in timestamps_per_instance
319
+ ],
320
+ dim=0,
321
+ )
258
322
 
259
323
  return MaskedOlmoEarthSample(**kwargs), present_modalities, device
260
324
 
@@ -64,8 +64,8 @@ class OlmoEarthNormalize(Transform):
64
64
  band_norms = self.norm_config[modality_name]
65
65
  image = input_dict[modality_name]
66
66
  # Keep a set of indices to make sure that we normalize all of them.
67
- needed_band_indices = set(range(image.shape[0]))
68
- num_timesteps = image.shape[0] // len(cur_band_names)
67
+ needed_band_indices = set(range(image.image.shape[0]))
68
+ num_timesteps = image.image.shape[0] // len(cur_band_names)
69
69
 
70
70
  for band, norm_dict in band_norms.items():
71
71
  # If multitemporal, normalize each timestep separately.
@@ -73,7 +73,9 @@ class OlmoEarthNormalize(Transform):
73
73
  band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
74
  min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
75
  max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
- image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
76
+ image.image[band_idx] = (image.image[band_idx] - min_val) / (
77
+ max_val - min_val
78
+ )
77
79
  needed_band_indices.remove(band_idx)
78
80
 
79
81
  if len(needed_band_indices) > 0:
@@ -142,7 +142,9 @@ class Panopticon(FeatureExtractor):
142
142
  def forward(self, context: ModelContext) -> FeatureMaps:
143
143
  """Forward pass through the panopticon model."""
144
144
  batch_inputs = {
145
- key: torch.stack([inp[key] for inp in context.inputs], dim=0)
145
+ key: torch.stack(
146
+ [inp[key].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
147
+ )
146
148
  for key in context.inputs[0].keys()
147
149
  }
148
150
  panopticon_inputs = self.prepare_input(batch_inputs)