SURE-tools 2.4.34__py3-none-any.whl → 2.4.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/DensityFlow.py CHANGED
@@ -57,7 +57,7 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 100,
60
+ codebook_size: int = 30,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
@@ -113,6 +113,7 @@ class DensityFlow(nn.Module):
113
113
  self.setup_networks()
114
114
 
115
115
  print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Codebook size: {self.code_size}")
116
117
  print(f" - Latent Dimension: {self.latent_dim}")
117
118
  print(f" - Gene Dimension: {self.input_size}")
118
119
  print(f" - Hidden Dimensions: {self.hidden_layers}")
@@ -397,7 +398,7 @@ class DensityFlow(nn.Module):
397
398
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
398
399
 
399
400
  if self.loss_func == 'negbinomial':
400
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
401
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
401
402
  if self.use_zeroinflate:
402
403
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
403
404
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -486,7 +487,7 @@ class DensityFlow(nn.Module):
486
487
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
487
488
 
488
489
  if self.loss_func == 'negbinomial':
489
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
490
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
490
491
  if self.use_zeroinflate:
491
492
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
492
493
  else:
@@ -585,7 +586,7 @@ class DensityFlow(nn.Module):
585
586
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
586
587
 
587
588
  if self.loss_func == 'negbinomial':
588
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
589
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
589
590
  if self.use_zeroinflate:
590
591
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
591
592
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -695,7 +696,7 @@ class DensityFlow(nn.Module):
695
696
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
696
697
 
697
698
  if self.loss_func == 'negbinomial':
698
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
699
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
699
700
  if self.use_zeroinflate:
700
701
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
701
702
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -758,6 +759,28 @@ class DensityFlow(nn.Module):
758
759
  cb = self._get_codebook()
759
760
  cb = tensor_to_numpy(cb)
760
761
  return cb
762
+
763
+ def _get_complete_embedding(self, xs, us):
764
+ basal,_ = self._get_basal_embedding(xs)
765
+ dzs = self._total_effects(basal, us)
766
+ return basal + dzs
767
+
768
+ def get_complete_embedding(self, xs, us, batch_size:int=1024):
769
+ xs = self.preprocess(xs)
770
+ xs = convert_to_tensor(xs, device=self.get_device())
771
+ us = convert_to_tensor(us, device=self.get_device())
772
+ dataset = CustomDataset2(xs, us)
773
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
774
+
775
+ Z = []
776
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
777
+ for X_batch, U_batch, _ in dataloader:
778
+ zns = self._get_complete_embedding(X_batch, U_batch)
779
+ Z.append(tensor_to_numpy(zns))
780
+ pbar.update(1)
781
+
782
+ Z = np.concatenate(Z)
783
+ return Z
761
784
 
762
785
  def _get_basal_embedding(self, xs):
763
786
  loc, scale = self.encoder_zn(xs)
@@ -1008,7 +1031,7 @@ class DensityFlow(nn.Module):
1008
1031
  ad = sc.AnnData(xs)
1009
1032
  binarize(ad, threshold=threshold)
1010
1033
  xs = ad.X.copy()
1011
- else:
1034
+ elif self.loss_func == 'poisson':
1012
1035
  xs = np.round(xs)
1013
1036
 
1014
1037
  if sparse.issparse(xs):
@@ -1142,7 +1165,7 @@ class DensityFlow(nn.Module):
1142
1165
  pbar.set_postfix({'loss': str_loss})
1143
1166
  pbar.update(1)
1144
1167
 
1145
- '''@classmethod
1168
+ @classmethod
1146
1169
  def save_model(cls, model, file_path, compression=False):
1147
1170
  """Save the model to the specified file path."""
1148
1171
  file_path = os.path.abspath(file_path)
@@ -1169,10 +1192,18 @@ class DensityFlow(nn.Module):
1169
1192
  else:
1170
1193
  with open(file_path, 'rb') as pickle_file:
1171
1194
  model = pickle.load(pickle_file)
1195
+
1196
+ print(f"🧬 DensityFlow Initialized:")
1197
+ print(f" - Codebook size: {model.code_size}")
1198
+ print(f" - Latent Dimension: {model.latent_dim}")
1199
+ print(f" - Gene Dimension: {model.input_size}")
1200
+ print(f" - Hidden Dimensions: {model.hidden_layers}")
1201
+ print(f" - Device: {model.get_device()}")
1202
+ print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
1172
1203
 
1173
- return model'''
1204
+ return model
1174
1205
 
1175
- def save_model(self, path):
1206
+ ''' def save(self, path):
1176
1207
  """Save model checkpoint"""
1177
1208
  torch.save({
1178
1209
  'model_state_dict': self.state_dict(),
@@ -1209,7 +1240,7 @@ class DensityFlow(nn.Module):
1209
1240
  checkpoint = torch.load(model_path, map_location=model.get_device())
1210
1241
  model.load_state_dict(checkpoint['model_state_dict'])
1211
1242
 
1212
- return model
1243
+ return model'''
1213
1244
 
1214
1245
 
1215
1246
  EXAMPLE_RUN = (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.34
3
+ Version: 2.4.42
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/DensityFlow.py,sha256=KsngVc_ciwNAxXQZD4JFqJ-sDFLU_3Ra6Nyjcs6NvWo,58878
1
+ SURE/DensityFlow.py,sha256=hYuPv1X9lsbKBIAGNRfzfCRYE1szCZPVJkOnIlK-Oc4,60246
2
2
  SURE/DensityFlow2.py,sha256=BBRCoA4NpU4EjghToOvowo17UtwYokTN75KxWYHTX1E,58404
3
3
  SURE/DensityFlowLinear.py,sha256=bYiPHJ6mza4sOXUjlFq7wButu3rNLYZuqWUTtIO06F4,57540
4
4
  SURE/EfficientTranscriptomeDecoder.py,sha256=O_x-4edKBU5OJJbOOS-59u3TQElZqhAtOVJMPlpw8m0,21667
@@ -25,9 +25,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
25
25
  SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
26
26
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
27
27
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
28
- sure_tools-2.4.34.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
29
- sure_tools-2.4.34.dist-info/METADATA,sha256=dRu5kHcrc-T8DPNWOoz5_aNQM_1tc0sJ3BR40ZHGkec,2678
30
- sure_tools-2.4.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- sure_tools-2.4.34.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
32
- sure_tools-2.4.34.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
33
- sure_tools-2.4.34.dist-info/RECORD,,
28
+ sure_tools-2.4.42.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
29
+ sure_tools-2.4.42.dist-info/METADATA,sha256=g9G4ftWszZ4prtVpgsS3GVIhY-nvh1g8zhOunpDGiqg,2678
30
+ sure_tools-2.4.42.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ sure_tools-2.4.42.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
32
+ sure_tools-2.4.42.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
33
+ sure_tools-2.4.42.dist-info/RECORD,,