rslearn 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,255 @@
1
+ """A partial data source implementation providing get_items using a STAC API."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ import shapely
7
+ from upath import UPath
8
+
9
+ from rslearn.config import QueryConfig
10
+ from rslearn.const import WGS84_PROJECTION
11
+ from rslearn.data_sources.data_source import Item, ItemLookupDataSource
12
+ from rslearn.data_sources.utils import match_candidate_items_to_window
13
+ from rslearn.log_utils import get_logger
14
+ from rslearn.utils.geometry import STGeometry
15
+ from rslearn.utils.stac import StacClient, StacItem
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class SourceItem(Item):
21
+ """An item in the StacDataSource data source."""
22
+
23
+ def __init__(
24
+ self,
25
+ name: str,
26
+ geometry: STGeometry,
27
+ asset_urls: dict[str, str],
28
+ properties: dict[str, str],
29
+ ):
30
+ """Creates a new SourceItem.
31
+
32
+ Args:
33
+ name: unique name of the item
34
+ geometry: the spatial and temporal extent of the item
35
+ asset_urls: map from asset key to the unsigned asset URL.
36
+ properties: properties requested by the data source implementation.
37
+ """
38
+ super().__init__(name, geometry)
39
+ self.asset_urls = asset_urls
40
+ self.properties = properties
41
+
42
+ def serialize(self) -> dict[str, Any]:
43
+ """Serializes the item to a JSON-encodable dictionary."""
44
+ d = super().serialize()
45
+ d["asset_urls"] = self.asset_urls
46
+ d["properties"] = self.properties
47
+ return d
48
+
49
+ @staticmethod
50
+ def deserialize(d: dict[str, Any]) -> "SourceItem":
51
+ """Deserializes an item from a JSON-decoded dictionary."""
52
+ item = super(SourceItem, SourceItem).deserialize(d)
53
+ return SourceItem(
54
+ name=item.name,
55
+ geometry=item.geometry,
56
+ asset_urls=d["asset_urls"],
57
+ properties=d["properties"],
58
+ )
59
+
60
+
61
+ class StacDataSource(ItemLookupDataSource[SourceItem]):
62
+ """A partial data source implementing get_items using a STAC API.
63
+
64
+ This is a helper class that full implementations can extend to not have to worry
65
+ about the get_items and get_item_by_name implementation.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ endpoint: str,
71
+ collection_name: str,
72
+ query: dict[str, Any] | None = None,
73
+ sort_by: str | None = None,
74
+ sort_ascending: bool = True,
75
+ required_assets: list[str] | None = None,
76
+ cache_dir: UPath | None = None,
77
+ limit: int = 100,
78
+ properties_to_record: list[str] = [],
79
+ ):
80
+ """Create a new StacDataSource.
81
+
82
+ Args:
83
+ endpoint: the STAC endpoint to use.
84
+ collection_name: the STAC collection name.
85
+ query: optional STAC query dict to include in searches, e.g. {"eo:cloud_cover": {"lt": 50}}.
86
+ sort_by: sort results by this STAC property.
87
+ sort_ascending: if sort_by is set, sort in ascending order (default).
88
+ Otherwise sort in descending order.
89
+ required_assets: if set, we ignore items that do not have all of these
90
+ asset keys.
91
+ cache_dir: optional cache directory to cache items. This is recommended if
92
+ allowing direct materialization from the data source, since it will
93
+ likely be necessary to make lots of get_item_by_name calls during
94
+ materialization. TODO: give direct materialization access to the Item
95
+ object.
96
+ limit: limit to pass to search queries.
97
+ properties_to_record: if these properties on the STAC item exist, they are
98
+ are retained in the SourceItem when we initialize it.
99
+ """
100
+ self.client = StacClient(endpoint)
101
+ self.collection_name = collection_name
102
+ self.query = query
103
+ self.sort_by = sort_by
104
+ self.sort_ascending = sort_ascending
105
+ self.required_assets = required_assets
106
+ self.cache_dir = cache_dir
107
+ self.limit = limit
108
+ self.properties_to_record = properties_to_record
109
+
110
+ def _stac_item_to_item(self, stac_item: StacItem) -> SourceItem:
111
+ # Make sure geometry, time range, and assets are set.
112
+ if stac_item.geometry is None:
113
+ raise ValueError("got unexpected item with no geometry")
114
+ if stac_item.time_range is None:
115
+ raise ValueError("got unexpected item with no time range")
116
+ if stac_item.assets is None:
117
+ raise ValueError("got unexpected item with no assets")
118
+
119
+ shp = shapely.geometry.shape(stac_item.geometry)
120
+ geom = STGeometry(WGS84_PROJECTION, shp, stac_item.time_range)
121
+ asset_urls = {
122
+ asset_key: asset_obj.href
123
+ for asset_key, asset_obj in stac_item.assets.items()
124
+ }
125
+
126
+ # Keep any properties requested by the data source implementation.
127
+ properties = {}
128
+ for prop_name in self.properties_to_record:
129
+ if prop_name not in stac_item.properties:
130
+ continue
131
+ properties[prop_name] = stac_item.properties[prop_name]
132
+
133
+ return SourceItem(stac_item.id, geom, asset_urls, properties)
134
+
135
+ def get_item_by_name(self, name: str) -> SourceItem:
136
+ """Gets an item by name.
137
+
138
+ Args:
139
+ name: the name of the item to get
140
+
141
+ Returns:
142
+ the item object
143
+ """
144
+ # If cache_dir is set, we cache the item. First here we check if it is already
145
+ # in the cache.
146
+ cache_fname: UPath | None = None
147
+ if self.cache_dir:
148
+ cache_fname = self.cache_dir / f"{name}.json"
149
+ if cache_fname is not None and cache_fname.exists():
150
+ with cache_fname.open() as f:
151
+ return SourceItem.deserialize(json.load(f))
152
+
153
+ # No cache or not in cache, so we need to make the STAC request.
154
+ logger.debug(f"Getting STAC item {name}")
155
+ stac_items = self.client.search(ids=[name], collections=[self.collection_name])
156
+
157
+ if len(stac_items) == 0:
158
+ raise ValueError(
159
+ f"Item {name} not found in collection {self.collection_name}"
160
+ )
161
+ if len(stac_items) > 1:
162
+ raise ValueError(
163
+ f"Multiple items found for ID {name} in collection {self.collection_name}"
164
+ )
165
+
166
+ stac_item = stac_items[0]
167
+ item = self._stac_item_to_item(stac_item)
168
+
169
+ # Finally we cache it if cache_dir is set.
170
+ if cache_fname is not None:
171
+ with cache_fname.open("w") as f:
172
+ json.dump(item.serialize(), f)
173
+
174
+ return item
175
+
176
+ def get_items(
177
+ self, geometries: list[STGeometry], query_config: QueryConfig
178
+ ) -> list[list[list[SourceItem]]]:
179
+ """Get a list of items in the data source intersecting the given geometries.
180
+
181
+ Args:
182
+ geometries: the spatiotemporal geometries
183
+ query_config: the query configuration
184
+
185
+ Returns:
186
+ List of groups of items that should be retrieved for each geometry.
187
+ """
188
+ groups = []
189
+ for geometry in geometries:
190
+ # Get potentially relevant items from the collection by performing one search
191
+ # for each requested geometry.
192
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
193
+ logger.debug("performing STAC search for geometry %s", wgs84_geometry)
194
+ stac_items = self.client.search(
195
+ collections=[self.collection_name],
196
+ intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
197
+ date_time=wgs84_geometry.time_range,
198
+ query=self.query,
199
+ limit=self.limit,
200
+ )
201
+ logger.debug("STAC search yielded %d items", len(stac_items))
202
+
203
+ if self.required_assets is not None:
204
+ # Filter out items that are missing any of the assets in self.asset_bands.
205
+ good_stac_items = []
206
+ for stac_item in stac_items:
207
+ if stac_item.assets is None:
208
+ raise ValueError(f"got STAC item {stac_item.id} with no assets")
209
+
210
+ good = True
211
+ for asset_key in self.required_assets:
212
+ if asset_key in stac_item.assets:
213
+ continue
214
+ good = False
215
+ break
216
+ if good:
217
+ good_stac_items.append(stac_item)
218
+ logger.debug(
219
+ "required_assets filter from %d to %d items",
220
+ len(stac_items),
221
+ len(good_stac_items),
222
+ )
223
+ stac_items = good_stac_items
224
+
225
+ if self.sort_by is not None:
226
+ sort_by = self.sort_by
227
+ stac_items.sort(
228
+ key=lambda stac_item: stac_item.properties[sort_by],
229
+ reverse=not self.sort_ascending,
230
+ )
231
+
232
+ candidate_items = [
233
+ self._stac_item_to_item(stac_item) for stac_item in stac_items
234
+ ]
235
+
236
+ # Since we made the STAC request, might as well save these to the cache.
237
+ if self.cache_dir is not None:
238
+ for item in candidate_items:
239
+ cache_fname = self.cache_dir / f"{item.name}.json"
240
+ if cache_fname.exists():
241
+ continue
242
+ with cache_fname.open("w") as f:
243
+ json.dump(item.serialize(), f)
244
+
245
+ cur_groups = match_candidate_items_to_window(
246
+ geometry, candidate_items, query_config
247
+ )
248
+ groups.append(cur_groups)
249
+
250
+ return groups
251
+
252
+ def deserialize_item(self, serialized_item: Any) -> SourceItem:
253
+ """Deserializes an item from JSON-decoded data."""
254
+ assert isinstance(serialized_item, dict)
255
+ return SourceItem.deserialize(serialized_item)
@@ -150,8 +150,11 @@ class AttentionPool(IntermediateComponent):
150
150
  D // self.num_heads
151
151
  )
152
152
  attn_weights = F.softmax(attn_scores, dim=-1)
153
- x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
- return x.reshape(B, D, H, W)
153
+ x = torch.matmul(attn_weights, v) # [B*H*W, num_heads, 1, D_head]
154
+ x = x.squeeze(-2) # [B*H*W, num_heads, D_head]
155
+ return rearrange(
156
+ x, "(b h w) nh dh -> b (nh dh) h w", b=B, h=H, w=W
157
+ ) # [B, D, H, W]
155
158
 
156
159
  def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
160
  """Forward pass for attention pooling linear probe.
@@ -159,20 +159,19 @@ class OlmoEarth(FeatureExtractor):
159
159
  that contains the distributed checkpoint. This is the format produced by
160
160
  pre-training runs in olmoearth_pretrain.
161
161
  """
162
- # We avoid loading the train module here because it depends on running within
163
- # olmo_core.
164
- # Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
165
- require_olmo_core("_load_model_from_checkpoint")
166
- from olmo_core.distributed.checkpoint import load_model_and_optim_state
167
-
168
162
  with (checkpoint_upath / "config.json").open() as f:
169
163
  config_dict = json.load(f)
170
164
  model_config = Config.from_dict(config_dict["model"])
171
165
 
172
166
  model = model_config.build()
173
167
 
174
- # Load the checkpoint.
168
+ # Load the checkpoint (requires olmo_core for distributed checkpoint loading).
175
169
  if not random_initialization:
170
+ require_olmo_core(
171
+ "_load_model_from_checkpoint with random_initialization=False"
172
+ )
173
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
174
+
176
175
  train_module_dir = checkpoint_upath / "model_and_optim"
177
176
  load_model_and_optim_state(str(train_module_dir), model)
178
177
  logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
@@ -242,7 +241,7 @@ class OlmoEarth(FeatureExtractor):
242
241
  present_modalities.append(modality)
243
242
  tensors = []
244
243
  for idx, inp in enumerate(context.inputs):
245
- assert isinstance(inp, RasterImage)
244
+ assert isinstance(inp[modality], RasterImage)
246
245
  tensors.append(inp[modality].image)
247
246
  cur_timestamps = inp[modality].timestamps
248
247
  if cur_timestamps is not None and len(cur_timestamps) > len(
rslearn/train/dataset.py CHANGED
@@ -205,8 +205,7 @@ def read_raster_layer_for_data_input(
205
205
  group_idx: int,
206
206
  layer_config: LayerConfig,
207
207
  data_input: DataInput,
208
- layer_data: WindowLayerData | None,
209
- ) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
208
+ ) -> torch.Tensor:
210
209
  """Read a raster layer for a DataInput.
211
210
 
212
211
  This scans the available rasters for the layer at the window to determine which
@@ -219,11 +218,9 @@ def read_raster_layer_for_data_input(
219
218
  group_idx: the item group.
220
219
  layer_config: the layer configuration.
221
220
  data_input: the DataInput that specifies the bands and dtype.
222
- layer_data: the WindowLayerData associated with this layer and window.
223
221
 
224
222
  Returns:
225
- RasterImage containing raster data and the timestamp associated
226
- with that data.
223
+ Raster data as a tensor.
227
224
  """
228
225
  # See what different sets of bands we need to read to get all the
229
226
  # configured bands.
@@ -294,34 +291,46 @@ def read_raster_layer_for_data_input(
294
291
  src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
295
292
  )
296
293
 
297
- # add the timestamp. this is a tuple defining the start and end of the time range.
298
- time_range = None
299
- if layer_data is not None:
300
- item = Item.deserialize(layer_data.serialized_item_groups[group_idx][0])
301
- if item.geometry.time_range is not None:
302
- # we assume if one layer data has a geometry & time range, all of them do
303
- time_ranges = [
304
- (
305
- datetime.fromisoformat(
306
- Item.deserialize(
307
- layer_data.serialized_item_groups[group_idx][idx]
308
- ).geometry.time_range[0] # type: ignore
309
- ),
310
- datetime.fromisoformat(
311
- Item.deserialize(
312
- layer_data.serialized_item_groups[group_idx][idx]
313
- ).geometry.time_range[1] # type: ignore
314
- ),
315
- )
316
- for idx in range(len(layer_data.serialized_item_groups[group_idx]))
317
- ]
318
- # take the min and max
319
- time_range = (
320
- min([t[0] for t in time_ranges]),
321
- max([t[1] for t in time_ranges]),
294
+ return image
295
+
296
+
297
+ def read_layer_time_range(
298
+ layer_data: WindowLayerData | None, group_idx: int
299
+ ) -> tuple[datetime, datetime] | None:
300
+ """Extract the combined time range from all items in a layer data group.
301
+
302
+ Returns the min start time and max end time across all items, or None if
303
+ no items have time ranges.
304
+
305
+ Raises:
306
+ ValueError: If some items have time_range and others don't.
307
+ """
308
+ if layer_data is None:
309
+ return None
310
+
311
+ serialized_items = layer_data.serialized_item_groups[group_idx]
312
+ if not serialized_items:
313
+ return None
314
+
315
+ first_item = Item.deserialize(serialized_items[0])
316
+ if first_item.geometry.time_range is None:
317
+ return None
318
+
319
+ # If the first item has a time_range, all items must have one
320
+ time_ranges: list[tuple[datetime, datetime]] = []
321
+ for serialized_item in serialized_items:
322
+ item = Item.deserialize(serialized_item)
323
+ if item.geometry.time_range is None:
324
+ raise ValueError(
325
+ f"Item '{item.name}' has no time_range, but first item does. "
326
+ "All items in a group must consistently have or lack time_range."
322
327
  )
328
+ time_ranges.append(item.geometry.time_range)
323
329
 
324
- return image, time_range
330
+ return (
331
+ min(tr[0] for tr in time_ranges),
332
+ max(tr[1] for tr in time_ranges),
333
+ )
325
334
 
326
335
 
327
336
  def read_data_input(
@@ -378,17 +387,17 @@ def read_data_input(
378
387
  time_ranges: list[tuple[datetime, datetime] | None] = []
379
388
  for layer_name, group_idx in layers_to_read:
380
389
  layer_config = dataset.layers[layer_name]
381
- image, time_range = read_raster_layer_for_data_input(
390
+ image = read_raster_layer_for_data_input(
382
391
  window,
383
392
  bounds,
384
393
  layer_name,
385
394
  group_idx,
386
395
  layer_config,
387
396
  data_input,
388
- # some layers (e.g. "label_raster") won't have associated
389
- # layer datas
390
- layer_datas[layer_name] if layer_name in layer_datas else None,
391
397
  )
398
+ # some layers (e.g. "label_raster") won't have associated layer datas
399
+ layer_data = layer_datas.get(layer_name)
400
+ time_range = read_layer_time_range(layer_data, group_idx)
392
401
  if len(time_ranges) > 0:
393
402
  if type(time_ranges[-1]) is not type(time_range):
394
403
  raise ValueError(
@@ -345,14 +345,14 @@ class RslearnLightningModule(L.LightningModule):
345
345
  )
346
346
 
347
347
  if self.visualize_dir:
348
- for idx, (inp, target, output, metadata) in enumerate(
349
- zip(inputs, targets, outputs, metadatas)
348
+ for inp, target, output, metadata in zip(
349
+ inputs, targets, outputs, metadatas
350
350
  ):
351
351
  images = self.task.visualize(inp, target, output)
352
352
  for image_suffix, image in images.items():
353
353
  out_fname = os.path.join(
354
354
  self.visualize_dir,
355
- f"{metadata['window_name']}_{metadata['bounds'][0]}_{metadata['bounds'][1]}_{image_suffix}.png",
355
+ f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
356
356
  )
357
357
  Image.fromarray(image).save(out_fname)
358
358
 
@@ -6,7 +6,7 @@ import numpy.typing as npt
6
6
  import torch
7
7
  from torchmetrics import MetricCollection
8
8
 
9
- from rslearn.models.component import FeatureMaps
9
+ from rslearn.models.component import FeatureMaps, Predictor
10
10
  from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
11
11
  from rslearn.utils import Feature
12
12
 
@@ -83,7 +83,7 @@ class EmbeddingTask(Task):
83
83
  return MetricCollection({})
84
84
 
85
85
 
86
- class EmbeddingHead:
86
+ class EmbeddingHead(Predictor):
87
87
  """Head for embedding task.
88
88
 
89
89
  It just adds a dummy loss to act as a Predictor.
@@ -108,8 +108,10 @@ class BasicTask(Task):
108
108
  Returns:
109
109
  a dictionary mapping image name to visualization image
110
110
  """
111
- image = input_dict["image"].cpu()
112
- image = image[self.image_bands, :, :]
111
+ raster_image = input_dict["image"]
112
+ assert isinstance(raster_image, RasterImage)
113
+ # We don't really handle time series here, just use the first timestep.
114
+ image = raster_image.image.cpu()[self.image_bands, 0, :, :]
113
115
  if self.remap_values:
114
116
  factor = (self.remap_values[1][1] - self.remap_values[1][0]) / (
115
117
  self.remap_values[0][1] - self.remap_values[0][0]
rslearn/utils/geometry.py CHANGED
@@ -153,8 +153,8 @@ class ResolutionFactor:
153
153
  else:
154
154
  return Projection(
155
155
  projection.crs,
156
- projection.x_resolution // self.numerator,
157
- projection.y_resolution // self.numerator,
156
+ projection.x_resolution / self.numerator,
157
+ projection.y_resolution / self.numerator,
158
158
  )
159
159
 
160
160
  def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
rslearn/utils/stac.py ADDED
@@ -0,0 +1,173 @@
1
+ """STAC API client."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import requests
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ Bbox = tuple[float, float, float, float]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class StacAsset:
17
+ """A STAC asset."""
18
+
19
+ href: str
20
+ title: str | None
21
+ type: str | None
22
+ roles: list[str] | None
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class StacItem:
27
+ """A STAC item."""
28
+
29
+ id: str
30
+ properties: dict[str, Any]
31
+ collection: str | None
32
+ bbox: Bbox | None
33
+ geometry: dict[str, Any] | None
34
+ assets: dict[str, StacAsset] | None
35
+ time_range: tuple[datetime, datetime] | None
36
+
37
+ @classmethod
38
+ def from_dict(cls, item: dict[str, Any]) -> "StacItem":
39
+ """Create a STAC item from the item dict returned from API."""
40
+ properties = item.get("properties", {})
41
+
42
+ # Parse bbox.
43
+ bbox: Bbox | None = None
44
+ if "bbox" in item:
45
+ if len(item["bbox"]) != 4:
46
+ raise NotImplementedError(
47
+ f"got bbox with {len(item['bbox'])} coordinates but only 4 coordinates is implemented"
48
+ )
49
+ bbox = tuple(item["bbox"])
50
+
51
+ # Parse assets.
52
+ assets: dict[str, StacAsset] = {}
53
+ for name, asset in item.get("assets", {}).items():
54
+ assets[name] = StacAsset(
55
+ href=asset["href"],
56
+ title=asset.get("title"),
57
+ type=asset.get("type"),
58
+ roles=asset.get("roles"),
59
+ )
60
+
61
+ # Parse time range.
62
+ time_range: tuple[datetime, datetime] | None = None
63
+ if "start_datetime" in properties and "end_datetime" in properties:
64
+ time_range = (
65
+ datetime.fromisoformat(properties["start_datetime"]),
66
+ datetime.fromisoformat(properties["end_datetime"]),
67
+ )
68
+ elif "datetime" in properties:
69
+ ts = datetime.fromisoformat(properties["datetime"])
70
+ time_range = (ts, ts)
71
+
72
+ return cls(
73
+ id=item["id"],
74
+ properties=properties,
75
+ collection=item.get("collection"),
76
+ bbox=bbox,
77
+ geometry=item.get("geometry"),
78
+ assets=assets,
79
+ time_range=time_range,
80
+ )
81
+
82
+
83
+ class StacClient:
84
+ """Limited functionality client for STAC APIs."""
85
+
86
+ def __init__(self, endpoint: str):
87
+ """Create a new StacClient.
88
+
89
+ Args:
90
+ endpoint: the STAC endpoint (base URL)
91
+ """
92
+ self.endpoint = endpoint
93
+ self.session = requests.Session()
94
+
95
+ def search(
96
+ self,
97
+ collections: list[str] | None = None,
98
+ bbox: Bbox | None = None,
99
+ intersects: dict[str, Any] | None = None,
100
+ date_time: datetime | tuple[datetime, datetime] | None = None,
101
+ ids: list[str] | None = None,
102
+ limit: int | None = None,
103
+ query: dict[str, Any] | None = None,
104
+ ) -> list[StacItem]:
105
+ """Execute a STAC item search.
106
+
107
+ We use the JSON POST API. Pagination is handled so the returned items are
108
+ concatenated across all available pages.
109
+
110
+ Args:
111
+ collections: only search within the provided collection(s).
112
+ bbox: only return features intersecting the provided bounding box.
113
+ intersects: only return features intersecting this GeoJSON geometry.
114
+ date_time: only return features that have a temporal property intersecting
115
+ the provided time range or timestamp.
116
+ ids: only return the provided item IDs.
117
+ limit: number of items per page. We will read all the pages.
118
+ query: query dict, if STAC query extension is supported by this API. See
119
+ https://github.com/stac-api-extensions/query.
120
+
121
+ Returns:
122
+ list of matching STAC items.
123
+ """
124
+ # Build JSON request data.
125
+ request_data: dict[str, Any] = {}
126
+ if collections is not None:
127
+ request_data["collections"] = collections
128
+ if bbox is not None:
129
+ request_data["bbox"] = bbox
130
+ if intersects is not None:
131
+ request_data["intersects"] = intersects
132
+ if date_time is not None:
133
+ if isinstance(date_time, tuple):
134
+ start_time = date_time[0].isoformat().replace("+00:00", "Z")
135
+ end_time = date_time[1].isoformat().replace("+00:00", "Z")
136
+ request_data["datetime"] = f"{start_time}/{end_time}"
137
+ else:
138
+ request_data["datetime"] = date_time.isoformat().replace("+00:00", "Z")
139
+ if ids is not None:
140
+ request_data["ids"] = ids
141
+ if limit is not None:
142
+ request_data["limit"] = limit
143
+ if query is not None:
144
+ request_data["query"] = query
145
+
146
+ # Handle pagination.
147
+ cur_url = self.endpoint + "/search"
148
+ items: list[StacItem] = []
149
+ while True:
150
+ logger.debug("Reading STAC items from %s", cur_url)
151
+ response = self.session.post(url=cur_url, json=request_data)
152
+ response.raise_for_status()
153
+ data = response.json()
154
+ for item_dict in data["features"]:
155
+ items.append(StacItem.from_dict(item_dict))
156
+
157
+ next_link = None
158
+ next_request_data: dict[str, Any] = {}
159
+ for link in data.get("links", []):
160
+ if "rel" not in link or link["rel"] != "next":
161
+ continue
162
+ assert link["method"] == "POST"
163
+ next_link = link["href"]
164
+ next_request_data = link["body"]
165
+ break
166
+
167
+ if next_link is None:
168
+ break
169
+
170
+ cur_url = next_link
171
+ request_data = next_request_data
172
+
173
+ return items