rslearn 0.0.20__tar.gz → 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. {rslearn-0.0.20/rslearn.egg-info → rslearn-0.0.21}/PKG-INFO +1 -1
  2. {rslearn-0.0.20 → rslearn-0.0.21}/pyproject.toml +1 -1
  3. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/climate_data_store.py +216 -29
  4. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/model.py +7 -8
  5. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/dataset.py +44 -35
  6. {rslearn-0.0.20 → rslearn-0.0.21/rslearn.egg-info}/PKG-INFO +1 -1
  7. {rslearn-0.0.20 → rslearn-0.0.21}/LICENSE +0 -0
  8. {rslearn-0.0.20 → rslearn-0.0.21}/NOTICE +0 -0
  9. {rslearn-0.0.20 → rslearn-0.0.21}/README.md +0 -0
  10. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/__init__.py +0 -0
  11. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/arg_parser.py +0 -0
  12. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/config/__init__.py +0 -0
  13. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/config/dataset.py +0 -0
  14. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/const.py +0 -0
  15. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/__init__.py +0 -0
  16. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/aws_landsat.py +0 -0
  17. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/aws_open_data.py +0 -0
  18. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/aws_sentinel1.py +0 -0
  19. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/copernicus.py +0 -0
  20. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/data_source.py +0 -0
  21. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/earthdaily.py +0 -0
  22. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/earthdata_srtm.py +0 -0
  23. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/eurocrops.py +0 -0
  24. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/gcp_public_data.py +0 -0
  25. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/google_earth_engine.py +0 -0
  26. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/local_files.py +0 -0
  27. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/openstreetmap.py +0 -0
  28. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/planet.py +0 -0
  29. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/planet_basemap.py +0 -0
  30. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/planetary_computer.py +0 -0
  31. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/usda_cdl.py +0 -0
  32. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/usgs_landsat.py +0 -0
  33. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/utils.py +0 -0
  34. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/vector_source.py +0 -0
  35. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/worldcereal.py +0 -0
  36. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/worldcover.py +0 -0
  37. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/worldpop.py +0 -0
  38. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/data_sources/xyz_tiles.py +0 -0
  39. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/__init__.py +0 -0
  40. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/add_windows.py +0 -0
  41. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/dataset.py +0 -0
  42. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/handler_summaries.py +0 -0
  43. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/manage.py +0 -0
  44. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/materialize.py +0 -0
  45. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/remap.py +0 -0
  46. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/storage/__init__.py +0 -0
  47. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/storage/file.py +0 -0
  48. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/storage/storage.py +0 -0
  49. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/dataset/window.py +0 -0
  50. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/lightning_cli.py +0 -0
  51. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/log_utils.py +0 -0
  52. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/main.py +0 -0
  53. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/__init__.py +0 -0
  54. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/anysat.py +0 -0
  55. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/attention_pooling.py +0 -0
  56. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/clay/clay.py +0 -0
  57. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/clay/configs/metadata.yaml +0 -0
  58. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/clip.py +0 -0
  59. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/component.py +0 -0
  60. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/concatenate_features.py +0 -0
  61. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/conv.py +0 -0
  62. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/croma.py +0 -0
  63. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/__init__.py +0 -0
  64. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/box_ops.py +0 -0
  65. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/detr.py +0 -0
  66. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/matcher.py +0 -0
  67. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/position_encoding.py +0 -0
  68. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/transformer.py +0 -0
  69. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/detr/util.py +0 -0
  70. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/dinov3.py +0 -0
  71. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/faster_rcnn.py +0 -0
  72. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/feature_center_crop.py +0 -0
  73. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/fpn.py +0 -0
  74. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/galileo/__init__.py +0 -0
  75. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/galileo/galileo.py +0 -0
  76. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/galileo/single_file_galileo.py +0 -0
  77. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/module_wrapper.py +0 -0
  78. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/molmo.py +0 -0
  79. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/multitask.py +0 -0
  80. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  81. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  82. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon.py +0 -0
  83. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  84. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  85. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  86. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  87. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  88. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  89. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  90. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  91. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  92. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  93. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  94. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  95. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/pick_features.py +0 -0
  96. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/pooling_decoder.py +0 -0
  97. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/presto/__init__.py +0 -0
  98. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/presto/presto.py +0 -0
  99. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/presto/single_file_presto.py +0 -0
  100. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/prithvi.py +0 -0
  101. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/resize_features.py +0 -0
  102. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/sam2_enc.py +0 -0
  103. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/satlaspretrain.py +0 -0
  104. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/simple_time_series.py +0 -0
  105. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/singletask.py +0 -0
  106. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/ssl4eo_s12.py +0 -0
  107. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/swin.py +0 -0
  108. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/task_embedding.py +0 -0
  109. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/terramind.py +0 -0
  110. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/trunk.py +0 -0
  111. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/unet.py +0 -0
  112. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/upsample.py +0 -0
  113. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/models/use_croma.py +0 -0
  114. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/py.typed +0 -0
  115. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/template_params.py +0 -0
  116. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/tile_stores/__init__.py +0 -0
  117. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/tile_stores/default.py +0 -0
  118. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/tile_stores/tile_store.py +0 -0
  119. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/__init__.py +0 -0
  120. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/all_patches_dataset.py +0 -0
  121. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/callbacks/__init__.py +0 -0
  122. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/callbacks/adapters.py +0 -0
  123. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  124. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/callbacks/gradients.py +0 -0
  125. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/callbacks/peft.py +0 -0
  126. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/data_module.py +0 -0
  127. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/lightning_module.py +0 -0
  128. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/model_context.py +0 -0
  129. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/optimizer.py +0 -0
  130. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/prediction_writer.py +0 -0
  131. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/scheduler.py +0 -0
  132. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/__init__.py +0 -0
  133. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/classification.py +0 -0
  134. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/detection.py +0 -0
  135. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/embedding.py +0 -0
  136. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/multi_task.py +0 -0
  137. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/per_pixel_regression.py +0 -0
  138. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/regression.py +0 -0
  139. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/segmentation.py +0 -0
  140. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/tasks/task.py +0 -0
  141. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/__init__.py +0 -0
  142. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/concatenate.py +0 -0
  143. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/crop.py +0 -0
  144. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/flip.py +0 -0
  145. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/mask.py +0 -0
  146. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/normalize.py +0 -0
  147. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/pad.py +0 -0
  148. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/resize.py +0 -0
  149. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/select_bands.py +0 -0
  150. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/sentinel1.py +0 -0
  151. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/train/transforms/transform.py +0 -0
  152. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/__init__.py +0 -0
  153. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/array.py +0 -0
  154. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/feature.py +0 -0
  155. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/fsspec.py +0 -0
  156. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/geometry.py +0 -0
  157. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/get_utm_ups_crs.py +0 -0
  158. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/grid_index.py +0 -0
  159. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/jsonargparse.py +0 -0
  160. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/mp.py +0 -0
  161. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/raster_format.py +0 -0
  162. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/rtree_index.py +0 -0
  163. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/spatial_index.py +0 -0
  164. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/sqlite_index.py +0 -0
  165. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/time.py +0 -0
  166. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn/utils/vector_format.py +0 -0
  167. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn.egg-info/SOURCES.txt +0 -0
  168. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn.egg-info/dependency_links.txt +0 -0
  169. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn.egg-info/entry_points.txt +0 -0
  170. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn.egg-info/requires.txt +0 -0
  171. {rslearn-0.0.20 → rslearn-0.0.21}/rslearn.egg-info/top_level.txt +0 -0
  172. {rslearn-0.0.20 → rslearn-0.0.21}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.20"
3
+ version = "0.0.21"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -24,51 +24,55 @@ from rslearn.utils.geometry import STGeometry
24
24
  logger = get_logger(__name__)
25
25
 
26
26
 
27
- class ERA5LandMonthlyMeans(DataSource):
28
- """A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
27
+ class ERA5Land(DataSource):
28
+ """Base class for ingesting ERA5 land data from the Copernicus Climate Data Store.
29
29
 
30
30
  An API key must be passed either in the configuration or via the CDSAPI_KEY
31
31
  environment variable. You can acquire an API key by going to the Climate Data Store
32
32
  website (https://cds.climate.copernicus.eu/), registering an account and logging
33
- in, and then
33
+ in, and then getting the API key from the user profile page.
34
34
 
35
35
  The band names should match CDS variable names (see the reference at
36
36
  https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation). However,
37
37
  replace "_" with "-" in the variable names when specifying bands in the layer
38
38
  configuration.
39
39
 
40
- This data source corresponds to the reanalysis-era5-land-monthly-means product.
41
-
42
- All requests to the API will be for the whole globe. Although the API supports arbitrary
43
- bounds in the requests, using the whole available area helps to reduce the total number of
44
- requests.
40
+ By default, all requests to the API will be for the whole globe. To speed up ingestion,
41
+ we recommend specifying the bounds of the area of interest, in particular for hourly data.
45
42
  """
46
43
 
47
44
  api_url = "https://cds.climate.copernicus.eu/api"
48
-
49
- # see: https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-land-monthly-means
50
- DATASET = "reanalysis-era5-land-monthly-means"
51
- PRODUCT_TYPE = "monthly_averaged_reanalysis"
52
45
  DATA_FORMAT = "netcdf"
53
46
  DOWNLOAD_FORMAT = "unarchived"
54
47
  PIXEL_SIZE = 0.1 # degrees, native resolution is 9km
55
48
 
56
49
  def __init__(
57
50
  self,
51
+ dataset: str,
52
+ product_type: str,
58
53
  band_names: list[str] | None = None,
59
54
  api_key: str | None = None,
55
+ bounds: list[float] | None = None,
60
56
  context: DataSourceContext = DataSourceContext(),
61
57
  ):
62
- """Initialize a new ERA5LandMonthlyMeans instance.
58
+ """Initialize a new ERA5Land instance.
63
59
 
64
60
  Args:
61
+ dataset: the CDS dataset name (e.g., "reanalysis-era5-land-monthly-means").
62
+ product_type: the CDS product type (e.g., "monthly_averaged_reanalysis").
65
63
  band_names: list of band names to acquire. These should correspond to CDS
66
64
  variable names but with "_" replaced with "-". This will only be used
67
65
  if the layer config is missing from the context.
68
66
  api_key: the API key. If not set, it should be set via the CDSAPI_KEY
69
67
  environment variable.
68
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
69
+ If not specified, the whole globe will be used.
70
70
  context: the data source context.
71
71
  """
72
+ self.dataset = dataset
73
+ self.product_type = product_type
74
+ self.bounds = bounds
75
+
72
76
  self.band_names: list[str]
73
77
  if context.layer_config is not None:
74
78
  self.band_names = []
@@ -134,8 +138,11 @@ class ERA5LandMonthlyMeans(DataSource):
134
138
  # Collect Item list corresponding to the current month.
135
139
  items = []
136
140
  item_name = f"era5land_monthlyaveraged_{cur_date.year}_{cur_date.month}"
137
- # Space is the whole globe.
138
- bounds = (-180, -90, 180, 90)
141
+ # Use bounds if set, otherwise use whole globe
142
+ if self.bounds is not None:
143
+ bounds = self.bounds
144
+ else:
145
+ bounds = [-180, -90, 180, 90]
139
146
  # Time is just the given month.
140
147
  start_date = datetime(cur_date.year, cur_date.month, 1, tzinfo=UTC)
141
148
  time_range = (
@@ -172,7 +179,9 @@ class ERA5LandMonthlyMeans(DataSource):
172
179
  # But the list of variables should include the bands we want in the correct
173
180
  # order. And we can distinguish those bands from other "variables" because they
174
181
  # will be 3D while the others will be scalars or 1D.
175
- bands_data = []
182
+
183
+ band_arrays = []
184
+ num_time_steps = None
176
185
  for band_name in nc.variables:
177
186
  band_data = nc.variables[band_name]
178
187
  if len(band_data.shape) != 3:
@@ -182,18 +191,27 @@ class ERA5LandMonthlyMeans(DataSource):
182
191
  logger.debug(
183
192
  f"NC file {nc_path} has variable {band_name} with shape {band_data.shape}"
184
193
  )
185
- # Variable data is stored in a 3D array (1, height, width)
186
- if band_data.shape[0] != 1:
194
+ # Variable data is stored in a 3D array (time, height, width)
195
+ # For hourly data, time is number of days in the month x 24 hours
196
+ if num_time_steps is None:
197
+ num_time_steps = band_data.shape[0]
198
+ elif band_data.shape[0] != num_time_steps:
187
199
  raise ValueError(
188
- f"Bad shape for band {band_name}, expected 1 band but got {band_data.shape[0]}"
200
+ f"Variable {band_name} has {band_data.shape[0]} time steps, "
201
+ f"but expected {num_time_steps}"
189
202
  )
190
- bands_data.append(band_data[0, :, :])
203
+ # Original shape: (time, height, width)
204
+ band_array = np.array(band_data[:])
205
+ band_array = np.expand_dims(band_array, axis=1)
206
+ band_arrays.append(band_array)
191
207
 
192
- array = np.array(bands_data) # (num_bands, height, width)
193
- if array.shape[0] != len(self.band_names):
194
- raise ValueError(
195
- f"Expected to get {len(self.band_names)} bands but got {array.shape[0]}"
196
- )
208
+ # After concatenation: (time, num_variables, height, width)
209
+ stacked_array = np.concatenate(band_arrays, axis=1)
210
+
211
+ # After reshaping: (time x num_variables, height, width)
212
+ array = stacked_array.reshape(
213
+ -1, stacked_array.shape[2], stacked_array.shape[3]
214
+ )
197
215
 
198
216
  # Get metadata for the GeoTIFF
199
217
  lat = nc.variables["latitude"][:]
@@ -235,6 +253,58 @@ class ERA5LandMonthlyMeans(DataSource):
235
253
  ) as dst:
236
254
  dst.write(array)
237
255
 
256
+ def ingest(
257
+ self,
258
+ tile_store: TileStoreWithLayer,
259
+ items: list[Item],
260
+ geometries: list[list[STGeometry]],
261
+ ) -> None:
262
+ """Ingest items into the given tile store.
263
+
264
+ This method should be overridden by subclasses.
265
+
266
+ Args:
267
+ tile_store: the tile store to ingest into
268
+ items: the items to ingest
269
+ geometries: a list of geometries needed for each item
270
+ """
271
+ raise NotImplementedError("Subclasses must implement ingest method")
272
+
273
+
274
+ class ERA5LandMonthlyMeans(ERA5Land):
275
+ """A data source for ingesting ERA5 land monthly averaged data from the Copernicus Climate Data Store.
276
+
277
+ This data source corresponds to the reanalysis-era5-land-monthly-means product.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ band_names: list[str] | None = None,
283
+ api_key: str | None = None,
284
+ bounds: list[float] | None = None,
285
+ context: DataSourceContext = DataSourceContext(),
286
+ ):
287
+ """Initialize a new ERA5LandMonthlyMeans instance.
288
+
289
+ Args:
290
+ band_names: list of band names to acquire. These should correspond to CDS
291
+ variable names but with "_" replaced with "-". This will only be used
292
+ if the layer config is missing from the context.
293
+ api_key: the API key. If not set, it should be set via the CDSAPI_KEY
294
+ environment variable.
295
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
296
+ If not specified, the whole globe will be used.
297
+ context: the data source context.
298
+ """
299
+ super().__init__(
300
+ dataset="reanalysis-era5-land-monthly-means",
301
+ product_type="monthly_averaged_reanalysis",
302
+ band_names=band_names,
303
+ api_key=api_key,
304
+ bounds=bounds,
305
+ context=context,
306
+ )
307
+
238
308
  def ingest(
239
309
  self,
240
310
  tile_store: TileStoreWithLayer,
@@ -256,25 +326,142 @@ class ERA5LandMonthlyMeans(DataSource):
256
326
  continue
257
327
 
258
328
  # Send the request to the CDS API
259
- # If area is not specified, the whole globe will be requested
329
+ if self.bounds is not None:
330
+ min_lon, min_lat, max_lon, max_lat = self.bounds
331
+ area = [max_lat, min_lon, min_lat, max_lon]
332
+ else:
333
+ area = [90, -180, -90, 180] # Whole globe
334
+
260
335
  request = {
261
- "product_type": [self.PRODUCT_TYPE],
336
+ "product_type": [self.product_type],
262
337
  "variable": variable_names,
263
338
  "year": [f"{item.geometry.time_range[0].year}"], # type: ignore
264
339
  "month": [
265
340
  f"{item.geometry.time_range[0].month:02d}" # type: ignore
266
341
  ],
267
342
  "time": ["00:00"],
343
+ "area": area,
344
+ "data_format": self.DATA_FORMAT,
345
+ "download_format": self.DOWNLOAD_FORMAT,
346
+ }
347
+ logger.debug(
348
+ f"CDS API request for year={request['year']} month={request['month']} area={area}"
349
+ )
350
+ with tempfile.TemporaryDirectory() as tmp_dir:
351
+ local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
352
+ local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
353
+ self.client.retrieve(self.dataset, request, local_nc_fname)
354
+ self._convert_nc_to_tif(
355
+ UPath(local_nc_fname),
356
+ UPath(local_tif_fname),
357
+ )
358
+ tile_store.write_raster_file(
359
+ item.name, self.band_names, UPath(local_tif_fname)
360
+ )
361
+
362
+
363
+ class ERA5LandHourly(ERA5Land):
364
+ """A data source for ingesting ERA5 land hourly data from the Copernicus Climate Data Store.
365
+
366
+ This data source corresponds to the reanalysis-era5-land product.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ band_names: list[str] | None = None,
372
+ api_key: str | None = None,
373
+ bounds: list[float] | None = None,
374
+ context: DataSourceContext = DataSourceContext(),
375
+ ):
376
+ """Initialize a new ERA5LandHourly instance.
377
+
378
+ Args:
379
+ band_names: list of band names to acquire. These should correspond to CDS
380
+ variable names but with "_" replaced with "-". This will only be used
381
+ if the layer config is missing from the context.
382
+ api_key: the API key. If not set, it should be set via the CDSAPI_KEY
383
+ environment variable.
384
+ bounds: optional bounding box as [min_lon, min_lat, max_lon, max_lat].
385
+ If not specified, the whole globe will be used.
386
+ context: the data source context.
387
+ """
388
+ super().__init__(
389
+ dataset="reanalysis-era5-land",
390
+ product_type="reanalysis",
391
+ band_names=band_names,
392
+ api_key=api_key,
393
+ bounds=bounds,
394
+ context=context,
395
+ )
396
+
397
+ def ingest(
398
+ self,
399
+ tile_store: TileStoreWithLayer,
400
+ items: list[Item],
401
+ geometries: list[list[STGeometry]],
402
+ ) -> None:
403
+ """Ingest items into the given tile store.
404
+
405
+ Args:
406
+ tile_store: the tile store to ingest into
407
+ items: the items to ingest
408
+ geometries: a list of geometries needed for each item
409
+ """
410
+ # for CDS variable names, replace "-" with "_"
411
+ variable_names = [band.replace("-", "_") for band in self.band_names]
412
+
413
+ for item in items:
414
+ if tile_store.is_raster_ready(item.name, self.band_names):
415
+ continue
416
+
417
+ # Send the request to the CDS API
418
+ # If area is not specified, the whole globe will be requested
419
+ time_range = item.geometry.time_range
420
+ if time_range is None:
421
+ raise ValueError("Item must have a time range")
422
+
423
+ # For hourly data, request all days in the month and all 24 hours
424
+ start_time = time_range[0]
425
+
426
+ # Get all days in the month
427
+ year = start_time.year
428
+ month = start_time.month
429
+ # Get the last day of the month
430
+ if month == 12:
431
+ last_day = 31
432
+ else:
433
+ next_month = datetime(year, month + 1, 1, tzinfo=UTC)
434
+ last_day = (next_month - relativedelta(days=1)).day
435
+
436
+ days = [f"{day:02d}" for day in range(1, last_day + 1)]
437
+
438
+ # Get all 24 hours
439
+ hours = [f"{hour:02d}:00" for hour in range(24)]
440
+
441
+ if self.bounds is not None:
442
+ min_lon, min_lat, max_lon, max_lat = self.bounds
443
+ area = [max_lat, min_lon, min_lat, max_lon]
444
+ else:
445
+ area = [90, -180, -90, 180] # Whole globe
446
+
447
+ request = {
448
+ "product_type": [self.product_type],
449
+ "variable": variable_names,
450
+ "year": [f"{year}"],
451
+ "month": [f"{month:02d}"],
452
+ "day": days,
453
+ "time": hours,
454
+ "area": area,
268
455
  "data_format": self.DATA_FORMAT,
269
456
  "download_format": self.DOWNLOAD_FORMAT,
270
457
  }
271
458
  logger.debug(
272
- f"CDS API request for the whole globe for year={request['year']} month={request['month']}"
459
+ f"CDS API request for year={request['year']} month={request['month']} days={len(days)} hours={len(hours)} area={area}"
273
460
  )
274
461
  with tempfile.TemporaryDirectory() as tmp_dir:
275
462
  local_nc_fname = os.path.join(tmp_dir, f"{item.name}.nc")
276
463
  local_tif_fname = os.path.join(tmp_dir, f"{item.name}.tif")
277
- self.client.retrieve(self.DATASET, request, local_nc_fname)
464
+ self.client.retrieve(self.dataset, request, local_nc_fname)
278
465
  self._convert_nc_to_tif(
279
466
  UPath(local_nc_fname),
280
467
  UPath(local_tif_fname),
@@ -159,20 +159,19 @@ class OlmoEarth(FeatureExtractor):
159
159
  that contains the distributed checkpoint. This is the format produced by
160
160
  pre-training runs in olmoearth_pretrain.
161
161
  """
162
- # We avoid loading the train module here because it depends on running within
163
- # olmo_core.
164
- # Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
165
- require_olmo_core("_load_model_from_checkpoint")
166
- from olmo_core.distributed.checkpoint import load_model_and_optim_state
167
-
168
162
  with (checkpoint_upath / "config.json").open() as f:
169
163
  config_dict = json.load(f)
170
164
  model_config = Config.from_dict(config_dict["model"])
171
165
 
172
166
  model = model_config.build()
173
167
 
174
- # Load the checkpoint.
168
+ # Load the checkpoint (requires olmo_core for distributed checkpoint loading).
175
169
  if not random_initialization:
170
+ require_olmo_core(
171
+ "_load_model_from_checkpoint with random_initialization=False"
172
+ )
173
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
174
+
176
175
  train_module_dir = checkpoint_upath / "model_and_optim"
177
176
  load_model_and_optim_state(str(train_module_dir), model)
178
177
  logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
@@ -242,7 +241,7 @@ class OlmoEarth(FeatureExtractor):
242
241
  present_modalities.append(modality)
243
242
  tensors = []
244
243
  for idx, inp in enumerate(context.inputs):
245
- assert isinstance(inp, RasterImage)
244
+ assert isinstance(inp[modality], RasterImage)
246
245
  tensors.append(inp[modality].image)
247
246
  cur_timestamps = inp[modality].timestamps
248
247
  if cur_timestamps is not None and len(cur_timestamps) > len(
@@ -205,8 +205,7 @@ def read_raster_layer_for_data_input(
205
205
  group_idx: int,
206
206
  layer_config: LayerConfig,
207
207
  data_input: DataInput,
208
- layer_data: WindowLayerData | None,
209
- ) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
208
+ ) -> torch.Tensor:
210
209
  """Read a raster layer for a DataInput.
211
210
 
212
211
  This scans the available rasters for the layer at the window to determine which
@@ -219,11 +218,9 @@ def read_raster_layer_for_data_input(
219
218
  group_idx: the item group.
220
219
  layer_config: the layer configuration.
221
220
  data_input: the DataInput that specifies the bands and dtype.
222
- layer_data: the WindowLayerData associated with this layer and window.
223
221
 
224
222
  Returns:
225
- RasterImage containing raster data and the timestamp associated
226
- with that data.
223
+ Raster data as a tensor.
227
224
  """
228
225
  # See what different sets of bands we need to read to get all the
229
226
  # configured bands.
@@ -294,34 +291,46 @@ def read_raster_layer_for_data_input(
294
291
  src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
295
292
  )
296
293
 
297
- # add the timestamp. this is a tuple defining the start and end of the time range.
298
- time_range = None
299
- if layer_data is not None:
300
- item = Item.deserialize(layer_data.serialized_item_groups[group_idx][0])
301
- if item.geometry.time_range is not None:
302
- # we assume if one layer data has a geometry & time range, all of them do
303
- time_ranges = [
304
- (
305
- datetime.fromisoformat(
306
- Item.deserialize(
307
- layer_data.serialized_item_groups[group_idx][idx]
308
- ).geometry.time_range[0] # type: ignore
309
- ),
310
- datetime.fromisoformat(
311
- Item.deserialize(
312
- layer_data.serialized_item_groups[group_idx][idx]
313
- ).geometry.time_range[1] # type: ignore
314
- ),
315
- )
316
- for idx in range(len(layer_data.serialized_item_groups[group_idx]))
317
- ]
318
- # take the min and max
319
- time_range = (
320
- min([t[0] for t in time_ranges]),
321
- max([t[1] for t in time_ranges]),
294
+ return image
295
+
296
+
297
+ def read_layer_time_range(
298
+ layer_data: WindowLayerData | None, group_idx: int
299
+ ) -> tuple[datetime, datetime] | None:
300
+ """Extract the combined time range from all items in a layer data group.
301
+
302
+ Returns the min start time and max end time across all items, or None if
303
+ no items have time ranges.
304
+
305
+ Raises:
306
+ ValueError: If some items have time_range and others don't.
307
+ """
308
+ if layer_data is None:
309
+ return None
310
+
311
+ serialized_items = layer_data.serialized_item_groups[group_idx]
312
+ if not serialized_items:
313
+ return None
314
+
315
+ first_item = Item.deserialize(serialized_items[0])
316
+ if first_item.geometry.time_range is None:
317
+ return None
318
+
319
+ # If the first item has a time_range, all items must have one
320
+ time_ranges: list[tuple[datetime, datetime]] = []
321
+ for serialized_item in serialized_items:
322
+ item = Item.deserialize(serialized_item)
323
+ if item.geometry.time_range is None:
324
+ raise ValueError(
325
+ f"Item '{item.name}' has no time_range, but first item does. "
326
+ "All items in a group must consistently have or lack time_range."
322
327
  )
328
+ time_ranges.append(item.geometry.time_range)
323
329
 
324
- return image, time_range
330
+ return (
331
+ min(tr[0] for tr in time_ranges),
332
+ max(tr[1] for tr in time_ranges),
333
+ )
325
334
 
326
335
 
327
336
  def read_data_input(
@@ -378,17 +387,17 @@ def read_data_input(
378
387
  time_ranges: list[tuple[datetime, datetime] | None] = []
379
388
  for layer_name, group_idx in layers_to_read:
380
389
  layer_config = dataset.layers[layer_name]
381
- image, time_range = read_raster_layer_for_data_input(
390
+ image = read_raster_layer_for_data_input(
382
391
  window,
383
392
  bounds,
384
393
  layer_name,
385
394
  group_idx,
386
395
  layer_config,
387
396
  data_input,
388
- # some layers (e.g. "label_raster") won't have associated
389
- # layer datas
390
- layer_datas[layer_name] if layer_name in layer_datas else None,
391
397
  )
398
+ # some layers (e.g. "label_raster") won't have associated layer datas
399
+ layer_data = layer_datas.get(layer_name)
400
+ time_range = read_layer_time_range(layer_data, group_idx)
392
401
  if len(time_ranges) > 0:
393
402
  if type(time_ranges[-1]) is not type(time_range):
394
403
  raise ValueError(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes