drn 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drn/__init__.py +6 -0
- drn/distributions/__init__.py +4 -0
- drn/distributions/extended_histogram.py +219 -0
- drn/distributions/histogram.py +406 -0
- drn/interpretability.py +1717 -0
- drn/metrics.py +114 -0
- drn/models/__init__.py +39 -0
- drn/models/cann.py +195 -0
- drn/models/ddr.py +111 -0
- drn/models/drn.py +231 -0
- drn/models/glm.py +324 -0
- drn/models/mdn.py +236 -0
- drn/py.typed +0 -0
- drn/train.py +244 -0
- drn-0.0.1.dist-info/METADATA +516 -0
- drn-0.0.1.dist-info/RECORD +18 -0
- drn-0.0.1.dist-info/WHEEL +4 -0
- drn-0.0.1.dist-info/licenses/LICENSE.md +21 -0
drn/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# __init__.py for distributionalforecasting package
|
|
2
|
+
from .models import *
|
|
3
|
+
from .distributions import *
|
|
4
|
+
from .interpretability import DRNExplainer, KernelSHAP_DRN
|
|
5
|
+
from .train import train, split_and_preprocess
|
|
6
|
+
from .metrics import crps, quantile_score, quantile_losses, rmse
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from torch.distributions import Distribution
|
|
5
|
+
|
|
6
|
+
from .histogram import Histogram
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ExtendedHistogram(Distribution):
|
|
10
|
+
"""
|
|
11
|
+
This class represents a splicing of a supplied distribution with a histogram distribution.
|
|
12
|
+
The histogram part is defined by K regions with boundaries -infty < c_0 < c_1 < ... < c_K < infty.
|
|
13
|
+
The final density before c_0 & after c_K is the same as the original distribution.
|
|
14
|
+
The density between c_k & c_{k+1} is defined by the histogram distribution.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
baseline: Distribution,
|
|
20
|
+
cutpoints: torch.Tensor,
|
|
21
|
+
pmf: torch.Tensor,
|
|
22
|
+
baseline_probs: Optional[torch.Tensor] = None,
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Args:
|
|
26
|
+
baseline: the original distribution
|
|
27
|
+
cutpoints: the bin boundaries (shape: (K+1,))
|
|
28
|
+
pmf: the refined (cond.) probability for landing in each region (shape: (n, K))
|
|
29
|
+
baseline_probs: the baseline's probability for landing in each region (shape: (n, K))
|
|
30
|
+
"""
|
|
31
|
+
self.baseline = baseline
|
|
32
|
+
self.cutpoints = cutpoints
|
|
33
|
+
self.prob_masses = pmf
|
|
34
|
+
self.baseline_probs = baseline_probs
|
|
35
|
+
self.histogram = Histogram(cutpoints, pmf)
|
|
36
|
+
self.scale_down_hist = baseline.cdf(cutpoints[-1]) - baseline.cdf(cutpoints[0])
|
|
37
|
+
|
|
38
|
+
assert self.scale_down_hist.shape == torch.Size([self.histogram.batch_shape[0]])
|
|
39
|
+
|
|
40
|
+
super(ExtendedHistogram, self).__init__(
|
|
41
|
+
batch_shape=self.histogram.batch_shape, validate_args=False
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def baseline_prob_between_cutpoints(self) -> torch.Tensor:
|
|
45
|
+
"""
|
|
46
|
+
Calculate the baseline probability vector
|
|
47
|
+
"""
|
|
48
|
+
if self.baseline_probs is None:
|
|
49
|
+
baseline_cdfs = self.baseline.cdf(self.cutpoints.unsqueeze(-1)).T
|
|
50
|
+
self.baseline_probs = torch.diff(baseline_cdfs, dim=1)
|
|
51
|
+
|
|
52
|
+
return self.baseline_probs
|
|
53
|
+
|
|
54
|
+
def real_adjustments(self) -> torch.Tensor:
|
|
55
|
+
"""
|
|
56
|
+
Calculate the real adjustment factors a_k's
|
|
57
|
+
"""
|
|
58
|
+
return self.prob_masses / self.baseline_prob_between_cutpoints()
|
|
59
|
+
|
|
60
|
+
def prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
"""
|
|
62
|
+
Calculate the probability densities of `values`.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
orig_ndim = value.ndim
|
|
66
|
+
|
|
67
|
+
# Ensure the last dimension of value matches the batch_shape
|
|
68
|
+
if value.shape[-1] != self.batch_shape[0]:
|
|
69
|
+
if value.ndim == 1:
|
|
70
|
+
value = value.unsqueeze(-1)
|
|
71
|
+
value = value.expand(-1, self.batch_shape[0])
|
|
72
|
+
|
|
73
|
+
# Ensure value is 2D
|
|
74
|
+
if value.ndim == 1:
|
|
75
|
+
value = value.unsqueeze(0)
|
|
76
|
+
|
|
77
|
+
baseline_prob = torch.exp(self.baseline.log_prob(value))
|
|
78
|
+
hist_prob = self.histogram.prob(value) * (self.scale_down_hist + 1e-10)
|
|
79
|
+
|
|
80
|
+
in_hist = (value >= self.histogram.cutpoints[0]) & (
|
|
81
|
+
value < self.histogram.cutpoints[-1]
|
|
82
|
+
)
|
|
83
|
+
in_baseline = ~in_hist
|
|
84
|
+
|
|
85
|
+
probabilities = torch.zeros_like(baseline_prob)
|
|
86
|
+
probabilities[in_baseline] = baseline_prob[in_baseline]
|
|
87
|
+
probabilities[in_hist] = hist_prob[in_hist]
|
|
88
|
+
|
|
89
|
+
return probabilities
|
|
90
|
+
|
|
91
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
return torch.log(self.prob(value))
|
|
93
|
+
|
|
94
|
+
def cdf(self, value: torch.Tensor) -> torch.Tensor:
|
|
95
|
+
"""
|
|
96
|
+
Calculate the cumulative distribution function for the given values.
|
|
97
|
+
"""
|
|
98
|
+
baseline_cdf = self.baseline.cdf(value)
|
|
99
|
+
hist_cdf = self.histogram.cdf(value) * self.scale_down_hist
|
|
100
|
+
in_hist = (value >= self.histogram.cutpoints[0]) & (
|
|
101
|
+
value < self.histogram.cutpoints[-1]
|
|
102
|
+
)
|
|
103
|
+
in_hist = (
|
|
104
|
+
in_hist.expand(value.shape[0], self.batch_shape[0])
|
|
105
|
+
if in_hist.ndim > 1
|
|
106
|
+
else in_hist
|
|
107
|
+
)
|
|
108
|
+
in_baseline = ~in_hist
|
|
109
|
+
|
|
110
|
+
lower_cdf = self.baseline.cdf(self.histogram.cutpoints[0])
|
|
111
|
+
cdf_values = torch.zeros_like(baseline_cdf)
|
|
112
|
+
|
|
113
|
+
cdf_values[in_baseline] = baseline_cdf[in_baseline]
|
|
114
|
+
cdf_values[in_hist] = (lower_cdf + hist_cdf)[in_hist]
|
|
115
|
+
|
|
116
|
+
return cdf_values
|
|
117
|
+
|
|
118
|
+
def cdf_at_cutpoints(self) -> torch.Tensor:
|
|
119
|
+
"""
|
|
120
|
+
Calculate the cumulative distribution function at each cutpoint.
|
|
121
|
+
"""
|
|
122
|
+
hist_at_cutpoints = (
|
|
123
|
+
self.histogram.cdf_at_cutpoints() * self.scale_down_hist.unsqueeze(0)
|
|
124
|
+
)
|
|
125
|
+
lower_cdf = self.baseline.cdf(self.histogram.cutpoints[0]).unsqueeze(0)
|
|
126
|
+
out = lower_cdf + hist_at_cutpoints
|
|
127
|
+
return out
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def mean(self) -> torch.Tensor:
|
|
131
|
+
"""
|
|
132
|
+
Calculate the mean of the distribution.
|
|
133
|
+
Returns:
|
|
134
|
+
the mean (shape: (batch_shape,))
|
|
135
|
+
"""
|
|
136
|
+
middle_of_bins = (self.cutpoints[1:] + self.cutpoints[:-1]) / 2
|
|
137
|
+
return torch.sum(self.prob_masses * middle_of_bins, dim=1)
|
|
138
|
+
|
|
139
|
+
def icdf(self, p, l=None, u=None, max_iter=1000, tolerance=1e-7) -> torch.Tensor:
|
|
140
|
+
"""
|
|
141
|
+
Calculate the inverse CDF (quantiles) of the distribution for the given cumulative probability.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
p: cumulative probability values at which to evaluate icdf
|
|
145
|
+
l: lower bound for the quantile search
|
|
146
|
+
u: upper bound for the quantile search
|
|
147
|
+
max_iter: maximum number of iterations permitted for the quantile search
|
|
148
|
+
tolerance: stopping criteria for the search (precision)
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
A tensor of shape (1, batch_shape) containing the inverse CDF values.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
num_observations = self.cdf(torch.Tensor([1]).unsqueeze(-1)).shape[
|
|
155
|
+
1
|
|
156
|
+
] # Dummy call to cdf to determine the batch size
|
|
157
|
+
percentiles_tensor = torch.full(
|
|
158
|
+
(1, num_observations), fill_value=p, dtype=torch.float32
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Initialise matrices for the bounds
|
|
162
|
+
lower_bounds = (
|
|
163
|
+
l if l is not None else torch.Tensor([0])
|
|
164
|
+
) # self.cutpoints[0] - (self.cutpoints[-1]-self.cutpoints[0])
|
|
165
|
+
upper_bounds = (
|
|
166
|
+
u
|
|
167
|
+
if u is not None
|
|
168
|
+
else self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
|
|
169
|
+
) # Adjust max value as needed
|
|
170
|
+
|
|
171
|
+
lower_bounds = lower_bounds.repeat(num_observations).reshape(
|
|
172
|
+
1, num_observations
|
|
173
|
+
)
|
|
174
|
+
upper_bounds = upper_bounds.repeat(num_observations).reshape(
|
|
175
|
+
1, num_observations
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
for _ in range(max_iter):
|
|
179
|
+
mid_points = (lower_bounds + upper_bounds) / 2
|
|
180
|
+
|
|
181
|
+
cdf_vals = self.cdf(mid_points)
|
|
182
|
+
|
|
183
|
+
# Update the bounds based on where the CDF values are relative to the target percentiles
|
|
184
|
+
lower_update = cdf_vals < percentiles_tensor
|
|
185
|
+
upper_update = ~lower_update
|
|
186
|
+
lower_bounds = torch.where(lower_update, mid_points, lower_bounds)
|
|
187
|
+
upper_bounds = torch.where(upper_update, mid_points, upper_bounds)
|
|
188
|
+
|
|
189
|
+
# Check for convergence
|
|
190
|
+
if torch.max(upper_bounds - lower_bounds) < tolerance:
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
# Use the midpoint between the final bounds as the quantile estimate
|
|
194
|
+
quantiles = (lower_bounds + upper_bounds) / 2
|
|
195
|
+
|
|
196
|
+
return quantiles
|
|
197
|
+
|
|
198
|
+
def quantiles(
|
|
199
|
+
self, percentiles: list, l=None, u=None, max_iter=1000, tolerance=1e-7
|
|
200
|
+
) -> torch.Tensor:
|
|
201
|
+
"""
|
|
202
|
+
Calculate the quantile values for the given observations and percentiles (cumulative probabilities * 100).
|
|
203
|
+
"""
|
|
204
|
+
l = torch.Tensor(
|
|
205
|
+
self.cutpoints[0] - (self.cutpoints[-1] - self.cutpoints[0])
|
|
206
|
+
if l is None
|
|
207
|
+
else l
|
|
208
|
+
)
|
|
209
|
+
u = (
|
|
210
|
+
self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
|
|
211
|
+
if u is None
|
|
212
|
+
else u
|
|
213
|
+
)
|
|
214
|
+
quantiles = [
|
|
215
|
+
self.icdf(torch.tensor(percentile / 100.0), l, u, max_iter, tolerance)
|
|
216
|
+
for percentile in percentiles
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
return torch.stack(quantiles, dim=1)[0]
|
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.distributions import Distribution
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Histogram(Distribution):
|
|
6
|
+
"""
|
|
7
|
+
This class represents a histogram distribution.
|
|
8
|
+
Basically, the distribution is a composite of uniform distributions over the bins.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, cutpoints: torch.Tensor, prob_masses: torch.Tensor):
|
|
12
|
+
"""
|
|
13
|
+
Args:
|
|
14
|
+
regions: the bin boundaries (shape: (K+1,))
|
|
15
|
+
prob_masses: the probability for landing in each regions (shape: (n, K))
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Constructed regions T_k \in [c_k, c_{k+1}) for all k \in \{0, ..., K-1\}
|
|
19
|
+
self.cutpoints = cutpoints
|
|
20
|
+
self.num_regions = len(self.cutpoints) - 1
|
|
21
|
+
|
|
22
|
+
# Predicted probabilities vector: (Pr(c_0<Y<c_1|x,w_{Histogram}),..., Pr(c_{K-1}<Y<c_K|x,w_{Histogram}))
|
|
23
|
+
self.prob_masses = prob_masses
|
|
24
|
+
assert torch.allclose(
|
|
25
|
+
torch.sum(self.prob_masses, dim=1),
|
|
26
|
+
torch.ones(self.prob_masses.shape[0], device=cutpoints.device),
|
|
27
|
+
)
|
|
28
|
+
assert self.prob_masses.shape[1] == self.num_regions
|
|
29
|
+
|
|
30
|
+
# Compute the bin widths for later use, i.e., (T_1 = c_1 - c_0,..., T_K = c_K - c_{K-1})
|
|
31
|
+
self.bin_widths = self.cutpoints[1:] - self.cutpoints[:-1]
|
|
32
|
+
assert torch.all(self.bin_widths > 0)
|
|
33
|
+
|
|
34
|
+
# Compute the PDF values using the probability masses and bin widths, i.e., normalising
|
|
35
|
+
# Pr(Y\in T_k|x,w_{Histogram})/T_k * Pr(c_0<Y<c_K|x,w_{Baseline})
|
|
36
|
+
self.prob_densities = self.prob_masses / self.bin_widths
|
|
37
|
+
|
|
38
|
+
super(Histogram, self).__init__(
|
|
39
|
+
batch_shape=torch.Size([prob_masses.shape[0]]), validate_args=False
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
"""
|
|
44
|
+
Calculate the probability densities of `values`.
|
|
45
|
+
"""
|
|
46
|
+
orig_ndim = value.ndim
|
|
47
|
+
|
|
48
|
+
# Ensure the last dimension of value matches the batch_shape
|
|
49
|
+
if value.shape[-1] != self.batch_shape[0]:
|
|
50
|
+
if value.ndim == 1:
|
|
51
|
+
value = value.unsqueeze(-1)
|
|
52
|
+
value = value.expand(-1, self.batch_shape[0])
|
|
53
|
+
|
|
54
|
+
# Ensure value is 2D
|
|
55
|
+
if value.ndim == 1:
|
|
56
|
+
value = value.unsqueeze(0)
|
|
57
|
+
|
|
58
|
+
# Initialize CDF to zeros with the same shape as value
|
|
59
|
+
probabilities = torch.zeros_like(value)
|
|
60
|
+
|
|
61
|
+
# Go through each observation vector and calculate pdf over all batch_size distributions
|
|
62
|
+
for i in range(value.shape[0]):
|
|
63
|
+
# Calculate the pdf for the `y` batch
|
|
64
|
+
y = value[i, :]
|
|
65
|
+
|
|
66
|
+
# Iterate over each bin
|
|
67
|
+
for r in range(self.num_regions):
|
|
68
|
+
in_bin = (y >= self.cutpoints[r]) & (y < self.cutpoints[r + 1])
|
|
69
|
+
probabilities[i, in_bin] = self.prob_densities[in_bin, r]
|
|
70
|
+
|
|
71
|
+
# If we added a leading dimension, remove it
|
|
72
|
+
if orig_ndim == 1 and probabilities.ndim == 2:
|
|
73
|
+
probabilities = probabilities.squeeze(0)
|
|
74
|
+
|
|
75
|
+
return probabilities
|
|
76
|
+
|
|
77
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
"""
|
|
79
|
+
Calculate the log probability densities of `values`.
|
|
80
|
+
"""
|
|
81
|
+
return torch.log(self.prob(value))
|
|
82
|
+
|
|
83
|
+
def cdf(self, value: torch.Tensor) -> torch.Tensor:
|
|
84
|
+
"""
|
|
85
|
+
Calculate the cumulative distribution function for the given values.
|
|
86
|
+
"""
|
|
87
|
+
orig_ndim = value.ndim
|
|
88
|
+
# Ensure the last dimension of value matches the batch_shape
|
|
89
|
+
if value.shape[-1] != self.batch_shape[0]: # and value.shape[-1] == 1:
|
|
90
|
+
return self.cdf_same_eval(value)
|
|
91
|
+
|
|
92
|
+
# Ensure value is 2D
|
|
93
|
+
if value.ndim == 1:
|
|
94
|
+
value = value.unsqueeze(0)
|
|
95
|
+
|
|
96
|
+
# Initialize CDF to zeros with the same shape as value
|
|
97
|
+
cdf_values = torch.zeros_like(value)
|
|
98
|
+
|
|
99
|
+
# Original cdf
|
|
100
|
+
cumulative_cdf = torch.cumsum(self.prob_masses, dim=-1).T
|
|
101
|
+
|
|
102
|
+
# Iterate over each bin
|
|
103
|
+
for j in range(value.shape[0]):
|
|
104
|
+
|
|
105
|
+
y = value[j, :]
|
|
106
|
+
|
|
107
|
+
# Conditions now compare each element in y against the cutpoints
|
|
108
|
+
# This assumes the intent is to apply the condition across all dimensions uniformly
|
|
109
|
+
condition_above_last_cutpoint = y >= self.cutpoints[-1]
|
|
110
|
+
condition_below_first_cutpoint = y <= self.cutpoints[0]
|
|
111
|
+
|
|
112
|
+
# Since y is a vector, the assignment needs to respect the condition per element
|
|
113
|
+
cdf_values[j, condition_above_last_cutpoint] = 1.0
|
|
114
|
+
cdf_values[j, condition_below_first_cutpoint] = 0.0
|
|
115
|
+
# print(condition_above_last_cutpoint, y, cdf_values)
|
|
116
|
+
|
|
117
|
+
# Determine the index of the cutpoints
|
|
118
|
+
y_expanded = y.unsqueeze(0) # Now y has shape [1, n]
|
|
119
|
+
cutpoints_below = self.cutpoints[:-1].unsqueeze(
|
|
120
|
+
1
|
|
121
|
+
) # Now self.cutpoints[:-1] has shape [K, 1]
|
|
122
|
+
cutpoints_above = self.cutpoints[1:].unsqueeze(1)
|
|
123
|
+
|
|
124
|
+
# Perform comparison to determine bins
|
|
125
|
+
comparison_result_below = y_expanded >= cutpoints_below # [K, n]
|
|
126
|
+
comparison_result_above = y_expanded < cutpoints_above # [K, n]
|
|
127
|
+
valid_bins = comparison_result_below & comparison_result_above # [K, n]
|
|
128
|
+
|
|
129
|
+
# Find the last valid bin for each element
|
|
130
|
+
last_bin_idx = valid_bins.long().argmax(dim=0) # [n]
|
|
131
|
+
# Ensure elements outside the cutpoints are handled correctly
|
|
132
|
+
below_min_mask = y_expanded.squeeze() <= self.cutpoints[0]
|
|
133
|
+
above_max_mask = y_expanded.squeeze() >= self.cutpoints[-1]
|
|
134
|
+
|
|
135
|
+
# Update last_bin_idx for values outside the cutpoints
|
|
136
|
+
last_bin_idx[below_min_mask] = (
|
|
137
|
+
0 # First bin for values below the minimum cutpoint
|
|
138
|
+
)
|
|
139
|
+
last_bin_idx[above_max_mask] = (
|
|
140
|
+
len(self.cutpoints) - 2
|
|
141
|
+
) # Last valid bin index for values above the maximum cutpoint
|
|
142
|
+
|
|
143
|
+
# Determine next_bin_idx based on last_bin_idx
|
|
144
|
+
next_bin_idx = last_bin_idx + 1
|
|
145
|
+
# Ensure next_bin_idx does not exceed the number of bins
|
|
146
|
+
next_bin_idx[above_max_mask] = len(self.cutpoints) - 1
|
|
147
|
+
|
|
148
|
+
# Initialize 'last_cdfs' and 'next_cdfs' with zeros and ones, respectively
|
|
149
|
+
zeros = torch.zeros(size=(1, value.shape[1]))
|
|
150
|
+
ones = torch.ones(size=(1, value.shape[1]))
|
|
151
|
+
|
|
152
|
+
# Ensure y is properly shaped for broadcasting. If y is already [1, n], this step might be redundant
|
|
153
|
+
y = y.reshape(1, -1) # Ensure y has shape [1, n]
|
|
154
|
+
|
|
155
|
+
# Reshape or select cutpoints for broadcasting
|
|
156
|
+
# If comparing against a single cutpoint, ensure it's shaped for broadcasting
|
|
157
|
+
cutpoint_for_comparison_lower = self.cutpoints[1].reshape(
|
|
158
|
+
1, -1
|
|
159
|
+
) # Ensure cutpoint is shaped [1, 1] or similar for broadcasting
|
|
160
|
+
cutpoint_for_comparison_upper = self.cutpoints[-2].reshape(1, -1)
|
|
161
|
+
|
|
162
|
+
# Perform the comparison
|
|
163
|
+
condition_last_cdfs = (
|
|
164
|
+
y >= cutpoint_for_comparison_lower
|
|
165
|
+
) # This should now result in a shape [1, n]
|
|
166
|
+
condition_next_cdfs = y < cutpoint_for_comparison_upper
|
|
167
|
+
|
|
168
|
+
# Update 'last_cdfs' based on condition, considering 'cumulative_cdf' indexing
|
|
169
|
+
# Ensure 'last_bin_idx' and 'next_bin_idx' are correctly broadcasted or indexed to match 'y's dimensions
|
|
170
|
+
last_cdfs = torch.where(
|
|
171
|
+
condition_last_cdfs,
|
|
172
|
+
cumulative_cdf[
|
|
173
|
+
(last_bin_idx - 1).unsqueeze(0),
|
|
174
|
+
torch.arange(cumulative_cdf.shape[1]),
|
|
175
|
+
],
|
|
176
|
+
zeros,
|
|
177
|
+
)
|
|
178
|
+
next_cdfs = torch.where(
|
|
179
|
+
condition_next_cdfs,
|
|
180
|
+
cumulative_cdf[
|
|
181
|
+
(next_bin_idx - 1).unsqueeze(0),
|
|
182
|
+
torch.arange(cumulative_cdf.shape[1]),
|
|
183
|
+
],
|
|
184
|
+
ones,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
last_bin_idx_expanded = last_bin_idx.unsqueeze(0).expand(y.shape[0], -1)
|
|
188
|
+
next_bin_idx_expanded = next_bin_idx.unsqueeze(0).expand(y.shape[0], -1)
|
|
189
|
+
|
|
190
|
+
# Determine cutpoints based on conditions
|
|
191
|
+
cutpoint_low = torch.where(
|
|
192
|
+
y < self.cutpoints[1].unsqueeze(0),
|
|
193
|
+
self.cutpoints[0],
|
|
194
|
+
self.cutpoints[last_bin_idx_expanded],
|
|
195
|
+
)
|
|
196
|
+
cutpoint_high = torch.where(
|
|
197
|
+
y >= self.cutpoints[-2].unsqueeze(0),
|
|
198
|
+
self.cutpoints[-1],
|
|
199
|
+
self.cutpoints[next_bin_idx_expanded],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Determine bin_width based on conditions
|
|
203
|
+
bin_width = torch.where(
|
|
204
|
+
y < self.cutpoints[1].unsqueeze(0),
|
|
205
|
+
self.bin_widths[0],
|
|
206
|
+
self.bin_widths[last_bin_idx_expanded],
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Compute bin_fraction for each feature across all observations
|
|
210
|
+
bin_fraction = (y - cutpoint_low) / bin_width
|
|
211
|
+
|
|
212
|
+
# Compute cdf_values using the previously obtained last_cdfs and next_cdfs for each feature across all observations
|
|
213
|
+
# Ensure last_cdfs and next_cdfs are retrieved using the approach described in previous steps and have the correct shape
|
|
214
|
+
cdf_values[j, :] = last_cdfs + (next_cdfs - last_cdfs) * bin_fraction
|
|
215
|
+
|
|
216
|
+
cdf_values = torch.clamp(cdf_values, max=1.0, min=0.0)
|
|
217
|
+
|
|
218
|
+
# If we added a leading dimension, remove it
|
|
219
|
+
if orig_ndim == 1 and cdf_values.ndim == 2:
|
|
220
|
+
cdf_values = cdf_values.squeeze(0)
|
|
221
|
+
|
|
222
|
+
return cdf_values
|
|
223
|
+
|
|
224
|
+
def cdf_same_eval(self, value: torch.Tensor) -> torch.Tensor:
|
|
225
|
+
"""
|
|
226
|
+
Calculate the cumulative distribution function for the same value across the batch.
|
|
227
|
+
"""
|
|
228
|
+
orig_ndim = value.ndim
|
|
229
|
+
|
|
230
|
+
# Ensure the last dimension of value matches the batch_shape
|
|
231
|
+
if value.shape[-1] != self.batch_shape[0]:
|
|
232
|
+
if value.ndim == 1:
|
|
233
|
+
value = value.unsqueeze(-1)
|
|
234
|
+
value = value.expand(-1, self.batch_shape[0])
|
|
235
|
+
|
|
236
|
+
# Ensure value is 2D
|
|
237
|
+
if value.ndim == 1:
|
|
238
|
+
value = value.unsqueeze(0)
|
|
239
|
+
|
|
240
|
+
# Initialize CDF to zeros with the same shape as value
|
|
241
|
+
cdf_values = torch.zeros_like(value)
|
|
242
|
+
|
|
243
|
+
# Original cdf
|
|
244
|
+
cumulative_cdf = torch.cumsum(self.prob_masses, dim=-1).T
|
|
245
|
+
|
|
246
|
+
# Iterate over each bin
|
|
247
|
+
for j in range(value.shape[0]):
|
|
248
|
+
y = value[j, 0]
|
|
249
|
+
if y >= self.cutpoints[-1]:
|
|
250
|
+
cdf_values[j, :] = torch.ones(size=(1, value.shape[1]))
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
if y <= self.cutpoints[0]:
|
|
254
|
+
cdf_values[j, :] = torch.zeros(size=(1, value.shape[1]))
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
# Determine the index of the cutpoints
|
|
258
|
+
last_bin_idx = (
|
|
259
|
+
(y >= self.cutpoints[:-1]) & (y < self.cutpoints[1:])
|
|
260
|
+
).nonzero(as_tuple=True)[0]
|
|
261
|
+
next_bin_idx = last_bin_idx + 1
|
|
262
|
+
|
|
263
|
+
# Set cdfs for the lower and upper bounds
|
|
264
|
+
last_cdfs = (
|
|
265
|
+
torch.zeros(size=(1, value.shape[1]), device=value.device)
|
|
266
|
+
if y < self.cutpoints[1]
|
|
267
|
+
else cumulative_cdf[last_bin_idx - 1, :]
|
|
268
|
+
)
|
|
269
|
+
next_cdfs = (
|
|
270
|
+
torch.ones(size=(1, value.shape[1]), device=value.device)
|
|
271
|
+
if y >= self.cutpoints[-2]
|
|
272
|
+
else cumulative_cdf[next_bin_idx - 1, :]
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Compute cdf_values
|
|
276
|
+
cutpoint_low = (
|
|
277
|
+
self.cutpoints[0]
|
|
278
|
+
if y < self.cutpoints[1]
|
|
279
|
+
else self.cutpoints[last_bin_idx]
|
|
280
|
+
)
|
|
281
|
+
cutpoint_high = (
|
|
282
|
+
self.cutpoints[-1]
|
|
283
|
+
if y >= self.cutpoints[-2]
|
|
284
|
+
else self.cutpoints[next_bin_idx]
|
|
285
|
+
)
|
|
286
|
+
bin_width = (
|
|
287
|
+
self.bin_widths[0]
|
|
288
|
+
if y < self.cutpoints[1]
|
|
289
|
+
else self.bin_widths[last_bin_idx]
|
|
290
|
+
)
|
|
291
|
+
bin_fraction = (y - cutpoint_low) / bin_width
|
|
292
|
+
cdf_values[j, :] = last_cdfs + (next_cdfs - last_cdfs) * bin_fraction
|
|
293
|
+
|
|
294
|
+
cdf_values = torch.clamp(cdf_values, max=1.0, min=0.0)
|
|
295
|
+
|
|
296
|
+
# If we added a leading dimension, remove it
|
|
297
|
+
if orig_ndim == 1 and cdf_values.ndim == 2:
|
|
298
|
+
cdf_values = cdf_values.squeeze(0)
|
|
299
|
+
|
|
300
|
+
return cdf_values
|
|
301
|
+
|
|
302
|
+
def cdf_at_cutpoints(self) -> torch.Tensor:
|
|
303
|
+
"""
|
|
304
|
+
Calculate the cumulative distribution function at each cutpoint.
|
|
305
|
+
"""
|
|
306
|
+
# Want to use cumsum but add a 0 to the beginning, torch.cumsum(self.prob_masses, dim=1)
|
|
307
|
+
return torch.cat(
|
|
308
|
+
[
|
|
309
|
+
torch.zeros(
|
|
310
|
+
self.prob_masses.shape[0], 1, device=self.prob_masses.device
|
|
311
|
+
),
|
|
312
|
+
torch.cumsum(self.prob_masses, dim=1),
|
|
313
|
+
],
|
|
314
|
+
dim=1,
|
|
315
|
+
).T
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def mean(self) -> torch.Tensor:
|
|
319
|
+
"""
|
|
320
|
+
Calculate the mean of the distribution.
|
|
321
|
+
Returns:
|
|
322
|
+
the mean (shape: (batch_shape,))
|
|
323
|
+
"""
|
|
324
|
+
middle_of_bins = (self.cutpoints[1:] + self.cutpoints[:-1]) / 2
|
|
325
|
+
return torch.sum(self.prob_masses * middle_of_bins, dim=1)
|
|
326
|
+
|
|
327
|
+
def icdf(self, p, l=None, u=None, max_iter=1000, tolerance=1e-7) -> torch.Tensor:
|
|
328
|
+
"""
|
|
329
|
+
Calculate the inverse CDF (quantiles) of the distribution for the given cumulative probability.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
p: cumulative probability values at which to evaluate icdf
|
|
333
|
+
l: lower bound for the quantile search
|
|
334
|
+
u: upper bound for the quantile search
|
|
335
|
+
max_iter: maximum number of iterations permitted for the quantile search
|
|
336
|
+
tolerance: stopping criteria for the search (precision)
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
A tensor of shape (1, batch_shape) containing the inverse CDF values.
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
num_observations = self.cdf(torch.Tensor([1]).unsqueeze(-1)).shape[
|
|
343
|
+
1
|
|
344
|
+
] # Dummy call to cdf to determine the batch size
|
|
345
|
+
percentiles_tensor = torch.full(
|
|
346
|
+
(1, num_observations), fill_value=p, dtype=torch.float32
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Initialise matrices for the bounds
|
|
350
|
+
lower_bounds = (
|
|
351
|
+
l if l is not None else torch.Tensor([0])
|
|
352
|
+
) # self.cutpoints[0] - (self.cutpoints[-1]-self.cutpoints[0])
|
|
353
|
+
upper_bounds = (
|
|
354
|
+
u
|
|
355
|
+
if u is not None
|
|
356
|
+
else self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
|
|
357
|
+
) # Adjust max value as needed
|
|
358
|
+
|
|
359
|
+
lower_bounds = lower_bounds.repeat(num_observations).reshape(
|
|
360
|
+
1, num_observations
|
|
361
|
+
)
|
|
362
|
+
upper_bounds = upper_bounds.repeat(num_observations).reshape(
|
|
363
|
+
1, num_observations
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
for _ in range(max_iter):
|
|
367
|
+
mid_points = (lower_bounds + upper_bounds) / 2
|
|
368
|
+
|
|
369
|
+
cdf_vals = self.cdf(mid_points)
|
|
370
|
+
|
|
371
|
+
# Update the bounds based on where the CDF values are relative to the target percentiles
|
|
372
|
+
lower_update = cdf_vals < percentiles_tensor
|
|
373
|
+
upper_update = ~lower_update
|
|
374
|
+
lower_bounds = torch.where(lower_update, mid_points, lower_bounds)
|
|
375
|
+
upper_bounds = torch.where(upper_update, mid_points, upper_bounds)
|
|
376
|
+
|
|
377
|
+
# Check for convergence
|
|
378
|
+
if torch.max(upper_bounds - lower_bounds) < tolerance:
|
|
379
|
+
break
|
|
380
|
+
|
|
381
|
+
# Use the midpoint between the final bounds as the quantile estimate
|
|
382
|
+
quantiles = (lower_bounds + upper_bounds) / 2
|
|
383
|
+
|
|
384
|
+
return quantiles
|
|
385
|
+
|
|
386
|
+
def quantiles(
|
|
387
|
+
self, percentiles: list, l=None, u=None, max_iter=1000, tolerance=1e-7
|
|
388
|
+
) -> torch.Tensor:
|
|
389
|
+
"""
|
|
390
|
+
Calculate the quantile values for the given observations and percentiles (cumulative probabilities * 100).
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
l = torch.Tensor(
|
|
394
|
+
[0]
|
|
395
|
+
) # self.cutpoints[0] - (self.cutpoints[-1]-self.cutpoints[0]) if l is None else l
|
|
396
|
+
u = (
|
|
397
|
+
self.cutpoints[-1] + (self.cutpoints[-1] - self.cutpoints[0])
|
|
398
|
+
if u is None
|
|
399
|
+
else u
|
|
400
|
+
)
|
|
401
|
+
quantiles = [
|
|
402
|
+
self.icdf(torch.tensor(percentile / 100.0), l, u, max_iter, tolerance)
|
|
403
|
+
for percentile in percentiles
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
return torch.stack(quantiles, dim=1)[0]
|