SURE-tools 2.1.34__py3-none-any.whl → 2.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/{PerturbFlow.py → DensityFlow.py} +216 -127
- SURE/SURE.py +33 -41
- SURE/__init__.py +3 -3
- SURE/flow/flow_stats.py +37 -0
- SURE/perturb/__init__.py +1 -1
- SURE/perturb/perturb.py +84 -2
- SURE/utils/__init__.py +1 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.1.34.dist-info → sure_tools-2.2.24.dist-info}/METADATA +2 -1
- sure_tools-2.2.24.dist-info/RECORD +25 -0
- sure_tools-2.1.34.dist-info/RECORD +0 -25
- {sure_tools-2.1.34.dist-info → sure_tools-2.2.24.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.34.dist-info → sure_tools-2.2.24.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.34.dist-info → sure_tools-2.2.24.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.34.dist-info → sure_tools-2.2.24.dist-info}/top_level.txt +0 -0
|
@@ -54,26 +54,27 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
-
hidden_layers: list = [
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
69
|
+
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
71
72
|
post_layer_fct: list = ['layernorm'],
|
|
72
73
|
post_act_fct: list = None,
|
|
73
74
|
config_enum: str = 'parallel',
|
|
74
|
-
use_cuda: bool =
|
|
75
|
+
use_cuda: bool = True,
|
|
75
76
|
seed: int = 42,
|
|
76
|
-
zero_bias: bool = True,
|
|
77
|
+
zero_bias: bool|list = True,
|
|
77
78
|
dtype = torch.float32, # type: ignore
|
|
78
79
|
):
|
|
79
80
|
super().__init__()
|
|
@@ -97,7 +98,12 @@ class PerturbFlow(nn.Module):
|
|
|
97
98
|
self.post_layer_fct = post_layer_fct
|
|
98
99
|
self.post_act_fct = post_act_fct
|
|
99
100
|
self.hidden_layer_activation = hidden_layer_activation
|
|
100
|
-
|
|
101
|
+
if type(zero_bias) == list:
|
|
102
|
+
self.use_bias = [not x for x in zero_bias]
|
|
103
|
+
else:
|
|
104
|
+
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
105
|
+
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
101
107
|
|
|
102
108
|
self.codebook_weights = None
|
|
103
109
|
|
|
@@ -198,38 +204,62 @@ class PerturbFlow(nn.Module):
|
|
|
198
204
|
if self.cell_factor_size>0:
|
|
199
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
200
206
|
for i in np.arange(self.cell_factor_size):
|
|
201
|
-
if self.use_bias:
|
|
202
|
-
self.
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
207
|
+
if self.use_bias[i]:
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
210
229
|
)
|
|
211
|
-
)
|
|
212
230
|
else:
|
|
213
|
-
self.
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
221
252
|
)
|
|
222
|
-
)
|
|
223
253
|
|
|
224
254
|
self.decoder_concentrate = MLP(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
255
|
+
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
256
|
+
activation=activate_fct,
|
|
257
|
+
output_activation=None,
|
|
258
|
+
post_layer_fct=post_layer_fct,
|
|
259
|
+
post_act_fct=post_act_fct,
|
|
260
|
+
allow_broadcast=self.allow_broadcast,
|
|
261
|
+
use_cuda=self.use_cuda,
|
|
262
|
+
)
|
|
233
263
|
|
|
234
264
|
if self.latent_dist == 'studentt':
|
|
235
265
|
self.codebook = MLP(
|
|
@@ -304,7 +334,7 @@ class PerturbFlow(nn.Module):
|
|
|
304
334
|
return xs
|
|
305
335
|
|
|
306
336
|
def model1(self, xs):
|
|
307
|
-
pyro.module('
|
|
337
|
+
pyro.module('DensityFlow', self)
|
|
308
338
|
|
|
309
339
|
eps = torch.finfo(xs.dtype).eps
|
|
310
340
|
batch_size = xs.size(0)
|
|
@@ -318,7 +348,7 @@ class PerturbFlow(nn.Module):
|
|
|
318
348
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
319
349
|
|
|
320
350
|
acs_scale = pyro.param("codebook_scale", xs.new_ones(self.latent_dim), constraint=constraints.positive)
|
|
321
|
-
|
|
351
|
+
|
|
322
352
|
I = torch.eye(self.code_size)
|
|
323
353
|
if self.latent_dist=='studentt':
|
|
324
354
|
acs_dof,acs_loc = self.codebook(I)
|
|
@@ -347,12 +377,13 @@ class PerturbFlow(nn.Module):
|
|
|
347
377
|
|
|
348
378
|
zs = zns
|
|
349
379
|
concentrate = self.decoder_concentrate(zs)
|
|
350
|
-
if self.loss_func
|
|
380
|
+
if self.loss_func in ['bernoulli']:
|
|
351
381
|
log_theta = concentrate
|
|
352
382
|
else:
|
|
353
383
|
rate = concentrate.exp()
|
|
354
|
-
|
|
355
|
-
|
|
384
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
385
|
+
if self.loss_func == 'poisson':
|
|
386
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
356
387
|
|
|
357
388
|
if self.loss_func == 'negbinomial':
|
|
358
389
|
if self.use_zeroinflate:
|
|
@@ -374,14 +405,15 @@ class PerturbFlow(nn.Module):
|
|
|
374
405
|
|
|
375
406
|
def guide1(self, xs):
|
|
376
407
|
with pyro.plate('data'):
|
|
377
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
408
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
409
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
378
410
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
379
411
|
|
|
380
412
|
alpha = self.encoder_n(zns)
|
|
381
413
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
382
414
|
|
|
383
415
|
def model2(self, xs, us=None):
|
|
384
|
-
pyro.module('
|
|
416
|
+
pyro.module('DensityFlow', self)
|
|
385
417
|
|
|
386
418
|
eps = torch.finfo(xs.dtype).eps
|
|
387
419
|
batch_size = xs.size(0)
|
|
@@ -423,23 +455,19 @@ class PerturbFlow(nn.Module):
|
|
|
423
455
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
424
456
|
|
|
425
457
|
if self.cell_factor_size>0:
|
|
426
|
-
zus =
|
|
427
|
-
for i in np.arange(self.cell_factor_size):
|
|
428
|
-
if i==0:
|
|
429
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
430
|
-
else:
|
|
431
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
458
|
+
zus = self._total_effects(zns, us)
|
|
432
459
|
zs = zns+zus
|
|
433
460
|
else:
|
|
434
461
|
zs = zns
|
|
435
462
|
|
|
436
463
|
concentrate = self.decoder_concentrate(zs)
|
|
437
|
-
if self.loss_func
|
|
464
|
+
if self.loss_func in ['bernoulli']:
|
|
438
465
|
log_theta = concentrate
|
|
439
466
|
else:
|
|
440
467
|
rate = concentrate.exp()
|
|
441
|
-
|
|
442
|
-
|
|
468
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
469
|
+
if self.loss_func == 'poisson':
|
|
470
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
443
471
|
|
|
444
472
|
if self.loss_func == 'negbinomial':
|
|
445
473
|
if self.use_zeroinflate:
|
|
@@ -461,14 +489,15 @@ class PerturbFlow(nn.Module):
|
|
|
461
489
|
|
|
462
490
|
def guide2(self, xs, us=None):
|
|
463
491
|
with pyro.plate('data'):
|
|
464
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
492
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
493
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
465
494
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
466
495
|
|
|
467
496
|
alpha = self.encoder_n(zns)
|
|
468
497
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
469
498
|
|
|
470
499
|
def model3(self, xs, ys, embeds=None):
|
|
471
|
-
pyro.module('
|
|
500
|
+
pyro.module('DensityFlow', self)
|
|
472
501
|
|
|
473
502
|
eps = torch.finfo(xs.dtype).eps
|
|
474
503
|
batch_size = xs.size(0)
|
|
@@ -528,12 +557,13 @@ class PerturbFlow(nn.Module):
|
|
|
528
557
|
zs = zns
|
|
529
558
|
|
|
530
559
|
concentrate = self.decoder_concentrate(zs)
|
|
531
|
-
if self.loss_func
|
|
560
|
+
if self.loss_func in ['bernoulli']:
|
|
532
561
|
log_theta = concentrate
|
|
533
562
|
else:
|
|
534
563
|
rate = concentrate.exp()
|
|
535
|
-
|
|
536
|
-
|
|
564
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
565
|
+
if self.loss_func == 'poisson':
|
|
566
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
537
567
|
|
|
538
568
|
if self.loss_func == 'negbinomial':
|
|
539
569
|
if self.use_zeroinflate:
|
|
@@ -556,11 +586,14 @@ class PerturbFlow(nn.Module):
|
|
|
556
586
|
def guide3(self, xs, ys, embeds=None):
|
|
557
587
|
with pyro.plate('data'):
|
|
558
588
|
if embeds is None:
|
|
559
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
589
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
590
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
560
591
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
592
|
+
else:
|
|
593
|
+
zns = embeds
|
|
561
594
|
|
|
562
595
|
def model4(self, xs, us, ys, embeds=None):
|
|
563
|
-
pyro.module('
|
|
596
|
+
pyro.module('DensityFlow', self)
|
|
564
597
|
|
|
565
598
|
eps = torch.finfo(xs.dtype).eps
|
|
566
599
|
batch_size = xs.size(0)
|
|
@@ -618,23 +651,25 @@ class PerturbFlow(nn.Module):
|
|
|
618
651
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
619
652
|
|
|
620
653
|
if self.cell_factor_size>0:
|
|
621
|
-
zus = None
|
|
622
|
-
for i in np.arange(self.cell_factor_size):
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
654
|
+
#zus = None
|
|
655
|
+
#for i in np.arange(self.cell_factor_size):
|
|
656
|
+
# if i==0:
|
|
657
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
658
|
+
# else:
|
|
659
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
660
|
+
zus = self._total_effects(zns, us)
|
|
627
661
|
zs = zns+zus
|
|
628
662
|
else:
|
|
629
663
|
zs = zns
|
|
630
664
|
|
|
631
665
|
concentrate = self.decoder_concentrate(zs)
|
|
632
|
-
if self.loss_func
|
|
666
|
+
if self.loss_func in ['bernoulli']:
|
|
633
667
|
log_theta = concentrate
|
|
634
668
|
else:
|
|
635
669
|
rate = concentrate.exp()
|
|
636
|
-
|
|
637
|
-
|
|
670
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
671
|
+
if self.loss_func == 'poisson':
|
|
672
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
638
673
|
|
|
639
674
|
if self.loss_func == 'negbinomial':
|
|
640
675
|
if self.use_zeroinflate:
|
|
@@ -657,9 +692,32 @@ class PerturbFlow(nn.Module):
|
|
|
657
692
|
def guide4(self, xs, us, ys, embeds=None):
|
|
658
693
|
with pyro.plate('data'):
|
|
659
694
|
if embeds is None:
|
|
660
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
695
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
696
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
661
697
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
662
|
-
|
|
698
|
+
else:
|
|
699
|
+
zns = embeds
|
|
700
|
+
|
|
701
|
+
def _total_effects(self, zns, us):
|
|
702
|
+
zus = None
|
|
703
|
+
for i in np.arange(self.cell_factor_size):
|
|
704
|
+
if i==0:
|
|
705
|
+
#if self.turn_off_cell_specific:
|
|
706
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
|
+
#else:
|
|
708
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
710
|
+
else:
|
|
711
|
+
#if self.turn_off_cell_specific:
|
|
712
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
|
+
#else:
|
|
714
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
716
|
+
return zus
|
|
717
|
+
|
|
718
|
+
def _get_codebook_identity(self):
|
|
719
|
+
return torch.eye(self.code_size, **self.options)
|
|
720
|
+
|
|
663
721
|
def _get_codebook(self):
|
|
664
722
|
I = torch.eye(self.code_size, **self.options)
|
|
665
723
|
if self.latent_dist=='studentt':
|
|
@@ -672,13 +730,13 @@ class PerturbFlow(nn.Module):
|
|
|
672
730
|
"""
|
|
673
731
|
Return the mean part of metacell codebook
|
|
674
732
|
"""
|
|
675
|
-
cb = self.
|
|
733
|
+
cb = self._get_codebook()
|
|
676
734
|
cb = tensor_to_numpy(cb)
|
|
677
735
|
return cb
|
|
678
736
|
|
|
679
737
|
def _get_basal_embedding(self, xs):
|
|
680
|
-
|
|
681
|
-
return
|
|
738
|
+
loc, scale = self.encoder_zn(xs)
|
|
739
|
+
return loc, scale
|
|
682
740
|
|
|
683
741
|
def get_basal_embedding(self,
|
|
684
742
|
xs,
|
|
@@ -705,7 +763,7 @@ class PerturbFlow(nn.Module):
|
|
|
705
763
|
Z = []
|
|
706
764
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
707
765
|
for X_batch, _ in dataloader:
|
|
708
|
-
zns = self._get_basal_embedding(X_batch)
|
|
766
|
+
zns,_ = self._get_basal_embedding(X_batch)
|
|
709
767
|
Z.append(tensor_to_numpy(zns))
|
|
710
768
|
pbar.update(1)
|
|
711
769
|
|
|
@@ -716,7 +774,8 @@ class PerturbFlow(nn.Module):
|
|
|
716
774
|
if self.supervised_mode:
|
|
717
775
|
alpha = self.encoder_n(xs)
|
|
718
776
|
else:
|
|
719
|
-
zns,_ = self.encoder_zn(xs)
|
|
777
|
+
#zns,_ = self.encoder_zn(xs)
|
|
778
|
+
zns,_ = self._get_basal_embedding(xs)
|
|
720
779
|
alpha = self.encoder_n(zns)
|
|
721
780
|
return alpha
|
|
722
781
|
|
|
@@ -785,46 +844,80 @@ class PerturbFlow(nn.Module):
|
|
|
785
844
|
A = np.concatenate(A)
|
|
786
845
|
return A
|
|
787
846
|
|
|
788
|
-
def
|
|
789
|
-
|
|
847
|
+
def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
|
|
848
|
+
perturbs_reference = np.array(perturbs_reference)
|
|
849
|
+
|
|
850
|
+
# basal embedding
|
|
851
|
+
zs = self.get_basal_embedding(xs)
|
|
852
|
+
for pert in perturbs_predict:
|
|
853
|
+
pert_idx = int(np.where(perturbs_reference==pert)[0])
|
|
854
|
+
us_i = us[:,pert_idx].reshape(-1,1)
|
|
855
|
+
|
|
856
|
+
# factor effect of xs
|
|
857
|
+
dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
858
|
+
|
|
859
|
+
# perturbation effect
|
|
860
|
+
ps = np.ones_like(us_i)
|
|
861
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
862
|
+
dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
863
|
+
zs = zs + dzs0 + dzs
|
|
864
|
+
else:
|
|
865
|
+
zs = zs + dzs0
|
|
866
|
+
|
|
867
|
+
if library_sizes is None:
|
|
868
|
+
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
869
|
+
elif type(library_sizes) == list:
|
|
870
|
+
library_sizes = np.array(library_sizes)
|
|
871
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
872
|
+
elif len(library_sizes.shape)==1:
|
|
873
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
874
|
+
|
|
875
|
+
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
876
|
+
|
|
877
|
+
return counts, zs
|
|
878
|
+
|
|
879
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
880
|
+
#zns,_ = self.encoder_zn(xs)
|
|
881
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
882
|
+
zns = zs
|
|
790
883
|
if perturb.ndim==2:
|
|
791
|
-
|
|
884
|
+
if self.turn_off_cell_specific:
|
|
885
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
886
|
+
else:
|
|
887
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
792
888
|
else:
|
|
793
|
-
|
|
889
|
+
if self.turn_off_cell_specific:
|
|
890
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
891
|
+
else:
|
|
892
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
794
893
|
|
|
795
894
|
return ms
|
|
796
895
|
|
|
797
896
|
def get_cell_response(self,
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
897
|
+
zs,
|
|
898
|
+
perturb_idx,
|
|
899
|
+
perturb_us,
|
|
801
900
|
batch_size: int = 1024):
|
|
802
901
|
"""
|
|
803
902
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
804
903
|
|
|
805
904
|
"""
|
|
806
|
-
xs = self.preprocess(xs)
|
|
807
|
-
|
|
808
|
-
ps = convert_to_tensor(
|
|
809
|
-
dataset = CustomDataset2(
|
|
905
|
+
#xs = self.preprocess(xs)
|
|
906
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
907
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
908
|
+
dataset = CustomDataset2(zs,ps)
|
|
810
909
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
811
910
|
|
|
812
911
|
Z = []
|
|
813
912
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
814
|
-
for
|
|
815
|
-
zns = self._cell_response(
|
|
913
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
914
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
816
915
|
Z.append(tensor_to_numpy(zns))
|
|
817
916
|
pbar.update(1)
|
|
818
917
|
|
|
819
918
|
Z = np.concatenate(Z)
|
|
820
919
|
return Z
|
|
821
920
|
|
|
822
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
823
|
-
zs = self._get_codebook()
|
|
824
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
825
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
826
|
-
return tensor_to_numpy(ms)
|
|
827
|
-
|
|
828
921
|
def _get_expression_response(self, delta_zs):
|
|
829
922
|
return self.decoder_concentrate(delta_zs)
|
|
830
923
|
|
|
@@ -849,38 +942,35 @@ class PerturbFlow(nn.Module):
|
|
|
849
942
|
R = np.concatenate(R)
|
|
850
943
|
return R
|
|
851
944
|
|
|
852
|
-
def _count(self,concentrate):
|
|
853
|
-
if self.loss_func == 'bernoulli':
|
|
854
|
-
counts = self.sigmoid(concentrate)
|
|
855
|
-
else:
|
|
856
|
-
counts = concentrate.exp()
|
|
857
|
-
return counts
|
|
858
|
-
|
|
859
|
-
def _count_sample(self,concentrate):
|
|
945
|
+
def _count(self, concentrate, library_size=None):
|
|
860
946
|
if self.loss_func == 'bernoulli':
|
|
861
|
-
|
|
862
|
-
counts = dist.Bernoulli(logits=
|
|
947
|
+
#counts = self.sigmoid(concentrate)
|
|
948
|
+
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
863
949
|
else:
|
|
864
|
-
|
|
865
|
-
|
|
950
|
+
rate = concentrate.exp()
|
|
951
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
952
|
+
counts = theta * library_size
|
|
866
953
|
return counts
|
|
867
954
|
|
|
868
|
-
def get_counts(self, zs,
|
|
869
|
-
batch_size: int = 1024
|
|
870
|
-
use_sampler: bool = False):
|
|
955
|
+
def get_counts(self, zs, library_sizes,
|
|
956
|
+
batch_size: int = 1024):
|
|
871
957
|
|
|
872
958
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
873
|
-
|
|
959
|
+
|
|
960
|
+
if type(library_sizes) == list:
|
|
961
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
962
|
+
elif len(library_sizes.shape)==1:
|
|
963
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
964
|
+
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
965
|
+
|
|
966
|
+
dataset = CustomDataset2(zs,ls)
|
|
874
967
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
875
968
|
|
|
876
969
|
E = []
|
|
877
970
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
878
|
-
for Z_batch, _ in dataloader:
|
|
879
|
-
concentrate = self.
|
|
880
|
-
|
|
881
|
-
counts = self._count_sample(concentrate)
|
|
882
|
-
else:
|
|
883
|
-
counts = self._count(concentrate)
|
|
971
|
+
for Z_batch, L_batch, _ in dataloader:
|
|
972
|
+
concentrate = self._get_expression_response(Z_batch)
|
|
973
|
+
counts = self._count(concentrate, L_batch)
|
|
884
974
|
E.append(tensor_to_numpy(counts))
|
|
885
975
|
pbar.update(1)
|
|
886
976
|
|
|
@@ -903,7 +993,7 @@ class PerturbFlow(nn.Module):
|
|
|
903
993
|
us = None,
|
|
904
994
|
ys = None,
|
|
905
995
|
zs = None,
|
|
906
|
-
num_epochs: int =
|
|
996
|
+
num_epochs: int = 500,
|
|
907
997
|
learning_rate: float = 0.0001,
|
|
908
998
|
batch_size: int = 256,
|
|
909
999
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -912,9 +1002,9 @@ class PerturbFlow(nn.Module):
|
|
|
912
1002
|
decay_rate: float = 0.9,
|
|
913
1003
|
config_enum: str = 'parallel',
|
|
914
1004
|
threshold: int = 0,
|
|
915
|
-
use_jax: bool =
|
|
1005
|
+
use_jax: bool = True):
|
|
916
1006
|
"""
|
|
917
|
-
Train the
|
|
1007
|
+
Train the DensityFlow model.
|
|
918
1008
|
|
|
919
1009
|
Parameters
|
|
920
1010
|
----------
|
|
@@ -940,7 +1030,7 @@ class PerturbFlow(nn.Module):
|
|
|
940
1030
|
Parameter for optimization.
|
|
941
1031
|
use_jax
|
|
942
1032
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
943
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
1033
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
944
1034
|
"""
|
|
945
1035
|
xs = self.preprocess(xs, threshold=threshold)
|
|
946
1036
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1025,7 +1115,7 @@ class PerturbFlow(nn.Module):
|
|
|
1025
1115
|
# Update progress bar
|
|
1026
1116
|
pbar.set_postfix({'loss': str_loss})
|
|
1027
1117
|
pbar.update(1)
|
|
1028
|
-
|
|
1118
|
+
|
|
1029
1119
|
@classmethod
|
|
1030
1120
|
def save_model(cls, model, file_path, compression=False):
|
|
1031
1121
|
"""Save the model to the specified file path."""
|
|
@@ -1058,12 +1148,12 @@ class PerturbFlow(nn.Module):
|
|
|
1058
1148
|
|
|
1059
1149
|
|
|
1060
1150
|
EXAMPLE_RUN = (
|
|
1061
|
-
"example run:
|
|
1151
|
+
"example run: DensityFlow --help"
|
|
1062
1152
|
)
|
|
1063
1153
|
|
|
1064
1154
|
def parse_args():
|
|
1065
1155
|
parser = argparse.ArgumentParser(
|
|
1066
|
-
description="
|
|
1156
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1067
1157
|
|
|
1068
1158
|
parser.add_argument(
|
|
1069
1159
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1250,7 +1340,7 @@ def main():
|
|
|
1250
1340
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1251
1341
|
|
|
1252
1342
|
###########################################
|
|
1253
|
-
|
|
1343
|
+
DensityFlow = DensityFlow(
|
|
1254
1344
|
input_size=input_size,
|
|
1255
1345
|
cell_factor_size=cell_factor_size,
|
|
1256
1346
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1269,7 +1359,7 @@ def main():
|
|
|
1269
1359
|
dtype=dtype,
|
|
1270
1360
|
)
|
|
1271
1361
|
|
|
1272
|
-
|
|
1362
|
+
DensityFlow.fit(xs, us=us,
|
|
1273
1363
|
num_epochs=args.num_epochs,
|
|
1274
1364
|
learning_rate=args.learning_rate,
|
|
1275
1365
|
batch_size=args.batch_size,
|
|
@@ -1281,12 +1371,11 @@ def main():
|
|
|
1281
1371
|
|
|
1282
1372
|
if args.save_model is not None:
|
|
1283
1373
|
if args.save_model.endswith('gz'):
|
|
1284
|
-
|
|
1374
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1285
1375
|
else:
|
|
1286
|
-
|
|
1376
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1287
1377
|
|
|
1288
1378
|
|
|
1289
1379
|
|
|
1290
1380
|
if __name__ == "__main__":
|
|
1291
|
-
|
|
1292
1381
|
main()
|
SURE/SURE.py
CHANGED
|
@@ -99,19 +99,18 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
|
-
zero_bias: bool = True,
|
|
115
114
|
dtype = torch.float32, # type: ignore
|
|
116
115
|
):
|
|
117
116
|
super().__init__()
|
|
@@ -135,7 +134,6 @@ class SURE(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.use_bias = not zero_bias
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -234,26 +232,16 @@ class SURE(nn.Module):
|
|
|
234
232
|
)
|
|
235
233
|
|
|
236
234
|
if self.cell_factor_size>0:
|
|
237
|
-
|
|
238
|
-
self.
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
else:
|
|
248
|
-
self.cell_factor_effect = ZeroBiasMLP(
|
|
249
|
-
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
250
|
-
activation=activate_fct,
|
|
251
|
-
output_activation=None,
|
|
252
|
-
post_layer_fct=post_layer_fct,
|
|
253
|
-
post_act_fct=post_act_fct,
|
|
254
|
-
allow_broadcast=self.allow_broadcast,
|
|
255
|
-
use_cuda=self.use_cuda,
|
|
256
|
-
)
|
|
235
|
+
self.cell_factor_effect = MLP(
|
|
236
|
+
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
237
|
+
activation=activate_fct,
|
|
238
|
+
output_activation=None,
|
|
239
|
+
post_layer_fct=post_layer_fct,
|
|
240
|
+
post_act_fct=post_act_fct,
|
|
241
|
+
allow_broadcast=self.allow_broadcast,
|
|
242
|
+
use_cuda=self.use_cuda,
|
|
243
|
+
)
|
|
244
|
+
|
|
257
245
|
|
|
258
246
|
self.decoder_concentrate = MLP(
|
|
259
247
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -381,12 +369,13 @@ class SURE(nn.Module):
|
|
|
381
369
|
|
|
382
370
|
zs = zns
|
|
383
371
|
concentrate = self.decoder_concentrate(zs)
|
|
384
|
-
if self.loss_func
|
|
372
|
+
if self.loss_func in ['bernoulli']:
|
|
385
373
|
log_theta = concentrate
|
|
386
374
|
else:
|
|
387
375
|
rate = concentrate.exp()
|
|
388
|
-
|
|
389
|
-
|
|
376
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
377
|
+
if self.loss_func == 'poisson':
|
|
378
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
390
379
|
|
|
391
380
|
if self.loss_func == 'negbinomial':
|
|
392
381
|
if self.use_zeroinflate:
|
|
@@ -463,12 +452,13 @@ class SURE(nn.Module):
|
|
|
463
452
|
zs = zns
|
|
464
453
|
|
|
465
454
|
concentrate = self.decoder_concentrate(zs)
|
|
466
|
-
if self.loss_func
|
|
455
|
+
if self.loss_func in ['bernoulli']:
|
|
467
456
|
log_theta = concentrate
|
|
468
457
|
else:
|
|
469
458
|
rate = concentrate.exp()
|
|
470
|
-
|
|
471
|
-
|
|
459
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
460
|
+
if self.loss_func == 'poisson':
|
|
461
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
472
462
|
|
|
473
463
|
if self.loss_func == 'negbinomial':
|
|
474
464
|
if self.use_zeroinflate:
|
|
@@ -557,12 +547,13 @@ class SURE(nn.Module):
|
|
|
557
547
|
zs = zns
|
|
558
548
|
|
|
559
549
|
concentrate = self.decoder_concentrate(zs)
|
|
560
|
-
if self.loss_func
|
|
550
|
+
if self.loss_func in ['bernoulli']:
|
|
561
551
|
log_theta = concentrate
|
|
562
552
|
else:
|
|
563
553
|
rate = concentrate.exp()
|
|
564
|
-
|
|
565
|
-
|
|
554
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
555
|
+
if self.loss_func == 'poisson':
|
|
556
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
566
557
|
|
|
567
558
|
if self.loss_func == 'negbinomial':
|
|
568
559
|
if self.use_zeroinflate:
|
|
@@ -653,13 +644,14 @@ class SURE(nn.Module):
|
|
|
653
644
|
zs = zns
|
|
654
645
|
|
|
655
646
|
concentrate = self.decoder_concentrate(zs)
|
|
656
|
-
if self.loss_func
|
|
647
|
+
if self.loss_func in ['bernoulli']:
|
|
657
648
|
log_theta = concentrate
|
|
658
649
|
else:
|
|
659
650
|
rate = concentrate.exp()
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
651
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
652
|
+
if self.loss_func == 'poisson':
|
|
653
|
+
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
654
|
+
|
|
663
655
|
if self.loss_func == 'negbinomial':
|
|
664
656
|
if self.use_zeroinflate:
|
|
665
657
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -825,7 +817,7 @@ class SURE(nn.Module):
|
|
|
825
817
|
us = None,
|
|
826
818
|
ys = None,
|
|
827
819
|
zs = None,
|
|
828
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
829
821
|
learning_rate: float = 0.0001,
|
|
830
822
|
batch_size: int = 256,
|
|
831
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -834,7 +826,7 @@ class SURE(nn.Module):
|
|
|
834
826
|
decay_rate: float = 0.9,
|
|
835
827
|
config_enum: str = 'parallel',
|
|
836
828
|
threshold: int = 0,
|
|
837
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
838
830
|
"""
|
|
839
831
|
Train the SURE model.
|
|
840
832
|
|
SURE/__init__.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
SURE/flow/flow_stats.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
+
from scipy.interpolate import griddata
|
|
2
3
|
from scipy.spatial.distance import pdist, squareform
|
|
3
4
|
from sklearn.decomposition import PCA
|
|
4
5
|
from scipy.stats import pearsonr
|
|
@@ -16,6 +17,42 @@ class VectorFieldEval:
|
|
|
16
17
|
|
|
17
18
|
def momentum_flow_metric(self, vectors, masses=None):
|
|
18
19
|
return momentum_flow_metric(vectors=vectors, masses=masses)
|
|
20
|
+
|
|
21
|
+
def divergence(self, points, vectors, grid_resolution=30):
|
|
22
|
+
# 提取坐标和向量分量
|
|
23
|
+
x_coords = points[:, 0]
|
|
24
|
+
y_coords = points[:, 1]
|
|
25
|
+
u_components = vectors[:, 0] # x方向分量
|
|
26
|
+
v_components = vectors[:, 1] # y方向分量
|
|
27
|
+
|
|
28
|
+
# 创建规则网格
|
|
29
|
+
x_grid = np.linspace(x_coords.min(), x_coords.max(), grid_resolution)
|
|
30
|
+
y_grid = np.linspace(y_coords.min(), y_coords.max(), grid_resolution)
|
|
31
|
+
X, Y = np.meshgrid(x_grid, y_grid)
|
|
32
|
+
|
|
33
|
+
# 插值到网格
|
|
34
|
+
U_grid = griddata((x_coords, y_coords), u_components, (X, Y), method='linear')
|
|
35
|
+
V_grid = griddata((x_coords, y_coords), v_components, (X, Y), method='linear')
|
|
36
|
+
|
|
37
|
+
# 计算散度
|
|
38
|
+
dU_dx = np.gradient(U_grid, x_grid, axis=1)
|
|
39
|
+
dV_dy = np.gradient(V_grid, y_grid, axis=0)
|
|
40
|
+
divergence = dU_dx + dV_dy
|
|
41
|
+
divergence[np.isnan(divergence)] = 0
|
|
42
|
+
|
|
43
|
+
return divergence
|
|
44
|
+
|
|
45
|
+
def movement_stats(self,vectors):
|
|
46
|
+
return calculate_movement_stats(vectors)
|
|
47
|
+
|
|
48
|
+
def direction_stats(self, vectors):
|
|
49
|
+
return calculate_direction_stats(vectors)
|
|
50
|
+
|
|
51
|
+
def movement_energy(self, vectors, masses=None):
|
|
52
|
+
return calculate_movement_energy(vectors, masses)
|
|
53
|
+
|
|
54
|
+
def movement_divergence(self, positions, vectors):
|
|
55
|
+
return calculate_movement_divergence(positions, vectors)
|
|
19
56
|
|
|
20
57
|
|
|
21
58
|
def calculate_movement_stats(vectors):
|
SURE/perturb/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .perturb import LabelMatrix
|
|
1
|
+
from .perturb import LabelMatrix,DoseMatrix
|
SURE/perturb/perturb.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from numba import njit
|
|
3
5
|
from itertools import chain
|
|
4
6
|
from joblib import Parallel, delayed
|
|
5
7
|
from typing import Literal
|
|
@@ -7,8 +9,10 @@ from typing import Literal
|
|
|
7
9
|
class LabelMatrix:
|
|
8
10
|
def __init__(self):
|
|
9
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
10
14
|
|
|
11
|
-
def fit_transform(self, labels, sep_pattern=r'[
|
|
15
|
+
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
12
16
|
if speedup=='none':
|
|
13
17
|
mat, self.labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
14
18
|
elif speedup=='vectorize':
|
|
@@ -17,10 +21,53 @@ class LabelMatrix:
|
|
|
17
21
|
mat, self.labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
18
22
|
|
|
19
23
|
self.labels_ = np.array(self.labels_)
|
|
20
|
-
|
|
24
|
+
|
|
25
|
+
if control_label is not None:
|
|
26
|
+
idx = np.where(self.labels_==control_label)[0]
|
|
27
|
+
mat = np.delete(mat, idx, axis=1)
|
|
28
|
+
self.labels_ = np.delete(self.labels_, idx)
|
|
21
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
33
|
+
return mat
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
22
55
|
def inverse_transform(self, matrix):
|
|
23
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
57
|
+
|
|
58
|
+
class DoseMatrix:
|
|
59
|
+
def __init__(self):
|
|
60
|
+
self.labels_ = None
|
|
61
|
+
|
|
62
|
+
def fit_transform(self, labels, label_dose, control_label=None):
|
|
63
|
+
mat, self.labels_ = dose_to_matrix(labels, label_dose)
|
|
64
|
+
|
|
65
|
+
if control_label is not None:
|
|
66
|
+
idx = np.where(self.labels_==control_label)[0]
|
|
67
|
+
mat = np.delete(mat, idx, axis=1)
|
|
68
|
+
self.labels_ = np.delete(self.labels_, idx)
|
|
69
|
+
|
|
70
|
+
return mat
|
|
24
71
|
|
|
25
72
|
def label_to_matrix(labels, sep_pattern=r'[;_\-\s]'):
|
|
26
73
|
"""
|
|
@@ -85,3 +132,38 @@ def parallel_label_to_matrix(labels, sep_pattern=r'[;_\-\s]', n_jobs=4):
|
|
|
85
132
|
def matrix_to_labels(matrix, unique_labels):
|
|
86
133
|
return [';'.join([unique_labels[i] for i in np.where(row)[0]])
|
|
87
134
|
for row in matrix]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@njit(parallel=True)
|
|
143
|
+
def _numba_fill_matrix(dose_matrix, label_indices, label_doses):
|
|
144
|
+
"""Numba 加速的矩阵填充函数"""
|
|
145
|
+
for i in range(len(label_indices)):
|
|
146
|
+
dose_matrix[i, label_indices[i]] = label_doses[i]
|
|
147
|
+
|
|
148
|
+
def dose_to_matrix(labels, label_dose, all_labels=None):
|
|
149
|
+
"""
|
|
150
|
+
使用 Numba 的终极加速版本(需预先安装 numba)
|
|
151
|
+
"""
|
|
152
|
+
if all_labels is None:
|
|
153
|
+
all_labels = sorted(set().union(labels))
|
|
154
|
+
|
|
155
|
+
label_to_idx = {label: idx for idx, label in enumerate(all_labels)}
|
|
156
|
+
n_samples = len(labels)
|
|
157
|
+
n_labels = len(all_labels)
|
|
158
|
+
dose_matrix = np.zeros((n_samples, n_labels), dtype=np.float64)
|
|
159
|
+
|
|
160
|
+
# 预处理为 Numba 兼容格式
|
|
161
|
+
label_indices = []
|
|
162
|
+
label_doses = []
|
|
163
|
+
for i, label in enumerate(labels):
|
|
164
|
+
label_indices.append(label_to_idx[label])
|
|
165
|
+
label_doses.append(label_dose[i])
|
|
166
|
+
|
|
167
|
+
# 调用 Numba 加速函数
|
|
168
|
+
_numba_fill_matrix(dose_matrix, label_indices, label_doses)
|
|
169
|
+
return dose_matrix,np.array(all_labels)
|
SURE/utils/__init__.py
CHANGED
SURE/utils/custom_mlp.py
CHANGED
|
@@ -239,6 +239,43 @@ class ZeroBiasMLP(nn.Module):
|
|
|
239
239
|
def forward(self, x):
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
|
-
|
|
242
|
+
if len(y.shape)==2:
|
|
243
|
+
mask[x[1][:,0]>0,:] = 1
|
|
244
|
+
elif len(y.shape)==3:
|
|
245
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
243
246
|
return y*mask
|
|
244
|
-
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class HDMLP(nn.Module):
|
|
250
|
+
def __init__(
|
|
251
|
+
self,
|
|
252
|
+
input_size,
|
|
253
|
+
hidden_sizes,
|
|
254
|
+
output_depth,
|
|
255
|
+
activation=nn.ReLU,
|
|
256
|
+
output_activation=None,
|
|
257
|
+
post_layer_fct=lambda layer_ix, total_layers, layer: None,
|
|
258
|
+
post_act_fct=lambda layer_ix, total_layers, layer: None,
|
|
259
|
+
allow_broadcast=False,
|
|
260
|
+
use_cuda=False,
|
|
261
|
+
):
|
|
262
|
+
# init the module object
|
|
263
|
+
super().__init__()
|
|
264
|
+
self.mlp = MLP(mlp_sizes=[1] + hidden_sizes + [output_depth],
|
|
265
|
+
activation=activation,
|
|
266
|
+
output_activation=output_activation,
|
|
267
|
+
post_layer_fct=post_layer_fct,
|
|
268
|
+
post_act_fct=post_act_fct,
|
|
269
|
+
allow_broadcast=allow_broadcast,
|
|
270
|
+
use_cuda=use_cuda,
|
|
271
|
+
bias=True)
|
|
272
|
+
self.input_size=input_size
|
|
273
|
+
self.output_depth=output_depth
|
|
274
|
+
|
|
275
|
+
# pass through our sequential for the output!
|
|
276
|
+
def forward(self, x):
|
|
277
|
+
batch_size, n = x.shape
|
|
278
|
+
x = x.view(batch_size * n, 1)
|
|
279
|
+
out = self.mlp(x)
|
|
280
|
+
out = out.view(batch_size, n, self.output_depth)
|
|
281
|
+
return out
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: SURE-tools
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.24
|
|
4
4
|
Summary: Succinct Representation of Single Cells
|
|
5
5
|
Home-page: https://github.com/ZengFLab/SURE
|
|
6
6
|
Author: Feng Zeng
|
|
@@ -20,6 +20,7 @@ Requires-Dist: numpy
|
|
|
20
20
|
Requires-Dist: scikit-learn
|
|
21
21
|
Requires-Dist: pandas
|
|
22
22
|
Requires-Dist: pyro-ppl
|
|
23
|
+
Requires-Dist: jax[cuda12]
|
|
23
24
|
Requires-Dist: leidenalg
|
|
24
25
|
Requires-Dist: python-igraph
|
|
25
26
|
Requires-Dist: networkx
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
SURE/DensityFlow.py,sha256=IpObVzq3pb2GAYt8f0rCkR8d9YdsRg1RwPruvNKHoHM,56132
|
|
2
|
+
SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
|
|
3
|
+
SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
|
|
4
|
+
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
|
+
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
6
|
+
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
7
|
+
SURE/atac/__init__.py,sha256=3smP8IKHfwNCd1G_sZH3pKHXuLkLpFuLtjUTUSy7_As,34
|
|
8
|
+
SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
|
|
9
|
+
SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
|
|
10
|
+
SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
|
|
11
|
+
SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
12
|
+
SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
|
|
13
|
+
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
14
|
+
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
|
+
SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
|
|
16
|
+
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
|
+
SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
|
|
18
|
+
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
|
+
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
+
sure_tools-2.2.24.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.2.24.dist-info/METADATA,sha256=oQslRmRo5_NDhapldLCnsck6dXrGEEHj-VAEl4XzWNU,2678
|
|
22
|
+
sure_tools-2.2.24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.2.24.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.2.24.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.2.24.dist-info/RECORD,,
|
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=BoaNDubCKpsYJcwipZxrSCpol4nVvCttP28MizHffzY,51650
|
|
2
|
-
SURE/SURE.py,sha256=ghagk4vO3xrAXwdyYTIv7y0X2KXr1R2baXH8lqvUl7k,48094
|
|
3
|
-
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
|
-
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
|
-
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
6
|
-
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
7
|
-
SURE/atac/__init__.py,sha256=3smP8IKHfwNCd1G_sZH3pKHXuLkLpFuLtjUTUSy7_As,34
|
|
8
|
-
SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
|
|
9
|
-
SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
|
|
10
|
-
SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
|
|
11
|
-
SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
12
|
-
SURE/flow/flow_stats.py,sha256=cBBsPEDpWNMpbzlyQ3f0385RSrX6_5RCH2caOyi4ihM,9908
|
|
13
|
-
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
14
|
-
SURE/perturb/__init__.py,sha256=ouxShhbxZM4r5Gf7GmKiutrsmtyq7QL8rHjhgF0BU08,32
|
|
15
|
-
SURE/perturb/perturb.py,sha256=CqO3xPfNA3cG175tadDidKvGsTu_yKfJRRLn_93awKM,3303
|
|
16
|
-
SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
|
|
17
|
-
SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
|
|
18
|
-
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
|
-
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.34.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
-
sure_tools-2.1.34.dist-info/METADATA,sha256=EV5AA3dO2YrSVqVC0wytPt0RkCIb7nuws0FSUhNcXuE,2651
|
|
22
|
-
sure_tools-2.1.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
-
sure_tools-2.1.34.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
-
sure_tools-2.1.34.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
-
sure_tools-2.1.34.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|