SURE-tools 3.6.13__tar.gz → 3.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {sure_tools-3.6.13 → sure_tools-3.7.0}/PKG-INFO +1 -1
  2. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE.py +11 -11
  3. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE_nsf.py +42 -28
  4. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE_vae.py +40 -25
  5. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-3.6.13 → sure_tools-3.7.0}/setup.py +1 -1
  7. {sure_tools-3.6.13 → sure_tools-3.7.0}/LICENSE +0 -0
  8. {sure_tools-3.6.13 → sure_tools-3.7.0}/README.md +0 -0
  9. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SUREMO.py +0 -0
  10. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE_vanilla.py +0 -0
  11. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/__init__.py +0 -0
  12. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/atac/__init__.py +0 -0
  13. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/atac/utils.py +0 -0
  14. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/dist/__init__.py +0 -0
  15. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/dist/negbinomial.py +0 -0
  16. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/graph/__init__.py +0 -0
  17. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/graph/graph_utils.py +0 -0
  18. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/__init__.py +0 -0
  19. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/custom_mlp.py +0 -0
  20. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/label.py +0 -0
  21. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/queue.py +0 -0
  22. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/utils.py +0 -0
  23. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/SOURCES.txt +0 -0
  24. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
  25. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/requires.txt +0 -0
  26. {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/top_level.txt +0 -0
  27. {sure_tools-3.6.13 → sure_tools-3.7.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 3.6.13
3
+ Version: 3.7.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -109,7 +109,7 @@ class SURE(nn.Module):
109
109
  def __init__(self,
110
110
  input_dim: int,
111
111
  codebook_size: int,
112
- condition_size: int = 0,
112
+ condition_sizes: int = 0,
113
113
  covariate_size: int = 0,
114
114
  method: Literal['flow','vae'] = 'vae',
115
115
  transforms: int = 1,
@@ -134,7 +134,7 @@ class SURE(nn.Module):
134
134
  if method == 'flow':
135
135
  self.engine = SURENF(input_dim=input_dim,
136
136
  codebook_size=codebook_size,
137
- condition_size=condition_size,
137
+ condition_sizes=condition_sizes,
138
138
  covariate_size=covariate_size,
139
139
  transforms=transforms,
140
140
  z_dim=z_dim,
@@ -155,7 +155,7 @@ class SURE(nn.Module):
155
155
  elif method == 'vae':
156
156
  self.engine = SUREVAE(input_dim=input_dim,
157
157
  codebook_size=codebook_size,
158
- condition_size=condition_size,
158
+ condition_sizes=condition_sizes,
159
159
  covariate_size=covariate_size,
160
160
  z_dim=z_dim,
161
161
  z_dist=z_dist,
@@ -214,13 +214,13 @@ class SURE(nn.Module):
214
214
  """
215
215
  return self.engine.hard_assignments(xs=xs, batch_size=batch_size, show_progress=show_progress)
216
216
 
217
- def get_condition_effects(self, xs, cs, batch_size=1024, show_progress=True):
218
- return self.engine.get_condition_effects(xs, cs, batch_size=batch_size, show_progress=show_progress)
217
+ def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
218
+ return self.engine.get_condition_effect(xs, cs, i, batch_size=batch_size, show_progress=show_progress)
219
219
 
220
220
  def predict_cluster(self, xs, batch_size=1024, show_progress=True):
221
221
  return self.engine.predict_cluster(xs, batch_size=batch_size, show_progress=show_progress)
222
222
 
223
- def predict(self, xs, cs, batch_size=1024, show_progress=True):
223
+ def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
224
224
  """
225
225
  Generate gene expression prediction from given cell data and covariates.
226
226
  This function can be used for simulating cells' transcription profiles at new conditions.
@@ -231,14 +231,14 @@ class SURE(nn.Module):
231
231
  :param batch_size: Data size per batch
232
232
  :param show_progress: Toggle on or off message output
233
233
  """
234
- return self.engine.predict(xs, cs, batch_size, show_progress)
234
+ return self.engine.predict(xs, cs_list, batch_size, show_progress)
235
235
 
236
236
  def preprocess(self, xs, threshold=0):
237
237
  return self.engine.preprocess(xs=xs, threshold=threshold)
238
238
 
239
- def fit(self, xs,
240
- cs = None,
241
- fs = None,
239
+ def fit(self, xs:np.array,
240
+ css:list = None,
241
+ fs:np.array = None,
242
242
  num_epochs: int = 100,
243
243
  learning_rate: float = 0.0001,
244
244
  use_mask: bool = False,
@@ -284,7 +284,7 @@ class SURE(nn.Module):
284
284
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
285
285
  the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
286
286
  """
287
- self.engine.fit(xs=xs, cs=cs, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
287
+ self.engine.fit(xs=xs, css=css, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
288
288
  beta_1=beta_1, weight_decay=weight_decay, decay_rate=decay_rate, config_enum=config_enum, threshold=threshold,
289
289
  use_jax=use_jax, show_progress=show_progress, patience=patience, min_delta=min_delta, restore_best_weights=restore_best_weights,
290
290
  monitor=monitor)
@@ -119,12 +119,12 @@ class SURENF(nn.Module):
119
119
  def __init__(self,
120
120
  input_dim: int,
121
121
  codebook_size: int,
122
- condition_size: int = 0,
122
+ condition_sizes: list = [0],
123
123
  covariate_size: int = 0,
124
124
  transforms: int = 1,
125
125
  z_dim: int = 50,
126
126
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
127
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
127
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
128
128
  dispersion: float = 10.0,
129
129
  use_zeroinflate: bool = True,
130
130
  hidden_layers: list = [500],
@@ -141,7 +141,7 @@ class SURENF(nn.Module):
141
141
  super().__init__()
142
142
 
143
143
  self.input_dim = input_dim
144
- self.condition_size = condition_size
144
+ self.condition_sizes = condition_sizes
145
145
  self.covariate_size = covariate_size
146
146
  self.dispersion = dispersion
147
147
  self.latent_dim = z_dim
@@ -247,16 +247,19 @@ class SURENF(nn.Module):
247
247
  self.encoder_zn = zuko.flows.NSF(features=self.latent_dim, context=self.input_dim,
248
248
  transforms=self.transforms, hidden_features=self.flow_hidden_layers)
249
249
 
250
- if self.condition_size>0:
251
- self.condition_effect = ZeroBiasMLP3(
252
- [self.condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
253
- activation=activate_fct,
254
- output_activation=None,
255
- post_layer_fct=post_layer_fct,
256
- post_act_fct=post_act_fct,
257
- allow_broadcast=self.allow_broadcast,
258
- use_cuda=self.use_cuda,
259
- )
250
+ if np.sum(self.condition_sizes)>0:
251
+ self.condition_effects = nn.ModuleList()
252
+ for condition_size in self.condition_sizes:
253
+ self.condition_effects.append(ZeroBiasMLP3(
254
+ [condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
255
+ activation=activate_fct,
256
+ output_activation=None,
257
+ post_layer_fct=post_layer_fct,
258
+ post_act_fct=post_act_fct,
259
+ allow_broadcast=self.allow_broadcast,
260
+ use_cuda=self.use_cuda,
261
+ )
262
+ )
260
263
  if self.covariate_size>0:
261
264
  self.covariate_effect = ZeroBiasMLP2(
262
265
  [self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
@@ -393,8 +396,13 @@ class SURENF(nn.Module):
393
396
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
394
397
 
395
398
  zs = zns
396
- if (self.covariate_size>0) and (cs is not None):
397
- zcs = self.condition_effect([cs,zns])
399
+ if (np.sum(self.condition_sizes)>0) and (cs is not None):
400
+ zcs = torch.zeros_like(zs)
401
+ shift = 0
402
+ for i,condition_size in enumerate(self.condition_sizes):
403
+ cs_i = cs[:,shift:(shift+condition_size)]
404
+ zcs += self.condition_effects[i]([cs_i,zns])
405
+ shift += condition_size
398
406
  else:
399
407
  zcs = torch.zeros_like(zs)
400
408
  if (self.covariate_size>0) and (fs is not None):
@@ -565,7 +573,7 @@ class SURENF(nn.Module):
565
573
  A = np.concatenate(A)
566
574
  return A
567
575
 
568
- def get_condition_effects(self, xs, cs, batch_size=1024, show_progress=True):
576
+ def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
569
577
  xs = self.preprocess(xs)
570
578
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
571
579
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
@@ -580,11 +588,8 @@ class SURENF(nn.Module):
580
588
  X_batch = X_batch.to(self.get_device())
581
589
  C_batch = cs[idx].to(self.get_device())
582
590
 
583
- #_,ind = self._hard_assignments(X_batch)
584
- #z_basal = cb_loc[ind.squeeze()]
585
- ns = self._soft_assignments(X_batch)
586
- z_basal = torch.matmul(ns, cb_loc)
587
- dzs = self.condition_effect([C_batch,z_basal])
591
+ z_basal = self._get_cell_embedding(X_batch)
592
+ dzs = self.condition_effects[i]([C_batch,z_basal])
588
593
 
589
594
  A.append(tensor_to_numpy(dzs))
590
595
  pbar.update(1)
@@ -596,7 +601,7 @@ class SURENF(nn.Module):
596
601
  zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
597
602
  return self.kmeans.predict(zs)
598
603
 
599
- def predict(self, xs, cs, batch_size=1024, show_progress=True):
604
+ def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
600
605
  """
601
606
  Generate gene expression prediction from given cell data and covariates.
602
607
  This function can be used for simulating cells' transcription profiles at new conditions.
@@ -609,6 +614,7 @@ class SURENF(nn.Module):
609
614
  """
610
615
  xs = self.preprocess(xs)
611
616
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
617
+ cs = np.hstack(cs_list)
612
618
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
613
619
 
614
620
  dataset = CustomDataset(xs)
@@ -622,7 +628,14 @@ class SURENF(nn.Module):
622
628
  library_size = torch.sum(X_batch, 1)
623
629
 
624
630
  z_basal = self._get_cell_embedding(X_batch)
625
- zcs = self.condition_effect([C_batch, z_basal])
631
+
632
+ zcs = torch.zeros_like(z_basal)
633
+ shift = 0
634
+ for i, condition_size in enumerate(self.condition_sizes):
635
+ C_batch_i = C_batch[:, shift:(shift+condition_size)]
636
+ zcs += self.condition_effects[i]([C_batch_i,z_basal])
637
+ shift += condition_size
638
+
626
639
  zfs = torch.zeros_like(z_basal)
627
640
 
628
641
  log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
@@ -829,9 +842,9 @@ class SURENF(nn.Module):
829
842
  if name in param_store:
830
843
  param_store[name] = param'''
831
844
 
832
- def fit(self, xs,
833
- cs = None,
834
- fs = None,
845
+ def fit(self, xs:np.array,
846
+ css:list = None,
847
+ fs:np.array = None,
835
848
  num_epochs: int = 100,
836
849
  learning_rate: float = 0.0001,
837
850
  use_mask: bool = False,
@@ -892,7 +905,8 @@ class SURENF(nn.Module):
892
905
 
893
906
  xs = self.preprocess(xs, threshold=threshold)
894
907
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
895
- if cs is not None:
908
+ if css is not None:
909
+ cs = np.hstack(css)
896
910
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
897
911
  if fs is not None:
898
912
  fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
@@ -948,7 +962,7 @@ class SURENF(nn.Module):
948
962
  for batch_x, idx in dataloader:
949
963
  batch_x = batch_x.to(self.get_device())
950
964
  for loss_id in range(num_losses):
951
- if cs is None:
965
+ if css is None:
952
966
  batch_c = None
953
967
  else:
954
968
  batch_c = cs[idx].to(self.get_device())
@@ -119,7 +119,7 @@ class SUREVAE(nn.Module):
119
119
  def __init__(self,
120
120
  input_dim: int,
121
121
  codebook_size: int,
122
- condition_size: int = 0,
122
+ condition_sizes: list = [0],
123
123
  covariate_size: int = 0,
124
124
  transforms: int = 1,
125
125
  z_dim: int = 50,
@@ -141,7 +141,7 @@ class SUREVAE(nn.Module):
141
141
  super().__init__()
142
142
 
143
143
  self.input_dim = input_dim
144
- self.condition_size = condition_size
144
+ self.condition_sizes = condition_sizes
145
145
  self.covariate_size = covariate_size
146
146
  self.dispersion = dispersion
147
147
  self.latent_dim = z_dim
@@ -255,16 +255,19 @@ class SUREVAE(nn.Module):
255
255
  use_cuda=self.use_cuda,
256
256
  )
257
257
 
258
- if self.condition_size>0:
259
- self.condition_effect = ZeroBiasMLP3(
260
- [self.condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
261
- activation=activate_fct,
262
- output_activation=None,
263
- post_layer_fct=post_layer_fct,
264
- post_act_fct=post_act_fct,
265
- allow_broadcast=self.allow_broadcast,
266
- use_cuda=self.use_cuda,
267
- )
258
+ if np.sum(self.condition_sizes)>0:
259
+ self.condition_effects = nn.ModuleList()
260
+ for condition_size in self.condition_sizes:
261
+ self.condition_effects.append(ZeroBiasMLP3(
262
+ [condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
263
+ activation=activate_fct,
264
+ output_activation=None,
265
+ post_layer_fct=post_layer_fct,
266
+ post_act_fct=post_act_fct,
267
+ allow_broadcast=self.allow_broadcast,
268
+ use_cuda=self.use_cuda,
269
+ )
270
+ )
268
271
  if self.covariate_size>0:
269
272
  self.covariate_effect = ZeroBiasMLP2(
270
273
  [self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
@@ -401,8 +404,13 @@ class SUREVAE(nn.Module):
401
404
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
402
405
 
403
406
  zs = zns
404
- if (self.condition_size>0) and (cs is not None):
405
- zcs = self.condition_effect([cs,zns])
407
+ if (np.sum(self.condition_sizes)>0) and (cs is not None):
408
+ zcs = torch.zeros_like(zs)
409
+ shift = 0
410
+ for i, condition_size in enumerate(self.condition_sizes):
411
+ cs_i = cs[:,shift:(shift+condition_size)]
412
+ zcs += self.condition_effects[i]([cs_i,zns])
413
+ shift += condition_size
406
414
  else:
407
415
  zcs = torch.zeros_like(zs)
408
416
  if (self.covariate_size>0) and (fs is not None):
@@ -578,7 +586,7 @@ class SUREVAE(nn.Module):
578
586
  A = np.concatenate(A)
579
587
  return A
580
588
 
581
- def get_condition_effects(self, xs, cs, batch_size=1024, show_progress=True):
589
+ def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
582
590
  xs = self.preprocess(xs)
583
591
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
584
592
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
@@ -593,10 +601,8 @@ class SUREVAE(nn.Module):
593
601
  X_batch = X_batch.to(self.get_device())
594
602
  C_batch = cs[idx].to(self.get_device())
595
603
 
596
- #ns = self._soft_assignments(X_batch)
597
- #z_basal = torch.matmul(ns, cb_loc)
598
604
  z_basal = self._get_cell_embedding(X_batch)
599
- dzs = self.condition_effect([C_batch,z_basal])
605
+ dzs = self.condition_effects[i]([C_batch,z_basal])
600
606
 
601
607
  A.append(tensor_to_numpy(dzs))
602
608
  pbar.update(1)
@@ -608,7 +614,7 @@ class SUREVAE(nn.Module):
608
614
  zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
609
615
  return self.kmeans.predict(zs)
610
616
 
611
- def predict(self, xs, cs, batch_size=1024, show_progress=True):
617
+ def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
612
618
  """
613
619
  Generate gene expression prediction from given cell data and covariates.
614
620
  This function can be used for simulating cells' transcription profiles at new conditions.
@@ -621,6 +627,7 @@ class SUREVAE(nn.Module):
621
627
  """
622
628
  xs = self.preprocess(xs)
623
629
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
630
+ cs = np.hstack(cs_list)
624
631
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
625
632
 
626
633
  dataset = CustomDataset(xs)
@@ -634,7 +641,14 @@ class SUREVAE(nn.Module):
634
641
  library_size = torch.sum(X_batch, 1)
635
642
 
636
643
  z_basal = self._get_cell_embedding(X_batch)
637
- zcs = self.condition_effect([C_batch,z_basal])
644
+
645
+ zcs = torch.zeros_like(z_basal)
646
+ shift = 0
647
+ for i, condition_size in enumerate(self.condition_sizes):
648
+ C_batch_i = C_batch[:, shift:(shift+condition_size)]
649
+ zcs += self.condition_effects[i]([C_batch_i,z_basal])
650
+ shift += condition_size
651
+
638
652
  zfs = torch.zeros_like(z_basal)
639
653
 
640
654
  log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
@@ -760,9 +774,9 @@ class SUREVAE(nn.Module):
760
774
  pbar.set_postfix({'loss': str_loss})
761
775
  pbar.update(1)'''
762
776
 
763
- def fit(self, xs,
764
- cs = None,
765
- fs = None,
777
+ def fit(self, xs: np.array,
778
+ css: list = None,
779
+ fs: np.array = None,
766
780
  num_epochs: int = 100,
767
781
  learning_rate: float = 0.0001,
768
782
  use_mask: bool = False,
@@ -823,7 +837,8 @@ class SUREVAE(nn.Module):
823
837
 
824
838
  xs = self.preprocess(xs, threshold=threshold)
825
839
  xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
826
- if cs is not None:
840
+ if css is not None:
841
+ cs = np.hstack(css)
827
842
  cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
828
843
  if fs is not None:
829
844
  fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
@@ -879,7 +894,7 @@ class SUREVAE(nn.Module):
879
894
  for batch_x, idx in dataloader:
880
895
  batch_x = batch_x.to(self.get_device())
881
896
  for loss_id in range(num_losses):
882
- if cs is None:
897
+ if css is None:
883
898
  batch_c = None
884
899
  else:
885
900
  batch_c = cs[idx].to(self.get_device())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 3.6.13
3
+ Version: 3.7.0
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='3.6.13',
8
+ version='3.7.0',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes
File without changes