SURE-tools 3.6.13__tar.gz → 3.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-3.6.13 → sure_tools-3.7.0}/PKG-INFO +1 -1
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE.py +11 -11
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE_nsf.py +42 -28
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE_vae.py +40 -25
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-3.6.13 → sure_tools-3.7.0}/setup.py +1 -1
- {sure_tools-3.6.13 → sure_tools-3.7.0}/LICENSE +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/README.md +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SUREMO.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/SURE_vanilla.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/__init__.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/atac/__init__.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/atac/utils.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/dist/__init__.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/dist/negbinomial.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/graph/__init__.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/graph/graph_utils.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/__init__.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/label.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/queue.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE/utils/utils.py +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-3.6.13 → sure_tools-3.7.0}/setup.cfg +0 -0
|
@@ -109,7 +109,7 @@ class SURE(nn.Module):
|
|
|
109
109
|
def __init__(self,
|
|
110
110
|
input_dim: int,
|
|
111
111
|
codebook_size: int,
|
|
112
|
-
|
|
112
|
+
condition_sizes: int = 0,
|
|
113
113
|
covariate_size: int = 0,
|
|
114
114
|
method: Literal['flow','vae'] = 'vae',
|
|
115
115
|
transforms: int = 1,
|
|
@@ -134,7 +134,7 @@ class SURE(nn.Module):
|
|
|
134
134
|
if method == 'flow':
|
|
135
135
|
self.engine = SURENF(input_dim=input_dim,
|
|
136
136
|
codebook_size=codebook_size,
|
|
137
|
-
|
|
137
|
+
condition_sizes=condition_sizes,
|
|
138
138
|
covariate_size=covariate_size,
|
|
139
139
|
transforms=transforms,
|
|
140
140
|
z_dim=z_dim,
|
|
@@ -155,7 +155,7 @@ class SURE(nn.Module):
|
|
|
155
155
|
elif method == 'vae':
|
|
156
156
|
self.engine = SUREVAE(input_dim=input_dim,
|
|
157
157
|
codebook_size=codebook_size,
|
|
158
|
-
|
|
158
|
+
condition_sizes=condition_sizes,
|
|
159
159
|
covariate_size=covariate_size,
|
|
160
160
|
z_dim=z_dim,
|
|
161
161
|
z_dist=z_dist,
|
|
@@ -214,13 +214,13 @@ class SURE(nn.Module):
|
|
|
214
214
|
"""
|
|
215
215
|
return self.engine.hard_assignments(xs=xs, batch_size=batch_size, show_progress=show_progress)
|
|
216
216
|
|
|
217
|
-
def
|
|
218
|
-
return self.engine.
|
|
217
|
+
def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
|
|
218
|
+
return self.engine.get_condition_effect(xs, cs, i, batch_size=batch_size, show_progress=show_progress)
|
|
219
219
|
|
|
220
220
|
def predict_cluster(self, xs, batch_size=1024, show_progress=True):
|
|
221
221
|
return self.engine.predict_cluster(xs, batch_size=batch_size, show_progress=show_progress)
|
|
222
222
|
|
|
223
|
-
def predict(self, xs,
|
|
223
|
+
def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
|
|
224
224
|
"""
|
|
225
225
|
Generate gene expression prediction from given cell data and covariates.
|
|
226
226
|
This function can be used for simulating cells' transcription profiles at new conditions.
|
|
@@ -231,14 +231,14 @@ class SURE(nn.Module):
|
|
|
231
231
|
:param batch_size: Data size per batch
|
|
232
232
|
:param show_progress: Toggle on or off message output
|
|
233
233
|
"""
|
|
234
|
-
return self.engine.predict(xs,
|
|
234
|
+
return self.engine.predict(xs, cs_list, batch_size, show_progress)
|
|
235
235
|
|
|
236
236
|
def preprocess(self, xs, threshold=0):
|
|
237
237
|
return self.engine.preprocess(xs=xs, threshold=threshold)
|
|
238
238
|
|
|
239
|
-
def fit(self, xs,
|
|
240
|
-
|
|
241
|
-
fs = None,
|
|
239
|
+
def fit(self, xs:np.array,
|
|
240
|
+
css:list = None,
|
|
241
|
+
fs:np.array = None,
|
|
242
242
|
num_epochs: int = 100,
|
|
243
243
|
learning_rate: float = 0.0001,
|
|
244
244
|
use_mask: bool = False,
|
|
@@ -284,7 +284,7 @@ class SURE(nn.Module):
|
|
|
284
284
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
285
285
|
the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
|
|
286
286
|
"""
|
|
287
|
-
self.engine.fit(xs=xs,
|
|
287
|
+
self.engine.fit(xs=xs, css=css, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
|
|
288
288
|
beta_1=beta_1, weight_decay=weight_decay, decay_rate=decay_rate, config_enum=config_enum, threshold=threshold,
|
|
289
289
|
use_jax=use_jax, show_progress=show_progress, patience=patience, min_delta=min_delta, restore_best_weights=restore_best_weights,
|
|
290
290
|
monitor=monitor)
|
|
@@ -119,12 +119,12 @@ class SURENF(nn.Module):
|
|
|
119
119
|
def __init__(self,
|
|
120
120
|
input_dim: int,
|
|
121
121
|
codebook_size: int,
|
|
122
|
-
|
|
122
|
+
condition_sizes: list = [0],
|
|
123
123
|
covariate_size: int = 0,
|
|
124
124
|
transforms: int = 1,
|
|
125
125
|
z_dim: int = 50,
|
|
126
126
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
127
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
127
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
128
128
|
dispersion: float = 10.0,
|
|
129
129
|
use_zeroinflate: bool = True,
|
|
130
130
|
hidden_layers: list = [500],
|
|
@@ -141,7 +141,7 @@ class SURENF(nn.Module):
|
|
|
141
141
|
super().__init__()
|
|
142
142
|
|
|
143
143
|
self.input_dim = input_dim
|
|
144
|
-
self.
|
|
144
|
+
self.condition_sizes = condition_sizes
|
|
145
145
|
self.covariate_size = covariate_size
|
|
146
146
|
self.dispersion = dispersion
|
|
147
147
|
self.latent_dim = z_dim
|
|
@@ -247,16 +247,19 @@ class SURENF(nn.Module):
|
|
|
247
247
|
self.encoder_zn = zuko.flows.NSF(features=self.latent_dim, context=self.input_dim,
|
|
248
248
|
transforms=self.transforms, hidden_features=self.flow_hidden_layers)
|
|
249
249
|
|
|
250
|
-
if self.
|
|
251
|
-
self.
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
250
|
+
if np.sum(self.condition_sizes)>0:
|
|
251
|
+
self.condition_effects = nn.ModuleList()
|
|
252
|
+
for condition_size in self.condition_sizes:
|
|
253
|
+
self.condition_effects.append(ZeroBiasMLP3(
|
|
254
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
255
|
+
activation=activate_fct,
|
|
256
|
+
output_activation=None,
|
|
257
|
+
post_layer_fct=post_layer_fct,
|
|
258
|
+
post_act_fct=post_act_fct,
|
|
259
|
+
allow_broadcast=self.allow_broadcast,
|
|
260
|
+
use_cuda=self.use_cuda,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
260
263
|
if self.covariate_size>0:
|
|
261
264
|
self.covariate_effect = ZeroBiasMLP2(
|
|
262
265
|
[self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
@@ -393,8 +396,13 @@ class SURENF(nn.Module):
|
|
|
393
396
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
394
397
|
|
|
395
398
|
zs = zns
|
|
396
|
-
if (self.
|
|
397
|
-
zcs =
|
|
399
|
+
if (np.sum(self.condition_sizes)>0) and (cs is not None):
|
|
400
|
+
zcs = torch.zeros_like(zs)
|
|
401
|
+
shift = 0
|
|
402
|
+
for i,condition_size in enumerate(self.condition_sizes):
|
|
403
|
+
cs_i = cs[:,shift:(shift+condition_size)]
|
|
404
|
+
zcs += self.condition_effects[i]([cs_i,zns])
|
|
405
|
+
shift += condition_size
|
|
398
406
|
else:
|
|
399
407
|
zcs = torch.zeros_like(zs)
|
|
400
408
|
if (self.covariate_size>0) and (fs is not None):
|
|
@@ -565,7 +573,7 @@ class SURENF(nn.Module):
|
|
|
565
573
|
A = np.concatenate(A)
|
|
566
574
|
return A
|
|
567
575
|
|
|
568
|
-
def
|
|
576
|
+
def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
|
|
569
577
|
xs = self.preprocess(xs)
|
|
570
578
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
571
579
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
@@ -580,11 +588,8 @@ class SURENF(nn.Module):
|
|
|
580
588
|
X_batch = X_batch.to(self.get_device())
|
|
581
589
|
C_batch = cs[idx].to(self.get_device())
|
|
582
590
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
ns = self._soft_assignments(X_batch)
|
|
586
|
-
z_basal = torch.matmul(ns, cb_loc)
|
|
587
|
-
dzs = self.condition_effect([C_batch,z_basal])
|
|
591
|
+
z_basal = self._get_cell_embedding(X_batch)
|
|
592
|
+
dzs = self.condition_effects[i]([C_batch,z_basal])
|
|
588
593
|
|
|
589
594
|
A.append(tensor_to_numpy(dzs))
|
|
590
595
|
pbar.update(1)
|
|
@@ -596,7 +601,7 @@ class SURENF(nn.Module):
|
|
|
596
601
|
zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
|
|
597
602
|
return self.kmeans.predict(zs)
|
|
598
603
|
|
|
599
|
-
def predict(self, xs,
|
|
604
|
+
def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
|
|
600
605
|
"""
|
|
601
606
|
Generate gene expression prediction from given cell data and covariates.
|
|
602
607
|
This function can be used for simulating cells' transcription profiles at new conditions.
|
|
@@ -609,6 +614,7 @@ class SURENF(nn.Module):
|
|
|
609
614
|
"""
|
|
610
615
|
xs = self.preprocess(xs)
|
|
611
616
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
617
|
+
cs = np.hstack(cs_list)
|
|
612
618
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
613
619
|
|
|
614
620
|
dataset = CustomDataset(xs)
|
|
@@ -622,7 +628,14 @@ class SURENF(nn.Module):
|
|
|
622
628
|
library_size = torch.sum(X_batch, 1)
|
|
623
629
|
|
|
624
630
|
z_basal = self._get_cell_embedding(X_batch)
|
|
625
|
-
|
|
631
|
+
|
|
632
|
+
zcs = torch.zeros_like(z_basal)
|
|
633
|
+
shift = 0
|
|
634
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
635
|
+
C_batch_i = C_batch[:, shift:(shift+condition_size)]
|
|
636
|
+
zcs += self.condition_effects[i]([C_batch_i,z_basal])
|
|
637
|
+
shift += condition_size
|
|
638
|
+
|
|
626
639
|
zfs = torch.zeros_like(z_basal)
|
|
627
640
|
|
|
628
641
|
log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
|
|
@@ -829,9 +842,9 @@ class SURENF(nn.Module):
|
|
|
829
842
|
if name in param_store:
|
|
830
843
|
param_store[name] = param'''
|
|
831
844
|
|
|
832
|
-
def fit(self, xs,
|
|
833
|
-
|
|
834
|
-
fs = None,
|
|
845
|
+
def fit(self, xs:np.array,
|
|
846
|
+
css:list = None,
|
|
847
|
+
fs:np.array = None,
|
|
835
848
|
num_epochs: int = 100,
|
|
836
849
|
learning_rate: float = 0.0001,
|
|
837
850
|
use_mask: bool = False,
|
|
@@ -892,7 +905,8 @@ class SURENF(nn.Module):
|
|
|
892
905
|
|
|
893
906
|
xs = self.preprocess(xs, threshold=threshold)
|
|
894
907
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
895
|
-
if
|
|
908
|
+
if css is not None:
|
|
909
|
+
cs = np.hstack(css)
|
|
896
910
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
897
911
|
if fs is not None:
|
|
898
912
|
fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
|
|
@@ -948,7 +962,7 @@ class SURENF(nn.Module):
|
|
|
948
962
|
for batch_x, idx in dataloader:
|
|
949
963
|
batch_x = batch_x.to(self.get_device())
|
|
950
964
|
for loss_id in range(num_losses):
|
|
951
|
-
if
|
|
965
|
+
if css is None:
|
|
952
966
|
batch_c = None
|
|
953
967
|
else:
|
|
954
968
|
batch_c = cs[idx].to(self.get_device())
|
|
@@ -119,7 +119,7 @@ class SUREVAE(nn.Module):
|
|
|
119
119
|
def __init__(self,
|
|
120
120
|
input_dim: int,
|
|
121
121
|
codebook_size: int,
|
|
122
|
-
|
|
122
|
+
condition_sizes: list = [0],
|
|
123
123
|
covariate_size: int = 0,
|
|
124
124
|
transforms: int = 1,
|
|
125
125
|
z_dim: int = 50,
|
|
@@ -141,7 +141,7 @@ class SUREVAE(nn.Module):
|
|
|
141
141
|
super().__init__()
|
|
142
142
|
|
|
143
143
|
self.input_dim = input_dim
|
|
144
|
-
self.
|
|
144
|
+
self.condition_sizes = condition_sizes
|
|
145
145
|
self.covariate_size = covariate_size
|
|
146
146
|
self.dispersion = dispersion
|
|
147
147
|
self.latent_dim = z_dim
|
|
@@ -255,16 +255,19 @@ class SUREVAE(nn.Module):
|
|
|
255
255
|
use_cuda=self.use_cuda,
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
-
if self.
|
|
259
|
-
self.
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
258
|
+
if np.sum(self.condition_sizes)>0:
|
|
259
|
+
self.condition_effects = nn.ModuleList()
|
|
260
|
+
for condition_size in self.condition_sizes:
|
|
261
|
+
self.condition_effects.append(ZeroBiasMLP3(
|
|
262
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
263
|
+
activation=activate_fct,
|
|
264
|
+
output_activation=None,
|
|
265
|
+
post_layer_fct=post_layer_fct,
|
|
266
|
+
post_act_fct=post_act_fct,
|
|
267
|
+
allow_broadcast=self.allow_broadcast,
|
|
268
|
+
use_cuda=self.use_cuda,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
268
271
|
if self.covariate_size>0:
|
|
269
272
|
self.covariate_effect = ZeroBiasMLP2(
|
|
270
273
|
[self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
@@ -401,8 +404,13 @@ class SUREVAE(nn.Module):
|
|
|
401
404
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
402
405
|
|
|
403
406
|
zs = zns
|
|
404
|
-
if (self.
|
|
405
|
-
zcs =
|
|
407
|
+
if (np.sum(self.condition_sizes)>0) and (cs is not None):
|
|
408
|
+
zcs = torch.zeros_like(zs)
|
|
409
|
+
shift = 0
|
|
410
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
411
|
+
cs_i = cs[:,shift:(shift+condition_size)]
|
|
412
|
+
zcs += self.condition_effects[i]([cs_i,zns])
|
|
413
|
+
shift += condition_size
|
|
406
414
|
else:
|
|
407
415
|
zcs = torch.zeros_like(zs)
|
|
408
416
|
if (self.covariate_size>0) and (fs is not None):
|
|
@@ -578,7 +586,7 @@ class SUREVAE(nn.Module):
|
|
|
578
586
|
A = np.concatenate(A)
|
|
579
587
|
return A
|
|
580
588
|
|
|
581
|
-
def
|
|
589
|
+
def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
|
|
582
590
|
xs = self.preprocess(xs)
|
|
583
591
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
584
592
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
@@ -593,10 +601,8 @@ class SUREVAE(nn.Module):
|
|
|
593
601
|
X_batch = X_batch.to(self.get_device())
|
|
594
602
|
C_batch = cs[idx].to(self.get_device())
|
|
595
603
|
|
|
596
|
-
#ns = self._soft_assignments(X_batch)
|
|
597
|
-
#z_basal = torch.matmul(ns, cb_loc)
|
|
598
604
|
z_basal = self._get_cell_embedding(X_batch)
|
|
599
|
-
dzs = self.
|
|
605
|
+
dzs = self.condition_effects[i]([C_batch,z_basal])
|
|
600
606
|
|
|
601
607
|
A.append(tensor_to_numpy(dzs))
|
|
602
608
|
pbar.update(1)
|
|
@@ -608,7 +614,7 @@ class SUREVAE(nn.Module):
|
|
|
608
614
|
zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
|
|
609
615
|
return self.kmeans.predict(zs)
|
|
610
616
|
|
|
611
|
-
def predict(self, xs,
|
|
617
|
+
def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
|
|
612
618
|
"""
|
|
613
619
|
Generate gene expression prediction from given cell data and covariates.
|
|
614
620
|
This function can be used for simulating cells' transcription profiles at new conditions.
|
|
@@ -621,6 +627,7 @@ class SUREVAE(nn.Module):
|
|
|
621
627
|
"""
|
|
622
628
|
xs = self.preprocess(xs)
|
|
623
629
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
630
|
+
cs = np.hstack(cs_list)
|
|
624
631
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
625
632
|
|
|
626
633
|
dataset = CustomDataset(xs)
|
|
@@ -634,7 +641,14 @@ class SUREVAE(nn.Module):
|
|
|
634
641
|
library_size = torch.sum(X_batch, 1)
|
|
635
642
|
|
|
636
643
|
z_basal = self._get_cell_embedding(X_batch)
|
|
637
|
-
|
|
644
|
+
|
|
645
|
+
zcs = torch.zeros_like(z_basal)
|
|
646
|
+
shift = 0
|
|
647
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
648
|
+
C_batch_i = C_batch[:, shift:(shift+condition_size)]
|
|
649
|
+
zcs += self.condition_effects[i]([C_batch_i,z_basal])
|
|
650
|
+
shift += condition_size
|
|
651
|
+
|
|
638
652
|
zfs = torch.zeros_like(z_basal)
|
|
639
653
|
|
|
640
654
|
log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
|
|
@@ -760,9 +774,9 @@ class SUREVAE(nn.Module):
|
|
|
760
774
|
pbar.set_postfix({'loss': str_loss})
|
|
761
775
|
pbar.update(1)'''
|
|
762
776
|
|
|
763
|
-
def fit(self, xs,
|
|
764
|
-
|
|
765
|
-
fs = None,
|
|
777
|
+
def fit(self, xs: np.array,
|
|
778
|
+
css: list = None,
|
|
779
|
+
fs: np.array = None,
|
|
766
780
|
num_epochs: int = 100,
|
|
767
781
|
learning_rate: float = 0.0001,
|
|
768
782
|
use_mask: bool = False,
|
|
@@ -823,7 +837,8 @@ class SUREVAE(nn.Module):
|
|
|
823
837
|
|
|
824
838
|
xs = self.preprocess(xs, threshold=threshold)
|
|
825
839
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
826
|
-
if
|
|
840
|
+
if css is not None:
|
|
841
|
+
cs = np.hstack(css)
|
|
827
842
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
828
843
|
if fs is not None:
|
|
829
844
|
fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
|
|
@@ -879,7 +894,7 @@ class SUREVAE(nn.Module):
|
|
|
879
894
|
for batch_x, idx in dataloader:
|
|
880
895
|
batch_x = batch_x.to(self.get_device())
|
|
881
896
|
for loss_id in range(num_losses):
|
|
882
|
-
if
|
|
897
|
+
if css is None:
|
|
883
898
|
batch_c = None
|
|
884
899
|
else:
|
|
885
900
|
batch_c = cs[idx].to(self.get_device())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|