giga-spatial 0.6.9__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -145,11 +145,18 @@ class GeometryBasedZonalViewGenerator(ZonalViewGenerator[T]):
145
145
  gpd.GeoDataFrame: A GeoDataFrame with 'zone_id' and 'geometry' columns.
146
146
  The zone_id column is renamed from the original zone_id_column if different.
147
147
  """
148
- # If we already have a GeoDataFrame, just rename the ID column if needed
149
- result = self._zone_gdf.copy()
150
- if self.zone_id_column != "zone_id":
151
- result = result.rename(columns={self.zone_id_column: "zone_id"})
152
- return result
148
+ # Since _zone_gdf is already created with 'zone_id' column in the constructor,
149
+ # we just need to return a copy of it
150
+ return self._zone_gdf.copy()
151
+
152
+ @property
153
+ def zone_gdf(self) -> gpd.GeoDataFrame:
154
+ """Override the base class zone_gdf property to ensure correct column names.
155
+
156
+ Returns:
157
+ gpd.GeoDataFrame: A GeoDataFrame with 'zone_id' and 'geometry' columns.
158
+ """
159
+ return self._zone_gdf.copy()
153
160
 
154
161
  def map_built_s(
155
162
  self,
@@ -254,13 +261,16 @@ class GeometryBasedZonalViewGenerator(ZonalViewGenerator[T]):
254
261
  f"Mapping {handler.config.product} data (year: {handler.config.year}, resolution: {handler.config.resolution}m)"
255
262
  )
256
263
  tif_processors = handler.load_data(
257
- self.zone_gdf, ensure_available=self.config.ensure_available
264
+ self.zone_gdf,
265
+ ensure_available=self.config.ensure_available,
266
+ merge_rasters=True,
267
+ **kwargs,
258
268
  )
259
269
 
260
270
  self.logger.info(
261
271
  f"Sampling {handler.config.product} data using '{stat}' statistic"
262
272
  )
263
- sampled_values = self.map_rasters(tif_processors=tif_processors, stat=stat)
273
+ sampled_values = self.map_rasters(raster_data=tif_processors, stat=stat)
264
274
 
265
275
  column_name = (
266
276
  output_column
@@ -481,57 +491,97 @@ class GeometryBasedZonalViewGenerator(ZonalViewGenerator[T]):
481
491
  self,
482
492
  country: Union[str, List[str]],
483
493
  resolution=1000,
484
- predicate: Literal["intersects", "fractional"] = "intersects",
494
+ predicate: Literal[
495
+ "centroid_within", "intersects", "fractional"
496
+ ] = "intersects",
485
497
  output_column: str = "population",
486
498
  **kwargs,
487
499
  ):
488
- if isinstance(country, str):
489
- country = [country]
500
+
501
+ # Ensure country is always a list for consistent handling
502
+ countries_list = [country] if isinstance(country, str) else country
490
503
 
491
504
  handler = WPPopulationHandler(
492
- project="pop", resolution=resolution, data_store=self.data_store, **kwargs
505
+ resolution=resolution,
506
+ data_store=self.data_store,
507
+ **kwargs,
493
508
  )
494
509
 
510
+ # Restrict to single country for age_structures project
511
+ if handler.config.project == "age_structures" and len(countries_list) > 1:
512
+ raise ValueError(
513
+ "For 'age_structures' project, only a single country can be processed at a time."
514
+ )
515
+
495
516
  self.logger.info(
496
517
  f"Mapping WorldPop Population data (year: {handler.config.year}, resolution: {handler.config.resolution}m)"
497
518
  )
498
519
 
499
- if predicate == "fractional":
500
- if resolution == 100:
501
- self.logger.warning(
502
- "Fractional aggregations only supported for datasets with 1000m resolution. Using `intersects` as predicate"
520
+ if predicate == "fractional" and resolution == 100:
521
+ self.logger.warning(
522
+ "Fractional aggregations only supported for datasets with 1000m resolution. Using `intersects` as predicate"
523
+ )
524
+ predicate = "intersects"
525
+
526
+ if predicate == "centroid_within":
527
+ if handler.config.project == "age_structures":
528
+ # Load individual tif processors for the single country
529
+ all_tif_processors = handler.load_data(
530
+ countries_list[0],
531
+ ensure_available=self.config.ensure_available,
532
+ **kwargs,
503
533
  )
504
- predicate = "intersects"
534
+
535
+ # Sum results from each tif_processor separately
536
+ all_results_by_zone = {
537
+ zone_id: 0 for zone_id in self.get_zone_identifiers()
538
+ }
539
+ self.logger.info(
540
+ f"Sampling individual age_structures rasters using 'sum' statistic and summing per zone."
541
+ )
542
+ for tif_processor in all_tif_processors:
543
+ single_raster_result = self.map_rasters(
544
+ raster_data=tif_processor, stat="sum"
545
+ )
546
+ for zone_id, value in single_raster_result.items():
547
+ all_results_by_zone[zone_id] += value
548
+ result = all_results_by_zone
505
549
  else:
506
- gdf_pop = pd.concat(
507
- [
508
- handler.load_into_geodataframe(
509
- c, ensure_available=self.config.ensure_available
550
+ # Existing behavior for non-age_structures projects or if merging is fine
551
+ tif_processors = []
552
+ for c in countries_list:
553
+ tif_processors.extend(
554
+ handler.load_data(
555
+ c,
556
+ ensure_available=self.config.ensure_available,
557
+ **kwargs,
510
558
  )
511
- for c in country
512
- ],
513
- ignore_index=True,
514
- )
515
-
516
- result = self.map_polygons(
517
- gdf_pop,
518
- value_columns="pixel_value",
519
- aggregation="sum",
520
- predicate=predicate,
559
+ )
560
+ self.logger.info(
561
+ f"Sampling WorldPop Population data using 'sum' statistic"
521
562
  )
522
-
523
- self.add_variable_to_view(result, output_column)
524
- return self.view
525
-
526
- tif_processors = []
527
- for c in country:
528
- tif_processors.extend(
529
- handler.load_data(c, ensure_available=self.config.ensure_available)
563
+ result = self.map_rasters(raster_data=tif_processors, stat="sum")
564
+ else:
565
+ gdf_pop = pd.concat(
566
+ [
567
+ handler.load_into_geodataframe(
568
+ c,
569
+ ensure_available=self.config.ensure_available,
570
+ **kwargs,
571
+ )
572
+ for c in countries_list
573
+ ],
574
+ ignore_index=True,
530
575
  )
531
576
 
532
- self.logger.info(f"Sampling WorldPop Population data using 'sum' statistic")
533
- sampled_values = self.map_rasters(tif_processors=tif_processors, stat="sum")
577
+ self.logger.info(f"Aggregating WorldPop Population data to the zones.")
578
+ result = self.map_polygons(
579
+ gdf_pop,
580
+ value_columns="pixel_value",
581
+ aggregation="sum",
582
+ predicate=predicate,
583
+ )
534
584
 
535
- self.add_variable_to_view(sampled_values, output_column)
585
+ self.add_variable_to_view(result, output_column)
536
586
 
537
587
  return self.view
gigaspatial/grid/h3.py ADDED
@@ -0,0 +1,417 @@
1
+ import pandas as pd
2
+ import geopandas as gpd
3
+ import h3
4
+ from shapely.geometry import Polygon, Point, shape
5
+ from shapely.geometry.base import BaseGeometry
6
+ from shapely.strtree import STRtree
7
+ import json
8
+ from pathlib import Path
9
+ from pydantic import BaseModel, Field
10
+ from typing import List, Union, Iterable, Optional, Tuple, ClassVar, Literal
11
+ import pycountry
12
+
13
+ from gigaspatial.core.io.data_store import DataStore
14
+ from gigaspatial.core.io.local_data_store import LocalDataStore
15
+ from gigaspatial.config import config
16
+
17
+
18
+ class H3Hexagons(BaseModel):
19
+ resolution: int = Field(..., ge=0, le=15)
20
+ hexagons: List[str] = Field(default_factory=list)
21
+ data_store: DataStore = Field(default_factory=LocalDataStore, exclude=True)
22
+ logger: ClassVar = config.get_logger("H3Hexagons")
23
+
24
+ class Config:
25
+ arbitrary_types_allowed = True
26
+
27
+ @classmethod
28
+ def from_hexagons(cls, hexagons: List[str]):
29
+ """Create H3Hexagons from list of H3 cell IDs."""
30
+ if not hexagons:
31
+ cls.logger.warning("No hexagons provided to from_hexagons.")
32
+ return cls(resolution=0, hexagons=[])
33
+
34
+ cls.logger.info(
35
+ f"Initializing H3Hexagons from {len(hexagons)} provided hexagons."
36
+ )
37
+ # Get resolution from first hexagon
38
+ resolution = h3.get_resolution(hexagons[0])
39
+ return cls(resolution=resolution, hexagons=list(set(hexagons)))
40
+
41
+ @classmethod
42
+ def from_bounds(
43
+ cls, xmin: float, ymin: float, xmax: float, ymax: float, resolution: int
44
+ ):
45
+ """Create H3Hexagons from boundary coordinates."""
46
+ cls.logger.info(
47
+ f"Creating H3Hexagons from bounds: ({xmin}, {ymin}, {xmax}, {ymax}) at resolution: {resolution}"
48
+ )
49
+
50
+ # Create a LatLong bounding box polygon
51
+ latlong_bbox_coords = [
52
+ [ymin, xmin],
53
+ [ymax, xmin],
54
+ [ymax, xmax],
55
+ [ymin, xmax],
56
+ [ymin, xmin],
57
+ ]
58
+
59
+ # Get H3 cells that intersect with the bounding box
60
+ poly = h3.LatLngPoly(latlong_bbox_coords)
61
+ hexagons = h3.h3shape_to_cells(poly, res=resolution)
62
+
63
+ return cls(resolution=resolution, hexagons=list(hexagons))
64
+
65
+ @classmethod
66
+ def from_spatial(
67
+ cls,
68
+ source: Union[
69
+ BaseGeometry,
70
+ gpd.GeoDataFrame,
71
+ List[Union[Point, Tuple[float, float]]], # points
72
+ ],
73
+ resolution: int,
74
+ contain: Literal["center", "full", "overlap", "bbox_overlap"] = "overlap",
75
+ **kwargs,
76
+ ):
77
+ cls.logger.info(
78
+ f"Creating H3Hexagons from spatial source (type: {type(source)}) at resolution: {resolution} with predicate: {contain}"
79
+ )
80
+ if isinstance(source, gpd.GeoDataFrame):
81
+ if source.crs != "EPSG:4326":
82
+ source = source.to_crs("EPSG:4326")
83
+
84
+ is_point_series = source.geometry.geom_type == "Point"
85
+ all_are_points = is_point_series.all()
86
+
87
+ if all_are_points:
88
+ source = source.geometry.to_list()
89
+ else:
90
+ source = source.geometry.unary_union
91
+
92
+ if isinstance(source, BaseGeometry):
93
+ return cls.from_geometry(
94
+ geometry=source, resolution=resolution, contain=contain, **kwargs
95
+ )
96
+ elif isinstance(source, Iterable) and all(
97
+ isinstance(pt, Point) or len(pt) == 2 for pt in source
98
+ ):
99
+ return cls.from_points(points=source, resolution=resolution, **kwargs)
100
+ else:
101
+ raise ValueError("Unsupported source type for H3Hexagons.from_spatial")
102
+
103
+ @classmethod
104
+ def from_geometry(
105
+ cls,
106
+ geometry: BaseGeometry,
107
+ resolution: int,
108
+ contain: Literal["center", "full", "overlap", "bbox_overlap"] = "overlap",
109
+ **kwargs,
110
+ ):
111
+ """Create H3Hexagons from a geometry."""
112
+ cls.logger.info(
113
+ f"Creating H3Hexagons from geometry (bounds: {geometry.bounds}) at resolution: {resolution} with predicate: {contain}"
114
+ )
115
+
116
+ if isinstance(geometry, Point):
117
+ return cls.from_points([geometry])
118
+
119
+ # Convert shapely geometry to GeoJSON-like format
120
+ if hasattr(geometry, "__geo_interface__"):
121
+ geojson_geom = geometry.__geo_interface__
122
+ else:
123
+ # Fallback for complex geometries
124
+ import json
125
+ from shapely.geometry import mapping
126
+
127
+ geojson_geom = mapping(geometry)
128
+
129
+ h3_geom = h3.geo_to_h3shape(geojson_geom)
130
+
131
+ hexagons = h3.h3shape_to_cells_experimental(
132
+ h3_geom, resolution, contain=contain
133
+ )
134
+
135
+ cls.logger.info(
136
+ f"Generated {len(hexagons)} hexagons using `{contain}` spatial predicate."
137
+ )
138
+ return cls(resolution=resolution, hexagons=list(hexagons), **kwargs)
139
+
140
+ @classmethod
141
+ def from_points(
142
+ cls, points: List[Union[Point, Tuple[float, float]]], resolution: int, **kwargs
143
+ ) -> "H3Hexagons":
144
+ """Create H3Hexagons from a list of points or lat-lon pairs."""
145
+ cls.logger.info(
146
+ f"Creating H3Hexagons from {len(points)} points at resolution: {resolution}"
147
+ )
148
+ hexagons = set(cls.get_hexagons_from_points(points, resolution))
149
+ cls.logger.info(f"Generated {len(hexagons)} unique hexagons from points.")
150
+ return cls(resolution=resolution, hexagons=list(hexagons), **kwargs)
151
+
152
+ @classmethod
153
+ def from_json(
154
+ cls, data_store: DataStore, file: Union[str, Path], **kwargs
155
+ ) -> "H3Hexagons":
156
+ """Load H3Hexagons from a JSON file."""
157
+ cls.logger.info(
158
+ f"Loading H3Hexagons from JSON file: {file} using data store: {type(data_store).__name__}"
159
+ )
160
+ with data_store.open(str(file), "r") as f:
161
+ data = json.load(f)
162
+ if isinstance(data, list): # If file contains only hexagon IDs
163
+ # Get resolution from first hexagon if available
164
+ resolution = h3.get_resolution(data[0]) if data else 0
165
+ data = {
166
+ "resolution": resolution,
167
+ "hexagons": data,
168
+ **kwargs,
169
+ }
170
+ else:
171
+ data.update(kwargs)
172
+ instance = cls(**data)
173
+ instance.data_store = data_store
174
+ cls.logger.info(
175
+ f"Successfully loaded {len(instance.hexagons)} hexagons from JSON file."
176
+ )
177
+ return instance
178
+
179
+ @property
180
+ def average_hexagon_area(self):
181
+ return h3.average_hexagon_area(self.resolution)
182
+
183
+ @property
184
+ def average_hexagon_edge_length(self):
185
+ return h3.average_hexagon_edge_length(self.resolution)
186
+
187
+ def filter_hexagons(self, hexagons: Iterable[str]) -> "H3Hexagons":
188
+ """Filter hexagons by a given set of hexagon IDs."""
189
+ original_count = len(self.hexagons)
190
+ incoming_count = len(
191
+ list(hexagons)
192
+ ) # Convert to list to get length if it's an iterator
193
+
194
+ self.logger.info(
195
+ f"Filtering {original_count} hexagons with an incoming set of {incoming_count} hexagons."
196
+ )
197
+ filtered_hexagons = list(set(self.hexagons) & set(hexagons))
198
+ self.logger.info(f"Resulting in {len(filtered_hexagons)} filtered hexagons.")
199
+ return H3Hexagons(
200
+ resolution=self.resolution,
201
+ hexagons=filtered_hexagons,
202
+ )
203
+
204
+ def to_dataframe(self) -> pd.DataFrame:
205
+ """Convert to pandas DataFrame with hexagon ID and centroid coordinates."""
206
+ self.logger.info(
207
+ f"Converting {len(self.hexagons)} hexagons to pandas DataFrame."
208
+ )
209
+ if not self.hexagons:
210
+ self.logger.warning(
211
+ "No hexagons to convert to DataFrame. Returning empty DataFrame."
212
+ )
213
+ return pd.DataFrame(columns=["hexagon", "latitude", "longitude"])
214
+
215
+ centroids = [h3.cell_to_latlng(hex_id) for hex_id in self.hexagons]
216
+
217
+ self.logger.info(f"Successfully converted to DataFrame.")
218
+
219
+ return pd.DataFrame(
220
+ {
221
+ "hexagon": self.hexagons,
222
+ "latitude": [c[0] for c in centroids],
223
+ "longitude": [c[1] for c in centroids],
224
+ }
225
+ )
226
+
227
+ def to_geoms(self) -> List[Polygon]:
228
+ """Convert hexagons to shapely Polygon geometries."""
229
+ self.logger.info(
230
+ f"Converting {len(self.hexagons)} hexagons to shapely Polygon geometries."
231
+ )
232
+ return [shape(h3.cells_to_geo([hex_id])) for hex_id in self.hexagons]
233
+
234
+ def to_geodataframe(self) -> gpd.GeoDataFrame:
235
+ """Convert to GeoPandas GeoDataFrame."""
236
+ return gpd.GeoDataFrame(
237
+ {"h3": self.hexagons, "geometry": self.to_geoms()}, crs="EPSG:4326"
238
+ )
239
+
240
+ @staticmethod
241
+ def get_hexagons_from_points(
242
+ points: List[Union[Point, Tuple[float, float]]], resolution: int
243
+ ) -> List[str]:
244
+ """Get list of H3 hexagon IDs for the provided points at specified resolution.
245
+
246
+ Args:
247
+ points: List of points as either shapely Points or (lon, lat) tuples
248
+ resolution: H3 resolution level
249
+
250
+ Returns:
251
+ List of H3 hexagon ID strings
252
+ """
253
+ hexagons = []
254
+ for p in points:
255
+ if isinstance(p, Point):
256
+ # Shapely Point has x=lon, y=lat
257
+ hex_id = h3.latlng_to_cell(p.y, p.x, resolution)
258
+ else:
259
+ # Assume tuple is (lon, lat) - convert to (lat, lon) for h3
260
+ hex_id = h3.latlng_to_cell(p[1], p[0], resolution)
261
+ hexagons.append(hex_id)
262
+ return hexagons
263
+
264
+ def get_neighbors(self, k: int = 1) -> "H3Hexagons":
265
+ """Get k-ring neighbors of all hexagons.
266
+
267
+ Args:
268
+ k: Distance of neighbors (1 for immediate neighbors, 2 for neighbors of neighbors, etc.)
269
+
270
+ Returns:
271
+ New H3Hexagons instance with neighbors included
272
+ """
273
+ self.logger.info(
274
+ f"Getting k-ring neighbors (k={k}) for {len(self.hexagons)} hexagons."
275
+ )
276
+
277
+ all_neighbors = set()
278
+ for hex_id in self.hexagons:
279
+ neighbors = h3.grid_ring(hex_id, k)
280
+ all_neighbors.update(neighbors)
281
+
282
+ self.logger.info(
283
+ f"Found {len(all_neighbors)} total hexagons including neighbors."
284
+ )
285
+ return H3Hexagons(resolution=self.resolution, hexagons=list(all_neighbors))
286
+
287
+ def get_compact_representation(self) -> "H3Hexagons":
288
+ """Get compact representation by merging adjacent hexagons into parent cells where possible."""
289
+ self.logger.info(f"Compacting {len(self.hexagons)} hexagons.")
290
+
291
+ # Convert to set for h3.compact
292
+ hex_set = set(self.hexagons)
293
+ compacted = h3.compact_cells(hex_set)
294
+
295
+ self.logger.info(f"Compacted to {len(compacted)} hexagons.")
296
+
297
+ # Note: compacted representation may have mixed resolutions
298
+ # We'll keep the original resolution as the "target" resolution
299
+ return H3Hexagons(resolution=self.resolution, hexagons=list(compacted))
300
+
301
+ def get_children(self, target_resolution: int) -> "H3Hexagons":
302
+ """Get children hexagons at higher resolution.
303
+
304
+ Args:
305
+ target_resolution: Target resolution (must be higher than current)
306
+
307
+ Returns:
308
+ New H3Hexagons instance with children at target resolution
309
+ """
310
+ if target_resolution <= self.resolution:
311
+ raise ValueError("Target resolution must be higher than current resolution")
312
+
313
+ self.logger.info(
314
+ f"Getting children at resolution {target_resolution} for {len(self.hexagons)} hexagons."
315
+ )
316
+
317
+ all_children = []
318
+ for hex_id in self.hexagons:
319
+ children = h3.cell_to_children(hex_id, target_resolution)
320
+ all_children.extend(children)
321
+
322
+ self.logger.info(f"Generated {len(all_children)} children hexagons.")
323
+ return H3Hexagons(resolution=target_resolution, hexagons=all_children)
324
+
325
+ def get_parents(self, target_resolution: int) -> "H3Hexagons":
326
+ """Get parent hexagons at lower resolution.
327
+
328
+ Args:
329
+ target_resolution: Target resolution (must be lower than current)
330
+
331
+ Returns:
332
+ New H3Hexagons instance with parents at target resolution
333
+ """
334
+ if target_resolution >= self.resolution:
335
+ raise ValueError("Target resolution must be lower than current resolution")
336
+
337
+ self.logger.info(
338
+ f"Getting parents at resolution {target_resolution} for {len(self.hexagons)} hexagons."
339
+ )
340
+
341
+ parents = set()
342
+ for hex_id in self.hexagons:
343
+ parent = h3.cell_to_parent(hex_id, target_resolution)
344
+ parents.add(parent)
345
+
346
+ self.logger.info(f"Generated {len(parents)} parent hexagons.")
347
+ return H3Hexagons(resolution=target_resolution, hexagons=list(parents))
348
+
349
+ def save(self, file: Union[str, Path], format: str = "json") -> None:
350
+ """Save H3Hexagons to file in specified format."""
351
+ with self.data_store.open(str(file), "wb" if format == "parquet" else "w") as f:
352
+ if format == "parquet":
353
+ self.to_geodataframe().to_parquet(f, index=False)
354
+ elif format == "geojson":
355
+ f.write(self.to_geodataframe().to_json(drop_id=True))
356
+ elif format == "json":
357
+ json.dump(self.hexagons, f)
358
+ else:
359
+ raise ValueError(f"Unsupported format: {format}")
360
+
361
+ def __len__(self) -> int:
362
+ return len(self.hexagons)
363
+
364
+
365
+ class CountryH3Hexagons(H3Hexagons):
366
+ """H3Hexagons specialized for country-level operations.
367
+
368
+ This class extends H3Hexagons to work specifically with country boundaries.
369
+ It can only be instantiated through the create() classmethod.
370
+ """
371
+
372
+ country: str = Field(..., exclude=True)
373
+
374
+ def __init__(self, *args, **kwargs):
375
+ raise TypeError(
376
+ "CountryH3Hexagons cannot be instantiated directly. "
377
+ "Use CountryH3Hexagons.create() instead."
378
+ )
379
+
380
+ @classmethod
381
+ def create(
382
+ cls,
383
+ country: str,
384
+ resolution: int,
385
+ contain: Literal["center", "full", "overlap", "bbox_overlap"] = "overlap",
386
+ data_store: Optional[DataStore] = None,
387
+ country_geom_path: Optional[Union[str, Path]] = None,
388
+ ):
389
+ """Create CountryH3Hexagons for a specific country."""
390
+ from gigaspatial.handlers.boundaries import AdminBoundaries
391
+
392
+ instance = super().__new__(cls)
393
+ super(CountryH3Hexagons, instance).__init__(
394
+ resolution=resolution,
395
+ hexagons=[],
396
+ data_store=data_store or LocalDataStore(),
397
+ country=pycountry.countries.lookup(country).alpha_3,
398
+ )
399
+
400
+ cls.logger.info(
401
+ f"Initializing H3 hexagons for country: {country} at resolution {resolution}"
402
+ )
403
+
404
+ country_geom = (
405
+ AdminBoundaries.create(
406
+ country_code=country,
407
+ data_store=data_store,
408
+ path=country_geom_path,
409
+ )
410
+ .boundaries[0]
411
+ .geometry
412
+ )
413
+
414
+ hexagons = H3Hexagons.from_geometry(country_geom, resolution, contain=contain)
415
+
416
+ instance.hexagons = hexagons.hexagons
417
+ return instance
@@ -77,7 +77,7 @@ class MercatorTiles(BaseModel):
77
77
  geometry=source, zoom_level=zoom_level, predicate=predicate, **kwargs
78
78
  )
79
79
  elif isinstance(source, Iterable) and all(
80
- len(pt) == 2 or isinstance(pt, Point) for pt in source
80
+ isinstance(pt, Point) or len(pt) == 2 for pt in source
81
81
  ):
82
82
  return cls.from_points(geometry=source, zoom_level=zoom_level, **kwargs)
83
83
  else:
@@ -328,21 +328,33 @@ class BaseHandlerReader(ABC):
328
328
  )
329
329
 
330
330
  def _load_raster_data(
331
- self, raster_paths: List[Union[str, Path]]
332
- ) -> List[TifProcessor]:
331
+ self,
332
+ raster_paths: List[Union[str, Path]],
333
+ merge_rasters: bool = False,
334
+ **kwargs,
335
+ ) -> Union[List[TifProcessor], TifProcessor]:
333
336
  """
334
337
  Load raster data from file paths.
335
338
 
336
339
  Args:
337
340
  raster_paths (List[Union[str, Path]]): List of file paths to raster files.
341
+ merge_rasters (bool): If True, all rasters will be merged into a single TifProcessor.
342
+ Defaults to False.
338
343
 
339
344
  Returns:
340
- List[TifProcessor]: List of TifProcessor objects for accessing the raster data.
345
+ Union[List[TifProcessor], TifProcessor]: List of TifProcessor objects or a single
346
+ TifProcessor if merge_rasters is True.
341
347
  """
342
- return [
343
- TifProcessor(data_path, self.data_store, mode="single")
344
- for data_path in raster_paths
345
- ]
348
+ if merge_rasters and len(raster_paths) > 1:
349
+ self.logger.info(
350
+ f"Merging {len(raster_paths)} rasters into a single TifProcessor."
351
+ )
352
+ return TifProcessor(raster_paths, self.data_store, **kwargs)
353
+ else:
354
+ return [
355
+ TifProcessor(data_path, self.data_store, **kwargs)
356
+ for data_path in raster_paths
357
+ ]
346
358
 
347
359
  def _load_tabular_data(
348
360
  self, file_paths: List[Union[str, Path]], read_function: Callable = read_dataset
@@ -619,7 +631,9 @@ class BaseHandler(ABC):
619
631
  # Download logic
620
632
  if data_units is not None:
621
633
  # Map data_units to their paths and select only those that are missing
622
- unit_to_path = dict(zip(data_paths,data_units)) #units might be dicts, cannot be used as key
634
+ unit_to_path = dict(
635
+ zip(data_paths, data_units)
636
+ ) # units might be dicts, cannot be used as key
623
637
  if force_download:
624
638
  # Download all units if force_download
625
639
  self.downloader.download_data_units(data_units, **kwargs)