pycontrails 0.56.0__cp312-cp312-win_amd64.whl → 0.57.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

@@ -0,0 +1,654 @@
1
+ """Support for Himawari-8/9 satellite data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import bz2
6
+ import collections
7
+ import datetime
8
+ import enum
9
+ import warnings
10
+ from collections.abc import Iterable
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import pandas as pd
16
+ import xarray as xr
17
+
18
+ from pycontrails.core import cache
19
+ from pycontrails.datalib import geo_utils
20
+ from pycontrails.datalib.himawari import header_struct
21
+ from pycontrails.utils import dependencies
22
+
23
+ if TYPE_CHECKING:
24
+ import cartopy.crs
25
+
26
+ try:
27
+ import s3fs
28
+ except ModuleNotFoundError as exc:
29
+ dependencies.raise_module_not_found_error(
30
+ name="goes module",
31
+ package_name="s3fs",
32
+ module_not_found_error=exc,
33
+ pycontrails_optional_package="sat",
34
+ )
35
+
36
+
37
+ #: Default bands to use if none are specified. These are the channels
38
+ #: required by the SEVIRI (MIT) ash color scheme.
39
+ DEFAULT_BANDS = "B11", "B14", "B15"
40
+
41
+ #: The date at which Himawari-9 was declared operational, replacing Himawari-8.
42
+ #: This is used to determine which S3 bucket to use if none is specified.
43
+ #: See the `documentation <https://www.data.jma.go.jp/mscweb/en/oper/switchover.html>`_
44
+ HIMAWARI_8_9_SWITCH_DATE = datetime.datetime(2022, 12, 13, 5, 0)
45
+
46
+
47
+ #: The S3 bucket for Himawari-8 data.
48
+ HIMAWARI_8_BUCKET = "noaa-himawari8"
49
+
50
+ #: The S3 bucket for Himawari-9 data.
51
+ HIMAWARI_9_BUCKET = "noaa-himawari9"
52
+
53
+
54
+ class HimawariRegion(enum.Enum):
55
+ """Himawari-8/9 regions."""
56
+
57
+ FLDK = enum.auto() # Full Disk
58
+ Japan = enum.auto()
59
+ Target = enum.auto()
60
+
61
+
62
+ def _check_time_resolution(
63
+ t: datetime.datetime,
64
+ region: HimawariRegion,
65
+ ) -> tuple[datetime.datetime, str]:
66
+ """Check that the time is at a valid Himawari time resolution.
67
+
68
+ Return the time and the scan type (FLDK, or JP01, JP02, JP03, JP04).
69
+ """
70
+ if t.microsecond:
71
+ raise ValueError("Microseconds are not supported in Himawari time.")
72
+
73
+ total_seconds = t.minute * 60 + t.second
74
+
75
+ if region == HimawariRegion.FLDK:
76
+ if total_seconds % 600:
77
+ raise ValueError("Himawari FLDK data is only available at 10-minute intervals.")
78
+ return t, "FLDK"
79
+
80
+ if total_seconds % 150:
81
+ raise ValueError("Himawari Japan or Target data is only available at 2.5-minute intervals.")
82
+
83
+ offset = (total_seconds // 150) % 4
84
+ t_floor = t - datetime.timedelta(minutes=t.minute % 10, seconds=t.second)
85
+
86
+ prefix = "JP0" if region == HimawariRegion.Japan else "R30"
87
+ scan_type = f"{prefix}{offset + 1}"
88
+ return t_floor, scan_type
89
+
90
+
91
+ def _parse_bands(bands: str | Iterable[str] | None) -> set[str]:
92
+ """Check that the bands are valid and return as a set.
93
+
94
+ This function is nearly identical to the GOES _parse_channels function.
95
+ """
96
+ if bands is None:
97
+ return set(DEFAULT_BANDS)
98
+
99
+ if isinstance(bands, str):
100
+ bands = (bands,)
101
+
102
+ available = {f"B{i:02d}" for i in range(1, 17)}
103
+ bands = {b.upper() for b in bands}
104
+ if not bands.issubset(available):
105
+ raise ValueError(f"bands must be in {sorted(available)}")
106
+ return bands
107
+
108
+
109
+ def _check_band_resolution(bands: Iterable[str]) -> None:
110
+ # https://www.data.jma.go.jp/mscweb/en/himawari89/space_segment/spsg_ahi.html
111
+ res = {
112
+ "B01": 1.0,
113
+ "B02": 1.0,
114
+ "B03": 1.0, # XXX: this actually has a resolution of 0.5 km, but we coarsen it to 1 km
115
+ "B04": 1.0,
116
+ "B05": 2.0,
117
+ "B06": 2.0,
118
+ "B07": 2.0,
119
+ "B08": 2.0,
120
+ "B09": 2.0,
121
+ "B10": 2.0,
122
+ "B11": 2.0,
123
+ "B12": 2.0,
124
+ "B13": 2.0,
125
+ "B14": 2.0,
126
+ "B15": 2.0,
127
+ "B16": 2.0,
128
+ }
129
+
130
+ found_res = {b: res[b] for b in bands}
131
+ unique_res = set(found_res.values())
132
+ if len(unique_res) > 1:
133
+ b0, r0 = found_res.popitem()
134
+ b1, r1 = next((b, r) for b, r in found_res.items() if r != r0)
135
+ raise ValueError(
136
+ "Bands must have a common horizontal resolution. "
137
+ f"Band {b0} has resolution {r0} km and band {b1} has resolution {r1} km."
138
+ )
139
+
140
+
141
+ def _parse_region(region: HimawariRegion | str) -> HimawariRegion:
142
+ """Parse region from string."""
143
+ if isinstance(region, HimawariRegion):
144
+ return region
145
+
146
+ region = region.upper().replace(" ", "").replace("_", "")
147
+
148
+ if region in ("F", "FLDK", "FULL", "FULLDISK"):
149
+ return HimawariRegion.FLDK
150
+ if region in ("J", "JAPAN"):
151
+ return HimawariRegion.Japan
152
+ if region in ("T", "TARGET", "MESOSCALE"):
153
+ return HimawariRegion.Target
154
+ raise ValueError(f"Region must be one of {HimawariRegion._member_names_}")
155
+
156
+
157
+ def _extract_band_from_rpath(rpath: str) -> str:
158
+ sep = "_B"
159
+ suffix = rpath.split(sep, maxsplit=1)[1]
160
+ return f"B{suffix[:2]}" # B??
161
+
162
+
163
+ def _mask_invalid(data: npt.NDArray[np.uint16], calib_info: dict) -> npt.NDArray[np.float32]:
164
+ """Mask invalid data."""
165
+ error_pixel = calib_info["count_error_pixels"]
166
+ outside_pixel = calib_info["count_outside_scan_area"]
167
+
168
+ mask = (data == error_pixel) | (data == outside_pixel)
169
+ return np.where(mask, np.float32(np.nan), data.astype(np.float32))
170
+
171
+
172
+ def _radiance_to_brightness_temperature(
173
+ radiance: npt.NDArray[np.float32],
174
+ calib_info: dict[str, Any],
175
+ ) -> npt.NDArray[np.float32]:
176
+ """Convert radiance to brightness temperature."""
177
+ radiance = np.where(radiance <= 0.0, np.float32(np.nan), radiance) # remove invalid
178
+ radiance_m = radiance * 1e6 # W/m^2/sr/um -> W/m^2/sr/m
179
+
180
+ lmbda = calib_info["central_wavelength"] * 1e-6 # um -> m
181
+ h = calib_info["planck_constant"]
182
+ c = calib_info["speed_of_light"]
183
+ k = calib_info["boltzmann_constant"]
184
+
185
+ term = (2 * h * c**2) / (lmbda**5 * radiance_m)
186
+ effective_bt = (h * c) / (k * lmbda * np.log1p(term))
187
+
188
+ c0 = calib_info["c0"]
189
+ c1 = calib_info["c1"]
190
+ c2 = calib_info["c2"]
191
+ return c0 + c1 * effective_bt + c2 * effective_bt**2
192
+
193
+
194
+ def _radiance_to_reflectance(
195
+ radiance: npt.NDArray[np.float32],
196
+ calib_info: dict[str, Any],
197
+ ) -> npt.NDArray[np.float32]:
198
+ """Convert radiance to reflectance."""
199
+ coeff = calib_info["coeff_c_prime"]
200
+ return radiance * coeff
201
+
202
+
203
+ def _load_raw_counts(content: bytes, metadata: dict[str, Any]) -> npt.NDArray[np.uint16]:
204
+ """Load raw counts from Himawari data."""
205
+ offset = metadata["basic_information"]["total_header_length"]
206
+ n_columns = metadata["data_information"]["num_columns"]
207
+ n_lines = metadata["data_information"]["num_lines"]
208
+ return np.frombuffer(content, dtype=np.uint16, offset=offset).reshape((n_lines, n_columns))
209
+
210
+
211
+ def _counts_to_radiance(
212
+ counts: npt.NDArray[np.float32],
213
+ calib_info: dict[str, Any],
214
+ ) -> npt.NDArray[np.float32]:
215
+ """Convert raw counts to radiance."""
216
+ gain = calib_info["gain"]
217
+ const = calib_info["constant"]
218
+ return counts * gain + const
219
+
220
+
221
+ def _load_image_data(content: bytes, metadata: dict) -> npt.NDArray[np.float32]:
222
+ counts = _load_raw_counts(content, metadata)
223
+
224
+ calib_info = metadata["calibration_information"]
225
+ masked_counts = _mask_invalid(counts, calib_info)
226
+ radiance = _counts_to_radiance(masked_counts, calib_info)
227
+
228
+ if calib_info["band_number"] <= 6: # visible/NIR
229
+ return _radiance_to_reflectance(radiance, calib_info)
230
+ return _radiance_to_brightness_temperature(radiance, calib_info)
231
+
232
+
233
+ def _ahi_fixed_grid(proj_info: dict, arr: np.ndarray) -> tuple[xr.DataArray, xr.DataArray]:
234
+ n_lines, n_columns = arr.shape
235
+
236
+ i = np.arange(n_columns, dtype=np.float32)
237
+ j = np.arange(n_lines, dtype=np.float32)
238
+
239
+ # See section 4.4.4 (scaling functions) of the CGMS LRIT/HRIT specification
240
+ # https://www.cgms-info.org/wp-content/uploads/2021/10/cgms-lrit-hrit-global-specification-(v2-8-of-30-oct-2013).pdf
241
+ x_deg = (i - proj_info["coff"]) / proj_info["cfac"] * 2**16
242
+ y_deg = -(j - proj_info["loff"]) / proj_info["lfac"] * 2**16 # positive y is north
243
+
244
+ x_rad = np.deg2rad(x_deg)
245
+ y_rad = np.deg2rad(y_deg)
246
+
247
+ x = xr.DataArray(
248
+ x_rad,
249
+ dims=("x",),
250
+ attrs={
251
+ "units": "rad",
252
+ "axis": "X",
253
+ "long_name": "AHI fixed grid projection x-coordinate",
254
+ "standard_name": "projection_x_coordinate",
255
+ },
256
+ )
257
+ y = xr.DataArray(
258
+ y_rad,
259
+ dims=("y",),
260
+ attrs={
261
+ "units": "rad",
262
+ "axis": "Y",
263
+ "long_name": "AHI fixed grid projection y-coordinate",
264
+ "standard_name": "projection_y_coordinate",
265
+ },
266
+ )
267
+
268
+ return x, y
269
+
270
+
271
+ def _himawari_proj4_string(proj_info: dict[str, Any]) -> str:
272
+ H = proj_info["dist_from_earth_center"] * 1000.0 # km -> m
273
+ a = proj_info["equatorial_radius"] * 1000.0 # km -> m
274
+ b = proj_info["polar_radius"] * 1000.0 # km -> m
275
+ lon = proj_info["sub_lon"]
276
+ h = H - a # height above surface
277
+ return f"+proj=geos +h={h} +a={a} +b={b} +lon_0={lon} +sweep=x +units=m +no_defs"
278
+
279
+
280
+ def _earth_disk_mask(proj_info: dict, x: xr.DataArray, y: xr.DataArray) -> npt.NDArray[np.bool_]:
281
+ """Return a boolean mask where True indicates pixels over the Earth disk."""
282
+ a = proj_info["equatorial_radius"] * 1000.0 # km -> m
283
+ b = proj_info["polar_radius"] * 1000.0 # km -> m
284
+ h = proj_info["dist_from_earth_center"] * 1000.0 # km -> m
285
+
286
+ # Precompute trig terms
287
+ cosx = np.cos(x.values[np.newaxis, :]) # shape (1, nx)
288
+ cosy = np.cos(y.values[:, np.newaxis]) # shape (ny, 1)
289
+ siny = np.sin(y.values[:, np.newaxis]) # shape (ny, 1)
290
+
291
+ # Form a ray from the satellite to each pixel (in the scan angle space). Compute the
292
+ # intersection of the ray with the ellipsoid gives a quadratic equation.
293
+ A = cosy**2 / a**2 + siny**2 / b**2
294
+ B = -2 * h * cosy * cosx / a**2
295
+ C = h**2 / a**2 - 1.0
296
+
297
+ discriminant = B**2 - 4 * A * C
298
+
299
+ # A positive discriminant indicates the ray from satellite intersects ellipsoid
300
+ # within Earth disk. Return True for valid Earth pixels.
301
+ return discriminant >= 0.0
302
+
303
+
304
+ def _parse_start_time(metadata: dict) -> datetime.datetime:
305
+ """Parse the start time from the metadata."""
306
+ mjd_value = metadata["basic_information"]["obs_start_time"]
307
+ mjd_epoch = datetime.datetime(1858, 11, 17)
308
+ return mjd_epoch + datetime.timedelta(days=mjd_value)
309
+
310
+
311
+ def _parse_s3_raw_data(raw_data: list[bytes]) -> xr.DataArray:
312
+ """Decode a list of Himawari bz2-compressed bytes to an xarray DataArray."""
313
+ arrays = []
314
+ proj_info = None
315
+ start_time = None
316
+
317
+ for data in raw_data:
318
+ content = bz2.decompress(data)
319
+ metadata = header_struct.parse_himawari_header(content)
320
+ proj_info = proj_info or metadata["projection_information"]
321
+ start_time = start_time or _parse_start_time(metadata)
322
+
323
+ arr = _load_image_data(content, metadata)
324
+
325
+ segment_number = metadata["segment_information"]["segment_seq_number"]
326
+ arrays.append((segment_number, arr))
327
+
328
+ # (This sorting isn't really necessary since s3fs.glob returns sorted results)
329
+ sorted_arrays = [arr for _, arr in sorted(arrays, key=lambda x: x[0])]
330
+ combined = np.vstack(sorted_arrays)
331
+
332
+ assert proj_info is not None
333
+ x, y = _ahi_fixed_grid(proj_info, combined)
334
+
335
+ mask = _earth_disk_mask(proj_info, x, y) # mask values outside Earth disk
336
+ combined[~mask] = np.float32(np.nan)
337
+
338
+ crs = _himawari_proj4_string(proj_info)
339
+ band = metadata["calibration_information"]["band_number"]
340
+ if band > 6:
341
+ long_name = "Advanced Himawari Imager (AHI) brightness temperature"
342
+ standard_name = "toa_brightness_temperature"
343
+ units = "K"
344
+ else:
345
+ long_name = "Advanced Himawari Imager (AHI) reflectance"
346
+ standard_name = "toa_reflectance"
347
+ units = ""
348
+
349
+ return xr.DataArray(
350
+ combined,
351
+ dims=("y", "x"),
352
+ coords={"x": x, "y": y, "t": start_time},
353
+ attrs={"crs": crs, "long_name": long_name, "standard_name": standard_name, "units": units},
354
+ ).expand_dims(band_id=np.array([band], dtype=np.int32)) # use band_id to match GOES
355
+
356
+
357
+ class Himawari:
358
+ """Support for Himawari-8/9 satellite data accessed via AWS S3.
359
+
360
+ This interface requires the ``s3fs`` package.
361
+
362
+ Parameters
363
+ ----------
364
+ region : HimawariRegion | str, optional
365
+ The Himawari-8/9 area to download. By default, :attr:`HimawariRegion.FLDK` (Full Disk).
366
+ bands : str | Iterable[str] | None, optional
367
+ The bands to download. The 16 possible bands are ``B01`` to ``B16``. For the SEVIRI
368
+ ash color scheme, bands ``B11``, ``B14``, and ``B15`` are required (default). For
369
+ the true color scheme, bands ``B01``, ``B02``, and ``B03`` are required.
370
+ See `here <https://www.data.jma.go.jp/mscweb/en/himawari89/space_segment/spsg_ahi.html#band>`_
371
+ for more information on the bands.
372
+ bucket : str | None, optional
373
+ The S3 bucket to use. By default, the bucket is chosen based on the time
374
+ (Himawari-8 before 2022-12-13, Himawari-9 after).
375
+ cachestore : cache.CacheStore | None, optional
376
+ The cache store to use. By default, a disk cache in the user cache directory
377
+ is used. If None, data is downloaded directly into memory from S3.
378
+
379
+ See Also
380
+ --------
381
+ pycontrails.datalib.goes.GOES
382
+ HimawariRegion
383
+ """
384
+
385
+ __marker = object()
386
+
387
+ def __init__(
388
+ self,
389
+ region: HimawariRegion | str = HimawariRegion.FLDK,
390
+ bands: str | Iterable[str] | None = None,
391
+ *,
392
+ bucket: str | None = None,
393
+ cachestore: cache.CacheStore | None = __marker, # type: ignore[assignment]
394
+ ) -> None:
395
+ self.region = _parse_region(region)
396
+ self.bands = _parse_bands(bands)
397
+ _check_band_resolution(self.bands)
398
+
399
+ self.bucket = bucket
400
+ self.fs = s3fs.S3FileSystem(anon=True)
401
+
402
+ if cachestore is self.__marker:
403
+ cache_root = cache._get_user_cache_dir()
404
+ cache_dir = f"{cache_root}/himawari"
405
+ cachestore = cache.DiskCacheStore(cache_dir=cache_dir)
406
+ self.cachestore = cachestore
407
+
408
+ def __repr__(self) -> str:
409
+ """Return string representation."""
410
+ return (
411
+ f"Himawari(region='{self.region.name}', bands={sorted(self.bands)}, "
412
+ f"bucket={self.bucket})"
413
+ )
414
+
415
+ def s3_rpaths(self, time: datetime.datetime) -> dict[str, list[str]]:
416
+ """Return S3 remote paths for a given time."""
417
+ t, scan_type = _check_time_resolution(time, self.region)
418
+
419
+ if self.bucket is None:
420
+ bucket = HIMAWARI_8_BUCKET if t < HIMAWARI_8_9_SWITCH_DATE else HIMAWARI_9_BUCKET
421
+ else:
422
+ bucket = self.bucket
423
+
424
+ sat_number = bucket.removeprefix("noaa-himawari") # Will not work for custom buckets
425
+
426
+ # Get all bands for the time
427
+ prefix = f"{bucket}/AHI-L1b-{self.region.name}/{t:%Y/%m/%d/%H%M}/HS_H0{sat_number}_{t:%Y%m%d_%H%M}_B??_{scan_type}" # noqa: E501
428
+ rpaths = self.fs.glob(f"{prefix}*")
429
+
430
+ out = collections.defaultdict(list)
431
+ for rpath in rpaths:
432
+ band = _extract_band_from_rpath(rpath)
433
+ if band in self.bands:
434
+ out[band].append(rpath)
435
+
436
+ return out
437
+
438
+ def _lpaths(self, time: datetime.datetime) -> dict[str, str]:
439
+ """Construct names for local netcdf files using the :attr:`cachestore`.
440
+
441
+ Returns dictionary of the form ``{band: local_path}``.
442
+
443
+ Implementation is copied directly from :meth:`GOES._lpaths`.
444
+ """
445
+ if not self.cachestore:
446
+ raise ValueError("cachestore must be set to use _lpaths")
447
+
448
+ t_str = time.strftime("%Y%m%d%H%M%S")
449
+
450
+ out = {}
451
+ for band in self.bands:
452
+ if self.bucket:
453
+ name = f"{self.bucket}_{self.region.name}_{t_str}_{band}.nc"
454
+ else:
455
+ name = f"{self.region.name}_{t_str}_{band}.nc"
456
+
457
+ lpath = self.cachestore.path(name)
458
+ out[band] = lpath
459
+
460
+ return out
461
+
462
+ def get(self, time: datetime.datetime | str) -> xr.DataArray:
463
+ """Get Himawari-8/9 data for a given time."""
464
+ t = pd.Timestamp(time).to_pydatetime()
465
+
466
+ if self.cachestore is not None:
467
+ return self._get_with_cache(t)
468
+ return self._get_without_cache(t)
469
+
470
+ def _get_with_cache(self, time: datetime.datetime) -> xr.DataArray:
471
+ """Get Himawari-8/9 data for a given time, using the cache if available."""
472
+ if self.cachestore is None:
473
+ raise ValueError("cachestore must be set to use get_with_cache")
474
+
475
+ lpaths = self._lpaths(time)
476
+
477
+ missing_bands = [b for b, p in lpaths.items() if not self.cachestore.exists(p)]
478
+ if missing_bands:
479
+ rpaths_all_bands = self.s3_rpaths(time)
480
+ for band in missing_bands:
481
+ rpaths = rpaths_all_bands[band]
482
+ if not rpaths:
483
+ raise ValueError(f"No data found for band {band} at time {time}")
484
+ raw_data = list(self.fs.cat(rpaths).values())
485
+ da = _parse_s3_raw_data(raw_data)
486
+ da.to_dataset(name="CMI").to_netcdf(lpaths[band]) # only using CMI to match GOES
487
+
488
+ kwargs = {
489
+ "concat_dim": "band_id",
490
+ "combine": "nested",
491
+ "combine_attrs": "override",
492
+ "coords": "minimal",
493
+ "compat": "override",
494
+ }
495
+ if len(lpaths) == 1 or "B03" not in lpaths:
496
+ return xr.open_mfdataset(lpaths.values(), **kwargs)["CMI"].sortby("band_id") # type: ignore[arg-type]
497
+
498
+ lpath03 = lpaths.pop("B03")
499
+ da1 = xr.open_mfdataset(lpaths.values(), **kwargs)["CMI"] # type: ignore[arg-type]
500
+ da03 = xr.open_dataset(lpath03)["CMI"]
501
+ return geo_utils._coarsen_then_concat(da1, da03).sortby("band_id")
502
+
503
+ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
504
+ """Get Himawari-8/9 data for a given time, without using the cache."""
505
+ all_rpaths = self.s3_rpaths(time)
506
+
507
+ da_dict = {}
508
+ for band, rpaths in all_rpaths.items():
509
+ if len(rpaths) == 0:
510
+ raise ValueError(f"No data found for band {band} at time {time}")
511
+
512
+ raw_data = list(self.fs.cat(rpaths).values())
513
+ da = _parse_s3_raw_data(raw_data).rename("CMI") # only using CMI to match GOES
514
+ da_dict[band] = da
515
+
516
+ kwargs = {
517
+ "dim": "band_id",
518
+ "coords": "minimal",
519
+ "compat": "override",
520
+ }
521
+ if len(da_dict) == 1 or "B03" not in da_dict:
522
+ return xr.concat(da_dict.values(), **kwargs).sortby("band_id") # type: ignore[call-overload]
523
+ da03 = da_dict.pop("B03")
524
+ da1 = xr.concat(da_dict.values(), **kwargs) # type: ignore[call-overload]
525
+ return geo_utils._coarsen_then_concat(da1, da03).sortby("band_id")
526
+
527
+
528
+ def _cartopy_crs(proj4_string: str) -> cartopy.crs.Geostationary:
529
+ try:
530
+ import pyproj
531
+ except ModuleNotFoundError as exc:
532
+ dependencies.raise_module_not_found_error(
533
+ name="Himawari visualization",
534
+ package_name="pyproj",
535
+ module_not_found_error=exc,
536
+ pycontrails_optional_package="sat",
537
+ )
538
+ try:
539
+ from cartopy import crs as ccrs
540
+ except ModuleNotFoundError as exc:
541
+ dependencies.raise_module_not_found_error(
542
+ name="Himawari visualization",
543
+ package_name="cartopy",
544
+ module_not_found_error=exc,
545
+ pycontrails_optional_package="sat",
546
+ )
547
+
548
+ crs_obj = pyproj.CRS(proj4_string)
549
+
550
+ with warnings.catch_warnings():
551
+ # pyproj warns that to_dict is lossy, but we built it ourselves so it's fine
552
+ warnings.filterwarnings("ignore", category=UserWarning)
553
+ crs_dict = crs_obj.to_dict()
554
+
555
+ globe = ccrs.Globe(
556
+ semimajor_axis=crs_dict["a"],
557
+ semiminor_axis=crs_dict["b"],
558
+ )
559
+ return ccrs.Geostationary(
560
+ central_longitude=crs_dict["lon_0"],
561
+ satellite_height=crs_dict["h"],
562
+ sweep_axis=crs_dict["sweep"],
563
+ globe=globe,
564
+ )
565
+
566
+
567
+ def extract_visualization(
568
+ da: xr.DataArray,
569
+ color_scheme: str = "ash",
570
+ ash_convention: str = "SEVIRI",
571
+ gamma: float = 2.2,
572
+ ) -> tuple[npt.NDArray[np.float32], cartopy.crs.Geostationary, tuple[float, float, float, float]]:
573
+ """Extract artifacts for visualizing Himawari data with the given color scheme.
574
+
575
+ Parameters
576
+ ----------
577
+ da : xr.DataArray
578
+ DataArray of Himawari data as returned by :meth:`Himawari.get`. Must have the channels
579
+ required by :func:`to_ash`.
580
+ color_scheme : str
581
+ Color scheme to use for visualization. Must be one of {"true", "ash"}.
582
+ If "true", the ``da`` must contain channels B01, B02, and B03.
583
+ If "ash", the ``da`` must contain channels B11, B14, and B15 (SEVIRI convention)
584
+ or channels B11, B13, B14, and B15 (standard convention).
585
+ ash_convention : str
586
+ Passed into :func:`to_ash`. Only used if ``color_scheme="ash"``. Must be one
587
+ of {"SEVIRI", "standard"}. By default, "SEVIRI" is used.
588
+ gamma : float
589
+ Passed into :func:`to_true_color`. Only used if ``color_scheme="true"``. By
590
+ default, 2.2 is used.
591
+
592
+ Returns
593
+ -------
594
+ rgb : npt.NDArray[np.float32]
595
+ 3D RGB array of shape ``(height, width, 3)``. Any nan values are replaced with 0.
596
+ src_crs : cartopy.crs.Geostationary
597
+ The Geostationary projection built from the Himawari metadata.
598
+ src_extent : tuple[float, float, float, float]
599
+ Extent of Himawari data in the Geostationary projection
600
+ """
601
+ proj4_string = da.attrs["crs"]
602
+ src_crs = _cartopy_crs(proj4_string)
603
+
604
+ if color_scheme == "true":
605
+ rgb = to_true_color(da, gamma)
606
+ elif color_scheme == "ash":
607
+ rgb = geo_utils.to_ash(da, ash_convention)
608
+ else:
609
+ raise ValueError(f"Color scheme must be 'true' or 'ash', not '{color_scheme}'")
610
+
611
+ np.nan_to_num(rgb, copy=False)
612
+
613
+ x = da["x"].values
614
+ y = da["y"].values
615
+
616
+ # Multiply extremes by the satellite height
617
+ h = src_crs.proj4_params["h"]
618
+ src_extent = h * x.min(), h * x.max(), h * y.min(), h * y.max()
619
+
620
+ return rgb, src_crs, src_extent
621
+
622
+
623
+ def to_true_color(da: xr.DataArray, gamma: float = 2.2) -> npt.NDArray[np.floating]:
624
+ """Compute 3d RGB array for the true color scheme.
625
+
626
+ Parameters
627
+ ----------
628
+ da : xr.DataArray
629
+ DataArray of GOES data with channels B01, B02, B03.
630
+ gamma : float, optional
631
+ Gamma correction for the RGB channels.
632
+
633
+ Returns
634
+ -------
635
+ npt.NDArray[np.floating]
636
+ 3d RGB array with true color scheme.
637
+
638
+ References
639
+ ----------
640
+ - https://www.jma.go.jp/jma/jma-eng/satellite/VLab/QG/RGB_QG_TrueColor_en.pdf
641
+ """
642
+ if not np.all(np.isin([1, 2, 3], da["band_id"])):
643
+ msg = "DataArray must contain bands 1, 2, and 3 for true color"
644
+ raise ValueError(msg)
645
+
646
+ red = da.sel(band_id=3).values
647
+ green = da.sel(band_id=2).values
648
+ blue = da.sel(band_id=1).values
649
+
650
+ red = geo_utils._clip_and_scale(red, 0.0, 1.0)
651
+ green = geo_utils._clip_and_scale(green, 0.0, 1.0)
652
+ blue = geo_utils._clip_and_scale(blue, 0.0, 1.0)
653
+
654
+ return np.dstack((red, green, blue)) ** (1 / gamma)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pycontrails
3
- Version: 0.56.0
3
+ Version: 0.57.0
4
4
  Summary: Python library for modeling aviation climate impacts
5
5
  Author-email: "Contrails.org" <py@contrails.org>
6
6
  License-Expression: Apache-2.0
@@ -74,7 +74,6 @@ Requires-Dist: google-cloud-storage>=2.1; extra == "gcp"
74
74
  Requires-Dist: platformdirs>=3.0; extra == "gcp"
75
75
  Requires-Dist: tqdm>=4.61; extra == "gcp"
76
76
  Provides-Extra: gfs
77
- Requires-Dist: boto3>=1.20; extra == "gfs"
78
77
  Requires-Dist: cfgrib>=0.9; extra == "gfs"
79
78
  Requires-Dist: eccodes>=2.38; extra == "gfs"
80
79
  Requires-Dist: netcdf4>=1.6.1; extra == "gfs"
@@ -94,6 +93,7 @@ Requires-Dist: google-cloud-bigquery-storage>=2.25; extra == "sat"
94
93
  Requires-Dist: pillow>=10.3; extra == "sat"
95
94
  Requires-Dist: pyproj>=3.5; extra == "sat"
96
95
  Requires-Dist: rasterio>=1.3; extra == "sat"
96
+ Requires-Dist: s3fs>=2022.3; extra == "sat"
97
97
  Requires-Dist: scikit-image>=0.18; extra == "sat"
98
98
  Requires-Dist: shapely>=2.0; extra == "sat"
99
99
  Provides-Extra: open3d