gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,498 @@
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.distributions import Distribution, Gamma, Poisson, constraints
6
+ from torch.distributions.utils import (
7
+ broadcast_all,
8
+ lazy_property,
9
+ logits_to_probs,
10
+ probs_to_logits,
11
+ )
12
+
13
+
14
+ def log_zinb_positive(
15
+ x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, pi: torch.Tensor, eps=1e-8
16
+ ):
17
+ """
18
+ From scVI.
19
+ Log likelihood (scalar) of a minibatch according to a zinb model.
20
+
21
+ Parameters
22
+ ----------
23
+ x
24
+ Data
25
+ mu
26
+ mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
27
+ theta
28
+ inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
29
+ pi
30
+ logit of the dropout parameter (real support) (shape: minibatch x vars)
31
+ eps
32
+ numerical stability constant
33
+
34
+ Notes
35
+ -----
36
+ We parametrize the bernoulli using the logits, hence the softplus functions appearing.
37
+ """
38
+ # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels)
39
+ if theta.ndimension() == 1:
40
+ theta = theta.view(
41
+ 1, theta.size(0)
42
+ ) # In this case, we reshape theta for broadcasting
43
+
44
+ softplus_pi = F.softplus(-pi) # uses log(sigmoid(x)) = -softplus(-x)
45
+ log_theta_eps = torch.log(theta + eps)
46
+ log_theta_mu_eps = torch.log(theta + mu + eps)
47
+ pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)
48
+
49
+ case_zero = F.softplus(pi_theta_log) - softplus_pi
50
+ mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)
51
+
52
+ case_non_zero = (
53
+ -softplus_pi
54
+ + pi_theta_log
55
+ + x * (torch.log(mu + eps) - log_theta_mu_eps)
56
+ + torch.lgamma(x + theta)
57
+ - torch.lgamma(theta)
58
+ - torch.lgamma(x + 1)
59
+ )
60
+ mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)
61
+
62
+ res = mul_case_zero + mul_case_non_zero
63
+
64
+ return res
65
+
66
+
67
+ def log_nb_positive(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps=1e-8):
68
+ """
69
+ From scVI.
70
+ Log likelihood (scalar) of a minibatch according to a nb model.
71
+
72
+ Parameters
73
+ ----------
74
+ x
75
+ data
76
+ mu
77
+ mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
78
+ theta
79
+ inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
80
+ eps
81
+ numerical stability constant
82
+
83
+ Notes
84
+ -----
85
+ We parametrize the bernoulli using the logits, hence the softplus functions appearing.
86
+
87
+ """
88
+ if theta.ndimension() == 1:
89
+ theta = theta.view(
90
+ 1, theta.size(0)
91
+ ) # In this case, we reshape theta for broadcasting
92
+
93
+ log_theta_mu_eps = torch.log(theta + mu + eps)
94
+
95
+ res = (
96
+ theta * (torch.log(theta + eps) - log_theta_mu_eps)
97
+ + x * (torch.log(mu + eps) - log_theta_mu_eps)
98
+ + torch.lgamma(x + theta)
99
+ - torch.lgamma(theta)
100
+ - torch.lgamma(x + 1)
101
+ )
102
+
103
+ return res
104
+
105
+
106
+ def log_mixture_nb(
107
+ x: torch.Tensor,
108
+ mu_1: torch.Tensor,
109
+ mu_2: torch.Tensor,
110
+ theta_1: torch.Tensor,
111
+ theta_2: torch.Tensor,
112
+ pi_logits: torch.Tensor,
113
+ eps=1e-8,
114
+ ):
115
+ """
116
+ From scVI.
117
+ Log likelihood (scalar) of a minibatch according to a mixture nb model.
118
+
119
+ pi_logits is the probability (logits) to be in the first component.
120
+ For totalVI, the first component should be background.
121
+
122
+ Parameters
123
+ ----------
124
+ x
125
+ Observed data
126
+ mu_1
127
+ Mean of the first negative binomial component (has to be positive support) (shape: minibatch x features)
128
+ mu_2
129
+ Mean of the second negative binomial (has to be positive support) (shape: minibatch x features)
130
+ theta_1
131
+ First inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
132
+ theta_2
133
+ Second inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
134
+ If None, assume one shared inverse dispersion parameter.
135
+ pi_logits
136
+ Probability of belonging to mixture component 1 (logits scale)
137
+ eps
138
+ Numerical stability constant
139
+ """
140
+ if theta_2 is not None:
141
+ log_nb_1 = log_nb_positive(x, mu_1, theta_1)
142
+ log_nb_2 = log_nb_positive(x, mu_2, theta_2)
143
+ # this is intended to reduce repeated computations
144
+ else:
145
+ theta = theta_1
146
+ if theta.ndimension() == 1:
147
+ theta = theta.view(
148
+ 1, theta.size(0)
149
+ ) # In this case, we reshape theta for broadcasting
150
+
151
+ log_theta_mu_1_eps = torch.log(theta + mu_1 + eps)
152
+ log_theta_mu_2_eps = torch.log(theta + mu_2 + eps)
153
+ lgamma_x_theta = torch.lgamma(x + theta)
154
+ lgamma_theta = torch.lgamma(theta)
155
+ lgamma_x_plus_1 = torch.lgamma(x + 1)
156
+
157
+ log_nb_1 = (
158
+ theta * (torch.log(theta + eps) - log_theta_mu_1_eps)
159
+ + x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps)
160
+ + lgamma_x_theta
161
+ - lgamma_theta
162
+ - lgamma_x_plus_1
163
+ )
164
+ log_nb_2 = (
165
+ theta * (torch.log(theta + eps) - log_theta_mu_2_eps)
166
+ + x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps)
167
+ + lgamma_x_theta
168
+ - lgamma_theta
169
+ - lgamma_x_plus_1
170
+ )
171
+
172
+ logsumexp = torch.logsumexp(torch.stack(
173
+ (log_nb_1, log_nb_2 - pi_logits)), dim=0)
174
+ softplus_pi = F.softplus(-pi_logits)
175
+
176
+ log_mixture_nb = logsumexp - softplus_pi
177
+
178
+ return log_mixture_nb
179
+
180
+
181
+ def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
182
+ r"""
183
+ From scVI.
184
+ NB parameterizations conversion.
185
+
186
+ Parameters
187
+ ----------
188
+ mu
189
+ mean of the NB distribution.
190
+ theta
191
+ inverse overdispersion.
192
+ eps
193
+ constant used for numerical log stability. (Default value = 1e-6)
194
+
195
+ Returns
196
+ -------
197
+ type
198
+ the number of failures until the experiment is stopped
199
+ and the success probability.
200
+ """
201
+ if not (mu is None) == (theta is None):
202
+ raise ValueError(
203
+ "If using the mu/theta NB parameterization, both parameters must be specified"
204
+ )
205
+ logits = (mu + eps).log() - (theta + eps).log()
206
+ total_count = theta
207
+ return total_count, logits
208
+
209
+
210
+ def _convert_counts_logits_to_mean_disp(total_count, logits):
211
+ """
212
+ From scVI.
213
+ NB parameterizations conversion.
214
+
215
+ Parameters
216
+ ----------
217
+ total_count
218
+ Number of failures until the experiment is stopped.
219
+ logits
220
+ success logits.
221
+
222
+ Returns
223
+ -------
224
+ type
225
+ the mean and inverse overdispersion of the NB distribution.
226
+
227
+ """
228
+ theta = total_count
229
+ mu = logits.exp() * theta
230
+ return mu, theta
231
+
232
+
233
+ def _gamma(theta, mu):
234
+ concentration = theta
235
+ rate = theta / mu
236
+ # Important remark: Gamma is parametrized by the rate = 1/scale!
237
+ gamma_d = Gamma(concentration=concentration, rate=rate)
238
+ return gamma_d
239
+
240
+
241
+ class NegativeBinomial(Distribution):
242
+ r"""
243
+ From scVI.
244
+ Negative binomial distribution.
245
+
246
+ One of the following parameterizations must be provided:
247
+
248
+ (1), (`total_count`, `probs`) where `total_count` is the number of failures until
249
+ the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
250
+ parameterization, which is the one used by scvi-tools. These parameters respectively
251
+ control the mean and inverse dispersion of the distribution.
252
+
253
+ In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:
254
+
255
+ 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
256
+ 2. :math:`x \sim \textrm{Poisson}(w)`
257
+
258
+ Parameters
259
+ ----------
260
+ total_count
261
+ Number of failures until the experiment is stopped.
262
+ probs
263
+ The success probability.
264
+ mu
265
+ Mean of the distribution.
266
+ theta
267
+ Inverse dispersion.
268
+ scale
269
+ Normalized mean expression of the distribution.
270
+ validate_args
271
+ Raise ValueError if arguments do not match constraints
272
+ """
273
+
274
+ arg_constraints = {
275
+ "mu": constraints.greater_than_eq(0),
276
+ "theta": constraints.greater_than_eq(0),
277
+ }
278
+ support = constraints.nonnegative_integer
279
+
280
+ def __init__(
281
+ self,
282
+ total_count: torch.Tensor | None = None,
283
+ probs: torch.Tensor | None = None,
284
+ logits: torch.Tensor | None = None,
285
+ mu: torch.Tensor | None = None,
286
+ theta: torch.Tensor | None = None,
287
+ scale: torch.Tensor | None = None,
288
+ validate_args: bool = False,
289
+ ):
290
+ self._eps = 1e-8
291
+ if (mu is None) == (total_count is None):
292
+ raise ValueError(
293
+ "Please use one of the two possible parameterizations. Refer to the documentation for more information."
294
+ )
295
+
296
+ using_param_1 = total_count is not None and (
297
+ logits is not None or probs is not None
298
+ )
299
+ if using_param_1:
300
+ logits = logits if logits is not None else probs_to_logits(probs)
301
+ total_count = total_count.type_as(logits)
302
+ total_count, logits = broadcast_all(total_count, logits)
303
+ mu, theta = _convert_counts_logits_to_mean_disp(
304
+ total_count, logits)
305
+ else:
306
+ mu, theta = broadcast_all(mu, theta)
307
+ self.mu = mu
308
+ self.theta = theta
309
+ self.scale = scale
310
+ super().__init__(validate_args=validate_args)
311
+
312
+ @property
313
+ def mean(self):
314
+ return self.mu
315
+
316
+ @property
317
+ def variance(self):
318
+ return self.mean + (self.mean**2) / self.theta
319
+
320
+ @torch.inference_mode()
321
+ def sample(
322
+ self,
323
+ sample_shape: torch.Size | tuple | None = None,
324
+ ) -> torch.Tensor:
325
+ """Sample from the distribution."""
326
+ sample_shape = sample_shape or torch.Size()
327
+ gamma_d = self._gamma()
328
+ p_means = gamma_d.sample(sample_shape)
329
+
330
+ # Clamping as distributions objects can have buggy behaviors when
331
+ # their parameters are too high
332
+ l_train = torch.clamp(p_means, max=1e8)
333
+ # Shape : (n_samples, n_cells_batch, n_vars)
334
+ counts = Poisson(l_train).sample()
335
+ return counts
336
+
337
+ @torch.inference_mode()
338
+ def rsample(self, sample_shape: torch.Size | tuple | None = None):
339
+ """Sample from the distribution."""
340
+ return self.sample(sample_shape=sample_shape)
341
+
342
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
343
+ if self._validate_args:
344
+ try:
345
+ self._validate_sample(value)
346
+ except ValueError:
347
+ warnings.warn(
348
+ "The value argument must be within the support of the distribution",
349
+ UserWarning, stacklevel=2,
350
+ )
351
+
352
+ return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps)
353
+
354
+ def _gamma(self):
355
+ return _gamma(self.theta, self.mu)
356
+
357
+ def pearson_residuals(self, x: torch.Tensor) -> torch.Tensor:
358
+ r"""
359
+ Compute the Pearson residuals.
360
+
361
+ Parameters
362
+ ----------
363
+ x
364
+ Observed data.
365
+
366
+ Returns
367
+ -------
368
+ type
369
+ Pearson residuals.
370
+ """
371
+ mean = self.mean
372
+ variance = self.variance
373
+ return (x - mean) / torch.sqrt(variance)
374
+
375
+
376
+ class ZeroInflatedNegativeBinomial(NegativeBinomial):
377
+ r"""
378
+ From scVI.
379
+ Zero-inflated negative binomial distribution.
380
+
381
+ One of the following parameterizations must be provided:
382
+
383
+ (1), (`total_count`, `probs`) where `total_count` is the number of failures until
384
+ the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
385
+ parameterization, which is the one used by scvi-tools. These parameters respectively
386
+ control the mean and inverse dispersion of the distribution.
387
+
388
+ In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:
389
+
390
+ 1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
391
+ 2. :math:`x \sim \textrm{Poisson}(w)`
392
+
393
+ Parameters
394
+ ----------
395
+ total_count
396
+ Number of failures until the experiment is stopped.
397
+ probs
398
+ The success probability.
399
+ mu
400
+ Mean of the distribution.
401
+ theta
402
+ Inverse dispersion.
403
+ zi_logits
404
+ Logits scale of zero inflation probability.
405
+ scale
406
+ Normalized mean expression of the distribution.
407
+ validate_args
408
+ Raise ValueError if arguments do not match constraints
409
+ """
410
+
411
+ arg_constraints = {
412
+ "mu": constraints.greater_than_eq(0),
413
+ "theta": constraints.greater_than_eq(0),
414
+ "zi_probs": constraints.half_open_interval(0.0, 1.0),
415
+ "zi_logits": constraints.real,
416
+ }
417
+ support = constraints.nonnegative_integer
418
+
419
+ def __init__(
420
+ self,
421
+ total_count: torch.Tensor | None = None,
422
+ probs: torch.Tensor | None = None,
423
+ logits: torch.Tensor | None = None,
424
+ mu: torch.Tensor | None = None,
425
+ theta: torch.Tensor | None = None,
426
+ zi_logits: torch.Tensor | None = None,
427
+ scale: torch.Tensor | None = None,
428
+ validate_args: bool = False,
429
+ ):
430
+ super().__init__(
431
+ total_count=total_count,
432
+ probs=probs,
433
+ logits=logits,
434
+ mu=mu,
435
+ theta=theta,
436
+ scale=scale,
437
+ validate_args=validate_args,
438
+ )
439
+ self.zi_logits, self.mu, self.theta = broadcast_all(
440
+ zi_logits, self.mu, self.theta
441
+ )
442
+
443
+ @property
444
+ def mean(self):
445
+ pi = self.zi_probs
446
+ return (1 - pi) * self.mu
447
+
448
+ @property
449
+ def variance(self):
450
+ pi = self.zi_probs
451
+ return (1 - pi) * self.mu * (1 + self.mu * (pi + 1 / self.theta))
452
+
453
+ @lazy_property
454
+ def zi_logits(self) -> torch.Tensor:
455
+ """ZI logits."""
456
+ return probs_to_logits(self.zi_probs, is_binary=True)
457
+
458
+ @lazy_property
459
+ def zi_probs(self) -> torch.Tensor:
460
+ return logits_to_probs(self.zi_logits, is_binary=True)
461
+
462
+ @torch.inference_mode()
463
+ def sample(
464
+ self,
465
+ sample_shape: torch.Size | tuple | None = None,
466
+ ) -> torch.Tensor:
467
+ """Sample from the distribution."""
468
+ sample_shape = sample_shape or torch.Size()
469
+ samp = super().sample(sample_shape=sample_shape)
470
+ is_zero = torch.rand_like(samp) <= self.zi_probs
471
+ samp_ = torch.where(
472
+ is_zero, torch.tensor(0.0, dtype=torch.float32, device=samp.device), samp
473
+ )
474
+ return samp_
475
+
476
+ @torch.inference_mode()
477
+ def rsample( # type: ignore
478
+ self,
479
+ sample_shape: torch.Size | tuple | None = None,
480
+ ) -> torch.Tensor:
481
+ """Sample from the distribution."""
482
+ sample_shape = sample_shape or torch.Size()
483
+ samp = super().rsample(sample_shape=sample_shape)
484
+ is_zero = torch.rand_like(samp) <= self.zi_probs
485
+ samp_ = torch.where(is_zero, torch.tensor(
486
+ 0.0, dtype=torch.float32, device=samp.device), samp)
487
+ return samp_
488
+
489
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
490
+ """Log probability."""
491
+ try:
492
+ self._validate_sample(value)
493
+ except ValueError:
494
+ warnings.warn(
495
+ "The value argument must be within the support of the distribution",
496
+ UserWarning, stacklevel=2,
497
+ )
498
+ return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08)
@@ -0,0 +1,186 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.distributions import Normal
4
+ from torch.distributions import kl_divergence as kl
5
+
6
+ from .gene_former import GeneModuleFormer
7
+
8
+
9
+ def full_block(in_dim, out_dim, p_drop=0.1):
10
+ return nn.Sequential(
11
+ nn.Linear(in_dim, out_dim),
12
+ nn.BatchNorm1d(out_dim),
13
+ nn.ReLU(),
14
+ nn.Dropout(p=p_drop),
15
+ )
16
+
17
+ class transform(nn.Module):
18
+ """
19
+ batch transform encoder
20
+ """
21
+ def __init__(self,
22
+ input_size,
23
+ hidden_size,
24
+ batch_emb_size,
25
+ module_dim,
26
+ hidden_gmf,
27
+ n_modules,
28
+ nhead,
29
+ n_enc_layer,
30
+ use_tf):
31
+
32
+ super().__init__()
33
+ self.use_tf = use_tf
34
+
35
+ if self.use_tf:
36
+ self.input_size = hidden_gmf + batch_emb_size
37
+ self.gmf = GeneModuleFormer(input_dim=input_size,
38
+ module_dim=module_dim,
39
+ hidden_dim=hidden_gmf,
40
+ n_modules=n_modules,
41
+ nhead=nhead,
42
+ n_enc_layer=n_enc_layer
43
+ )
44
+ self.transform = full_block(self.input_size,hidden_size)
45
+ else:
46
+ self.input_size = input_size + batch_emb_size
47
+ self.transform = full_block(self.input_size,hidden_size)
48
+ self.norm = nn.LayerNorm(hidden_size)
49
+
50
+ def forward(self, x, batch):
51
+ if self.use_tf:
52
+ x = self.gmf(x)
53
+ x = self.transform(torch.cat([x,batch],dim=1))
54
+ else:
55
+ x = self.transform(torch.cat([x,batch],dim=1))
56
+ return self.norm(x)
57
+
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ """
62
+ GCN encoder
63
+ """
64
+ def __init__(self,
65
+ input_size,
66
+ hidden_size,
67
+ emb_size,
68
+ batch_emb_size,
69
+ module_dim,
70
+ hidden_gmf,
71
+ n_modules,
72
+ nhead,
73
+ n_enc_layer,
74
+ use_tf,
75
+ variational=True):
76
+
77
+ super().__init__()
78
+ self.variational = variational
79
+
80
+ self.tf = transform(
81
+ input_size,
82
+ hidden_size,
83
+ batch_emb_size,
84
+ module_dim,
85
+ hidden_gmf,
86
+ n_modules,
87
+ nhead,
88
+ n_enc_layer,
89
+ use_tf
90
+ )
91
+
92
+ self.mlp = nn.Sequential(full_block(hidden_size, hidden_size),
93
+ full_block(hidden_size,hidden_size))
94
+
95
+
96
+ self.fc_mean = nn.Linear(hidden_size,emb_size)
97
+ self.fc_var = nn.Linear(hidden_size,emb_size)
98
+
99
+ def forward(self, x, batch):
100
+
101
+ xtf = self.tf(x,batch)
102
+ h = self.mlp(xtf)
103
+ if not self.variational:
104
+ mu = self.fc_mean(h)
105
+ return mu
106
+
107
+ mu = self.fc_mean(h)
108
+ logvar = self.fc_var(h)
109
+ self.mu = mu
110
+ self.sigma = logvar.exp().sqrt()
111
+ self.dist = Normal(self.mu, self.sigma)
112
+ return self.dist.rsample()
113
+
114
+ def kl_loss(self):
115
+ if not hasattr(self, "dist"):
116
+ return 0
117
+
118
+ mean = torch.zeros_like(self.mu)
119
+ scale = torch.ones_like(self.sigma)
120
+ kl_loss = kl(self.dist, Normal(mean, scale))
121
+ return kl_loss.mean()
122
+
123
+ class Decoder(nn.Module):
124
+ """
125
+ Shared decoder
126
+ """
127
+ def __init__(self,
128
+ out_put_size,
129
+ hidden_size,
130
+ emb_size,
131
+ batch_emb_size,
132
+ class_size,
133
+ decoder_type,
134
+ distribution,
135
+ n_layers=3):
136
+ super().__init__()
137
+
138
+ self.decoder_type = decoder_type
139
+ self.mlp = nn.ModuleList()
140
+
141
+ # Set initial input size
142
+ if decoder_type == 'reconstruction':
143
+ input_size = emb_size + batch_emb_size
144
+ elif decoder_type == 'classification':
145
+ input_size = emb_size * 2 + batch_emb_size
146
+ else:
147
+ raise ValueError(f"Unknown decoder_type: {decoder_type}")
148
+
149
+ # Build MLP layers with batch embedding concat at each step
150
+ if isinstance(n_layers, int):
151
+ n_layers = [hidden_size] * n_layers
152
+
153
+ for hidden_size in n_layers:
154
+ self.mlp.append(full_block(input_size, hidden_size))
155
+ input_size = hidden_size + batch_emb_size # update for next layer input
156
+
157
+ # Final output layer
158
+ if decoder_type == 'reconstruction':
159
+ self.zi_logit = nn.Linear(input_size, out_put_size)
160
+ self.fc_rec = nn.Linear(input_size, out_put_size)
161
+ elif decoder_type == 'classification':
162
+ self.fc_class = nn.Linear(input_size, class_size)
163
+
164
+ if distribution in ['nb','zinb']:
165
+ self.act = nn.Softmax(dim=-1)
166
+ else:
167
+ self.act = nn.Identity()
168
+
169
+ def forward(self, z, batch):
170
+ x = torch.cat([z, batch], dim=1)
171
+
172
+ for layer in self.mlp:
173
+ x = layer(x)
174
+ x = torch.cat([x, batch], dim=1) # concat batch after each layer
175
+
176
+ if self.decoder_type == 'reconstruction':
177
+ x_hat = self.act(self.fc_rec(x))
178
+ zi_logit = self.zi_logit(x)
179
+ return x_hat, zi_logit
180
+
181
+ elif self.decoder_type == 'classification':
182
+ x_class = self.fc_class(x)
183
+ return x_class
184
+
185
+ else:
186
+ raise ValueError(f"Unknown decoder_type: {self.decoder_type}")