SURE-tools 2.1.14__py3-none-any.whl → 2.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/PerturbFlow.py CHANGED
@@ -835,6 +835,44 @@ class PerturbFlow(nn.Module):
835
835
  R = np.concatenate(R)
836
836
  return R
837
837
 
838
+ def _count(self,concentrate):
839
+ if self.loss_func == 'bernoulli':
840
+ counts = self.sigmoid(concentrate)
841
+ else:
842
+ counts = concentrate.exp()
843
+ return counts
844
+
845
+ def _count_sample(self,concentrate):
846
+ if self.loss_func == 'bernoulli':
847
+ logits = concentrate
848
+ counts = dist.Bernoulli(logits=logits).to_event(1).sample()
849
+ else:
850
+ counts = self._count(concentrate=concentrate)
851
+ counts = dist.Poisson(rate=counts).to_event(1).sample()
852
+ return counts
853
+
854
+ def get_counts(self, zs,
855
+ batch_size: int = 1024,
856
+ use_sampler: bool = False):
857
+
858
+ zs = convert_to_tensor(zs, device=self.get_device())
859
+ dataset = CustomDataset(zs)
860
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
861
+
862
+ E = []
863
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
864
+ for Z_batch, _ in dataloader:
865
+ concentrate = self._expression(Z_batch)
866
+ if use_sampler:
867
+ counts = self._count_sample(concentrate)
868
+ else:
869
+ counts = self._count(concentrate)
870
+ E.append(tensor_to_numpy(counts))
871
+ pbar.update(1)
872
+
873
+ E = np.concatenate(E)
874
+ return E
875
+
838
876
  def preprocess(self, xs, threshold=0):
839
877
  if self.loss_func == 'bernoulli':
840
878
  ad = sc.AnnData(xs)
SURE/utils/custom_mlp.py CHANGED
@@ -85,6 +85,7 @@ class MLP(nn.Module):
85
85
  post_act_fct=lambda layer_ix, total_layers, layer: None,
86
86
  allow_broadcast=False,
87
87
  use_cuda=False,
88
+ bias=True,
88
89
  ):
89
90
  # init the module object
90
91
  super().__init__()
@@ -114,11 +115,12 @@ class MLP(nn.Module):
114
115
  assert type(layer_size) == int, "Hidden layer sizes must be ints"
115
116
 
116
117
  # get our nn layer module (in this case nn.Linear by default)
117
- cur_linear_layer = nn.Linear(last_layer_size, layer_size)
118
+ cur_linear_layer = nn.Linear(last_layer_size, layer_size, bias=bias)
118
119
 
119
120
  # for numerical stability -- initialize the layer properly
120
121
  cur_linear_layer.weight.data.normal_(0, 0.001)
121
- cur_linear_layer.bias.data.normal_(0, 0.001)
122
+ if bias:
123
+ cur_linear_layer.bias.data.normal_(0, 0.001)
122
124
 
123
125
  # use GPUs to share data during training (if available)
124
126
  if use_cuda:
@@ -160,7 +162,7 @@ class MLP(nn.Module):
160
162
  ), "output_size must be int, list, tuple"
161
163
 
162
164
  if type(output_size) == int:
163
- all_modules.append(nn.Linear(last_layer_size, output_size))
165
+ all_modules.append(nn.Linear(last_layer_size, output_size, bias=bias))
164
166
  if output_activation is not None:
165
167
  all_modules.append(
166
168
  call_nn_op(output_activation)
@@ -179,7 +181,7 @@ class MLP(nn.Module):
179
181
  split_layer = []
180
182
 
181
183
  # we have an activation function
182
- split_layer.append(nn.Linear(last_layer_size, out_size))
184
+ split_layer.append(nn.Linear(last_layer_size, out_size, bias=bias))
183
185
 
184
186
  # then we get our output activation (either we repeat all or we index into a same sized array)
185
187
  act_out_fct = (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.14
3
+ Version: 2.1.16
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/PerturbFlow.py,sha256=eW_RUuNd-D4UUY-YhCHklWm0TMdOdfyuzfvwRWwfwAc,49553
1
+ SURE/PerturbFlow.py,sha256=NSNXAct139286O8f8k_NaONIPYB_2sdFmtQ19gtwO6Y,50925
2
2
  SURE/SURE.py,sha256=xMD6VBYsgk-bZ_xBWzpdGyxEAleonNRoPkZAxAX467s,47444
3
3
  SURE/SURE2.py,sha256=8wlnMwb1xuf9QUksNkWdWx5ZWq-xIy9NLx8RdUnE82o,48501
4
4
  SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
@@ -16,12 +16,12 @@ SURE/flow/quiver.py,sha256=_euFqSaRrDoZ_oOabOx20LOoUTJ__XPhLW-vzLNQfAo,1859
16
16
  SURE/perturb/__init__.py,sha256=ouxShhbxZM4r5Gf7GmKiutrsmtyq7QL8rHjhgF0BU08,32
17
17
  SURE/perturb/perturb.py,sha256=CqO3xPfNA3cG175tadDidKvGsTu_yKfJRRLn_93awKM,3303
18
18
  SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
19
- SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
19
+ SURE/utils/custom_mlp.py,sha256=n0GJj8GKSAqvyokCL_6LugbeedzqLdfBijq24B_I9dA,7113
20
20
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
21
21
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
22
- sure_tools-2.1.14.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
23
- sure_tools-2.1.14.dist-info/METADATA,sha256=b19udVA1bQF9_lsgyWblCnvlp1P3gX-kd_iB7E64hLE,2651
24
- sure_tools-2.1.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- sure_tools-2.1.14.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
26
- sure_tools-2.1.14.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
27
- sure_tools-2.1.14.dist-info/RECORD,,
22
+ sure_tools-2.1.16.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
23
+ sure_tools-2.1.16.dist-info/METADATA,sha256=Va8l2Qh9jLwOLeiCUDF6Oudx_MqPzN-Opxd6-IMjnRs,2651
24
+ sure_tools-2.1.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ sure_tools-2.1.16.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
26
+ sure_tools-2.1.16.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
27
+ sure_tools-2.1.16.dist-info/RECORD,,