SURE-tools 4.0.1__py3-none-any.whl → 4.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/SURE.py CHANGED
@@ -79,7 +79,7 @@ class SURE(nn.Module):
79
79
  Number of features (e.g., genes, peaks, proteins, etc.) per cell.
80
80
  codebook_size
81
81
  Number of metacells.
82
- covariate_size
82
+ covariate_sizes
83
83
  Number of cell-level factors.
84
84
  transforms
85
85
  Number of neural spline flows
@@ -111,7 +111,7 @@ class SURE(nn.Module):
111
111
  codebook_size: int,
112
112
  context_sizes: list = [0],
113
113
  perturb_size: int = 0,
114
- covariate_size: int = 0,
114
+ covariate_sizes: int = 0,
115
115
  method: Literal['flow','vae'] = 'vae',
116
116
  transforms: int = 1,
117
117
  z_dim: int = 30,
@@ -137,7 +137,7 @@ class SURE(nn.Module):
137
137
  codebook_size=codebook_size,
138
138
  context_sizes=context_sizes,
139
139
  perturb_size=perturb_size,
140
- covariate_size=covariate_size,
140
+ covariate_sizes=covariate_sizes,
141
141
  transforms=transforms,
142
142
  z_dim=z_dim,
143
143
  z_dist=z_dist,
@@ -159,7 +159,7 @@ class SURE(nn.Module):
159
159
  codebook_size=codebook_size,
160
160
  context_sizes=context_sizes,
161
161
  perturb_size=perturb_size,
162
- covariate_size=covariate_size,
162
+ covariate_sizes=covariate_sizes,
163
163
  z_dim=z_dim,
164
164
  z_dist=z_dist,
165
165
  loss_func=loss_func,
@@ -245,7 +245,7 @@ class SURE(nn.Module):
245
245
  def fit(self, xs:np.array,
246
246
  css:list = None,
247
247
  ps:np.array = None,
248
- fs:np.array = None,
248
+ fss:list = None,
249
249
  num_epochs: int = 100,
250
250
  learning_rate: float = 0.0001,
251
251
  use_mask: bool = False,
@@ -291,7 +291,7 @@ class SURE(nn.Module):
291
291
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
292
292
  the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
293
293
  """
294
- self.engine.fit(xs=xs, css=css, ps=ps, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
294
+ self.engine.fit(xs=xs, css=css, ps=ps, fss=fss, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
295
295
  beta_1=beta_1, weight_decay=weight_decay, decay_rate=decay_rate, config_enum=config_enum, threshold=threshold,
296
296
  use_jax=use_jax, show_progress=show_progress, patience=patience, min_delta=min_delta, restore_best_weights=restore_best_weights,
297
297
  monitor=monitor)
SURE/SURE_nsf.py CHANGED
@@ -89,7 +89,7 @@ class SURENF(nn.Module):
89
89
  Number of features (e.g., genes, peaks, proteins, etc.) per cell.
90
90
  codebook_size
91
91
  Number of metacells.
92
- covariate_size
92
+ covariate_sizes
93
93
  Number of cell-level factors.
94
94
  transforms
95
95
  Number of neural spline flows
@@ -121,7 +121,7 @@ class SURENF(nn.Module):
121
121
  codebook_size: int,
122
122
  context_sizes: list = [0],
123
123
  perturb_size: int = 0,
124
- covariate_size: int = 0,
124
+ covariate_sizes: int = 0,
125
125
  transforms: int = 1,
126
126
  z_dim: int = 50,
127
127
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
@@ -144,7 +144,7 @@ class SURENF(nn.Module):
144
144
  self.input_dim = input_dim
145
145
  self.context_sizes = context_sizes
146
146
  self.perturb_size = perturb_size
147
- self.covariate_size = covariate_size
147
+ self.covariate_sizes = covariate_sizes
148
148
  self.dispersion = dispersion
149
149
  self.latent_dim = z_dim
150
150
  self.latent_dist = z_dist
@@ -272,16 +272,19 @@ class SURENF(nn.Module):
272
272
  allow_broadcast=self.allow_broadcast,
273
273
  use_cuda=self.use_cuda,
274
274
  )
275
- if self.covariate_size>0:
276
- self.covariate_effect = ZeroBiasMLP2(
277
- [self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
278
- activation=activate_fct,
279
- output_activation=None,
280
- post_layer_fct=post_layer_fct,
281
- post_act_fct=post_act_fct,
282
- allow_broadcast=self.allow_broadcast,
283
- use_cuda=self.use_cuda,
284
- )
275
+ if np.sum(self.covariate_sizes)>0:
276
+ self.covariate_effects = nn.ModuleList()
277
+ for covariate_size in self.covariate_sizes:
278
+ self.covariate_effects.append(ZeroBiasMLP2(
279
+ [covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
280
+ activation=activate_fct,
281
+ output_activation=None,
282
+ post_layer_fct=post_layer_fct,
283
+ post_act_fct=post_act_fct,
284
+ allow_broadcast=self.allow_broadcast,
285
+ use_cuda=self.use_cuda,
286
+ )
287
+ )
285
288
 
286
289
  self.decoder_log_mu = MLP(
287
290
  [self.latent_dim+self.latent_dim+self.latent_dim] + self.decoder_hidden_layers + [self.input_dim],
@@ -421,8 +424,13 @@ class SURENF(nn.Module):
421
424
  zps = self.perturb_effect([ps, zs+zcs])
422
425
  else:
423
426
  zps = torch.zeros_like(zs)
424
- if (self.covariate_size>0) and (fs is not None):
425
- zfs = self.covariate_effect(fs)
427
+ if (self.covariate_sizes>0) and (fs is not None):
428
+ zfs = torch.zeros_like(zs)
429
+ shift = 0
430
+ for i, covariate_size in enumerate(self.covariate_sizes):
431
+ fs_i = fs[:,shift:(shift+covariate_size)]
432
+ zfs += self.covariate_effects[i](fs_i)
433
+ shift += covariate_size
426
434
  else:
427
435
  zfs = torch.zeros_like(zs)
428
436
 
@@ -891,7 +899,7 @@ class SURENF(nn.Module):
891
899
  def fit(self, xs:np.array,
892
900
  css:list = None,
893
901
  ps: np.array = None,
894
- fs:np.array = None,
902
+ fss:list = None,
895
903
  num_epochs: int = 100,
896
904
  learning_rate: float = 0.0001,
897
905
  use_mask: bool = False,
@@ -957,7 +965,8 @@ class SURENF(nn.Module):
957
965
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
958
966
  if ps is not None:
959
967
  ps = convert_to_tensor(ps, dtype=self.dtype, device='cpu')
960
- if fs is not None:
968
+ if fss is not None:
969
+ fs = np.hstack(fss)
961
970
  fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
962
971
 
963
972
  dataset = CustomDataset(xs)
@@ -1019,7 +1028,7 @@ class SURENF(nn.Module):
1019
1028
  batch_p = None
1020
1029
  else:
1021
1030
  batch_p = ps[idx].to(self.get_device())
1022
- if fs is None:
1031
+ if fss is None:
1023
1032
  batch_f = None
1024
1033
  else:
1025
1034
  batch_f = fs[idx].to(self.get_device())
SURE/SURE_vae.py CHANGED
@@ -121,7 +121,7 @@ class SUREVAE(nn.Module):
121
121
  codebook_size: int,
122
122
  context_sizes: list = [0],
123
123
  perturb_size: int = 0,
124
- covariate_size: int = 0,
124
+ covariate_sizes: int = 0,
125
125
  transforms: int = 1,
126
126
  z_dim: int = 50,
127
127
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
@@ -144,7 +144,7 @@ class SUREVAE(nn.Module):
144
144
  self.input_dim = input_dim
145
145
  self.context_sizes = context_sizes
146
146
  self.perturb_size = perturb_size
147
- self.covariate_size = covariate_size
147
+ self.covariate_sizes = covariate_sizes
148
148
  self.dispersion = dispersion
149
149
  self.latent_dim = z_dim
150
150
  self.latent_dist = z_dist
@@ -273,7 +273,7 @@ class SUREVAE(nn.Module):
273
273
 
274
274
  if self.perturb_size>0:
275
275
  self.perturb_effect = ZeroBiasMLP3(
276
- [self.perturb_size+self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
276
+ [self.perturb_size+self.latent_dim+self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
277
277
  activation=activate_fct,
278
278
  output_activation=None,
279
279
  post_layer_fct=post_layer_fct,
@@ -282,16 +282,19 @@ class SUREVAE(nn.Module):
282
282
  use_cuda=self.use_cuda,
283
283
  )
284
284
 
285
- if self.covariate_size>0:
286
- self.covariate_effect = ZeroBiasMLP2(
287
- [self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
288
- activation=activate_fct,
289
- output_activation=None,
290
- post_layer_fct=post_layer_fct,
291
- post_act_fct=post_act_fct,
292
- allow_broadcast=self.allow_broadcast,
293
- use_cuda=self.use_cuda,
294
- )
285
+ if np.sum(self.covariate_sizes)>0:
286
+ self.covariate_effects = nn.ModuleList()
287
+ for covariate_size in self.covariate_sizes:
288
+ self.covariate_effects.append(ZeroBiasMLP2(
289
+ [covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
290
+ activation=activate_fct,
291
+ output_activation=None,
292
+ post_layer_fct=post_layer_fct,
293
+ post_act_fct=post_act_fct,
294
+ allow_broadcast=self.allow_broadcast,
295
+ use_cuda=self.use_cuda,
296
+ )
297
+ )
295
298
 
296
299
  self.decoder_log_mu = MLP(
297
300
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_dim],
@@ -428,11 +431,16 @@ class SUREVAE(nn.Module):
428
431
  else:
429
432
  zcs = torch.zeros_like(zs)
430
433
  if (self.perturb_size>0) and (ps is not None):
431
- zps = self.perturb_effect([ps, zs+zcs])
434
+ zps = self.perturb_effect([ps, zs, zcs])
432
435
  else:
433
436
  zps = torch.zeros_like(zs)
434
- if (self.covariate_size>0) and (fs is not None):
435
- zfs = self.covariate_effect(fs)
437
+ if (np.sum(self.covariate_sizes)>0) and (fs is not None):
438
+ zfs = torch.zeros_like(zs)
439
+ shift = 0
440
+ for i, covariate_size in enumerate(self.covariate_sizes):
441
+ fs_i = fs[:,shift:(shift+covariate_size)]
442
+ zfs += self.covariate_effects[i](fs_i)
443
+ shift += covariate_size
436
444
  else:
437
445
  zfs = torch.zeros_like(zs)
438
446
 
@@ -640,7 +648,7 @@ class SUREVAE(nn.Module):
640
648
  C_batch = zcs[idx].to(self.get_device())
641
649
  P_batch = ps[idx].to(self.get_device())
642
650
 
643
- dzs = self.perturb_effect([P_batch,Z_batch+C_batch])
651
+ dzs = self.perturb_effect([P_batch,Z_batch,C_batch])
644
652
 
645
653
  A.append(tensor_to_numpy(dzs))
646
654
  pbar.update(1)
@@ -693,7 +701,7 @@ class SUREVAE(nn.Module):
693
701
 
694
702
  if ps is not None:
695
703
  P_batch = ps[idx].to(self.get_device())
696
- zps = self.perturb_effect([P_batch,z_basal+zcs])
704
+ zps = self.perturb_effect([P_batch,z_basal,zcs])
697
705
  else:
698
706
  zps = torch.zeros_like(z_basal)
699
707
 
@@ -825,7 +833,7 @@ class SUREVAE(nn.Module):
825
833
  def fit(self, xs: np.array,
826
834
  css: list = None,
827
835
  ps: np.array = None,
828
- fs: np.array = None,
836
+ fss: list = None,
829
837
  num_epochs: int = 100,
830
838
  learning_rate: float = 0.0001,
831
839
  use_mask: bool = False,
@@ -891,7 +899,8 @@ class SUREVAE(nn.Module):
891
899
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
892
900
  if ps is not None:
893
901
  ps = convert_to_tensor(ps, dtype=self.dtype, device='cpu')
894
- if fs is not None:
902
+ if fss is not None:
903
+ fs = np.hstack(fss)
895
904
  fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
896
905
 
897
906
  dataset = CustomDataset(xs)
@@ -953,7 +962,7 @@ class SUREVAE(nn.Module):
953
962
  batch_p = None
954
963
  else:
955
964
  batch_p = ps[idx].to(self.get_device())
956
- if fs is None:
965
+ if fss is None:
957
966
  batch_f = None
958
967
  else:
959
968
  batch_f = fs[idx].to(self.get_device())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 4.0.1
3
+ Version: 4.0.3
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,7 +1,7 @@
1
- SURE/SURE.py,sha256=8E39np6zhLbT1cp1xYOg5xLwzFHBupKIm1ydLxNKtqM,16654
1
+ SURE/SURE.py,sha256=YcJ6Lxt4ziQXZ5Yk02Ss-fJRl_M1Z0QghLY_yeHt5Pc,16659
2
2
  SURE/SUREMO.py,sha256=hN0G0ZEBNQdmj0gGlBIy1wjKGKqAMUemDjQeICzvNUY,47644
3
- SURE/SURE_nsf.py,sha256=VR6YgEiIfu7mRH0XLbovfjd2X3WFuom2J4_AgHrq2dM,49040
4
- SURE/SURE_vae.py,sha256=b3xluu49kmuDlcpZnTAJc92SvQI8MPDN9ntBfnKTJOQ,45850
3
+ SURE/SURE_nsf.py,sha256=uKpau3nDdlX58lLzPzecZULKobyOYB7xKtcYo6vwbak,49541
4
+ SURE/SURE_vae.py,sha256=X4aNb6Wv9Xn-w1Xc-wUuio6YsxX0dIwdJqsWG44hCa0,46375
5
5
  SURE/SURE_vanilla.py,sha256=I1RHHCpzk8ml1vMdH_gITOzAFXrHYUA7IAwbVSoxmBo,27327
6
6
  SURE/__init__.py,sha256=eJN0vlGblWir1JHaoiQqbQHzS_C1PNYcA_ls7UviTqc,444
7
7
  SURE/atac/__init__.py,sha256=3smP8IKHfwNCd1G_sZH3pKHXuLkLpFuLtjUTUSy7_As,34
@@ -15,8 +15,8 @@ SURE/utils/custom_mlp.py,sha256=XvviNUYeaZ5D1evqETyWLxgtL56mmFQR_pxsYpKw0yY,1170
15
15
  SURE/utils/label.py,sha256=joKO1mSkjZXeLvSeC7GluQk4-_qgGgPqlwWixdcbKMQ,4648
16
16
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
17
17
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
18
- sure_tools-4.0.1.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
19
- sure_tools-4.0.1.dist-info/METADATA,sha256=m4l4x8KnzdyfZquikfwrSqZXegaeRl6X9plr-fcuGrc,1661
20
- sure_tools-4.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
- sure_tools-4.0.1.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
22
- sure_tools-4.0.1.dist-info/RECORD,,
18
+ sure_tools-4.0.3.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
19
+ sure_tools-4.0.3.dist-info/METADATA,sha256=DlTyOxPOevVRibU7zjiIorXK-Te5I5uMI3zec3H94zw,1661
20
+ sure_tools-4.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
+ sure_tools-4.0.3.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
22
+ sure_tools-4.0.3.dist-info/RECORD,,