rslearn 0.0.18__tar.gz → 0.0.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. {rslearn-0.0.18/rslearn.egg-info → rslearn-0.0.20}/PKG-INFO +1 -1
  2. {rslearn-0.0.18 → rslearn-0.0.20}/pyproject.toml +1 -1
  3. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/arg_parser.py +2 -9
  4. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/config/dataset.py +15 -16
  5. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/dataset.py +28 -22
  6. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/lightning_cli.py +22 -11
  7. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/main.py +1 -1
  8. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/anysat.py +35 -33
  9. rslearn-0.0.20/rslearn/models/attention_pooling.py +177 -0
  10. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/clip.py +5 -2
  11. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/component.py +12 -0
  12. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/croma.py +11 -3
  13. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/dinov3.py +2 -1
  14. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/faster_rcnn.py +2 -1
  15. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/galileo/galileo.py +58 -31
  16. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/module_wrapper.py +6 -1
  17. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/molmo.py +4 -2
  18. rslearn-0.0.20/rslearn/models/olmoearth_pretrain/model.py +422 -0
  19. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/norm.py +5 -3
  20. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon.py +3 -1
  21. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/presto/presto.py +45 -15
  22. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/prithvi.py +9 -7
  23. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/sam2_enc.py +3 -1
  24. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/satlaspretrain.py +4 -1
  25. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/simple_time_series.py +43 -17
  26. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/ssl4eo_s12.py +19 -14
  27. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/swin.py +3 -1
  28. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/terramind.py +5 -4
  29. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/all_patches_dataset.py +96 -28
  30. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/dataset.py +102 -53
  31. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/model_context.py +35 -1
  32. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/scheduler.py +15 -0
  33. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/classification.py +8 -2
  34. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/detection.py +3 -2
  35. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/multi_task.py +2 -3
  36. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/per_pixel_regression.py +14 -5
  37. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/regression.py +8 -2
  38. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/segmentation.py +13 -4
  39. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/task.py +2 -2
  40. rslearn-0.0.20/rslearn/train/transforms/concatenate.py +89 -0
  41. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/crop.py +22 -8
  42. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/flip.py +13 -5
  43. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/mask.py +11 -2
  44. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/normalize.py +46 -15
  45. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/pad.py +15 -3
  46. rslearn-0.0.20/rslearn/train/transforms/resize.py +83 -0
  47. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/select_bands.py +11 -2
  48. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/sentinel1.py +18 -3
  49. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/geometry.py +73 -0
  50. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/jsonargparse.py +66 -0
  51. {rslearn-0.0.18 → rslearn-0.0.20/rslearn.egg-info}/PKG-INFO +1 -1
  52. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn.egg-info/SOURCES.txt +2 -0
  53. rslearn-0.0.18/rslearn/models/olmoearth_pretrain/model.py +0 -267
  54. rslearn-0.0.18/rslearn/train/transforms/concatenate.py +0 -49
  55. {rslearn-0.0.18 → rslearn-0.0.20}/LICENSE +0 -0
  56. {rslearn-0.0.18 → rslearn-0.0.20}/NOTICE +0 -0
  57. {rslearn-0.0.18 → rslearn-0.0.20}/README.md +0 -0
  58. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/__init__.py +0 -0
  59. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/config/__init__.py +0 -0
  60. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/const.py +0 -0
  61. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/__init__.py +0 -0
  62. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/aws_landsat.py +0 -0
  63. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/aws_open_data.py +0 -0
  64. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/aws_sentinel1.py +0 -0
  65. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/climate_data_store.py +0 -0
  66. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/copernicus.py +0 -0
  67. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/data_source.py +0 -0
  68. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/earthdaily.py +0 -0
  69. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/earthdata_srtm.py +0 -0
  70. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/eurocrops.py +0 -0
  71. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/gcp_public_data.py +0 -0
  72. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/google_earth_engine.py +0 -0
  73. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/local_files.py +0 -0
  74. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/openstreetmap.py +0 -0
  75. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/planet.py +0 -0
  76. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/planet_basemap.py +0 -0
  77. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/planetary_computer.py +0 -0
  78. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/usda_cdl.py +0 -0
  79. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/usgs_landsat.py +0 -0
  80. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/utils.py +0 -0
  81. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/vector_source.py +0 -0
  82. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/worldcereal.py +0 -0
  83. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/worldcover.py +0 -0
  84. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/worldpop.py +0 -0
  85. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/data_sources/xyz_tiles.py +0 -0
  86. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/__init__.py +0 -0
  87. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/add_windows.py +0 -0
  88. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/handler_summaries.py +0 -0
  89. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/manage.py +0 -0
  90. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/materialize.py +0 -0
  91. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/remap.py +0 -0
  92. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/storage/__init__.py +0 -0
  93. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/storage/file.py +0 -0
  94. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/storage/storage.py +0 -0
  95. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/dataset/window.py +0 -0
  96. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/log_utils.py +0 -0
  97. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/__init__.py +0 -0
  98. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/clay/clay.py +0 -0
  99. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/clay/configs/metadata.yaml +0 -0
  100. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/concatenate_features.py +0 -0
  101. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/conv.py +0 -0
  102. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/__init__.py +0 -0
  103. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/box_ops.py +0 -0
  104. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/detr.py +0 -0
  105. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/matcher.py +0 -0
  106. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/position_encoding.py +0 -0
  107. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/transformer.py +0 -0
  108. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/detr/util.py +0 -0
  109. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/feature_center_crop.py +0 -0
  110. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/fpn.py +0 -0
  111. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/galileo/__init__.py +0 -0
  112. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/galileo/single_file_galileo.py +0 -0
  113. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/multitask.py +0 -0
  114. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  115. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  116. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  117. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  118. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  119. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  120. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  121. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  122. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  123. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  124. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  125. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  126. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  127. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/pick_features.py +0 -0
  128. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/pooling_decoder.py +0 -0
  129. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/presto/__init__.py +0 -0
  130. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/presto/single_file_presto.py +0 -0
  131. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/resize_features.py +0 -0
  132. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/singletask.py +0 -0
  133. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/task_embedding.py +0 -0
  134. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/trunk.py +0 -0
  135. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/unet.py +0 -0
  136. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/upsample.py +0 -0
  137. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/models/use_croma.py +0 -0
  138. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/py.typed +0 -0
  139. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/template_params.py +0 -0
  140. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/tile_stores/__init__.py +0 -0
  141. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/tile_stores/default.py +0 -0
  142. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/tile_stores/tile_store.py +0 -0
  143. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/__init__.py +0 -0
  144. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/callbacks/__init__.py +0 -0
  145. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/callbacks/adapters.py +0 -0
  146. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  147. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/callbacks/gradients.py +0 -0
  148. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/callbacks/peft.py +0 -0
  149. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/data_module.py +0 -0
  150. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/lightning_module.py +0 -0
  151. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/optimizer.py +0 -0
  152. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/prediction_writer.py +0 -0
  153. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/__init__.py +0 -0
  154. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/tasks/embedding.py +0 -0
  155. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/__init__.py +0 -0
  156. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/train/transforms/transform.py +0 -0
  157. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/__init__.py +0 -0
  158. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/array.py +0 -0
  159. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/feature.py +0 -0
  160. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/fsspec.py +0 -0
  161. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/get_utm_ups_crs.py +0 -0
  162. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/grid_index.py +0 -0
  163. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/mp.py +0 -0
  164. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/raster_format.py +0 -0
  165. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/rtree_index.py +0 -0
  166. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/spatial_index.py +0 -0
  167. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/sqlite_index.py +0 -0
  168. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/time.py +0 -0
  169. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn/utils/vector_format.py +0 -0
  170. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn.egg-info/dependency_links.txt +0 -0
  171. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn.egg-info/entry_points.txt +0 -0
  172. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn.egg-info/requires.txt +0 -0
  173. {rslearn-0.0.18 → rslearn-0.0.20}/rslearn.egg-info/top_level.txt +0 -0
  174. {rslearn-0.0.18 → rslearn-0.0.20}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.18"
3
+ version = "0.0.20"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -1,6 +1,5 @@
1
1
  """Custom Lightning ArgumentParser with environment variable substitution support."""
2
2
 
3
- import os
4
3
  from typing import Any
5
4
 
6
5
  from jsonargparse import Namespace
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
21
20
  def parse_string(
22
21
  self,
23
22
  cfg_str: str,
24
- cfg_path: str | os.PathLike = "",
25
- ext_vars: dict | None = None,
26
- env: bool | None = None,
27
- defaults: bool = True,
28
- with_meta: bool | None = None,
23
+ *args: Any,
29
24
  **kwargs: Any,
30
25
  ) -> Namespace:
31
26
  """Pre-processes string for environment variable substitution before parsing."""
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
33
28
  substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
34
29
 
35
30
  # Call the parent method with the substituted config
36
- return super().parse_string(
37
- substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
38
- )
31
+ return super().parse_string(substituted_cfg_str, *args, **kwargs)
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
25
25
  from upath import UPath
26
26
 
27
27
  from rslearn.log_utils import get_logger
28
- from rslearn.utils import PixelBounds, Projection
28
+ from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
29
29
  from rslearn.utils.raster_format import RasterFormat
30
30
  from rslearn.utils.vector_format import VectorFormat
31
31
 
@@ -215,22 +215,12 @@ class BandSetConfig(BaseModel):
215
215
  Returns:
216
216
  tuple of updated projection and bounds with zoom offset applied
217
217
  """
218
- if self.zoom_offset == 0:
219
- return projection, bounds
220
- projection = Projection(
221
- projection.crs,
222
- projection.x_resolution / (2**self.zoom_offset),
223
- projection.y_resolution / (2**self.zoom_offset),
224
- )
225
- if self.zoom_offset > 0:
226
- zoom_factor = 2**self.zoom_offset
227
- bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
218
+ if self.zoom_offset >= 0:
219
+ factor = ResolutionFactor(numerator=2**self.zoom_offset)
228
220
  else:
229
- bounds = tuple(
230
- x // (2 ** (-self.zoom_offset))
231
- for x in bounds # type: ignore
232
- )
233
- return projection, bounds
221
+ factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
222
+
223
+ return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
234
224
 
235
225
  @field_validator("format", mode="before")
236
226
  @classmethod
@@ -645,3 +635,12 @@ class DatasetConfig(BaseModel):
645
635
  default_factory=lambda: StorageConfig(),
646
636
  description="jsonargparse configuration for the WindowStorageFactory.",
647
637
  )
638
+
639
+ @field_validator("layers", mode="after")
640
+ @classmethod
641
+ def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
642
+ """Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
643
+ for layer_name in v.keys():
644
+ if "." in layer_name:
645
+ raise ValueError(f"layer names must not contain periods: {layer_name}")
646
+ return v
@@ -23,7 +23,7 @@ class Dataset:
23
23
  .. code-block:: none
24
24
 
25
25
  dataset/
26
- config.json
26
+ config.json # optional, if config provided as runtime object
27
27
  windows/
28
28
  group1/
29
29
  epsg:3857_10_623565_1528020/
@@ -40,37 +40,43 @@ class Dataset:
40
40
  materialize.
41
41
  """
42
42
 
43
- def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
43
+ def __init__(
44
+ self,
45
+ path: UPath,
46
+ disabled_layers: list[str] = [],
47
+ dataset_config: DatasetConfig | None = None,
48
+ ) -> None:
44
49
  """Initializes a new Dataset.
45
50
 
46
51
  Args:
47
52
  path: the root directory of the dataset
48
53
  disabled_layers: list of layers to disable
54
+ dataset_config: optional dataset configuration to use instead of loading from the dataset directory
49
55
  """
50
56
  self.path = path
51
57
 
52
- # Load dataset configuration.
53
- with (self.path / "config.json").open("r") as f:
54
- config_content = f.read()
55
- config_content = substitute_env_vars_in_string(config_content)
56
- config = DatasetConfig.model_validate(json.loads(config_content))
57
-
58
- self.layers = {}
59
- for layer_name, layer_config in config.layers.items():
60
- # Layer names must not contain period, since we use period to
61
- # distinguish different materialized groups within a layer.
62
- assert "." not in layer_name, "layer names must not contain periods"
63
- if layer_name in disabled_layers:
64
- logger.warning(f"Layer {layer_name} is disabled")
65
- continue
66
- self.layers[layer_name] = layer_config
67
-
68
- self.tile_store_config = config.tile_store
69
- self.storage = (
70
- config.storage.instantiate_window_storage_factory().get_storage(
71
- self.path
58
+ if dataset_config is None:
59
+ # Load dataset configuration from the dataset directory.
60
+ with (self.path / "config.json").open("r") as f:
61
+ config_content = f.read()
62
+ config_content = substitute_env_vars_in_string(config_content)
63
+ dataset_config = DatasetConfig.model_validate(
64
+ json.loads(config_content)
72
65
  )
66
+
67
+ self.layers = {}
68
+ for layer_name, layer_config in dataset_config.layers.items():
69
+ if layer_name in disabled_layers:
70
+ logger.warning(f"Layer {layer_name} is disabled")
71
+ continue
72
+ self.layers[layer_name] = layer_config
73
+
74
+ self.tile_store_config = dataset_config.tile_store
75
+ self.storage = (
76
+ dataset_config.storage.instantiate_window_storage_factory().get_storage(
77
+ self.path
73
78
  )
79
+ )
74
80
 
75
81
  def load_windows(
76
82
  self,
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
21
21
  from rslearn.train.data_module import RslearnDataModule
22
22
  from rslearn.train.lightning_module import RslearnLightningModule
23
23
  from rslearn.utils.fsspec import open_atomic
24
+ from rslearn.utils.jsonargparse import init_jsonargparse
24
25
 
25
26
  WANDB_ID_FNAME = "wandb_id"
26
27
 
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
390
391
 
391
392
  Sets the dataset path for any configured RslearnPredictionWriter callbacks.
392
393
  """
393
- subcommand = self.config.subcommand
394
- c = self.config[subcommand]
394
+ if not hasattr(self.config, "subcommand"):
395
+ logger.warning(
396
+ "Config does not have subcommand attribute, assuming we are in run=False mode"
397
+ )
398
+ subcommand = None
399
+ c = self.config
400
+ else:
401
+ subcommand = self.config.subcommand
402
+ c = self.config[subcommand]
395
403
 
396
404
  # If there is a RslearnPredictionWriter, set its path.
397
405
  prediction_writer_callback = None
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
415
423
  if subcommand == "predict":
416
424
  c.return_predictions = False
417
425
 
418
- # For now we use DDP strategy with find_unused_parameters=True.
426
+ # Default to DDP with find_unused_parameters. Likely won't get called with unified config
419
427
  if subcommand == "fit":
420
- c.trainer.strategy = jsonargparse.Namespace(
421
- {
422
- "class_path": "lightning.pytorch.strategies.DDPStrategy",
423
- "init_args": jsonargparse.Namespace(
424
- {"find_unused_parameters": True}
425
- ),
426
- }
427
- )
428
+ if not c.trainer.strategy:
429
+ c.trainer.strategy = jsonargparse.Namespace(
430
+ {
431
+ "class_path": "lightning.pytorch.strategies.DDPStrategy",
432
+ "init_args": jsonargparse.Namespace(
433
+ {"find_unused_parameters": True}
434
+ ),
435
+ }
436
+ )
428
437
 
429
438
  if c.management_dir:
430
439
  self.enable_project_management(c.management_dir)
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
432
441
 
433
442
  def model_handler() -> None:
434
443
  """Handler for any rslearn model X commands."""
444
+ init_jsonargparse()
445
+
435
446
  RslearnLightningCLI(
436
447
  model_class=RslearnLightningModule,
437
448
  datamodule_class=RslearnDataModule,
@@ -380,7 +380,7 @@ def apply_on_windows(
380
380
 
381
381
  def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
382
382
  """Call apply_on_windows with arguments passed via command-line interface."""
383
- dataset = Dataset(UPath(args.root), args.disabled_layers)
383
+ dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
384
384
  apply_on_windows(
385
385
  f=f,
386
386
  dataset=dataset,
@@ -4,6 +4,8 @@ This code loads the AnySat model from torch hub. See
4
4
  https://github.com/gastruc/AnySat for applicable license and copyright information.
5
5
  """
6
6
 
7
+ from datetime import datetime
8
+
7
9
  import torch
8
10
  from einops import rearrange
9
11
 
@@ -53,7 +55,6 @@ class AnySat(FeatureExtractor):
53
55
  self,
54
56
  modalities: list[str],
55
57
  patch_size_meters: int,
56
- dates: dict[str, list[int]],
57
58
  output: str = "patch",
58
59
  output_modality: str | None = None,
59
60
  hub_repo: str = "gastruc/anysat",
@@ -85,14 +86,6 @@ class AnySat(FeatureExtractor):
85
86
  if m not in MODALITY_RESOLUTIONS:
86
87
  raise ValueError(f"Invalid modality: {m}")
87
88
 
88
- if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
89
- raise ValueError("`dates` keys must be time-series modalities only.")
90
- for m in modalities:
91
- if m in TIME_SERIES_MODALITIES and m not in dates:
92
- raise ValueError(
93
- f"Missing required dates for time-series modality '{m}'."
94
- )
95
-
96
89
  if patch_size_meters % 10 != 0:
97
90
  raise ValueError(
98
91
  "In AnySat, `patch_size` is in meters and must be a multiple of 10."
@@ -106,7 +99,6 @@ class AnySat(FeatureExtractor):
106
99
 
107
100
  self.modalities = modalities
108
101
  self.patch_size_meters = int(patch_size_meters)
109
- self.dates = dates
110
102
  self.output = output
111
103
  self.output_modality = output_modality
112
104
 
@@ -119,6 +111,20 @@ class AnySat(FeatureExtractor):
119
111
  )
120
112
  self._embed_dim = 768 # base width, 'dense' returns 2x
121
113
 
114
+ @staticmethod
115
+ def time_ranges_to_doy(
116
+ time_ranges: list[tuple[datetime, datetime]],
117
+ device: torch.device,
118
+ ) -> torch.Tensor:
119
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
120
+
121
+ AnySat uses the doy with each timestamp, so we take the midpoint
122
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
123
+ time so that start_time == end_time == mid_time.
124
+ """
125
+ doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
126
+ return torch.tensor(doys, dtype=torch.int32, device=device)
127
+
122
128
  def forward(self, context: ModelContext) -> FeatureMaps:
123
129
  """Forward pass for the AnySat model.
124
130
 
@@ -139,17 +145,29 @@ class AnySat(FeatureExtractor):
139
145
  raise ValueError(f"Modality '{modality}' not present in inputs.")
140
146
 
141
147
  cur = torch.stack(
142
- [inp[modality] for inp in inputs], dim=0
143
- ) # (B, C, H, W) or (B, T*C, H, W)
148
+ [inp[modality].image for inp in inputs], dim=0
149
+ ) # (B, C, T, H, W)
144
150
 
145
151
  if modality in TIME_SERIES_MODALITIES:
146
- num_dates = len(self.dates[modality])
147
- num_bands = cur.shape[1] // num_dates
148
- cur = rearrange(
149
- cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
150
- )
152
+ num_bands = cur.shape[1]
153
+ cur = rearrange(cur, "b c t h w -> b t c h w")
151
154
  H, W = cur.shape[-2], cur.shape[-1]
155
+
156
+ if inputs[0][modality].timestamps is None:
157
+ raise ValueError(
158
+ f"Require timestamps for time series modality {modality}"
159
+ )
160
+ timestamps = torch.stack(
161
+ [
162
+ self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
163
+ for inp in inputs
164
+ ],
165
+ dim=0,
166
+ )
167
+ batch[f"{modality}_dates"] = timestamps
152
168
  else:
169
+ # take the first (assumed only) timestep
170
+ cur = cur[:, :, 0]
153
171
  num_bands = cur.shape[1]
154
172
  H, W = cur.shape[-2], cur.shape[-1]
155
173
 
@@ -173,22 +191,6 @@ class AnySat(FeatureExtractor):
173
191
  "All modalities must share the same spatial extent (H*res, W*res)."
174
192
  )
175
193
 
176
- # Add *_dates
177
- to_add = {}
178
- for modality, x in list(batch.items()):
179
- if modality in TIME_SERIES_MODALITIES:
180
- B, T = x.shape[0], x.shape[1]
181
- d = torch.as_tensor(
182
- self.dates[modality], dtype=torch.long, device=x.device
183
- )
184
- if d.ndim != 1 or d.numel() != T:
185
- raise ValueError(
186
- f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
187
- )
188
- to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
189
-
190
- batch.update(to_add)
191
-
192
194
  kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
193
195
  if self.output == "dense":
194
196
  kwargs["output_modality"] = self.output_modality
@@ -0,0 +1,177 @@
1
+ """An attention pooling layer."""
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from rslearn.models.component import (
12
+ FeatureMaps,
13
+ IntermediateComponent,
14
+ TokenFeatureMaps,
15
+ )
16
+ from rslearn.train.model_context import ModelContext
17
+
18
+
19
+ class SimpleAttentionPool(IntermediateComponent):
20
+ """Simple Attention Pooling.
21
+
22
+ Given a token feature map of shape BCHWN,
23
+ learn an attention layer which aggregates over
24
+ the N dimension.
25
+
26
+ This is done simply by learning a mapping D->1 which is the weight
27
+ which should be assigned to each token during averaging:
28
+
29
+ output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
30
+ """
31
+
32
+ def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
33
+ """Initialize the simple attention pooling layer.
34
+
35
+ Args:
36
+ in_dim: the encoding dimension D
37
+ hidden_linear: whether to apply an additional linear transformation D -> D
38
+ to the feat tokens. If this is True, a ReLU activation is applied
39
+ after the first linear transformation.
40
+ """
41
+ super().__init__()
42
+ if hidden_linear:
43
+ self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
44
+ else:
45
+ self.hidden_linear = None
46
+ self.linear = nn.Linear(in_features=in_dim, out_features=1)
47
+
48
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
49
+ """Attention pooling for a single feature map (BCHWN tensor)."""
50
+ B, D, H, W, N = feat_tokens.shape
51
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
52
+ if self.hidden_linear is not None:
53
+ feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
54
+ attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
55
+ feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
56
+ return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
57
+
58
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
59
+ """Forward pass for attention pooling linear probe.
60
+
61
+ Args:
62
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
63
+ We pool over the final dimension in the TokenFeatureMaps. If multiple maps
64
+ are passed, we apply the same linear layers to all of them.
65
+ context: the model context.
66
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
67
+
68
+ Returns:
69
+ torch.Tensor:
70
+ - output, attentioned pool over the last dimension (B, C, H, W)
71
+ """
72
+ if not isinstance(intermediates, TokenFeatureMaps):
73
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
74
+
75
+ features = []
76
+ for feat in intermediates.feature_maps:
77
+ features.append(self.forward_for_map(feat))
78
+ return FeatureMaps(features)
79
+
80
+
81
+ class AttentionPool(IntermediateComponent):
82
+ """Attention Pooling.
83
+
84
+ Given a feature map of shape BCHWN,
85
+ learn an attention layer which aggregates over
86
+ the N dimension.
87
+
88
+ We do this by learning a query token, and applying a standard
89
+ attention mechanism against this learned query token.
90
+ """
91
+
92
+ def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
93
+ """Initialize the attention pooling layer.
94
+
95
+ Args:
96
+ in_dim: the encoding dimension D
97
+ num_heads: the number of heads to use
98
+ linear_on_kv: Whether to apply a linear layer on the input tokens
99
+ to create the key and value tokens.
100
+ """
101
+ super().__init__()
102
+ self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
103
+ if linear_on_kv:
104
+ self.k_linear = nn.Linear(in_dim, in_dim)
105
+ self.v_linear = nn.Linear(in_dim, in_dim)
106
+ else:
107
+ self.k_linear = None
108
+ self.v_linear = None
109
+ if in_dim % num_heads != 0:
110
+ raise ValueError(
111
+ f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
112
+ )
113
+ self.num_heads = num_heads
114
+ self.init_weights()
115
+
116
+ def init_weights(self) -> None:
117
+ """Initialize weights for the probe."""
118
+ nn.init.trunc_normal_(self.query_token, std=0.02)
119
+
120
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
121
+ """Attention pooling for a single feature map (BCHWN tensor)."""
122
+ B, D, H, W, N = feat_tokens.shape
123
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
124
+ collapsed_dim = B * H * W
125
+ q = self.query_token.expand(collapsed_dim, 1, -1)
126
+ q = q.reshape(
127
+ collapsed_dim, 1, self.num_heads, D // self.num_heads
128
+ ) # [B, 1, head, D_head]
129
+ q = rearrange(q, "b h n d -> b n h d")
130
+ if self.k_linear is not None:
131
+ assert self.v_linear is not None
132
+ k = self.k_linear(feat_tokens).reshape(
133
+ collapsed_dim, N, self.num_heads, D // self.num_heads
134
+ )
135
+ v = self.v_linear(feat_tokens).reshape(
136
+ collapsed_dim, N, self.num_heads, D // self.num_heads
137
+ )
138
+ else:
139
+ k = feat_tokens.reshape(
140
+ collapsed_dim, N, self.num_heads, D // self.num_heads
141
+ )
142
+ v = feat_tokens.reshape(
143
+ collapsed_dim, N, self.num_heads, D // self.num_heads
144
+ )
145
+ k = rearrange(k, "b n h d -> b h n d")
146
+ v = rearrange(v, "b n h d -> b h n d")
147
+
148
+ # Compute attention scores
149
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
150
+ D // self.num_heads
151
+ )
152
+ attn_weights = F.softmax(attn_scores, dim=-1)
153
+ x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
+ return x.reshape(B, D, H, W)
155
+
156
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
+ """Forward pass for attention pooling linear probe.
158
+
159
+ Args:
160
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
161
+ We pool over the final dimension in the TokenFeatureMaps. If multiple feature
162
+ maps are passed, we apply the same attention weights (query token and linear k, v layers)
163
+ to all the maps.
164
+ context: the model context.
165
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
166
+
167
+ Returns:
168
+ torch.Tensor:
169
+ - output, attentioned pool over the last dimension (B, C, H, W)
170
+ """
171
+ if not isinstance(intermediates, TokenFeatureMaps):
172
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
173
+
174
+ features = []
175
+ for feat in intermediates.feature_maps:
176
+ features.append(self.forward_for_map(feat))
177
+ return FeatureMaps(features)
@@ -43,9 +43,12 @@ class CLIP(FeatureExtractor):
43
43
  a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
44
44
  """
45
45
  inputs = context.inputs
46
- device = inputs[0]["image"].device
46
+ device = inputs[0]["image"].image.device
47
47
  clip_inputs = self.processor(
48
- images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
48
+ images=[
49
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
50
+ for inp in inputs
51
+ ],
49
52
  return_tensors="pt",
50
53
  padding=True,
51
54
  )
@@ -91,6 +91,18 @@ class FeatureMaps:
91
91
  feature_maps: list[torch.Tensor]
92
92
 
93
93
 
94
+ @dataclass
95
+ class TokenFeatureMaps:
96
+ """An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
97
+
98
+ Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
99
+ """
100
+
101
+ # List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
102
+ # (most fine-grained) to lowest resolution (coarsest).
103
+ feature_maps: list[torch.Tensor]
104
+
105
+
94
106
  @dataclass
95
107
  class FeatureVector:
96
108
  """An intermediate output type for a flat feature vector."""
@@ -175,10 +175,16 @@ class Croma(FeatureExtractor):
175
175
  sentinel1: torch.Tensor | None = None
176
176
  sentinel2: torch.Tensor | None = None
177
177
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
178
- sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
178
+ sentinel1 = torch.stack(
179
+ [inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
180
+ dim=0,
181
+ )
179
182
  sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
180
183
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
181
- sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
184
+ sentinel2 = torch.stack(
185
+ [inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
186
+ dim=0,
187
+ )
182
188
  sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
183
189
 
184
190
  outputs = self.model(
@@ -294,5 +300,7 @@ class CromaNormalize(Transform):
294
300
  for modality in MODALITY_BANDS.keys():
295
301
  if modality not in input_dict:
296
302
  continue
297
- input_dict[modality] = self.apply_image(input_dict[modality], modality)
303
+ input_dict[modality].image = self.apply_image(
304
+ input_dict[modality].image, modality
305
+ )
298
306
  return input_dict, target_dict
@@ -104,7 +104,8 @@ class DinoV3(FeatureExtractor):
104
104
  a FeatureMaps with one feature map.
105
105
  """
106
106
  cur = torch.stack(
107
- [inp["image"] for inp in context.inputs], dim=0
107
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
108
+ dim=0,
108
109
  ) # (B, C, H, W)
109
110
 
110
111
  if self.do_resizing and (
@@ -210,7 +210,8 @@ class FasterRCNN(Predictor):
210
210
  ),
211
211
  )
212
212
 
213
- image_list = [inp["image"] for inp in context.inputs]
213
+ # take the first (and assumed to be only) timestep
214
+ image_list = [inp["image"].image[:, 0] for inp in context.inputs]
214
215
  images, targets = self.noop_transform(image_list, targets)
215
216
 
216
217
  feature_dict = collections.OrderedDict()