rslearn 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,331 @@
1
+ """Data source for SoilGrids via the `soilgrids` Python package.
2
+
3
+ This source is intended to be used with `ingest: false` (direct materialization),
4
+ since data is fetched on-demand per window.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import tempfile
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ import rasterio
15
+ import rasterio.warp
16
+ import shapely
17
+ from rasterio.crs import CRS
18
+ from rasterio.enums import Resampling
19
+ from upath import UPath
20
+
21
+ from rslearn.config import LayerConfig, QueryConfig
22
+ from rslearn.dataset import Window
23
+ from rslearn.dataset.materialize import RasterMaterializer
24
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
25
+ from rslearn.utils import PixelBounds, Projection, STGeometry
26
+ from rslearn.utils.geometry import get_global_geometry
27
+ from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
28
+
29
+ from .data_source import DataSource, DataSourceContext, Item
30
+ from .utils import match_candidate_items_to_window
31
+
32
+ SOILGRIDS_NODATA_VALUE = -32768.0
33
+ """Default nodata value used by SoilGrids GeoTIFF responses (GEOTIFF_INT16)."""
34
+
35
+
36
+ def _crs_to_rasterio(crs: str) -> CRS:
37
+ """Best-effort conversion of CRS strings used by `soilgrids` to rasterio CRS."""
38
+ try:
39
+ return CRS.from_string(crs)
40
+ except Exception:
41
+ # Fallback: if rasterio can't parse the string but it contains an EPSG code,
42
+ # extract the trailing integer and build a CRS from it.
43
+ parts = [p for p in crs.replace(":", " ").split() if p.isdigit()]
44
+ if parts:
45
+ return CRS.from_epsg(int(parts[-1]))
46
+ raise
47
+
48
+
49
+ def _crs_to_soilgrids_urn(crs: str) -> str:
50
+ """Convert common CRS spellings to the URN form expected by `soilgrids`.
51
+
52
+ The `soilgrids` package compares CRS strings against the supported CRS URNs from
53
+ OWSLib (e.g. "urn:ogc:def:crs:EPSG::3857"). This helper allows users to specify
54
+ simpler forms like "EPSG:3857" while still working.
55
+ """
56
+ s = crs.strip()
57
+
58
+ # If already an EPSG URN, canonicalize to the form soilgrids expects.
59
+ if s.lower().startswith("urn:ogc:def:crs:") and "epsg" in s.lower():
60
+ parts = [p for p in s.replace(":", " ").split() if p.isdigit()]
61
+ if parts:
62
+ return f"urn:ogc:def:crs:EPSG::{parts[-1]}"
63
+ return s
64
+
65
+ # Accept "EPSG:3857", "epsg:3857", or other strings containing an EPSG code.
66
+ if "epsg" in s.lower():
67
+ parts = [p for p in s.replace(":", " ").split() if p.isdigit()]
68
+ if parts:
69
+ return f"urn:ogc:def:crs:EPSG::{parts[-1]}"
70
+
71
+ return s
72
+
73
+
74
+ class SoilGrids(DataSource, TileStore):
75
+ """Access SoilGrids coverages as an rslearn raster data source."""
76
+
77
+ def __init__(
78
+ self,
79
+ service_id: str,
80
+ coverage_id: str,
81
+ crs: str = "EPSG:3857",
82
+ width: int | None = None,
83
+ height: int | None = None,
84
+ resx: float | None = None,
85
+ resy: float | None = None,
86
+ response_crs: str | None = None,
87
+ band_names: list[str] = ["B1"],
88
+ context: DataSourceContext = DataSourceContext(),
89
+ ):
90
+ """Create a new SoilGrids data source.
91
+
92
+ Args:
93
+ service_id: SoilGrids map service id (e.g., "clay", "phh2o").
94
+ coverage_id: coverage id within the service (e.g., "clay_0-5cm_mean").
95
+ crs: request CRS string passed through to `soilgrids.SoilGrids`, typically
96
+ a URN like "urn:ogc:def:crs:EPSG::4326" or "urn:ogc:def:crs:EPSG::152160".
97
+ width: optional WCS WIDTH parameter. Required by SoilGrids WCS when CRS is
98
+ EPSG:4326.
99
+ height: optional WCS HEIGHT parameter.
100
+ resx: optional WCS RESX parameter (projection units / pixel).
101
+ resy: optional WCS RESY parameter (projection units / pixel).
102
+ response_crs: optional response CRS (defaults to `crs`).
103
+ band_names: band names exposed to rslearn. For a single coverage, this
104
+ should have length 1.
105
+ context: rslearn data source context.
106
+ """
107
+ if len(band_names) != 1:
108
+ raise ValueError("SoilGrids currently supports only single-band coverages")
109
+ if (width is None) != (height is None):
110
+ raise ValueError("width and height must be specified together")
111
+ if (resx is None) != (resy is None):
112
+ raise ValueError("resx and resy must be specified together")
113
+ if width is not None and resx is not None:
114
+ raise ValueError("specify either width/height or resx/resy, not both")
115
+
116
+ self.service_id = service_id
117
+ self.coverage_id = coverage_id
118
+ self.crs = crs
119
+ self.width = width
120
+ self.height = height
121
+ self.resx = resx
122
+ self.resy = resy
123
+ self.response_crs = response_crs
124
+ self.band_names = band_names
125
+
126
+ # Represent the coverage as a single item that matches all windows.
127
+ item_name = f"{self.service_id}:{self.coverage_id}"
128
+ self._items = [Item(item_name, get_global_geometry(time_range=None))]
129
+
130
+ def get_items(
131
+ self, geometries: list[STGeometry], query_config: QueryConfig
132
+ ) -> list[list[list[Item]]]:
133
+ """Get item groups matching each requested geometry."""
134
+ groups = []
135
+ for geometry in geometries:
136
+ cur_groups = match_candidate_items_to_window(
137
+ geometry, self._items, query_config
138
+ )
139
+ groups.append(cur_groups)
140
+ return groups
141
+
142
+ def deserialize_item(self, serialized_item: Any) -> Item:
143
+ """Deserialize an item from JSON-decoded data."""
144
+ return Item.deserialize(serialized_item)
145
+
146
+ def ingest(
147
+ self,
148
+ tile_store: TileStoreWithLayer,
149
+ items: list[Item],
150
+ geometries: list[list[STGeometry]],
151
+ ) -> None:
152
+ """Ingest is not supported (direct materialization only)."""
153
+ raise NotImplementedError(
154
+ "SoilGrids is intended for direct materialization; set data_source.ingest=false."
155
+ )
156
+
157
+ def is_raster_ready(
158
+ self, layer_name: str, item_name: str, bands: list[str]
159
+ ) -> bool:
160
+ """Return whether the requested raster is ready (always true for direct reads)."""
161
+ return True
162
+
163
+ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
164
+ """Return the band sets available for this coverage."""
165
+ return [self.band_names]
166
+
167
+ def get_raster_bounds(
168
+ self, layer_name: str, item_name: str, bands: list[str], projection: Projection
169
+ ) -> PixelBounds:
170
+ """Return (approximate) bounds for this raster in the requested projection."""
171
+ # We don't know bounds without an extra metadata request; treat as "very large"
172
+ # so materialization always attempts reads for windows.
173
+ return (-(10**9), -(10**9), 10**9, 10**9)
174
+
175
+ def _download_geotiff(
176
+ self,
177
+ west: float,
178
+ south: float,
179
+ east: float,
180
+ north: float,
181
+ output: str,
182
+ width: int | None,
183
+ height: int | None,
184
+ resx: float | None,
185
+ resy: float | None,
186
+ ) -> None:
187
+ from soilgrids import SoilGrids as SoilGridsClient
188
+
189
+ client = SoilGridsClient()
190
+ kwargs: dict[str, Any] = dict(
191
+ service_id=self.service_id,
192
+ coverage_id=self.coverage_id,
193
+ crs=_crs_to_soilgrids_urn(self.crs),
194
+ west=west,
195
+ south=south,
196
+ east=east,
197
+ north=north,
198
+ output=output,
199
+ )
200
+ if width is not None and height is not None:
201
+ kwargs["width"] = width
202
+ kwargs["height"] = height
203
+ elif resx is not None and resy is not None:
204
+ kwargs["resx"] = resx
205
+ kwargs["resy"] = resy
206
+
207
+ if self.response_crs is not None:
208
+ kwargs["response_crs"] = _crs_to_soilgrids_urn(self.response_crs)
209
+
210
+ client.get_coverage_data(**kwargs)
211
+
212
+ def read_raster(
213
+ self,
214
+ layer_name: str,
215
+ item_name: str,
216
+ bands: list[str],
217
+ projection: Projection,
218
+ bounds: PixelBounds,
219
+ resampling: Resampling = Resampling.bilinear,
220
+ ) -> npt.NDArray[Any]:
221
+ """Read and reproject a SoilGrids coverage subset into the requested grid."""
222
+ if bands != self.band_names:
223
+ raise ValueError(
224
+ f"expected request for bands {self.band_names} but got {bands}"
225
+ )
226
+
227
+ # Compute bounding box in CRS coordinates for the request.
228
+ request_crs = _crs_to_rasterio(self.crs)
229
+ request_projection = Projection(request_crs, 1.0, 1.0)
230
+ request_geom = STGeometry(projection, shapely.box(*bounds), None).to_projection(
231
+ request_projection
232
+ )
233
+ west, south, east, north = request_geom.shp.bounds
234
+
235
+ # Determine output grid for the WCS request.
236
+ #
237
+ # If the user explicitly configured an output grid (width/height or resx/resy),
238
+ # we respect it.
239
+ #
240
+ # Otherwise, default to requesting at ~250 m resolution in the request CRS
241
+ # (when it is projected), and then reprojecting to the window grid.
242
+ #
243
+ # For EPSG:4326 requests, SoilGrids WCS requires WIDTH/HEIGHT, so we default
244
+ # to matching the window pixel size.
245
+ window_width = bounds[2] - bounds[0]
246
+ window_height = bounds[3] - bounds[1]
247
+
248
+ out_width = self.width
249
+ out_height = self.height
250
+ out_resx = self.resx
251
+ out_resy = self.resy
252
+
253
+ if request_crs.to_epsg() == 4326 and out_width is None:
254
+ # Required by the SoilGrids WCS for EPSG:4326; resx/resy is not accepted.
255
+ out_width = window_width
256
+ out_height = window_height
257
+ out_resx = None
258
+ out_resy = None
259
+ elif out_width is None and out_resx is None:
260
+ # Default to native-ish SoilGrids resolution (~250 m) in projected CRSs.
261
+ out_resx = 250.0
262
+ out_resy = 250.0
263
+
264
+ with tempfile.TemporaryDirectory(prefix="rslearn_soilgrids_") as tmpdir:
265
+ output_path = str(UPath(tmpdir) / "coverage.tif")
266
+ self._download_geotiff(
267
+ west=west,
268
+ south=south,
269
+ east=east,
270
+ north=north,
271
+ output=output_path,
272
+ width=out_width,
273
+ height=out_height,
274
+ resx=out_resx,
275
+ resy=out_resy,
276
+ )
277
+
278
+ with rasterio.open(output_path) as src:
279
+ src_array = src.read(1).astype(np.float32)
280
+ src_nodata = src.nodata
281
+ scale = float(src.scales[0]) if src.scales else 1.0
282
+ offset = float(src.offsets[0]) if src.offsets else 0.0
283
+
284
+ if src_nodata is not None:
285
+ valid_mask = src_array != float(src_nodata)
286
+ src_array[valid_mask] = src_array[valid_mask] * scale + offset
287
+ dst_nodata = float(src_nodata)
288
+ src_nodata_val = dst_nodata
289
+ else:
290
+ src_array = src_array * scale + offset
291
+ dst_nodata = SOILGRIDS_NODATA_VALUE
292
+ src_nodata_val = None
293
+
294
+ src_chw = src_array[None, :, :]
295
+ dst = np.full(
296
+ (1, bounds[3] - bounds[1], bounds[2] - bounds[0]),
297
+ dst_nodata,
298
+ dtype=np.float32,
299
+ )
300
+ dst_transform = get_transform_from_projection_and_bounds(
301
+ projection, bounds
302
+ )
303
+
304
+ rasterio.warp.reproject(
305
+ source=src_chw,
306
+ src_crs=src.crs,
307
+ src_transform=src.transform,
308
+ src_nodata=src_nodata_val,
309
+ destination=dst,
310
+ dst_crs=projection.crs,
311
+ dst_transform=dst_transform,
312
+ dst_nodata=dst_nodata,
313
+ resampling=resampling,
314
+ )
315
+ return dst
316
+
317
+ def materialize(
318
+ self,
319
+ window: Window,
320
+ item_groups: list[list[Item]],
321
+ layer_name: str,
322
+ layer_cfg: LayerConfig,
323
+ ) -> None:
324
+ """Materialize a window by reading from SoilGrids on-demand."""
325
+ RasterMaterializer().materialize(
326
+ TileStoreWithLayer(self, layer_name),
327
+ window,
328
+ layer_name,
329
+ layer_cfg,
330
+ item_groups,
331
+ )
@@ -0,0 +1,275 @@
1
+ """A partial data source implementation providing get_items using a STAC API."""
2
+
3
+ import json
4
+ from datetime import datetime
5
+ from typing import Any
6
+
7
+ import shapely
8
+ from upath import UPath
9
+
10
+ from rslearn.config import QueryConfig
11
+ from rslearn.const import WGS84_PROJECTION
12
+ from rslearn.data_sources.data_source import Item, ItemLookupDataSource
13
+ from rslearn.data_sources.utils import match_candidate_items_to_window
14
+ from rslearn.log_utils import get_logger
15
+ from rslearn.utils.geometry import STGeometry
16
+ from rslearn.utils.stac import StacClient, StacItem
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class SourceItem(Item):
22
+ """An item in the StacDataSource data source."""
23
+
24
+ def __init__(
25
+ self,
26
+ name: str,
27
+ geometry: STGeometry,
28
+ asset_urls: dict[str, str],
29
+ properties: dict[str, str],
30
+ ):
31
+ """Creates a new SourceItem.
32
+
33
+ Args:
34
+ name: unique name of the item
35
+ geometry: the spatial and temporal extent of the item
36
+ asset_urls: map from asset key to the unsigned asset URL.
37
+ properties: properties requested by the data source implementation.
38
+ """
39
+ super().__init__(name, geometry)
40
+ self.asset_urls = asset_urls
41
+ self.properties = properties
42
+
43
+ def serialize(self) -> dict[str, Any]:
44
+ """Serializes the item to a JSON-encodable dictionary."""
45
+ d = super().serialize()
46
+ d["asset_urls"] = self.asset_urls
47
+ d["properties"] = self.properties
48
+ return d
49
+
50
+ @staticmethod
51
+ def deserialize(d: dict[str, Any]) -> "SourceItem":
52
+ """Deserializes an item from a JSON-decoded dictionary."""
53
+ item = super(SourceItem, SourceItem).deserialize(d)
54
+ return SourceItem(
55
+ name=item.name,
56
+ geometry=item.geometry,
57
+ asset_urls=d["asset_urls"],
58
+ properties=d["properties"],
59
+ )
60
+
61
+
62
+ class StacDataSource(ItemLookupDataSource[SourceItem]):
63
+ """A partial data source implementing get_items using a STAC API.
64
+
65
+ This is a helper class that full implementations can extend to not have to worry
66
+ about the get_items and get_item_by_name implementation.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ endpoint: str,
72
+ collection_name: str,
73
+ query: dict[str, Any] | None = None,
74
+ sort_by: str | None = None,
75
+ sort_ascending: bool = True,
76
+ required_assets: list[str] | None = None,
77
+ cache_dir: UPath | None = None,
78
+ limit: int = 100,
79
+ properties_to_record: list[str] = [],
80
+ ):
81
+ """Create a new StacDataSource.
82
+
83
+ Args:
84
+ endpoint: the STAC endpoint to use.
85
+ collection_name: the STAC collection name.
86
+ query: optional STAC query dict to include in searches, e.g. {"eo:cloud_cover": {"lt": 50}}.
87
+ sort_by: sort results by this STAC property.
88
+ sort_ascending: if sort_by is set, sort in ascending order (default).
89
+ Otherwise sort in descending order.
90
+ required_assets: if set, we ignore items that do not have all of these
91
+ asset keys.
92
+ cache_dir: optional cache directory to cache items. This is recommended if
93
+ allowing direct materialization from the data source, since it will
94
+ likely be necessary to make lots of get_item_by_name calls during
95
+ materialization. TODO: give direct materialization access to the Item
96
+ object.
97
+ limit: limit to pass to search queries.
98
+ properties_to_record: if these properties on the STAC item exist, they are
99
+ are retained in the SourceItem when we initialize it.
100
+ """
101
+ self.client = StacClient(endpoint)
102
+ self.collection_name = collection_name
103
+ self.query = query
104
+ self.sort_by = sort_by
105
+ self.sort_ascending = sort_ascending
106
+ self.required_assets = required_assets
107
+ self.cache_dir = cache_dir
108
+ self.limit = limit
109
+ self.properties_to_record = properties_to_record
110
+
111
+ def _stac_item_to_item(self, stac_item: StacItem) -> SourceItem:
112
+ # Make sure geometry, time range, and assets are set.
113
+ if stac_item.geometry is None:
114
+ raise ValueError("got unexpected item with no geometry")
115
+ if stac_item.time_range is None:
116
+ raise ValueError("got unexpected item with no time range")
117
+ if stac_item.assets is None:
118
+ raise ValueError("got unexpected item with no assets")
119
+
120
+ shp = shapely.geometry.shape(stac_item.geometry)
121
+ geom = STGeometry(WGS84_PROJECTION, shp, stac_item.time_range)
122
+ asset_urls = {
123
+ asset_key: asset_obj.href
124
+ for asset_key, asset_obj in stac_item.assets.items()
125
+ }
126
+
127
+ # Keep any properties requested by the data source implementation.
128
+ properties = {}
129
+ for prop_name in self.properties_to_record:
130
+ if prop_name not in stac_item.properties:
131
+ continue
132
+ properties[prop_name] = stac_item.properties[prop_name]
133
+
134
+ return SourceItem(stac_item.id, geom, asset_urls, properties)
135
+
136
+ def _get_search_time_range(
137
+ self, geometry: STGeometry
138
+ ) -> datetime | tuple[datetime, datetime] | None:
139
+ """Get time range to include in STAC API search.
140
+
141
+ By default, we filter STAC searches to the window's time range. Subclasses can
142
+ override this to disable time filtering for "static" datasets.
143
+
144
+ Args:
145
+ geometry: the geometry we are searching for.
146
+
147
+ Returns:
148
+ the time range (or timestamp) to pass to the STAC search, or None to avoid
149
+ temporal filtering in the search request.
150
+ """
151
+ # Note: StacClient.search accepts either a datetime or a (start, end) tuple.
152
+ return geometry.time_range
153
+
154
+ def get_item_by_name(self, name: str) -> SourceItem:
155
+ """Gets an item by name.
156
+
157
+ Args:
158
+ name: the name of the item to get
159
+
160
+ Returns:
161
+ the item object
162
+ """
163
+ # If cache_dir is set, we cache the item. First here we check if it is already
164
+ # in the cache.
165
+ cache_fname: UPath | None = None
166
+ if self.cache_dir:
167
+ cache_fname = self.cache_dir / f"{name}.json"
168
+ if cache_fname is not None and cache_fname.exists():
169
+ with cache_fname.open() as f:
170
+ return SourceItem.deserialize(json.load(f))
171
+
172
+ # No cache or not in cache, so we need to make the STAC request.
173
+ logger.debug(f"Getting STAC item {name}")
174
+ stac_items = self.client.search(ids=[name], collections=[self.collection_name])
175
+
176
+ if len(stac_items) == 0:
177
+ raise ValueError(
178
+ f"Item {name} not found in collection {self.collection_name}"
179
+ )
180
+ if len(stac_items) > 1:
181
+ raise ValueError(
182
+ f"Multiple items found for ID {name} in collection {self.collection_name}"
183
+ )
184
+
185
+ stac_item = stac_items[0]
186
+ item = self._stac_item_to_item(stac_item)
187
+
188
+ # Finally we cache it if cache_dir is set.
189
+ if cache_fname is not None:
190
+ with cache_fname.open("w") as f:
191
+ json.dump(item.serialize(), f)
192
+
193
+ return item
194
+
195
+ def get_items(
196
+ self, geometries: list[STGeometry], query_config: QueryConfig
197
+ ) -> list[list[list[SourceItem]]]:
198
+ """Get a list of items in the data source intersecting the given geometries.
199
+
200
+ Args:
201
+ geometries: the spatiotemporal geometries
202
+ query_config: the query configuration
203
+
204
+ Returns:
205
+ List of groups of items that should be retrieved for each geometry.
206
+ """
207
+ groups = []
208
+ for geometry in geometries:
209
+ # Get potentially relevant items from the collection by performing one search
210
+ # for each requested geometry.
211
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
212
+ logger.debug("performing STAC search for geometry %s", wgs84_geometry)
213
+ search_time_range = self._get_search_time_range(wgs84_geometry)
214
+ stac_items = self.client.search(
215
+ collections=[self.collection_name],
216
+ intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
217
+ date_time=search_time_range,
218
+ query=self.query,
219
+ limit=self.limit,
220
+ )
221
+ logger.debug("STAC search yielded %d items", len(stac_items))
222
+
223
+ if self.required_assets is not None:
224
+ # Filter out items that are missing any of the assets in self.asset_bands.
225
+ good_stac_items = []
226
+ for stac_item in stac_items:
227
+ if stac_item.assets is None:
228
+ raise ValueError(f"got STAC item {stac_item.id} with no assets")
229
+
230
+ good = True
231
+ for asset_key in self.required_assets:
232
+ if asset_key in stac_item.assets:
233
+ continue
234
+ good = False
235
+ break
236
+ if good:
237
+ good_stac_items.append(stac_item)
238
+ logger.debug(
239
+ "required_assets filter from %d to %d items",
240
+ len(stac_items),
241
+ len(good_stac_items),
242
+ )
243
+ stac_items = good_stac_items
244
+
245
+ if self.sort_by is not None:
246
+ sort_by = self.sort_by
247
+ stac_items.sort(
248
+ key=lambda stac_item: stac_item.properties[sort_by],
249
+ reverse=not self.sort_ascending,
250
+ )
251
+
252
+ candidate_items = [
253
+ self._stac_item_to_item(stac_item) for stac_item in stac_items
254
+ ]
255
+
256
+ # Since we made the STAC request, might as well save these to the cache.
257
+ if self.cache_dir is not None:
258
+ for item in candidate_items:
259
+ cache_fname = self.cache_dir / f"{item.name}.json"
260
+ if cache_fname.exists():
261
+ continue
262
+ with cache_fname.open("w") as f:
263
+ json.dump(item.serialize(), f)
264
+
265
+ cur_groups = match_candidate_items_to_window(
266
+ geometry, candidate_items, query_config
267
+ )
268
+ groups.append(cur_groups)
269
+
270
+ return groups
271
+
272
+ def deserialize_item(self, serialized_item: Any) -> SourceItem:
273
+ """Deserializes an item from JSON-decoded data."""
274
+ assert isinstance(serialized_item, dict)
275
+ return SourceItem.deserialize(serialized_item)
rslearn/main.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import argparse
4
4
  import multiprocessing
5
+ import os
5
6
  import random
6
7
  import sys
7
8
  import time
@@ -45,6 +46,7 @@ handler_registry = {}
45
46
  ItemType = TypeVar("ItemType", bound="Item")
46
47
 
47
48
  MULTIPROCESSING_CONTEXT = "forkserver"
49
+ MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
48
50
 
49
51
 
50
52
  def register_handler(category: Any, command: str) -> Callable:
@@ -837,7 +839,8 @@ def model_predict() -> None:
837
839
  def main() -> None:
838
840
  """CLI entrypoint."""
839
841
  try:
840
- multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)
842
+ mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
843
+ multiprocessing.set_start_method(mp_context)
841
844
  except RuntimeError as e:
842
845
  logger.error(
843
846
  f"Multiprocessing context already set to {multiprocessing.get_context()}: "
@@ -150,8 +150,11 @@ class AttentionPool(IntermediateComponent):
150
150
  D // self.num_heads
151
151
  )
152
152
  attn_weights = F.softmax(attn_scores, dim=-1)
153
- x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
- return x.reshape(B, D, H, W)
153
+ x = torch.matmul(attn_weights, v) # [B*H*W, num_heads, 1, D_head]
154
+ x = x.squeeze(-2) # [B*H*W, num_heads, D_head]
155
+ return rearrange(
156
+ x, "(b h w) nh dh -> b (nh dh) h w", b=B, h=H, w=W
157
+ ) # [B, D, H, W]
155
158
 
156
159
  def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
160
  """Forward pass for attention pooling linear probe.