disdrodb 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. disdrodb/__init__.py +64 -34
  2. disdrodb/_config.py +5 -4
  3. disdrodb/_version.py +16 -3
  4. disdrodb/accessor/__init__.py +20 -0
  5. disdrodb/accessor/methods.py +125 -0
  6. disdrodb/api/checks.py +139 -9
  7. disdrodb/api/configs.py +4 -2
  8. disdrodb/api/info.py +10 -10
  9. disdrodb/api/io.py +237 -18
  10. disdrodb/api/path.py +81 -75
  11. disdrodb/api/search.py +6 -6
  12. disdrodb/cli/disdrodb_create_summary_station.py +91 -0
  13. disdrodb/cli/disdrodb_run_l0.py +1 -1
  14. disdrodb/cli/disdrodb_run_l0_station.py +1 -1
  15. disdrodb/cli/disdrodb_run_l0b.py +1 -1
  16. disdrodb/cli/disdrodb_run_l0b_station.py +1 -1
  17. disdrodb/cli/disdrodb_run_l0c.py +1 -1
  18. disdrodb/cli/disdrodb_run_l0c_station.py +1 -1
  19. disdrodb/cli/disdrodb_run_l2e_station.py +1 -1
  20. disdrodb/configs.py +149 -4
  21. disdrodb/constants.py +61 -0
  22. disdrodb/data_transfer/download_data.py +5 -5
  23. disdrodb/etc/configs/attributes.yaml +339 -0
  24. disdrodb/etc/configs/encodings.yaml +473 -0
  25. disdrodb/etc/products/L1/global.yaml +13 -0
  26. disdrodb/etc/products/L2E/10MIN.yaml +12 -0
  27. disdrodb/etc/products/L2E/1MIN.yaml +1 -0
  28. disdrodb/etc/products/L2E/global.yaml +22 -0
  29. disdrodb/etc/products/L2M/10MIN.yaml +12 -0
  30. disdrodb/etc/products/L2M/GAMMA_ML.yaml +8 -0
  31. disdrodb/etc/products/L2M/NGAMMA_GS_LOG_ND_MAE.yaml +6 -0
  32. disdrodb/etc/products/L2M/NGAMMA_GS_ND_MAE.yaml +6 -0
  33. disdrodb/etc/products/L2M/NGAMMA_GS_Z_MAE.yaml +6 -0
  34. disdrodb/etc/products/L2M/global.yaml +26 -0
  35. disdrodb/l0/__init__.py +13 -0
  36. disdrodb/l0/configs/LPM/l0b_cf_attrs.yml +4 -4
  37. disdrodb/l0/configs/PARSIVEL/l0b_cf_attrs.yml +1 -1
  38. disdrodb/l0/configs/PARSIVEL/l0b_encodings.yml +3 -3
  39. disdrodb/l0/configs/PARSIVEL/raw_data_format.yml +1 -1
  40. disdrodb/l0/configs/PARSIVEL2/l0b_cf_attrs.yml +5 -5
  41. disdrodb/l0/configs/PARSIVEL2/l0b_encodings.yml +3 -3
  42. disdrodb/l0/configs/PARSIVEL2/raw_data_format.yml +1 -1
  43. disdrodb/l0/configs/PWS100/l0b_cf_attrs.yml +4 -4
  44. disdrodb/l0/configs/PWS100/raw_data_format.yml +1 -1
  45. disdrodb/l0/l0a_processing.py +30 -30
  46. disdrodb/l0/l0b_nc_processing.py +108 -2
  47. disdrodb/l0/l0b_processing.py +4 -4
  48. disdrodb/l0/l0c_processing.py +5 -13
  49. disdrodb/l0/readers/LPM/NETHERLANDS/DELFT_LPM_NC.py +66 -0
  50. disdrodb/l0/readers/LPM/SLOVENIA/{CRNI_VRH.py → UL.py} +3 -0
  51. disdrodb/l0/readers/LPM/SWITZERLAND/INNERERIZ_LPM.py +195 -0
  52. disdrodb/l0/readers/PARSIVEL/GPM/PIERS.py +0 -2
  53. disdrodb/l0/readers/PARSIVEL/JAPAN/JMA.py +4 -1
  54. disdrodb/l0/readers/PARSIVEL/NCAR/PECAN_MOBILE.py +1 -1
  55. disdrodb/l0/readers/PARSIVEL/NCAR/VORTEX2_2009.py +1 -1
  56. disdrodb/l0/readers/PARSIVEL2/BELGIUM/ILVO.py +168 -0
  57. disdrodb/l0/readers/PARSIVEL2/DENMARK/DTU.py +165 -0
  58. disdrodb/l0/readers/PARSIVEL2/FINLAND/FMI_PARSIVEL2.py +69 -0
  59. disdrodb/l0/readers/PARSIVEL2/FRANCE/ENPC_PARSIVEL2.py +255 -134
  60. disdrodb/l0/readers/PARSIVEL2/FRANCE/OSUG.py +525 -0
  61. disdrodb/l0/readers/PARSIVEL2/FRANCE/SIRTA_PARSIVEL2.py +1 -1
  62. disdrodb/l0/readers/PARSIVEL2/GPM/GCPEX.py +9 -7
  63. disdrodb/l0/readers/PARSIVEL2/KIT/BURKINA_FASO.py +1 -1
  64. disdrodb/l0/readers/PARSIVEL2/KIT/TEAMX.py +123 -0
  65. disdrodb/l0/readers/PARSIVEL2/NASA/APU.py +120 -0
  66. disdrodb/l0/readers/PARSIVEL2/NCAR/FARM_PARSIVEL2.py +1 -0
  67. disdrodb/l0/readers/PARSIVEL2/NCAR/PECAN_FP3.py +1 -1
  68. disdrodb/l0/readers/PARSIVEL2/NCAR/PERILS_MIPS.py +126 -0
  69. disdrodb/l0/readers/PARSIVEL2/NCAR/PERILS_PIPS.py +165 -0
  70. disdrodb/l0/readers/PARSIVEL2/NCAR/VORTEX_SE_2016_P2.py +1 -1
  71. disdrodb/l0/readers/PARSIVEL2/NCAR/VORTEX_SE_2016_PIPS.py +20 -12
  72. disdrodb/l0/readers/PARSIVEL2/NETHERLANDS/DELFT_NC.py +2 -0
  73. disdrodb/l0/readers/PARSIVEL2/SPAIN/CENER.py +144 -0
  74. disdrodb/l0/readers/PARSIVEL2/SPAIN/CR1000DL.py +201 -0
  75. disdrodb/l0/readers/PARSIVEL2/SPAIN/LIAISE.py +137 -0
  76. disdrodb/l0/readers/PARSIVEL2/{NETHERLANDS/DELFT.py → USA/C3WE.py} +65 -85
  77. disdrodb/l0/readers/PWS100/FRANCE/ENPC_PWS100.py +105 -99
  78. disdrodb/l0/readers/PWS100/FRANCE/ENPC_PWS100_SIRTA.py +151 -0
  79. disdrodb/l0/routines.py +105 -14
  80. disdrodb/l1/__init__.py +5 -0
  81. disdrodb/l1/filters.py +34 -20
  82. disdrodb/l1/processing.py +45 -44
  83. disdrodb/l1/resampling.py +77 -66
  84. disdrodb/l1/routines.py +35 -43
  85. disdrodb/l1_env/routines.py +18 -3
  86. disdrodb/l2/__init__.py +7 -0
  87. disdrodb/l2/empirical_dsd.py +58 -10
  88. disdrodb/l2/event.py +27 -120
  89. disdrodb/l2/processing.py +267 -116
  90. disdrodb/l2/routines.py +618 -254
  91. disdrodb/metadata/standards.py +3 -1
  92. disdrodb/psd/fitting.py +463 -144
  93. disdrodb/psd/models.py +8 -5
  94. disdrodb/routines.py +3 -3
  95. disdrodb/scattering/__init__.py +16 -4
  96. disdrodb/scattering/axis_ratio.py +56 -36
  97. disdrodb/scattering/permittivity.py +486 -0
  98. disdrodb/scattering/routines.py +701 -159
  99. disdrodb/summary/__init__.py +17 -0
  100. disdrodb/summary/routines.py +4120 -0
  101. disdrodb/utils/attrs.py +68 -125
  102. disdrodb/utils/compression.py +30 -1
  103. disdrodb/utils/dask.py +59 -8
  104. disdrodb/utils/dataframe.py +61 -7
  105. disdrodb/utils/directories.py +35 -15
  106. disdrodb/utils/encoding.py +33 -19
  107. disdrodb/utils/logger.py +13 -6
  108. disdrodb/utils/manipulations.py +71 -0
  109. disdrodb/utils/subsetting.py +214 -0
  110. disdrodb/utils/time.py +165 -19
  111. disdrodb/utils/writer.py +20 -7
  112. disdrodb/utils/xarray.py +2 -4
  113. disdrodb/viz/__init__.py +13 -0
  114. disdrodb/viz/plots.py +327 -0
  115. {disdrodb-0.1.2.dist-info → disdrodb-0.1.3.dist-info}/METADATA +3 -2
  116. {disdrodb-0.1.2.dist-info → disdrodb-0.1.3.dist-info}/RECORD +121 -88
  117. {disdrodb-0.1.2.dist-info → disdrodb-0.1.3.dist-info}/entry_points.txt +1 -0
  118. disdrodb/l1/encoding_attrs.py +0 -642
  119. disdrodb/l2/processing_options.py +0 -213
  120. /disdrodb/l0/readers/PARSIVEL/SLOVENIA/{UL_FGG.py → UL.py} +0 -0
  121. {disdrodb-0.1.2.dist-info → disdrodb-0.1.3.dist-info}/WHEEL +0 -0
  122. {disdrodb-0.1.2.dist-info → disdrodb-0.1.3.dist-info}/licenses/LICENSE +0 -0
  123. {disdrodb-0.1.2.dist-info → disdrodb-0.1.3.dist-info}/top_level.txt +0 -0
disdrodb/utils/time.py CHANGED
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
33
33
  #### Sampling Interval Acronyms
34
34
 
35
35
 
36
- def seconds_to_acronym(seconds):
36
+ def seconds_to_temporal_resolution(seconds):
37
37
  """
38
38
  Convert a duration in seconds to a readable string format (e.g., "1H30", "1D2H").
39
39
 
@@ -57,27 +57,27 @@ def seconds_to_acronym(seconds):
57
57
  parts.append(f"{components.minutes}MIN")
58
58
  if components.seconds > 0:
59
59
  parts.append(f"{components.seconds}S")
60
- acronym = "".join(parts)
61
- return acronym
60
+ temporal_resolution = "".join(parts)
61
+ return temporal_resolution
62
62
 
63
63
 
64
- def get_resampling_information(sample_interval_acronym):
64
+ def get_resampling_information(temporal_resolution):
65
65
  """
66
- Extract resampling information from the sample interval acronym.
66
+ Extract resampling information from the temporal_resolution string.
67
67
 
68
68
  Parameters
69
69
  ----------
70
- sample_interval_acronym: str
71
- A string representing the sample interval: e.g., "1H30MIN", "ROLL1H30MIN".
70
+ temporal_resolution: str
71
+ A string representing the product temporal resolution: e.g., "1H30MIN", "ROLL1H30MIN".
72
72
 
73
73
  Returns
74
74
  -------
75
75
  sample_interval_seconds, rolling: tuple
76
76
  Sample_interval in seconds and whether rolling is enabled.
77
77
  """
78
- rolling = sample_interval_acronym.startswith("ROLL")
78
+ rolling = temporal_resolution.startswith("ROLL")
79
79
  if rolling:
80
- sample_interval_acronym = sample_interval_acronym[4:] # Remove "ROLL"
80
+ temporal_resolution = temporal_resolution[4:] # Remove "ROLL"
81
81
 
82
82
  # Allowed pattern: one or more occurrences of "<number><unit>"
83
83
  # where unit is exactly one of D, H, MIN, or S.
@@ -85,15 +85,15 @@ def get_resampling_information(sample_interval_acronym):
85
85
  pattern = r"^(\d+(?:D|H|MIN|S))+$"
86
86
 
87
87
  # Check if the entire string matches the pattern
88
- if not re.match(pattern, sample_interval_acronym):
88
+ if not re.match(pattern, temporal_resolution):
89
89
  raise ValueError(
90
- f"Invalid sample interval acronym '{sample_interval_acronym}'. "
90
+ f"Invalid temporal resolution '{temporal_resolution}'. "
91
91
  "Must be composed of one or more <number><unit> groups, where unit is D, H, MIN, or S.",
92
92
  )
93
93
 
94
94
  # Regular expression to match duration components and extract all (value, unit) pairs
95
95
  pattern = r"(\d+)(D|H|MIN|S)"
96
- matches = re.findall(pattern, sample_interval_acronym)
96
+ matches = re.findall(pattern, temporal_resolution)
97
97
 
98
98
  # Conversion factors for each unit
99
99
  unit_to_seconds = {
@@ -112,21 +112,21 @@ def get_resampling_information(sample_interval_acronym):
112
112
  return sample_interval, rolling
113
113
 
114
114
 
115
- def acronym_to_seconds(acronym):
115
+ def temporal_resolution_to_seconds(temporal_resolution):
116
116
  """
117
- Extract the interval in seconds from the duration acronym.
117
+ Extract the measurement interval in seconds from the temporal resolution string.
118
118
 
119
119
  Parameters
120
120
  ----------
121
- acronym: str
122
- A string representing a duration: e.g., "1H30MIN", "ROLL1H30MIN".
121
+ temporal_resolution: str
122
+ A string representing the product measurement interval: e.g., "1H30MIN", "ROLL1H30MIN".
123
123
 
124
124
  Returns
125
125
  -------
126
126
  seconds
127
127
  Duration in seconds.
128
128
  """
129
- seconds, _ = get_resampling_information(acronym)
129
+ seconds, _ = get_resampling_information(temporal_resolution)
130
130
  return seconds
131
131
 
132
132
 
@@ -262,6 +262,7 @@ def regularize_dataset(
262
262
  Regularized dataset.
263
263
 
264
264
  """
265
+ attrs = xr_obj.attrs.copy()
265
266
  xr_obj = _check_time_sorted(xr_obj, time_dim=time_dim)
266
267
  start_time, end_time = get_dataset_start_end_time(xr_obj, time_dim=time_dim)
267
268
 
@@ -289,11 +290,14 @@ def regularize_dataset(
289
290
  # tolerance=tolerance, # mismatch in seconds
290
291
  fill_value=fill_value,
291
292
  )
293
+
294
+ # Ensure attributes are preserved
295
+ xr_obj.attrs = attrs
292
296
  return xr_obj
293
297
 
294
298
 
295
299
  ####------------------------------------------
296
- #### Sampling interval utilities
300
+ #### Interval utilities
297
301
 
298
302
 
299
303
  def ensure_sample_interval_in_seconds(sample_interval): # noqa: PLR0911
@@ -376,7 +380,7 @@ def ensure_sample_interval_in_seconds(sample_interval): # noqa: PLR0911
376
380
  raise TypeError("Float array sample_interval must contain only whole numbers.")
377
381
  return sample_interval.astype(int)
378
382
 
379
- # Deal with xarray.DataArrayy of floats that are all integer-valued (with optionally some NaN)
383
+ # Deal with xarray.DataArray of floats that are all integer-valued (with optionally some NaN)
380
384
  if isinstance(sample_interval, xr.DataArray) and np.issubdtype(sample_interval.dtype, np.floating):
381
385
  arr = sample_interval.copy()
382
386
  data = arr.data
@@ -397,6 +401,17 @@ def ensure_sample_interval_in_seconds(sample_interval): # noqa: PLR0911
397
401
  )
398
402
 
399
403
 
404
+ def ensure_timedelta_seconds_interval(interval):
405
+ """Return interval as numpy.timedelta64 in seconds."""
406
+ if isinstance(interval, (xr.DataArray, np.ndarray)):
407
+ return ensure_sample_interval_in_seconds(interval).astype("m8[s]")
408
+ return np.array(ensure_sample_interval_in_seconds(interval), dtype="m8[s]")
409
+
410
+
411
+ ####------------------------------------------
412
+ #### Sample Interval Utilities
413
+
414
+
400
415
  def infer_sample_interval(ds, robust=False, verbose=False, logger=None):
401
416
  """Infer the sample interval of a dataset.
402
417
 
@@ -655,3 +670,134 @@ def regularize_timesteps(ds, sample_interval, robust=False, add_quality_flag=Tru
655
670
  ds = ds.isel(time=idx_valid_timesteps)
656
671
  # Return dataset
657
672
  return ds
673
+
674
+
675
+ ####---------------------------------------------------------------------------------
676
+ #### Time blocks
677
+
678
+
679
+ def check_freq(freq: str) -> None:
680
+ """Check validity of freq argument."""
681
+ valid_freq = ["none", "year", "season", "quarter", "month", "day", "hour"]
682
+ if not isinstance(freq, str):
683
+ raise TypeError("'freq' must be a string.")
684
+ if freq not in valid_freq:
685
+ raise ValueError(
686
+ f"'freq' '{freq}' is not possible. Must be one of: {valid_freq}.",
687
+ )
688
+ return freq
689
+
690
+
691
+ def generate_time_blocks(start_time: np.datetime64, end_time: np.datetime64, freq: str) -> np.ndarray: # noqa: PLR0911
692
+ """Generate time blocks between `start_time` and `end_time` for a given frequency.
693
+
694
+ Parameters
695
+ ----------
696
+ start_time : numpy.datetime64
697
+ Inclusive start of the overall time range.
698
+ end_time : numpy.datetime64
699
+ Inclusive end of the overall time range.
700
+ freq : str
701
+ Frequency specifier. Accepted values are:
702
+ - 'none' : return a single block [start_time, end_time]
703
+ - 'day' : split into daily blocks
704
+ - 'month' : split into calendar months
705
+ - 'quarter' : split into calendar quarters
706
+ - 'year' : split into calendar years
707
+ - 'season' : split into meteorological seasons (MAM, JJA, SON, DJF)
708
+
709
+ Returns
710
+ -------
711
+ numpy.ndarray
712
+ Array of shape (n, 2) with dtype datetime64[s], where each row is [block_start, block_end].
713
+
714
+ """
715
+ freq = check_freq(freq)
716
+ if freq == "none":
717
+ return np.array([[start_time, end_time]], dtype="datetime64[s]")
718
+
719
+ if freq == "hour":
720
+ periods = pd.period_range(start=start_time, end=end_time, freq="h")
721
+ blocks = np.array(
722
+ [
723
+ [
724
+ period.start_time.to_datetime64().astype("datetime64[s]"),
725
+ period.end_time.to_datetime64().astype("datetime64[s]"),
726
+ ]
727
+ for period in periods
728
+ ],
729
+ dtype="datetime64[s]",
730
+ )
731
+ return blocks
732
+
733
+ if freq == "day":
734
+ periods = pd.period_range(start=start_time, end=end_time, freq="d")
735
+ blocks = np.array(
736
+ [
737
+ [
738
+ period.start_time.to_datetime64().astype("datetime64[s]"),
739
+ period.end_time.to_datetime64().astype("datetime64[s]"),
740
+ ]
741
+ for period in periods
742
+ ],
743
+ dtype="datetime64[s]",
744
+ )
745
+ return blocks
746
+
747
+ if freq == "month":
748
+ periods = pd.period_range(start=start_time, end=end_time, freq="M")
749
+ blocks = np.array(
750
+ [
751
+ [
752
+ period.start_time.to_datetime64().astype("datetime64[s]"),
753
+ period.end_time.to_datetime64().astype("datetime64[s]"),
754
+ ]
755
+ for period in periods
756
+ ],
757
+ dtype="datetime64[s]",
758
+ )
759
+ return blocks
760
+
761
+ if freq == "year":
762
+ periods = pd.period_range(start=start_time, end=end_time, freq="Y")
763
+ blocks = np.array(
764
+ [
765
+ [
766
+ period.start_time.to_datetime64().astype("datetime64[s]"),
767
+ period.end_time.to_datetime64().astype("datetime64[s]"),
768
+ ]
769
+ for period in periods
770
+ ],
771
+ dtype="datetime64[s]",
772
+ )
773
+ return blocks
774
+
775
+ if freq == "quarter":
776
+ periods = pd.period_range(start=start_time, end=end_time, freq="Q")
777
+ blocks = np.array(
778
+ [
779
+ [
780
+ period.start_time.to_datetime64().astype("datetime64[s]"),
781
+ period.end_time.floor("s").to_datetime64().astype("datetime64[s]"),
782
+ ]
783
+ for period in periods
784
+ ],
785
+ dtype="datetime64[s]",
786
+ )
787
+ return blocks
788
+
789
+ if freq == "season":
790
+ # Fiscal quarter frequency ending in Feb → seasons DJF, MAM, JJA, SON
791
+ periods = pd.period_range(start=start_time, end=end_time, freq="Q-FEB")
792
+ blocks = np.array(
793
+ [
794
+ [
795
+ period.start_time.to_datetime64().astype("datetime64[s]"),
796
+ period.end_time.to_datetime64().astype("datetime64[s]"),
797
+ ]
798
+ for period in periods
799
+ ],
800
+ dtype="datetime64[s]",
801
+ )
802
+ return blocks
803
+ raise NotImplementedError(f"Frequency '{freq}' is not implemented.")
disdrodb/utils/writer.py CHANGED
@@ -22,11 +22,29 @@ import os
22
22
 
23
23
  import xarray as xr
24
24
 
25
- from disdrodb.utils.attrs import set_disdrodb_attrs
25
+ from disdrodb.utils.attrs import get_attrs_dict, set_attrs, set_disdrodb_attrs
26
26
  from disdrodb.utils.directories import create_directory, remove_if_exists
27
+ from disdrodb.utils.encoding import get_encodings_dict, set_encodings
27
28
 
28
29
 
29
- def write_product(ds: xr.Dataset, filepath: str, product: str, force: bool = False) -> None:
30
+ def finalize_product(ds, product=None) -> xr.Dataset:
31
+ """Finalize DISDRODB product."""
32
+ # Add variables attributes
33
+ attrs_dict = get_attrs_dict()
34
+ ds = set_attrs(ds, attrs_dict=attrs_dict)
35
+
36
+ # Add variables encoding
37
+ encodings_dict = get_encodings_dict()
38
+ ds = set_encodings(ds, encodings_dict=encodings_dict)
39
+
40
+ # Add DISDRODB global attributes
41
+ # - e.g. in generate_l2_radar it inherit from input dataset !
42
+ if product is not None:
43
+ ds = set_disdrodb_attrs(ds, product=product)
44
+ return ds
45
+
46
+
47
+ def write_product(ds: xr.Dataset, filepath: str, force: bool = False) -> None:
30
48
  """Save the xarray dataset into a NetCDF file.
31
49
 
32
50
  Parameters
@@ -35,8 +53,6 @@ def write_product(ds: xr.Dataset, filepath: str, product: str, force: bool = Fal
35
53
  Input xarray dataset.
36
54
  filepath : str
37
55
  Output file path.
38
- product: str
39
- DISDRODB product name.
40
56
  force : bool, optional
41
57
  Whether to overwrite existing data.
42
58
  If ``True``, overwrite existing data into destination directories.
@@ -50,8 +66,5 @@ def write_product(ds: xr.Dataset, filepath: str, product: str, force: bool = Fal
50
66
  # - If force=False --> Raise error
51
67
  remove_if_exists(filepath, force=force)
52
68
 
53
- # Update attributes
54
- ds = set_disdrodb_attrs(ds, product=product)
55
-
56
69
  # Write netcdf
57
70
  ds.to_netcdf(filepath, engine="netcdf4")
disdrodb/utils/xarray.py CHANGED
@@ -21,6 +21,8 @@ import numpy as np
21
21
  import xarray as xr
22
22
  from xarray.core import dtypes
23
23
 
24
+ from disdrodb.constants import DIAMETER_COORDS, VELOCITY_COORDS
25
+
24
26
 
25
27
  def xr_get_last_valid_idx(da_condition, dim, fill_value=None):
26
28
  """
@@ -246,13 +248,9 @@ def define_fill_value_dictionary(xr_obj):
246
248
 
247
249
  def remove_diameter_coordinates(xr_obj):
248
250
  """Drop diameter coordinates from xarray object."""
249
- from disdrodb import DIAMETER_COORDS
250
-
251
251
  return xr_obj.drop_vars(DIAMETER_COORDS, errors="ignore")
252
252
 
253
253
 
254
254
  def remove_velocity_coordinates(xr_obj):
255
255
  """Drop velocity coordinates from xarray object."""
256
- from disdrodb import VELOCITY_COORDS
257
-
258
256
  return xr_obj.drop_vars(VELOCITY_COORDS, errors="ignore")
disdrodb/viz/__init__.py CHANGED
@@ -15,3 +15,16 @@
15
15
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
16
  # -----------------------------------------------------------------------------.
17
17
  """DISDRODB Visualization Module."""
18
+ from disdrodb.viz.plots import (
19
+ compute_dense_lines,
20
+ max_blend_images,
21
+ plot_nd,
22
+ to_rgba,
23
+ )
24
+
25
+ __all__ = [
26
+ "compute_dense_lines",
27
+ "max_blend_images",
28
+ "plot_nd",
29
+ "to_rgba",
30
+ ]
disdrodb/viz/plots.py CHANGED
@@ -15,3 +15,330 @@
15
15
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
16
  # -----------------------------------------------------------------------------.
17
17
  """DISDRODB Plotting Tools."""
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import xarray as xr
21
+ from matplotlib.colors import LogNorm, Normalize
22
+
23
+
24
+ def plot_nd(ds, var="drop_number_concentration", cmap=None, norm=None):
25
+ """Plot drop number concentration N(D) timeseries."""
26
+ # Check inputs
27
+ if var not in ds:
28
+ raise ValueError(f"{var} is not a xarray Dataset variable!")
29
+ # Check only time and diameter dimensions are specified
30
+ # TODO: DIAMETER_DIMENSION, "time"
31
+
32
+ # Select N(D)
33
+ ds_var = ds[[var]].compute()
34
+
35
+ # Regularize input
36
+ ds_var = ds_var.disdrodb.regularize()
37
+
38
+ # Set 0 values to np.nan
39
+ ds_var = ds_var.where(ds_var[var] > 0)
40
+
41
+ # Define cmap an norm
42
+ if cmap is None:
43
+ cmap = plt.get_cmap("Spectral_r").copy()
44
+
45
+ vmin = ds_var[var].min().item()
46
+ norm = LogNorm(vmin, None) if norm is None else norm
47
+
48
+ # Plot N(D)
49
+ p = ds_var[var].plot.pcolormesh(x="time", norm=norm, cmap=cmap)
50
+ p.axes.set_title("Drop number concentration (N(D))")
51
+ p.axes.set_ylabel("Drop diameter (mm)")
52
+ return p
53
+
54
+
55
+ def normalize_array(arr, method="max"):
56
+ """Normalize a NumPy array according to the chosen method.
57
+
58
+ Parameters
59
+ ----------
60
+ arr : np.ndarray
61
+ Input array.
62
+ method : str
63
+ Normalization method. Options:
64
+ - 'max' : Divide by the maximum value.
65
+ - 'minmax': Scale to [0, 1] range.
66
+ - 'zscore': Standardize to mean 0, std 1.
67
+ - 'log' : Apply log10 transform (shifted if min <= 0).
68
+ - 'none' : No normalization (return original array).
69
+
70
+ Returns
71
+ -------
72
+ np.ndarray
73
+ Normalized array.
74
+ """
75
+ arr = np.asarray(arr, dtype=float)
76
+
77
+ if method == "max":
78
+ max_val = np.nanmax(arr)
79
+ return arr / max_val if max_val != 0 else arr
80
+
81
+ if method == "minmax":
82
+ min_val = np.nanmin(arr)
83
+ max_val = np.nanmax(arr)
84
+ return (arr - min_val) / (max_val - min_val) if max_val != min_val else np.zeros_like(arr)
85
+
86
+ if method == "zscore":
87
+ mean_val = np.nanmean(arr)
88
+ std_val = np.nanstd(arr)
89
+ return (arr - mean_val) / std_val if std_val != 0 else np.zeros_like(arr)
90
+
91
+ if method == "log":
92
+ min_val = np.nanmin(arr)
93
+ shifted = arr - min_val + 1e-12 # Shift to avoid log(0) or log of negative
94
+ return np.log10(shifted)
95
+
96
+ if method == "none":
97
+ return arr
98
+
99
+ raise ValueError(f"Unknown normalization method: {method}")
100
+
101
+
102
+ def _np_to_rgba_alpha(arr, cmap="viridis", cmap_norm=None, scaling="linear"):
103
+ """Convert a numpy array to an RGBA array with alpha based on array value.
104
+
105
+ Parameters
106
+ ----------
107
+ arr : numpy.ndarray
108
+ arr of counts or frequencies.
109
+ cmap : str or Colormap, optional
110
+ Matplotlib colormap to use for RGB channels.
111
+ cmap_norm: matplotlib.colors.Norm
112
+ Norm to be used to scale data before assigning cmap colors.
113
+ The default is Normalize(vmin, vmax).
114
+ scaling : str, optional
115
+ Scaling type for alpha mapping:
116
+ - "linear" : min-max normalization
117
+ - "log" : logarithmic normalization (positive values only)
118
+ - "sqrt" : square-root (power-law with exponent=0.5)
119
+ - "exp" : exponential scaling
120
+ - "quantile" : percentile-based scaling
121
+ - "none" : full opacity (alpha=1)
122
+
123
+ Returns
124
+ -------
125
+ rgba : 3D numpy array (ny, nx, 4)
126
+ RGBA array.
127
+ """
128
+ # Ensure numpy array
129
+ arr = np.asarray(arr, dtype=float)
130
+ # Define mask with NaN pixel
131
+ mask_na = np.isnan(arr)
132
+ # Retrieve array shape
133
+ ny, nx = arr.shape
134
+
135
+ # Define colormap norm
136
+ if cmap_norm is None:
137
+ cmap_norm = Normalize(vmin=np.nanmin(arr), vmax=np.nanmax(arr))
138
+
139
+ # Define alpha
140
+ if scaling == "linear":
141
+ norm = Normalize(vmin=np.nanmin(arr), vmax=np.nanmax(arr))
142
+ alpha = norm(arr)
143
+ elif scaling == "log":
144
+ vals = np.where(arr > 0, arr, np.nan) # mask non-positive
145
+ norm = LogNorm(vmin=np.nanmin(vals), vmax=np.nanmax(vals))
146
+ alpha = norm(arr)
147
+ alpha = np.nan_to_num(alpha, nan=0.0)
148
+ elif scaling == "sqrt":
149
+ alpha = np.sqrt(np.clip(arr, 0, None) / np.nanmax(arr))
150
+ elif scaling == "exp":
151
+ normed = np.clip(arr / np.nanmax(arr), 0, 1)
152
+ alpha = np.expm1(normed) / np.expm1(1)
153
+ elif scaling == "quantile":
154
+ flat = arr.ravel()
155
+ ranks = np.argsort(np.argsort(flat)) # rankdata without scipy
156
+ alpha = ranks / (len(flat) - 1)
157
+ alpha = alpha.reshape(arr.shape)
158
+ elif scaling == "none":
159
+ alpha = np.ones_like(arr, dtype=float)
160
+ else:
161
+ raise ValueError(f"Unknown scaling type: {scaling}")
162
+
163
+ # Map values to colors
164
+ cmap = plt.get_cmap(cmap).copy()
165
+ rgba = cmap(cmap_norm(arr))
166
+
167
+ # Set alpha channel
168
+ alpha[mask_na] = 0 # where input was NaN
169
+ rgba[..., -1] = np.clip(alpha, 0, 1)
170
+ return rgba
171
+
172
+
173
+ def to_rgba(obj, cmap="viridis", norm=None, scaling="none"):
174
+ """Map a xarray DataArray (or numpy array) to RGBA with optional alpha-scaling."""
175
+ input_is_xarray = False
176
+ if isinstance(obj, xr.DataArray):
177
+ # Define template for RGBA DataArray
178
+ da_rgba = obj.copy()
179
+ da_rgba = da_rgba.expand_dims({"rgba": 4}).transpose(..., "rgba")
180
+ input_is_xarray = True
181
+
182
+ # Extract numpy array
183
+ obj = obj.to_numpy()
184
+
185
+ # Apply transparency
186
+ arr = _np_to_rgba_alpha(obj, cmap=cmap, cmap_norm=norm, scaling=scaling)
187
+
188
+ # Return xarray.DataArray
189
+ if input_is_xarray:
190
+ da_rgba.data = arr
191
+ return da_rgba
192
+ # Or numpy array otherwise
193
+ return arr
194
+
195
+
196
+ def max_blend_images(ds_rgb, dim):
197
+ """Max blend a RGBA DataArray across a samples dimensions."""
198
+ # Ensure dimension to blend in first position
199
+ ds_rgb = ds_rgb.transpose(dim, ...)
200
+ # Extract numpy array
201
+ stack = ds_rgb.data
202
+ # Extract alpha array
203
+ alphas = stack[..., 3]
204
+ # Select the winning RGBA per pixel # (N, H, W)
205
+ idx = np.argmax(alphas, axis=0) # (H, W), index of image with max alpha
206
+ idx4 = np.repeat(idx[np.newaxis, ..., np.newaxis], 4, axis=-1) # (1, H, W, 4)
207
+ out = np.take_along_axis(stack, idx4, axis=0)[0] # (H, W, 4)
208
+ # Create output RGBA array
209
+ da = ds_rgb.isel({dim: 0}).copy()
210
+ da.data = out
211
+ return da
212
+
213
+
214
+ def compute_dense_lines(
215
+ da: xr.DataArray,
216
+ coord: str,
217
+ x_bins: list,
218
+ y_bins: list,
219
+ normalization="max",
220
+ ):
221
+ """
222
+ Compute a 2D density-of-lines histogram from an xarray.DataArray.
223
+
224
+ Parameters
225
+ ----------
226
+ da : xarray.DataArray
227
+ Input data array. One of its dimensions (named by ``coord``) is taken
228
+ as the horizontal coordinate. All other dimensions are collapsed into
229
+ “series,” so that each combination of the remaining dimension values
230
+ produces one 1D line along ``coord``.
231
+ coord : str
232
+ The name of the coordinate/dimension of the DataArray to bin over.
233
+ ``da.coords[coord]`` must be a 1D numeric array (monotonic is recommended).
234
+ x_bins : array_like of shape (nx+1,)
235
+ Bin edges to bin the coordinate/dimension.
236
+ Must be monotonically increasing.
237
+ The number of x-bins will be ``nx = len(x_bins) - 1``.
238
+ y_bins : array_like of shape (ny+1,)
239
+ Bin edges for the DataArray values.
240
+ Must be monotonically increasing.
241
+ The number of y-bins will be ``ny = len(y_bins) - 1``.
242
+ normalization : bool, optional
243
+ If 'none', returns the raw histogram.
244
+ By default, the function normalize the histogram by its global maximum ('max').
245
+ Log-normalization ('log') is also available.
246
+
247
+ Returns
248
+ -------
249
+ xr.DataArray
250
+ 2D histogram of shape ``(ny, nx)``. Dimensions are ``('y', 'x')``, where:
251
+
252
+ - ``x``: the bin-center coordinate of ``x_bins`` (length ``nx``)
253
+ - ``y``: the bin-center coordinate of ``y_bins`` (length ``ny``)
254
+
255
+ Each element ``out.values[y_i, x_j]`` is the count (or normalized count) of how
256
+ many “series-values” from ``da`` fell into the rectangular bin
257
+ ``x_bins[j] ≤ x_value < x_bins[j+1]`` and
258
+ ``y_bins[i] ≤ data_value < y_bins[i+1]``.
259
+
260
+ References
261
+ ----------
262
+ Moritz, D., Fisher, D. (2018).
263
+ Visualizing a Million Time Series with the Density Line Chart
264
+ https://doi.org/10.48550/arXiv.1808.06019
265
+ """
266
+ # Check DataArray name
267
+ if da.name is None or da.name == "":
268
+ raise ValueError("The DataArray must have a name.")
269
+
270
+ # Validate x_bins and y_bins
271
+ x_bins = np.asarray(x_bins)
272
+ y_bins = np.asarray(y_bins)
273
+ if x_bins.ndim != 1 or x_bins.size < 2:
274
+ raise ValueError("`x_bins` must be a 1D array with at least two edges.")
275
+ if y_bins.ndim != 1 or y_bins.size < 2:
276
+ raise ValueError("`y_bins` must be a 1D array with at least two edges.")
277
+ if not np.all(np.diff(x_bins) > 0):
278
+ raise ValueError("`x_bins` must be strictly increasing.")
279
+ if not np.all(np.diff(y_bins) > 0):
280
+ raise ValueError("`y_bins` must be strictly increasing.")
281
+
282
+ # Verify that `coord` exists as either a dimension or a coordinate
283
+ if coord not in (list(da.coords) + list(da.dims)):
284
+ raise ValueError(f"'{coord}' is not a dimension or coordinate of the DataArray.")
285
+ if coord not in da.dims:
286
+ if da[coord].ndim != 1:
287
+ raise ValueError(f"Coordinate '{coord}' must be 1D. Instead has dimensions {da[coord].dims}")
288
+ x_dim = da[coord].dims[0]
289
+ else:
290
+ x_dim = coord
291
+
292
+ # Extract the coordinate array
293
+ x_values = (x_bins[0:-1] + x_bins[1:]) / 2
294
+
295
+ # Extract the array (samples, x)
296
+ other_dims = [d for d in da.dims if d != x_dim]
297
+ if len(other_dims) == 1:
298
+ arr = da.transpose(*other_dims, x_dim).to_numpy()
299
+ else:
300
+ arr = da.stack({"sample": other_dims}).transpose("sample", x_dim).to_numpy()
301
+
302
+ # Define y bins center
303
+ y_center = (y_bins[0:-1] + y_bins[1:]) / 2
304
+
305
+ # Prepare the 2D count grid of shape (ny, nx)
306
+ # - ny correspond tot he value of the timeseries at nx points
307
+ nx = len(x_bins) - 1
308
+ ny = len(y_bins) - 1
309
+ nsamples = arr.shape[0]
310
+ grid = np.zeros((ny, nx), dtype=float)
311
+
312
+ # For each (series, x-index), find which y-bin it falls into:
313
+ # - np.searchsorted(y_bins, value) gives the insertion index in y_bins;
314
+ # --> subtracting 1 yields the bin index.
315
+ # If a value is not in y_bins, searchsorted returns 0, so idx = -1
316
+ indices = np.searchsorted(y_bins, arr) - 1 # (samples, nx)
317
+
318
+ # Assign 1 when line pass in a bin
319
+ valid = (indices >= 0) & (indices < ny)
320
+ s_idx, x_idx = np.nonzero(valid)
321
+ y_idx = indices[valid]
322
+ grid_3d = np.zeros((nsamples, ny, nx), dtype=int)
323
+ grid_3d[s_idx, y_idx, x_idx] = 1
324
+
325
+ # Normalize by columns
326
+ col_sums = grid_3d.sum(axis=1, keepdims=True)
327
+ col_sums[col_sums == 0] = 1 # Avoid division by zero
328
+ grid_3d = grid_3d / col_sums
329
+
330
+ # Normalize over samples
331
+ grid = grid_3d.sum(axis=0)
332
+
333
+ # Normalize grid
334
+ grid = normalize_array(grid, method=normalization)
335
+
336
+ # Create DataArray
337
+ name = da.name
338
+ out = xr.DataArray(grid, dims=[name, coord], coords={coord: (coord, x_values), name: (name, y_center)})
339
+
340
+ # Mask values which are 0 with NaN
341
+ out = out.where(out > 0)
342
+
343
+ # Return 2D histogram
344
+ return out