SURE-tools 2.1.46__tar.gz → 2.1.48__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.1.46 → sure_tools-2.1.48}/PKG-INFO +1 -1
  2. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/PerturbFlow.py +13 -9
  3. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/perturb/perturb.py +1 -1
  4. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/utils/__init__.py +1 -1
  5. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/utils/custom_mlp.py +35 -1
  6. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE_tools.egg-info/PKG-INFO +1 -1
  7. {sure_tools-2.1.46 → sure_tools-2.1.48}/setup.py +1 -1
  8. {sure_tools-2.1.46 → sure_tools-2.1.48}/LICENSE +0 -0
  9. {sure_tools-2.1.46 → sure_tools-2.1.48}/README.md +0 -0
  10. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/SURE.py +0 -0
  11. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/__init__.py +0 -0
  12. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/assembly/__init__.py +0 -0
  13. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/assembly/assembly.py +0 -0
  14. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/assembly/atlas.py +0 -0
  15. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/atac/__init__.py +0 -0
  16. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/atac/utils.py +0 -0
  17. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/codebook/__init__.py +0 -0
  18. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/codebook/codebook.py +0 -0
  19. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/flow/__init__.py +0 -0
  20. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/flow/flow_stats.py +0 -0
  21. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/flow/plot_quiver.py +0 -0
  22. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/perturb/__init__.py +0 -0
  23. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.46 → sure_tools-2.1.48}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.46 → sure_tools-2.1.48}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.46
3
+ Version: 2.1.48
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -378,7 +378,8 @@ class PerturbFlow(nn.Module):
378
378
 
379
379
  def guide1(self, xs):
380
380
  with pyro.plate('data'):
381
- zn_loc, zn_scale = self.encoder_zn(xs)
381
+ #zn_loc, zn_scale = self.encoder_zn(xs)
382
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
382
383
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
383
384
 
384
385
  alpha = self.encoder_n(zns)
@@ -466,7 +467,8 @@ class PerturbFlow(nn.Module):
466
467
 
467
468
  def guide2(self, xs, us=None):
468
469
  with pyro.plate('data'):
469
- zn_loc, zn_scale = self.encoder_zn(xs)
470
+ #zn_loc, zn_scale = self.encoder_zn(xs)
471
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
470
472
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
471
473
 
472
474
  alpha = self.encoder_n(zns)
@@ -561,7 +563,8 @@ class PerturbFlow(nn.Module):
561
563
  def guide3(self, xs, ys, embeds=None):
562
564
  with pyro.plate('data'):
563
565
  if embeds is None:
564
- zn_loc, zn_scale = self.encoder_zn(xs)
566
+ #zn_loc, zn_scale = self.encoder_zn(xs)
567
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
565
568
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
566
569
 
567
570
  def model4(self, xs, us, ys, embeds=None):
@@ -663,7 +666,8 @@ class PerturbFlow(nn.Module):
663
666
  def guide4(self, xs, us, ys, embeds=None):
664
667
  with pyro.plate('data'):
665
668
  if embeds is None:
666
- zn_loc, zn_scale = self.encoder_zn(xs)
669
+ #zn_loc, zn_scale = self.encoder_zn(xs)
670
+ zn_loc, zn_scale = self._get_basal_embedding(xs)
667
671
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
668
672
 
669
673
  def _total_effects(self, zns, us):
@@ -692,8 +696,8 @@ class PerturbFlow(nn.Module):
692
696
  return cb
693
697
 
694
698
  def _get_basal_embedding(self, xs):
695
- zns, _ = self.encoder_zn(xs)
696
- return zns
699
+ loc, scale = self.encoder_zn(xs)
700
+ return loc, scale
697
701
 
698
702
  def get_basal_embedding(self,
699
703
  xs,
@@ -720,7 +724,7 @@ class PerturbFlow(nn.Module):
720
724
  Z = []
721
725
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
722
726
  for X_batch, _ in dataloader:
723
- zns = self._get_basal_embedding(X_batch)
727
+ zns,_ = self._get_basal_embedding(X_batch)
724
728
  Z.append(tensor_to_numpy(zns))
725
729
  pbar.update(1)
726
730
 
@@ -732,7 +736,7 @@ class PerturbFlow(nn.Module):
732
736
  alpha = self.encoder_n(xs)
733
737
  else:
734
738
  #zns,_ = self.encoder_zn(xs)
735
- zns = self._get_basal_embedding(xs)
739
+ zns,_ = self._get_basal_embedding(xs)
736
740
  alpha = self.encoder_n(zns)
737
741
  return alpha
738
742
 
@@ -803,7 +807,7 @@ class PerturbFlow(nn.Module):
803
807
 
804
808
  def _cell_response(self, xs, factor_idx, perturb):
805
809
  #zns,_ = self.encoder_zn(xs)
806
- zns = self._get_basal_embedding(xs)
810
+ zns,_ = self._get_basal_embedding(xs)
807
811
  if perturb.ndim==2:
808
812
  ms = self.cell_factor_effect[factor_idx]([zns, perturb])
809
813
  else:
@@ -9,7 +9,7 @@ class LabelMatrix:
9
9
  def __init__(self):
10
10
  self.labels_ = None
11
11
 
12
- def fit_transform(self, labels, control_label=None, sep_pattern=r'[;_\-\s]', speedup: Literal['none','vectorize','parallel']='none'):
12
+ def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
13
  if speedup=='none':
14
14
  mat, self.labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
15
15
  elif speedup=='vectorize':
@@ -7,7 +7,7 @@ from .utils import find_partitions_greedy
7
7
 
8
8
  from .queue import PriorityQueue
9
9
 
10
- from .custom_mlp import MLP, Exp, ZeroBiasMLP
10
+ from .custom_mlp import MLP, Exp, ZeroBiasMLP, HDMLP
11
11
 
12
12
  # Importing modules
13
13
  #from . import utils
@@ -241,4 +241,38 @@ class ZeroBiasMLP(nn.Module):
241
241
  mask = torch.zeros_like(y)
242
242
  mask[x[1][:,0]>0,:] = 1
243
243
  return y*mask
244
-
244
+
245
+
246
+ class HDMLP(nn.Module):
247
+ def __init__(
248
+ self,
249
+ input_size,
250
+ hidden_sizes,
251
+ output_depth,
252
+ activation=nn.ReLU,
253
+ output_activation=None,
254
+ post_layer_fct=lambda layer_ix, total_layers, layer: None,
255
+ post_act_fct=lambda layer_ix, total_layers, layer: None,
256
+ allow_broadcast=False,
257
+ use_cuda=False,
258
+ ):
259
+ # init the module object
260
+ super().__init__()
261
+ self.mlp = MLP(mlp_sizes=[1] + hidden_sizes + [output_depth],
262
+ activation=activation,
263
+ output_activation=output_activation,
264
+ post_layer_fct=post_layer_fct,
265
+ post_act_fct=post_act_fct,
266
+ allow_broadcast=allow_broadcast,
267
+ use_cuda=use_cuda,
268
+ bias=True)
269
+ self.input_size=input_size
270
+ self.output_depth=output_depth
271
+
272
+ # pass through our sequential for the output!
273
+ def forward(self, x):
274
+ batch_size, n = x.shape
275
+ x = x.view(batch_size * n, 1)
276
+ out = self.mlp(x)
277
+ out = out.view(batch_size, n, self.output_depth)
278
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.46
3
+ Version: 2.1.48
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.46',
8
+ version='2.1.48',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes