SURE-tools 2.4.34__py3-none-any.whl → 2.4.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SURE/DensityFlow.py +41 -10
- {sure_tools-2.4.34.dist-info → sure_tools-2.4.42.dist-info}/METADATA +1 -1
- {sure_tools-2.4.34.dist-info → sure_tools-2.4.42.dist-info}/RECORD +7 -7
- {sure_tools-2.4.34.dist-info → sure_tools-2.4.42.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.34.dist-info → sure_tools-2.4.42.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.34.dist-info → sure_tools-2.4.42.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.34.dist-info → sure_tools-2.4.42.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,7 +57,7 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 30,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
@@ -113,6 +113,7 @@ class DensityFlow(nn.Module):
|
|
|
113
113
|
self.setup_networks()
|
|
114
114
|
|
|
115
115
|
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Codebook size: {self.code_size}")
|
|
116
117
|
print(f" - Latent Dimension: {self.latent_dim}")
|
|
117
118
|
print(f" - Gene Dimension: {self.input_size}")
|
|
118
119
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
@@ -397,7 +398,7 @@ class DensityFlow(nn.Module):
|
|
|
397
398
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
398
399
|
|
|
399
400
|
if self.loss_func == 'negbinomial':
|
|
400
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
401
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
401
402
|
if self.use_zeroinflate:
|
|
402
403
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
403
404
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -486,7 +487,7 @@ class DensityFlow(nn.Module):
|
|
|
486
487
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
487
488
|
|
|
488
489
|
if self.loss_func == 'negbinomial':
|
|
489
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
490
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
490
491
|
if self.use_zeroinflate:
|
|
491
492
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
492
493
|
else:
|
|
@@ -585,7 +586,7 @@ class DensityFlow(nn.Module):
|
|
|
585
586
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
586
587
|
|
|
587
588
|
if self.loss_func == 'negbinomial':
|
|
588
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
589
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
589
590
|
if self.use_zeroinflate:
|
|
590
591
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
591
592
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -695,7 +696,7 @@ class DensityFlow(nn.Module):
|
|
|
695
696
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
696
697
|
|
|
697
698
|
if self.loss_func == 'negbinomial':
|
|
698
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
699
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
699
700
|
if self.use_zeroinflate:
|
|
700
701
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
701
702
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -758,6 +759,28 @@ class DensityFlow(nn.Module):
|
|
|
758
759
|
cb = self._get_codebook()
|
|
759
760
|
cb = tensor_to_numpy(cb)
|
|
760
761
|
return cb
|
|
762
|
+
|
|
763
|
+
def _get_complete_embedding(self, xs, us):
|
|
764
|
+
basal,_ = self._get_basal_embedding(xs)
|
|
765
|
+
dzs = self._total_effects(basal, us)
|
|
766
|
+
return basal + dzs
|
|
767
|
+
|
|
768
|
+
def get_complete_embedding(self, xs, us, batch_size:int=1024):
|
|
769
|
+
xs = self.preprocess(xs)
|
|
770
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
771
|
+
us = convert_to_tensor(us, device=self.get_device())
|
|
772
|
+
dataset = CustomDataset2(xs, us)
|
|
773
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
774
|
+
|
|
775
|
+
Z = []
|
|
776
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
777
|
+
for X_batch, U_batch, _ in dataloader:
|
|
778
|
+
zns = self._get_complete_embedding(X_batch, U_batch)
|
|
779
|
+
Z.append(tensor_to_numpy(zns))
|
|
780
|
+
pbar.update(1)
|
|
781
|
+
|
|
782
|
+
Z = np.concatenate(Z)
|
|
783
|
+
return Z
|
|
761
784
|
|
|
762
785
|
def _get_basal_embedding(self, xs):
|
|
763
786
|
loc, scale = self.encoder_zn(xs)
|
|
@@ -1008,7 +1031,7 @@ class DensityFlow(nn.Module):
|
|
|
1008
1031
|
ad = sc.AnnData(xs)
|
|
1009
1032
|
binarize(ad, threshold=threshold)
|
|
1010
1033
|
xs = ad.X.copy()
|
|
1011
|
-
|
|
1034
|
+
elif self.loss_func == 'poisson':
|
|
1012
1035
|
xs = np.round(xs)
|
|
1013
1036
|
|
|
1014
1037
|
if sparse.issparse(xs):
|
|
@@ -1142,7 +1165,7 @@ class DensityFlow(nn.Module):
|
|
|
1142
1165
|
pbar.set_postfix({'loss': str_loss})
|
|
1143
1166
|
pbar.update(1)
|
|
1144
1167
|
|
|
1145
|
-
|
|
1168
|
+
@classmethod
|
|
1146
1169
|
def save_model(cls, model, file_path, compression=False):
|
|
1147
1170
|
"""Save the model to the specified file path."""
|
|
1148
1171
|
file_path = os.path.abspath(file_path)
|
|
@@ -1169,10 +1192,18 @@ class DensityFlow(nn.Module):
|
|
|
1169
1192
|
else:
|
|
1170
1193
|
with open(file_path, 'rb') as pickle_file:
|
|
1171
1194
|
model = pickle.load(pickle_file)
|
|
1195
|
+
|
|
1196
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
1197
|
+
print(f" - Codebook size: {model.code_size}")
|
|
1198
|
+
print(f" - Latent Dimension: {model.latent_dim}")
|
|
1199
|
+
print(f" - Gene Dimension: {model.input_size}")
|
|
1200
|
+
print(f" - Hidden Dimensions: {model.hidden_layers}")
|
|
1201
|
+
print(f" - Device: {model.get_device()}")
|
|
1202
|
+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1172
1203
|
|
|
1173
|
-
return model
|
|
1204
|
+
return model
|
|
1174
1205
|
|
|
1175
|
-
def
|
|
1206
|
+
''' def save(self, path):
|
|
1176
1207
|
"""Save model checkpoint"""
|
|
1177
1208
|
torch.save({
|
|
1178
1209
|
'model_state_dict': self.state_dict(),
|
|
@@ -1209,7 +1240,7 @@ class DensityFlow(nn.Module):
|
|
|
1209
1240
|
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1210
1241
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1211
1242
|
|
|
1212
|
-
return model
|
|
1243
|
+
return model'''
|
|
1213
1244
|
|
|
1214
1245
|
|
|
1215
1246
|
EXAMPLE_RUN = (
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=hYuPv1X9lsbKBIAGNRfzfCRYE1szCZPVJkOnIlK-Oc4,60246
|
|
2
2
|
SURE/DensityFlow2.py,sha256=BBRCoA4NpU4EjghToOvowo17UtwYokTN75KxWYHTX1E,58404
|
|
3
3
|
SURE/DensityFlowLinear.py,sha256=bYiPHJ6mza4sOXUjlFq7wButu3rNLYZuqWUTtIO06F4,57540
|
|
4
4
|
SURE/EfficientTranscriptomeDecoder.py,sha256=O_x-4edKBU5OJJbOOS-59u3TQElZqhAtOVJMPlpw8m0,21667
|
|
@@ -25,9 +25,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
25
25
|
SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
|
|
26
26
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
27
27
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
28
|
-
sure_tools-2.4.
|
|
29
|
-
sure_tools-2.4.
|
|
30
|
-
sure_tools-2.4.
|
|
31
|
-
sure_tools-2.4.
|
|
32
|
-
sure_tools-2.4.
|
|
33
|
-
sure_tools-2.4.
|
|
28
|
+
sure_tools-2.4.42.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
29
|
+
sure_tools-2.4.42.dist-info/METADATA,sha256=g9G4ftWszZ4prtVpgsS3GVIhY-nvh1g8zhOunpDGiqg,2678
|
|
30
|
+
sure_tools-2.4.42.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
sure_tools-2.4.42.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
32
|
+
sure_tools-2.4.42.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
33
|
+
sure_tools-2.4.42.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|