SURE-tools 3.6.14__tar.gz → 3.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sure_tools-3.6.14 → sure_tools-3.7.0}/PKG-INFO +1 -1
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE.py +11 -11
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE_nsf.py +41 -24
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE_vae.py +40 -23
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-3.6.14 → sure_tools-3.7.0}/setup.py +1 -1
- {sure_tools-3.6.14 → sure_tools-3.7.0}/LICENSE +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/README.md +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SUREMO.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/SURE_vanilla.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/__init__.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/atac/__init__.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/atac/utils.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/dist/__init__.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/dist/negbinomial.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/graph/__init__.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/graph/graph_utils.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/__init__.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/label.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/queue.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE/utils/utils.py +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-3.6.14 → sure_tools-3.7.0}/setup.cfg +0 -0
|
@@ -109,7 +109,7 @@ class SURE(nn.Module):
|
|
|
109
109
|
def __init__(self,
|
|
110
110
|
input_dim: int,
|
|
111
111
|
codebook_size: int,
|
|
112
|
-
|
|
112
|
+
condition_sizes: int = 0,
|
|
113
113
|
covariate_size: int = 0,
|
|
114
114
|
method: Literal['flow','vae'] = 'vae',
|
|
115
115
|
transforms: int = 1,
|
|
@@ -134,7 +134,7 @@ class SURE(nn.Module):
|
|
|
134
134
|
if method == 'flow':
|
|
135
135
|
self.engine = SURENF(input_dim=input_dim,
|
|
136
136
|
codebook_size=codebook_size,
|
|
137
|
-
|
|
137
|
+
condition_sizes=condition_sizes,
|
|
138
138
|
covariate_size=covariate_size,
|
|
139
139
|
transforms=transforms,
|
|
140
140
|
z_dim=z_dim,
|
|
@@ -155,7 +155,7 @@ class SURE(nn.Module):
|
|
|
155
155
|
elif method == 'vae':
|
|
156
156
|
self.engine = SUREVAE(input_dim=input_dim,
|
|
157
157
|
codebook_size=codebook_size,
|
|
158
|
-
|
|
158
|
+
condition_sizes=condition_sizes,
|
|
159
159
|
covariate_size=covariate_size,
|
|
160
160
|
z_dim=z_dim,
|
|
161
161
|
z_dist=z_dist,
|
|
@@ -214,13 +214,13 @@ class SURE(nn.Module):
|
|
|
214
214
|
"""
|
|
215
215
|
return self.engine.hard_assignments(xs=xs, batch_size=batch_size, show_progress=show_progress)
|
|
216
216
|
|
|
217
|
-
def
|
|
218
|
-
return self.engine.
|
|
217
|
+
def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
|
|
218
|
+
return self.engine.get_condition_effect(xs, cs, i, batch_size=batch_size, show_progress=show_progress)
|
|
219
219
|
|
|
220
220
|
def predict_cluster(self, xs, batch_size=1024, show_progress=True):
|
|
221
221
|
return self.engine.predict_cluster(xs, batch_size=batch_size, show_progress=show_progress)
|
|
222
222
|
|
|
223
|
-
def predict(self, xs,
|
|
223
|
+
def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
|
|
224
224
|
"""
|
|
225
225
|
Generate gene expression prediction from given cell data and covariates.
|
|
226
226
|
This function can be used for simulating cells' transcription profiles at new conditions.
|
|
@@ -231,14 +231,14 @@ class SURE(nn.Module):
|
|
|
231
231
|
:param batch_size: Data size per batch
|
|
232
232
|
:param show_progress: Toggle on or off message output
|
|
233
233
|
"""
|
|
234
|
-
return self.engine.predict(xs,
|
|
234
|
+
return self.engine.predict(xs, cs_list, batch_size, show_progress)
|
|
235
235
|
|
|
236
236
|
def preprocess(self, xs, threshold=0):
|
|
237
237
|
return self.engine.preprocess(xs=xs, threshold=threshold)
|
|
238
238
|
|
|
239
|
-
def fit(self, xs,
|
|
240
|
-
|
|
241
|
-
fs = None,
|
|
239
|
+
def fit(self, xs:np.array,
|
|
240
|
+
css:list = None,
|
|
241
|
+
fs:np.array = None,
|
|
242
242
|
num_epochs: int = 100,
|
|
243
243
|
learning_rate: float = 0.0001,
|
|
244
244
|
use_mask: bool = False,
|
|
@@ -284,7 +284,7 @@ class SURE(nn.Module):
|
|
|
284
284
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
285
285
|
the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
|
|
286
286
|
"""
|
|
287
|
-
self.engine.fit(xs=xs,
|
|
287
|
+
self.engine.fit(xs=xs, css=css, fs=fs, num_epochs=num_epochs, learning_rate=learning_rate, use_mask=use_mask, mask_ratio=mask_ratio, batch_size=batch_size, algo=algo,
|
|
288
288
|
beta_1=beta_1, weight_decay=weight_decay, decay_rate=decay_rate, config_enum=config_enum, threshold=threshold,
|
|
289
289
|
use_jax=use_jax, show_progress=show_progress, patience=patience, min_delta=min_delta, restore_best_weights=restore_best_weights,
|
|
290
290
|
monitor=monitor)
|
|
@@ -119,12 +119,12 @@ class SURENF(nn.Module):
|
|
|
119
119
|
def __init__(self,
|
|
120
120
|
input_dim: int,
|
|
121
121
|
codebook_size: int,
|
|
122
|
-
|
|
122
|
+
condition_sizes: list = [0],
|
|
123
123
|
covariate_size: int = 0,
|
|
124
124
|
transforms: int = 1,
|
|
125
125
|
z_dim: int = 50,
|
|
126
126
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
|
|
127
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
127
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
128
128
|
dispersion: float = 10.0,
|
|
129
129
|
use_zeroinflate: bool = True,
|
|
130
130
|
hidden_layers: list = [500],
|
|
@@ -141,7 +141,7 @@ class SURENF(nn.Module):
|
|
|
141
141
|
super().__init__()
|
|
142
142
|
|
|
143
143
|
self.input_dim = input_dim
|
|
144
|
-
self.
|
|
144
|
+
self.condition_sizes = condition_sizes
|
|
145
145
|
self.covariate_size = covariate_size
|
|
146
146
|
self.dispersion = dispersion
|
|
147
147
|
self.latent_dim = z_dim
|
|
@@ -247,16 +247,19 @@ class SURENF(nn.Module):
|
|
|
247
247
|
self.encoder_zn = zuko.flows.NSF(features=self.latent_dim, context=self.input_dim,
|
|
248
248
|
transforms=self.transforms, hidden_features=self.flow_hidden_layers)
|
|
249
249
|
|
|
250
|
-
if self.
|
|
251
|
-
self.
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
250
|
+
if np.sum(self.condition_sizes)>0:
|
|
251
|
+
self.condition_effects = nn.ModuleList()
|
|
252
|
+
for condition_size in self.condition_sizes:
|
|
253
|
+
self.condition_effects.append(ZeroBiasMLP3(
|
|
254
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
255
|
+
activation=activate_fct,
|
|
256
|
+
output_activation=None,
|
|
257
|
+
post_layer_fct=post_layer_fct,
|
|
258
|
+
post_act_fct=post_act_fct,
|
|
259
|
+
allow_broadcast=self.allow_broadcast,
|
|
260
|
+
use_cuda=self.use_cuda,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
260
263
|
if self.covariate_size>0:
|
|
261
264
|
self.covariate_effect = ZeroBiasMLP2(
|
|
262
265
|
[self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
@@ -393,8 +396,13 @@ class SURENF(nn.Module):
|
|
|
393
396
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
394
397
|
|
|
395
398
|
zs = zns
|
|
396
|
-
if (self.
|
|
397
|
-
zcs =
|
|
399
|
+
if (np.sum(self.condition_sizes)>0) and (cs is not None):
|
|
400
|
+
zcs = torch.zeros_like(zs)
|
|
401
|
+
shift = 0
|
|
402
|
+
for i,condition_size in enumerate(self.condition_sizes):
|
|
403
|
+
cs_i = cs[:,shift:(shift+condition_size)]
|
|
404
|
+
zcs += self.condition_effects[i]([cs_i,zns])
|
|
405
|
+
shift += condition_size
|
|
398
406
|
else:
|
|
399
407
|
zcs = torch.zeros_like(zs)
|
|
400
408
|
if (self.covariate_size>0) and (fs is not None):
|
|
@@ -565,7 +573,7 @@ class SURENF(nn.Module):
|
|
|
565
573
|
A = np.concatenate(A)
|
|
566
574
|
return A
|
|
567
575
|
|
|
568
|
-
def
|
|
576
|
+
def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
|
|
569
577
|
xs = self.preprocess(xs)
|
|
570
578
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
571
579
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
@@ -581,7 +589,7 @@ class SURENF(nn.Module):
|
|
|
581
589
|
C_batch = cs[idx].to(self.get_device())
|
|
582
590
|
|
|
583
591
|
z_basal = self._get_cell_embedding(X_batch)
|
|
584
|
-
dzs = self.
|
|
592
|
+
dzs = self.condition_effects[i]([C_batch,z_basal])
|
|
585
593
|
|
|
586
594
|
A.append(tensor_to_numpy(dzs))
|
|
587
595
|
pbar.update(1)
|
|
@@ -593,7 +601,7 @@ class SURENF(nn.Module):
|
|
|
593
601
|
zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
|
|
594
602
|
return self.kmeans.predict(zs)
|
|
595
603
|
|
|
596
|
-
def predict(self, xs,
|
|
604
|
+
def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
|
|
597
605
|
"""
|
|
598
606
|
Generate gene expression prediction from given cell data and covariates.
|
|
599
607
|
This function can be used for simulating cells' transcription profiles at new conditions.
|
|
@@ -606,6 +614,7 @@ class SURENF(nn.Module):
|
|
|
606
614
|
"""
|
|
607
615
|
xs = self.preprocess(xs)
|
|
608
616
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
617
|
+
cs = np.hstack(cs_list)
|
|
609
618
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
610
619
|
|
|
611
620
|
dataset = CustomDataset(xs)
|
|
@@ -619,7 +628,14 @@ class SURENF(nn.Module):
|
|
|
619
628
|
library_size = torch.sum(X_batch, 1)
|
|
620
629
|
|
|
621
630
|
z_basal = self._get_cell_embedding(X_batch)
|
|
622
|
-
|
|
631
|
+
|
|
632
|
+
zcs = torch.zeros_like(z_basal)
|
|
633
|
+
shift = 0
|
|
634
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
635
|
+
C_batch_i = C_batch[:, shift:(shift+condition_size)]
|
|
636
|
+
zcs += self.condition_effects[i]([C_batch_i,z_basal])
|
|
637
|
+
shift += condition_size
|
|
638
|
+
|
|
623
639
|
zfs = torch.zeros_like(z_basal)
|
|
624
640
|
|
|
625
641
|
log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
|
|
@@ -826,9 +842,9 @@ class SURENF(nn.Module):
|
|
|
826
842
|
if name in param_store:
|
|
827
843
|
param_store[name] = param'''
|
|
828
844
|
|
|
829
|
-
def fit(self, xs,
|
|
830
|
-
|
|
831
|
-
fs = None,
|
|
845
|
+
def fit(self, xs:np.array,
|
|
846
|
+
css:list = None,
|
|
847
|
+
fs:np.array = None,
|
|
832
848
|
num_epochs: int = 100,
|
|
833
849
|
learning_rate: float = 0.0001,
|
|
834
850
|
use_mask: bool = False,
|
|
@@ -889,7 +905,8 @@ class SURENF(nn.Module):
|
|
|
889
905
|
|
|
890
906
|
xs = self.preprocess(xs, threshold=threshold)
|
|
891
907
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
892
|
-
if
|
|
908
|
+
if css is not None:
|
|
909
|
+
cs = np.hstack(css)
|
|
893
910
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
894
911
|
if fs is not None:
|
|
895
912
|
fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
|
|
@@ -945,7 +962,7 @@ class SURENF(nn.Module):
|
|
|
945
962
|
for batch_x, idx in dataloader:
|
|
946
963
|
batch_x = batch_x.to(self.get_device())
|
|
947
964
|
for loss_id in range(num_losses):
|
|
948
|
-
if
|
|
965
|
+
if css is None:
|
|
949
966
|
batch_c = None
|
|
950
967
|
else:
|
|
951
968
|
batch_c = cs[idx].to(self.get_device())
|
|
@@ -119,7 +119,7 @@ class SUREVAE(nn.Module):
|
|
|
119
119
|
def __init__(self,
|
|
120
120
|
input_dim: int,
|
|
121
121
|
codebook_size: int,
|
|
122
|
-
|
|
122
|
+
condition_sizes: list = [0],
|
|
123
123
|
covariate_size: int = 0,
|
|
124
124
|
transforms: int = 1,
|
|
125
125
|
z_dim: int = 50,
|
|
@@ -141,7 +141,7 @@ class SUREVAE(nn.Module):
|
|
|
141
141
|
super().__init__()
|
|
142
142
|
|
|
143
143
|
self.input_dim = input_dim
|
|
144
|
-
self.
|
|
144
|
+
self.condition_sizes = condition_sizes
|
|
145
145
|
self.covariate_size = covariate_size
|
|
146
146
|
self.dispersion = dispersion
|
|
147
147
|
self.latent_dim = z_dim
|
|
@@ -255,16 +255,19 @@ class SUREVAE(nn.Module):
|
|
|
255
255
|
use_cuda=self.use_cuda,
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
-
if self.
|
|
259
|
-
self.
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
258
|
+
if np.sum(self.condition_sizes)>0:
|
|
259
|
+
self.condition_effects = nn.ModuleList()
|
|
260
|
+
for condition_size in self.condition_sizes:
|
|
261
|
+
self.condition_effects.append(ZeroBiasMLP3(
|
|
262
|
+
[condition_size + self.latent_dim] + self.decoder_hidden_layers + [self.latent_dim],
|
|
263
|
+
activation=activate_fct,
|
|
264
|
+
output_activation=None,
|
|
265
|
+
post_layer_fct=post_layer_fct,
|
|
266
|
+
post_act_fct=post_act_fct,
|
|
267
|
+
allow_broadcast=self.allow_broadcast,
|
|
268
|
+
use_cuda=self.use_cuda,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
268
271
|
if self.covariate_size>0:
|
|
269
272
|
self.covariate_effect = ZeroBiasMLP2(
|
|
270
273
|
[self.covariate_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
@@ -401,8 +404,13 @@ class SUREVAE(nn.Module):
|
|
|
401
404
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
402
405
|
|
|
403
406
|
zs = zns
|
|
404
|
-
if (self.
|
|
405
|
-
zcs =
|
|
407
|
+
if (np.sum(self.condition_sizes)>0) and (cs is not None):
|
|
408
|
+
zcs = torch.zeros_like(zs)
|
|
409
|
+
shift = 0
|
|
410
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
411
|
+
cs_i = cs[:,shift:(shift+condition_size)]
|
|
412
|
+
zcs += self.condition_effects[i]([cs_i,zns])
|
|
413
|
+
shift += condition_size
|
|
406
414
|
else:
|
|
407
415
|
zcs = torch.zeros_like(zs)
|
|
408
416
|
if (self.covariate_size>0) and (fs is not None):
|
|
@@ -578,7 +586,7 @@ class SUREVAE(nn.Module):
|
|
|
578
586
|
A = np.concatenate(A)
|
|
579
587
|
return A
|
|
580
588
|
|
|
581
|
-
def
|
|
589
|
+
def get_condition_effect(self, xs, cs, i, batch_size=1024, show_progress=True):
|
|
582
590
|
xs = self.preprocess(xs)
|
|
583
591
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
584
592
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
@@ -594,7 +602,7 @@ class SUREVAE(nn.Module):
|
|
|
594
602
|
C_batch = cs[idx].to(self.get_device())
|
|
595
603
|
|
|
596
604
|
z_basal = self._get_cell_embedding(X_batch)
|
|
597
|
-
dzs = self.
|
|
605
|
+
dzs = self.condition_effects[i]([C_batch,z_basal])
|
|
598
606
|
|
|
599
607
|
A.append(tensor_to_numpy(dzs))
|
|
600
608
|
pbar.update(1)
|
|
@@ -606,7 +614,7 @@ class SUREVAE(nn.Module):
|
|
|
606
614
|
zs = self.get_cell_embedding(xs, batch_size=batch_size, show_progress=show_progress)
|
|
607
615
|
return self.kmeans.predict(zs)
|
|
608
616
|
|
|
609
|
-
def predict(self, xs,
|
|
617
|
+
def predict(self, xs, cs_list, batch_size=1024, show_progress=True):
|
|
610
618
|
"""
|
|
611
619
|
Generate gene expression prediction from given cell data and covariates.
|
|
612
620
|
This function can be used for simulating cells' transcription profiles at new conditions.
|
|
@@ -619,6 +627,7 @@ class SUREVAE(nn.Module):
|
|
|
619
627
|
"""
|
|
620
628
|
xs = self.preprocess(xs)
|
|
621
629
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
630
|
+
cs = np.hstack(cs_list)
|
|
622
631
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
623
632
|
|
|
624
633
|
dataset = CustomDataset(xs)
|
|
@@ -632,7 +641,14 @@ class SUREVAE(nn.Module):
|
|
|
632
641
|
library_size = torch.sum(X_batch, 1)
|
|
633
642
|
|
|
634
643
|
z_basal = self._get_cell_embedding(X_batch)
|
|
635
|
-
|
|
644
|
+
|
|
645
|
+
zcs = torch.zeros_like(z_basal)
|
|
646
|
+
shift = 0
|
|
647
|
+
for i, condition_size in enumerate(self.condition_sizes):
|
|
648
|
+
C_batch_i = C_batch[:, shift:(shift+condition_size)]
|
|
649
|
+
zcs += self.condition_effects[i]([C_batch_i,z_basal])
|
|
650
|
+
shift += condition_size
|
|
651
|
+
|
|
636
652
|
zfs = torch.zeros_like(z_basal)
|
|
637
653
|
|
|
638
654
|
log_mu = self.decoder_log_mu([z_basal, zcs, zfs])
|
|
@@ -758,9 +774,9 @@ class SUREVAE(nn.Module):
|
|
|
758
774
|
pbar.set_postfix({'loss': str_loss})
|
|
759
775
|
pbar.update(1)'''
|
|
760
776
|
|
|
761
|
-
def fit(self, xs,
|
|
762
|
-
|
|
763
|
-
fs = None,
|
|
777
|
+
def fit(self, xs: np.array,
|
|
778
|
+
css: list = None,
|
|
779
|
+
fs: np.array = None,
|
|
764
780
|
num_epochs: int = 100,
|
|
765
781
|
learning_rate: float = 0.0001,
|
|
766
782
|
use_mask: bool = False,
|
|
@@ -821,7 +837,8 @@ class SUREVAE(nn.Module):
|
|
|
821
837
|
|
|
822
838
|
xs = self.preprocess(xs, threshold=threshold)
|
|
823
839
|
xs = convert_to_tensor(xs, dtype=self.dtype, device='cpu')
|
|
824
|
-
if
|
|
840
|
+
if css is not None:
|
|
841
|
+
cs = np.hstack(css)
|
|
825
842
|
cs = convert_to_tensor(cs, dtype=self.dtype, device='cpu')
|
|
826
843
|
if fs is not None:
|
|
827
844
|
fs = convert_to_tensor(fs, dtype=self.dtype, device='cpu')
|
|
@@ -877,7 +894,7 @@ class SUREVAE(nn.Module):
|
|
|
877
894
|
for batch_x, idx in dataloader:
|
|
878
895
|
batch_x = batch_x.to(self.get_device())
|
|
879
896
|
for loss_id in range(num_losses):
|
|
880
|
-
if
|
|
897
|
+
if css is None:
|
|
881
898
|
batch_c = None
|
|
882
899
|
else:
|
|
883
900
|
batch_c = cs[idx].to(self.get_device())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|