rslearn 0.0.19__tar.gz → 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. {rslearn-0.0.19/rslearn.egg-info → rslearn-0.0.21}/PKG-INFO +1 -1
  2. {rslearn-0.0.19 → rslearn-0.0.21}/pyproject.toml +1 -1
  3. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/climate_data_store.py +216 -29
  4. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/anysat.py +35 -33
  5. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/clip.py +5 -2
  6. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/croma.py +11 -3
  7. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/dinov3.py +2 -1
  8. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/faster_rcnn.py +2 -1
  9. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/galileo/galileo.py +58 -31
  10. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/module_wrapper.py +6 -1
  11. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/molmo.py +4 -2
  12. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/model.py +95 -32
  13. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/norm.py +5 -3
  14. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon.py +3 -1
  15. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/presto/presto.py +45 -15
  16. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/prithvi.py +9 -7
  17. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/sam2_enc.py +3 -1
  18. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/satlaspretrain.py +4 -1
  19. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/simple_time_series.py +36 -16
  20. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/ssl4eo_s12.py +19 -14
  21. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/swin.py +3 -1
  22. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/terramind.py +5 -4
  23. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/all_patches_dataset.py +34 -14
  24. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/dataset.py +73 -8
  25. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/model_context.py +35 -1
  26. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/classification.py +8 -2
  27. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/detection.py +3 -2
  28. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/multi_task.py +2 -3
  29. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/per_pixel_regression.py +14 -5
  30. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/regression.py +8 -2
  31. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/segmentation.py +13 -4
  32. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/task.py +2 -2
  33. rslearn-0.0.21/rslearn/train/transforms/concatenate.py +89 -0
  34. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/crop.py +22 -8
  35. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/flip.py +13 -5
  36. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/mask.py +11 -2
  37. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/normalize.py +46 -15
  38. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/pad.py +15 -3
  39. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/resize.py +18 -9
  40. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/select_bands.py +11 -2
  41. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/sentinel1.py +18 -3
  42. {rslearn-0.0.19 → rslearn-0.0.21/rslearn.egg-info}/PKG-INFO +1 -1
  43. rslearn-0.0.19/rslearn/train/transforms/concatenate.py +0 -49
  44. {rslearn-0.0.19 → rslearn-0.0.21}/LICENSE +0 -0
  45. {rslearn-0.0.19 → rslearn-0.0.21}/NOTICE +0 -0
  46. {rslearn-0.0.19 → rslearn-0.0.21}/README.md +0 -0
  47. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/__init__.py +0 -0
  48. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/arg_parser.py +0 -0
  49. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/config/__init__.py +0 -0
  50. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/config/dataset.py +0 -0
  51. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/const.py +0 -0
  52. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/__init__.py +0 -0
  53. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/aws_landsat.py +0 -0
  54. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/aws_open_data.py +0 -0
  55. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/aws_sentinel1.py +0 -0
  56. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/copernicus.py +0 -0
  57. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/data_source.py +0 -0
  58. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/earthdaily.py +0 -0
  59. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/earthdata_srtm.py +0 -0
  60. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/eurocrops.py +0 -0
  61. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/gcp_public_data.py +0 -0
  62. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/google_earth_engine.py +0 -0
  63. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/local_files.py +0 -0
  64. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/openstreetmap.py +0 -0
  65. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/planet.py +0 -0
  66. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/planet_basemap.py +0 -0
  67. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/planetary_computer.py +0 -0
  68. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/usda_cdl.py +0 -0
  69. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/usgs_landsat.py +0 -0
  70. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/utils.py +0 -0
  71. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/vector_source.py +0 -0
  72. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/worldcereal.py +0 -0
  73. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/worldcover.py +0 -0
  74. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/worldpop.py +0 -0
  75. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/data_sources/xyz_tiles.py +0 -0
  76. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/__init__.py +0 -0
  77. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/add_windows.py +0 -0
  78. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/dataset.py +0 -0
  79. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/handler_summaries.py +0 -0
  80. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/manage.py +0 -0
  81. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/materialize.py +0 -0
  82. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/remap.py +0 -0
  83. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/storage/__init__.py +0 -0
  84. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/storage/file.py +0 -0
  85. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/storage/storage.py +0 -0
  86. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/dataset/window.py +0 -0
  87. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/lightning_cli.py +0 -0
  88. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/log_utils.py +0 -0
  89. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/main.py +0 -0
  90. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/__init__.py +0 -0
  91. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/attention_pooling.py +0 -0
  92. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/clay/clay.py +0 -0
  93. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/clay/configs/metadata.yaml +0 -0
  94. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/component.py +0 -0
  95. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/concatenate_features.py +0 -0
  96. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/conv.py +0 -0
  97. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/__init__.py +0 -0
  98. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/box_ops.py +0 -0
  99. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/detr.py +0 -0
  100. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/matcher.py +0 -0
  101. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/position_encoding.py +0 -0
  102. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/transformer.py +0 -0
  103. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/detr/util.py +0 -0
  104. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/feature_center_crop.py +0 -0
  105. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/fpn.py +0 -0
  106. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/galileo/__init__.py +0 -0
  107. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/galileo/single_file_galileo.py +0 -0
  108. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/multitask.py +0 -0
  109. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  110. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  111. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  112. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  113. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  114. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  115. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  116. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  117. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  118. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  119. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  120. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  121. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  122. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/pick_features.py +0 -0
  123. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/pooling_decoder.py +0 -0
  124. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/presto/__init__.py +0 -0
  125. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/presto/single_file_presto.py +0 -0
  126. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/resize_features.py +0 -0
  127. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/singletask.py +0 -0
  128. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/task_embedding.py +0 -0
  129. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/trunk.py +0 -0
  130. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/unet.py +0 -0
  131. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/upsample.py +0 -0
  132. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/models/use_croma.py +0 -0
  133. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/py.typed +0 -0
  134. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/template_params.py +0 -0
  135. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/tile_stores/__init__.py +0 -0
  136. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/tile_stores/default.py +0 -0
  137. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/tile_stores/tile_store.py +0 -0
  138. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/__init__.py +0 -0
  139. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/__init__.py +0 -0
  140. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/adapters.py +0 -0
  141. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  142. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/gradients.py +0 -0
  143. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/callbacks/peft.py +0 -0
  144. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/data_module.py +0 -0
  145. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/lightning_module.py +0 -0
  146. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/optimizer.py +0 -0
  147. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/prediction_writer.py +0 -0
  148. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/scheduler.py +0 -0
  149. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/__init__.py +0 -0
  150. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/tasks/embedding.py +0 -0
  151. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/__init__.py +0 -0
  152. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/train/transforms/transform.py +0 -0
  153. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/__init__.py +0 -0
  154. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/array.py +0 -0
  155. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/feature.py +0 -0
  156. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/fsspec.py +0 -0
  157. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/geometry.py +0 -0
  158. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/get_utm_ups_crs.py +0 -0
  159. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/grid_index.py +0 -0
  160. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/jsonargparse.py +0 -0
  161. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/mp.py +0 -0
  162. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/raster_format.py +0 -0
  163. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/rtree_index.py +0 -0
  164. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/spatial_index.py +0 -0
  165. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/sqlite_index.py +0 -0
  166. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/time.py +0 -0
  167. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn/utils/vector_format.py +0 -0
  168. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/SOURCES.txt +0 -0
  169. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/dependency_links.txt +0 -0
  170. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/entry_points.txt +0 -0
  171. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/requires.txt +0 -0
  172. {rslearn-0.0.19 → rslearn-0.0.21}/rslearn.egg-info/top_level.txt +0 -0
  173. {rslearn-0.0.19 → rslearn-0.0.21}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.19"
3
+ version = "0.0.21"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -24,51 +24,55 @@ from rslearn.utils.geometry import STGeometry
24
24
  logger = get_logger(__name__)
25
25
 
26
26
 
27
- class ERA5LandMonthlyMeans(DataSource):
28
- """A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
27
+ class ERA5Land(DataSource):
28
+ """Base class for ingesting ERA5 land data from the Copernicus Climate Data Store.
29
29
 
30
30
  An API key must be passed either in the configuration or via the CDSAPI_KEY
31
31
  environment variable. You can acquire an API key by going to the Climate Data Store
32
32
  website (https://cds.climate.copernicus.eu/), registering an account and logging
33
- in, and then
33
+ in, and then getting the API key from the user profile page.
34
34
 
35
35
  The band names should match CDS variable names (see the reference at
36
36
  https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation). However,
37
37
  replace "_" with "-" in the variable names when specifying bands in the layer
38
38
  configuration.
39
39
 
40
- This data source corresponds to the reanalysis-era5-land-monthly-means product.
41
-
42
- All requests to the API will be for the whole globe. Although the API supports arbitrary
43
- bounds in the requests, using the whole available area helps to reduce the total number of
44
- requests.
40
+ By default, all requests to the API will be for the whole globe. To speed up ingestion,
41
+ we recommend specifying the bounds of the area of interest, in particular for hourly data.
45
42
  """
46
43
 
47
44
  api_url = "https://cds.climate.copernicus.eu/api"
48
-
49
- # see: https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-land-monthly-means
50
- DATASET = "reanalysis-era5-land-monthly-means"
51
- PRODUCT_TYPE = "monthly_averaged_reanalysis"
52
45
  DATA_FORMAT = "netcdf"
53
46
  DOWNLOAD_FORMAT = "unarchived"
54
47
  PIXEL_SIZE = 0.1 # degrees, native resolution is 9km
55
48
 
56
49
  def __init__(
57
50
  self,
51
+ dataset: str,
52
+ product_type: str,
58
53
  band_names: list[str] | None = None,
59
54
  api_key: str | None = None,
55
+ bounds: list[float] | None = None,
60
56
  context: DataSourceContext = DataSourceContext(),
61
57
  ):
62
- """Initialize a new ERA5LandMonthlyMeans instance.
58
+ """Initialize a new ERA5Land instance.
63
59
 
64
60
  Args:
61
+ dataset: the CDS dataset name (e.g., "reanalysis-era5-land-monthly-means").
62
+ product_type: the CDS product type (e.g., "monthly_averaged_reanalysis").
65
63
  band_names: list of band names to acquire. These should correspond to CDS
66
64
  variable names but with "_" replaced with "-". This will only be used
67
65
  if the layer config is missing from the context.
68
66
  api_key: the API key. If not set, it should be set via the CDSAPI_KEY
69
67
  environment variable.
68
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
69
+ If not specified, the whole globe will be used.
70
70
  context: the data source context.
71
71
  """
72
+ self.dataset = dataset
73
+ self.product_type = product_type
74
+ self.bounds = bounds
75
+
72
76
  self.band_names: list[str]
73
77
  if context.layer_config is not None:
74
78
  self.band_names = []
@@ -134,8 +138,11 @@ class ERA5LandMonthlyMeans(DataSource):
134
138
  # Collect Item list corresponding to the current month.
135
139
  items = []
136
140
  item_name = f"era5land_monthlyaveraged_{cur_date.year}_{cur_date.month}"
137
- # Space is the whole globe.
138
- bounds = (-180, -90, 180, 90)
141
+ # Use bounds if set, otherwise use whole globe
142
+ if self.bounds is not None:
143
+ bounds = self.bounds
144
+ else:
145
+ bounds = [-180, -90, 180, 90]
139
146
  # Time is just the given month.
140
147
  start_date = datetime(cur_date.year, cur_date.month, 1, tzinfo=UTC)
141
148
  time_range = (
@@ -172,7 +179,9 @@ class ERA5LandMonthlyMeans(DataSource):
172
179
  # But the list of variables should include the bands we want in the correct
173
180
  # order. And we can distinguish those bands from other "variables" because they
174
181
  # will be 3D while the others will be scalars or 1D.
175
- bands_data = []
182
+
183
+ band_arrays = []
184
+ num_time_steps = None
176
185
  for band_name in nc.variables:
177
186
  band_data = nc.variables[band_name]
178
187
  if len(band_data.shape) != 3:
@@ -182,18 +191,27 @@ class ERA5LandMonthlyMeans(DataSource):
182
191
  logger.debug(
183
192
  f"NC file {nc_path} has variable {band_name} with shape {band_data.shape}"
184
193
  )
185
- # Variable data is stored in a 3D array (1, height, width)
186
- if band_data.shape[0] != 1:
194
+ # Variable data is stored in a 3D array (time, height, width)
195
+ # For hourly data, time is number of days in the month x 24 hours
196
+ if num_time_steps is None:
197
+ num_time_steps = band_data.shape[0]
198
+ elif band_data.shape[0] != num_time_steps:
187
199
  raise ValueError(
188
- f"Bad shape for band {band_name}, expected 1 band but got {band_data.shape[0]}"
200
+ f"Variable {band_name} has {band_data.shape[0]} time steps, "
201
+ f"but expected {num_time_steps}"
189
202
  )
190
- bands_data.append(band_data[0, :, :])
203
+ # Original shape: (time, height, width)
204
+ band_array = np.array(band_data[:])
205
+ band_array = np.expand_dims(band_array, axis=1)
206
+ band_arrays.append(band_array)
191
207
 
192
- array = np.array(bands_data) # (num_bands, height, width)
193
- if array.shape[0] != len(self.band_names):
194
- raise ValueError(
195
- f"Expected to get {len(self.band_names)} bands but got {array.shape[0]}"
196
- )
208
+ # After concatenation: (time, num_variables, height, width)
209
+ stacked_array = np.concatenate(band_arrays, axis=1)
210
+
211
+ # After reshaping: (time x num_variables, height, width)
212
+ array = stacked_array.reshape(
213
+ -1, stacked_array.shape[2], stacked_array.shape[3]
214
+ )
197
215
 
198
216
  # Get metadata for the GeoTIFF
199
217
  lat = nc.variables["latitude"][:]
@@ -235,6 +253,58 @@ class ERA5LandMonthlyMeans(DataSource):
235
253
  ) as dst:
236
254
  dst.write(array)
237
255
 
256
+ def ingest(
257
+ self,
258
+ tile_store: TileStoreWithLayer,
259
+ items: list[Item],
260
+ geometries: list[list[STGeometry]],
261
+ ) -> None:
262
+ """Ingest items into the given tile store.
263
+
264
+ This method should be overridden by subclasses.
265
+
266
+ Args:
267
+ tile_store: the tile store to ingest into
268
+ items: the items to ingest
269
+ geometries: a list of geometries needed for each item
270
+ """
271
+ raise NotImplementedError("Subclasses must implement ingest method")
272
+
273
+
274
+ class ERA5LandMonthlyMeans(ERA5Land):
275
+ """A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
276
+
277
+ This data source corresponds to the reanalysis-era5-land-monthly-means product.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ band_names: list[str] | None = None,
283
+ api_key: str | None = None,
284
+ bounds: list[float] | None = None,
285
+ context: DataSourceContext = DataSourceContext(),
286
+ ):
287
+ """Initialize a new ERA5LandMonthlyMeans instance.
288
+
289
+ Args:
290
+ band_names: list of band names to acquire. These should correspond to CDS
291
+ variable names but with "_" replaced with "-". This will only be used
292
+ if the layer config is missing from the context.
293
+ api_key: the API key. If not set, it should be set via the CDSAPI_KEY
294
+ environment variable.
295
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
296
+ If not specified, the whole globe will be used.
297
+ context: the data source context.
298
+ """
299
+ super().__init__(
300
+ dataset="reanalysis-era5-land-monthly-means",
301
+ product_type="monthly_averaged_reanalysis",
302
+ band_names=band_names,
303
+ api_key=api_key,
304
+ bounds=bounds,
305
+ context=context,
306
+ )
307
+
238
308
  def ingest(
239
309
  self,
240
310
  tile_store: TileStoreWithLayer,
@@ -256,25 +326,142 @@ class ERA5LandMonthlyMeans(DataSource):
256
326
  continue
257
327
 
258
328
  # Send the request to the CDS API
259
- # If area is not specified, the whole globe will be requested
329
+ if self.bounds is not None:
330
+ min_lon, min_lat, max_lon, max_lat = self.bounds
331
+ area = [max_lat, min_lon, min_lat, max_lon]
332
+ else:
333
+ area = [90, -180, -90, 180] # Whole globe
334
+
260
335
  request = {
261
- "product_type": [self.PRODUCT_TYPE],
336
+ "product_type": [self.product_type],
262
337
  "variable": variable_names,
263
338
  "year": [f"{item.geometry.time_range[0].year}"], # type: ignore
264
339
  "month": [
265
340
  f"{item.geometry.time_range[0].month:02d}" # type: ignore
266
341
  ],
267
342
  "time": ["00:00"],
343
+ "area": area,
344
+ "data_format": self.DATA_FORMAT,
345
+ "download_format": self.DOWNLOAD_FORMAT,
346
+ }
347
+ logger.debug(
348
+ f"CDS API request for year={request['year']} month={request['month']} area={area}"
349
+ )
350
+ with tempfile.TemporaryDirectory() as tmp_dir:
351
+ local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
352
+ local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
353
+ self.client.retrieve(self.dataset, request, local_nc_fname)
354
+ self._convert_nc_to_tif(
355
+ UPath(local_nc_fname),
356
+ UPath(local_tif_fname),
357
+ )
358
+ tile_store.write_raster_file(
359
+ item.name, self.band_names, UPath(local_tif_fname)
360
+ )
361
+
362
+
363
+ class ERA5LandHourly(ERA5Land):
364
+ """A data source for ingesting ERA5 land hourly data from the Copernicus Climate Data Store.
365
+
366
+ This data source corresponds to the reanalysis-era5-land product.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ band_names: list[str] | None = None,
372
+ api_key: str | None = None,
373
+ bounds: list[float] | None = None,
374
+ context: DataSourceContext = DataSourceContext(),
375
+ ):
376
+ """Initialize a new ERA5LandHourly instance.
377
+
378
+ Args:
379
+ band_names: list of band names to acquire. These should correspond to CDS
380
+ variable names but with "_" replaced with "-". This will only be used
381
+ if the layer config is missing from the context.
382
+ api_key: the API key. If not set, it should be set via the CDSAPI_KEY
383
+ environment variable.
384
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
385
+ If not specified, the whole globe will be used.
386
+ context: the data source context.
387
+ """
388
+ super().__init__(
389
+ dataset="reanalysis-era5-land",
390
+ product_type="reanalysis",
391
+ band_names=band_names,
392
+ api_key=api_key,
393
+ bounds=bounds,
394
+ context=context,
395
+ )
396
+
397
+ def ingest(
398
+ self,
399
+ tile_store: TileStoreWithLayer,
400
+ items: list[Item],
401
+ geometries: list[list[STGeometry]],
402
+ ) -> None:
403
+ """Ingest items into the given tile store.
404
+
405
+ Args:
406
+ tile_store: the tile store to ingest into
407
+ items: the items to ingest
408
+ geometries: a list of geometries needed for each item
409
+ """
410
+ # for CDS variable names, replace "-" with "_"
411
+ variable_names = [band.replace("-", "_") for band in self.band_names]
412
+
413
+ for item in items:
414
+ if tile_store.is_raster_ready(item.name, self.band_names):
415
+ continue
416
+
417
+ # Send the request to the CDS API
418
+ # If area is not specified, the whole globe will be requested
419
+ time_range = item.geometry.time_range
420
+ if time_range is None:
421
+ raise ValueError("Item must have a time range")
422
+
423
+ # For hourly data, request all days in the month and all 24 hours
424
+ start_time = time_range[0]
425
+
426
+ # Get all days in the month
427
+ year = start_time.year
428
+ month = start_time.month
429
+ # Get the last day of the month
430
+ if month == 12:
431
+ last_day = 31
432
+ else:
433
+ next_month = datetime(year, month + 1, 1, tzinfo=UTC)
434
+ last_day = (next_month - relativedelta(days=1)).day
435
+
436
+ days = [f"{day:02d}" for day in range(1, last_day + 1)]
437
+
438
+ # Get all 24 hours
439
+ hours = [f"{hour:02d}:00" for hour in range(24)]
440
+
441
+ if self.bounds is not None:
442
+ min_lon, min_lat, max_lon, max_lat = self.bounds
443
+ area = [max_lat, min_lon, min_lat, max_lon]
444
+ else:
445
+ area = [90, -180, -90, 180] # Whole globe
446
+
447
+ request = {
448
+ "product_type": [self.product_type],
449
+ "variable": variable_names,
450
+ "year": [f"{year}"],
451
+ "month": [f"{month:02d}"],
452
+ "day": days,
453
+ "time": hours,
454
+ "area": area,
268
455
  "data_format": self.DATA_FORMAT,
269
456
  "download_format": self.DOWNLOAD_FORMAT,
270
457
  }
271
458
  logger.debug(
272
- f"CDS API request for the whole globe for year={request['year']} month={request['month']}"
459
+ f"CDS API request for year={request['year']} month={request['month']} days={len(days)} hours={len(hours)} area={area}"
273
460
  )
274
461
  with tempfile.TemporaryDirectory() as tmp_dir:
275
462
  local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
276
463
  local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
277
- self.client.retrieve(self.DATASET, request, local_nc_fname)
464
+ self.client.retrieve(self.dataset, request, local_nc_fname)
278
465
  self._convert_nc_to_tif(
279
466
  UPath(local_nc_fname),
280
467
  UPath(local_tif_fname),
@@ -4,6 +4,8 @@ This code loads the AnySat model from torch hub. See
4
4
  https://github.com/gastruc/AnySat for applicable license and copyright information.
5
5
  """
6
6
 
7
+ from datetime import datetime
8
+
7
9
  import torch
8
10
  from einops import rearrange
9
11
 
@@ -53,7 +55,6 @@ class AnySat(FeatureExtractor):
53
55
  self,
54
56
  modalities: list[str],
55
57
  patch_size_meters: int,
56
- dates: dict[str, list[int]],
57
58
  output: str = "patch",
58
59
  output_modality: str | None = None,
59
60
  hub_repo: str = "gastruc/anysat",
@@ -85,14 +86,6 @@ class AnySat(FeatureExtractor):
85
86
  if m not in MODALITY_RESOLUTIONS:
86
87
  raise ValueError(f"Invalid modality: {m}")
87
88
 
88
- if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
89
- raise ValueError("`dates` keys must be time-series modalities only.")
90
- for m in modalities:
91
- if m in TIME_SERIES_MODALITIES and m not in dates:
92
- raise ValueError(
93
- f"Missing required dates for time-series modality '{m}'."
94
- )
95
-
96
89
  if patch_size_meters % 10 != 0:
97
90
  raise ValueError(
98
91
  "In AnySat, `patch_size` is in meters and must be a multiple of 10."
@@ -106,7 +99,6 @@ class AnySat(FeatureExtractor):
106
99
 
107
100
  self.modalities = modalities
108
101
  self.patch_size_meters = int(patch_size_meters)
109
- self.dates = dates
110
102
  self.output = output
111
103
  self.output_modality = output_modality
112
104
 
@@ -119,6 +111,20 @@ class AnySat(FeatureExtractor):
119
111
  )
120
112
  self._embed_dim = 768 # base width, 'dense' returns 2x
121
113
 
114
+ @staticmethod
115
+ def time_ranges_to_doy(
116
+ time_ranges: list[tuple[datetime, datetime]],
117
+ device: torch.device,
118
+ ) -> torch.Tensor:
119
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
120
+
121
+ AnySat uses the doy with each timestamp, so we take the midpoint
122
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
123
+ time so that start_time == end_time == mid_time.
124
+ """
125
+ doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
126
+ return torch.tensor(doys, dtype=torch.int32, device=device)
127
+
122
128
  def forward(self, context: ModelContext) -> FeatureMaps:
123
129
  """Forward pass for the AnySat model.
124
130
 
@@ -139,17 +145,29 @@ class AnySat(FeatureExtractor):
139
145
  raise ValueError(f"Modality '{modality}' not present in inputs.")
140
146
 
141
147
  cur = torch.stack(
142
- [inp[modality] for inp in inputs], dim=0
143
- ) # (B, C, H, W) or (B, T*C, H, W)
148
+ [inp[modality].image for inp in inputs], dim=0
149
+ ) # (B, C, T, H, W)
144
150
 
145
151
  if modality in TIME_SERIES_MODALITIES:
146
- num_dates = len(self.dates[modality])
147
- num_bands = cur.shape[1] // num_dates
148
- cur = rearrange(
149
- cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
150
- )
152
+ num_bands = cur.shape[1]
153
+ cur = rearrange(cur, "b c t h w -> b t c h w")
151
154
  H, W = cur.shape[-2], cur.shape[-1]
155
+
156
+ if inputs[0][modality].timestamps is None:
157
+ raise ValueError(
158
+ f"Require timestamps for time series modality {modality}"
159
+ )
160
+ timestamps = torch.stack(
161
+ [
162
+ self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
163
+ for inp in inputs
164
+ ],
165
+ dim=0,
166
+ )
167
+ batch[f"{modality}_dates"] = timestamps
152
168
  else:
169
+ # take the first (assumed only) timestep
170
+ cur = cur[:, :, 0]
153
171
  num_bands = cur.shape[1]
154
172
  H, W = cur.shape[-2], cur.shape[-1]
155
173
 
@@ -173,22 +191,6 @@ class AnySat(FeatureExtractor):
173
191
  "All modalities must share the same spatial extent (H*res, W*res)."
174
192
  )
175
193
 
176
- # Add *_dates
177
- to_add = {}
178
- for modality, x in list(batch.items()):
179
- if modality in TIME_SERIES_MODALITIES:
180
- B, T = x.shape[0], x.shape[1]
181
- d = torch.as_tensor(
182
- self.dates[modality], dtype=torch.long, device=x.device
183
- )
184
- if d.ndim != 1 or d.numel() != T:
185
- raise ValueError(
186
- f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
187
- )
188
- to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
189
-
190
- batch.update(to_add)
191
-
192
194
  kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
193
195
  if self.output == "dense":
194
196
  kwargs["output_modality"] = self.output_modality
@@ -43,9 +43,12 @@ class CLIP(FeatureExtractor):
43
43
  a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
44
44
  """
45
45
  inputs = context.inputs
46
- device = inputs[0]["image"].device
46
+ device = inputs[0]["image"].image.device
47
47
  clip_inputs = self.processor(
48
- images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
48
+ images=[
49
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
50
+ for inp in inputs
51
+ ],
49
52
  return_tensors="pt",
50
53
  padding=True,
51
54
  )
@@ -175,10 +175,16 @@ class Croma(FeatureExtractor):
175
175
  sentinel1: torch.Tensor | None = None
176
176
  sentinel2: torch.Tensor | None = None
177
177
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
178
- sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
178
+ sentinel1 = torch.stack(
179
+ [inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
180
+ dim=0,
181
+ )
179
182
  sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
180
183
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
181
- sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
184
+ sentinel2 = torch.stack(
185
+ [inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
186
+ dim=0,
187
+ )
182
188
  sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
183
189
 
184
190
  outputs = self.model(
@@ -294,5 +300,7 @@ class CromaNormalize(Transform):
294
300
  for modality in MODALITY_BANDS.keys():
295
301
  if modality not in input_dict:
296
302
  continue
297
- input_dict[modality] = self.apply_image(input_dict[modality], modality)
303
+ input_dict[modality].image = self.apply_image(
304
+ input_dict[modality].image, modality
305
+ )
298
306
  return input_dict, target_dict
@@ -104,7 +104,8 @@ class DinoV3(FeatureExtractor):
104
104
  a FeatureMaps with one feature map.
105
105
  """
106
106
  cur = torch.stack(
107
- [inp["image"] for inp in context.inputs], dim=0
107
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
108
+ dim=0,
108
109
  ) # (B, C, H, W)
109
110
 
110
111
  if self.do_resizing and (
@@ -210,7 +210,8 @@ class FasterRCNN(Predictor):
210
210
  ),
211
211
  )
212
212
 
213
- image_list = [inp["image"] for inp in context.inputs]
213
+ # take the first (and assumed to be only) timestep
214
+ image_list = [inp["image"].image[:, 0] for inp in context.inputs]
214
215
  images, targets = self.noop_transform(image_list, targets)
215
216
 
216
217
  feature_dict = collections.OrderedDict()