giga-spatial 0.6.6__py3-none-any.whl → 0.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: giga-spatial
3
- Version: 0.6.6
3
+ Version: 0.6.8
4
4
  Summary: A package for spatial data download & processing
5
5
  Home-page: https://github.com/unicef/giga-spatial
6
6
  Author: Utku Can Ozturk
@@ -1,5 +1,5 @@
1
- giga_spatial-0.6.6.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
2
- gigaspatial/__init__.py,sha256=I3h5MyD10PkOUQEBnR6L9ja7s4WeTEg8rRjRKTCWYWQ,22
1
+ giga_spatial-0.6.8.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
2
+ gigaspatial/__init__.py,sha256=wwbrOIx2rQA0YHGob_KGFY89qGDsh20rh2M3y3Ua458,22
3
3
  gigaspatial/config.py,sha256=pLbxGc08OHT2IfTBzZVuIJTPR2vvg3KTFfvciOtRswk,9304
4
4
  gigaspatial/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  gigaspatial/core/io/__init__.py,sha256=stlpgEeHf5KIb2ZW8yEbdJK5iq6n_wX4DPmKyR9PK-w,317
@@ -27,24 +27,24 @@ gigaspatial/handlers/boundaries.py,sha256=jtWyQt3iAzS77mbAOi7mjh3cv_YCV3uB_r1h56
27
27
  gigaspatial/handlers/ghsl.py,sha256=aSEVQVANzJf8O8TiQYmfwyeM43ZaO65VJHmiuLSQfLs,30524
28
28
  gigaspatial/handlers/giga.py,sha256=F5ZfcE37a24X-c6Xhyt72C9eZZbyN_gV7w_InxKFMQQ,28348
29
29
  gigaspatial/handlers/google_open_buildings.py,sha256=Liqk7qJhDtB4Ia4uhBe44LFcf-XVKBjRfj-pWlE5erY,16594
30
- gigaspatial/handlers/hdx.py,sha256=LTEs_xZF1yPhD8dAdZ_YN8Vcan7iB5_tZ8NjF_ip6u0,18001
30
+ gigaspatial/handlers/hdx.py,sha256=1m6oG1DeEC_RLFtb6CrTReWpbQ5uG2e8EIt-IUkZbaI,18122
31
31
  gigaspatial/handlers/mapbox_image.py,sha256=M_nkJ_b1PD8FG1ajVgSycCb0NRTAI_SLpHdzszNetKA,7786
32
32
  gigaspatial/handlers/maxar_image.py,sha256=kcc8uGljQB0Yh0MKBA7lT7KwBbNZwFzuyBklR3db1P4,10204
33
33
  gigaspatial/handlers/microsoft_global_buildings.py,sha256=bQ5WHIv3v0wWrZZUbZkKPRjgdlqIxlK7CV_0zSvdrTw,20292
34
34
  gigaspatial/handlers/ookla_speedtest.py,sha256=EcvSAxJZ9GPfzYnT_C85Qgy2ecc9ndf70Pklk53OdC8,6506
35
35
  gigaspatial/handlers/opencellid.py,sha256=KuJqd-5-RO5ZzyDaBSrTgCK2ib5N_m3RUcPlX5heWwI,10683
36
- gigaspatial/handlers/osm.py,sha256=sLNMkOVh1v50jrWw7Z0-HILY5QTQjgKCHCeAfXj5jA8,14084
36
+ gigaspatial/handlers/osm.py,sha256=vUbdUm6lO2f8YyU7o4qUSkWMxlZElp7EPBFlneRaeo0,16641
37
37
  gigaspatial/handlers/overture.py,sha256=lKeNw00v5Qia7LdWORuYihnlKEqxE9m38tdeRrvag9k,4218
38
38
  gigaspatial/handlers/rwi.py,sha256=eAaplDysVeBhghJusYUKZYbKL5hW-klWvi8pWhILQkY,4962
39
39
  gigaspatial/handlers/unicef_georepo.py,sha256=ODYNvkU_UKgOHXT--0MqmJ4Uk6U1_mp9xgehbTzKpX8,31924
40
- gigaspatial/handlers/worldpop.py,sha256=pkTmqb0k0vpa58t6tM3jfcpMHt1YuayLPFEFEULlrLs,30156
40
+ gigaspatial/handlers/worldpop.py,sha256=jV166EP02Xdj8jiT8aQi4sexds8Qd3KRGHXqq70_Sdk,30161
41
41
  gigaspatial/processing/__init__.py,sha256=QDVL-QbLCrIb19lrajP7LrHNdGdnsLeGcvAs_jQpdRM,183
42
42
  gigaspatial/processing/algorithms.py,sha256=6fBCwbZrI_ISWJ7UpkH6moq1vw-7dBy14yXSLHZprqY,6591
43
43
  gigaspatial/processing/geo.py,sha256=8kD7-LQdGzKVfuZDWr3zK5uQhPzgxbZ3JBPosLRBJ5M,41390
44
44
  gigaspatial/processing/sat_images.py,sha256=YUbH5MFNzl6NX49Obk14WaFcr1s3SyGJIOk-kRpbBNg,1429
45
- gigaspatial/processing/tif_processor.py,sha256=QLln9D-_zBhdYQL9NAL_bmo0bmmxE3sxDUQEglYQK94,27490
45
+ gigaspatial/processing/tif_processor.py,sha256=dZRhMGj5r7DIu8Bop31NPbN1IdOK1syIlCOFTjTiiyo,40024
46
46
  gigaspatial/processing/utils.py,sha256=HC85vGKQakxlkoQAkZmeAXWHsenAwTIRn7jPKUA7x20,1500
47
- giga_spatial-0.6.6.dist-info/METADATA,sha256=ZKoXmthabbL_5xJYHdQfk3ev4Dz02tWU6RAtpv0vWSU,7537
48
- giga_spatial-0.6.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
49
- giga_spatial-0.6.6.dist-info/top_level.txt,sha256=LZsccgw6H4zXT7m6Y4XChm-Y5LjHAwZ2hkGN_B3ExmI,12
50
- giga_spatial-0.6.6.dist-info/RECORD,,
47
+ giga_spatial-0.6.8.dist-info/METADATA,sha256=f9MSxVRX6yhfkeoGhrsO5CdbAmVVHfhq9T4Ip7CRac4,7537
48
+ giga_spatial-0.6.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
49
+ giga_spatial-0.6.8.dist-info/top_level.txt,sha256=LZsccgw6H4zXT7m6Y4XChm-Y5LjHAwZ2hkGN_B3ExmI,12
50
+ giga_spatial-0.6.8.dist-info/RECORD,,
gigaspatial/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.6.6"
1
+ __version__ = "0.6.8"
@@ -247,7 +247,10 @@ class HDXConfig(BaseHandlerConfig):
247
247
  # If source is a dict, use it directly as a filter
248
248
  return self.get_dataset_resources(filter=source, **kwargs)
249
249
  else:
250
- raise ValueError(f"Unsupported source type: {type(source)}")
250
+ raise ValueError(
251
+ f"Unsupported source type: {type(source)}"
252
+ "Please use country-based filtering or direct resource filtering instead."
253
+ )
251
254
 
252
255
  def get_relevant_data_units_by_geometry(
253
256
  self, geometry: Union[BaseGeometry, gpd.GeoDataFrame], **kwargs
@@ -1,7 +1,8 @@
1
1
  import requests
2
2
  import pandas as pd
3
3
  from typing import List, Dict, Union, Optional, Literal
4
- from dataclasses import dataclass
4
+ from pydantic.dataclasses import dataclass
5
+ from pydantic import Field
5
6
  from time import sleep
6
7
  from concurrent.futures import ThreadPoolExecutor
7
8
  from requests.exceptions import RequestException
@@ -20,8 +21,10 @@ class OSMLocationFetcher:
20
21
  shops, and other POI categories.
21
22
  """
22
23
 
23
- country: str
24
- location_types: Union[List[str], Dict[str, List[str]]]
24
+ country: Optional[str] = None
25
+ admin_level: Optional[int] = None
26
+ admin_value: Optional[str] = None
27
+ location_types: Union[List[str], Dict[str, List[str]]] = Field(...)
25
28
  base_url: str = "http://overpass-api.de/api/interpreter"
26
29
  timeout: int = 600
27
30
  max_retries: int = 3
@@ -29,10 +32,6 @@ class OSMLocationFetcher:
29
32
 
30
33
  def __post_init__(self):
31
34
  """Validate inputs, normalize location_types, and set up logging."""
32
- try:
33
- self.country = pycountry.countries.lookup(self.country).alpha_2
34
- except LookupError:
35
- raise ValueError(f"Invalid country code provided: {self.country}")
36
35
 
37
36
  # Normalize location_types to always be a dictionary
38
37
  if isinstance(self.location_types, list):
@@ -44,6 +43,75 @@ class OSMLocationFetcher:
44
43
 
45
44
  self.logger = config.get_logger(self.__class__.__name__)
46
45
 
46
+ # Validate area selection
47
+ if self.admin_level is not None and self.admin_value is not None:
48
+ self.area_query = f'area["admin_level"={self.admin_level}]["name"="{self.admin_value}"]->.searchArea;'
49
+ self.logger.info(
50
+ f"Using admin_level={self.admin_level}, name={self.admin_value} for area selection."
51
+ )
52
+ elif self.country is not None:
53
+ try:
54
+ self.country = pycountry.countries.lookup(self.country).alpha_2
55
+ except LookupError:
56
+ raise ValueError(f"Invalid country code provided: {self.country}")
57
+ self.area_query = f'area["ISO3166-1"={self.country}]->.searchArea;'
58
+ self.logger.info(f"Using country={self.country} for area selection.")
59
+ else:
60
+ raise ValueError(
61
+ "Either country or both admin_level and admin_value must be provided."
62
+ )
63
+
64
+ @staticmethod
65
+ def get_admin_names(
66
+ admin_level: int, country: Optional[str] = None, timeout: int = 120
67
+ ) -> List[str]:
68
+ """
69
+ Fetch all admin area names for a given admin_level (optionally within a country).
70
+
71
+ Args:
72
+ admin_level (int): The OSM admin_level to search for (e.g., 4 for states, 6 for counties).
73
+ country (str, optional): Country name or ISO code to filter within.
74
+ timeout (int): Timeout for the Overpass API request.
75
+
76
+ Returns:
77
+ List[str]: List of admin area names.
78
+ """
79
+
80
+ # Build area filter for country if provided
81
+ if country:
82
+ try:
83
+ country_code = pycountry.countries.lookup(country).alpha_2
84
+ except LookupError:
85
+ raise ValueError(f"Invalid country code or name: {country}")
86
+ area_filter = f'area["ISO3166-1"="{country_code}"]->.countryArea;'
87
+ area_ref = "(area.countryArea)"
88
+ else:
89
+ area_filter = ""
90
+ area_ref = ""
91
+
92
+ # Overpass QL to get all admin areas at the specified level
93
+ query = f"""
94
+ [out:json][timeout:{timeout}];
95
+ {area_filter}
96
+ (
97
+ relation["admin_level"="{admin_level}"]{area_ref};
98
+ );
99
+ out tags;
100
+ """
101
+
102
+ url = "http://overpass-api.de/api/interpreter"
103
+ response = requests.get(url, params={"data": query}, timeout=timeout)
104
+ response.raise_for_status()
105
+ data = response.json()
106
+
107
+ names = []
108
+ for el in data.get("elements", []):
109
+ tags = el.get("tags", {})
110
+ name = tags.get("name")
111
+ if name:
112
+ names.append(name)
113
+ return sorted(set(names))
114
+
47
115
  def _build_queries(self, since_year: Optional[int] = None) -> List[str]:
48
116
  """
49
117
  Construct separate Overpass QL queries for different element types and categories.
@@ -68,7 +136,7 @@ class OSMLocationFetcher:
68
136
 
69
137
  nodes_relations_query = f"""
70
138
  [out:json][timeout:{self.timeout}];
71
- area["ISO3166-1"={self.country}]->.searchArea;
139
+ {self.area_query}
72
140
  (
73
141
  {nodes_relations_queries}
74
142
  );
@@ -86,7 +154,7 @@ class OSMLocationFetcher:
86
154
 
87
155
  ways_query = f"""
88
156
  [out:json][timeout:{self.timeout}];
89
- area["ISO3166-1"={self.country}]->.searchArea;
157
+ {self.area_query}
90
158
  (
91
159
  {ways_queries}
92
160
  );
@@ -611,7 +611,7 @@ class WPPopulationDownloader(BaseHandlerDownloader):
611
611
  total_size = int(response.headers.get("content-length", 0))
612
612
  file_path = self.config.get_data_unit_path(url)
613
613
 
614
- with self.data_store.open(file_path, "wb") as file:
614
+ with self.data_store.open(str(file_path), "wb") as file:
615
615
  with tqdm(
616
616
  total=total_size,
617
617
  unit="B",
@@ -9,9 +9,13 @@ from shapely.geometry import box, Polygon, MultiPolygon
9
9
  from pathlib import Path
10
10
  import rasterio
11
11
  from rasterio.mask import mask
12
+ from rasterio.merge import merge
13
+ from rasterio.warp import calculate_default_transform, reproject, Resampling
12
14
  from functools import partial
13
15
  import multiprocessing
14
16
  from tqdm import tqdm
17
+ import tempfile
18
+ import os
15
19
 
16
20
  from gigaspatial.core.io.data_store import DataStore
17
21
  from gigaspatial.core.io.local_data_store import LocalDataStore
@@ -22,20 +26,34 @@ from gigaspatial.config import config
22
26
  class TifProcessor:
23
27
  """
24
28
  A class to handle tif data processing, supporting single-band, RGB, RGBA, and multi-band data.
29
+ Can merge multiple rasters into one during initialization.
25
30
  """
26
31
 
27
- dataset_path: Union[Path, str]
32
+ dataset_path: Union[Path, str, List[Union[Path, str]]]
28
33
  data_store: Optional[DataStore] = None
29
34
  mode: Literal["single", "rgb", "rgba", "multi"] = "single"
35
+ merge_method: Literal["first", "last", "min", "max", "mean"] = "first"
36
+ target_crs: Optional[str] = None # For reprojection if needed
37
+ resampling_method: Resampling = Resampling.nearest
30
38
 
31
39
  def __post_init__(self):
32
- """Validate inputs and set up logging."""
40
+ """Validate inputs, merge rasters if needed, and set up logging."""
33
41
  self.data_store = self.data_store or LocalDataStore()
34
42
  self.logger = config.get_logger(self.__class__.__name__)
35
43
  self._cache = {}
36
-
37
- if not self.data_store.file_exists(self.dataset_path):
38
- raise FileNotFoundError(f"Dataset not found at {self.dataset_path}")
44
+ self._merged_file_path = None
45
+ self._temp_dir = None
46
+
47
+ # Handle multiple dataset paths
48
+ if isinstance(self.dataset_path, list):
49
+ self.dataset_paths = [Path(p) for p in self.dataset_path]
50
+ self._validate_multiple_datasets()
51
+ self._merge_rasters()
52
+ self.dataset_path = self._merged_file_path
53
+ else:
54
+ self.dataset_paths = [Path(self.dataset_path)]
55
+ if not self.data_store.file_exists(self.dataset_path):
56
+ raise FileNotFoundError(f"Dataset not found at {self.dataset_path}")
39
57
 
40
58
  self._load_metadata()
41
59
 
@@ -49,13 +67,298 @@ class TifProcessor:
49
67
  if self.mode == "multi" and self.count < 2:
50
68
  raise ValueError("Multi mode requires a TIF file with 2 or more bands")
51
69
 
70
+ def _validate_multiple_datasets(self):
71
+ """Validate that all datasets exist and have compatible properties."""
72
+ if len(self.dataset_paths) < 2:
73
+ raise ValueError("Multiple dataset paths required for merging")
74
+
75
+ # Check if all files exist
76
+ for path in self.dataset_paths:
77
+ if not self.data_store.file_exists(path):
78
+ raise FileNotFoundError(f"Dataset not found at {path}")
79
+
80
+ # Load first dataset to get reference properties
81
+ with self.data_store.open(self.dataset_paths[0], "rb") as f:
82
+ with rasterio.MemoryFile(f.read()) as memfile:
83
+ with memfile.open() as ref_src:
84
+ ref_count = ref_src.count
85
+ ref_dtype = ref_src.dtypes[0]
86
+ ref_crs = ref_src.crs
87
+ ref_transform = ref_src.transform
88
+ ref_nodata = ref_src.nodata
89
+
90
+ # Validate all other datasets against reference
91
+ for i, path in enumerate(self.dataset_paths[1:], 1):
92
+ with self.data_store.open(path, "rb") as f:
93
+ with rasterio.MemoryFile(f.read()) as memfile:
94
+ with memfile.open() as src:
95
+ if src.count != ref_count:
96
+ raise ValueError(
97
+ f"Dataset {i} has {src.count} bands, expected {ref_count}"
98
+ )
99
+ if src.dtypes[0] != ref_dtype:
100
+ raise ValueError(
101
+ f"Dataset {i} has dtype {src.dtypes[0]}, expected {ref_dtype}"
102
+ )
103
+ if self.target_crs is None and src.crs != ref_crs:
104
+ raise ValueError(
105
+ f"Dataset {i} has CRS {src.crs}, expected {ref_crs}. Consider setting target_crs parameter."
106
+ )
107
+ if self.target_crs is None and not self._transforms_compatible(
108
+ src.transform, ref_transform
109
+ ):
110
+ self.logger.warning(
111
+ f"Dataset {i} has different resolution. Resampling may be needed."
112
+ )
113
+ if src.nodata != ref_nodata:
114
+ self.logger.warning(
115
+ f"Dataset {i} has different nodata value: {src.nodata} vs {ref_nodata}"
116
+ )
117
+
118
+ def _transforms_compatible(self, transform1, transform2, tolerance=1e-6):
119
+ """Check if two transforms have compatible pixel sizes."""
120
+ return (
121
+ abs(transform1.a - transform2.a) < tolerance
122
+ and abs(transform1.e - transform2.e) < tolerance
123
+ )
124
+
125
+ def _merge_rasters(self):
126
+ """Merge multiple rasters into a single raster."""
127
+ self.logger.info(f"Merging {len(self.dataset_paths)} rasters...")
128
+
129
+ # Create temporary directory for merged file
130
+ self._temp_dir = tempfile.mkdtemp()
131
+ merged_filename = "merged_raster.tif"
132
+ self._merged_file_path = os.path.join(self._temp_dir, merged_filename)
133
+
134
+ # Open all datasets and handle reprojection if needed
135
+ src_files = []
136
+ reprojected_files = []
137
+
138
+ try:
139
+ for path in self.dataset_paths:
140
+ with self.data_store.open(path, "rb") as f:
141
+ # Create temporary file for each dataset
142
+ temp_file = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
143
+ temp_file.write(f.read())
144
+ temp_file.close()
145
+ src_files.append(rasterio.open(temp_file.name))
146
+
147
+ # Handle reprojection if target_crs is specified
148
+ if self.target_crs:
149
+ self.logger.info(f"Reprojecting rasters to {self.target_crs}...")
150
+ processed_files = self._reproject_rasters(src_files, self.target_crs)
151
+ reprojected_files = processed_files
152
+ else:
153
+ processed_files = src_files
154
+
155
+ if self.merge_method == "mean":
156
+ # For mean, we need to handle it manually
157
+ merged_array, merged_transform = self._merge_with_mean(src_files)
158
+
159
+ # Use first source as reference for metadata
160
+ ref_src = src_files[0]
161
+ profile = ref_src.profile.copy()
162
+ profile.update(
163
+ {
164
+ "height": merged_array.shape[-2],
165
+ "width": merged_array.shape[-1],
166
+ "transform": merged_transform,
167
+ }
168
+ )
169
+
170
+ # Write merged raster
171
+ with rasterio.open(self._merged_file_path, "w", **profile) as dst:
172
+ dst.write(merged_array)
173
+
174
+ else:
175
+ # Use rasterio's merge function
176
+ merged_array, merged_transform = merge(
177
+ src_files,
178
+ method=self.merge_method,
179
+ resampling=self.resampling_method,
180
+ )
181
+
182
+ # Use first source as reference for metadata
183
+ ref_src = src_files[0]
184
+ profile = ref_src.profile.copy()
185
+ profile.update(
186
+ {
187
+ "height": merged_array.shape[-2],
188
+ "width": merged_array.shape[-1],
189
+ "transform": merged_transform,
190
+ }
191
+ )
192
+
193
+ if self.target_crs:
194
+ profile["crs"] = self.target_crs
195
+
196
+ # Write merged raster
197
+ with rasterio.open(self._merged_file_path, "w", **profile) as dst:
198
+ dst.write(merged_array)
199
+
200
+ finally:
201
+ # Clean up source files
202
+ for src in src_files:
203
+ temp_path = src.name
204
+ src.close()
205
+ try:
206
+ os.unlink(temp_path)
207
+ except:
208
+ pass
209
+
210
+ # Clean up reprojected files
211
+ for src in reprojected_files:
212
+ if src not in src_files: # Don't double-close
213
+ temp_path = src.name
214
+ src.close()
215
+ try:
216
+ os.unlink(temp_path)
217
+ except:
218
+ pass
219
+
220
+ self.logger.info("Raster merging completed!")
221
+
222
+ def _reproject_rasters(self, src_files, target_crs):
223
+ """Reproject all rasters to a common CRS before merging."""
224
+ reprojected_files = []
225
+
226
+ for i, src in enumerate(src_files):
227
+ if src.crs.to_string() == target_crs:
228
+ # No reprojection needed
229
+ reprojected_files.append(src)
230
+ continue
231
+
232
+ # Calculate transform and dimensions for reprojection
233
+ transform, width, height = calculate_default_transform(
234
+ src.crs,
235
+ target_crs,
236
+ src.width,
237
+ src.height,
238
+ *src.bounds,
239
+ resolution=self.resolution if hasattr(self, "resolution") else None,
240
+ )
241
+
242
+ # Create temporary file for reprojected raster
243
+ temp_file = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
244
+ temp_file.close()
245
+
246
+ # Set up profile for reprojected raster
247
+ profile = src.profile.copy()
248
+ profile.update(
249
+ {
250
+ "crs": target_crs,
251
+ "transform": transform,
252
+ "width": width,
253
+ "height": height,
254
+ }
255
+ )
256
+
257
+ # Reproject and write to temporary file
258
+ with rasterio.open(temp_file.name, "w", **profile) as dst:
259
+ for band_idx in range(1, src.count + 1):
260
+ reproject(
261
+ source=rasterio.band(src, band_idx),
262
+ destination=rasterio.band(dst, band_idx),
263
+ src_transform=src.transform,
264
+ src_crs=src.crs,
265
+ dst_transform=transform,
266
+ dst_crs=target_crs,
267
+ resampling=self.resampling_method,
268
+ )
269
+
270
+ # Open reprojected file
271
+ reprojected_files.append(rasterio.open(temp_file.name))
272
+
273
+ return reprojected_files
274
+
275
+ def _merge_with_mean(self, src_files):
276
+ """Merge rasters using mean aggregation."""
277
+ # Get bounds and resolution for merged raster
278
+ bounds = src_files[0].bounds
279
+ transform = src_files[0].transform
280
+
281
+ for src in src_files[1:]:
282
+ bounds = rasterio.coords.BoundingBox(
283
+ min(bounds.left, src.bounds.left),
284
+ min(bounds.bottom, src.bounds.bottom),
285
+ max(bounds.right, src.bounds.right),
286
+ max(bounds.top, src.bounds.top),
287
+ )
288
+
289
+ # Calculate dimensions for merged raster
290
+ width = int((bounds.right - bounds.left) / abs(transform.a))
291
+ height = int((bounds.top - bounds.bottom) / abs(transform.e))
292
+
293
+ # Create new transform for merged bounds
294
+ merged_transform = rasterio.transform.from_bounds(
295
+ bounds.left, bounds.bottom, bounds.right, bounds.top, width, height
296
+ )
297
+
298
+ # Initialize arrays for sum and count
299
+ sum_array = np.zeros((src_files[0].count, height, width), dtype=np.float64)
300
+ count_array = np.zeros((height, width), dtype=np.int32)
301
+
302
+ # Process each source file
303
+ for src in src_files:
304
+ # Read data
305
+ data = src.read()
306
+
307
+ # Calculate offset in merged raster
308
+ src_bounds = src.bounds
309
+ col_off = int((src_bounds.left - bounds.left) / abs(transform.a))
310
+ row_off = int((bounds.top - src_bounds.top) / abs(transform.e))
311
+
312
+ # Get valid data mask
313
+ if src.nodata is not None:
314
+ valid_mask = data[0] != src.nodata
315
+ else:
316
+ valid_mask = np.ones(data[0].shape, dtype=bool)
317
+
318
+ # Add to sum and count arrays
319
+ end_row = row_off + data.shape[1]
320
+ end_col = col_off + data.shape[2]
321
+
322
+ sum_array[:, row_off:end_row, col_off:end_col] += np.where(
323
+ valid_mask, data, 0
324
+ )
325
+ count_array[row_off:end_row, col_off:end_col] += valid_mask.astype(np.int32)
326
+
327
+ # Calculate mean
328
+ mean_array = np.divide(
329
+ sum_array,
330
+ count_array,
331
+ out=np.full_like(
332
+ sum_array, src_files[0].nodata or 0, dtype=sum_array.dtype
333
+ ),
334
+ where=count_array > 0,
335
+ )
336
+
337
+ return mean_array.astype(src_files[0].dtypes[0]), merged_transform
338
+
339
+ def __del__(self):
340
+ """Cleanup temporary files."""
341
+ if self._temp_dir and os.path.exists(self._temp_dir):
342
+ try:
343
+ import shutil
344
+
345
+ shutil.rmtree(self._temp_dir)
346
+ except:
347
+ pass
348
+
52
349
  @contextmanager
53
350
  def open_dataset(self):
54
351
  """Context manager for accessing the dataset"""
55
- with self.data_store.open(self.dataset_path, "rb") as f:
56
- with rasterio.MemoryFile(f.read()) as memfile:
57
- with memfile.open() as src:
58
- yield src
352
+ if self._merged_file_path:
353
+ # Open merged file directly
354
+ with rasterio.open(self._merged_file_path) as src:
355
+ yield src
356
+ else:
357
+ # Original single file logic
358
+ with self.data_store.open(self.dataset_path, "rb") as f:
359
+ with rasterio.MemoryFile(f.read()) as memfile:
360
+ with memfile.open() as src:
361
+ yield src
59
362
 
60
363
  def _load_metadata(self):
61
364
  """Load metadata from the TIF file if not already cached"""
@@ -73,6 +376,17 @@ class TifProcessor:
73
376
  self._cache["count"] = src.count
74
377
  self._cache["dtype"] = src.dtypes[0]
75
378
 
379
+ @property
380
+ def is_merged(self) -> bool:
381
+ """Check if this processor was created from multiple rasters."""
382
+ return len(self.dataset_paths) > 1
383
+
384
+ @property
385
+ def source_count(self) -> int:
386
+ """Get the number of source rasters."""
387
+ return len(self.dataset_paths)
388
+
389
+ # All other methods remain the same...
76
390
  @property
77
391
  def transform(self):
78
392
  """Get the transform from the TIF file"""
@@ -380,7 +694,7 @@ class TifProcessor:
380
694
  results = [item for sublist in batched_results for item in sublist]
381
695
 
382
696
  return np.array(results)
383
-
697
+
384
698
  def _initializer_worker(self):
385
699
  """
386
700
  Initializer function for each worker process.
@@ -727,9 +1041,7 @@ def sample_multiple_tifs_by_polygons(
727
1041
  sampled_values = np.full(len(polygon_list), np.nan, dtype=np.float32)
728
1042
 
729
1043
  for tp in tif_processors:
730
- values = tp.sample_by_polygons(
731
- polygon_list=polygon_list, stat=stat
732
- )
1044
+ values = tp.sample_by_polygons(polygon_list=polygon_list, stat=stat)
733
1045
 
734
1046
  mask = np.isnan(sampled_values) # replace all NaNs
735
1047