rslearn 0.0.12__tar.gz → 0.0.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {rslearn-0.0.12/rslearn.egg-info → rslearn-0.0.13}/PKG-INFO +2 -2
  2. {rslearn-0.0.12 → rslearn-0.0.13}/pyproject.toml +2 -2
  3. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/config/dataset.py +23 -4
  4. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/planetary_computer.py +52 -0
  5. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/handler_summaries.py +1 -0
  6. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/manage.py +16 -2
  7. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/olmoearth_pretrain/model.py +1 -0
  8. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/prediction_writer.py +25 -8
  9. rslearn-0.0.13/rslearn/train/tasks/embedding.py +116 -0
  10. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/raster_format.py +38 -0
  11. {rslearn-0.0.12 → rslearn-0.0.13/rslearn.egg-info}/PKG-INFO +2 -2
  12. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn.egg-info/SOURCES.txt +1 -0
  13. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn.egg-info/requires.txt +1 -1
  14. {rslearn-0.0.12 → rslearn-0.0.13}/LICENSE +0 -0
  15. {rslearn-0.0.12 → rslearn-0.0.13}/NOTICE +0 -0
  16. {rslearn-0.0.12 → rslearn-0.0.13}/README.md +0 -0
  17. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/__init__.py +0 -0
  18. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/arg_parser.py +0 -0
  19. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/config/__init__.py +0 -0
  20. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/const.py +0 -0
  21. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/__init__.py +0 -0
  22. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/aws_landsat.py +0 -0
  23. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/aws_open_data.py +0 -0
  24. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/aws_sentinel1.py +0 -0
  25. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/climate_data_store.py +0 -0
  26. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/copernicus.py +0 -0
  27. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/data_source.py +0 -0
  28. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/earthdaily.py +0 -0
  29. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/earthdata_srtm.py +0 -0
  30. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/eurocrops.py +0 -0
  31. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/gcp_public_data.py +0 -0
  32. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/geotiff.py +0 -0
  33. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/google_earth_engine.py +0 -0
  34. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/local_files.py +0 -0
  35. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/openstreetmap.py +0 -0
  36. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/planet.py +0 -0
  37. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/planet_basemap.py +0 -0
  38. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/raster_source.py +0 -0
  39. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/usda_cdl.py +0 -0
  40. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/usgs_landsat.py +0 -0
  41. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/utils.py +0 -0
  42. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/vector_source.py +0 -0
  43. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/worldcereal.py +0 -0
  44. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/worldcover.py +0 -0
  45. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/worldpop.py +0 -0
  46. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/data_sources/xyz_tiles.py +0 -0
  47. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/__init__.py +0 -0
  48. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/add_windows.py +0 -0
  49. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/dataset.py +0 -0
  50. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/index.py +0 -0
  51. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/materialize.py +0 -0
  52. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/remap.py +0 -0
  53. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/dataset/window.py +0 -0
  54. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/log_utils.py +0 -0
  55. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/main.py +0 -0
  56. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/__init__.py +0 -0
  57. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/anysat.py +0 -0
  58. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/clay/clay.py +0 -0
  59. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/clay/configs/metadata.yaml +0 -0
  60. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/clip.py +0 -0
  61. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/conv.py +0 -0
  62. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/croma.py +0 -0
  63. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/__init__.py +0 -0
  64. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/box_ops.py +0 -0
  65. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/detr.py +0 -0
  66. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/matcher.py +0 -0
  67. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/position_encoding.py +0 -0
  68. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/transformer.py +0 -0
  69. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/detr/util.py +0 -0
  70. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/dinov3.py +0 -0
  71. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/faster_rcnn.py +0 -0
  72. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/feature_center_crop.py +0 -0
  73. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/fpn.py +0 -0
  74. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/galileo/__init__.py +0 -0
  75. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/galileo/galileo.py +0 -0
  76. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/galileo/single_file_galileo.py +0 -0
  77. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/module_wrapper.py +0 -0
  78. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/molmo.py +0 -0
  79. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/multitask.py +0 -0
  80. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  81. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  82. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon.py +0 -0
  83. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  84. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  85. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  86. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  87. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  88. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  89. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  90. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  91. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  92. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  93. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  94. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  95. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/pick_features.py +0 -0
  96. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/pooling_decoder.py +0 -0
  97. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/presto/__init__.py +0 -0
  98. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/presto/presto.py +0 -0
  99. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/presto/single_file_presto.py +0 -0
  100. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/prithvi.py +0 -0
  101. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/registry.py +0 -0
  102. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/resize_features.py +0 -0
  103. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/sam2_enc.py +0 -0
  104. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/satlaspretrain.py +0 -0
  105. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/simple_time_series.py +0 -0
  106. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/singletask.py +0 -0
  107. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/ssl4eo_s12.py +0 -0
  108. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/swin.py +0 -0
  109. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/task_embedding.py +0 -0
  110. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/terramind.py +0 -0
  111. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/trunk.py +0 -0
  112. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/unet.py +0 -0
  113. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/upsample.py +0 -0
  114. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/models/use_croma.py +0 -0
  115. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/py.typed +0 -0
  116. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/template_params.py +0 -0
  117. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/tile_stores/__init__.py +0 -0
  118. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/tile_stores/default.py +0 -0
  119. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/tile_stores/tile_store.py +0 -0
  120. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/__init__.py +0 -0
  121. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/callbacks/__init__.py +0 -0
  122. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/callbacks/adapters.py +0 -0
  123. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  124. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/callbacks/gradients.py +0 -0
  125. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/callbacks/peft.py +0 -0
  126. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/data_module.py +0 -0
  127. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/dataset.py +0 -0
  128. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/lightning_module.py +0 -0
  129. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/optimizer.py +0 -0
  130. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/scheduler.py +0 -0
  131. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/__init__.py +0 -0
  132. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/classification.py +0 -0
  133. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/detection.py +0 -0
  134. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/multi_task.py +0 -0
  135. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/per_pixel_regression.py +0 -0
  136. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/regression.py +0 -0
  137. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/segmentation.py +0 -0
  138. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/tasks/task.py +0 -0
  139. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/__init__.py +0 -0
  140. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/concatenate.py +0 -0
  141. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/crop.py +0 -0
  142. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/flip.py +0 -0
  143. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/mask.py +0 -0
  144. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/normalize.py +0 -0
  145. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/pad.py +0 -0
  146. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/select_bands.py +0 -0
  147. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/sentinel1.py +0 -0
  148. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/train/transforms/transform.py +0 -0
  149. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/__init__.py +0 -0
  150. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/array.py +0 -0
  151. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/feature.py +0 -0
  152. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/fsspec.py +0 -0
  153. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/geometry.py +0 -0
  154. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/get_utm_ups_crs.py +0 -0
  155. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/grid_index.py +0 -0
  156. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/jsonargparse.py +0 -0
  157. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/mp.py +0 -0
  158. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/rtree_index.py +0 -0
  159. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/spatial_index.py +0 -0
  160. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/sqlite_index.py +0 -0
  161. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/time.py +0 -0
  162. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn/utils/vector_format.py +0 -0
  163. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn.egg-info/dependency_links.txt +0 -0
  164. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn.egg-info/entry_points.txt +0 -0
  165. {rslearn-0.0.12 → rslearn-0.0.13}/rslearn.egg-info/top_level.txt +0 -0
  166. {rslearn-0.0.12 → rslearn-0.0.13}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.12
3
+ Version: 0.0.13
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -214,7 +214,7 @@ License-File: LICENSE
214
214
  License-File: NOTICE
215
215
  Requires-Dist: boto3>=1.39
216
216
  Requires-Dist: fiona>=1.10
217
- Requires-Dist: fsspec>=2025.9.0
217
+ Requires-Dist: fsspec>=2025.10.0
218
218
  Requires-Dist: jsonargparse>=4.35.0
219
219
  Requires-Dist: lightning>=2.5.1.post0
220
220
  Requires-Dist: Pillow>=11.3
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.12"
3
+ version = "0.0.13"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -11,7 +11,7 @@ requires-python = ">=3.11"
11
11
  dependencies = [
12
12
  "boto3>=1.39",
13
13
  "fiona>=1.10",
14
- "fsspec>=2025.9.0", # this is used both directly and indirectly (via universal_pathlib) in our code
14
+ "fsspec>=2025.10.0", # this is used both directly and indirectly (via universal_pathlib) in our code
15
15
  "jsonargparse>=4.35.0",
16
16
  "lightning>=2.5.1.post0",
17
17
  "Pillow>=11.3",
@@ -125,7 +125,8 @@ class BandSetConfig:
125
125
  self,
126
126
  config_dict: dict[str, Any],
127
127
  dtype: DType,
128
- bands: list[str],
128
+ bands: list[str] | None = None,
129
+ num_bands: int | None = None,
129
130
  format: dict[str, Any] | None = None,
130
131
  zoom_offset: int = 0,
131
132
  remap: dict[str, Any] | None = None,
@@ -137,7 +138,10 @@ class BandSetConfig:
137
138
  Args:
138
139
  config_dict: the config dict used to configure this BandSetConfig
139
140
  dtype: the pixel value type to store tiles in
140
- bands: list of band names in this BandSetConfig
141
+ bands: list of band names in this BandSetConfig. One of bands or num_bands
142
+ must be set.
143
+ num_bands: the number of bands in this band set. The bands will be named
144
+ B00, B01, B02, etc.
141
145
  format: the format to store tiles in, defaults to geotiff
142
146
  zoom_offset: store images at a resolution higher or lower than the window
143
147
  resolution. This enables keeping source data at its native resolution,
@@ -155,6 +159,14 @@ class BandSetConfig:
155
159
  materialization when creating mosaics, to determine which parts of the
156
160
  source images should be copied.
157
161
  """
162
+ if (bands is None and num_bands is None) or (
163
+ bands is not None and num_bands is not None
164
+ ):
165
+ raise ValueError("exactly one of bands and num_bands must be set")
166
+ if bands is None:
167
+ assert num_bands is not None
168
+ bands = [f"B{idx}" for idx in range(num_bands)]
169
+
158
170
  if class_names is not None and len(bands) != len(class_names):
159
171
  raise ValueError(
160
172
  f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
@@ -187,9 +199,16 @@ class BandSetConfig:
187
199
  kwargs = dict(
188
200
  config_dict=config,
189
201
  dtype=DType(config["dtype"]),
190
- bands=config["bands"],
191
202
  )
192
- for k in ["format", "zoom_offset", "remap", "class_names", "nodata_vals"]:
203
+ for k in [
204
+ "bands",
205
+ "num_bands",
206
+ "format",
207
+ "zoom_offset",
208
+ "remap",
209
+ "class_names",
210
+ "nodata_vals",
211
+ ]:
193
212
  if k in config:
194
213
  kwargs[k] = config[k]
195
214
  return BandSetConfig(**kwargs) # type: ignore
@@ -827,3 +827,55 @@ class Sentinel1(PlanetaryComputer):
827
827
  kwargs[k] = d[k]
828
828
 
829
829
  return Sentinel1(**kwargs)
830
+
831
+
832
+ class Naip(PlanetaryComputer):
833
+ """A data source for NAIP data on Microsoft Planetary Computer.
834
+
835
+ See https://planetarycomputer.microsoft.com/dataset/naip.
836
+ """
837
+
838
+ COLLECTION_NAME = "naip"
839
+ ASSET_BANDS = {"image": ["R", "G", "B", "NIR"]}
840
+
841
+ def __init__(
842
+ self,
843
+ **kwargs: Any,
844
+ ):
845
+ """Initialize a new Naip instance.
846
+
847
+ Args:
848
+ band_names: list of bands to try to ingest.
849
+ kwargs: additional arguments to pass to PlanetaryComputer.
850
+ """
851
+ super().__init__(
852
+ collection_name=self.COLLECTION_NAME,
853
+ asset_bands=self.ASSET_BANDS,
854
+ **kwargs,
855
+ )
856
+
857
+ @staticmethod
858
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
859
+ """Creates a new Naip instance from a configuration dictionary."""
860
+ if config.data_source is None:
861
+ raise ValueError("config.data_source is required")
862
+ d = config.data_source.config_dict
863
+ kwargs = {}
864
+
865
+ if "timeout_seconds" in d:
866
+ kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
867
+
868
+ if "cache_dir" in d:
869
+ kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
870
+
871
+ simple_optionals = [
872
+ "query",
873
+ "sort_by",
874
+ "sort_ascending",
875
+ "max_items_per_client",
876
+ ]
877
+ for k in simple_optionals:
878
+ if k in d:
879
+ kwargs[k] = d[k]
880
+
881
+ return Naip(**kwargs)
@@ -20,6 +20,7 @@ class LayerPrepareSummary:
20
20
  # Counts
21
21
  windows_prepared: int
22
22
  windows_skipped: int
23
+ windows_rejected: int
23
24
  get_items_attempts: int
24
25
 
25
26
 
@@ -118,6 +118,7 @@ def prepare_dataset_windows(
118
118
  duration_seconds=time.monotonic() - layer_start_time,
119
119
  windows_prepared=0,
120
120
  windows_skipped=len(windows),
121
+ windows_rejected=0,
121
122
  get_items_attempts=0,
122
123
  )
123
124
  )
@@ -141,6 +142,7 @@ def prepare_dataset_windows(
141
142
  duration_seconds=time.monotonic() - layer_start_time,
142
143
  windows_prepared=0,
143
144
  windows_skipped=len(windows),
145
+ windows_rejected=0,
144
146
  get_items_attempts=0,
145
147
  )
146
148
  )
@@ -181,6 +183,9 @@ def prepare_dataset_windows(
181
183
  attempts_counter=attempts_counter,
182
184
  )
183
185
 
186
+ windows_prepared = 0
187
+ windows_rejected = 0
188
+ min_matches = data_source_cfg.query_config.min_matches
184
189
  for window, result in zip(needed_windows, results):
185
190
  layer_datas = window.load_layer_datas()
186
191
  layer_datas[layer_name] = WindowLayerData(
@@ -191,13 +196,22 @@ def prepare_dataset_windows(
191
196
  )
192
197
  window.save_layer_datas(layer_datas)
193
198
 
199
+ # If result is empty and min_matches > 0, window was rejected due to min_matches
200
+ if len(result) == 0 and min_matches > 0:
201
+ windows_rejected += 1
202
+ else:
203
+ windows_prepared += 1
204
+
205
+ windows_skipped = len(windows) - len(needed_windows)
206
+
194
207
  layer_summaries.append(
195
208
  LayerPrepareSummary(
196
209
  layer_name=layer_name,
197
210
  data_source_name=data_source_cfg.name,
198
211
  duration_seconds=time.monotonic() - layer_start_time,
199
- windows_prepared=len(needed_windows), # we assume all have succeeded
200
- windows_skipped=len(windows) - len(needed_windows),
212
+ windows_prepared=windows_prepared,
213
+ windows_skipped=windows_skipped,
214
+ windows_rejected=windows_rejected,
201
215
  get_items_attempts=attempts_counter.value,
202
216
  )
203
217
  )
@@ -40,6 +40,7 @@ EMBEDDING_SIZES = {
40
40
  ModelID.OLMOEARTH_V1_NANO: 128,
41
41
  ModelID.OLMOEARTH_V1_TINY: 192,
42
42
  ModelID.OLMOEARTH_V1_BASE: 768,
43
+ ModelID.OLMOEARTH_V1_LARGE: 1024,
43
44
  }
44
45
 
45
46
 
@@ -22,7 +22,11 @@ from rslearn.log_utils import get_logger
22
22
  from rslearn.utils.array import copy_spatial_array
23
23
  from rslearn.utils.feature import Feature
24
24
  from rslearn.utils.geometry import PixelBounds
25
- from rslearn.utils.raster_format import RasterFormat, load_raster_format
25
+ from rslearn.utils.raster_format import (
26
+ RasterFormat,
27
+ adjust_projection_and_bounds_for_array,
28
+ load_raster_format,
29
+ )
26
30
  from rslearn.utils.vector_format import VectorFormat, load_vector_format
27
31
 
28
32
  from .lightning_module import RslearnLightningModule
@@ -68,15 +72,18 @@ class VectorMerger(PatchPredictionMerger):
68
72
  class RasterMerger(PatchPredictionMerger):
69
73
  """Merger for raster data that copies the rasters to the output."""
70
74
 
71
- def __init__(self, padding: int | None = None):
75
+ def __init__(self, padding: int | None = None, downsample_factor: int = 1):
72
76
  """Create a new RasterMerger.
73
77
 
74
78
  Args:
75
79
  padding: the padding around the individual patch outputs to remove. This is
76
80
  typically used when leveraging overlapping patches. Portions of outputs
77
81
  at the border of the window will still be retained.
82
+ downsample_factor: the factor by which the rasters output by the task are
83
+ lower in resolution relative to the window resolution.
78
84
  """
79
85
  self.padding = padding
86
+ self.downsample_factor = downsample_factor
80
87
 
81
88
  def merge(
82
89
  self, window: Window, outputs: Sequence[PendingPatchOutput]
@@ -87,8 +94,8 @@ class RasterMerger(PatchPredictionMerger):
87
94
  merged_image = np.zeros(
88
95
  (
89
96
  num_channels,
90
- window.bounds[3] - window.bounds[1],
91
- window.bounds[2] - window.bounds[0],
97
+ (window.bounds[3] - window.bounds[1]) // self.downsample_factor,
98
+ (window.bounds[2] - window.bounds[0]) // self.downsample_factor,
92
99
  ),
93
100
  dtype=dtype,
94
101
  )
@@ -104,7 +111,10 @@ class RasterMerger(PatchPredictionMerger):
104
111
  # If the output is not on the left or top boundary, then we should apply
105
112
  # the padding (if set).
106
113
  src = output.output
107
- src_offset = (output.bounds[0], output.bounds[1])
114
+ src_offset = (
115
+ output.bounds[0] // self.downsample_factor,
116
+ output.bounds[1] // self.downsample_factor,
117
+ )
108
118
  if self.padding is not None and output.bounds[0] != window.bounds[0]:
109
119
  src = src[:, :, self.padding :]
110
120
  src_offset = (src_offset[0] + self.padding, src_offset[1])
@@ -116,7 +126,10 @@ class RasterMerger(PatchPredictionMerger):
116
126
  src=src,
117
127
  dst=merged_image,
118
128
  src_offset=src_offset,
119
- dst_offset=(window.bounds[0], window.bounds[1]),
129
+ dst_offset=(
130
+ window.bounds[0] // self.downsample_factor,
131
+ window.bounds[1] // self.downsample_factor,
132
+ ),
120
133
  )
121
134
 
122
135
  return merged_image
@@ -330,9 +343,13 @@ class RslearnWriter(BasePredictionWriter):
330
343
  self.output_layer, self.layer_config.band_sets[0].bands
331
344
  )
332
345
  assert isinstance(self.format, RasterFormat)
333
- self.format.encode_raster(
334
- raster_dir, window.projection, window.bounds, merged_output
346
+
347
+ # In case the merged_output is at a different resolution than the window,
348
+ # get adjusted projection and bounds for writing it.
349
+ projection, bounds = adjust_projection_and_bounds_for_array(
350
+ window.projection, window.bounds, merged_output
335
351
  )
352
+ self.format.encode_raster(raster_dir, projection, bounds, merged_output)
336
353
 
337
354
  elif self.layer_config.layer_type == LayerType.VECTOR:
338
355
  layer_dir = window.get_layer_dir(self.output_layer)
@@ -0,0 +1,116 @@
1
+ """Embedding task."""
2
+
3
+ from typing import Any
4
+
5
+ import numpy.typing as npt
6
+ import torch
7
+ from torchmetrics import MetricCollection
8
+
9
+ from rslearn.utils import Feature
10
+
11
+ from .task import Task
12
+
13
+
14
+ class EmbeddingTask(Task):
15
+ """A dummy task for computing embeddings.
16
+
17
+ This task does not compute any targets or loss. Instead, it is just set up for
18
+ inference, to save embeddings from the configured model.
19
+ """
20
+
21
+ def process_inputs(
22
+ self,
23
+ raw_inputs: dict[str, torch.Tensor],
24
+ metadata: dict[str, Any],
25
+ load_targets: bool = True,
26
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
27
+ """Processes the data into targets.
28
+
29
+ Args:
30
+ raw_inputs: raster or vector data to process
31
+ metadata: metadata about the patch being read
32
+ load_targets: whether to load the targets or only inputs
33
+
34
+ Returns:
35
+ tuple (input_dict, target_dict) containing the processed inputs and targets
36
+ that are compatible with both metrics and loss functions
37
+ """
38
+ return {}, {}
39
+
40
+ def process_output(
41
+ self, raw_output: Any, metadata: dict[str, Any]
42
+ ) -> npt.NDArray[Any] | list[Feature]:
43
+ """Processes an output into raster or vector data.
44
+
45
+ Args:
46
+ raw_output: the output from prediction head.
47
+ metadata: metadata about the patch being read
48
+
49
+ Returns:
50
+ either raster or vector data.
51
+ """
52
+ # Just convert the raw output to numpy array that can be saved to GeoTIFF.
53
+ return raw_output.cpu().numpy()
54
+
55
+ def visualize(
56
+ self,
57
+ input_dict: dict[str, Any],
58
+ target_dict: dict[str, Any] | None,
59
+ output: Any,
60
+ ) -> dict[str, npt.NDArray[Any]]:
61
+ """Visualize the outputs and targets.
62
+
63
+ Args:
64
+ input_dict: the input dict from process_inputs
65
+ target_dict: the target dict from process_inputs
66
+ output: the prediction
67
+
68
+ Returns:
69
+ a dictionary mapping image name to visualization image
70
+ """
71
+ # EmbeddingTask is only set up to support `model predict`.
72
+ raise NotImplementedError
73
+
74
+ def get_metrics(self) -> MetricCollection:
75
+ """Get the metrics for this task."""
76
+ return MetricCollection({})
77
+
78
+
79
+ class EmbeddingHead(torch.nn.Module):
80
+ """Head for embedding task.
81
+
82
+ This picks one feature map from the input list of feature maps to output. It also
83
+ returns a dummy loss.
84
+ """
85
+
86
+ def __init__(self, feature_map_index: int | None = 0):
87
+ """Create a new EmbeddingHead.
88
+
89
+ Args:
90
+ feature_map_index: the index of the feature map to choose from the input
91
+ list of multi-scale feature maps (default 0). If the input is already
92
+ a single feature map, then set to None.
93
+ """
94
+ super().__init__()
95
+ self.feature_map_index = feature_map_index
96
+
97
+ def forward(
98
+ self,
99
+ features: torch.Tensor,
100
+ inputs: list[dict[str, Any]],
101
+ targets: list[dict[str, Any]] | None = None,
102
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
103
+ """Select the desired feature map and return it along with a dummy loss.
104
+
105
+ Args:
106
+ features: list of BCHW feature maps (or one feature map, if feature_map_index is None).
107
+ inputs: original inputs (ignored).
108
+ targets: should contain classes key that stores the per-pixel class labels.
109
+
110
+ Returns:
111
+ tuple of outputs and loss dict
112
+ """
113
+ if self.feature_map_index is not None:
114
+ features = features[self.feature_map_index]
115
+
116
+ return features, {"loss": 0}
@@ -123,6 +123,44 @@ def get_transform_from_projection_and_bounds(
123
123
  )
124
124
 
125
125
 
126
+ def adjust_projection_and_bounds_for_array(
127
+ projection: Projection, bounds: PixelBounds, array: npt.NDArray
128
+ ) -> tuple[Projection, PixelBounds]:
129
+ """Adjust the projection and bounds to correspond to the resolution of the array.
130
+
131
+ The returned projection and bounds cover the same spatial extent as the inputs, but
132
+ are updated so that the width and height match that of the array.
133
+
134
+ Args:
135
+ projection: the original projection.
136
+ bounds: the original bounds.
137
+ array: the CHW array for which to compute an updated projection and bounds. The
138
+ returned bounds will have the same width and height as this array.
139
+
140
+ Returns:
141
+ a tuple of adjusted (projection, bounds)
142
+ """
143
+ if array.shape[2] == (bounds[2] - bounds[0]) and array.shape[1] == (
144
+ bounds[3] - bounds[1]
145
+ ):
146
+ return (projection, bounds)
147
+
148
+ x_factor = array.shape[2] / (bounds[2] - bounds[0])
149
+ y_factor = array.shape[1] / (bounds[3] - bounds[1])
150
+ adjusted_projection = Projection(
151
+ projection.crs,
152
+ projection.x_resolution / x_factor,
153
+ projection.y_resolution / y_factor,
154
+ )
155
+ adjusted_bounds = (
156
+ round(bounds[0] * x_factor),
157
+ round(bounds[1] * y_factor),
158
+ round(bounds[0] * x_factor) + array.shape[2],
159
+ round(bounds[1] * y_factor) + array.shape[1],
160
+ )
161
+ return (adjusted_projection, adjusted_bounds)
162
+
163
+
126
164
  class RasterFormat:
127
165
  """An abstract class for writing raster data.
128
166
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.12
3
+ Version: 0.0.13
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -214,7 +214,7 @@ License-File: LICENSE
214
214
  License-File: NOTICE
215
215
  Requires-Dist: boto3>=1.39
216
216
  Requires-Dist: fiona>=1.10
217
- Requires-Dist: fsspec>=2025.9.0
217
+ Requires-Dist: fsspec>=2025.10.0
218
218
  Requires-Dist: jsonargparse>=4.35.0
219
219
  Requires-Dist: lightning>=2.5.1.post0
220
220
  Requires-Dist: Pillow>=11.3
@@ -131,6 +131,7 @@ rslearn/train/callbacks/peft.py
131
131
  rslearn/train/tasks/__init__.py
132
132
  rslearn/train/tasks/classification.py
133
133
  rslearn/train/tasks/detection.py
134
+ rslearn/train/tasks/embedding.py
134
135
  rslearn/train/tasks/multi_task.py
135
136
  rslearn/train/tasks/per_pixel_regression.py
136
137
  rslearn/train/tasks/regression.py
@@ -1,6 +1,6 @@
1
1
  boto3>=1.39
2
2
  fiona>=1.10
3
- fsspec>=2025.9.0
3
+ fsspec>=2025.10.0
4
4
  jsonargparse>=4.35.0
5
5
  lightning>=2.5.1.post0
6
6
  Pillow>=11.3
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes