SURE-tools 2.1.91__tar.gz → 2.2.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.91 → sure_tools-2.2.14}/PKG-INFO +1 -1
- sure_tools-2.1.91/SURE/PerturbFlow.py → sure_tools-2.2.14/SURE/DensityFlow.py +43 -62
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/SURE.py +6 -6
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/__init__.py +3 -3
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/flow/flow_stats.py +12 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/perturb/perturb.py +27 -1
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.14}/setup.py +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.14}/LICENSE +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/README.md +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.14}/setup.cfg +0 -0
|
@@ -54,19 +54,18 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
-
cell_factor_effect_discrete: bool = False,
|
|
63
62
|
supervised_mode: bool = False,
|
|
64
63
|
z_dim: int = 10,
|
|
65
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
67
66
|
inverse_dispersion: float = 10.0,
|
|
68
|
-
use_zeroinflate: bool =
|
|
69
|
-
hidden_layers: list = [
|
|
67
|
+
use_zeroinflate: bool = True,
|
|
68
|
+
hidden_layers: list = [500],
|
|
70
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
70
|
nn_dropout: float = 0.1,
|
|
72
71
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -103,7 +102,6 @@ class PerturbFlow(nn.Module):
|
|
|
103
102
|
else:
|
|
104
103
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
105
104
|
#self.use_bias = not zero_bias
|
|
106
|
-
self.enumrate = cell_factor_effect_discrete
|
|
107
105
|
|
|
108
106
|
self.codebook_weights = None
|
|
109
107
|
|
|
@@ -310,7 +308,7 @@ class PerturbFlow(nn.Module):
|
|
|
310
308
|
return xs
|
|
311
309
|
|
|
312
310
|
def model1(self, xs):
|
|
313
|
-
pyro.module('
|
|
311
|
+
pyro.module('DensityFlow', self)
|
|
314
312
|
|
|
315
313
|
eps = torch.finfo(xs.dtype).eps
|
|
316
314
|
batch_size = xs.size(0)
|
|
@@ -389,7 +387,7 @@ class PerturbFlow(nn.Module):
|
|
|
389
387
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
390
388
|
|
|
391
389
|
def model2(self, xs, us=None):
|
|
392
|
-
pyro.module('
|
|
390
|
+
pyro.module('DensityFlow', self)
|
|
393
391
|
|
|
394
392
|
eps = torch.finfo(xs.dtype).eps
|
|
395
393
|
batch_size = xs.size(0)
|
|
@@ -431,10 +429,7 @@ class PerturbFlow(nn.Module):
|
|
|
431
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
432
430
|
|
|
433
431
|
if self.cell_factor_size>0:
|
|
434
|
-
|
|
435
|
-
zus = self._total_effects(zn_loc, us)
|
|
436
|
-
else:
|
|
437
|
-
zus = self._total_effects(zns, us)
|
|
432
|
+
zus = self._total_effects(zns, us)
|
|
438
433
|
zs = zns+zus
|
|
439
434
|
else:
|
|
440
435
|
zs = zns
|
|
@@ -476,7 +471,7 @@ class PerturbFlow(nn.Module):
|
|
|
476
471
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
477
472
|
|
|
478
473
|
def model3(self, xs, ys, embeds=None):
|
|
479
|
-
pyro.module('
|
|
474
|
+
pyro.module('DensityFlow', self)
|
|
480
475
|
|
|
481
476
|
eps = torch.finfo(xs.dtype).eps
|
|
482
477
|
batch_size = xs.size(0)
|
|
@@ -572,7 +567,7 @@ class PerturbFlow(nn.Module):
|
|
|
572
567
|
zns = embeds
|
|
573
568
|
|
|
574
569
|
def model4(self, xs, us, ys, embeds=None):
|
|
575
|
-
pyro.module('
|
|
570
|
+
pyro.module('DensityFlow', self)
|
|
576
571
|
|
|
577
572
|
eps = torch.finfo(xs.dtype).eps
|
|
578
573
|
batch_size = xs.size(0)
|
|
@@ -636,10 +631,7 @@ class PerturbFlow(nn.Module):
|
|
|
636
631
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
637
632
|
# else:
|
|
638
633
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
639
|
-
|
|
640
|
-
zus = self._total_effects(zn_loc, us)
|
|
641
|
-
else:
|
|
642
|
-
zus = self._total_effects(zns, us)
|
|
634
|
+
zus = self._total_effects(zns, us)
|
|
643
635
|
zs = zns+zus
|
|
644
636
|
else:
|
|
645
637
|
zs = zns
|
|
@@ -832,9 +824,11 @@ class PerturbFlow(nn.Module):
|
|
|
832
824
|
|
|
833
825
|
# perturbation effect
|
|
834
826
|
ps = np.ones_like(us_i)
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
827
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
828
|
+
dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
|
|
829
|
+
zs = zs + dzs0 + dzs
|
|
830
|
+
else:
|
|
831
|
+
zs = zs + dzs0
|
|
838
832
|
|
|
839
833
|
if library_sizes is None:
|
|
840
834
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -848,35 +842,36 @@ class PerturbFlow(nn.Module):
|
|
|
848
842
|
|
|
849
843
|
return counts, zs
|
|
850
844
|
|
|
851
|
-
def _cell_response(self,
|
|
845
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
852
846
|
#zns,_ = self.encoder_zn(xs)
|
|
853
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
847
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
848
|
+
zns = zs
|
|
854
849
|
if perturb.ndim==2:
|
|
855
|
-
ms = self.cell_factor_effect[
|
|
850
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
856
851
|
else:
|
|
857
|
-
ms = self.cell_factor_effect[
|
|
852
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
858
853
|
|
|
859
854
|
return ms
|
|
860
855
|
|
|
861
856
|
def get_cell_response(self,
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
857
|
+
zs,
|
|
858
|
+
perturb_idx,
|
|
859
|
+
perturb_us,
|
|
865
860
|
batch_size: int = 1024):
|
|
866
861
|
"""
|
|
867
862
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
868
863
|
|
|
869
864
|
"""
|
|
870
|
-
xs = self.preprocess(xs)
|
|
871
|
-
|
|
872
|
-
ps = convert_to_tensor(
|
|
873
|
-
dataset = CustomDataset2(
|
|
865
|
+
#xs = self.preprocess(xs)
|
|
866
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
867
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
868
|
+
dataset = CustomDataset2(zs,ps)
|
|
874
869
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
875
870
|
|
|
876
871
|
Z = []
|
|
877
872
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
878
|
-
for
|
|
879
|
-
zns = self._cell_response(
|
|
873
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
874
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
880
875
|
Z.append(tensor_to_numpy(zns))
|
|
881
876
|
pbar.update(1)
|
|
882
877
|
|
|
@@ -913,7 +908,7 @@ class PerturbFlow(nn.Module):
|
|
|
913
908
|
R = np.concatenate(R)
|
|
914
909
|
return R
|
|
915
910
|
|
|
916
|
-
def _count(self,concentrate, library_size=None):
|
|
911
|
+
def _count(self, concentrate, library_size=None):
|
|
917
912
|
if self.loss_func == 'bernoulli':
|
|
918
913
|
#counts = self.sigmoid(concentrate)
|
|
919
914
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -921,28 +916,17 @@ class PerturbFlow(nn.Module):
|
|
|
921
916
|
rate = concentrate.exp()
|
|
922
917
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
923
918
|
counts = theta * library_size
|
|
924
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
925
|
-
return counts
|
|
926
|
-
|
|
927
|
-
def _count_sample(self,concentrate):
|
|
928
|
-
if self.loss_func == 'bernoulli':
|
|
929
|
-
logits = concentrate
|
|
930
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
931
|
-
else:
|
|
932
|
-
counts = self._count(concentrate=concentrate)
|
|
933
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
934
919
|
return counts
|
|
935
920
|
|
|
936
921
|
def get_counts(self, zs, library_sizes,
|
|
937
|
-
batch_size: int = 1024
|
|
938
|
-
use_sampler: bool = False):
|
|
922
|
+
batch_size: int = 1024):
|
|
939
923
|
|
|
940
924
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
941
925
|
|
|
942
926
|
if type(library_sizes) == list:
|
|
943
|
-
library_sizes = np.array(library_sizes).
|
|
927
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
944
928
|
elif len(library_sizes.shape)==1:
|
|
945
|
-
library_sizes = library_sizes.
|
|
929
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
946
930
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
947
931
|
|
|
948
932
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -952,10 +936,7 @@ class PerturbFlow(nn.Module):
|
|
|
952
936
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
953
937
|
for Z_batch, L_batch, _ in dataloader:
|
|
954
938
|
concentrate = self._get_expression_response(Z_batch)
|
|
955
|
-
|
|
956
|
-
counts = self._count_sample(concentrate)
|
|
957
|
-
else:
|
|
958
|
-
counts = self._count(concentrate, L_batch)
|
|
939
|
+
counts = self._count(concentrate, L_batch)
|
|
959
940
|
E.append(tensor_to_numpy(counts))
|
|
960
941
|
pbar.update(1)
|
|
961
942
|
|
|
@@ -978,7 +959,7 @@ class PerturbFlow(nn.Module):
|
|
|
978
959
|
us = None,
|
|
979
960
|
ys = None,
|
|
980
961
|
zs = None,
|
|
981
|
-
num_epochs: int =
|
|
962
|
+
num_epochs: int = 500,
|
|
982
963
|
learning_rate: float = 0.0001,
|
|
983
964
|
batch_size: int = 256,
|
|
984
965
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -989,7 +970,7 @@ class PerturbFlow(nn.Module):
|
|
|
989
970
|
threshold: int = 0,
|
|
990
971
|
use_jax: bool = True):
|
|
991
972
|
"""
|
|
992
|
-
Train the
|
|
973
|
+
Train the DensityFlow model.
|
|
993
974
|
|
|
994
975
|
Parameters
|
|
995
976
|
----------
|
|
@@ -1015,7 +996,7 @@ class PerturbFlow(nn.Module):
|
|
|
1015
996
|
Parameter for optimization.
|
|
1016
997
|
use_jax
|
|
1017
998
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1018
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
999
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
1019
1000
|
"""
|
|
1020
1001
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1021
1002
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1133,12 +1114,12 @@ class PerturbFlow(nn.Module):
|
|
|
1133
1114
|
|
|
1134
1115
|
|
|
1135
1116
|
EXAMPLE_RUN = (
|
|
1136
|
-
"example run:
|
|
1117
|
+
"example run: DensityFlow --help"
|
|
1137
1118
|
)
|
|
1138
1119
|
|
|
1139
1120
|
def parse_args():
|
|
1140
1121
|
parser = argparse.ArgumentParser(
|
|
1141
|
-
description="
|
|
1122
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1142
1123
|
|
|
1143
1124
|
parser.add_argument(
|
|
1144
1125
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1325,7 +1306,7 @@ def main():
|
|
|
1325
1306
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1326
1307
|
|
|
1327
1308
|
###########################################
|
|
1328
|
-
|
|
1309
|
+
DensityFlow = DensityFlow(
|
|
1329
1310
|
input_size=input_size,
|
|
1330
1311
|
cell_factor_size=cell_factor_size,
|
|
1331
1312
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1344,7 +1325,7 @@ def main():
|
|
|
1344
1325
|
dtype=dtype,
|
|
1345
1326
|
)
|
|
1346
1327
|
|
|
1347
|
-
|
|
1328
|
+
DensityFlow.fit(xs, us=us,
|
|
1348
1329
|
num_epochs=args.num_epochs,
|
|
1349
1330
|
learning_rate=args.learning_rate,
|
|
1350
1331
|
batch_size=args.batch_size,
|
|
@@ -1356,9 +1337,9 @@ def main():
|
|
|
1356
1337
|
|
|
1357
1338
|
if args.save_model is not None:
|
|
1358
1339
|
if args.save_model.endswith('gz'):
|
|
1359
|
-
|
|
1340
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1360
1341
|
else:
|
|
1361
|
-
|
|
1342
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1362
1343
|
|
|
1363
1344
|
|
|
1364
1345
|
|
|
@@ -99,17 +99,17 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
114
|
dtype = torch.float32, # type: ignore
|
|
115
115
|
):
|
|
@@ -817,7 +817,7 @@ class SURE(nn.Module):
|
|
|
817
817
|
us = None,
|
|
818
818
|
ys = None,
|
|
819
819
|
zs = None,
|
|
820
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
821
821
|
learning_rate: float = 0.0001,
|
|
822
822
|
batch_size: int = 256,
|
|
823
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -826,7 +826,7 @@ class SURE(nn.Module):
|
|
|
826
826
|
decay_rate: float = 0.9,
|
|
827
827
|
config_enum: str = 'parallel',
|
|
828
828
|
threshold: int = 0,
|
|
829
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
830
830
|
"""
|
|
831
831
|
Train the SURE model.
|
|
832
832
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
@@ -41,6 +41,18 @@ class VectorFieldEval:
|
|
|
41
41
|
divergence[np.isnan(divergence)] = 0
|
|
42
42
|
|
|
43
43
|
return divergence
|
|
44
|
+
|
|
45
|
+
def movement_stats(self,vectors):
|
|
46
|
+
return calculate_movement_stats(vectors)
|
|
47
|
+
|
|
48
|
+
def direction_stats(self, vectors):
|
|
49
|
+
return calculate_direction_stats(vectors)
|
|
50
|
+
|
|
51
|
+
def movement_energy(self, vectors, masses=None):
|
|
52
|
+
return calculate_movement_energy(vectors, masses)
|
|
53
|
+
|
|
54
|
+
def movement_divergence(self, positions, vectors):
|
|
55
|
+
return calculate_movement_divergence(positions, vectors)
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def calculate_movement_stats(vectors):
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|