drn 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
drn/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ # __init__.py for distributionalforecasting package
2
+ from .models import *
3
+ from .distributions import *
4
+ from .interpretability import DRNExplainer, KernelSHAP_DRN
5
+ from .train import train, split_and_preprocess
6
+ from .metrics import crps, quantile_score, quantile_losses, rmse
@@ -0,0 +1,4 @@
1
+ from .histogram import Histogram
2
+ from .extended_histogram import ExtendedHistogram
3
+
4
+ __all__ = ["Histogram", "ExtendedHistogram"]
@@ -0,0 +1,219 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ from torch.distributions import Distribution
5
+
6
+ from .histogram import Histogram
7
+
8
+
9
+ class ExtendedHistogram(Distribution):
10
+ """
11
+ This class represents a splicing of a supplied distribution with a histogram distribution.
12
+ The histogram part is defined by K regions with boundaries -infty < c_0 < c_1 < ... < c_K < infty.
13
+ The final density before c_0 & after c_K is the same as the original distribution.
14
+ The density between c_k & c_{k+1} is defined by the histogram distribution.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ baseline: Distribution,
20
+ cutpoints: torch.Tensor,
21
+ pmf: torch.Tensor,
22
+ baseline_probs: Optional[torch.Tensor] = None,
23
+ ):
24
+ """
25
+ Args:
26
+ baseline: the original distribution
27
+ cutpoints: the bin boundaries (shape: (K+1,))
28
+ pmf: the refined (cond.) probability for landing in each region (shape: (n, K))
29
+ baseline_probs: the baseline's probability for landing in each region (shape: (n, K))
30
+ """
31
+ self.baseline = baseline
32
+ self.cutpoints = cutpoints
33
+ self.prob_masses = pmf
34
+ self.baseline_probs = baseline_probs
35
+ self.histogram = Histogram(cutpoints, pmf)
36
+ self.scale_down_hist = baseline.cdf(cutpoints[-1]) - baseline.cdf(cutpoints[0])
37
+
38
+ assert self.scale_down_hist.shape == torch.Size([self.histogram.batch_shape[0]])
39
+
40
+ super(ExtendedHistogram, self).__init__(
41
+ batch_shape=self.histogram.batch_shape, validate_args=False
42
+ )
43
+
44
+ def baseline_prob_between_cutpoints(self) -> torch.Tensor:
45
+ """
46
+ Calculate the baseline probability vector
47
+ """
48
+ if self.baseline_probs is None:
49
+ baseline_cdfs = self.baseline.cdf(self.cutpoints.unsqueeze(-1)).T
50
+ self.baseline_probs = torch.diff(baseline_cdfs, dim=1)
51
+
52
+ return self.baseline_probs
53
+
54
+ def real_adjustments(self) -> torch.Tensor:
55
+ """
56
+ Calculate the real adjustment factors a_k's
57
+ """
58
+ return self.prob_masses / self.baseline_prob_between_cutpoints()
59
+
60
+ def prob(self, value: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Calculate the probability densities of `values`.
63
+ """
64
+
65
+ orig_ndim = value.ndim
66
+
67
+ # Ensure the last dimension of value matches the batch_shape
68
+ if value.shape[-1] != self.batch_shape[0]:
69
+ if value.ndim == 1:
70
+ value = value.unsqueeze(-1)
71
+ value = value.expand(-1, self.batch_shape[0])
72
+
73
+ # Ensure value is 2D
74
+ if value.ndim == 1:
75
+ value = value.unsqueeze(0)
76
+
77
+ baseline_prob = torch.exp(self.baseline.log_prob(value))
78
+ hist_prob = self.histogram.prob(value) * (self.scale_down_hist + 1e-10)
79
+
80
+ in_hist = (value >= self.histogram.cutpoints[0]) & (
81
+ value < self.histogram.cutpoints[-1]
82
+ )
83
+ in_baseline = ~in_hist
84
+
85
+ probabilities = torch.zeros_like(baseline_prob)
86
+ probabilities[in_baseline] = baseline_prob[in_baseline]
87
+ probabilities[in_hist] = hist_prob[in_hist]
88
+
89
+ return probabilities
90
+
91
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
92
+ return torch.log(self.prob(value))
93
+
94
+ def cdf(self, value: torch.Tensor) -> torch.Tensor:
95
+ """
96
+ Calculate the cumulative distribution function for the given values.
97
+ """
98
+ baseline_cdf = self.baseline.cdf(value)
99
+ hist_cdf = self.histogram.cdf(value) * self.scale_down_hist
100
+ in_hist = (value >= self.histogram.cutpoints[0]) & (
101
+ value < self.histogram.cutpoints[-1]
102
+ )
103
+ in_hist = (
104
+ in_hist.expand(value.shape[0], self.batch_shape[0])
105
+ if in_hist.ndim > 1
106
+ else in_hist
107
+ )
108
+ in_baseline = ~in_hist
109
+
110
+ lower_cdf = self.baseline.cdf(self.histogram.cutpoints[0])
111
+ cdf_values = torch.zeros_like(baseline_cdf)
112
+
113
+ cdf_values[in_baseline] = baseline_cdf[in_baseline]
114
+ cdf_values[in_hist] = (lower_cdf + hist_cdf)[in_hist]
115
+
116
+ return cdf_values
117
+
118
+ def cdf_at_cutpoints(self) -> torch.Tensor:
119
+ """
120
+ Calculate the cumulative distribution function at each cutpoint.
121
+ """
122
+ hist_at_cutpoints = (
123
+ self.histogram.cdf_at_cutpoints() * self.scale_down_hist.unsqueeze(0)
124
+ )
125
+ lower_cdf = self.baseline.cdf(self.histogram.cutpoints[0]).unsqueeze(0)
126
+ out = lower_cdf + hist_at_cutpoints
127
+ return out
128
+
129
+ @property
130
+ def mean(self) -> torch.Tensor:
131
+ """
132
+ Calculate the mean of the distribution.
133
+ Returns:
134
+ the mean (shape: (batch_shape,))
135
+ """
136
+ middle_of_bins = (self.cutpoints[1:] + self.cutpoints[:-1]) / 2
137
+ return torch.sum(self.prob_masses * middle_of_bins, dim=1)
138
+
139
+ def icdf(self, p, l=None, u=None, max_iter=1000, tolerance=1e-7) -> torch.Tensor:
140
+ """
141
+ Calculate the inverse CDF (quantiles) of the distribution for the given cumulative probability.
142
+
143
+ Args:
144
+ p: cumulative probability values at which to evaluate icdf
145
+ l: lower bound for the quantile search
146
+ u: upper bound for the quantile search
147
+ max_iter: maximum number of iterations permitted for the quantile search
148
+ tolerance: stopping criteria for the search (precision)
149
+
150
+ Returns:
151
+ A tensor of shape (1, batch_shape) containing the inverse CDF values.
152
+ """
153
+
154
+ num_observations = self.cdf(torch.Tensor([1]).unsqueeze(-1)).shape[
155
+ 1
156
+ ] # Dummy call to cdf to determine the batch size
157
+ percentiles_tensor = torch.full(
158
+ (1, num_observations), fill_value=p, dtype=torch.float32
159
+ )
160
+
161
+ # Initialise matrices for the bounds
162
+ lower_bounds = (
163
+ l if l is not None else torch.Tensor([0])
164
+ ) # self.cutpoints[0] - (self.cutpoints[-1]-self.cutpoints[0])
165
+ upper_bounds = (
166
+ u
167
+ if u is not None
168
+ else self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
169
+ ) # Adjust max value as needed
170
+
171
+ lower_bounds = lower_bounds.repeat(num_observations).reshape(
172
+ 1, num_observations
173
+ )
174
+ upper_bounds = upper_bounds.repeat(num_observations).reshape(
175
+ 1, num_observations
176
+ )
177
+
178
+ for _ in range(max_iter):
179
+ mid_points = (lower_bounds + upper_bounds) / 2
180
+
181
+ cdf_vals = self.cdf(mid_points)
182
+
183
+ # Update the bounds based on where the CDF values are relative to the target percentiles
184
+ lower_update = cdf_vals < percentiles_tensor
185
+ upper_update = ~lower_update
186
+ lower_bounds = torch.where(lower_update, mid_points, lower_bounds)
187
+ upper_bounds = torch.where(upper_update, mid_points, upper_bounds)
188
+
189
+ # Check for convergence
190
+ if torch.max(upper_bounds - lower_bounds) < tolerance:
191
+ break
192
+
193
+ # Use the midpoint between the final bounds as the quantile estimate
194
+ quantiles = (lower_bounds + upper_bounds) / 2
195
+
196
+ return quantiles
197
+
198
+ def quantiles(
199
+ self, percentiles: list, l=None, u=None, max_iter=1000, tolerance=1e-7
200
+ ) -> torch.Tensor:
201
+ """
202
+ Calculate the quantile values for the given observations and percentiles (cumulative probabilities * 100).
203
+ """
204
+ l = torch.Tensor(
205
+ self.cutpoints[0] - (self.cutpoints[-1] - self.cutpoints[0])
206
+ if l is None
207
+ else l
208
+ )
209
+ u = (
210
+ self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
211
+ if u is None
212
+ else u
213
+ )
214
+ quantiles = [
215
+ self.icdf(torch.tensor(percentile / 100.0), l, u, max_iter, tolerance)
216
+ for percentile in percentiles
217
+ ]
218
+
219
+ return torch.stack(quantiles, dim=1)[0]
@@ -0,0 +1,406 @@
1
+ import torch
2
+ from torch.distributions import Distribution
3
+
4
+
5
+ class Histogram(Distribution):
6
+ """
7
+ This class represents a histogram distribution.
8
+ Basically, the distribution is a composite of uniform distributions over the bins.
9
+ """
10
+
11
+ def __init__(self, cutpoints: torch.Tensor, prob_masses: torch.Tensor):
12
+ """
13
+ Args:
14
+ regions: the bin boundaries (shape: (K+1,))
15
+ prob_masses: the probability for landing in each regions (shape: (n, K))
16
+ """
17
+
18
+ # Constructed regions T_k \in [c_k, c_{k+1}) for all k \in \{0, ..., K-1\}
19
+ self.cutpoints = cutpoints
20
+ self.num_regions = len(self.cutpoints) - 1
21
+
22
+ # Predicted probabilities vector: (Pr(c_0<Y<c_1|x,w_{Histogram}),..., Pr(c_{K-1}<Y<c_K|x,w_{Histogram}))
23
+ self.prob_masses = prob_masses
24
+ assert torch.allclose(
25
+ torch.sum(self.prob_masses, dim=1),
26
+ torch.ones(self.prob_masses.shape[0], device=cutpoints.device),
27
+ )
28
+ assert self.prob_masses.shape[1] == self.num_regions
29
+
30
+ # Compute the bin widths for later use, i.e., (T_1 = c_1 - c_0,..., T_K = c_K - c_{K-1})
31
+ self.bin_widths = self.cutpoints[1:] - self.cutpoints[:-1]
32
+ assert torch.all(self.bin_widths > 0)
33
+
34
+ # Compute the PDF values using the probability masses and bin widths, i.e., normalising
35
+ # Pr(Y\in T_k|x,w_{Histogram})/T_k * Pr(c_0<Y<c_K|x,w_{Baseline})
36
+ self.prob_densities = self.prob_masses / self.bin_widths
37
+
38
+ super(Histogram, self).__init__(
39
+ batch_shape=torch.Size([prob_masses.shape[0]]), validate_args=False
40
+ )
41
+
42
+ def prob(self, value: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Calculate the probability densities of `values`.
45
+ """
46
+ orig_ndim = value.ndim
47
+
48
+ # Ensure the last dimension of value matches the batch_shape
49
+ if value.shape[-1] != self.batch_shape[0]:
50
+ if value.ndim == 1:
51
+ value = value.unsqueeze(-1)
52
+ value = value.expand(-1, self.batch_shape[0])
53
+
54
+ # Ensure value is 2D
55
+ if value.ndim == 1:
56
+ value = value.unsqueeze(0)
57
+
58
+ # Initialize CDF to zeros with the same shape as value
59
+ probabilities = torch.zeros_like(value)
60
+
61
+ # Go through each observation vector and calculate pdf over all batch_size distributions
62
+ for i in range(value.shape[0]):
63
+ # Calculate the pdf for the `y` batch
64
+ y = value[i, :]
65
+
66
+ # Iterate over each bin
67
+ for r in range(self.num_regions):
68
+ in_bin = (y >= self.cutpoints[r]) & (y < self.cutpoints[r + 1])
69
+ probabilities[i, in_bin] = self.prob_densities[in_bin, r]
70
+
71
+ # If we added a leading dimension, remove it
72
+ if orig_ndim == 1 and probabilities.ndim == 2:
73
+ probabilities = probabilities.squeeze(0)
74
+
75
+ return probabilities
76
+
77
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Calculate the log probability densities of `values`.
80
+ """
81
+ return torch.log(self.prob(value))
82
+
83
+ def cdf(self, value: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Calculate the cumulative distribution function for the given values.
86
+ """
87
+ orig_ndim = value.ndim
88
+ # Ensure the last dimension of value matches the batch_shape
89
+ if value.shape[-1] != self.batch_shape[0]: # and value.shape[-1] == 1:
90
+ return self.cdf_same_eval(value)
91
+
92
+ # Ensure value is 2D
93
+ if value.ndim == 1:
94
+ value = value.unsqueeze(0)
95
+
96
+ # Initialize CDF to zeros with the same shape as value
97
+ cdf_values = torch.zeros_like(value)
98
+
99
+ # Original cdf
100
+ cumulative_cdf = torch.cumsum(self.prob_masses, dim=-1).T
101
+
102
+ # Iterate over each bin
103
+ for j in range(value.shape[0]):
104
+
105
+ y = value[j, :]
106
+
107
+ # Conditions now compare each element in y against the cutpoints
108
+ # This assumes the intent is to apply the condition across all dimensions uniformly
109
+ condition_above_last_cutpoint = y >= self.cutpoints[-1]
110
+ condition_below_first_cutpoint = y <= self.cutpoints[0]
111
+
112
+ # Since y is a vector, the assignment needs to respect the condition per element
113
+ cdf_values[j, condition_above_last_cutpoint] = 1.0
114
+ cdf_values[j, condition_below_first_cutpoint] = 0.0
115
+ # print(condition_above_last_cutpoint, y, cdf_values)
116
+
117
+ # Determine the index of the cutpoints
118
+ y_expanded = y.unsqueeze(0) # Now y has shape [1, n]
119
+ cutpoints_below = self.cutpoints[:-1].unsqueeze(
120
+ 1
121
+ ) # Now self.cutpoints[:-1] has shape [K, 1]
122
+ cutpoints_above = self.cutpoints[1:].unsqueeze(1)
123
+
124
+ # Perform comparison to determine bins
125
+ comparison_result_below = y_expanded >= cutpoints_below # [K, n]
126
+ comparison_result_above = y_expanded < cutpoints_above # [K, n]
127
+ valid_bins = comparison_result_below & comparison_result_above # [K, n]
128
+
129
+ # Find the last valid bin for each element
130
+ last_bin_idx = valid_bins.long().argmax(dim=0) # [n]
131
+ # Ensure elements outside the cutpoints are handled correctly
132
+ below_min_mask = y_expanded.squeeze() <= self.cutpoints[0]
133
+ above_max_mask = y_expanded.squeeze() >= self.cutpoints[-1]
134
+
135
+ # Update last_bin_idx for values outside the cutpoints
136
+ last_bin_idx[below_min_mask] = (
137
+ 0 # First bin for values below the minimum cutpoint
138
+ )
139
+ last_bin_idx[above_max_mask] = (
140
+ len(self.cutpoints) - 2
141
+ ) # Last valid bin index for values above the maximum cutpoint
142
+
143
+ # Determine next_bin_idx based on last_bin_idx
144
+ next_bin_idx = last_bin_idx + 1
145
+ # Ensure next_bin_idx does not exceed the number of bins
146
+ next_bin_idx[above_max_mask] = len(self.cutpoints) - 1
147
+
148
+ # Initialize 'last_cdfs' and 'next_cdfs' with zeros and ones, respectively
149
+ zeros = torch.zeros(size=(1, value.shape[1]))
150
+ ones = torch.ones(size=(1, value.shape[1]))
151
+
152
+ # Ensure y is properly shaped for broadcasting. If y is already [1, n], this step might be redundant
153
+ y = y.reshape(1, -1) # Ensure y has shape [1, n]
154
+
155
+ # Reshape or select cutpoints for broadcasting
156
+ # If comparing against a single cutpoint, ensure it's shaped for broadcasting
157
+ cutpoint_for_comparison_lower = self.cutpoints[1].reshape(
158
+ 1, -1
159
+ ) # Ensure cutpoint is shaped [1, 1] or similar for broadcasting
160
+ cutpoint_for_comparison_upper = self.cutpoints[-2].reshape(1, -1)
161
+
162
+ # Perform the comparison
163
+ condition_last_cdfs = (
164
+ y >= cutpoint_for_comparison_lower
165
+ ) # This should now result in a shape [1, n]
166
+ condition_next_cdfs = y < cutpoint_for_comparison_upper
167
+
168
+ # Update 'last_cdfs' based on condition, considering 'cumulative_cdf' indexing
169
+ # Ensure 'last_bin_idx' and 'next_bin_idx' are correctly broadcasted or indexed to match 'y's dimensions
170
+ last_cdfs = torch.where(
171
+ condition_last_cdfs,
172
+ cumulative_cdf[
173
+ (last_bin_idx - 1).unsqueeze(0),
174
+ torch.arange(cumulative_cdf.shape[1]),
175
+ ],
176
+ zeros,
177
+ )
178
+ next_cdfs = torch.where(
179
+ condition_next_cdfs,
180
+ cumulative_cdf[
181
+ (next_bin_idx - 1).unsqueeze(0),
182
+ torch.arange(cumulative_cdf.shape[1]),
183
+ ],
184
+ ones,
185
+ )
186
+
187
+ last_bin_idx_expanded = last_bin_idx.unsqueeze(0).expand(y.shape[0], -1)
188
+ next_bin_idx_expanded = next_bin_idx.unsqueeze(0).expand(y.shape[0], -1)
189
+
190
+ # Determine cutpoints based on conditions
191
+ cutpoint_low = torch.where(
192
+ y < self.cutpoints[1].unsqueeze(0),
193
+ self.cutpoints[0],
194
+ self.cutpoints[last_bin_idx_expanded],
195
+ )
196
+ cutpoint_high = torch.where(
197
+ y >= self.cutpoints[-2].unsqueeze(0),
198
+ self.cutpoints[-1],
199
+ self.cutpoints[next_bin_idx_expanded],
200
+ )
201
+
202
+ # Determine bin_width based on conditions
203
+ bin_width = torch.where(
204
+ y < self.cutpoints[1].unsqueeze(0),
205
+ self.bin_widths[0],
206
+ self.bin_widths[last_bin_idx_expanded],
207
+ )
208
+
209
+ # Compute bin_fraction for each feature across all observations
210
+ bin_fraction = (y - cutpoint_low) / bin_width
211
+
212
+ # Compute cdf_values using the previously obtained last_cdfs and next_cdfs for each feature across all observations
213
+ # Ensure last_cdfs and next_cdfs are retrieved using the approach described in previous steps and have the correct shape
214
+ cdf_values[j, :] = last_cdfs + (next_cdfs - last_cdfs) * bin_fraction
215
+
216
+ cdf_values = torch.clamp(cdf_values, max=1.0, min=0.0)
217
+
218
+ # If we added a leading dimension, remove it
219
+ if orig_ndim == 1 and cdf_values.ndim == 2:
220
+ cdf_values = cdf_values.squeeze(0)
221
+
222
+ return cdf_values
223
+
224
+ def cdf_same_eval(self, value: torch.Tensor) -> torch.Tensor:
225
+ """
226
+ Calculate the cumulative distribution function for the same value across the batch.
227
+ """
228
+ orig_ndim = value.ndim
229
+
230
+ # Ensure the last dimension of value matches the batch_shape
231
+ if value.shape[-1] != self.batch_shape[0]:
232
+ if value.ndim == 1:
233
+ value = value.unsqueeze(-1)
234
+ value = value.expand(-1, self.batch_shape[0])
235
+
236
+ # Ensure value is 2D
237
+ if value.ndim == 1:
238
+ value = value.unsqueeze(0)
239
+
240
+ # Initialize CDF to zeros with the same shape as value
241
+ cdf_values = torch.zeros_like(value)
242
+
243
+ # Original cdf
244
+ cumulative_cdf = torch.cumsum(self.prob_masses, dim=-1).T
245
+
246
+ # Iterate over each bin
247
+ for j in range(value.shape[0]):
248
+ y = value[j, 0]
249
+ if y >= self.cutpoints[-1]:
250
+ cdf_values[j, :] = torch.ones(size=(1, value.shape[1]))
251
+ continue
252
+
253
+ if y <= self.cutpoints[0]:
254
+ cdf_values[j, :] = torch.zeros(size=(1, value.shape[1]))
255
+ continue
256
+
257
+ # Determine the index of the cutpoints
258
+ last_bin_idx = (
259
+ (y >= self.cutpoints[:-1]) & (y < self.cutpoints[1:])
260
+ ).nonzero(as_tuple=True)[0]
261
+ next_bin_idx = last_bin_idx + 1
262
+
263
+ # Set cdfs for the lower and upper bounds
264
+ last_cdfs = (
265
+ torch.zeros(size=(1, value.shape[1]), device=value.device)
266
+ if y < self.cutpoints[1]
267
+ else cumulative_cdf[last_bin_idx - 1, :]
268
+ )
269
+ next_cdfs = (
270
+ torch.ones(size=(1, value.shape[1]), device=value.device)
271
+ if y >= self.cutpoints[-2]
272
+ else cumulative_cdf[next_bin_idx - 1, :]
273
+ )
274
+
275
+ # Compute cdf_values
276
+ cutpoint_low = (
277
+ self.cutpoints[0]
278
+ if y < self.cutpoints[1]
279
+ else self.cutpoints[last_bin_idx]
280
+ )
281
+ cutpoint_high = (
282
+ self.cutpoints[-1]
283
+ if y >= self.cutpoints[-2]
284
+ else self.cutpoints[next_bin_idx]
285
+ )
286
+ bin_width = (
287
+ self.bin_widths[0]
288
+ if y < self.cutpoints[1]
289
+ else self.bin_widths[last_bin_idx]
290
+ )
291
+ bin_fraction = (y - cutpoint_low) / bin_width
292
+ cdf_values[j, :] = last_cdfs + (next_cdfs - last_cdfs) * bin_fraction
293
+
294
+ cdf_values = torch.clamp(cdf_values, max=1.0, min=0.0)
295
+
296
+ # If we added a leading dimension, remove it
297
+ if orig_ndim == 1 and cdf_values.ndim == 2:
298
+ cdf_values = cdf_values.squeeze(0)
299
+
300
+ return cdf_values
301
+
302
+ def cdf_at_cutpoints(self) -> torch.Tensor:
303
+ """
304
+ Calculate the cumulative distribution function at each cutpoint.
305
+ """
306
+ # Want to use cumsum but add a 0 to the beginning, torch.cumsum(self.prob_masses, dim=1)
307
+ return torch.cat(
308
+ [
309
+ torch.zeros(
310
+ self.prob_masses.shape[0], 1, device=self.prob_masses.device
311
+ ),
312
+ torch.cumsum(self.prob_masses, dim=1),
313
+ ],
314
+ dim=1,
315
+ ).T
316
+
317
+ @property
318
+ def mean(self) -> torch.Tensor:
319
+ """
320
+ Calculate the mean of the distribution.
321
+ Returns:
322
+ the mean (shape: (batch_shape,))
323
+ """
324
+ middle_of_bins = (self.cutpoints[1:] + self.cutpoints[:-1]) / 2
325
+ return torch.sum(self.prob_masses * middle_of_bins, dim=1)
326
+
327
+ def icdf(self, p, l=None, u=None, max_iter=1000, tolerance=1e-7) -> torch.Tensor:
328
+ """
329
+ Calculate the inverse CDF (quantiles) of the distribution for the given cumulative probability.
330
+
331
+ Args:
332
+ p: cumulative probability values at which to evaluate icdf
333
+ l: lower bound for the quantile search
334
+ u: upper bound for the quantile search
335
+ max_iter: maximum number of iterations permitted for the quantile search
336
+ tolerance: stopping criteria for the search (precision)
337
+
338
+ Returns:
339
+ A tensor of shape (1, batch_shape) containing the inverse CDF values.
340
+ """
341
+
342
+ num_observations = self.cdf(torch.Tensor([1]).unsqueeze(-1)).shape[
343
+ 1
344
+ ] # Dummy call to cdf to determine the batch size
345
+ percentiles_tensor = torch.full(
346
+ (1, num_observations), fill_value=p, dtype=torch.float32
347
+ )
348
+
349
+ # Initialise matrices for the bounds
350
+ lower_bounds = (
351
+ l if l is not None else torch.Tensor([0])
352
+ ) # self.cutpoints[0] - (self.cutpoints[-1]-self.cutpoints[0])
353
+ upper_bounds = (
354
+ u
355
+ if u is not None
356
+ else self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
357
+ ) # Adjust max value as needed
358
+
359
+ lower_bounds = lower_bounds.repeat(num_observations).reshape(
360
+ 1, num_observations
361
+ )
362
+ upper_bounds = upper_bounds.repeat(num_observations).reshape(
363
+ 1, num_observations
364
+ )
365
+
366
+ for _ in range(max_iter):
367
+ mid_points = (lower_bounds + upper_bounds) / 2
368
+
369
+ cdf_vals = self.cdf(mid_points)
370
+
371
+ # Update the bounds based on where the CDF values are relative to the target percentiles
372
+ lower_update = cdf_vals < percentiles_tensor
373
+ upper_update = ~lower_update
374
+ lower_bounds = torch.where(lower_update, mid_points, lower_bounds)
375
+ upper_bounds = torch.where(upper_update, mid_points, upper_bounds)
376
+
377
+ # Check for convergence
378
+ if torch.max(upper_bounds - lower_bounds) < tolerance:
379
+ break
380
+
381
+ # Use the midpoint between the final bounds as the quantile estimate
382
+ quantiles = (lower_bounds + upper_bounds) / 2
383
+
384
+ return quantiles
385
+
386
+ def quantiles(
387
+ self, percentiles: list, l=None, u=None, max_iter=1000, tolerance=1e-7
388
+ ) -> torch.Tensor:
389
+ """
390
+ Calculate the quantile values for the given observations and percentiles (cumulative probabilities * 100).
391
+ """
392
+
393
+ l = torch.Tensor(
394
+ [0]
395
+ ) # self.cutpoints[0] - (self.cutpoints[-1]-self.cutpoints[0]) if l is None else l
396
+ u = (
397
+ self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
398
+ if u is None
399
+ else u
400
+ )
401
+ quantiles = [
402
+ self.icdf(torch.tensor(percentile / 100.0), l, u, max_iter, tolerance)
403
+ for percentile in percentiles
404
+ ]
405
+
406
+ return torch.stack(quantiles, dim=1)[0]