gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch.distributions import Distribution, Gamma, Poisson, constraints
|
|
6
|
+
from torch.distributions.utils import (
|
|
7
|
+
broadcast_all,
|
|
8
|
+
lazy_property,
|
|
9
|
+
logits_to_probs,
|
|
10
|
+
probs_to_logits,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def log_zinb_positive(
|
|
15
|
+
x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, pi: torch.Tensor, eps=1e-8
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
From scVI.
|
|
19
|
+
Log likelihood (scalar) of a minibatch according to a zinb model.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
x
|
|
24
|
+
Data
|
|
25
|
+
mu
|
|
26
|
+
mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
|
|
27
|
+
theta
|
|
28
|
+
inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
|
|
29
|
+
pi
|
|
30
|
+
logit of the dropout parameter (real support) (shape: minibatch x vars)
|
|
31
|
+
eps
|
|
32
|
+
numerical stability constant
|
|
33
|
+
|
|
34
|
+
Notes
|
|
35
|
+
-----
|
|
36
|
+
We parametrize the bernoulli using the logits, hence the softplus functions appearing.
|
|
37
|
+
"""
|
|
38
|
+
# theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels)
|
|
39
|
+
if theta.ndimension() == 1:
|
|
40
|
+
theta = theta.view(
|
|
41
|
+
1, theta.size(0)
|
|
42
|
+
) # In this case, we reshape theta for broadcasting
|
|
43
|
+
|
|
44
|
+
softplus_pi = F.softplus(-pi) # uses log(sigmoid(x)) = -softplus(-x)
|
|
45
|
+
log_theta_eps = torch.log(theta + eps)
|
|
46
|
+
log_theta_mu_eps = torch.log(theta + mu + eps)
|
|
47
|
+
pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)
|
|
48
|
+
|
|
49
|
+
case_zero = F.softplus(pi_theta_log) - softplus_pi
|
|
50
|
+
mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)
|
|
51
|
+
|
|
52
|
+
case_non_zero = (
|
|
53
|
+
-softplus_pi
|
|
54
|
+
+ pi_theta_log
|
|
55
|
+
+ x * (torch.log(mu + eps) - log_theta_mu_eps)
|
|
56
|
+
+ torch.lgamma(x + theta)
|
|
57
|
+
- torch.lgamma(theta)
|
|
58
|
+
- torch.lgamma(x + 1)
|
|
59
|
+
)
|
|
60
|
+
mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)
|
|
61
|
+
|
|
62
|
+
res = mul_case_zero + mul_case_non_zero
|
|
63
|
+
|
|
64
|
+
return res
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def log_nb_positive(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps=1e-8):
|
|
68
|
+
"""
|
|
69
|
+
From scVI.
|
|
70
|
+
Log likelihood (scalar) of a minibatch according to a nb model.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
x
|
|
75
|
+
data
|
|
76
|
+
mu
|
|
77
|
+
mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
|
|
78
|
+
theta
|
|
79
|
+
inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
|
|
80
|
+
eps
|
|
81
|
+
numerical stability constant
|
|
82
|
+
|
|
83
|
+
Notes
|
|
84
|
+
-----
|
|
85
|
+
We parametrize the bernoulli using the logits, hence the softplus functions appearing.
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
if theta.ndimension() == 1:
|
|
89
|
+
theta = theta.view(
|
|
90
|
+
1, theta.size(0)
|
|
91
|
+
) # In this case, we reshape theta for broadcasting
|
|
92
|
+
|
|
93
|
+
log_theta_mu_eps = torch.log(theta + mu + eps)
|
|
94
|
+
|
|
95
|
+
res = (
|
|
96
|
+
theta * (torch.log(theta + eps) - log_theta_mu_eps)
|
|
97
|
+
+ x * (torch.log(mu + eps) - log_theta_mu_eps)
|
|
98
|
+
+ torch.lgamma(x + theta)
|
|
99
|
+
- torch.lgamma(theta)
|
|
100
|
+
- torch.lgamma(x + 1)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return res
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def log_mixture_nb(
|
|
107
|
+
x: torch.Tensor,
|
|
108
|
+
mu_1: torch.Tensor,
|
|
109
|
+
mu_2: torch.Tensor,
|
|
110
|
+
theta_1: torch.Tensor,
|
|
111
|
+
theta_2: torch.Tensor,
|
|
112
|
+
pi_logits: torch.Tensor,
|
|
113
|
+
eps=1e-8,
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
From scVI.
|
|
117
|
+
Log likelihood (scalar) of a minibatch according to a mixture nb model.
|
|
118
|
+
|
|
119
|
+
pi_logits is the probability (logits) to be in the first component.
|
|
120
|
+
For totalVI, the first component should be background.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
x
|
|
125
|
+
Observed data
|
|
126
|
+
mu_1
|
|
127
|
+
Mean of the first negative binomial component (has to be positive support) (shape: minibatch x features)
|
|
128
|
+
mu_2
|
|
129
|
+
Mean of the second negative binomial (has to be positive support) (shape: minibatch x features)
|
|
130
|
+
theta_1
|
|
131
|
+
First inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
|
|
132
|
+
theta_2
|
|
133
|
+
Second inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
|
|
134
|
+
If None, assume one shared inverse dispersion parameter.
|
|
135
|
+
pi_logits
|
|
136
|
+
Probability of belonging to mixture component 1 (logits scale)
|
|
137
|
+
eps
|
|
138
|
+
Numerical stability constant
|
|
139
|
+
"""
|
|
140
|
+
if theta_2 is not None:
|
|
141
|
+
log_nb_1 = log_nb_positive(x, mu_1, theta_1)
|
|
142
|
+
log_nb_2 = log_nb_positive(x, mu_2, theta_2)
|
|
143
|
+
# this is intended to reduce repeated computations
|
|
144
|
+
else:
|
|
145
|
+
theta = theta_1
|
|
146
|
+
if theta.ndimension() == 1:
|
|
147
|
+
theta = theta.view(
|
|
148
|
+
1, theta.size(0)
|
|
149
|
+
) # In this case, we reshape theta for broadcasting
|
|
150
|
+
|
|
151
|
+
log_theta_mu_1_eps = torch.log(theta + mu_1 + eps)
|
|
152
|
+
log_theta_mu_2_eps = torch.log(theta + mu_2 + eps)
|
|
153
|
+
lgamma_x_theta = torch.lgamma(x + theta)
|
|
154
|
+
lgamma_theta = torch.lgamma(theta)
|
|
155
|
+
lgamma_x_plus_1 = torch.lgamma(x + 1)
|
|
156
|
+
|
|
157
|
+
log_nb_1 = (
|
|
158
|
+
theta * (torch.log(theta + eps) - log_theta_mu_1_eps)
|
|
159
|
+
+ x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps)
|
|
160
|
+
+ lgamma_x_theta
|
|
161
|
+
- lgamma_theta
|
|
162
|
+
- lgamma_x_plus_1
|
|
163
|
+
)
|
|
164
|
+
log_nb_2 = (
|
|
165
|
+
theta * (torch.log(theta + eps) - log_theta_mu_2_eps)
|
|
166
|
+
+ x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps)
|
|
167
|
+
+ lgamma_x_theta
|
|
168
|
+
- lgamma_theta
|
|
169
|
+
- lgamma_x_plus_1
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
logsumexp = torch.logsumexp(torch.stack(
|
|
173
|
+
(log_nb_1, log_nb_2 - pi_logits)), dim=0)
|
|
174
|
+
softplus_pi = F.softplus(-pi_logits)
|
|
175
|
+
|
|
176
|
+
log_mixture_nb = logsumexp - softplus_pi
|
|
177
|
+
|
|
178
|
+
return log_mixture_nb
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
|
|
182
|
+
r"""
|
|
183
|
+
From scVI.
|
|
184
|
+
NB parameterizations conversion.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
mu
|
|
189
|
+
mean of the NB distribution.
|
|
190
|
+
theta
|
|
191
|
+
inverse overdispersion.
|
|
192
|
+
eps
|
|
193
|
+
constant used for numerical log stability. (Default value = 1e-6)
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
type
|
|
198
|
+
the number of failures until the experiment is stopped
|
|
199
|
+
and the success probability.
|
|
200
|
+
"""
|
|
201
|
+
if not (mu is None) == (theta is None):
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"If using the mu/theta NB parameterization, both parameters must be specified"
|
|
204
|
+
)
|
|
205
|
+
logits = (mu + eps).log() - (theta + eps).log()
|
|
206
|
+
total_count = theta
|
|
207
|
+
return total_count, logits
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _convert_counts_logits_to_mean_disp(total_count, logits):
|
|
211
|
+
"""
|
|
212
|
+
From scVI.
|
|
213
|
+
NB parameterizations conversion.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
total_count
|
|
218
|
+
Number of failures until the experiment is stopped.
|
|
219
|
+
logits
|
|
220
|
+
success logits.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
type
|
|
225
|
+
the mean and inverse overdispersion of the NB distribution.
|
|
226
|
+
|
|
227
|
+
"""
|
|
228
|
+
theta = total_count
|
|
229
|
+
mu = logits.exp() * theta
|
|
230
|
+
return mu, theta
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _gamma(theta, mu):
|
|
234
|
+
concentration = theta
|
|
235
|
+
rate = theta / mu
|
|
236
|
+
# Important remark: Gamma is parametrized by the rate = 1/scale!
|
|
237
|
+
gamma_d = Gamma(concentration=concentration, rate=rate)
|
|
238
|
+
return gamma_d
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class NegativeBinomial(Distribution):
|
|
242
|
+
r"""
|
|
243
|
+
From scVI.
|
|
244
|
+
Negative binomial distribution.
|
|
245
|
+
|
|
246
|
+
One of the following parameterizations must be provided:
|
|
247
|
+
|
|
248
|
+
(1), (`total_count`, `probs`) where `total_count` is the number of failures until
|
|
249
|
+
the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
|
|
250
|
+
parameterization, which is the one used by scvi-tools. These parameters respectively
|
|
251
|
+
control the mean and inverse dispersion of the distribution.
|
|
252
|
+
|
|
253
|
+
In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:
|
|
254
|
+
|
|
255
|
+
1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
|
|
256
|
+
2. :math:`x \sim \textrm{Poisson}(w)`
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
total_count
|
|
261
|
+
Number of failures until the experiment is stopped.
|
|
262
|
+
probs
|
|
263
|
+
The success probability.
|
|
264
|
+
mu
|
|
265
|
+
Mean of the distribution.
|
|
266
|
+
theta
|
|
267
|
+
Inverse dispersion.
|
|
268
|
+
scale
|
|
269
|
+
Normalized mean expression of the distribution.
|
|
270
|
+
validate_args
|
|
271
|
+
Raise ValueError if arguments do not match constraints
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
arg_constraints = {
|
|
275
|
+
"mu": constraints.greater_than_eq(0),
|
|
276
|
+
"theta": constraints.greater_than_eq(0),
|
|
277
|
+
}
|
|
278
|
+
support = constraints.nonnegative_integer
|
|
279
|
+
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
total_count: torch.Tensor | None = None,
|
|
283
|
+
probs: torch.Tensor | None = None,
|
|
284
|
+
logits: torch.Tensor | None = None,
|
|
285
|
+
mu: torch.Tensor | None = None,
|
|
286
|
+
theta: torch.Tensor | None = None,
|
|
287
|
+
scale: torch.Tensor | None = None,
|
|
288
|
+
validate_args: bool = False,
|
|
289
|
+
):
|
|
290
|
+
self._eps = 1e-8
|
|
291
|
+
if (mu is None) == (total_count is None):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
"Please use one of the two possible parameterizations. Refer to the documentation for more information."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
using_param_1 = total_count is not None and (
|
|
297
|
+
logits is not None or probs is not None
|
|
298
|
+
)
|
|
299
|
+
if using_param_1:
|
|
300
|
+
logits = logits if logits is not None else probs_to_logits(probs)
|
|
301
|
+
total_count = total_count.type_as(logits)
|
|
302
|
+
total_count, logits = broadcast_all(total_count, logits)
|
|
303
|
+
mu, theta = _convert_counts_logits_to_mean_disp(
|
|
304
|
+
total_count, logits)
|
|
305
|
+
else:
|
|
306
|
+
mu, theta = broadcast_all(mu, theta)
|
|
307
|
+
self.mu = mu
|
|
308
|
+
self.theta = theta
|
|
309
|
+
self.scale = scale
|
|
310
|
+
super().__init__(validate_args=validate_args)
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def mean(self):
|
|
314
|
+
return self.mu
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def variance(self):
|
|
318
|
+
return self.mean + (self.mean**2) / self.theta
|
|
319
|
+
|
|
320
|
+
@torch.inference_mode()
|
|
321
|
+
def sample(
|
|
322
|
+
self,
|
|
323
|
+
sample_shape: torch.Size | tuple | None = None,
|
|
324
|
+
) -> torch.Tensor:
|
|
325
|
+
"""Sample from the distribution."""
|
|
326
|
+
sample_shape = sample_shape or torch.Size()
|
|
327
|
+
gamma_d = self._gamma()
|
|
328
|
+
p_means = gamma_d.sample(sample_shape)
|
|
329
|
+
|
|
330
|
+
# Clamping as distributions objects can have buggy behaviors when
|
|
331
|
+
# their parameters are too high
|
|
332
|
+
l_train = torch.clamp(p_means, max=1e8)
|
|
333
|
+
# Shape : (n_samples, n_cells_batch, n_vars)
|
|
334
|
+
counts = Poisson(l_train).sample()
|
|
335
|
+
return counts
|
|
336
|
+
|
|
337
|
+
@torch.inference_mode()
|
|
338
|
+
def rsample(self, sample_shape: torch.Size | tuple | None = None):
|
|
339
|
+
"""Sample from the distribution."""
|
|
340
|
+
return self.sample(sample_shape=sample_shape)
|
|
341
|
+
|
|
342
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
343
|
+
if self._validate_args:
|
|
344
|
+
try:
|
|
345
|
+
self._validate_sample(value)
|
|
346
|
+
except ValueError:
|
|
347
|
+
warnings.warn(
|
|
348
|
+
"The value argument must be within the support of the distribution",
|
|
349
|
+
UserWarning, stacklevel=2,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps)
|
|
353
|
+
|
|
354
|
+
def _gamma(self):
|
|
355
|
+
return _gamma(self.theta, self.mu)
|
|
356
|
+
|
|
357
|
+
def pearson_residuals(self, x: torch.Tensor) -> torch.Tensor:
|
|
358
|
+
r"""
|
|
359
|
+
Compute the Pearson residuals.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
x
|
|
364
|
+
Observed data.
|
|
365
|
+
|
|
366
|
+
Returns
|
|
367
|
+
-------
|
|
368
|
+
type
|
|
369
|
+
Pearson residuals.
|
|
370
|
+
"""
|
|
371
|
+
mean = self.mean
|
|
372
|
+
variance = self.variance
|
|
373
|
+
return (x - mean) / torch.sqrt(variance)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class ZeroInflatedNegativeBinomial(NegativeBinomial):
|
|
377
|
+
r"""
|
|
378
|
+
From scVI.
|
|
379
|
+
Zero-inflated negative binomial distribution.
|
|
380
|
+
|
|
381
|
+
One of the following parameterizations must be provided:
|
|
382
|
+
|
|
383
|
+
(1), (`total_count`, `probs`) where `total_count` is the number of failures until
|
|
384
|
+
the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
|
|
385
|
+
parameterization, which is the one used by scvi-tools. These parameters respectively
|
|
386
|
+
control the mean and inverse dispersion of the distribution.
|
|
387
|
+
|
|
388
|
+
In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:
|
|
389
|
+
|
|
390
|
+
1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
|
|
391
|
+
2. :math:`x \sim \textrm{Poisson}(w)`
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
total_count
|
|
396
|
+
Number of failures until the experiment is stopped.
|
|
397
|
+
probs
|
|
398
|
+
The success probability.
|
|
399
|
+
mu
|
|
400
|
+
Mean of the distribution.
|
|
401
|
+
theta
|
|
402
|
+
Inverse dispersion.
|
|
403
|
+
zi_logits
|
|
404
|
+
Logits scale of zero inflation probability.
|
|
405
|
+
scale
|
|
406
|
+
Normalized mean expression of the distribution.
|
|
407
|
+
validate_args
|
|
408
|
+
Raise ValueError if arguments do not match constraints
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
arg_constraints = {
|
|
412
|
+
"mu": constraints.greater_than_eq(0),
|
|
413
|
+
"theta": constraints.greater_than_eq(0),
|
|
414
|
+
"zi_probs": constraints.half_open_interval(0.0, 1.0),
|
|
415
|
+
"zi_logits": constraints.real,
|
|
416
|
+
}
|
|
417
|
+
support = constraints.nonnegative_integer
|
|
418
|
+
|
|
419
|
+
def __init__(
|
|
420
|
+
self,
|
|
421
|
+
total_count: torch.Tensor | None = None,
|
|
422
|
+
probs: torch.Tensor | None = None,
|
|
423
|
+
logits: torch.Tensor | None = None,
|
|
424
|
+
mu: torch.Tensor | None = None,
|
|
425
|
+
theta: torch.Tensor | None = None,
|
|
426
|
+
zi_logits: torch.Tensor | None = None,
|
|
427
|
+
scale: torch.Tensor | None = None,
|
|
428
|
+
validate_args: bool = False,
|
|
429
|
+
):
|
|
430
|
+
super().__init__(
|
|
431
|
+
total_count=total_count,
|
|
432
|
+
probs=probs,
|
|
433
|
+
logits=logits,
|
|
434
|
+
mu=mu,
|
|
435
|
+
theta=theta,
|
|
436
|
+
scale=scale,
|
|
437
|
+
validate_args=validate_args,
|
|
438
|
+
)
|
|
439
|
+
self.zi_logits, self.mu, self.theta = broadcast_all(
|
|
440
|
+
zi_logits, self.mu, self.theta
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
@property
|
|
444
|
+
def mean(self):
|
|
445
|
+
pi = self.zi_probs
|
|
446
|
+
return (1 - pi) * self.mu
|
|
447
|
+
|
|
448
|
+
@property
|
|
449
|
+
def variance(self):
|
|
450
|
+
pi = self.zi_probs
|
|
451
|
+
return (1 - pi) * self.mu * (1 + self.mu * (pi + 1 / self.theta))
|
|
452
|
+
|
|
453
|
+
@lazy_property
|
|
454
|
+
def zi_logits(self) -> torch.Tensor:
|
|
455
|
+
"""ZI logits."""
|
|
456
|
+
return probs_to_logits(self.zi_probs, is_binary=True)
|
|
457
|
+
|
|
458
|
+
@lazy_property
|
|
459
|
+
def zi_probs(self) -> torch.Tensor:
|
|
460
|
+
return logits_to_probs(self.zi_logits, is_binary=True)
|
|
461
|
+
|
|
462
|
+
@torch.inference_mode()
|
|
463
|
+
def sample(
|
|
464
|
+
self,
|
|
465
|
+
sample_shape: torch.Size | tuple | None = None,
|
|
466
|
+
) -> torch.Tensor:
|
|
467
|
+
"""Sample from the distribution."""
|
|
468
|
+
sample_shape = sample_shape or torch.Size()
|
|
469
|
+
samp = super().sample(sample_shape=sample_shape)
|
|
470
|
+
is_zero = torch.rand_like(samp) <= self.zi_probs
|
|
471
|
+
samp_ = torch.where(
|
|
472
|
+
is_zero, torch.tensor(0.0, dtype=torch.float32, device=samp.device), samp
|
|
473
|
+
)
|
|
474
|
+
return samp_
|
|
475
|
+
|
|
476
|
+
@torch.inference_mode()
|
|
477
|
+
def rsample( # type: ignore
|
|
478
|
+
self,
|
|
479
|
+
sample_shape: torch.Size | tuple | None = None,
|
|
480
|
+
) -> torch.Tensor:
|
|
481
|
+
"""Sample from the distribution."""
|
|
482
|
+
sample_shape = sample_shape or torch.Size()
|
|
483
|
+
samp = super().rsample(sample_shape=sample_shape)
|
|
484
|
+
is_zero = torch.rand_like(samp) <= self.zi_probs
|
|
485
|
+
samp_ = torch.where(is_zero, torch.tensor(
|
|
486
|
+
0.0, dtype=torch.float32, device=samp.device), samp)
|
|
487
|
+
return samp_
|
|
488
|
+
|
|
489
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
490
|
+
"""Log probability."""
|
|
491
|
+
try:
|
|
492
|
+
self._validate_sample(value)
|
|
493
|
+
except ValueError:
|
|
494
|
+
warnings.warn(
|
|
495
|
+
"The value argument must be within the support of the distribution",
|
|
496
|
+
UserWarning, stacklevel=2,
|
|
497
|
+
)
|
|
498
|
+
return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.distributions import Normal
|
|
4
|
+
from torch.distributions import kl_divergence as kl
|
|
5
|
+
|
|
6
|
+
from .gene_former import GeneModuleFormer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def full_block(in_dim, out_dim, p_drop=0.1):
|
|
10
|
+
return nn.Sequential(
|
|
11
|
+
nn.Linear(in_dim, out_dim),
|
|
12
|
+
nn.BatchNorm1d(out_dim),
|
|
13
|
+
nn.ReLU(),
|
|
14
|
+
nn.Dropout(p=p_drop),
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
class transform(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
batch transform encoder
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self,
|
|
22
|
+
input_size,
|
|
23
|
+
hidden_size,
|
|
24
|
+
batch_emb_size,
|
|
25
|
+
module_dim,
|
|
26
|
+
hidden_gmf,
|
|
27
|
+
n_modules,
|
|
28
|
+
nhead,
|
|
29
|
+
n_enc_layer,
|
|
30
|
+
use_tf):
|
|
31
|
+
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.use_tf = use_tf
|
|
34
|
+
|
|
35
|
+
if self.use_tf:
|
|
36
|
+
self.input_size = hidden_gmf + batch_emb_size
|
|
37
|
+
self.gmf = GeneModuleFormer(input_dim=input_size,
|
|
38
|
+
module_dim=module_dim,
|
|
39
|
+
hidden_dim=hidden_gmf,
|
|
40
|
+
n_modules=n_modules,
|
|
41
|
+
nhead=nhead,
|
|
42
|
+
n_enc_layer=n_enc_layer
|
|
43
|
+
)
|
|
44
|
+
self.transform = full_block(self.input_size,hidden_size)
|
|
45
|
+
else:
|
|
46
|
+
self.input_size = input_size + batch_emb_size
|
|
47
|
+
self.transform = full_block(self.input_size,hidden_size)
|
|
48
|
+
self.norm = nn.LayerNorm(hidden_size)
|
|
49
|
+
|
|
50
|
+
def forward(self, x, batch):
|
|
51
|
+
if self.use_tf:
|
|
52
|
+
x = self.gmf(x)
|
|
53
|
+
x = self.transform(torch.cat([x,batch],dim=1))
|
|
54
|
+
else:
|
|
55
|
+
x = self.transform(torch.cat([x,batch],dim=1))
|
|
56
|
+
return self.norm(x)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Encoder(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
GCN encoder
|
|
63
|
+
"""
|
|
64
|
+
def __init__(self,
|
|
65
|
+
input_size,
|
|
66
|
+
hidden_size,
|
|
67
|
+
emb_size,
|
|
68
|
+
batch_emb_size,
|
|
69
|
+
module_dim,
|
|
70
|
+
hidden_gmf,
|
|
71
|
+
n_modules,
|
|
72
|
+
nhead,
|
|
73
|
+
n_enc_layer,
|
|
74
|
+
use_tf,
|
|
75
|
+
variational=True):
|
|
76
|
+
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.variational = variational
|
|
79
|
+
|
|
80
|
+
self.tf = transform(
|
|
81
|
+
input_size,
|
|
82
|
+
hidden_size,
|
|
83
|
+
batch_emb_size,
|
|
84
|
+
module_dim,
|
|
85
|
+
hidden_gmf,
|
|
86
|
+
n_modules,
|
|
87
|
+
nhead,
|
|
88
|
+
n_enc_layer,
|
|
89
|
+
use_tf
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.mlp = nn.Sequential(full_block(hidden_size, hidden_size),
|
|
93
|
+
full_block(hidden_size,hidden_size))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
self.fc_mean = nn.Linear(hidden_size,emb_size)
|
|
97
|
+
self.fc_var = nn.Linear(hidden_size,emb_size)
|
|
98
|
+
|
|
99
|
+
def forward(self, x, batch):
|
|
100
|
+
|
|
101
|
+
xtf = self.tf(x,batch)
|
|
102
|
+
h = self.mlp(xtf)
|
|
103
|
+
if not self.variational:
|
|
104
|
+
mu = self.fc_mean(h)
|
|
105
|
+
return mu
|
|
106
|
+
|
|
107
|
+
mu = self.fc_mean(h)
|
|
108
|
+
logvar = self.fc_var(h)
|
|
109
|
+
self.mu = mu
|
|
110
|
+
self.sigma = logvar.exp().sqrt()
|
|
111
|
+
self.dist = Normal(self.mu, self.sigma)
|
|
112
|
+
return self.dist.rsample()
|
|
113
|
+
|
|
114
|
+
def kl_loss(self):
|
|
115
|
+
if not hasattr(self, "dist"):
|
|
116
|
+
return 0
|
|
117
|
+
|
|
118
|
+
mean = torch.zeros_like(self.mu)
|
|
119
|
+
scale = torch.ones_like(self.sigma)
|
|
120
|
+
kl_loss = kl(self.dist, Normal(mean, scale))
|
|
121
|
+
return kl_loss.mean()
|
|
122
|
+
|
|
123
|
+
class Decoder(nn.Module):
|
|
124
|
+
"""
|
|
125
|
+
Shared decoder
|
|
126
|
+
"""
|
|
127
|
+
def __init__(self,
|
|
128
|
+
out_put_size,
|
|
129
|
+
hidden_size,
|
|
130
|
+
emb_size,
|
|
131
|
+
batch_emb_size,
|
|
132
|
+
class_size,
|
|
133
|
+
decoder_type,
|
|
134
|
+
distribution,
|
|
135
|
+
n_layers=3):
|
|
136
|
+
super().__init__()
|
|
137
|
+
|
|
138
|
+
self.decoder_type = decoder_type
|
|
139
|
+
self.mlp = nn.ModuleList()
|
|
140
|
+
|
|
141
|
+
# Set initial input size
|
|
142
|
+
if decoder_type == 'reconstruction':
|
|
143
|
+
input_size = emb_size + batch_emb_size
|
|
144
|
+
elif decoder_type == 'classification':
|
|
145
|
+
input_size = emb_size * 2 + batch_emb_size
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError(f"Unknown decoder_type: {decoder_type}")
|
|
148
|
+
|
|
149
|
+
# Build MLP layers with batch embedding concat at each step
|
|
150
|
+
if isinstance(n_layers, int):
|
|
151
|
+
n_layers = [hidden_size] * n_layers
|
|
152
|
+
|
|
153
|
+
for hidden_size in n_layers:
|
|
154
|
+
self.mlp.append(full_block(input_size, hidden_size))
|
|
155
|
+
input_size = hidden_size + batch_emb_size # update for next layer input
|
|
156
|
+
|
|
157
|
+
# Final output layer
|
|
158
|
+
if decoder_type == 'reconstruction':
|
|
159
|
+
self.zi_logit = nn.Linear(input_size, out_put_size)
|
|
160
|
+
self.fc_rec = nn.Linear(input_size, out_put_size)
|
|
161
|
+
elif decoder_type == 'classification':
|
|
162
|
+
self.fc_class = nn.Linear(input_size, class_size)
|
|
163
|
+
|
|
164
|
+
if distribution in ['nb','zinb']:
|
|
165
|
+
self.act = nn.Softmax(dim=-1)
|
|
166
|
+
else:
|
|
167
|
+
self.act = nn.Identity()
|
|
168
|
+
|
|
169
|
+
def forward(self, z, batch):
|
|
170
|
+
x = torch.cat([z, batch], dim=1)
|
|
171
|
+
|
|
172
|
+
for layer in self.mlp:
|
|
173
|
+
x = layer(x)
|
|
174
|
+
x = torch.cat([x, batch], dim=1) # concat batch after each layer
|
|
175
|
+
|
|
176
|
+
if self.decoder_type == 'reconstruction':
|
|
177
|
+
x_hat = self.act(self.fc_rec(x))
|
|
178
|
+
zi_logit = self.zi_logit(x)
|
|
179
|
+
return x_hat, zi_logit
|
|
180
|
+
|
|
181
|
+
elif self.decoder_type == 'classification':
|
|
182
|
+
x_class = self.fc_class(x)
|
|
183
|
+
return x_class
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError(f"Unknown decoder_type: {self.decoder_type}")
|