SURE-tools 3.6.14__tar.gz → 3.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {sure_tools-3.6.14 → sure_tools-3.7.0}/PKG-INFO +1 -1
  2. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE.py +11 -11
  3. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE_nsf.py +41 -24
  4. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE_vae.py +40 -23
  5. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-3.6.14 → sure_tools-3.7.0}/setup.py +1 -1
  7. {sure_tools-3.6.14 → sure_tools-3.7.0}/LICENSE +0 -0
  8. {sure_tools-3.6.14 → sure_tools-3.7.0}/README.md +0 -0
  9. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SUREMO.py +0 -0
  10. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE_vanilla.py +0 -0
  11. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/__init__.py +0 -0
  12. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/atac/__init__.py +0 -0
  13. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/atac/utils.py +0 -0
  14. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/dist/__init__.py +0 -0
  15. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/dist/negbinomial.py +0 -0
  16. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/graph/__init__.py +0 -0
  17. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/graph/graph_utils.py +0 -0
  18. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/__init__.py +0 -0
  19. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/custom_mlp.py +0 -0
  20. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/label.py +0 -0
  21. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/queue.py +0 -0
  22. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/utils.py +0 -0
  23. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/SOURCES.txt +0 -0
  24. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
  25. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/requires.txt +0 -0
  26. {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/top_level.txt +0 -0
  27. {sure_tools-3.6.14 → sure_tools-3.7.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 3.6.14
3
+ Version: 3.7.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -109,7 +109,7 @@ class SURE(nn.Module):
109
109
  def __init__(self,
110
110
  input_dim: int,
111
111
  codebook_size: int,
112
- condition_size: int = 0,
112
+ condition_sizes: int = 0,
113
113
  covariate_size: int = 0,
114
114
  method: Literal['flow','vae'] = 'vae',
115
115
  transforms: int = 1,
@@ -134,7 +134,7 @@ class SURE(nn.Module):
134
134
  if method == 'flow':
135
135
  self.engine = SURENF(input_dim=input_dim,
136
136
  codebook_size=codebook_size,
137
- condition_size=condition_size,
137
+ condition_sizes=condition_sizes,
138
138
  covariate_size=covariate_size,
139
139
  transforms=transforms,
140
140
  z_dim=z_dim,
@@ -155,7 +155,7 @@ class SURE(nn.Module):
155
155
  elif method == 'vae':
156
156
  self.engine = SUREVAE(input_dim=input_dim,
157
157
  codebook_size=codebook_size,
158
- condition_size=condition_size,
158
+ condition_sizes=condition_sizes,
159
159
  covariate_size=covariate_size,
160
160
  z_dim=z_dim,
161
161
  z_dist=z_dist,
@@ -214,13 +214,13 @@ class SURE(nn.Module):
214
214
  """
215
215
  return self.engine.hard_assignments(xs=xs, batch_size=batch_size, show_progress=show_progress)
216
216
 
217
- def get_condition_effects(self, xs, cs, batch_size=1024, show_progress=True):
218
- return self.engine.get_condition_effects(xs, cs, batch_size=batch_size, show_progress=show_progress)
217
+ def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
218
+ return self.engine.get_condition_effect(xs, cs, i, batch_size=batch_size, show_progress=show_progress)
219
219
 
220
220
  def predict_cluster(self, xs, batch_size=1024, show_progress=True):
221
221
  return self.engine.predict_cluster(xs, batch_size=batch_size, show_progress=show_progress)
222
222
 
223
- def predict(self, xs, cs, batch_size=1024, show_progress=True):
223
+ def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
224
224
  """
225
225
  Generate gene expression prediction from given cell data and covariates.
226
226
  This function can be used for simulating cells' transcription profiles at new conditions.
@@ -231,14 +231,14 @@ class SURE(nn.Module):
231
231
  :param batch_size: Data size per batch
232
232
  :param show_progress: Toggle on or off message output
233
233
  """
234
- return self.engine.predict(xs, cs, batch_size, show_progress)
234
+ return self.engine.predict(xs, cs_list, batch_size, show_progress)
235
235
 
236
236
  def preprocess(self, xs, threshold=0):
237
237
  return self.engine.preprocess(xs=xs, threshold=threshold)
238
238
 
239
- def fit(self, xs,
240
- cs = None,
241
- fs = None,
239
+ def fit(self, xs:np.array,
240
+ css:list = None,
241
+ fs:np.array = None,
242
242
  num_epochs: int = 100,
243
243
  learning_rate: float = 0.0001,
244
244
  use_mask: bool = False,
@@ -284,7 +284,7 @@ class SURE(nn.Module):
284
284
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
285
285
  the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
286
286
  """
287
- self.engine.fit(xs=xs, cs=cs, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
287
+ self.engine.fit(xs=xs, css=css, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
288
288
  beta_1=beta_1, weight_decay=weight_decay, decay_rate=decay_rate, config_enum=config_enum, threshold=threshold,
289
289
  use_jax=use_jax, show_progress=show_progress, patience=patience, min_delta=min_delta, restore_best_weights=restore_best_weights,
290
290
  monitor=monitor)
@@ -119,12 +119,12 @@ class SURENF(nn.Module):
119
119
  def __init__(self,
120
120
  input_dim: int,
121
121
  codebook_size: int,
122
- condition_size: int = 0,
122
+ condition_sizes: list = [0],
123
123
  covariate_size: int = 0,
124
124
  transforms: int = 1,
125
125
  z_dim: int = 50,
126
126
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
127
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
127
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
128
128
  dispersion: float = 10.0,
129
129
  use_zeroinflate: bool = True,
130
130
  hidden_layers: list = [500],
@@ -141,7 +141,7 @@ class SURENF(nn.Module):
141
141
  super().__init__()
142
142
 
143
143
  self.input_dim = input_dim
144
- self.condition_size = condition_size
144
+ self.condition_sizes = condition_sizes
145
145
  self.covariate_size = covariate_size
146
146
  self.dispersion = dispersion
147
147
  self.latent_dim = z_dim
@@ -247,16 +247,19 @@ class SURENF(nn.Module):
247
247
  self.encoder_zn = zuko.flows.NSF(features=self.latent_dim, context=self.input_dim,
248
248
  transforms=self.transforms, hidden_features=self.flow_hidden_layers)
249
249
 
250
- if self.condition_size>0:
251
- self.condition_effect = ZeroBiasMLP3(
252
- [self.condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
253
- activation=activate_fct,
254
- output_activation=None,
255
- post_layer_fct=post_layer_fct,
256
- post_act_fct=post_act_fct,
257
- allow_broadcast=self.allow_broadcast,
258
- use_cuda=self.use_cuda,
259
- )
250
+ if np.sum(self.condition_sizes)>0:
251
+ self.condition_effects = nn.ModuleList()
252
+ for condition_size in self.condition_sizes:
253
+ self.condition_effects.append(ZeroBiasMLP3(
254
+ [condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
255
+ activation=activate_fct,
256
+ output_activation=None,
257
+ post_layer_fct=post_layer_fct,
258
+ post_act_fct=post_act_fct,
259
+ allow_broadcast=self.allow_broadcast,
260
+ use_cuda=self.use_cuda,
261
+ )
262
+ )
260
263
  if self.covariate_size>0:
261
264
  self.covariate_effect = ZeroBiasMLP2(
262
265
  [self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
@@ -393,8 +396,13 @@ class SURENF(nn.Module):
393
396
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
394
397
 
395
398
  zs = zns
396
- if (self.covariate_size>0) and (cs is not None):
397
- zcs = self.condition_effect([cs,zns])
399
+ if (np.sum(self.condition_sizes)>0) and (cs is not None):
400
+ zcs = torch.zeros_like(zs)
401
+ shift = 0
402
+ for i,condition_size in enumerate(self.condition_sizes):
403
+ cs_i = cs[:,shift:(shift+condition_size)]
404
+ zcs += self.condition_effects[i]([cs_i,zns])
405
+ shift += condition_size
398
406
  else:
399
407
  zcs = torch.zeros_like(zs)
400
408
  if (self.covariate_size>0) and (fs is not None):
@@ -565,7 +573,7 @@ class SURENF(nn.Module):
565
573
  A = np.concatenate(A)
566
574
  return A
567
575
 
568
- def get_condition_effects(self, xs, cs, batch_size=1024, show_progress=True):
576
+ def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
569
577
  xs = self.preprocess(xs)
570
578
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
571
579
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
@@ -581,7 +589,7 @@ class SURENF(nn.Module):
581
589
  C_batch = cs[idx].to(self.get_device())
582
590
 
583
591
  z_basal = self._get_cell_embedding(X_batch)
584
- dzs = self.condition_effect([C_batch,z_basal])
592
+ dzs = self.condition_effects[i]([C_batch,z_basal])
585
593
 
586
594
  A.append(tensor_to_numpy(dzs))
587
595
  pbar.update(1)
@@ -593,7 +601,7 @@ class SURENF(nn.Module):
593
601
  zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
594
602
  return self.kmeans.predict(zs)
595
603
 
596
- def predict(self, xs, cs, batch_size=1024, show_progress=True):
604
+ def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
597
605
  """
598
606
  Generate gene expression prediction from given cell data and covariates.
599
607
  This function can be used for simulating cells' transcription profiles at new conditions.
@@ -606,6 +614,7 @@ class SURENF(nn.Module):
606
614
  """
607
615
  xs = self.preprocess(xs)
608
616
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
617
+ cs = np.hstack(cs_list)
609
618
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
610
619
 
611
620
  dataset = CustomDataset(xs)
@@ -619,7 +628,14 @@ class SURENF(nn.Module):
619
628
  library_size = torch.sum(X_batch, 1)
620
629
 
621
630
  z_basal = self._get_cell_embedding(X_batch)
622
- zcs = self.condition_effect([C_batch, z_basal])
631
+
632
+ zcs = torch.zeros_like(z_basal)
633
+ shift = 0
634
+ for i, condition_size in enumerate(self.condition_sizes):
635
+ C_batch_i = C_batch[:, shift:(shift+condition_size)]
636
+ zcs += self.condition_effects[i]([C_batch_i,z_basal])
637
+ shift += condition_size
638
+
623
639
  zfs = torch.zeros_like(z_basal)
624
640
 
625
641
  log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
@@ -826,9 +842,9 @@ class SURENF(nn.Module):
826
842
  if name in param_store:
827
843
  param_store[name] = param'''
828
844
 
829
- def fit(self, xs,
830
- cs = None,
831
- fs = None,
845
+ def fit(self, xs:np.array,
846
+ css:list = None,
847
+ fs:np.array = None,
832
848
  num_epochs: int = 100,
833
849
  learning_rate: float = 0.0001,
834
850
  use_mask: bool = False,
@@ -889,7 +905,8 @@ class SURENF(nn.Module):
889
905
 
890
906
  xs = self.preprocess(xs, threshold=threshold)
891
907
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
892
- if cs is not None:
908
+ if css is not None:
909
+ cs = np.hstack(css)
893
910
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
894
911
  if fs is not None:
895
912
  fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
@@ -945,7 +962,7 @@ class SURENF(nn.Module):
945
962
  for batch_x, idx in dataloader:
946
963
  batch_x = batch_x.to(self.get_device())
947
964
  for loss_id in range(num_losses):
948
- if cs is None:
965
+ if css is None:
949
966
  batch_c = None
950
967
  else:
951
968
  batch_c = cs[idx].to(self.get_device())
@@ -119,7 +119,7 @@ class SUREVAE(nn.Module):
119
119
  def __init__(self,
120
120
  input_dim: int,
121
121
  codebook_size: int,
122
- condition_size: int = 0,
122
+ condition_sizes: list = [0],
123
123
  covariate_size: int = 0,
124
124
  transforms: int = 1,
125
125
  z_dim: int = 50,
@@ -141,7 +141,7 @@ class SUREVAE(nn.Module):
141
141
  super().__init__()
142
142
 
143
143
  self.input_dim = input_dim
144
- self.condition_size = condition_size
144
+ self.condition_sizes = condition_sizes
145
145
  self.covariate_size = covariate_size
146
146
  self.dispersion = dispersion
147
147
  self.latent_dim = z_dim
@@ -255,16 +255,19 @@ class SUREVAE(nn.Module):
255
255
  use_cuda=self.use_cuda,
256
256
  )
257
257
 
258
- if self.condition_size>0:
259
- self.condition_effect = ZeroBiasMLP3(
260
- [self.condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
261
- activation=activate_fct,
262
- output_activation=None,
263
- post_layer_fct=post_layer_fct,
264
- post_act_fct=post_act_fct,
265
- allow_broadcast=self.allow_broadcast,
266
- use_cuda=self.use_cuda,
267
- )
258
+ if np.sum(self.condition_sizes)>0:
259
+ self.condition_effects = nn.ModuleList()
260
+ for condition_size in self.condition_sizes:
261
+ self.condition_effects.append(ZeroBiasMLP3(
262
+ [condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
263
+ activation=activate_fct,
264
+ output_activation=None,
265
+ post_layer_fct=post_layer_fct,
266
+ post_act_fct=post_act_fct,
267
+ allow_broadcast=self.allow_broadcast,
268
+ use_cuda=self.use_cuda,
269
+ )
270
+ )
268
271
  if self.covariate_size>0:
269
272
  self.covariate_effect = ZeroBiasMLP2(
270
273
  [self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
@@ -401,8 +404,13 @@ class SUREVAE(nn.Module):
401
404
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
402
405
 
403
406
  zs = zns
404
- if (self.condition_size>0) and (cs is not None):
405
- zcs = self.condition_effect([cs,zns])
407
+ if (np.sum(self.condition_sizes)>0) and (cs is not None):
408
+ zcs = torch.zeros_like(zs)
409
+ shift = 0
410
+ for i, condition_size in enumerate(self.condition_sizes):
411
+ cs_i = cs[:,shift:(shift+condition_size)]
412
+ zcs += self.condition_effects[i]([cs_i,zns])
413
+ shift += condition_size
406
414
  else:
407
415
  zcs = torch.zeros_like(zs)
408
416
  if (self.covariate_size>0) and (fs is not None):
@@ -578,7 +586,7 @@ class SUREVAE(nn.Module):
578
586
  A = np.concatenate(A)
579
587
  return A
580
588
 
581
- def get_condition_effects(self, xs, cs, batch_size=1024, show_progress=True):
589
+ def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
582
590
  xs = self.preprocess(xs)
583
591
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
584
592
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
@@ -594,7 +602,7 @@ class SUREVAE(nn.Module):
594
602
  C_batch = cs[idx].to(self.get_device())
595
603
 
596
604
  z_basal = self._get_cell_embedding(X_batch)
597
- dzs = self.condition_effect([C_batch,z_basal])
605
+ dzs = self.condition_effects[i]([C_batch,z_basal])
598
606
 
599
607
  A.append(tensor_to_numpy(dzs))
600
608
  pbar.update(1)
@@ -606,7 +614,7 @@ class SUREVAE(nn.Module):
606
614
  zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
607
615
  return self.kmeans.predict(zs)
608
616
 
609
- def predict(self, xs, cs, batch_size=1024, show_progress=True):
617
+ def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
610
618
  """
611
619
  Generate gene expression prediction from given cell data and covariates.
612
620
  This function can be used for simulating cells' transcription profiles at new conditions.
@@ -619,6 +627,7 @@ class SUREVAE(nn.Module):
619
627
  """
620
628
  xs = self.preprocess(xs)
621
629
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
630
+ cs = np.hstack(cs_list)
622
631
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
623
632
 
624
633
  dataset = CustomDataset(xs)
@@ -632,7 +641,14 @@ class SUREVAE(nn.Module):
632
641
  library_size = torch.sum(X_batch, 1)
633
642
 
634
643
  z_basal = self._get_cell_embedding(X_batch)
635
- zcs = self.condition_effect([C_batch,z_basal])
644
+
645
+ zcs = torch.zeros_like(z_basal)
646
+ shift = 0
647
+ for i, condition_size in enumerate(self.condition_sizes):
648
+ C_batch_i = C_batch[:, shift:(shift+condition_size)]
649
+ zcs += self.condition_effects[i]([C_batch_i,z_basal])
650
+ shift += condition_size
651
+
636
652
  zfs = torch.zeros_like(z_basal)
637
653
 
638
654
  log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
@@ -758,9 +774,9 @@ class SUREVAE(nn.Module):
758
774
  pbar.set_postfix({'loss': str_loss})
759
775
  pbar.update(1)'''
760
776
 
761
- def fit(self, xs,
762
- cs = None,
763
- fs = None,
777
+ def fit(self, xs: np.array,
778
+ css: list = None,
779
+ fs: np.array = None,
764
780
  num_epochs: int = 100,
765
781
  learning_rate: float = 0.0001,
766
782
  use_mask: bool = False,
@@ -821,7 +837,8 @@ class SUREVAE(nn.Module):
821
837
 
822
838
  xs = self.preprocess(xs, threshold=threshold)
823
839
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
824
- if cs is not None:
840
+ if css is not None:
841
+ cs = np.hstack(css)
825
842
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
826
843
  if fs is not None:
827
844
  fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
@@ -877,7 +894,7 @@ class SUREVAE(nn.Module):
877
894
  for batch_x, idx in dataloader:
878
895
  batch_x = batch_x.to(self.get_device())
879
896
  for loss_id in range(num_losses):
880
- if cs is None:
897
+ if css is None:
881
898
  batch_c = None
882
899
  else:
883
900
  batch_c = cs[idx].to(self.get_device())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 3.6.14
3
+ Version: 3.7.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='3.6.14',
8
+ version='3.7.0',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes
File without changes