SURE-tools 2.4.32__py3-none-any.whl → 2.4.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py CHANGED
@@ -57,12 +57,12 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 100,
60
+ codebook_size: int = 30,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
64
  z_dim: int = 50,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
66
  loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
67
  dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
@@ -85,6 +85,7 @@ class DensityFlow(nn.Module):
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,10 +108,12 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
112
114
 
113
115
  print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Codebook size: {self.code_size}")
114
117
  print(f" - Latent Dimension: {self.latent_dim}")
115
118
  print(f" - Gene Dimension: {self.input_size}")
116
119
  print(f" - Hidden Dimensions: {self.hidden_layers}")
@@ -395,7 +398,7 @@ class DensityFlow(nn.Module):
395
398
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
396
399
 
397
400
  if self.loss_func == 'negbinomial':
398
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
401
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
399
402
  if self.use_zeroinflate:
400
403
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
401
404
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -484,7 +487,7 @@ class DensityFlow(nn.Module):
484
487
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
485
488
 
486
489
  if self.loss_func == 'negbinomial':
487
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
490
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
488
491
  if self.use_zeroinflate:
489
492
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
490
493
  else:
@@ -583,7 +586,7 @@ class DensityFlow(nn.Module):
583
586
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
584
587
 
585
588
  if self.loss_func == 'negbinomial':
586
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
589
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
587
590
  if self.use_zeroinflate:
588
591
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
589
592
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -693,7 +696,7 @@ class DensityFlow(nn.Module):
693
696
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
694
697
 
695
698
  if self.loss_func == 'negbinomial':
696
- logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
699
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
697
700
  if self.use_zeroinflate:
698
701
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
699
702
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -756,6 +759,28 @@ class DensityFlow(nn.Module):
756
759
  cb = self._get_codebook()
757
760
  cb = tensor_to_numpy(cb)
758
761
  return cb
762
+
763
+ def _get_complete_embedding(self, xs, us):
764
+ basal,_ = self._get_basal_embedding(xs)
765
+ dzs = self._total_effects(basal, us)
766
+ return basal + dzs
767
+
768
+ def get_complete_embedding(self, xs, us, batch_size:int=1024):
769
+ xs = self.preprocess(xs)
770
+ xs = convert_to_tensor(xs, device=self.get_device())
771
+ us = convert_to_tensor(us, device=self.get_device())
772
+ dataset = CustomDataset2(xs, us)
773
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
774
+
775
+ Z = []
776
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
777
+ for X_batch, U_batch, _ in dataloader:
778
+ zns = self._get_complete_embedding(X_batch, U_batch)
779
+ Z.append(tensor_to_numpy(zns))
780
+ pbar.update(1)
781
+
782
+ Z = np.concatenate(Z)
783
+ return Z
759
784
 
760
785
  def _get_basal_embedding(self, xs):
761
786
  loc, scale = self.encoder_zn(xs)
@@ -1167,8 +1192,55 @@ class DensityFlow(nn.Module):
1167
1192
  else:
1168
1193
  with open(file_path, 'rb') as pickle_file:
1169
1194
  model = pickle.load(pickle_file)
1195
+
1196
+ print(f"🧬 DensityFlow Initialized:")
1197
+ print(f" - Codebook size: {model.code_size}")
1198
+ print(f" - Latent Dimension: {model.latent_dim}")
1199
+ print(f" - Gene Dimension: {model.input_size}")
1200
+ print(f" - Hidden Dimensions: {model.hidden_layers}")
1201
+ print(f" - Device: {model.get_device()}")
1202
+ print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
1170
1203
 
1171
1204
  return model
1205
+
1206
+ ''' def save(self, path):
1207
+ """Save model checkpoint"""
1208
+ torch.save({
1209
+ 'model_state_dict': self.state_dict(),
1210
+ 'model_config': {
1211
+ 'input_size': self.input_size,
1212
+ 'codebook_size': self.code_size,
1213
+ 'cell_factor_size': self.cell_factor_size,
1214
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1215
+ 'supervised_mode':self.supervised_mode,
1216
+ 'z_dim': self.latent_dim,
1217
+ 'z_dist': self.latent_dist,
1218
+ 'loss_func': self.loss_func,
1219
+ 'dispersion': self.dispersion,
1220
+ 'use_zeroinflate': self.use_zeroinflate,
1221
+ 'hidden_layers':self.hidden_layers,
1222
+ 'hidden_layer_activation':self.hidden_layer_activation,
1223
+ 'nn_dropout':self.nn_dropout,
1224
+ 'post_layer_fct':self.post_layer_fct,
1225
+ 'post_act_fct':self.post_act_fct,
1226
+ 'config_enum':self.config_enum,
1227
+ 'use_cuda':self.use_cuda,
1228
+ 'seed':self.seed,
1229
+ 'zero_bias':self.use_bias,
1230
+ 'dtype':self.dtype,
1231
+ }
1232
+ }, path)
1233
+
1234
+ @classmethod
1235
+ def load_model(cls, model_path: str):
1236
+ """Load pre-trained model"""
1237
+ checkpoint = torch.load(model_path)
1238
+ model = DensityFlow(**checkpoint.get('model_config'))
1239
+
1240
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1241
+ model.load_state_dict(checkpoint['model_state_dict'])
1242
+
1243
+ return model'''
1172
1244
 
1173
1245
 
1174
1246
  EXAMPLE_RUN = (
SURE/DensityFlowLinear.py CHANGED
@@ -107,6 +107,7 @@ class DensityFlowLinear(nn.Module):
107
107
 
108
108
  self.codebook_weights = None
109
109
 
110
+ self.seed = seed
110
111
  set_random_seed(seed)
111
112
  self.setup_networks()
112
113
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.32
3
+ Version: 2.4.43
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,6 +1,6 @@
1
- SURE/DensityFlow.py,sha256=M8yDMCzH-LSVjFB34pUaN2TNUqo0jnZMmaFcLLjtYBU,57218
1
+ SURE/DensityFlow.py,sha256=fqqI8sHnfXuTK9O1il-dL7F4W7gbUMjGHD8uRwpESlc,60218
2
2
  SURE/DensityFlow2.py,sha256=BBRCoA4NpU4EjghToOvowo17UtwYokTN75KxWYHTX1E,58404
3
- SURE/DensityFlowLinear.py,sha256=RfjIwXxO0gWQni7LHRy4DSDUkdz3HpDjlZbxJse65Ts,57515
3
+ SURE/DensityFlowLinear.py,sha256=bYiPHJ6mza4sOXUjlFq7wButu3rNLYZuqWUTtIO06F4,57540
4
4
  SURE/EfficientTranscriptomeDecoder.py,sha256=O_x-4edKBU5OJJbOOS-59u3TQElZqhAtOVJMPlpw8m0,21667
5
5
  SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
6
6
  SURE/PerturbationAwareDecoder.py,sha256=duhvBvZjOpAk7c2YTfmA2qKbrgVvwT7IW1pxaukq_iU,30231
@@ -25,9 +25,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
25
25
  SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
26
26
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
27
27
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
28
- sure_tools-2.4.32.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
29
- sure_tools-2.4.32.dist-info/METADATA,sha256=9o2WcDEY2JIf1AhclsLePndrjYVexLj9QUHtZjzGc38,2678
30
- sure_tools-2.4.32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- sure_tools-2.4.32.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
32
- sure_tools-2.4.32.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
33
- sure_tools-2.4.32.dist-info/RECORD,,
28
+ sure_tools-2.4.43.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
29
+ sure_tools-2.4.43.dist-info/METADATA,sha256=q0DTzGgBqj5Hi8n2YNmJymHD25dZSUdRVlMrfiy-5Hw,2678
30
+ sure_tools-2.4.43.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ sure_tools-2.4.43.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
32
+ sure_tools-2.4.43.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
33
+ sure_tools-2.4.43.dist-info/RECORD,,