SURE-tools 2.4.32__py3-none-any.whl → 2.4.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/DensityFlow.py +78 -6
- SURE/DensityFlowLinear.py +1 -0
- {sure_tools-2.4.32.dist-info → sure_tools-2.4.43.dist-info}/METADATA +1 -1
- {sure_tools-2.4.32.dist-info → sure_tools-2.4.43.dist-info}/RECORD +8 -8
- {sure_tools-2.4.32.dist-info → sure_tools-2.4.43.dist-info}/WHEEL +0 -0
- {sure_tools-2.4.32.dist-info → sure_tools-2.4.43.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.4.32.dist-info → sure_tools-2.4.43.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.4.32.dist-info → sure_tools-2.4.43.dist-info}/top_level.txt +0 -0
SURE/DensityFlow.py
CHANGED
|
@@ -57,12 +57,12 @@ def set_random_seed(seed):
|
|
|
57
57
|
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
|
-
codebook_size: int =
|
|
60
|
+
codebook_size: int = 30,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
62
|
turn_off_cell_specific: bool = False,
|
|
63
63
|
supervised_mode: bool = False,
|
|
64
64
|
z_dim: int = 50,
|
|
65
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
65
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
66
66
|
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
67
|
dispersion: float = 8.0,
|
|
68
68
|
use_zeroinflate: bool = False,
|
|
@@ -85,6 +85,7 @@ class DensityFlow(nn.Module):
|
|
|
85
85
|
self.latent_dim = z_dim
|
|
86
86
|
self.hidden_layers = hidden_layers
|
|
87
87
|
self.decoder_hidden_layers = hidden_layers[::-1]
|
|
88
|
+
self.config_enum = config_enum
|
|
88
89
|
self.allow_broadcast = config_enum == 'parallel'
|
|
89
90
|
self.use_cuda = use_cuda
|
|
90
91
|
self.loss_func = loss_func
|
|
@@ -107,10 +108,12 @@ class DensityFlow(nn.Module):
|
|
|
107
108
|
|
|
108
109
|
self.codebook_weights = None
|
|
109
110
|
|
|
111
|
+
self.seed = seed
|
|
110
112
|
set_random_seed(seed)
|
|
111
113
|
self.setup_networks()
|
|
112
114
|
|
|
113
115
|
print(f"🧬 DensityFlow Initialized:")
|
|
116
|
+
print(f" - Codebook size: {self.code_size}")
|
|
114
117
|
print(f" - Latent Dimension: {self.latent_dim}")
|
|
115
118
|
print(f" - Gene Dimension: {self.input_size}")
|
|
116
119
|
print(f" - Hidden Dimensions: {self.hidden_layers}")
|
|
@@ -395,7 +398,7 @@ class DensityFlow(nn.Module):
|
|
|
395
398
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
396
399
|
|
|
397
400
|
if self.loss_func == 'negbinomial':
|
|
398
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
401
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
399
402
|
if self.use_zeroinflate:
|
|
400
403
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
401
404
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -484,7 +487,7 @@ class DensityFlow(nn.Module):
|
|
|
484
487
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
485
488
|
|
|
486
489
|
if self.loss_func == 'negbinomial':
|
|
487
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
490
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
488
491
|
if self.use_zeroinflate:
|
|
489
492
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
490
493
|
else:
|
|
@@ -583,7 +586,7 @@ class DensityFlow(nn.Module):
|
|
|
583
586
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
584
587
|
|
|
585
588
|
if self.loss_func == 'negbinomial':
|
|
586
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
589
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
587
590
|
if self.use_zeroinflate:
|
|
588
591
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
589
592
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -693,7 +696,7 @@ class DensityFlow(nn.Module):
|
|
|
693
696
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
694
697
|
|
|
695
698
|
if self.loss_func == 'negbinomial':
|
|
696
|
-
logits = (mu.log()-dispersion.log()).clamp(min=-
|
|
699
|
+
logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
|
|
697
700
|
if self.use_zeroinflate:
|
|
698
701
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
|
|
699
702
|
logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
@@ -756,6 +759,28 @@ class DensityFlow(nn.Module):
|
|
|
756
759
|
cb = self._get_codebook()
|
|
757
760
|
cb = tensor_to_numpy(cb)
|
|
758
761
|
return cb
|
|
762
|
+
|
|
763
|
+
def _get_complete_embedding(self, xs, us):
|
|
764
|
+
basal,_ = self._get_basal_embedding(xs)
|
|
765
|
+
dzs = self._total_effects(basal, us)
|
|
766
|
+
return basal + dzs
|
|
767
|
+
|
|
768
|
+
def get_complete_embedding(self, xs, us, batch_size:int=1024):
|
|
769
|
+
xs = self.preprocess(xs)
|
|
770
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
771
|
+
us = convert_to_tensor(us, device=self.get_device())
|
|
772
|
+
dataset = CustomDataset2(xs, us)
|
|
773
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
774
|
+
|
|
775
|
+
Z = []
|
|
776
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
777
|
+
for X_batch, U_batch, _ in dataloader:
|
|
778
|
+
zns = self._get_complete_embedding(X_batch, U_batch)
|
|
779
|
+
Z.append(tensor_to_numpy(zns))
|
|
780
|
+
pbar.update(1)
|
|
781
|
+
|
|
782
|
+
Z = np.concatenate(Z)
|
|
783
|
+
return Z
|
|
759
784
|
|
|
760
785
|
def _get_basal_embedding(self, xs):
|
|
761
786
|
loc, scale = self.encoder_zn(xs)
|
|
@@ -1167,8 +1192,55 @@ class DensityFlow(nn.Module):
|
|
|
1167
1192
|
else:
|
|
1168
1193
|
with open(file_path, 'rb') as pickle_file:
|
|
1169
1194
|
model = pickle.load(pickle_file)
|
|
1195
|
+
|
|
1196
|
+
print(f"🧬 DensityFlow Initialized:")
|
|
1197
|
+
print(f" - Codebook size: {model.code_size}")
|
|
1198
|
+
print(f" - Latent Dimension: {model.latent_dim}")
|
|
1199
|
+
print(f" - Gene Dimension: {model.input_size}")
|
|
1200
|
+
print(f" - Hidden Dimensions: {model.hidden_layers}")
|
|
1201
|
+
print(f" - Device: {model.get_device()}")
|
|
1202
|
+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
1170
1203
|
|
|
1171
1204
|
return model
|
|
1205
|
+
|
|
1206
|
+
''' def save(self, path):
|
|
1207
|
+
"""Save model checkpoint"""
|
|
1208
|
+
torch.save({
|
|
1209
|
+
'model_state_dict': self.state_dict(),
|
|
1210
|
+
'model_config': {
|
|
1211
|
+
'input_size': self.input_size,
|
|
1212
|
+
'codebook_size': self.code_size,
|
|
1213
|
+
'cell_factor_size': self.cell_factor_size,
|
|
1214
|
+
'turn_off_cell_specific':self.turn_off_cell_specific,
|
|
1215
|
+
'supervised_mode':self.supervised_mode,
|
|
1216
|
+
'z_dim': self.latent_dim,
|
|
1217
|
+
'z_dist': self.latent_dist,
|
|
1218
|
+
'loss_func': self.loss_func,
|
|
1219
|
+
'dispersion': self.dispersion,
|
|
1220
|
+
'use_zeroinflate': self.use_zeroinflate,
|
|
1221
|
+
'hidden_layers':self.hidden_layers,
|
|
1222
|
+
'hidden_layer_activation':self.hidden_layer_activation,
|
|
1223
|
+
'nn_dropout':self.nn_dropout,
|
|
1224
|
+
'post_layer_fct':self.post_layer_fct,
|
|
1225
|
+
'post_act_fct':self.post_act_fct,
|
|
1226
|
+
'config_enum':self.config_enum,
|
|
1227
|
+
'use_cuda':self.use_cuda,
|
|
1228
|
+
'seed':self.seed,
|
|
1229
|
+
'zero_bias':self.use_bias,
|
|
1230
|
+
'dtype':self.dtype,
|
|
1231
|
+
}
|
|
1232
|
+
}, path)
|
|
1233
|
+
|
|
1234
|
+
@classmethod
|
|
1235
|
+
def load_model(cls, model_path: str):
|
|
1236
|
+
"""Load pre-trained model"""
|
|
1237
|
+
checkpoint = torch.load(model_path)
|
|
1238
|
+
model = DensityFlow(**checkpoint.get('model_config'))
|
|
1239
|
+
|
|
1240
|
+
checkpoint = torch.load(model_path, map_location=model.get_device())
|
|
1241
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
1242
|
+
|
|
1243
|
+
return model'''
|
|
1172
1244
|
|
|
1173
1245
|
|
|
1174
1246
|
EXAMPLE_RUN = (
|
SURE/DensityFlowLinear.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
SURE/DensityFlow.py,sha256=
|
|
1
|
+
SURE/DensityFlow.py,sha256=fqqI8sHnfXuTK9O1il-dL7F4W7gbUMjGHD8uRwpESlc,60218
|
|
2
2
|
SURE/DensityFlow2.py,sha256=BBRCoA4NpU4EjghToOvowo17UtwYokTN75KxWYHTX1E,58404
|
|
3
|
-
SURE/DensityFlowLinear.py,sha256=
|
|
3
|
+
SURE/DensityFlowLinear.py,sha256=bYiPHJ6mza4sOXUjlFq7wButu3rNLYZuqWUTtIO06F4,57540
|
|
4
4
|
SURE/EfficientTranscriptomeDecoder.py,sha256=O_x-4edKBU5OJJbOOS-59u3TQElZqhAtOVJMPlpw8m0,21667
|
|
5
5
|
SURE/PerturbE.py,sha256=DxEp-qef--x8-GMZdPfBf8ts8UDDc34h2P5AnpqZ-YM,52265
|
|
6
6
|
SURE/PerturbationAwareDecoder.py,sha256=duhvBvZjOpAk7c2YTfmA2qKbrgVvwT7IW1pxaukq_iU,30231
|
|
@@ -25,9 +25,9 @@ SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
|
25
25
|
SURE/utils/custom_mlp.py,sha256=Rn_PQouxPMSda-KKBYrwVVv3GFFuUmCLxp8cV5LszZo,10580
|
|
26
26
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
27
27
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
28
|
-
sure_tools-2.4.
|
|
29
|
-
sure_tools-2.4.
|
|
30
|
-
sure_tools-2.4.
|
|
31
|
-
sure_tools-2.4.
|
|
32
|
-
sure_tools-2.4.
|
|
33
|
-
sure_tools-2.4.
|
|
28
|
+
sure_tools-2.4.43.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
29
|
+
sure_tools-2.4.43.dist-info/METADATA,sha256=q0DTzGgBqj5Hi8n2YNmJymHD25dZSUdRVlMrfiy-5Hw,2678
|
|
30
|
+
sure_tools-2.4.43.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
sure_tools-2.4.43.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
32
|
+
sure_tools-2.4.43.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
33
|
+
sure_tools-2.4.43.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|