SURE-tools 4.0.1__tar.gz → 4.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-4.0.1 → sure_tools-4.0.3}/PKG-INFO +1 -1
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/SURE.py +6 -6
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/SURE_nsf.py +27 -18
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/SURE_vae.py +30 -21
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-4.0.1 → sure_tools-4.0.3}/setup.py +1 -1
- {sure_tools-4.0.1 → sure_tools-4.0.3}/LICENSE +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/README.md +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/SUREMO.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/SURE_vanilla.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/__init__.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/atac/__init__.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/atac/utils.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/dist/__init__.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/dist/negbinomial.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/graph/__init__.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/graph/graph_utils.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/utils/__init__.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/utils/label.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/utils/queue.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE/utils/utils.py +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-4.0.1 → sure_tools-4.0.3}/setup.cfg +0 -0
|
@@ -79,7 +79,7 @@ class SURE(nn.Module):
|
|
|
79
79
|
Number of features (e.g., genes, peaks, proteins, etc.) per cell.
|
|
80
80
|
codebook_size
|
|
81
81
|
Number of metacells.
|
|
82
|
-
|
|
82
|
+
covariate_sizes
|
|
83
83
|
Number of cell-level factors.
|
|
84
84
|
transforms
|
|
85
85
|
Number of neural spline flows
|
|
@@ -111,7 +111,7 @@ class SURE(nn.Module):
|
|
|
111
111
|
codebook_size: int,
|
|
112
112
|
context_sizes: list = [0],
|
|
113
113
|
perturb_size: int = 0,
|
|
114
|
-
|
|
114
|
+
covariate_sizes: int = 0,
|
|
115
115
|
method: Literal['flow','vae'] = 'vae',
|
|
116
116
|
transforms: int = 1,
|
|
117
117
|
z_dim: int = 30,
|
|
@@ -137,7 +137,7 @@ class SURE(nn.Module):
|
|
|
137
137
|
codebook_size=codebook_size,
|
|
138
138
|
context_sizes=context_sizes,
|
|
139
139
|
perturb_size=perturb_size,
|
|
140
|
-
|
|
140
|
+
covariate_sizes=covariate_sizes,
|
|
141
141
|
transforms=transforms,
|
|
142
142
|
z_dim=z_dim,
|
|
143
143
|
z_dist=z_dist,
|
|
@@ -159,7 +159,7 @@ class SURE(nn.Module):
|
|
|
159
159
|
codebook_size=codebook_size,
|
|
160
160
|
context_sizes=context_sizes,
|
|
161
161
|
perturb_size=perturb_size,
|
|
162
|
-
|
|
162
|
+
covariate_sizes=covariate_sizes,
|
|
163
163
|
z_dim=z_dim,
|
|
164
164
|
z_dist=z_dist,
|
|
165
165
|
loss_func=loss_func,
|
|
@@ -245,7 +245,7 @@ class SURE(nn.Module):
|
|
|
245
245
|
def fit(self, xs:np.array,
|
|
246
246
|
css:list = None,
|
|
247
247
|
ps:np.array = None,
|
|
248
|
-
|
|
248
|
+
fss:list = None,
|
|
249
249
|
num_epochs: int = 100,
|
|
250
250
|
learning_rate: float = 0.0001,
|
|
251
251
|
use_mask: bool = False,
|
|
@@ -291,7 +291,7 @@ class SURE(nn.Module):
|
|
|
291
291
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
292
292
|
the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
|
|
293
293
|
"""
|
|
294
|
-
self.engine.fit(xs=xs, css=css, ps=ps,
|
|
294
|
+
self.engine.fit(xs=xs, css=css, ps=ps, fss=fss, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
|
|
295
295
|
beta_1=beta_1, weight_decay=weight_decay, decay_rate=decay_rate, config_enum=config_enum, threshold=threshold,
|
|
296
296
|
use_jax=use_jax, show_progress=show_progress, patience=patience, min_delta=min_delta, restore_best_weights=restore_best_weights,
|
|
297
297
|
monitor=monitor)
|
|
@@ -89,7 +89,7 @@ class SURENF(nn.Module):
|
|
|
89
89
|
Number of features (e.g., genes, peaks, proteins, etc.) per cell.
|
|
90
90
|
codebook_size
|
|
91
91
|
Number of metacells.
|
|
92
|
-
|
|
92
|
+
covariate_sizes
|
|
93
93
|
Number of cell-level factors.
|
|
94
94
|
transforms
|
|
95
95
|
Number of neural spline flows
|
|
@@ -121,7 +121,7 @@ class SURENF(nn.Module):
|
|
|
121
121
|
codebook_size: int,
|
|
122
122
|
context_sizes: list = [0],
|
|
123
123
|
perturb_size: int = 0,
|
|
124
|
-
|
|
124
|
+
covariate_sizes: int = 0,
|
|
125
125
|
transforms: int = 1,
|
|
126
126
|
z_dim: int = 50,
|
|
127
127
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
@@ -144,7 +144,7 @@ class SURENF(nn.Module):
|
|
|
144
144
|
self.input_dim = input_dim
|
|
145
145
|
self.context_sizes = context_sizes
|
|
146
146
|
self.perturb_size = perturb_size
|
|
147
|
-
self.
|
|
147
|
+
self.covariate_sizes = covariate_sizes
|
|
148
148
|
self.dispersion = dispersion
|
|
149
149
|
self.latent_dim = z_dim
|
|
150
150
|
self.latent_dist = z_dist
|
|
@@ -272,16 +272,19 @@ class SURENF(nn.Module):
|
|
|
272
272
|
allow_broadcast=self.allow_broadcast,
|
|
273
273
|
use_cuda=self.use_cuda,
|
|
274
274
|
)
|
|
275
|
-
if self.
|
|
276
|
-
self.
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
275
|
+
if np.sum(self.covariate_sizes)>0:
|
|
276
|
+
self.covariate_effects = nn.ModuleList()
|
|
277
|
+
for covariate_size in self.covariate_sizes:
|
|
278
|
+
self.covariate_effects.append(ZeroBiasMLP2(
|
|
279
|
+
[covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
280
|
+
activation=activate_fct,
|
|
281
|
+
output_activation=None,
|
|
282
|
+
post_layer_fct=post_layer_fct,
|
|
283
|
+
post_act_fct=post_act_fct,
|
|
284
|
+
allow_broadcast=self.allow_broadcast,
|
|
285
|
+
use_cuda=self.use_cuda,
|
|
286
|
+
)
|
|
287
|
+
)
|
|
285
288
|
|
|
286
289
|
self.decoder_log_mu = MLP(
|
|
287
290
|
[self.latent_dim+self.latent_dim+self.latent_dim] + self.decoder_hidden_layers + [self.input_dim],
|
|
@@ -421,8 +424,13 @@ class SURENF(nn.Module):
|
|
|
421
424
|
zps = self.perturb_effect([ps, zs+zcs])
|
|
422
425
|
else:
|
|
423
426
|
zps = torch.zeros_like(zs)
|
|
424
|
-
if (self.
|
|
425
|
-
zfs =
|
|
427
|
+
if (self.covariate_sizes>0) and (fs is not None):
|
|
428
|
+
zfs = torch.zeros_like(zs)
|
|
429
|
+
shift = 0
|
|
430
|
+
for i, covariate_size in enumerate(self.covariate_sizes):
|
|
431
|
+
fs_i = fs[:,shift:(shift+covariate_size)]
|
|
432
|
+
zfs += self.covariate_effects[i](fs_i)
|
|
433
|
+
shift += covariate_size
|
|
426
434
|
else:
|
|
427
435
|
zfs = torch.zeros_like(zs)
|
|
428
436
|
|
|
@@ -891,7 +899,7 @@ class SURENF(nn.Module):
|
|
|
891
899
|
def fit(self, xs:np.array,
|
|
892
900
|
css:list = None,
|
|
893
901
|
ps: np.array = None,
|
|
894
|
-
|
|
902
|
+
fss:list = None,
|
|
895
903
|
num_epochs: int = 100,
|
|
896
904
|
learning_rate: float = 0.0001,
|
|
897
905
|
use_mask: bool = False,
|
|
@@ -957,7 +965,8 @@ class SURENF(nn.Module):
|
|
|
957
965
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
958
966
|
if ps is not None:
|
|
959
967
|
ps = convert_to_tensor(ps, dtype=self.dtype, device='cpu')
|
|
960
|
-
if
|
|
968
|
+
if fss is not None:
|
|
969
|
+
fs = np.hstack(fss)
|
|
961
970
|
fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
|
|
962
971
|
|
|
963
972
|
dataset = CustomDataset(xs)
|
|
@@ -1019,7 +1028,7 @@ class SURENF(nn.Module):
|
|
|
1019
1028
|
batch_p = None
|
|
1020
1029
|
else:
|
|
1021
1030
|
batch_p = ps[idx].to(self.get_device())
|
|
1022
|
-
if
|
|
1031
|
+
if fss is None:
|
|
1023
1032
|
batch_f = None
|
|
1024
1033
|
else:
|
|
1025
1034
|
batch_f = fs[idx].to(self.get_device())
|
|
@@ -121,7 +121,7 @@ class SUREVAE(nn.Module):
|
|
|
121
121
|
codebook_size: int,
|
|
122
122
|
context_sizes: list = [0],
|
|
123
123
|
perturb_size: int = 0,
|
|
124
|
-
|
|
124
|
+
covariate_sizes: int = 0,
|
|
125
125
|
transforms: int = 1,
|
|
126
126
|
z_dim: int = 50,
|
|
127
127
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
@@ -144,7 +144,7 @@ class SUREVAE(nn.Module):
|
|
|
144
144
|
self.input_dim = input_dim
|
|
145
145
|
self.context_sizes = context_sizes
|
|
146
146
|
self.perturb_size = perturb_size
|
|
147
|
-
self.
|
|
147
|
+
self.covariate_sizes = covariate_sizes
|
|
148
148
|
self.dispersion = dispersion
|
|
149
149
|
self.latent_dim = z_dim
|
|
150
150
|
self.latent_dist = z_dist
|
|
@@ -273,7 +273,7 @@ class SUREVAE(nn.Module):
|
|
|
273
273
|
|
|
274
274
|
if self.perturb_size>0:
|
|
275
275
|
self.perturb_effect = ZeroBiasMLP3(
|
|
276
|
-
[self.perturb_size+self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
276
|
+
[self.perturb_size+self.latent_dim+self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
277
277
|
activation=activate_fct,
|
|
278
278
|
output_activation=None,
|
|
279
279
|
post_layer_fct=post_layer_fct,
|
|
@@ -282,16 +282,19 @@ class SUREVAE(nn.Module):
|
|
|
282
282
|
use_cuda=self.use_cuda,
|
|
283
283
|
)
|
|
284
284
|
|
|
285
|
-
if self.
|
|
286
|
-
self.
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
285
|
+
if np.sum(self.covariate_sizes)>0:
|
|
286
|
+
self.covariate_effects = nn.ModuleList()
|
|
287
|
+
for covariate_size in self.covariate_sizes:
|
|
288
|
+
self.covariate_effects.append(ZeroBiasMLP2(
|
|
289
|
+
[covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
290
|
+
activation=activate_fct,
|
|
291
|
+
output_activation=None,
|
|
292
|
+
post_layer_fct=post_layer_fct,
|
|
293
|
+
post_act_fct=post_act_fct,
|
|
294
|
+
allow_broadcast=self.allow_broadcast,
|
|
295
|
+
use_cuda=self.use_cuda,
|
|
296
|
+
)
|
|
297
|
+
)
|
|
295
298
|
|
|
296
299
|
self.decoder_log_mu = MLP(
|
|
297
300
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_dim],
|
|
@@ -428,11 +431,16 @@ class SUREVAE(nn.Module):
|
|
|
428
431
|
else:
|
|
429
432
|
zcs = torch.zeros_like(zs)
|
|
430
433
|
if (self.perturb_size>0) and (ps is not None):
|
|
431
|
-
zps = self.perturb_effect([ps, zs
|
|
434
|
+
zps = self.perturb_effect([ps, zs, zcs])
|
|
432
435
|
else:
|
|
433
436
|
zps = torch.zeros_like(zs)
|
|
434
|
-
if (self.
|
|
435
|
-
zfs =
|
|
437
|
+
if (np.sum(self.covariate_sizes)>0) and (fs is not None):
|
|
438
|
+
zfs = torch.zeros_like(zs)
|
|
439
|
+
shift = 0
|
|
440
|
+
for i, covariate_size in enumerate(self.covariate_sizes):
|
|
441
|
+
fs_i = fs[:,shift:(shift+covariate_size)]
|
|
442
|
+
zfs += self.covariate_effects[i](fs_i)
|
|
443
|
+
shift += covariate_size
|
|
436
444
|
else:
|
|
437
445
|
zfs = torch.zeros_like(zs)
|
|
438
446
|
|
|
@@ -640,7 +648,7 @@ class SUREVAE(nn.Module):
|
|
|
640
648
|
C_batch = zcs[idx].to(self.get_device())
|
|
641
649
|
P_batch = ps[idx].to(self.get_device())
|
|
642
650
|
|
|
643
|
-
dzs = self.perturb_effect([P_batch,Z_batch
|
|
651
|
+
dzs = self.perturb_effect([P_batch,Z_batch,C_batch])
|
|
644
652
|
|
|
645
653
|
A.append(tensor_to_numpy(dzs))
|
|
646
654
|
pbar.update(1)
|
|
@@ -693,7 +701,7 @@ class SUREVAE(nn.Module):
|
|
|
693
701
|
|
|
694
702
|
if ps is not None:
|
|
695
703
|
P_batch = ps[idx].to(self.get_device())
|
|
696
|
-
zps = self.perturb_effect([P_batch,z_basal
|
|
704
|
+
zps = self.perturb_effect([P_batch,z_basal,zcs])
|
|
697
705
|
else:
|
|
698
706
|
zps = torch.zeros_like(z_basal)
|
|
699
707
|
|
|
@@ -825,7 +833,7 @@ class SUREVAE(nn.Module):
|
|
|
825
833
|
def fit(self, xs: np.array,
|
|
826
834
|
css: list = None,
|
|
827
835
|
ps: np.array = None,
|
|
828
|
-
|
|
836
|
+
fss: list = None,
|
|
829
837
|
num_epochs: int = 100,
|
|
830
838
|
learning_rate: float = 0.0001,
|
|
831
839
|
use_mask: bool = False,
|
|
@@ -891,7 +899,8 @@ class SUREVAE(nn.Module):
|
|
|
891
899
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
892
900
|
if ps is not None:
|
|
893
901
|
ps = convert_to_tensor(ps, dtype=self.dtype, device='cpu')
|
|
894
|
-
if
|
|
902
|
+
if fss is not None:
|
|
903
|
+
fs = np.hstack(fss)
|
|
895
904
|
fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
|
|
896
905
|
|
|
897
906
|
dataset = CustomDataset(xs)
|
|
@@ -953,7 +962,7 @@ class SUREVAE(nn.Module):
|
|
|
953
962
|
batch_p = None
|
|
954
963
|
else:
|
|
955
964
|
batch_p = ps[idx].to(self.get_device())
|
|
956
|
-
if
|
|
965
|
+
if fss is None:
|
|
957
966
|
batch_f = None
|
|
958
967
|
else:
|
|
959
968
|
batch_f = fs[idx].to(self.get_device())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|