rslearn 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import tempfile
5
5
  import xml.etree.ElementTree as ET
6
- from datetime import timedelta
6
+ from datetime import datetime, timedelta
7
7
  from typing import Any
8
8
 
9
9
  import affine
@@ -12,6 +12,7 @@ import planetary_computer
12
12
  import rasterio
13
13
  import requests
14
14
  from rasterio.enums import Resampling
15
+ from typing_extensions import override
15
16
  from upath import UPath
16
17
 
17
18
  from rslearn.config import LayerConfig
@@ -24,11 +25,104 @@ from rslearn.tile_stores import TileStore, TileStoreWithLayer
24
25
  from rslearn.utils.fsspec import join_upath
25
26
  from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
26
27
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
28
+ from rslearn.utils.stac import StacClient, StacItem
27
29
 
28
30
  from .copernicus import get_harmonize_callback
29
31
 
30
32
  logger = get_logger(__name__)
31
33
 
34
+ # Max limit accepted by Planetary Computer API.
35
+ PLANETARY_COMPUTER_LIMIT = 1000
36
+
37
+
38
+ class PlanetaryComputerStacClient(StacClient):
39
+ """A StacClient subclass that handles Planetary Computer's pagination limits.
40
+
41
+ Planetary Computer STAC API does not support standard pagination and has a max
42
+ limit of 1000. If the initial query returns 1000 items, this client paginates
43
+ by sorting by ID and using gt (greater than) queries to fetch subsequent pages.
44
+ """
45
+
46
+ @override
47
+ def search(
48
+ self,
49
+ collections: list[str] | None = None,
50
+ bbox: tuple[float, float, float, float] | None = None,
51
+ intersects: dict[str, Any] | None = None,
52
+ date_time: datetime | tuple[datetime, datetime] | None = None,
53
+ ids: list[str] | None = None,
54
+ limit: int | None = None,
55
+ query: dict[str, Any] | None = None,
56
+ sortby: list[dict[str, str]] | None = None,
57
+ ) -> list[StacItem]:
58
+ # We will use sortby for pagination, so the caller must not set it.
59
+ if sortby is not None:
60
+ raise ValueError("sortby must not be set for PlanetaryComputerStacClient")
61
+
62
+ # First, try a simple query with the PC limit to detect if pagination is needed.
63
+ # We always use PLANETARY_COMPUTER_LIMIT for the request because PC doesn't
64
+ # support standard pagination, and we need to detect when we hit the limit
65
+ # to switch to ID-based pagination.
66
+ # We could just start sorting by ID here and do pagination, but we treate it as
67
+ # a special case to avoid sorting since that seems to speed up the query.
68
+ stac_items = super().search(
69
+ collections=collections,
70
+ bbox=bbox,
71
+ intersects=intersects,
72
+ date_time=date_time,
73
+ ids=ids,
74
+ limit=PLANETARY_COMPUTER_LIMIT,
75
+ query=query,
76
+ )
77
+
78
+ # If we got fewer than the PC limit, we have all the results.
79
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
80
+ return stac_items
81
+
82
+ # We hit the limit, so we need to paginate by ID.
83
+ # Re-fetch with sorting by ID to ensure consistent ordering for pagination.
84
+ logger.debug(
85
+ "Initial request returned %d items (at limit), switching to ID pagination",
86
+ len(stac_items),
87
+ )
88
+
89
+ all_items: list[StacItem] = []
90
+ last_id: str | None = None
91
+
92
+ while True:
93
+ # Build query with id > last_id if we're paginating.
94
+ combined_query: dict[str, Any] = dict(query) if query else {}
95
+ if last_id is not None:
96
+ combined_query["id"] = {"gt": last_id}
97
+
98
+ stac_items = super().search(
99
+ collections=collections,
100
+ bbox=bbox,
101
+ intersects=intersects,
102
+ date_time=date_time,
103
+ ids=ids,
104
+ limit=PLANETARY_COMPUTER_LIMIT,
105
+ query=combined_query if combined_query else None,
106
+ sortby=[{"field": "id", "direction": "asc"}],
107
+ )
108
+
109
+ all_items.extend(stac_items)
110
+
111
+ # If we got fewer than the limit, we've fetched everything.
112
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
113
+ break
114
+
115
+ # Otherwise, paginate using the last item's ID.
116
+ last_id = stac_items[-1].id
117
+ logger.debug(
118
+ "Got %d items, paginating with id > %s",
119
+ len(stac_items),
120
+ last_id,
121
+ )
122
+
123
+ logger.debug("Total items fetched: %d", len(all_items))
124
+ return all_items
125
+
32
126
 
33
127
  class PlanetaryComputer(StacDataSource, TileStore):
34
128
  """Modality-agnostic data source for data on Microsoft Planetary Computer.
@@ -100,6 +194,10 @@ class PlanetaryComputer(StacDataSource, TileStore):
100
194
  required_assets=required_assets,
101
195
  cache_dir=cache_upath,
102
196
  )
197
+
198
+ # Replace the client with PlanetaryComputerStacClient to handle PC's pagination limits.
199
+ self.client = PlanetaryComputerStacClient(self.STAC_ENDPOINT)
200
+
103
201
  self.asset_bands = asset_bands
104
202
  self.timeout = timeout
105
203
  self.skip_items_missing_assets = skip_items_missing_assets
@@ -12,6 +12,7 @@ from rslearn.const import WGS84_PROJECTION
12
12
  from rslearn.data_sources.data_source import Item, ItemLookupDataSource
13
13
  from rslearn.data_sources.utils import match_candidate_items_to_window
14
14
  from rslearn.log_utils import get_logger
15
+ from rslearn.utils.fsspec import open_atomic
15
16
  from rslearn.utils.geometry import STGeometry
16
17
  from rslearn.utils.stac import StacClient, StacItem
17
18
 
@@ -187,7 +188,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
187
188
 
188
189
  # Finally we cache it if cache_dir is set.
189
190
  if cache_fname is not None:
190
- with cache_fname.open("w") as f:
191
+ with open_atomic(cache_fname, "w") as f:
191
192
  json.dump(item.serialize(), f)
192
193
 
193
194
  return item
@@ -259,7 +260,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
259
260
  cache_fname = self.cache_dir / f"{item.name}.json"
260
261
  if cache_fname.exists():
261
262
  continue
262
- with cache_fname.open("w") as f:
263
+ with open_atomic(cache_fname, "w") as f:
263
264
  json.dump(item.serialize(), f)
264
265
 
265
266
  cur_groups = match_candidate_items_to_window(
@@ -180,7 +180,7 @@ class SimpleTimeSeries(FeatureExtractor):
180
180
  # want to pass 2 timesteps to the model.
181
181
  # TODO is probably to make this behaviour clearer but lets leave it like
182
182
  # this for now to not break things.
183
- num_timesteps = images.shape[1] // image_channels
183
+ num_timesteps = image_channels // images.shape[1]
184
184
  batched_timesteps = images.shape[2] // num_timesteps
185
185
  images = rearrange(
186
186
  images,
@@ -59,6 +59,8 @@ class SegmentationTask(BasicTask):
59
59
  miou_metric_kwargs: dict[str, Any] = {},
60
60
  prob_scales: list[float] | None = None,
61
61
  other_metrics: dict[str, Metric] = {},
62
+ output_probs: bool = False,
63
+ output_class_idx: int | None = None,
62
64
  **kwargs: Any,
63
65
  ) -> None:
64
66
  """Initialize a new SegmentationTask.
@@ -92,6 +94,10 @@ class SegmentationTask(BasicTask):
92
94
  this is only applied during prediction, not when computing val or test
93
95
  metrics.
94
96
  other_metrics: additional metrics to configure on this task.
97
+ output_probs: if True, output raw softmax probabilities instead of class IDs
98
+ during prediction.
99
+ output_class_idx: if set along with output_probs, only output the probability
100
+ for this specific class index (single-channel output).
95
101
  kwargs: additional arguments to pass to BasicTask
96
102
  """
97
103
  super().__init__(**kwargs)
@@ -116,6 +122,8 @@ class SegmentationTask(BasicTask):
116
122
  self.miou_metric_kwargs = miou_metric_kwargs
117
123
  self.prob_scales = prob_scales
118
124
  self.other_metrics = other_metrics
125
+ self.output_probs = output_probs
126
+ self.output_class_idx = output_class_idx
119
127
 
120
128
  def process_inputs(
121
129
  self,
@@ -171,7 +179,9 @@ class SegmentationTask(BasicTask):
171
179
  metadata: metadata about the patch being read
172
180
 
173
181
  Returns:
174
- CHW numpy array with one channel, containing the predicted class IDs.
182
+ CHW numpy array. If output_probs is False, returns one channel with
183
+ predicted class IDs. If output_probs is True, returns softmax probabilities
184
+ (num_classes channels, or 1 channel if output_class_idx is set).
175
185
  """
176
186
  if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
177
187
  raise ValueError("the output for SegmentationTask must be a CHW tensor")
@@ -183,6 +193,15 @@ class SegmentationTask(BasicTask):
183
193
  self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
184
194
  )[:, None, None]
185
195
  )
196
+
197
+ if self.output_probs:
198
+ # Return raw softmax probabilities
199
+ probs = raw_output.cpu().numpy()
200
+ if self.output_class_idx is not None:
201
+ # Return only the specified class probability
202
+ return probs[self.output_class_idx : self.output_class_idx + 1, :, :]
203
+ return probs
204
+
186
205
  classes = raw_output.argmax(dim=0).cpu().numpy()
187
206
  return classes[None, :, :]
188
207
 
@@ -476,6 +476,7 @@ class GeotiffRasterFormat(RasterFormat):
476
476
  bounds: PixelBounds,
477
477
  array: npt.NDArray[Any],
478
478
  fname: str | None = None,
479
+ nodata_val: int | float | None = None,
479
480
  ) -> None:
480
481
  """Encodes raster data.
481
482
 
@@ -485,6 +486,7 @@ class GeotiffRasterFormat(RasterFormat):
485
486
  bounds: the bounds of the raster data in the projection
486
487
  array: the raster data
487
488
  fname: override the filename to save as
489
+ nodata_val: set the nodata value when writing the raster.
488
490
  """
489
491
  if fname is None:
490
492
  fname = self.fname
@@ -520,6 +522,9 @@ class GeotiffRasterFormat(RasterFormat):
520
522
  profile["tiled"] = True
521
523
  profile["blockxsize"] = self.block_size
522
524
  profile["blockysize"] = self.block_size
525
+ # Set nodata_val if provided.
526
+ if nodata_val is not None:
527
+ profile["nodata"] = nodata_val
523
528
 
524
529
  profile.update(self.geotiff_options)
525
530
 
@@ -535,6 +540,7 @@ class GeotiffRasterFormat(RasterFormat):
535
540
  bounds: PixelBounds,
536
541
  resampling: Resampling = Resampling.bilinear,
537
542
  fname: str | None = None,
543
+ nodata_val: int | float | None = None,
538
544
  ) -> npt.NDArray[Any]:
539
545
  """Decodes raster data.
540
546
 
@@ -544,6 +550,16 @@ class GeotiffRasterFormat(RasterFormat):
544
550
  bounds: the bounds to read in the given projection.
545
551
  resampling: resampling method to use in case resampling is needed.
546
552
  fname: override the filename to read from
553
+ nodata_val: override the nodata value in the raster when reading. Pixels in
554
+ bounds that are not present in the source raster will be initialized to
555
+ this value. Note that, if the raster specifies a nodata value, and
556
+ some source pixels have that value, they will still be read under their
557
+ original value; overriding the nodata value is primarily useful if the
558
+ user wants out of bounds pixels to have a different value from the
559
+ source pixels, e.g. if the source data has background and foreground
560
+ classes (with background being nodata) but we want to read it in a
561
+ different projection and have out of bounds pixels be a third "invalid"
562
+ value.
547
563
 
548
564
  Returns:
549
565
  the raster data
@@ -561,6 +577,7 @@ class GeotiffRasterFormat(RasterFormat):
561
577
  width=bounds[2] - bounds[0],
562
578
  height=bounds[3] - bounds[1],
563
579
  resampling=resampling,
580
+ src_nodata=nodata_val,
564
581
  ) as vrt:
565
582
  return vrt.read()
566
583
 
rslearn/utils/stac.py CHANGED
@@ -101,6 +101,7 @@ class StacClient:
101
101
  ids: list[str] | None = None,
102
102
  limit: int | None = None,
103
103
  query: dict[str, Any] | None = None,
104
+ sortby: list[dict[str, str]] | None = None,
104
105
  ) -> list[StacItem]:
105
106
  """Execute a STAC item search.
106
107
 
@@ -117,6 +118,7 @@ class StacClient:
117
118
  limit: number of items per page. We will read all the pages.
118
119
  query: query dict, if STAC query extension is supported by this API. See
119
120
  https://github.com/stac-api-extensions/query.
121
+ sortby: list of sort specifications, e.g. [{"field": "id", "direction": "asc"}].
120
122
 
121
123
  Returns:
122
124
  list of matching STAC items.
@@ -142,6 +144,8 @@ class StacClient:
142
144
  request_data["limit"] = limit
143
145
  if query is not None:
144
146
  request_data["query"] = query
147
+ if sortby is not None:
148
+ request_data["sortby"] = sortby
145
149
 
146
150
  # Handle pagination.
147
151
  cur_url = self.endpoint + "/search"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.23
3
+ Version: 0.0.24
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -25,9 +25,9 @@ rslearn/data_sources/local_files.py,sha256=mo5W_BxBl89EPTIHNDEpXM6qBjrP225KK0Pcm
25
25
  rslearn/data_sources/openstreetmap.py,sha256=TzZfouc2Z4_xjx2v_uv7aPn4tVW3flRVQN4qBfl507E,18161
26
26
  rslearn/data_sources/planet.py,sha256=6FWQ0bl1k3jwvwp4EVGi2qs3OD1QhnKOKP36mN4HELI,9446
27
27
  rslearn/data_sources/planet_basemap.py,sha256=e9R6FlagJjg8Z6Rc1dC6zK3xMkCohz8eohXqXmd29xg,9670
28
- rslearn/data_sources/planetary_computer.py,sha256=8kVatSXnwPUZljVOjj9vnVbOsmWhRdROi5YTiCmYmII,22594
28
+ rslearn/data_sources/planetary_computer.py,sha256=k6CO5Yim2I-frlD8r2_uBo0CQFw89mN3_5mrv0Xk2WU,26449
29
29
  rslearn/data_sources/soilgrids.py,sha256=rwO4goFPQ7lx420FvYBHYFXdihnZqn_-IjdqtxQ9j2g,12455
30
- rslearn/data_sources/stac.py,sha256=l7V1QzvpNtoH_funiTSl1J8Lj1P3nMj24_fRpgCAslQ,10692
30
+ rslearn/data_sources/stac.py,sha256=Xty1JDueAAonNVLRo8vfNBhlHrVLjhmZ-uRBYbrGvtA,10753
31
31
  rslearn/data_sources/usda_cdl.py,sha256=_WvxZkm0fbXfniRs6NT8iVCbTTmVPflDhsFT2ci6_Dk,6879
32
32
  rslearn/data_sources/usgs_landsat.py,sha256=kPOb3hsZe5-guUcFZZkwzcRpYZ3Zo7Bk4E829q_xiyU,18516
33
33
  rslearn/data_sources/utils.py,sha256=v_90ALOuts7RHNcx-j8o-aQ_aFjh8ZhXrmsaa9uEGDA,11651
@@ -69,7 +69,7 @@ rslearn/models/prithvi.py,sha256=J45eC1pd4l5AGlr19Qjrjrw5PPwvYE9bNM5qCFoznmg,403
69
69
  rslearn/models/resize_features.py,sha256=U7ZIVwwToJJnwchFG59wLWWP9eikHDB_1c4OtpubxHU,1693
70
70
  rslearn/models/sam2_enc.py,sha256=WZOtlp0FjaVztW4gEVIcsFQdKArS9iblRODP0b6Oc8M,3641
71
71
  rslearn/models/satlaspretrain.py,sha256=2R48ulbtd44Qy2FYJCkllE2Wk35eZxkc79ruSgkmgcQ,3384
72
- rslearn/models/simple_time_series.py,sha256=Nfk5E3d9W-4AyLQiy-P8p-JvxmFYE3FBrvOgttjXSMw,14678
72
+ rslearn/models/simple_time_series.py,sha256=farQwt_nJVyAbgaM2UzdyqpDuIO0SLmHr9e9EVPSWCE,14678
73
73
  rslearn/models/singletask.py,sha256=9DM9a9-Mv3vVQqRhPOIXG2HHuVqVa_zuvgafeeYh4r0,1903
74
74
  rslearn/models/ssl4eo_s12.py,sha256=DOlpIj6NfjIlWyJ27m9Xo8TMlovBDstFq0ARnmAJ6qY,3919
75
75
  rslearn/models/swin.py,sha256=Xqr3SswbHP6IhwT2atZMAPF2TUzQqfxvihksb8WSeRo,6065
@@ -133,7 +133,7 @@ rslearn/train/tasks/embedding.py,sha256=NdJEAaDWlWYzvOBVf7eIHfFOzqTgavfFH1J1gMbA
133
133
  rslearn/train/tasks/multi_task.py,sha256=32hvwyVsHqt7N_M3zXsTErK1K7-0-BPHzt7iGNehyaI,6314
134
134
  rslearn/train/tasks/per_pixel_regression.py,sha256=Clrod6LQGjgNC0IAR4HLY7eCGWMHj2mk4d4moZCl4Qc,10209
135
135
  rslearn/train/tasks/regression.py,sha256=bVS_ApZSpbL0NaaM8Mu5Bsu4SBUyLpVtrPslulvvZHs,12695
136
- rslearn/train/tasks/segmentation.py,sha256=Y3Sm2oOzR3yJCpagwBmp1yCwa024MQN2v1PcpiaWBf8,28425
136
+ rslearn/train/tasks/segmentation.py,sha256=LZeuveHhMQsjNOQfMcwqSI4Ux3k9zfa58A2eZHSif8Y,29391
137
137
  rslearn/train/tasks/task.py,sha256=nMPunl9OlnOimr48saeTnwKMQ7Du4syGrwNKVQq4FL4,4110
138
138
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
139
139
  rslearn/train/transforms/concatenate.py,sha256=hVVBaxIdk1Cx8JHPirj54TGpbWAJx5y_xD7k1rmGmT0,3166
@@ -155,17 +155,17 @@ rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF
155
155
  rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
156
156
  rslearn/utils/jsonargparse.py,sha256=TRyZA151KzhjJlZczIHYguEA-YxCDYaZ2IwCRgx76nM,4791
157
157
  rslearn/utils/mp.py,sha256=XYmVckI5TOQuCKc49NJyirDJyFgvb4AI-gGypG2j680,1399
158
- rslearn/utils/raster_format.py,sha256=qZpbODF4I7BsOxf43O6vTmH2TSNw6N8PP0wsFUVvdIw,26267
158
+ rslearn/utils/raster_format.py,sha256=fwotJBadwqYSdK8UokiKOk1pOF8JMim3kP_VwLWivPI,27382
159
159
  rslearn/utils/rtree_index.py,sha256=j0Zwrq3pXuAJ-hKpiRFQ7VNtvO3fZYk-Em2uBPAqfx4,6460
160
160
  rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfsU,762
161
161
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
162
- rslearn/utils/stac.py,sha256=z93N5ZeEe1oUikX5ILMA5sQEZX276sAeMjsg0TShnSk,5776
162
+ rslearn/utils/stac.py,sha256=c8NwOCKWvUwA-FSKlxZn-t7RZYweuye53OufT0bAK4A,5996
163
163
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
164
164
  rslearn/utils/vector_format.py,sha256=4ZDYpfBLLxguJkiIaavTagiQK2Sv4Rz9NumbHlq-3Lw,15041
165
- rslearn-0.0.23.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
166
- rslearn-0.0.23.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
167
- rslearn-0.0.23.dist-info/METADATA,sha256=YFo7HcByJFrlgbSqcCUat2Z7nn1RU0aQzR0InaDSKEg,37936
168
- rslearn-0.0.23.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
169
- rslearn-0.0.23.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
170
- rslearn-0.0.23.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
171
- rslearn-0.0.23.dist-info/RECORD,,
165
+ rslearn-0.0.24.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
166
+ rslearn-0.0.24.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
167
+ rslearn-0.0.24.dist-info/METADATA,sha256=gV5mgeYPYiKWrEu7D8acOubWvg76Nn_4ICvlD7iTpcs,37936
168
+ rslearn-0.0.24.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
169
+ rslearn-0.0.24.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
170
+ rslearn-0.0.24.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
171
+ rslearn-0.0.24.dist-info/RECORD,,