disdrodb 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. disdrodb/__init__.py +4 -0
  2. disdrodb/_version.py +2 -2
  3. disdrodb/accessor/methods.py +14 -0
  4. disdrodb/api/checks.py +8 -7
  5. disdrodb/api/io.py +81 -29
  6. disdrodb/api/path.py +17 -14
  7. disdrodb/api/search.py +15 -18
  8. disdrodb/cli/disdrodb_open_products_options.py +38 -0
  9. disdrodb/cli/disdrodb_run.py +2 -2
  10. disdrodb/cli/disdrodb_run_station.py +4 -4
  11. disdrodb/configs.py +1 -1
  12. disdrodb/data_transfer/download_data.py +70 -1
  13. disdrodb/etc/configs/attributes.yaml +62 -8
  14. disdrodb/etc/configs/encodings.yaml +28 -0
  15. disdrodb/etc/products/L2M/MODELS/GAMMA_GS_ND_SSE.yaml +8 -0
  16. disdrodb/etc/products/L2M/MODELS/GAMMA_ML.yaml +1 -1
  17. disdrodb/etc/products/L2M/MODELS/LOGNORMAL_GS_LOG_ND_SSE.yaml +8 -0
  18. disdrodb/etc/products/L2M/MODELS/LOGNORMAL_GS_ND_SSE.yaml +8 -0
  19. disdrodb/etc/products/L2M/MODELS/LOGNORMAL_ML.yaml +1 -1
  20. disdrodb/etc/products/L2M/MODELS/NGAMMA_GS_LOG_ND_SSE.yaml +8 -0
  21. disdrodb/etc/products/L2M/MODELS/NGAMMA_GS_ND_SSE.yaml +8 -0
  22. disdrodb/etc/products/L2M/global.yaml +4 -4
  23. disdrodb/fall_velocity/graupel.py +8 -8
  24. disdrodb/fall_velocity/hail.py +2 -2
  25. disdrodb/fall_velocity/rain.py +33 -5
  26. disdrodb/issue/checks.py +1 -1
  27. disdrodb/l0/l0_reader.py +1 -1
  28. disdrodb/l0/l0a_processing.py +2 -2
  29. disdrodb/l0/l0b_nc_processing.py +5 -5
  30. disdrodb/l0/l0b_processing.py +20 -24
  31. disdrodb/l0/l0c_processing.py +18 -13
  32. disdrodb/l0/readers/LPM/SLOVENIA/ARSO.py +4 -0
  33. disdrodb/l0/readers/PARSIVEL2/VIETNAM/IGE_PARSIVEL2.py +239 -0
  34. disdrodb/l0/template_tools.py +13 -13
  35. disdrodb/l1/classification.py +10 -6
  36. disdrodb/l2/empirical_dsd.py +25 -15
  37. disdrodb/l2/processing.py +32 -14
  38. disdrodb/metadata/download.py +1 -1
  39. disdrodb/metadata/geolocation.py +4 -4
  40. disdrodb/metadata/reader.py +3 -3
  41. disdrodb/metadata/search.py +10 -8
  42. disdrodb/psd/__init__.py +4 -0
  43. disdrodb/psd/fitting.py +2660 -592
  44. disdrodb/psd/gof_metrics.py +389 -0
  45. disdrodb/psd/grid_search.py +1066 -0
  46. disdrodb/psd/models.py +1281 -145
  47. disdrodb/routines/l2.py +6 -6
  48. disdrodb/routines/options_validation.py +8 -8
  49. disdrodb/scattering/axis_ratio.py +70 -2
  50. disdrodb/scattering/permittivity.py +13 -10
  51. disdrodb/scattering/routines.py +10 -10
  52. disdrodb/summary/routines.py +23 -20
  53. disdrodb/utils/archiving.py +29 -22
  54. disdrodb/utils/attrs.py +6 -4
  55. disdrodb/utils/dataframe.py +4 -4
  56. disdrodb/utils/encoding.py +3 -1
  57. disdrodb/utils/event.py +9 -9
  58. disdrodb/utils/logger.py +4 -7
  59. disdrodb/utils/manipulations.py +2 -2
  60. disdrodb/utils/subsetting.py +1 -1
  61. disdrodb/utils/time.py +8 -7
  62. disdrodb/viz/plots.py +25 -17
  63. {disdrodb-0.4.0.dist-info → disdrodb-0.5.1.dist-info}/METADATA +44 -33
  64. {disdrodb-0.4.0.dist-info → disdrodb-0.5.1.dist-info}/RECORD +68 -66
  65. {disdrodb-0.4.0.dist-info → disdrodb-0.5.1.dist-info}/WHEEL +1 -1
  66. {disdrodb-0.4.0.dist-info → disdrodb-0.5.1.dist-info}/entry_points.txt +1 -0
  67. disdrodb/etc/products/L2M/MODELS/GAMMA_GS_ND_MAE.yaml +0 -6
  68. disdrodb/etc/products/L2M/MODELS/LOGNORMAL_GS_LOG_ND_MAE.yaml +0 -6
  69. disdrodb/etc/products/L2M/MODELS/LOGNORMAL_GS_ND_MAE.yaml +0 -6
  70. disdrodb/etc/products/L2M/MODELS/NGAMMA_GS_LOG_ND_MAE.yaml +0 -6
  71. disdrodb/etc/products/L2M/MODELS/NGAMMA_GS_ND_MAE.yaml +0 -6
  72. disdrodb/etc/products/L2M/MODELS/NGAMMA_GS_R_MAE.yaml +0 -6
  73. disdrodb/etc/products/L2M/MODELS/NGAMMA_GS_Z_MAE.yaml +0 -6
  74. {disdrodb-0.4.0.dist-info → disdrodb-0.5.1.dist-info}/licenses/LICENSE +0 -0
  75. {disdrodb-0.4.0.dist-info → disdrodb-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1066 @@
1
+ # -----------------------------------------------------------------------------.
2
+ # Copyright (c) 2021-2026 DISDRODB developers
3
+ #
4
+ # This program is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # This program is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
+ # -----------------------------------------------------------------------------.
17
+ """Routines for grid search optimization."""
18
+
19
+ import numpy as np
20
+
21
+ DISTRIBUTION_TARGETS = {"N(D)", "H(x)"}
22
+ MOMENTS = {"M0", "M1", "M2", "M3", "M4", "M5", "M6"}
23
+ INTEGRAL_TARGETS = {"Z", "R", "LWC"} | MOMENTS
24
+ TARGETS = DISTRIBUTION_TARGETS | INTEGRAL_TARGETS
25
+
26
+ TRANSFORMATIONS = {"identity", "log", "sqrt"}
27
+ CENSORING = {"none", "left", "right", "both"}
28
+
29
+ DISTRIBUTION_METRICS = {"SSE", "SAE", "MAE", "MSE", "RMSE", "relMAE", "KLDiv", "WD", "JSD", "KS"}
30
+ INTEGRAL_METRICS = {"SE", "AE"}
31
+ ERROR_METRICS = DISTRIBUTION_METRICS | INTEGRAL_METRICS
32
+
33
+
34
+ def check_target(target):
35
+ """Check valid target argument."""
36
+ valid_targets = TARGETS
37
+ if target not in valid_targets:
38
+ raise ValueError(f"Invalid 'target' {target}. Valid targets are {valid_targets}.")
39
+ return target
40
+
41
+
42
+ def check_censoring(censoring):
43
+ """Check valid censoring argument."""
44
+ valid_censoring = CENSORING
45
+ if censoring not in valid_censoring:
46
+ raise ValueError(f"Invalid 'censoring' {censoring}. Valid options are {valid_censoring}.")
47
+ return censoring
48
+
49
+
50
+ def check_transformation(transformation):
51
+ """Check valid transformation argument."""
52
+ valid_transformation = TRANSFORMATIONS
53
+ if transformation not in valid_transformation:
54
+ raise ValueError(
55
+ f"Invalid 'transformation' {transformation}. Valid options are {valid_transformation}.",
56
+ )
57
+ return transformation
58
+
59
+
60
+ def check_loss(loss, valid_metrics=ERROR_METRICS):
61
+ """Check valid loss argument."""
62
+ if loss not in valid_metrics:
63
+ raise ValueError(f"Invalid 'loss' {loss}. Valid options are {valid_metrics}.")
64
+ return loss
65
+
66
+
67
+ def check_valid_loss(loss, target):
68
+ """Check if loss is valid for the given target.
69
+
70
+ For distribution targets (ND, H(x)), any error metric is valid.
71
+ For scalar targets (Z, R, LWC), distribution error metrics are not valid.
72
+
73
+ Parameters
74
+ ----------
75
+ loss : str
76
+ The error metric to validate.
77
+ target : str
78
+ The target variable type.
79
+
80
+ Returns
81
+ -------
82
+ str
83
+ The validated loss.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If loss is not valid for the given target.
89
+ """
90
+ if target in {"N(D)", "H(x)"}:
91
+ return check_loss(loss, valid_metrics=DISTRIBUTION_METRICS)
92
+ # Integral N(D) target (Z, R, LWC, M1, ..., M6)
93
+ return check_loss(loss, valid_metrics=INTEGRAL_METRICS)
94
+
95
+
96
+ def check_loss_weight(loss_weight):
97
+ """Check valid loss_weight argument."""
98
+ if loss_weight <= 0:
99
+ raise ValueError(f"Invalid 'loss_weight' {loss_weight}. Must be greater than 0.")
100
+ return loss_weight
101
+
102
+
103
+ def check_objective(objective):
104
+ """Check objective validity."""
105
+ # Check required keys are present
106
+ required_keys = {"target", "transformation", "censoring", "loss"}
107
+ missing_keys = required_keys - set(objective.keys())
108
+ if missing_keys:
109
+ raise ValueError(
110
+ f"Objective {objective} is missing required keys: {missing_keys}. " f"Required keys are: {required_keys}",
111
+ )
112
+
113
+ # Validate target
114
+ objective["target"] = check_target(objective["target"])
115
+
116
+ # Validate transformation
117
+ objective["transformation"] = check_transformation(objective["transformation"])
118
+
119
+ # Validate censoring
120
+ objective["censoring"] = check_censoring(objective["censoring"])
121
+
122
+ # Validate loss and check compatibility with target
123
+ objective["loss"] = check_loss(objective["loss"])
124
+ objective["loss"] = check_valid_loss(objective["loss"], objective["target"])
125
+ return objective
126
+
127
+
128
+ def check_objectives(objectives):
129
+ """Validate and normalize objectives for grid search optimization.
130
+
131
+ Parameters
132
+ ----------
133
+ objectives : list of dict
134
+ List of objective dictionaries, each containing:
135
+ - 'target' : str, Target variable (N(D), H(x), R, Z, LWC, or M<p>)
136
+ - 'transformation' : str, Transformation type (identity, log, sqrt)
137
+ - 'censoring' : str, Censoring type (none, left, right, both)
138
+ - 'loss' : str, Error metric (SSE, SAE, MAE, MSE, RMSE, etc.)
139
+ - 'loss_weight' : float, optional, Weight for weighted optimization (auto-set to 1.0 for single objective)
140
+
141
+ Returns
142
+ -------
143
+ list of dict
144
+ Validated objectives. Loss weights are not normalized.
145
+ """
146
+ if objectives is None:
147
+ return None
148
+ if not isinstance(objectives, list) or len(objectives) == 0:
149
+ raise TypeError("'objectives' must be a non-empty list of dictionaries.")
150
+
151
+ if not all(isinstance(obj, dict) for obj in objectives):
152
+ raise TypeError("All items in 'objectives' must be dictionaries.")
153
+
154
+ # Validate each objective
155
+ for idx, objective in enumerate(objectives):
156
+ objectives[idx] = check_objective(objective)
157
+
158
+ # Handle loss_weight
159
+ if len(objectives) == 1:
160
+ # Single objective: auto-set weight to 1.0 if not provided
161
+ if "loss_weight" in objectives[0]:
162
+ raise ValueError(
163
+ "'loss_weight' should be specified only if multiple objectives are used.",
164
+ )
165
+ else:
166
+ # Multiple objectives: verify all have weights and normalize
167
+ for objective in objectives:
168
+ if "loss_weight" not in objective:
169
+ raise ValueError(
170
+ f"Objective {objective} is missing 'loss_weight'. "
171
+ f"When using multiple objectives, all must have 'loss_weight' specified.",
172
+ )
173
+ objective["loss_weight"] = check_loss_weight(objective["loss_weight"])
174
+ return objectives
175
+
176
+
177
+ ####---------------------------------------------------------------------------
178
+ #### Targets
179
+
180
+
181
+ def compute_rain_rate(ND, D, dD, V):
182
+ """Compute rain rate from drop size distribution.
183
+
184
+ Parameters
185
+ ----------
186
+ ND : numpy.ndarray
187
+ Drop size distribution [#/m3/mm-1]. Can be 1D [n_bins] or 2D [n_samples, n_bins].
188
+ D : numpy.ndarray
189
+ Diameter bin centers in mm [n_bins]
190
+ dD : numpy.ndarray
191
+ Diameter bin width in mm [n_bins]
192
+ V : numpy.ndarray
193
+ Terminal velocity [n_bins] [m/s]
194
+
195
+ Returns
196
+ -------
197
+ numpy.ndarray
198
+ Rain rate [mm/h]
199
+ """
200
+ axis = 1 if ND.ndim == 2 else None
201
+ rain_rate = np.pi / 6 * np.sum(ND * V * (D / 1000) ** 3 * dD, axis=axis) * 3600 * 1000
202
+ return rain_rate # mm/h
203
+
204
+
205
+ def compute_lwc(ND, D, dD, rho_w=1000):
206
+ """Compute liquid water content from drop size distribution.
207
+
208
+ Parameters
209
+ ----------
210
+ ND : numpy.ndarray
211
+ Drop size distribution [#/m3/mm-1]. Can be 1D [n_bins] or 2D [n_samples, n_bins].
212
+ D : numpy.ndarray
213
+ Diameter bin centers in mm [n_bins]
214
+ dD : numpy.ndarray
215
+ Diameter bin width in mm [n_bins]
216
+ rho_w : float, optional
217
+ Water density [kg/m3]. Default is 1000.
218
+
219
+ Returns
220
+ -------
221
+ numpy.ndarray
222
+ Liquid water content [g/m3]
223
+ """
224
+ axis = 1 if ND.ndim == 2 else None
225
+ lwc = np.pi / 6.0 * (rho_w * 1000) * np.sum((D / 1000) ** 3 * ND * dD, axis=axis)
226
+ return lwc # g/m3
227
+
228
+
229
+ def compute_moment(ND, order, D, dD):
230
+ """Compute moment of the drop size distribution.
231
+
232
+ Parameters
233
+ ----------
234
+ ND : numpy.ndarray
235
+ Drop size distribution [#/m3/mm-1]. Can be 1D [n_bins] or 2D [n_samples, n_bins].
236
+ order : int
237
+ Moment order.
238
+ D : numpy.ndarray
239
+ Diameter bin centers in mm [n_bins]
240
+ dD : numpy.ndarray
241
+ Diameter bin width in mm [n_bins]
242
+
243
+ Returns
244
+ -------
245
+ numpy.ndarray
246
+ Moment of the specified order [mm^order·m^-3]
247
+ """
248
+ axis = 1 if ND.ndim == 2 else None
249
+ return np.sum((D**order * ND * dD), axis=axis) # mm**order·m⁻³
250
+
251
+
252
+ def compute_z(ND, D, dD):
253
+ """Compute radar reflectivity from drop size distribution.
254
+
255
+ Parameters
256
+ ----------
257
+ ND : numpy.ndarray
258
+ Drop size distribution [#/m3/mm-1]. Can be 1D [n_bins] or 2D [n_samples, n_bins].
259
+ D : numpy.ndarray
260
+ Diameter bin centers in mm [n_bins]
261
+ dD : numpy.ndarray
262
+ Diameter bin width in mm [n_bins]
263
+
264
+ Returns
265
+ -------
266
+ numpy.ndarray
267
+ Radar reflectivity [dBZ]
268
+ """
269
+ z = compute_moment(ND, order=6, D=D, dD=dD) # mm⁶·m⁻³
270
+ Z = 10 * np.log10(np.where(z > 0, z, np.nan))
271
+ return Z
272
+
273
+
274
+ def compute_target_variable(
275
+ target,
276
+ ND_obs,
277
+ ND_preds,
278
+ D,
279
+ dD,
280
+ V,
281
+ ):
282
+ """Compute target variable from drop size distribution.
283
+
284
+ Parameters
285
+ ----------
286
+ target : str
287
+ Target variable type. Can be 'Z', 'R', 'LWC', moments ('M0'-'M6'), 'N(D)', or 'H(x)'.
288
+ ND_obs : numpy.ndarray
289
+ Observed drop size distribution of shape (n_bins, ) with units [#/m3/mm-1]
290
+ ND_preds : numpy.ndarray
291
+ Predicted drop size distributions of shape (n_samples, n_bins) with units [#/m3/mm-1]
292
+ D : numpy.ndarray
293
+ Diameter bin centers in mm [n_bins]
294
+ dD : numpy.ndarray
295
+ Diameter bin width in mm [n_bins]
296
+ V : numpy.ndarray
297
+ Terminal velocity [n_bins] [m/s]
298
+
299
+ Returns
300
+ -------
301
+ tuple
302
+ (obs, pred) where obs is 1D [n_bins] or scalar, and pred is 2D [n_samples, n_bins] or 1D [n_samples]
303
+ """
304
+ # Compute observed and predicted target variables
305
+ if target == "Z":
306
+ obs = compute_z(ND_obs, D=D, dD=dD)
307
+ pred = compute_z(ND_preds, D=D, dD=dD)
308
+ elif target == "R":
309
+ obs = compute_rain_rate(ND_obs, D=D, dD=dD, V=V)
310
+ pred = compute_rain_rate(ND_preds, D=D, dD=dD, V=V)
311
+ elif target == "LWC":
312
+ obs = compute_lwc(ND_obs, D=D, dD=dD)
313
+ pred = compute_lwc(ND_preds, D=D, dD=dD)
314
+ elif target in MOMENTS:
315
+ order = int(target[1])
316
+ obs = compute_moment(ND_obs, order=order, D=D, dD=dD)
317
+ pred = compute_moment(ND_preds, order=order, D=D, dD=dD)
318
+ else: # N(D) or H(x)
319
+ obs = ND_obs
320
+ pred = ND_preds
321
+ return obs, pred
322
+
323
+
324
+ ####---------------------------------------------------------------------------
325
+ #### Censoring
326
+
327
+
328
+ def left_truncate_bins(ND_obs, ND_preds, D, dD, V):
329
+ """Truncate left side of bins (smallest diameters) to first non-zero bin.
330
+
331
+ Parameters
332
+ ----------
333
+ ND_obs : numpy.ndarray
334
+ Observed drop size distribution of shape (n_bins, ) with units [#/m3/mm-1]
335
+ ND_preds : numpy.ndarray
336
+ Predicted drop size distributions of shape (n_samples, n_bins) with units [#/m3/mm-1]
337
+ D : numpy.ndarray
338
+ Diameter bin center in mm [n_bins]
339
+ dD : numpy.ndarray
340
+ Diameter bin width in mm [n_bins]
341
+ V : numpy.ndarray
342
+ Terminal velocity [n_bins]
343
+
344
+ Returns
345
+ -------
346
+ tuple or None
347
+ (ND_obs_trunc, ND_preds_trunc, D_trunc, dD_trunc, V_trunc) or None if all zeros
348
+ """
349
+ if np.all(ND_obs == 0): # all zeros
350
+ return None
351
+ idx = np.argmax(ND_obs > 0)
352
+ return (
353
+ ND_obs[idx:],
354
+ ND_preds[:, idx:],
355
+ D[idx:],
356
+ dD[idx:],
357
+ V[idx:],
358
+ )
359
+
360
+
361
+ def right_truncate_bins(ND_obs, ND_preds, D, dD, V):
362
+ """Truncate right side of bins (largest diameters) to last non-zero bin.
363
+
364
+ Parameters
365
+ ----------
366
+ ND_obs : numpy.ndarray
367
+ Observed drop size distribution of shape (n_bins, ) with units [#/m3/mm-1]
368
+ ND_preds : numpy.ndarray
369
+ Predicted drop size distributions of shape (n_samples, n_bins) with units [#/m3/mm-1]
370
+ D : numpy.ndarray
371
+ Diameter bin center in mm [n_bins]
372
+ dD : numpy.ndarray
373
+ Diameter bin width in mm [n_bins]
374
+ V : numpy.ndarray
375
+ Terminal velocity [n_bins]
376
+
377
+ Returns
378
+ -------
379
+ tuple or None
380
+ (ND_obs_trunc, ND_preds_trunc, D_trunc, dD_trunc, V_trunc) or None if all zeros
381
+ """
382
+ if np.all(ND_obs == 0): # all zeros
383
+ return None
384
+ idx = len(ND_obs) - np.argmax(ND_obs[::-1] > 0)
385
+ return (
386
+ ND_obs[:idx],
387
+ ND_preds[:, :idx],
388
+ D[:idx],
389
+ dD[:idx],
390
+ V[:idx],
391
+ )
392
+
393
+
394
+ def truncate_bin_edges(
395
+ ND_obs,
396
+ ND_preds,
397
+ D,
398
+ dD,
399
+ V,
400
+ left_censored=False,
401
+ right_censored=False,
402
+ ):
403
+ """Truncate bin edges based on censoring strategy.
404
+
405
+ Parameters
406
+ ----------
407
+ ND_obs : numpy.ndarray
408
+ Observed drop size distribution of shape (n_bins, ) with units [#/m3/mm-1]
409
+ ND_preds : numpy.ndarray
410
+ Predicted drop size distributions of shape (n_samples, n_bins) with units [#/m3/mm-1]
411
+ D : numpy.ndarray
412
+ Diameter bin center in mm [n_bins]
413
+ dD : numpy.ndarray
414
+ Diameter bin width in mm [n_bins]
415
+ V : numpy.ndarray
416
+ Terminal velocity [n_bins]
417
+ left_censored : bool, optional
418
+ If True, truncate from the left (remove small diameter bins). Default is False.
419
+ right_censored : bool, optional
420
+ If True, truncate from the right (remove large diameter bins). Default is False.
421
+
422
+ Returns
423
+ -------
424
+ tuple or None
425
+ (ND_obs_trunc, ND_preds_trunc, D_trunc, dD_trunc, V_trunc) or None if all zeros
426
+ """
427
+ data = (ND_obs, ND_preds, D, dD, V)
428
+ if left_censored:
429
+ data = left_truncate_bins(*data)
430
+ if data is None:
431
+ return None
432
+ if right_censored:
433
+ data = right_truncate_bins(*data)
434
+ if data is None:
435
+ return None
436
+ return data
437
+
438
+
439
+ ####---------------------------------------------------------------------------
440
+ #### Transformation
441
+
442
+
443
+ def apply_transformation(obs, pred, transformation):
444
+ """Apply transformation to observed and predicted values.
445
+
446
+ Parameters
447
+ ----------
448
+ obs : numpy.ndarray
449
+ Observed values
450
+ pred : numpy.ndarray
451
+ Predicted values
452
+ transformation : str
453
+ Transformation type: 'identity', 'log', or 'sqrt'.
454
+
455
+ Returns
456
+ -------
457
+ tuple
458
+ (obs_transformed, pred_transformed)
459
+ """
460
+ if transformation == "log":
461
+ return np.log(obs + 1), np.log(pred + 1)
462
+ if transformation == "sqrt":
463
+ return np.sqrt(obs), np.sqrt(pred)
464
+ # if transformation == "identity":
465
+ return obs, pred
466
+
467
+
468
+ ####---------------------------------------------------------------------------
469
+ #### Loss metrics
470
+ def _compute_kl(p_k, q_k, eps=1e-12):
471
+ """Compute Kullback-Leibler divergence.
472
+
473
+ Parameters
474
+ ----------
475
+ p_k : numpy.ndarray
476
+ Reference probability distribution [n_samples, n_bins] or [1, n_bins]
477
+ q_k : numpy.ndarray
478
+ Comparison probability distribution [n_samples, n_bins]
479
+ eps : float, optional
480
+ Small value for numerical stability. Default is 1e-12.
481
+
482
+ Returns
483
+ -------
484
+ numpy.ndarray
485
+ KL divergence for each sample [n_samples]
486
+ """
487
+ q_safe = np.maximum(q_k, eps)
488
+ kl = np.sum(
489
+ p_k * np.log(p_k / q_safe),
490
+ axis=1,
491
+ where=(p_k > 0), # sum where p > 0
492
+ )
493
+ # Clip to 0
494
+ kl = np.maximum(kl, 0.0)
495
+
496
+ # Set to NaN if probability mass is all 0
497
+ pk_mass = p_k.sum()
498
+ qk_mass = q_k.sum(axis=1)
499
+ kl = np.where(pk_mass > 0, kl, np.nan)
500
+ kl = np.where(qk_mass > 0, kl, np.nan)
501
+ return kl
502
+
503
+
504
+ def compute_kl_divergence(obs, pred, dD, eps=1e-12):
505
+ """Compute Kullback-Leibler divergence between observed and predicted N(D).
506
+
507
+ Parameters
508
+ ----------
509
+ obs : numpy.ndarray
510
+ Observed N(D) values with shape (n_bins,) and units [#/m3/mm-1]
511
+ pred : numpy.ndarray
512
+ Predicted N(D) values with shape (n_samples, n_bins) and units [#/m3/mm-1]
513
+ dD : numpy.ndarray
514
+ Diameter bin width in mm with shape (n_bins,)
515
+
516
+ Returns
517
+ -------
518
+ numpy.ndarray
519
+ KL divergence for each sample with shape (n_samples,)
520
+ """
521
+ # Convert N(D) to probabilities (normalize by bin width and total)
522
+ # pdf = N(D) * dD / sum( N(D) * dD)
523
+ p_k = (obs * dD) / (np.sum(obs * dD) + eps)
524
+ q_k = (pred * dD[None, :]) / (np.sum(pred * dD[None, :], axis=1, keepdims=True) + eps)
525
+
526
+ # KL(P||Q) = sum(P * log(P/Q))
527
+ kl = _compute_kl(p_k=p_k[None, :], q_k=q_k, eps=eps)
528
+ return kl
529
+
530
+
531
+ def compute_jensen_shannon_distance(obs, pred, dD, eps=1e-12):
532
+ """Compute Jensen-Shannon distance between observed and predicted N(D).
533
+
534
+ The Jensen-Shannon distance is the square root of the Jensen-Shannon divergence.
535
+ Values are defined between 0 and np.sqrt(ln(2)) = 0.83256
536
+
537
+ Vectorized implementation for multiple predictions.
538
+
539
+ Parameters
540
+ ----------
541
+ obs : numpy.ndarray
542
+ Observed N(D) values with shape (n_bins,) and units [#/m3/mm-1]
543
+ pred : numpy.ndarray
544
+ Predicted N(D) values with shape (n_samples, n_bins) and units [#/m3/mm-1]
545
+ dD : numpy.ndarray
546
+ Diameter bin width in mm with shape (n_bins,)
547
+
548
+ Returns
549
+ -------
550
+ numpy.ndarray
551
+ Jensen-Shannon distance for each sample [n_samples]
552
+ """
553
+ # Convert N(D) to probability distributions
554
+ obs_prob = (obs * dD) / (np.sum(obs * dD) + eps)
555
+ pred_prob = (pred * dD[None, :]) / (np.sum(pred * dD[None, :], axis=1, keepdims=True) + eps)
556
+
557
+ # Mixture distribution
558
+ M = 0.5 * (obs_prob[None, :] + pred_prob)
559
+
560
+ # Compute KL divergences
561
+ # - KL(P||M)
562
+ kl_obs = _compute_kl(p_k=obs_prob[None, :], q_k=M, eps=eps)
563
+
564
+ # - KL(Q||M)
565
+ kl_pred = _compute_kl(p_k=pred_prob, q_k=M, eps=eps)
566
+
567
+ # Compute Jensen Shannon divergence
568
+ js_div = 0.5 * (kl_obs + kl_pred)
569
+ js_div = np.maximum(js_div, 0.0) # clip tiny negative values to zero (numerical safety)
570
+
571
+ # Jensen-Shannon distance
572
+ js_distance = np.sqrt(js_div)
573
+ js_distance = np.maximum(js_distance, 0.0)
574
+ return js_distance
575
+
576
+
577
+ def compute_wasserstein_distance(obs, pred, D, dD, eps=1e-12, integration="bin"):
578
+ """Compute Wasserstein distance (Earth Mover's Distance) between observed and predicted N(D).
579
+
580
+ Vectorized implementation for multiple predictions.
581
+
582
+ Parameters
583
+ ----------
584
+ obs : numpy.ndarray
585
+ Observed N(D) values with shape (n_bins,) and units [#/m3/mm-1]
586
+ pred : numpy.ndarray
587
+ Predicted N(D) values with shape (n_samples, n_bins) and units [#/m3/mm-1]
588
+ D : numpy.ndarray
589
+ Diameter bin centers in mm with shape (n_bins,)
590
+ dD : numpy.ndarray
591
+ Diameter bin width in mm with shape (n_bins,)
592
+ integration : str, optional
593
+ Integration scheme used to compute the Wasserstein integral.
594
+ Supported options are ``"bin"`` and ``"left_riemann"``.
595
+
596
+ ``"bin"`` compute Histogram-based Wasserstein distance. N(D) are interpreted as
597
+ piecewise-constant densities over bins of width ``dD``. The distance is
598
+ computed by integrating the difference between cumulative distribution
599
+ functions over each bin. This is the default.
600
+
601
+ ``"left_riemann"`` computes Discrete-support Wasserstein distance. Probability mass is assumed to be
602
+ concentrated at bin centers ``D``, and the integral is approximated using
603
+ the spacing between support points, consistent with :func:`scipy.stats.wasserstein_distance`.
604
+
605
+ Returns
606
+ -------
607
+ numpy.ndarray
608
+ Wasserstein distance for each sample [n_samples]
609
+ """
610
+ # from scipy.stats import wasserstein_distance
611
+
612
+ # wasserstein_distance(
613
+ # u_values=D,
614
+ # v_values=D,
615
+ # u_weights=obs_prob,
616
+ # v_weights=pred_prob[0]
617
+ # )
618
+
619
+ # Convert N(D) to probabilities (normalize by bin width and total)
620
+ # pdf = N(D) * dD / sum( N(D) * dD)
621
+ obs_prob = (obs * dD) / (np.sum(obs * dD) + eps)
622
+ pred_prob = (pred * dD[None, :]) / (np.sum(pred * dD[None, :], axis=1, keepdims=True) + eps)
623
+
624
+ # Compute cumulative distributions
625
+ obs_cdf = np.cumsum(obs_prob)
626
+ pred_cdf = np.cumsum(pred_prob, axis=1)
627
+
628
+ # Wasserstein distance = integral of |CDF_obs - CDF_pred| over D
629
+ # - Compute difference between CDFs
630
+ obs_cdf_expanded = obs_cdf[None, :] # [1, n_bins]
631
+ diff = np.abs(obs_cdf_expanded - pred_cdf) # [n_samples, n_bins]
632
+
633
+ if integration == "bin":
634
+ wd = np.sum(diff * dD[None, :], axis=1)
635
+ else:
636
+ # Integrate using left Riemann sum (as Scipy wasserstein_distance)
637
+ dx = np.diff(D)
638
+ wd = np.sum(diff[:, :-1] * dx[None, :], axis=1)
639
+
640
+ # Clip to 0
641
+ wd = np.maximum(wd, 0.0)
642
+
643
+ # Set to NaN if probability mass is all 0
644
+ obs_mass = obs_prob.sum()
645
+ pred_mass = pred_prob.sum(axis=1)
646
+ wd = np.where(obs_mass > 0, wd, np.nan)
647
+ wd = np.where(pred_mass > 0, wd, np.nan)
648
+ return wd
649
+
650
+
651
+ def compute_kolmogorov_smirnov_distance(obs, pred, dD, eps=1e-12):
652
+ """Compute Kolmogorov-Smirnov (KS) distance between observed and predicted N(D).
653
+
654
+ The Kolmogorov-Smirnov (KS) distance is bounded between 0 and 1,
655
+ where 0 indicates that the two distributions are identical.
656
+ The associated KS test p-value ranges from 0 to 1,
657
+ with a value of 1 indicating no evidence against the null hypothesis that the distributions are identical.
658
+ When the p value is smaller than the significance level (e.g. < 0.05) the model is rejected.
659
+
660
+ If model parameters are estimated from the same data to which the model is compared,
661
+ the standard KS p-values are invalid.
662
+ The solution is to use a parametric bootstrap:
663
+ 1. Fit model to your data
664
+ 2. Simulate many datasets from that fitted gamma
665
+ 3. Refit gamma for each simulated dataset
666
+ 4. Compute KS statistic each time
667
+ 5. Compare your observed KS statistic to the bootstrap distribution
668
+
669
+ Vectorized implementation for multiple predictions.
670
+
671
+ Parameters
672
+ ----------
673
+ obs : numpy.ndarray
674
+ Observed N(D) values with shape (n_bins,) and units [#/m3/mm-1]
675
+ pred : numpy.ndarray
676
+ Predicted N(D) values with shape (n_samples, n_bins) and units [#/m3/mm-1]
677
+ dD : numpy.ndarray
678
+ Diameter bin width in mm with shape (n_bins,)
679
+
680
+ Returns
681
+ -------
682
+ numpy.ndarray
683
+ KS statistic for each sample [n_samples]
684
+ If 0, the two distributions are identical.
685
+ np.ndarray
686
+ KS p-value for each sample [n_samples]
687
+ A p-value of 0 means “strong evidence against equality.”
688
+ A p-value of 1 means “no evidence against equality.”
689
+ Identical distributions show a pvalue of 1.
690
+ Similar distributions show a pvalue close to 1.
691
+ """
692
+ # Convert N(D) to probability mass
693
+ obs_prob = (obs * dD) / (np.sum(obs * dD) + eps)
694
+ pred_prob = (pred * dD[None, :]) / (np.sum(pred * dD[None, :], axis=1, keepdims=True) + eps)
695
+
696
+ # Compute CDFs
697
+ obs_cdf = np.cumsum(obs_prob) # (n_bins,)
698
+ pred_cdf = np.cumsum(pred_prob, axis=1) # (n_samples, n_bins)
699
+
700
+ # KS statistic = max |CDF_obs - CDF_pred|
701
+ ks = np.max(np.abs(pred_cdf - obs_cdf[None, :]), axis=1)
702
+
703
+ # Compute effective sample sizes (from probabilities)
704
+ n_eff_obs = 1.0 / np.sum(obs_prob**2)
705
+ n_eff_pred = 1.0 / np.sum(pred_prob**2, axis=1)
706
+ n_eff_ks = (n_eff_obs * n_eff_pred) / (n_eff_obs + n_eff_pred)
707
+
708
+ # Compute KS pvalue (asymptotic approximation)
709
+ p_value = 2.0 * np.exp(-2.0 * (ks * np.sqrt(n_eff_ks)) ** 2)
710
+ p_value = np.clip(p_value, 0.0, 1.0)
711
+
712
+ # Set to NaN if probability mass is all 0
713
+ obs_mass = obs_prob.sum()
714
+ pred_mass = pred_prob.sum(axis=1)
715
+ ks = np.where(obs_mass > 0, ks, np.nan)
716
+ ks = np.where(pred_mass > 0, ks, np.nan)
717
+ p_value = np.where(obs_mass > 0, p_value, np.nan)
718
+ p_value = np.where(pred_mass > 0, p_value, np.nan)
719
+ return ks, p_value
720
+
721
+
722
+ ####---------------------------------------------------------------------------
723
+ #### Wrappers
724
+
725
+
726
+ def compute_errors(obs, pred, loss, D=None, dD=None): # noqa: PLR0911
727
+ """Compute error between observed and predicted values.
728
+
729
+ The function is entirely vectorized and can handle multiple predictions at once.
730
+
731
+ Parameters
732
+ ----------
733
+ obs : numpy.ndarray
734
+ Observed values.
735
+ Is scalar value if specified target is an integral variable.
736
+ Is 1D array of size [n_bins] if target is a distribution.
737
+ pred : numpy.ndarray
738
+ Predicted values. Can be 1D [n_samples] or 2D [n_samples, n_bins].
739
+ Is 1D when specified target is an integral variable.
740
+ Is 2D when specified target is a distribution.
741
+ loss : str
742
+ Error metric to compute. See supported metrics in ERROR_METRICS.
743
+ D : numpy.ndarray, optional
744
+ Diameter bin center in mm [n_bins]. Required for 'WD' metric. Default is None.
745
+ dD : numpy.ndarray, optional
746
+ Diameter bin width in mm [n_bins]. Required for distribution metrics. Default is None.
747
+
748
+ Returns
749
+ -------
750
+ numpy.ndarray
751
+ Computed error(s) [n_samples] for most metrics, or [n_samples, n_bins] for element-wise metrics.
752
+ """
753
+ # Handle scalar obs case (from integral targets like Z, R, LWC)
754
+ if np.isscalar(obs):
755
+ obs = np.asarray(obs)
756
+
757
+ # Compute SE or AE (for integral targets)
758
+ if obs.size == 1:
759
+ if loss == "AE":
760
+ return np.abs(obs - pred)
761
+ # "SE"
762
+ return (obs - pred) ** 2
763
+
764
+ # Compute KL or WD if asked (obs is expanded internally to save computations)
765
+ if loss == "KLDiv":
766
+ return compute_kl_divergence(obs, pred, dD=dD)
767
+ if loss == "JSD":
768
+ return compute_jensen_shannon_distance(obs, pred, dD=dD)
769
+ if loss == "WD":
770
+ return compute_wasserstein_distance(obs, pred, D=D, dD=dD)
771
+ if loss == "KS":
772
+ return compute_kolmogorov_smirnov_distance(obs, pred, dD=dD)[0] # select distance
773
+ # if loss == "KS_pvalue":
774
+ # return compute_kolmogorov_smirnov_distance(obs, pred, dD=dD)[1] # select p_value
775
+
776
+ # Broadcast obs to match pred shape if needed (when target is N(D) or H(x))
777
+ # If obs is 1D and pred is 2D, add dimension to obs
778
+ if pred.ndim > obs.ndim:
779
+ obs = obs[None, :]
780
+
781
+ # Compute error metrics
782
+ if loss == "SSE":
783
+ return np.sum((obs - pred) ** 2, axis=1)
784
+ if loss == "SAE":
785
+ return np.sum(np.abs(obs - pred), axis=1)
786
+ if loss == "MAE":
787
+ return np.mean(np.abs(obs - pred), axis=1)
788
+ if loss == "relMAE":
789
+ return np.mean(np.abs(obs - pred) / (np.abs(obs) + 1e-12), axis=1)
790
+ if loss == "MSE":
791
+ return np.mean((obs - pred) ** 2, axis=1)
792
+ if loss == "RMSE":
793
+ return np.sqrt(np.mean((obs - pred) ** 2, axis=1))
794
+ raise NotImplementedError(f"Error metric '{loss}' is not implemented.")
795
+
796
+
797
+ def normalize_errors(errors):
798
+ """Normalize errors to scale minimum error region to O(1).
799
+
800
+ Scaling by the median value of the p0-p10 region normalizes error in
801
+ the minimum region to approximately O(1). Scaling by p95-p5 is not used
802
+ because when tails span orders of magnitude, it normalizes the spread
803
+ rather than the minimum region, suppressing the minimum region and
804
+ amplifying the bad region.
805
+
806
+ Parameters
807
+ ----------
808
+ errors : numpy.ndarray
809
+ Error values to normalize [n_samples]
810
+
811
+ Returns
812
+ -------
813
+ numpy.ndarray
814
+ Normalized errors (if normalize_error=True) or original errors (if False)
815
+ """
816
+ p10 = np.nanpercentile(errors, q=10)
817
+ scale = np.nanmedian(errors[errors <= p10])
818
+
819
+ ## Investigate normalization
820
+ # plt.hist(errors[errors < p10], bins=100)
821
+ # errors_norm = errors / scale
822
+ # p_norm10 = np.nanpercentile(errors_norm, q=10)
823
+ # plt.hist(errors_norm[errors_norm < p_norm10], bins=100)
824
+
825
+ # scale = np.diff(np.nanpercentile(errors, q=[1, 99])) + 1e-12
826
+ if scale != 0:
827
+ errors = errors / scale
828
+ return errors
829
+
830
+
831
+ def compute_loss(
832
+ ND_obs,
833
+ ND_preds,
834
+ D,
835
+ dD,
836
+ V,
837
+ target,
838
+ censoring,
839
+ transformation,
840
+ loss,
841
+ check_arguments=True,
842
+ ):
843
+ """Compute loss.
844
+
845
+ Computes loss between observed and predicted drop size distributions,
846
+ with optional censoring, transformation, and target variable specification.
847
+
848
+ Parameters
849
+ ----------
850
+ ND_obs : numpy.ndarray
851
+ Observed drop size distribution of shape (n_bins, ) with units [#/m3/mm-1]
852
+ ND_preds : numpy.ndarray
853
+ Predicted drop size distributions of shape (n_samples, n_bins) with units [#/m3/mm-1]
854
+ D : numpy.ndarray
855
+ Diameter bin centers in mm [n_bins]
856
+ dD : numpy.ndarray
857
+ Diameter bin width in mm [n_bins]
858
+ V : numpy.ndarray
859
+ Terminal velocity [n_bins] [m/s]
860
+ target : str
861
+ Target variable: 'Z', 'R', 'LWC', moments ('M0'-'M6'), 'N(D)', or 'H(x)'.
862
+ censoring : str
863
+ Censoring strategy: 'none', 'left', 'right', or 'both'.
864
+ transformation : str
865
+ Transformation: 'identity', 'log', or 'sqrt'.
866
+ loss : str
867
+ Loss function.
868
+ If target is ``"N(D)"`` or ``"H(x)"``, valid options are:
869
+ - ``SSE``: Sum of Squared Errors
870
+ - ``SAE``: Sum of Absolute Errors
871
+ - ``MAE``: Mean Absolute Error
872
+ - ``MSE``: Mean Squared Error
873
+ - ``RMSE``: Root Mean Squared Error
874
+ - ``relMAE``: Relative Mean Absolute Error
875
+ - ``KLDiv``: Kullback-Leibler Divergence
876
+ - ``WD``: Wasserstein Distance
877
+ - ``JSD``: Jensen-Shannon Distance
878
+ - ``KS``: Kolmogorov-Smirnov Statistic
879
+ If target is one of ``"R"``, ``"Z"``, ``"LWC"``, or ``"M<p>"``, valid options are:
880
+ - ``AE``: Absolute Error
881
+ - ``SE``: Squared Error
882
+ check_arguments : bool, optional
883
+ If True, validate input arguments. Default is True.
884
+
885
+ Returns
886
+ -------
887
+ numpy.ndarray
888
+ Computed errors [n_samples]. Values are NaN where computation failed.
889
+ """
890
+ # Check input
891
+ if check_arguments:
892
+ target = check_target(target)
893
+ transformation = check_transformation(transformation)
894
+ censoring = check_censoring(censoring)
895
+ loss = check_valid_loss(loss, target=target)
896
+
897
+ # Clip N(D) < 1e-3 to 0
898
+ ND_obs = np.where(ND_obs < 1e-3, 0.0, ND_obs)
899
+ ND_preds = np.where(ND_preds < 1e-3, 0.0, ND_preds)
900
+
901
+ # Truncate if asked
902
+ left_censored = censoring in {"left", "both"}
903
+ right_censored = censoring in {"right", "both"}
904
+ if left_censored or right_censored:
905
+ truncated = truncate_bin_edges(
906
+ ND_obs,
907
+ ND_preds,
908
+ D,
909
+ dD,
910
+ V,
911
+ left_censored=left_censored,
912
+ right_censored=right_censored,
913
+ )
914
+ if truncated is None:
915
+ # Grid search logic expects inf so it can be turned into NaN later
916
+ return np.full(ND_preds.shape[0], np.inf)
917
+ ND_obs, ND_preds, D, dD, V = truncated
918
+
919
+ # Compute target variable
920
+ obs, pred = compute_target_variable(target, ND_obs, ND_preds, D=D, dD=dD, V=V)
921
+
922
+ # Apply transformation
923
+ obs, pred = apply_transformation(obs, pred, transformation=transformation)
924
+
925
+ # Compute errors
926
+ errors = compute_errors(obs, pred, loss=loss, D=D, dD=dD)
927
+
928
+ # Replace inf with NaN
929
+ errors[~np.isfinite(errors)] = np.nan
930
+ return errors
931
+
932
+
933
+ def compute_weighted_loss(ND_obs, ND_preds, D, dD, V, objectives, Nc=None):
934
+ """Compute weighted loss between observed and predicted particle size distributions.
935
+
936
+ Parameters
937
+ ----------
938
+ ND_obs : numpy.ndarray
939
+ Observed drop size distribution of shape (n_bins, ) with units [#/m3/mm-1]
940
+ ND_preds : numpy.ndarray
941
+ Predicted drop size distributions of shape (n_samples, n_bins) with units [#/m3/mm-1]
942
+ D : numpy.ndarray
943
+ Diameter bin centers in mm [n_bins]
944
+ dD : numpy.ndarray
945
+ Diameter bin width in mm [n_bins]
946
+ V : numpy.ndarray
947
+ Terminal velocity [n_bins] [m/s]
948
+ objectives: list of dict
949
+ target : str, optional
950
+ Target quantity to optimize. Valid options:
951
+
952
+ - ``"N(D)"`` : Drop number concentration [m⁻³ mm⁻¹]
953
+ - ``"H(x)"`` : Normalized drop number concentration [-]
954
+ - ``"R"`` : Rain rate [mm h⁻¹]
955
+ - ``"Z"`` : Radar reflectivity [mm⁶ m⁻³]
956
+ - ``"LWC"`` : Liquid water content [g m⁻³]
957
+ - ``"M<p>"`` : Moment of order p
958
+
959
+ transformation : str, optional
960
+ Transformation applied to the target quantity before computing the loss.
961
+ Valid options:
962
+
963
+ - ``"identity"`` : No transformation
964
+ - ``"log"`` : Logarithmic transformation
965
+ - ``"sqrt"`` : Square root transformation
966
+
967
+ censoring : str
968
+ Specifies whether the observed particle size distribution (PSD) is
969
+ treated as censored at the edges of the diameter range due to
970
+ instrumental sensitivity limits:
971
+
972
+ - ``"none"`` : No censoring is applied. All diameter bins are used.
973
+ - ``"left"`` : Left-censored PSD. Diameter bins at the lower end of
974
+ the spectrum where the observed number concentration is zero are
975
+ removed prior to cost-function evaluation.
976
+ - ``"right"`` : Right-censored PSD. Diameter bins at the upper end of
977
+ the spectrum where the observed number concentration is zero are
978
+ removed prior to cost-function evaluation.
979
+ - ``"both"`` : Both left- and right-censored PSD. Only the contiguous
980
+ range of diameter bins with non-zero observed concentrations is
981
+ retained.
982
+
983
+ loss : int, optional
984
+ Loss function.
985
+ If target is ``"N(D)"`` or ``"H(x)"``, valid options are:
986
+
987
+ - ``SSE``: Sum of Squared Errors
988
+ - ``SAE``: Sum of Absolute Errors
989
+ - ``MAE``: Mean Absolute Error
990
+ - ``MSE``: Mean Squared Error
991
+ - ``RMSE``: Root Mean Squared Error
992
+ - ``relMAE``: Relative Mean Absolute Error
993
+ - ``KLDiv``: Kullback-Leibler Divergence
994
+ - ``WD``: Wasserstein Distance
995
+ - ``JSD``: Jensen-Shannon Distance
996
+ - ``KS``: Kolmogorov-Smirnov Statistic
997
+
998
+ If target is one of ``"R"``, ``"Z"``, ``"LWC"``, or ``"M<p>"``, valid options are:
999
+
1000
+ - ``AE``: Absolute Error
1001
+ - ``SE``: Squared Error
1002
+
1003
+ loss_weight: int, optional
1004
+ Weight of this objective when multiple objectives are used.
1005
+ Must be specified if more than one objective is specified.
1006
+ Nc : float, optional
1007
+ Normalization constant for H(x) target.
1008
+ If provided, N(D) will be divided by Nc.
1009
+
1010
+ Returns
1011
+ -------
1012
+ numpy.ndarray
1013
+ Computed errors [n_samples]. Values are NaN where computation failed.
1014
+ """
1015
+ # Compute weighted loss across all targets
1016
+ total_loss = np.zeros(ND_preds.shape[0])
1017
+ total_loss_weights = 0
1018
+ for objective in objectives:
1019
+ # Extract target configuration
1020
+ target = objective["target"]
1021
+ loss = objective.get("loss", None)
1022
+ censoring = objective["censoring"]
1023
+ transformation = objective["transformation"]
1024
+ if len(objectives) > 1:
1025
+ loss_weight = objective["loss_weight"]
1026
+ normalize_loss = True # objective["normalize_loss"]
1027
+ else:
1028
+ loss_weight = 1
1029
+ normalize_loss = False # objective["normalize_loss"]
1030
+
1031
+ # Prepare observed and predicted variables
1032
+ # - Compute normalized H(x) if Nc provided and target is H(x)
1033
+ if Nc is not None:
1034
+ obs = ND_obs / Nc if target == "H(x)" else ND_obs
1035
+ preds = ND_preds / Nc if target == "H(x)" else ND_preds
1036
+ else:
1037
+ obs = ND_obs
1038
+ preds = ND_preds
1039
+
1040
+ # Compute errors for this target
1041
+ loss_values = compute_loss(
1042
+ ND_obs=obs,
1043
+ ND_preds=preds,
1044
+ D=D,
1045
+ dD=dD,
1046
+ V=V,
1047
+ target=target,
1048
+ transformation=transformation,
1049
+ loss=loss,
1050
+ censoring=censoring,
1051
+ )
1052
+
1053
+ # Normalize loss
1054
+ if normalize_loss:
1055
+ loss_values = normalize_errors(loss_values)
1056
+
1057
+ # Accumulate weighted loss
1058
+ total_loss += loss_weight * loss_values
1059
+ total_loss_weights += loss_weight
1060
+
1061
+ # Normalize by total weight
1062
+ total_loss = total_loss / total_loss_weights
1063
+
1064
+ # Replace inf with NaN
1065
+ total_loss[~np.isfinite(total_loss)] = np.nan
1066
+ return total_loss