SURE-tools 2.2.19__py3-none-any.whl → 2.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +19 -12
- SURE/PerturbE.py +1293 -0
- SURE/__init__.py +2 -1
- SURE/utils/custom_mlp.py +39 -2
- {sure_tools-2.2.19.dist-info → sure_tools-2.3.2.dist-info}/METADATA +1 -1
- {sure_tools-2.2.19.dist-info → sure_tools-2.3.2.dist-info}/RECORD +10 -9
- {sure_tools-2.2.19.dist-info → sure_tools-2.3.2.dist-info}/WHEEL +0 -0
- {sure_tools-2.2.19.dist-info → sure_tools-2.3.2.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.2.19.dist-info → sure_tools-2.3.2.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.2.19.dist-info → sure_tools-2.3.2.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -396,7 +396,8 @@ class DensityFlow(nn.Module):
|
|
|
396
396
|
else:
|
|
397
397
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
398
398
|
elif self.loss_func == 'multinomial':
|
|
399
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
399
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
400
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
400
401
|
elif self.loss_func == 'bernoulli':
|
|
401
402
|
if self.use_zeroinflate:
|
|
402
403
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -480,7 +481,8 @@ class DensityFlow(nn.Module):
|
|
|
480
481
|
else:
|
|
481
482
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
482
483
|
elif self.loss_func == 'multinomial':
|
|
483
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
484
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
485
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
484
486
|
elif self.loss_func == 'bernoulli':
|
|
485
487
|
if self.use_zeroinflate:
|
|
486
488
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -576,7 +578,8 @@ class DensityFlow(nn.Module):
|
|
|
576
578
|
else:
|
|
577
579
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
578
580
|
elif self.loss_func == 'multinomial':
|
|
579
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
581
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
582
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
580
583
|
elif self.loss_func == 'bernoulli':
|
|
581
584
|
if self.use_zeroinflate:
|
|
582
585
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -682,7 +685,8 @@ class DensityFlow(nn.Module):
|
|
|
682
685
|
else:
|
|
683
686
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
684
687
|
elif self.loss_func == 'multinomial':
|
|
685
|
-
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
688
|
+
#pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
689
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
|
|
686
690
|
elif self.loss_func == 'bernoulli':
|
|
687
691
|
if self.use_zeroinflate:
|
|
688
692
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -706,13 +710,13 @@ class DensityFlow(nn.Module):
|
|
|
706
710
|
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
711
|
#else:
|
|
708
712
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
-
zus = self._cell_response(zns, i, us)
|
|
713
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
710
714
|
else:
|
|
711
715
|
#if self.turn_off_cell_specific:
|
|
712
716
|
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
717
|
#else:
|
|
714
718
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
-
zus = zus + self._cell_response(zns, i, us)
|
|
719
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
716
720
|
return zus
|
|
717
721
|
|
|
718
722
|
def _get_codebook_identity(self):
|
|
@@ -854,12 +858,12 @@ class DensityFlow(nn.Module):
|
|
|
854
858
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
855
859
|
|
|
856
860
|
# factor effect of xs
|
|
857
|
-
dzs0 = self.get_cell_response(zs,
|
|
861
|
+
dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
858
862
|
|
|
859
863
|
# perturbation effect
|
|
860
864
|
ps = np.ones_like(us_i)
|
|
861
865
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
862
|
-
dzs = self.get_cell_response(zs,
|
|
866
|
+
dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
863
867
|
zs = zs + dzs0 + dzs
|
|
864
868
|
else:
|
|
865
869
|
zs = zs + dzs0
|
|
@@ -946,6 +950,9 @@ class DensityFlow(nn.Module):
|
|
|
946
950
|
if self.loss_func == 'bernoulli':
|
|
947
951
|
#counts = self.sigmoid(concentrate)
|
|
948
952
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
953
|
+
elif self.loss_func == 'multinomial':
|
|
954
|
+
theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
|
|
955
|
+
counts = theta * library_size
|
|
949
956
|
else:
|
|
950
957
|
rate = concentrate.exp()
|
|
951
958
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
@@ -1340,7 +1347,7 @@ def main():
|
|
|
1340
1347
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1341
1348
|
|
|
1342
1349
|
###########################################
|
|
1343
|
-
|
|
1350
|
+
df = DensityFlow(
|
|
1344
1351
|
input_size=input_size,
|
|
1345
1352
|
cell_factor_size=cell_factor_size,
|
|
1346
1353
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1359,7 +1366,7 @@ def main():
|
|
|
1359
1366
|
dtype=dtype,
|
|
1360
1367
|
)
|
|
1361
1368
|
|
|
1362
|
-
|
|
1369
|
+
df.fit(xs, us=us,
|
|
1363
1370
|
num_epochs=args.num_epochs,
|
|
1364
1371
|
learning_rate=args.learning_rate,
|
|
1365
1372
|
batch_size=args.batch_size,
|
|
@@ -1371,9 +1378,9 @@ def main():
|
|
|
1371
1378
|
|
|
1372
1379
|
if args.save_model is not None:
|
|
1373
1380
|
if args.save_model.endswith('gz'):
|
|
1374
|
-
DensityFlow.save_model(
|
|
1381
|
+
DensityFlow.save_model(df, args.save_model, compression=True)
|
|
1375
1382
|
else:
|
|
1376
|
-
DensityFlow.save_model(
|
|
1383
|
+
DensityFlow.save_model(df, args.save_model)
|
|
1377
1384
|
|
|
1378
1385
|
|
|
1379
1386
|
|