SURE-tools 2.1.0__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -97,7 +97,6 @@ class PerturbFlow(nn.Module):
97
97
  input_size: int,
98
98
  codebook_size: int = 200,
99
99
  cell_factor_size: int = 0,
100
- cell_factor_names: list = None,
101
100
  supervised_mode: bool = False,
102
101
  z_dim: int = 10,
103
102
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
@@ -135,7 +134,6 @@ class PerturbFlow(nn.Module):
135
134
  self.post_layer_fct = post_layer_fct
136
135
  self.post_act_fct = post_act_fct
137
136
  self.hidden_layer_activation = hidden_layer_activation
138
- self.cell_factor_names = cell_factor_names
139
137
 
140
138
  self.codebook_weights = None
141
139
 
@@ -813,11 +811,8 @@ class PerturbFlow(nn.Module):
813
811
  A = np.concatenate(A)
814
812
  return A
815
813
 
816
- def _cell_state_response(self, xs, factor_idx, perturb):
817
- zns,_ = self.encoder_zn(xs)
818
- if type(factor_idx) == str:
819
- factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
820
-
814
+ def _cell_response(self, xs, factor_idx, perturb):
815
+ zns,_ = self.encoder_zn(xs)
821
816
  if perturb.ndim==2:
822
817
  ms = self.cell_factor_effect[factor_idx]([zns, perturb])
823
818
  else:
@@ -825,7 +820,7 @@ class PerturbFlow(nn.Module):
825
820
 
826
821
  return ms
827
822
 
828
- def get_cell_state_response(self,
823
+ def get_cell_response(self,
829
824
  xs,
830
825
  factor_idx,
831
826
  perturb,
@@ -843,7 +838,7 @@ class PerturbFlow(nn.Module):
843
838
  Z = []
844
839
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
845
840
  for X_batch, P_batch, _ in dataloader:
846
- zns = self._cell_state_response(X_batch, factor_idx, P_batch)
841
+ zns = self._cell_response(X_batch, factor_idx, P_batch)
847
842
  Z.append(tensor_to_numpy(zns))
848
843
  pbar.update(1)
849
844
 
@@ -852,11 +847,7 @@ class PerturbFlow(nn.Module):
852
847
 
853
848
  def get_metacell_response(self, factor_idx, perturb):
854
849
  zs = self._get_codebook()
855
- ps = convert_to_tensor(perturb, device=self.get_device())
856
-
857
- if type(factor_idx) == str:
858
- factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
859
-
850
+ ps = convert_to_tensor(perturb, device=self.get_device())
860
851
  ms = self.cell_factor_effect[factor_idx]([zs,ps])
861
852
  return tensor_to_numpy(ms)
862
853
 
SURE/SURE.py CHANGED
@@ -97,7 +97,6 @@ class SURE(nn.Module):
97
97
  input_size: int,
98
98
  codebook_size: int = 200,
99
99
  cell_factor_size: int = 0,
100
- cell_factor_names: list = None,
101
100
  supervised_mode: bool = False,
102
101
  z_dim: int = 10,
103
102
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
@@ -135,7 +134,6 @@ class SURE(nn.Module):
135
134
  self.post_layer_fct = post_layer_fct
136
135
  self.post_act_fct = post_act_fct
137
136
  self.hidden_layer_activation = hidden_layer_activation
138
- self.cell_factor_names = cell_factor_names
139
137
 
140
138
  self.codebook_weights = None
141
139
 
@@ -234,18 +232,15 @@ class SURE(nn.Module):
234
232
  )
235
233
 
236
234
  if self.cell_factor_size>0:
237
- self.cell_factor_effect = nn.ModuleList()
238
- for i in np.arange(self.cell_factor_size):
239
- self.cell_factor_effect.append(MLP(
240
- [self.latent_dim+1] + hidden_sizes + [self.latent_dim],
241
- activation=activate_fct,
242
- output_activation=None,
243
- post_layer_fct=post_layer_fct,
244
- post_act_fct=post_act_fct,
245
- allow_broadcast=self.allow_broadcast,
246
- use_cuda=self.use_cuda,
247
- )
248
- )
235
+ self.cell_factor_effect = MLP(
236
+ [self.z_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.z_dim],
237
+ activation=activate_fct,
238
+ output_activation=None,
239
+ post_layer_fct=post_layer_fct,
240
+ post_act_fct=post_act_fct,
241
+ allow_broadcast=self.allow_broadcast,
242
+ use_cuda=self.use_cuda,
243
+ )
249
244
 
250
245
  self.decoder_concentrate = MLP(
251
246
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -449,13 +444,7 @@ class SURE(nn.Module):
449
444
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
450
445
 
451
446
  if self.cell_factor_size>0:
452
- #zus = self.decoder_undesired([zns,us])
453
- zus = None
454
- for i in np.arange(self.cell_factor_size):
455
- if i==0:
456
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
457
- else:
458
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
447
+ zus = self.cell_factor_effect([zns,us])
459
448
  zs = zns+zus
460
449
  else:
461
450
  zs = zns
@@ -645,13 +634,7 @@ class SURE(nn.Module):
645
634
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
646
635
 
647
636
  if self.cell_factor_size>0:
648
- #zus = self.decoder_undesired([zns,us])
649
- zus = None
650
- for i in np.arange(self.cell_factor_size):
651
- if i==0:
652
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
653
- else:
654
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
637
+ zus = self.decoder_undesired([zns,us])
655
638
  zs = zns+zus
656
639
  else:
657
640
  zs = zns
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.0
3
+ Version: 2.1.1
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,5 +1,5 @@
1
- SURE/PerturbFlow.py,sha256=ko6DcOoO-IUEm736n689jmtJSHM_s2hafG-yUmflNSU,52019
2
- SURE/SURE.py,sha256=OmwmA8RnhwrirXgnzqJMxv1BCr1jmjfHitvlmBKhHNA,49245
1
+ SURE/PerturbFlow.py,sha256=RQoIhYQJdQpHdY_sMeDuqurbwvm6IX1XH7SVWG6SmS0,51658
2
+ SURE/SURE.py,sha256=_ZOymj24DLQju0Lb90lKspHPmqIUDDzjIEr9t4qgqCI,48364
3
3
  SURE/SURE2.py,sha256=8wlnMwb1xuf9QUksNkWdWx5ZWq-xIy9NLx8RdUnE82o,48501
4
4
  SURE/__init__.py,sha256=xV10iBbh69g4mjBMb1cQxjuHe8e3Aq7pDzkZmx5G754,260
5
5
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -16,9 +16,9 @@ SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
16
16
  SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
17
17
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
18
18
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
19
- sure_tools-2.1.0.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
20
- sure_tools-2.1.0.dist-info/METADATA,sha256=bH8ARqqdk2bnLC5rk8MvPgBVAb8kaEKnzuHwyyDXpvE,2650
21
- sure_tools-2.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- sure_tools-2.1.0.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
23
- sure_tools-2.1.0.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
24
- sure_tools-2.1.0.dist-info/RECORD,,
19
+ sure_tools-2.1.1.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
20
+ sure_tools-2.1.1.dist-info/METADATA,sha256=ET1LmoMzkRak6WiRpuGf2dcoY7cLgGoZtNmMkcqi6DU,2650
21
+ sure_tools-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ sure_tools-2.1.1.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
23
+ sure_tools-2.1.1.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
24
+ sure_tools-2.1.1.dist-info/RECORD,,