SURE-tools 2.2.10__tar.gz → 2.2.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.2.10 → sure_tools-2.2.25}/PKG-INFO +1 -1
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/DensityFlow.py +91 -71
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/perturb/perturb.py +27 -1
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/utils/custom_mlp.py +8 -2
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.2.10 → sure_tools-2.2.25}/setup.py +1 -1
- {sure_tools-2.2.10 → sure_tools-2.2.25}/LICENSE +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/README.md +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/SURE.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/atac/utils.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/utils/queue.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE/utils/utils.py +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.2.10 → sure_tools-2.2.25}/setup.cfg +0 -0
|
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
|
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
+
turn_off_cell_specific: bool = False,
|
|
62
63
|
supervised_mode: bool = False,
|
|
63
64
|
z_dim: int = 10,
|
|
64
65
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
66
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
66
67
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
+
use_zeroinflate: bool = False,
|
|
68
69
|
hidden_layers: list = [500],
|
|
69
70
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
71
|
nn_dropout: float = 0.1,
|
|
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.turn_off_cell_specific = turn_off_cell_specific
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
|
|
|
203
205
|
self.cell_factor_effect = nn.ModuleList()
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
if self.turn_off_cell_specific:
|
|
209
|
+
self.cell_factor_effect.append(MLP(
|
|
210
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
211
|
+
activation=activate_fct,
|
|
212
|
+
output_activation=None,
|
|
213
|
+
post_layer_fct=post_layer_fct,
|
|
214
|
+
post_act_fct=post_act_fct,
|
|
215
|
+
allow_broadcast=self.allow_broadcast,
|
|
216
|
+
use_cuda=self.use_cuda,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.cell_factor_effect.append(MLP(
|
|
221
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
222
|
+
activation=activate_fct,
|
|
223
|
+
output_activation=None,
|
|
224
|
+
post_layer_fct=post_layer_fct,
|
|
225
|
+
post_act_fct=post_act_fct,
|
|
226
|
+
allow_broadcast=self.allow_broadcast,
|
|
227
|
+
use_cuda=self.use_cuda,
|
|
228
|
+
)
|
|
214
229
|
)
|
|
215
|
-
)
|
|
216
230
|
else:
|
|
217
|
-
self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if self.turn_off_cell_specific:
|
|
232
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
233
|
+
[1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
234
|
+
activation=activate_fct,
|
|
235
|
+
output_activation=None,
|
|
236
|
+
post_layer_fct=post_layer_fct,
|
|
237
|
+
post_act_fct=post_act_fct,
|
|
238
|
+
allow_broadcast=self.allow_broadcast,
|
|
239
|
+
use_cuda=self.use_cuda,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
244
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
245
|
+
activation=activate_fct,
|
|
246
|
+
output_activation=None,
|
|
247
|
+
post_layer_fct=post_layer_fct,
|
|
248
|
+
post_act_fct=post_act_fct,
|
|
249
|
+
allow_broadcast=self.allow_broadcast,
|
|
250
|
+
use_cuda=self.use_cuda,
|
|
251
|
+
)
|
|
225
252
|
)
|
|
226
|
-
)
|
|
227
253
|
|
|
228
254
|
self.decoder_concentrate = MLP(
|
|
229
255
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -234,16 +260,6 @@ class DensityFlow(nn.Module):
|
|
|
234
260
|
allow_broadcast=self.allow_broadcast,
|
|
235
261
|
use_cuda=self.use_cuda,
|
|
236
262
|
)
|
|
237
|
-
if self.loss_func == 'negbinomial':
|
|
238
|
-
self.decoder_total_count = MLP(
|
|
239
|
-
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
240
|
-
activation=activate_fct,
|
|
241
|
-
output_activation=Exp,
|
|
242
|
-
post_layer_fct=post_layer_fct,
|
|
243
|
-
post_act_fct=post_act_fct,
|
|
244
|
-
allow_broadcast=self.allow_broadcast,
|
|
245
|
-
use_cuda=self.use_cuda,
|
|
246
|
-
)
|
|
247
263
|
|
|
248
264
|
if self.latent_dist == 'studentt':
|
|
249
265
|
self.codebook = MLP(
|
|
@@ -324,9 +340,9 @@ class DensityFlow(nn.Module):
|
|
|
324
340
|
batch_size = xs.size(0)
|
|
325
341
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
326
342
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
343
|
+
if self.loss_func=='negbinomial':
|
|
344
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
345
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
330
346
|
|
|
331
347
|
if self.use_zeroinflate:
|
|
332
348
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -370,7 +386,6 @@ class DensityFlow(nn.Module):
|
|
|
370
386
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
371
387
|
|
|
372
388
|
if self.loss_func == 'negbinomial':
|
|
373
|
-
total_count = self.decoder_total_count(zs)
|
|
374
389
|
if self.use_zeroinflate:
|
|
375
390
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
376
391
|
else:
|
|
@@ -404,9 +419,9 @@ class DensityFlow(nn.Module):
|
|
|
404
419
|
batch_size = xs.size(0)
|
|
405
420
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
406
421
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
422
|
+
if self.loss_func=='negbinomial':
|
|
423
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
424
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
410
425
|
|
|
411
426
|
if self.use_zeroinflate:
|
|
412
427
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -455,7 +470,6 @@ class DensityFlow(nn.Module):
|
|
|
455
470
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
456
471
|
|
|
457
472
|
if self.loss_func == 'negbinomial':
|
|
458
|
-
total_count = self.decoder_total_count(zs)
|
|
459
473
|
if self.use_zeroinflate:
|
|
460
474
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
461
475
|
else:
|
|
@@ -489,9 +503,9 @@ class DensityFlow(nn.Module):
|
|
|
489
503
|
batch_size = xs.size(0)
|
|
490
504
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
491
505
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
506
|
+
if self.loss_func=='negbinomial':
|
|
507
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
508
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
495
509
|
|
|
496
510
|
if self.use_zeroinflate:
|
|
497
511
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -552,7 +566,6 @@ class DensityFlow(nn.Module):
|
|
|
552
566
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
553
567
|
|
|
554
568
|
if self.loss_func == 'negbinomial':
|
|
555
|
-
total_count = self.decoder_total_count(zs)
|
|
556
569
|
if self.use_zeroinflate:
|
|
557
570
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
558
571
|
else:
|
|
@@ -586,9 +599,9 @@ class DensityFlow(nn.Module):
|
|
|
586
599
|
batch_size = xs.size(0)
|
|
587
600
|
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
588
601
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
602
|
+
if self.loss_func=='negbinomial':
|
|
603
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
|
|
604
|
+
xs.new_ones(self.input_size), constraint=constraints.positive)
|
|
592
605
|
|
|
593
606
|
if self.use_zeroinflate:
|
|
594
607
|
gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
|
|
@@ -659,7 +672,6 @@ class DensityFlow(nn.Module):
|
|
|
659
672
|
rate = theta * torch.sum(xs, dim=1, keepdim=True)
|
|
660
673
|
|
|
661
674
|
if self.loss_func == 'negbinomial':
|
|
662
|
-
total_count = self.decoder_total_count(zs)
|
|
663
675
|
if self.use_zeroinflate:
|
|
664
676
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
665
677
|
else:
|
|
@@ -690,9 +702,17 @@ class DensityFlow(nn.Module):
|
|
|
690
702
|
zus = None
|
|
691
703
|
for i in np.arange(self.cell_factor_size):
|
|
692
704
|
if i==0:
|
|
693
|
-
|
|
705
|
+
#if self.turn_off_cell_specific:
|
|
706
|
+
# zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
707
|
+
#else:
|
|
708
|
+
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
709
|
+
zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
694
710
|
else:
|
|
695
|
-
|
|
711
|
+
#if self.turn_off_cell_specific:
|
|
712
|
+
# zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
|
|
713
|
+
#else:
|
|
714
|
+
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
715
|
+
zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
|
|
696
716
|
return zus
|
|
697
717
|
|
|
698
718
|
def _get_codebook_identity(self):
|
|
@@ -710,7 +730,7 @@ class DensityFlow(nn.Module):
|
|
|
710
730
|
"""
|
|
711
731
|
Return the mean part of metacell codebook
|
|
712
732
|
"""
|
|
713
|
-
cb = self.
|
|
733
|
+
cb = self._get_codebook()
|
|
714
734
|
cb = tensor_to_numpy(cb)
|
|
715
735
|
return cb
|
|
716
736
|
|
|
@@ -834,12 +854,12 @@ class DensityFlow(nn.Module):
|
|
|
834
854
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
835
855
|
|
|
836
856
|
# factor effect of xs
|
|
837
|
-
dzs0 = self.get_cell_response(
|
|
857
|
+
dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
|
|
838
858
|
|
|
839
859
|
# perturbation effect
|
|
840
860
|
ps = np.ones_like(us_i)
|
|
841
861
|
if np.sum(np.abs(ps-us_i))>=1:
|
|
842
|
-
dzs = self.get_cell_response(
|
|
862
|
+
dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
|
|
843
863
|
zs = zs + dzs0 + dzs
|
|
844
864
|
else:
|
|
845
865
|
zs = zs + dzs0
|
|
@@ -856,47 +876,48 @@ class DensityFlow(nn.Module):
|
|
|
856
876
|
|
|
857
877
|
return counts, zs
|
|
858
878
|
|
|
859
|
-
def _cell_response(self,
|
|
879
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
860
880
|
#zns,_ = self.encoder_zn(xs)
|
|
861
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
881
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
882
|
+
zns = zs
|
|
862
883
|
if perturb.ndim==2:
|
|
863
|
-
|
|
884
|
+
if self.turn_off_cell_specific:
|
|
885
|
+
ms = self.cell_factor_effect[perturb_idx](perturb)
|
|
886
|
+
else:
|
|
887
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
864
888
|
else:
|
|
865
|
-
|
|
889
|
+
if self.turn_off_cell_specific:
|
|
890
|
+
ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
|
|
891
|
+
else:
|
|
892
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
866
893
|
|
|
867
894
|
return ms
|
|
868
895
|
|
|
869
896
|
def get_cell_response(self,
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
897
|
+
zs,
|
|
898
|
+
perturb_idx,
|
|
899
|
+
perturb_us,
|
|
873
900
|
batch_size: int = 1024):
|
|
874
901
|
"""
|
|
875
902
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
876
903
|
|
|
877
904
|
"""
|
|
878
|
-
xs = self.preprocess(xs)
|
|
879
|
-
|
|
880
|
-
ps = convert_to_tensor(
|
|
881
|
-
dataset = CustomDataset2(
|
|
905
|
+
#xs = self.preprocess(xs)
|
|
906
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
907
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
908
|
+
dataset = CustomDataset2(zs,ps)
|
|
882
909
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
883
910
|
|
|
884
911
|
Z = []
|
|
885
912
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
886
|
-
for
|
|
887
|
-
zns = self._cell_response(
|
|
913
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
914
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
888
915
|
Z.append(tensor_to_numpy(zns))
|
|
889
916
|
pbar.update(1)
|
|
890
917
|
|
|
891
918
|
Z = np.concatenate(Z)
|
|
892
919
|
return Z
|
|
893
920
|
|
|
894
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
895
|
-
zs = self._get_codebook()
|
|
896
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
897
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
898
|
-
return tensor_to_numpy(ms)
|
|
899
|
-
|
|
900
921
|
def _get_expression_response(self, delta_zs):
|
|
901
922
|
return self.decoder_concentrate(delta_zs)
|
|
902
923
|
|
|
@@ -1357,5 +1378,4 @@ def main():
|
|
|
1357
1378
|
|
|
1358
1379
|
|
|
1359
1380
|
if __name__ == "__main__":
|
|
1360
|
-
|
|
1361
1381
|
main()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
@@ -240,9 +240,15 @@ class ZeroBiasMLP(nn.Module):
|
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
242
|
if len(y.shape)==2:
|
|
243
|
-
|
|
243
|
+
if len(x)>2:
|
|
244
|
+
mask[x[1][:,0]>0,:] = 1
|
|
245
|
+
else:
|
|
246
|
+
mask[x[:,0]>0,:] = 1
|
|
244
247
|
elif len(y.shape)==3:
|
|
245
|
-
|
|
248
|
+
if len(x)>1:
|
|
249
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
250
|
+
else:
|
|
251
|
+
mask[:,x[:,0]>0,:] = 1
|
|
246
252
|
return y*mask
|
|
247
253
|
|
|
248
254
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|