SURE-tools 2.2.4__tar.gz → 2.2.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.2.4 → sure_tools-2.2.10}/PKG-INFO +1 -1
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/DensityFlow.py +36 -41
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.4 → sure_tools-2.2.10}/setup.py +1 -1
- {sure_tools-2.2.4 → sure_tools-2.2.10}/LICENSE +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/README.md +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/SURE.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.4 → sure_tools-2.2.10}/setup.cfg +0 -0
|
@@ -62,7 +62,7 @@ class DensityFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
67
|
use_zeroinflate: bool = True,
|
|
68
68
|
hidden_layers: list = [500],
|
|
@@ -234,6 +234,16 @@ class DensityFlow(nn.Module):
|
|
|
234
234
|
allow_broadcast=self.allow_broadcast,
|
|
235
235
|
use_cuda=self.use_cuda,
|
|
236
236
|
)
|
|
237
|
+
if self.loss_func == 'negbinomial':
|
|
238
|
+
self.decoder_total_count = MLP(
|
|
239
|
+
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
240
|
+
activation=activate_fct,
|
|
241
|
+
output_activation=Exp,
|
|
242
|
+
post_layer_fct=post_layer_fct,
|
|
243
|
+
post_act_fct=post_act_fct,
|
|
244
|
+
allow_broadcast=self.allow_broadcast,
|
|
245
|
+
use_cuda=self.use_cuda,
|
|
246
|
+
)
|
|
237
247
|
|
|
238
248
|
if self.latent_dist == 'studentt':
|
|
239
249
|
self.codebook = MLP(
|
|
@@ -314,9 +324,9 @@ class DensityFlow(nn.Module):
|
|
|
314
324
|
batch_size = xs.size(0)
|
|
315
325
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
316
326
|
|
|
317
|
-
if self.loss_func=='negbinomial':
|
|
318
|
-
|
|
319
|
-
|
|
327
|
+
#if self.loss_func=='negbinomial':
|
|
328
|
+
# total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
329
|
+
# xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
320
330
|
|
|
321
331
|
if self.use_zeroinflate:
|
|
322
332
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -360,6 +370,7 @@ class DensityFlow(nn.Module):
|
|
|
360
370
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
361
371
|
|
|
362
372
|
if self.loss_func == 'negbinomial':
|
|
373
|
+
total_count = self.decoder_total_count(zs)
|
|
363
374
|
if self.use_zeroinflate:
|
|
364
375
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
365
376
|
else:
|
|
@@ -393,9 +404,9 @@ class DensityFlow(nn.Module):
|
|
|
393
404
|
batch_size = xs.size(0)
|
|
394
405
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
395
406
|
|
|
396
|
-
if self.loss_func=='negbinomial':
|
|
397
|
-
|
|
398
|
-
|
|
407
|
+
#if self.loss_func=='negbinomial':
|
|
408
|
+
# total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
409
|
+
# xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
399
410
|
|
|
400
411
|
if self.use_zeroinflate:
|
|
401
412
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -444,6 +455,7 @@ class DensityFlow(nn.Module):
|
|
|
444
455
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
445
456
|
|
|
446
457
|
if self.loss_func == 'negbinomial':
|
|
458
|
+
total_count = self.decoder_total_count(zs)
|
|
447
459
|
if self.use_zeroinflate:
|
|
448
460
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
449
461
|
else:
|
|
@@ -477,9 +489,9 @@ class DensityFlow(nn.Module):
|
|
|
477
489
|
batch_size = xs.size(0)
|
|
478
490
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
479
491
|
|
|
480
|
-
if self.loss_func=='negbinomial':
|
|
481
|
-
|
|
482
|
-
|
|
492
|
+
#if self.loss_func=='negbinomial':
|
|
493
|
+
# total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
494
|
+
# xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
483
495
|
|
|
484
496
|
if self.use_zeroinflate:
|
|
485
497
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -540,6 +552,7 @@ class DensityFlow(nn.Module):
|
|
|
540
552
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
541
553
|
|
|
542
554
|
if self.loss_func == 'negbinomial':
|
|
555
|
+
total_count = self.decoder_total_count(zs)
|
|
543
556
|
if self.use_zeroinflate:
|
|
544
557
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
545
558
|
else:
|
|
@@ -573,9 +586,9 @@ class DensityFlow(nn.Module):
|
|
|
573
586
|
batch_size = xs.size(0)
|
|
574
587
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
575
588
|
|
|
576
|
-
if self.loss_func=='negbinomial':
|
|
577
|
-
|
|
578
|
-
|
|
589
|
+
#if self.loss_func=='negbinomial':
|
|
590
|
+
# total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
591
|
+
# xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
579
592
|
|
|
580
593
|
if self.use_zeroinflate:
|
|
581
594
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -646,6 +659,7 @@ class DensityFlow(nn.Module):
|
|
|
646
659
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
647
660
|
|
|
648
661
|
if self.loss_func == 'negbinomial':
|
|
662
|
+
total_count = self.decoder_total_count(zs)
|
|
649
663
|
if self.use_zeroinflate:
|
|
650
664
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
651
665
|
else:
|
|
@@ -824,9 +838,11 @@ class DensityFlow(nn.Module):
|
|
|
824
838
|
|
|
825
839
|
# perturbation effect
|
|
826
840
|
ps = np.ones_like(us_i)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
841
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
842
|
+
dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
|
|
843
|
+
zs = zs + dzs0 + dzs
|
|
844
|
+
else:
|
|
845
|
+
zs = zs + dzs0
|
|
830
846
|
|
|
831
847
|
if library_sizes is None:
|
|
832
848
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -905,33 +921,18 @@ class DensityFlow(nn.Module):
|
|
|
905
921
|
R = np.concatenate(R)
|
|
906
922
|
return R
|
|
907
923
|
|
|
908
|
-
def _count(self,concentrate, library_size=None):
|
|
924
|
+
def _count(self, concentrate, library_size=None):
|
|
909
925
|
if self.loss_func == 'bernoulli':
|
|
910
926
|
#counts = self.sigmoid(concentrate)
|
|
911
927
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
912
928
|
else:
|
|
913
929
|
rate = concentrate.exp()
|
|
914
930
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
915
|
-
|
|
916
|
-
rate = theta * library_size
|
|
917
|
-
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
918
|
-
elif self.loss_func == 'negbinomial':
|
|
919
|
-
total_count = self.inverse_dispersion
|
|
920
|
-
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
|
|
921
|
-
return counts
|
|
922
|
-
|
|
923
|
-
def _count_sample(self,concentrate):
|
|
924
|
-
if self.loss_func == 'bernoulli':
|
|
925
|
-
logits = concentrate
|
|
926
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
927
|
-
else:
|
|
928
|
-
counts = self._count(concentrate=concentrate)
|
|
929
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
931
|
+
counts = theta * library_size
|
|
930
932
|
return counts
|
|
931
933
|
|
|
932
934
|
def get_counts(self, zs, library_sizes,
|
|
933
|
-
batch_size: int = 1024
|
|
934
|
-
use_sampler: bool = False):
|
|
935
|
+
batch_size: int = 1024):
|
|
935
936
|
|
|
936
937
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
937
938
|
|
|
@@ -948,10 +949,7 @@ class DensityFlow(nn.Module):
|
|
|
948
949
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
949
950
|
for Z_batch, L_batch, _ in dataloader:
|
|
950
951
|
concentrate = self._get_expression_response(Z_batch)
|
|
951
|
-
|
|
952
|
-
counts = self._count_sample(concentrate)
|
|
953
|
-
else:
|
|
954
|
-
counts = self._count(concentrate, L_batch)
|
|
952
|
+
counts = self._count(concentrate, L_batch)
|
|
955
953
|
E.append(tensor_to_numpy(counts))
|
|
956
954
|
pbar.update(1)
|
|
957
955
|
|
|
@@ -1096,9 +1094,6 @@ class DensityFlow(nn.Module):
|
|
|
1096
1094
|
# Update progress bar
|
|
1097
1095
|
pbar.set_postfix({'loss': str_loss})
|
|
1098
1096
|
pbar.update(1)
|
|
1099
|
-
|
|
1100
|
-
if self.loss_func == 'negbinomial':
|
|
1101
|
-
self.inverse_dispersion = pyro.param("inverse_dispersion")
|
|
1102
1097
|
|
|
1103
1098
|
@classmethod
|
|
1104
1099
|
def save_model(cls, model, file_path, compression=False):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|