rslearn 0.0.18__tar.gz → 0.0.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. {rslearn-0.0.18/rslearn.egg-info → rslearn-0.0.19}/PKG-INFO +1 -1
  2. {rslearn-0.0.18 → rslearn-0.0.19}/pyproject.toml +1 -1
  3. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/arg_parser.py +2 -9
  4. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/config/dataset.py +15 -16
  5. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/dataset.py +28 -22
  6. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/lightning_cli.py +22 -11
  7. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/main.py +1 -1
  8. rslearn-0.0.19/rslearn/models/attention_pooling.py +177 -0
  9. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/component.py +12 -0
  10. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/olmoearth_pretrain/model.py +125 -34
  11. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/simple_time_series.py +7 -1
  12. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/all_patches_dataset.py +67 -19
  13. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/dataset.py +36 -43
  14. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/scheduler.py +15 -0
  15. rslearn-0.0.19/rslearn/train/transforms/resize.py +74 -0
  16. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/geometry.py +73 -0
  17. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/jsonargparse.py +66 -0
  18. {rslearn-0.0.18 → rslearn-0.0.19/rslearn.egg-info}/PKG-INFO +1 -1
  19. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/SOURCES.txt +2 -0
  20. {rslearn-0.0.18 → rslearn-0.0.19}/LICENSE +0 -0
  21. {rslearn-0.0.18 → rslearn-0.0.19}/NOTICE +0 -0
  22. {rslearn-0.0.18 → rslearn-0.0.19}/README.md +0 -0
  23. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/__init__.py +0 -0
  24. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/config/__init__.py +0 -0
  25. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/const.py +0 -0
  26. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/__init__.py +0 -0
  27. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/aws_landsat.py +0 -0
  28. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/aws_open_data.py +0 -0
  29. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/aws_sentinel1.py +0 -0
  30. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/climate_data_store.py +0 -0
  31. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/copernicus.py +0 -0
  32. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/data_source.py +0 -0
  33. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/earthdaily.py +0 -0
  34. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/earthdata_srtm.py +0 -0
  35. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/eurocrops.py +0 -0
  36. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/gcp_public_data.py +0 -0
  37. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/google_earth_engine.py +0 -0
  38. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/local_files.py +0 -0
  39. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/openstreetmap.py +0 -0
  40. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/planet.py +0 -0
  41. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/planet_basemap.py +0 -0
  42. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/planetary_computer.py +0 -0
  43. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/usda_cdl.py +0 -0
  44. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/usgs_landsat.py +0 -0
  45. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/utils.py +0 -0
  46. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/vector_source.py +0 -0
  47. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/worldcereal.py +0 -0
  48. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/worldcover.py +0 -0
  49. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/worldpop.py +0 -0
  50. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/xyz_tiles.py +0 -0
  51. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/__init__.py +0 -0
  52. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/add_windows.py +0 -0
  53. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/handler_summaries.py +0 -0
  54. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/manage.py +0 -0
  55. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/materialize.py +0 -0
  56. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/remap.py +0 -0
  57. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/storage/__init__.py +0 -0
  58. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/storage/file.py +0 -0
  59. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/storage/storage.py +0 -0
  60. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/window.py +0 -0
  61. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/log_utils.py +0 -0
  62. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/__init__.py +0 -0
  63. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/anysat.py +0 -0
  64. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/clay/clay.py +0 -0
  65. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/clay/configs/metadata.yaml +0 -0
  66. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/clip.py +0 -0
  67. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/concatenate_features.py +0 -0
  68. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/conv.py +0 -0
  69. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/croma.py +0 -0
  70. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/__init__.py +0 -0
  71. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/box_ops.py +0 -0
  72. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/detr.py +0 -0
  73. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/matcher.py +0 -0
  74. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/position_encoding.py +0 -0
  75. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/transformer.py +0 -0
  76. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/util.py +0 -0
  77. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/dinov3.py +0 -0
  78. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/faster_rcnn.py +0 -0
  79. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/feature_center_crop.py +0 -0
  80. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/fpn.py +0 -0
  81. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/galileo/__init__.py +0 -0
  82. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/galileo/galileo.py +0 -0
  83. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/galileo/single_file_galileo.py +0 -0
  84. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/module_wrapper.py +0 -0
  85. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/molmo.py +0 -0
  86. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/multitask.py +0 -0
  87. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  88. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  89. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon.py +0 -0
  90. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  91. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  92. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  93. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  94. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  95. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  96. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  97. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  98. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  99. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  100. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  101. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  102. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/pick_features.py +0 -0
  103. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/pooling_decoder.py +0 -0
  104. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/presto/__init__.py +0 -0
  105. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/presto/presto.py +0 -0
  106. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/presto/single_file_presto.py +0 -0
  107. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/prithvi.py +0 -0
  108. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/resize_features.py +0 -0
  109. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/sam2_enc.py +0 -0
  110. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/satlaspretrain.py +0 -0
  111. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/singletask.py +0 -0
  112. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/ssl4eo_s12.py +0 -0
  113. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/swin.py +0 -0
  114. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/task_embedding.py +0 -0
  115. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/terramind.py +0 -0
  116. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/trunk.py +0 -0
  117. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/unet.py +0 -0
  118. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/upsample.py +0 -0
  119. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/use_croma.py +0 -0
  120. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/py.typed +0 -0
  121. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/template_params.py +0 -0
  122. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/tile_stores/__init__.py +0 -0
  123. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/tile_stores/default.py +0 -0
  124. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/tile_stores/tile_store.py +0 -0
  125. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/__init__.py +0 -0
  126. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/__init__.py +0 -0
  127. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/adapters.py +0 -0
  128. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  129. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/gradients.py +0 -0
  130. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/peft.py +0 -0
  131. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/data_module.py +0 -0
  132. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/lightning_module.py +0 -0
  133. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/model_context.py +0 -0
  134. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/optimizer.py +0 -0
  135. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/prediction_writer.py +0 -0
  136. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/__init__.py +0 -0
  137. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/classification.py +0 -0
  138. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/detection.py +0 -0
  139. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/embedding.py +0 -0
  140. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/multi_task.py +0 -0
  141. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/per_pixel_regression.py +0 -0
  142. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/regression.py +0 -0
  143. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/segmentation.py +0 -0
  144. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/task.py +0 -0
  145. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/__init__.py +0 -0
  146. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/concatenate.py +0 -0
  147. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/crop.py +0 -0
  148. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/flip.py +0 -0
  149. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/mask.py +0 -0
  150. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/normalize.py +0 -0
  151. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/pad.py +0 -0
  152. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/select_bands.py +0 -0
  153. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/sentinel1.py +0 -0
  154. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/transform.py +0 -0
  155. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/__init__.py +0 -0
  156. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/array.py +0 -0
  157. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/feature.py +0 -0
  158. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/fsspec.py +0 -0
  159. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/get_utm_ups_crs.py +0 -0
  160. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/grid_index.py +0 -0
  161. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/mp.py +0 -0
  162. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/raster_format.py +0 -0
  163. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/rtree_index.py +0 -0
  164. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/spatial_index.py +0 -0
  165. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/sqlite_index.py +0 -0
  166. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/time.py +0 -0
  167. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/vector_format.py +0 -0
  168. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/dependency_links.txt +0 -0
  169. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/entry_points.txt +0 -0
  170. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/requires.txt +0 -0
  171. {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/top_level.txt +0 -0
  172. {rslearn-0.0.18 → rslearn-0.0.19}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.18
3
+ Version: 0.0.19
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.18"
3
+ version = "0.0.19"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -1,6 +1,5 @@
1
1
  """Custom Lightning ArgumentParser with environment variable substitution support."""
2
2
 
3
- import os
4
3
  from typing import Any
5
4
 
6
5
  from jsonargparse import Namespace
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
21
20
  def parse_string(
22
21
  self,
23
22
  cfg_str: str,
24
- cfg_path: str | os.PathLike = "",
25
- ext_vars: dict | None = None,
26
- env: bool | None = None,
27
- defaults: bool = True,
28
- with_meta: bool | None = None,
23
+ *args: Any,
29
24
  **kwargs: Any,
30
25
  ) -> Namespace:
31
26
  """Pre-processes string for environment variable substitution before parsing."""
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
33
28
  substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
34
29
 
35
30
  # Call the parent method with the substituted config
36
- return super().parse_string(
37
- substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
38
- )
31
+ return super().parse_string(substituted_cfg_str, *args, **kwargs)
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
25
25
  from upath import UPath
26
26
 
27
27
  from rslearn.log_utils import get_logger
28
- from rslearn.utils import PixelBounds, Projection
28
+ from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
29
29
  from rslearn.utils.raster_format import RasterFormat
30
30
  from rslearn.utils.vector_format import VectorFormat
31
31
 
@@ -215,22 +215,12 @@ class BandSetConfig(BaseModel):
215
215
  Returns:
216
216
  tuple of updated projection and bounds with zoom offset applied
217
217
  """
218
- if self.zoom_offset == 0:
219
- return projection, bounds
220
- projection = Projection(
221
- projection.crs,
222
- projection.x_resolution / (2**self.zoom_offset),
223
- projection.y_resolution / (2**self.zoom_offset),
224
- )
225
- if self.zoom_offset > 0:
226
- zoom_factor = 2**self.zoom_offset
227
- bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
218
+ if self.zoom_offset >= 0:
219
+ factor = ResolutionFactor(numerator=2**self.zoom_offset)
228
220
  else:
229
- bounds = tuple(
230
- x // (2 ** (-self.zoom_offset))
231
- for x in bounds # type: ignore
232
- )
233
- return projection, bounds
221
+ factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
222
+
223
+ return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
234
224
 
235
225
  @field_validator("format", mode="before")
236
226
  @classmethod
@@ -645,3 +635,12 @@ class DatasetConfig(BaseModel):
645
635
  default_factory=lambda: StorageConfig(),
646
636
  description="jsonargparse configuration for the WindowStorageFactory.",
647
637
  )
638
+
639
+ @field_validator("layers", mode="after")
640
+ @classmethod
641
+ def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
642
+ """Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
643
+ for layer_name in v.keys():
644
+ if "." in layer_name:
645
+ raise ValueError(f"layer names must not contain periods: {layer_name}")
646
+ return v
@@ -23,7 +23,7 @@ class Dataset:
23
23
  .. code-block:: none
24
24
 
25
25
  dataset/
26
- config.json
26
+ config.json # optional, if config provided as runtime object
27
27
  windows/
28
28
  group1/
29
29
  epsg:3857_10_623565_1528020/
@@ -40,37 +40,43 @@ class Dataset:
40
40
  materialize.
41
41
  """
42
42
 
43
- def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
43
+ def __init__(
44
+ self,
45
+ path: UPath,
46
+ disabled_layers: list[str] = [],
47
+ dataset_config: DatasetConfig | None = None,
48
+ ) -> None:
44
49
  """Initializes a new Dataset.
45
50
 
46
51
  Args:
47
52
  path: the root directory of the dataset
48
53
  disabled_layers: list of layers to disable
54
+ dataset_config: optional dataset configuration to use instead of loading from the dataset directory
49
55
  """
50
56
  self.path = path
51
57
 
52
- # Load dataset configuration.
53
- with (self.path / "config.json").open("r") as f:
54
- config_content = f.read()
55
- config_content = substitute_env_vars_in_string(config_content)
56
- config = DatasetConfig.model_validate(json.loads(config_content))
57
-
58
- self.layers = {}
59
- for layer_name, layer_config in config.layers.items():
60
- # Layer names must not contain period, since we use period to
61
- # distinguish different materialized groups within a layer.
62
- assert "." not in layer_name, "layer names must not contain periods"
63
- if layer_name in disabled_layers:
64
- logger.warning(f"Layer {layer_name} is disabled")
65
- continue
66
- self.layers[layer_name] = layer_config
67
-
68
- self.tile_store_config = config.tile_store
69
- self.storage = (
70
- config.storage.instantiate_window_storage_factory().get_storage(
71
- self.path
58
+ if dataset_config is None:
59
+ # Load dataset configuration from the dataset directory.
60
+ with (self.path / "config.json").open("r") as f:
61
+ config_content = f.read()
62
+ config_content = substitute_env_vars_in_string(config_content)
63
+ dataset_config = DatasetConfig.model_validate(
64
+ json.loads(config_content)
72
65
  )
66
+
67
+ self.layers = {}
68
+ for layer_name, layer_config in dataset_config.layers.items():
69
+ if layer_name in disabled_layers:
70
+ logger.warning(f"Layer {layer_name} is disabled")
71
+ continue
72
+ self.layers[layer_name] = layer_config
73
+
74
+ self.tile_store_config = dataset_config.tile_store
75
+ self.storage = (
76
+ dataset_config.storage.instantiate_window_storage_factory().get_storage(
77
+ self.path
73
78
  )
79
+ )
74
80
 
75
81
  def load_windows(
76
82
  self,
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
21
21
  from rslearn.train.data_module import RslearnDataModule
22
22
  from rslearn.train.lightning_module import RslearnLightningModule
23
23
  from rslearn.utils.fsspec import open_atomic
24
+ from rslearn.utils.jsonargparse import init_jsonargparse
24
25
 
25
26
  WANDB_ID_FNAME = "wandb_id"
26
27
 
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
390
391
 
391
392
  Sets the dataset path for any configured RslearnPredictionWriter callbacks.
392
393
  """
393
- subcommand = self.config.subcommand
394
- c = self.config[subcommand]
394
+ if not hasattr(self.config, "subcommand"):
395
+ logger.warning(
396
+ "Config does not have subcommand attribute, assuming we are in run=False mode"
397
+ )
398
+ subcommand = None
399
+ c = self.config
400
+ else:
401
+ subcommand = self.config.subcommand
402
+ c = self.config[subcommand]
395
403
 
396
404
  # If there is a RslearnPredictionWriter, set its path.
397
405
  prediction_writer_callback = None
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
415
423
  if subcommand == "predict":
416
424
  c.return_predictions = False
417
425
 
418
- # For now we use DDP strategy with find_unused_parameters=True.
426
+ # Default to DDP with find_unused_parameters. Likely won't get called with unified config
419
427
  if subcommand == "fit":
420
- c.trainer.strategy = jsonargparse.Namespace(
421
- {
422
- "class_path": "lightning.pytorch.strategies.DDPStrategy",
423
- "init_args": jsonargparse.Namespace(
424
- {"find_unused_parameters": True}
425
- ),
426
- }
427
- )
428
+ if not c.trainer.strategy:
429
+ c.trainer.strategy = jsonargparse.Namespace(
430
+ {
431
+ "class_path": "lightning.pytorch.strategies.DDPStrategy",
432
+ "init_args": jsonargparse.Namespace(
433
+ {"find_unused_parameters": True}
434
+ ),
435
+ }
436
+ )
428
437
 
429
438
  if c.management_dir:
430
439
  self.enable_project_management(c.management_dir)
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
432
441
 
433
442
  def model_handler() -> None:
434
443
  """Handler for any rslearn model X commands."""
444
+ init_jsonargparse()
445
+
435
446
  RslearnLightningCLI(
436
447
  model_class=RslearnLightningModule,
437
448
  datamodule_class=RslearnDataModule,
@@ -380,7 +380,7 @@ def apply_on_windows(
380
380
 
381
381
  def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
382
382
  """Call apply_on_windows with arguments passed via command-line interface."""
383
- dataset = Dataset(UPath(args.root), args.disabled_layers)
383
+ dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
384
384
  apply_on_windows(
385
385
  f=f,
386
386
  dataset=dataset,
@@ -0,0 +1,177 @@
1
+ """An attention pooling layer."""
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from rslearn.models.component import (
12
+ FeatureMaps,
13
+ IntermediateComponent,
14
+ TokenFeatureMaps,
15
+ )
16
+ from rslearn.train.model_context import ModelContext
17
+
18
+
19
+ class SimpleAttentionPool(IntermediateComponent):
20
+ """Simple Attention Pooling.
21
+
22
+ Given a token feature map of shape BCHWN,
23
+ learn an attention layer which aggregates over
24
+ the N dimension.
25
+
26
+ This is done simply by learning a mapping D->1 which is the weight
27
+ which should be assigned to each token during averaging:
28
+
29
+ output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
30
+ """
31
+
32
+ def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
33
+ """Initialize the simple attention pooling layer.
34
+
35
+ Args:
36
+ in_dim: the encoding dimension D
37
+ hidden_linear: whether to apply an additional linear transformation D -> D
38
+ to the feat tokens. If this is True, a ReLU activation is applied
39
+ after the first linear transformation.
40
+ """
41
+ super().__init__()
42
+ if hidden_linear:
43
+ self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
44
+ else:
45
+ self.hidden_linear = None
46
+ self.linear = nn.Linear(in_features=in_dim, out_features=1)
47
+
48
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
49
+ """Attention pooling for a single feature map (BCHWN tensor)."""
50
+ B, D, H, W, N = feat_tokens.shape
51
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
52
+ if self.hidden_linear is not None:
53
+ feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
54
+ attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
55
+ feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
56
+ return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
57
+
58
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
59
+ """Forward pass for attention pooling linear probe.
60
+
61
+ Args:
62
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
63
+ We pool over the final dimension in the TokenFeatureMaps. If multiple maps
64
+ are passed, we apply the same linear layers to all of them.
65
+ context: the model context.
66
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
67
+
68
+ Returns:
69
+ torch.Tensor:
70
+ - output, attentioned pool over the last dimension (B, C, H, W)
71
+ """
72
+ if not isinstance(intermediates, TokenFeatureMaps):
73
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
74
+
75
+ features = []
76
+ for feat in intermediates.feature_maps:
77
+ features.append(self.forward_for_map(feat))
78
+ return FeatureMaps(features)
79
+
80
+
81
+ class AttentionPool(IntermediateComponent):
82
+ """Attention Pooling.
83
+
84
+ Given a feature map of shape BCHWN,
85
+ learn an attention layer which aggregates over
86
+ the N dimension.
87
+
88
+ We do this by learning a query token, and applying a standard
89
+ attention mechanism against this learned query token.
90
+ """
91
+
92
+ def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
93
+ """Initialize the attention pooling layer.
94
+
95
+ Args:
96
+ in_dim: the encoding dimension D
97
+ num_heads: the number of heads to use
98
+ linear_on_kv: Whether to apply a linear layer on the input tokens
99
+ to create the key and value tokens.
100
+ """
101
+ super().__init__()
102
+ self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
103
+ if linear_on_kv:
104
+ self.k_linear = nn.Linear(in_dim, in_dim)
105
+ self.v_linear = nn.Linear(in_dim, in_dim)
106
+ else:
107
+ self.k_linear = None
108
+ self.v_linear = None
109
+ if in_dim % num_heads != 0:
110
+ raise ValueError(
111
+ f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
112
+ )
113
+ self.num_heads = num_heads
114
+ self.init_weights()
115
+
116
+ def init_weights(self) -> None:
117
+ """Initialize weights for the probe."""
118
+ nn.init.trunc_normal_(self.query_token, std=0.02)
119
+
120
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
121
+ """Attention pooling for a single feature map (BCHWN tensor)."""
122
+ B, D, H, W, N = feat_tokens.shape
123
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
124
+ collapsed_dim = B * H * W
125
+ q = self.query_token.expand(collapsed_dim, 1, -1)
126
+ q = q.reshape(
127
+ collapsed_dim, 1, self.num_heads, D // self.num_heads
128
+ ) # [B, 1, head, D_head]
129
+ q = rearrange(q, "b h n d -> b n h d")
130
+ if self.k_linear is not None:
131
+ assert self.v_linear is not None
132
+ k = self.k_linear(feat_tokens).reshape(
133
+ collapsed_dim, N, self.num_heads, D // self.num_heads
134
+ )
135
+ v = self.v_linear(feat_tokens).reshape(
136
+ collapsed_dim, N, self.num_heads, D // self.num_heads
137
+ )
138
+ else:
139
+ k = feat_tokens.reshape(
140
+ collapsed_dim, N, self.num_heads, D // self.num_heads
141
+ )
142
+ v = feat_tokens.reshape(
143
+ collapsed_dim, N, self.num_heads, D // self.num_heads
144
+ )
145
+ k = rearrange(k, "b n h d -> b h n d")
146
+ v = rearrange(v, "b n h d -> b h n d")
147
+
148
+ # Compute attention scores
149
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
150
+ D // self.num_heads
151
+ )
152
+ attn_weights = F.softmax(attn_scores, dim=-1)
153
+ x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
+ return x.reshape(B, D, H, W)
155
+
156
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
+ """Forward pass for attention pooling linear probe.
158
+
159
+ Args:
160
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
161
+ We pool over the final dimension in the TokenFeatureMaps. If multiple feature
162
+ maps are passed, we apply the same attention weights (query token and linear k, v layers)
163
+ to all the maps.
164
+ context: the model context.
165
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
166
+
167
+ Returns:
168
+ torch.Tensor:
169
+ - output, attentioned pool over the last dimension (B, C, H, W)
170
+ """
171
+ if not isinstance(intermediates, TokenFeatureMaps):
172
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
173
+
174
+ features = []
175
+ for feat in intermediates.feature_maps:
176
+ features.append(self.forward_for_map(feat))
177
+ return FeatureMaps(features)
@@ -91,6 +91,18 @@ class FeatureMaps:
91
91
  feature_maps: list[torch.Tensor]
92
92
 
93
93
 
94
+ @dataclass
95
+ class TokenFeatureMaps:
96
+ """An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
97
+
98
+ Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
99
+ """
100
+
101
+ # List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
102
+ # (most fine-grained) to lowest resolution (coarsest).
103
+ feature_maps: list[torch.Tensor]
104
+
105
+
94
106
  @dataclass
95
107
  class FeatureVector:
96
108
  """An intermediate output type for a flat feature vector."""