SURE-tools 2.2.24__py3-none-any.whl → 2.4.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +130 -65
- SURE/DensityFlow2.py +1422 -0
- SURE/DensityFlowLinear.py +1414 -0
- SURE/EfficientTranscriptomeDecoder.py +552 -0
- SURE/PerturbE.py +1300 -0
- SURE/PerturbationAwareDecoder.py +737 -0
- SURE/SimpleTranscriptomeDecoder.py +567 -0
- SURE/TranscriptomeDecoder.py +511 -0
- SURE/VirtualCellDecoder.py +658 -0
- SURE/__init__.py +17 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/METADATA +1 -1
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/RECORD +17 -9
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.24.dist-info → sure_tools-2.4.34.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,16 +57,16 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 100,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
|
-
z_dim: int =
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
67
|
-
|
|
64
|
+
z_dim: int = 50,
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
|
+
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
69
|
+
hidden_layers: list = [1024],
|
|
70
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
71
|
nn_dropout: float = 0.1,
|
|
72
72
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
self.input_size = input_size
|
|
83
83
|
self.cell_factor_size = cell_factor_size
|
|
84
|
-
self.
|
|
84
|
+
self.dispersion = dispersion
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,8 +108,16 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
114
|
+
|
|
115
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Latent Dimension: {self.latent_dim}")
|
|
117
|
+
print(f" - Gene Dimension: {self.input_size}")
|
|
118
|
+
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
119
|
+
print(f" - Device: {self.get_device()}")
|
|
120
|
+
print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
|
|
112
121
|
|
|
113
122
|
def setup_networks(self):
|
|
114
123
|
latent_dim = self.latent_dim
|
|
@@ -251,7 +260,7 @@ class DensityFlow(nn.Module):
|
|
|
251
260
|
)
|
|
252
261
|
)
|
|
253
262
|
|
|
254
|
-
self.
|
|
263
|
+
self.decoder_log_mu = MLP(
|
|
255
264
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
256
265
|
activation=activate_fct,
|
|
257
266
|
output_activation=None,
|
|
@@ -341,8 +350,8 @@ class DensityFlow(nn.Module):
|
|
|
341
350
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
342
351
|
|
|
343
352
|
if self.loss_func=='negbinomial':
|
|
344
|
-
|
|
345
|
-
|
|
353
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
354
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
346
355
|
|
|
347
356
|
if self.use_zeroinflate:
|
|
348
357
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -376,20 +385,25 @@ class DensityFlow(nn.Module):
|
|
|
376
385
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
377
386
|
|
|
378
387
|
zs = zns
|
|
379
|
-
|
|
388
|
+
log_mu = self.decoder_log_mu(zs)
|
|
380
389
|
if self.loss_func in ['bernoulli']:
|
|
381
|
-
log_theta =
|
|
390
|
+
log_theta = log_mu
|
|
391
|
+
elif self.loss_func == 'negbinomial':
|
|
392
|
+
mu = log_mu.exp()
|
|
382
393
|
else:
|
|
383
|
-
rate =
|
|
394
|
+
rate = log_mu.exp()
|
|
384
395
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
385
396
|
if self.loss_func == 'poisson':
|
|
386
397
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
387
398
|
|
|
388
399
|
if self.loss_func == 'negbinomial':
|
|
400
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
389
401
|
if self.use_zeroinflate:
|
|
390
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
402
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
403
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
391
404
|
else:
|
|
392
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
405
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
|
|
406
|
+
logits=logits).to_event(1), obs=xs)
|
|
393
407
|
elif self.loss_func == 'poisson':
|
|
394
408
|
if self.use_zeroinflate:
|
|
395
409
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -420,8 +434,8 @@ class DensityFlow(nn.Module):
|
|
|
420
434
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
421
435
|
|
|
422
436
|
if self.loss_func=='negbinomial':
|
|
423
|
-
|
|
424
|
-
|
|
437
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
438
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
425
439
|
|
|
426
440
|
if self.use_zeroinflate:
|
|
427
441
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -460,20 +474,23 @@ class DensityFlow(nn.Module):
|
|
|
460
474
|
else:
|
|
461
475
|
zs = zns
|
|
462
476
|
|
|
463
|
-
|
|
477
|
+
log_mu = self.decoder_log_mu(zs)
|
|
464
478
|
if self.loss_func in ['bernoulli']:
|
|
465
|
-
log_theta =
|
|
479
|
+
log_theta = log_mu
|
|
480
|
+
elif self.loss_func == 'negbinomial':
|
|
481
|
+
mu = log_mu.exp()
|
|
466
482
|
else:
|
|
467
|
-
rate =
|
|
483
|
+
rate = log_mu.exp()
|
|
468
484
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
469
485
|
if self.loss_func == 'poisson':
|
|
470
486
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
471
487
|
|
|
472
488
|
if self.loss_func == 'negbinomial':
|
|
489
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
473
490
|
if self.use_zeroinflate:
|
|
474
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
491
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
475
492
|
else:
|
|
476
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
493
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
477
494
|
elif self.loss_func == 'poisson':
|
|
478
495
|
if self.use_zeroinflate:
|
|
479
496
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -504,8 +521,8 @@ class DensityFlow(nn.Module):
|
|
|
504
521
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
505
522
|
|
|
506
523
|
if self.loss_func=='negbinomial':
|
|
507
|
-
|
|
508
|
-
|
|
524
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
525
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
509
526
|
|
|
510
527
|
if self.use_zeroinflate:
|
|
511
528
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -556,20 +573,24 @@ class DensityFlow(nn.Module):
|
|
|
556
573
|
|
|
557
574
|
zs = zns
|
|
558
575
|
|
|
559
|
-
|
|
576
|
+
log_mu = self.decoder_log_mu(zs)
|
|
560
577
|
if self.loss_func in ['bernoulli']:
|
|
561
|
-
log_theta =
|
|
578
|
+
log_theta = log_mu
|
|
579
|
+
elif self.loss_func in ['negbinomial']:
|
|
580
|
+
mu = log_mu.exp()
|
|
562
581
|
else:
|
|
563
|
-
rate =
|
|
582
|
+
rate = log_mu.exp()
|
|
564
583
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
565
584
|
if self.loss_func == 'poisson':
|
|
566
585
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
567
586
|
|
|
568
587
|
if self.loss_func == 'negbinomial':
|
|
588
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
569
589
|
if self.use_zeroinflate:
|
|
570
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
590
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
591
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
571
592
|
else:
|
|
572
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
593
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
573
594
|
elif self.loss_func == 'poisson':
|
|
574
595
|
if self.use_zeroinflate:
|
|
575
596
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -600,8 +621,8 @@ class DensityFlow(nn.Module):
|
|
|
600
621
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
601
622
|
|
|
602
623
|
if self.loss_func=='negbinomial':
|
|
603
|
-
|
|
604
|
-
|
|
624
|
+
dispersion = pyro.param("dispersion", self.dispersion *
|
|
625
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
605
626
|
|
|
606
627
|
if self.use_zeroinflate:
|
|
607
628
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -662,20 +683,24 @@ class DensityFlow(nn.Module):
|
|
|
662
683
|
else:
|
|
663
684
|
zs = zns
|
|
664
685
|
|
|
665
|
-
|
|
686
|
+
log_mu = self.decoder_log_mu(zs)
|
|
666
687
|
if self.loss_func in ['bernoulli']:
|
|
667
|
-
log_theta =
|
|
688
|
+
log_theta = log_mu
|
|
689
|
+
elif self.loss_func in ['negbinomial']:
|
|
690
|
+
mu = log_mu.exp()
|
|
668
691
|
else:
|
|
669
|
-
rate =
|
|
692
|
+
rate = log_mu.exp()
|
|
670
693
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
671
694
|
if self.loss_func == 'poisson':
|
|
672
695
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
673
696
|
|
|
674
697
|
if self.loss_func == 'negbinomial':
|
|
698
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
|
|
675
699
|
if self.use_zeroinflate:
|
|
676
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=
|
|
700
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
701
|
+
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
677
702
|
else:
|
|
678
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=
|
|
703
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
|
|
679
704
|
elif self.loss_func == 'poisson':
|
|
680
705
|
if self.use_zeroinflate:
|
|
681
706
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -706,13 +731,13 @@ class DensityFlow(nn.Module):
|
|
|
706
731
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
732
|
#else:
|
|
708
733
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
-
zus = self.
|
|
734
|
+
zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
710
735
|
else:
|
|
711
736
|
#if self.turn_off_cell_specific:
|
|
712
737
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
738
|
#else:
|
|
714
739
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
-
zus = zus + self.
|
|
740
|
+
zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
|
|
716
741
|
return zus
|
|
717
742
|
|
|
718
743
|
def _get_codebook_identity(self):
|
|
@@ -854,12 +879,12 @@ class DensityFlow(nn.Module):
|
|
|
854
879
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
855
880
|
|
|
856
881
|
# factor effect of xs
|
|
857
|
-
dzs0 = self.
|
|
882
|
+
dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
858
883
|
|
|
859
884
|
# perturbation effect
|
|
860
885
|
ps = np.ones_like(us_i)
|
|
861
886
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
862
|
-
dzs = self.
|
|
887
|
+
dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
863
888
|
zs = zs + dzs0 + dzs
|
|
864
889
|
else:
|
|
865
890
|
zs = zs + dzs0
|
|
@@ -873,10 +898,11 @@ class DensityFlow(nn.Module):
|
|
|
873
898
|
library_sizes = library_sizes.reshape(-1,1)
|
|
874
899
|
|
|
875
900
|
counts = self.get_counts(zs, library_sizes=library_sizes)
|
|
901
|
+
log_mu = self.get_log_mu(zs)
|
|
876
902
|
|
|
877
|
-
return counts,
|
|
903
|
+
return counts, log_mu
|
|
878
904
|
|
|
879
|
-
def
|
|
905
|
+
def _cell_shift(self, zs, perturb_idx, perturb):
|
|
880
906
|
#zns,_ = self.encoder_zn(xs)
|
|
881
907
|
#zns,_ = self._get_basal_embedding(xs)
|
|
882
908
|
zns = zs
|
|
@@ -893,7 +919,7 @@ class DensityFlow(nn.Module):
|
|
|
893
919
|
|
|
894
920
|
return ms
|
|
895
921
|
|
|
896
|
-
def
|
|
922
|
+
def get_cell_shift(self,
|
|
897
923
|
zs,
|
|
898
924
|
perturb_idx,
|
|
899
925
|
perturb_us,
|
|
@@ -911,43 +937,43 @@ class DensityFlow(nn.Module):
|
|
|
911
937
|
Z = []
|
|
912
938
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
913
939
|
for Z_batch, P_batch, _ in dataloader:
|
|
914
|
-
zns = self.
|
|
940
|
+
zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
|
|
915
941
|
Z.append(tensor_to_numpy(zns))
|
|
916
942
|
pbar.update(1)
|
|
917
943
|
|
|
918
944
|
Z = np.concatenate(Z)
|
|
919
945
|
return Z
|
|
920
946
|
|
|
921
|
-
def
|
|
922
|
-
return self.
|
|
947
|
+
def _log_mu(self, zs):
|
|
948
|
+
return self.decoder_log_mu(zs)
|
|
923
949
|
|
|
924
|
-
def
|
|
925
|
-
delta_zs,
|
|
926
|
-
batch_size: int = 1024):
|
|
950
|
+
def get_log_mu(self, zs, batch_size: int = 1024):
|
|
927
951
|
"""
|
|
928
952
|
Return cells' changes in the feature space induced by specific perturbation of a factor
|
|
929
953
|
|
|
930
954
|
"""
|
|
931
|
-
|
|
932
|
-
dataset = CustomDataset(
|
|
955
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
956
|
+
dataset = CustomDataset(zs)
|
|
933
957
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
934
958
|
|
|
935
959
|
R = []
|
|
936
960
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
937
|
-
for
|
|
938
|
-
r = self.
|
|
961
|
+
for Z_batch, _ in dataloader:
|
|
962
|
+
r = self._log_mu(Z_batch)
|
|
939
963
|
R.append(tensor_to_numpy(r))
|
|
940
964
|
pbar.update(1)
|
|
941
965
|
|
|
942
966
|
R = np.concatenate(R)
|
|
943
967
|
return R
|
|
944
968
|
|
|
945
|
-
def _count(self,
|
|
969
|
+
def _count(self, log_mu, library_size=None):
|
|
946
970
|
if self.loss_func == 'bernoulli':
|
|
947
|
-
|
|
948
|
-
|
|
971
|
+
counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
|
|
972
|
+
elif self.loss_func == 'multinomial':
|
|
973
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
|
|
974
|
+
counts = theta * library_size
|
|
949
975
|
else:
|
|
950
|
-
rate =
|
|
976
|
+
rate = log_mu.exp()
|
|
951
977
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
952
978
|
counts = theta * library_size
|
|
953
979
|
return counts
|
|
@@ -969,8 +995,8 @@ class DensityFlow(nn.Module):
|
|
|
969
995
|
E = []
|
|
970
996
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
971
997
|
for Z_batch, L_batch, _ in dataloader:
|
|
972
|
-
|
|
973
|
-
counts = self._count(
|
|
998
|
+
log_mu = self._log_mu(Z_batch)
|
|
999
|
+
counts = self._count(log_mu, L_batch)
|
|
974
1000
|
E.append(tensor_to_numpy(counts))
|
|
975
1001
|
pbar.update(1)
|
|
976
1002
|
|
|
@@ -1116,7 +1142,7 @@ class DensityFlow(nn.Module):
|
|
|
1116
1142
|
pbar.set_postfix({'loss': str_loss})
|
|
1117
1143
|
pbar.update(1)
|
|
1118
1144
|
|
|
1119
|
-
@classmethod
|
|
1145
|
+
'''@classmethod
|
|
1120
1146
|
def save_model(cls, model, file_path, compression=False):
|
|
1121
1147
|
"""Save the model to the specified file path."""
|
|
1122
1148
|
file_path = os.path.abspath(file_path)
|
|
@@ -1144,6 +1170,45 @@ class DensityFlow(nn.Module):
|
|
|
1144
1170
|
with open(file_path, 'rb') as pickle_file:
|
|
1145
1171
|
model = pickle.load(pickle_file)
|
|
1146
1172
|
|
|
1173
|
+
return model'''
|
|
1174
|
+
|
|
1175
|
+
def save_model(self, path):
|
|
1176
|
+
"""Save model checkpoint"""
|
|
1177
|
+
torch.save({
|
|
1178
|
+
'model_state_dict': self.state_dict(),
|
|
1179
|
+
'model_config': {
|
|
1180
|
+
'input_size': self.input_size,
|
|
1181
|
+
'codebook_size': self.code_size,
|
|
1182
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1183
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1184
|
+
'supervised_mode':self.supervised_mode,
|
|
1185
|
+
'z_dim': self.latent_dim,
|
|
1186
|
+
'z_dist': self.latent_dist,
|
|
1187
|
+
'loss_func': self.loss_func,
|
|
1188
|
+
'dispersion': self.dispersion,
|
|
1189
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1190
|
+
'hidden_layers':self.hidden_layers,
|
|
1191
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1192
|
+
'nn_dropout':self.nn_dropout,
|
|
1193
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1194
|
+
'post_act_fct':self.post_act_fct,
|
|
1195
|
+
'config_enum':self.config_enum,
|
|
1196
|
+
'use_cuda':self.use_cuda,
|
|
1197
|
+
'seed':self.seed,
|
|
1198
|
+
'zero_bias':self.use_bias,
|
|
1199
|
+
'dtype':self.dtype,
|
|
1200
|
+
}
|
|
1201
|
+
}, path)
|
|
1202
|
+
|
|
1203
|
+
@classmethod
|
|
1204
|
+
def load_model(cls, model_path: str):
|
|
1205
|
+
"""Load pre-trained model"""
|
|
1206
|
+
checkpoint = torch.load(model_path)
|
|
1207
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1208
|
+
|
|
1209
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1210
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1211
|
+
|
|
1147
1212
|
return model
|
|
1148
1213
|
|
|
1149
1214
|
|
|
@@ -1340,10 +1405,10 @@ def main():
|
|
|
1340
1405
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1341
1406
|
|
|
1342
1407
|
###########################################
|
|
1343
|
-
|
|
1408
|
+
df = DensityFlow(
|
|
1344
1409
|
input_size=input_size,
|
|
1345
1410
|
cell_factor_size=cell_factor_size,
|
|
1346
|
-
|
|
1411
|
+
dispersion=args.dispersion,
|
|
1347
1412
|
z_dim=args.z_dim,
|
|
1348
1413
|
hidden_layers=args.hidden_layers,
|
|
1349
1414
|
hidden_layer_activation=args.hidden_layer_activation,
|
|
@@ -1359,7 +1424,7 @@ def main():
|
|
|
1359
1424
|
dtype=dtype,
|
|
1360
1425
|
)
|
|
1361
1426
|
|
|
1362
|
-
|
|
1427
|
+
df.fit(xs, us=us,
|
|
1363
1428
|
num_epochs=args.num_epochs,
|
|
1364
1429
|
learning_rate=args.learning_rate,
|
|
1365
1430
|
batch_size=args.batch_size,
|
|
@@ -1371,9 +1436,9 @@ def main():
|
|
|
1371
1436
|
|
|
1372
1437
|
if args.save_model is not None:
|
|
1373
1438
|
if args.save_model.endswith('gz'):
|
|
1374
|
-
DensityFlow.save_model(
|
|
1439
|
+
DensityFlow.save_model(df, args.save_model, compression=True)
|
|
1375
1440
|
else:
|
|
1376
|
-
DensityFlow.save_model(
|
|
1441
|
+
DensityFlow.save_model(df, args.save_model)
|
|
1377
1442
|
|
|
1378
1443
|
|
|
1379
1444
|
|