SURE-tools 2.1.84__tar.gz → 2.2.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.84 → sure_tools-2.2.14}/PKG-INFO +1 -1
- sure_tools-2.1.84/SURE/PerturbFlow.py → sure_tools-2.2.14/SURE/DensityFlow.py +43 -54
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/SURE.py +6 -6
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/__init__.py +3 -3
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/flow/flow_stats.py +12 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/perturb/perturb.py +27 -1
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.1.84 → sure_tools-2.2.14}/setup.py +1 -1
- {sure_tools-2.1.84 → sure_tools-2.2.14}/LICENSE +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/README.md +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.84 → sure_tools-2.2.14}/setup.cfg +0 -0
|
@@ -54,7 +54,7 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
@@ -62,10 +62,10 @@ class PerturbFlow(nn.Module):
|
|
|
62
62
|
supervised_mode: bool = False,
|
|
63
63
|
z_dim: int = 10,
|
|
64
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
65
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
66
66
|
inverse_dispersion: float = 10.0,
|
|
67
|
-
use_zeroinflate: bool =
|
|
68
|
-
hidden_layers: list = [
|
|
67
|
+
use_zeroinflate: bool = True,
|
|
68
|
+
hidden_layers: list = [500],
|
|
69
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
70
|
nn_dropout: float = 0.1,
|
|
71
71
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -308,7 +308,7 @@ class PerturbFlow(nn.Module):
|
|
|
308
308
|
return xs
|
|
309
309
|
|
|
310
310
|
def model1(self, xs):
|
|
311
|
-
pyro.module('
|
|
311
|
+
pyro.module('DensityFlow', self)
|
|
312
312
|
|
|
313
313
|
eps = torch.finfo(xs.dtype).eps
|
|
314
314
|
batch_size = xs.size(0)
|
|
@@ -387,7 +387,7 @@ class PerturbFlow(nn.Module):
|
|
|
387
387
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
388
388
|
|
|
389
389
|
def model2(self, xs, us=None):
|
|
390
|
-
pyro.module('
|
|
390
|
+
pyro.module('DensityFlow', self)
|
|
391
391
|
|
|
392
392
|
eps = torch.finfo(xs.dtype).eps
|
|
393
393
|
batch_size = xs.size(0)
|
|
@@ -429,7 +429,7 @@ class PerturbFlow(nn.Module):
|
|
|
429
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
430
430
|
|
|
431
431
|
if self.cell_factor_size>0:
|
|
432
|
-
zus = self._total_effects(
|
|
432
|
+
zus = self._total_effects(zns, us)
|
|
433
433
|
zs = zns+zus
|
|
434
434
|
else:
|
|
435
435
|
zs = zns
|
|
@@ -471,7 +471,7 @@ class PerturbFlow(nn.Module):
|
|
|
471
471
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
472
472
|
|
|
473
473
|
def model3(self, xs, ys, embeds=None):
|
|
474
|
-
pyro.module('
|
|
474
|
+
pyro.module('DensityFlow', self)
|
|
475
475
|
|
|
476
476
|
eps = torch.finfo(xs.dtype).eps
|
|
477
477
|
batch_size = xs.size(0)
|
|
@@ -567,7 +567,7 @@ class PerturbFlow(nn.Module):
|
|
|
567
567
|
zns = embeds
|
|
568
568
|
|
|
569
569
|
def model4(self, xs, us, ys, embeds=None):
|
|
570
|
-
pyro.module('
|
|
570
|
+
pyro.module('DensityFlow', self)
|
|
571
571
|
|
|
572
572
|
eps = torch.finfo(xs.dtype).eps
|
|
573
573
|
batch_size = xs.size(0)
|
|
@@ -631,7 +631,7 @@ class PerturbFlow(nn.Module):
|
|
|
631
631
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
632
632
|
# else:
|
|
633
633
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
634
|
-
zus = self._total_effects(
|
|
634
|
+
zus = self._total_effects(zns, us)
|
|
635
635
|
zs = zns+zus
|
|
636
636
|
else:
|
|
637
637
|
zs = zns
|
|
@@ -824,9 +824,11 @@ class PerturbFlow(nn.Module):
|
|
|
824
824
|
|
|
825
825
|
# perturbation effect
|
|
826
826
|
ps = np.ones_like(us_i)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
827
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
828
|
+
dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
|
|
829
|
+
zs = zs + dzs0 + dzs
|
|
830
|
+
else:
|
|
831
|
+
zs = zs + dzs0
|
|
830
832
|
|
|
831
833
|
if library_sizes is None:
|
|
832
834
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -840,35 +842,36 @@ class PerturbFlow(nn.Module):
|
|
|
840
842
|
|
|
841
843
|
return counts, zs
|
|
842
844
|
|
|
843
|
-
def _cell_response(self,
|
|
845
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
844
846
|
#zns,_ = self.encoder_zn(xs)
|
|
845
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
847
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
848
|
+
zns = zs
|
|
846
849
|
if perturb.ndim==2:
|
|
847
|
-
ms = self.cell_factor_effect[
|
|
850
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
848
851
|
else:
|
|
849
|
-
ms = self.cell_factor_effect[
|
|
852
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
850
853
|
|
|
851
854
|
return ms
|
|
852
855
|
|
|
853
856
|
def get_cell_response(self,
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
+
zs,
|
|
858
|
+
perturb_idx,
|
|
859
|
+
perturb_us,
|
|
857
860
|
batch_size: int = 1024):
|
|
858
861
|
"""
|
|
859
862
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
860
863
|
|
|
861
864
|
"""
|
|
862
|
-
xs = self.preprocess(xs)
|
|
863
|
-
|
|
864
|
-
ps = convert_to_tensor(
|
|
865
|
-
dataset = CustomDataset2(
|
|
865
|
+
#xs = self.preprocess(xs)
|
|
866
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
867
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
868
|
+
dataset = CustomDataset2(zs,ps)
|
|
866
869
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
867
870
|
|
|
868
871
|
Z = []
|
|
869
872
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
870
|
-
for
|
|
871
|
-
zns = self._cell_response(
|
|
873
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
874
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
872
875
|
Z.append(tensor_to_numpy(zns))
|
|
873
876
|
pbar.update(1)
|
|
874
877
|
|
|
@@ -905,7 +908,7 @@ class PerturbFlow(nn.Module):
|
|
|
905
908
|
R = np.concatenate(R)
|
|
906
909
|
return R
|
|
907
910
|
|
|
908
|
-
def _count(self,concentrate, library_size=None):
|
|
911
|
+
def _count(self, concentrate, library_size=None):
|
|
909
912
|
if self.loss_func == 'bernoulli':
|
|
910
913
|
#counts = self.sigmoid(concentrate)
|
|
911
914
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -913,28 +916,17 @@ class PerturbFlow(nn.Module):
|
|
|
913
916
|
rate = concentrate.exp()
|
|
914
917
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
915
918
|
counts = theta * library_size
|
|
916
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
917
|
-
return counts
|
|
918
|
-
|
|
919
|
-
def _count_sample(self,concentrate):
|
|
920
|
-
if self.loss_func == 'bernoulli':
|
|
921
|
-
logits = concentrate
|
|
922
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
923
|
-
else:
|
|
924
|
-
counts = self._count(concentrate=concentrate)
|
|
925
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
926
919
|
return counts
|
|
927
920
|
|
|
928
921
|
def get_counts(self, zs, library_sizes,
|
|
929
|
-
batch_size: int = 1024
|
|
930
|
-
use_sampler: bool = False):
|
|
922
|
+
batch_size: int = 1024):
|
|
931
923
|
|
|
932
924
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
933
925
|
|
|
934
926
|
if type(library_sizes) == list:
|
|
935
|
-
library_sizes = np.array(library_sizes).
|
|
927
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
936
928
|
elif len(library_sizes.shape)==1:
|
|
937
|
-
library_sizes = library_sizes.
|
|
929
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
938
930
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
939
931
|
|
|
940
932
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -944,10 +936,7 @@ class PerturbFlow(nn.Module):
|
|
|
944
936
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
945
937
|
for Z_batch, L_batch, _ in dataloader:
|
|
946
938
|
concentrate = self._get_expression_response(Z_batch)
|
|
947
|
-
|
|
948
|
-
counts = self._count_sample(concentrate)
|
|
949
|
-
else:
|
|
950
|
-
counts = self._count(concentrate, L_batch)
|
|
939
|
+
counts = self._count(concentrate, L_batch)
|
|
951
940
|
E.append(tensor_to_numpy(counts))
|
|
952
941
|
pbar.update(1)
|
|
953
942
|
|
|
@@ -970,7 +959,7 @@ class PerturbFlow(nn.Module):
|
|
|
970
959
|
us = None,
|
|
971
960
|
ys = None,
|
|
972
961
|
zs = None,
|
|
973
|
-
num_epochs: int =
|
|
962
|
+
num_epochs: int = 500,
|
|
974
963
|
learning_rate: float = 0.0001,
|
|
975
964
|
batch_size: int = 256,
|
|
976
965
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -981,7 +970,7 @@ class PerturbFlow(nn.Module):
|
|
|
981
970
|
threshold: int = 0,
|
|
982
971
|
use_jax: bool = True):
|
|
983
972
|
"""
|
|
984
|
-
Train the
|
|
973
|
+
Train the DensityFlow model.
|
|
985
974
|
|
|
986
975
|
Parameters
|
|
987
976
|
----------
|
|
@@ -1007,7 +996,7 @@ class PerturbFlow(nn.Module):
|
|
|
1007
996
|
Parameter for optimization.
|
|
1008
997
|
use_jax
|
|
1009
998
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1010
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
999
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
1011
1000
|
"""
|
|
1012
1001
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1013
1002
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1125,12 +1114,12 @@ class PerturbFlow(nn.Module):
|
|
|
1125
1114
|
|
|
1126
1115
|
|
|
1127
1116
|
EXAMPLE_RUN = (
|
|
1128
|
-
"example run:
|
|
1117
|
+
"example run: DensityFlow --help"
|
|
1129
1118
|
)
|
|
1130
1119
|
|
|
1131
1120
|
def parse_args():
|
|
1132
1121
|
parser = argparse.ArgumentParser(
|
|
1133
|
-
description="
|
|
1122
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1134
1123
|
|
|
1135
1124
|
parser.add_argument(
|
|
1136
1125
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1317,7 +1306,7 @@ def main():
|
|
|
1317
1306
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1318
1307
|
|
|
1319
1308
|
###########################################
|
|
1320
|
-
|
|
1309
|
+
DensityFlow = DensityFlow(
|
|
1321
1310
|
input_size=input_size,
|
|
1322
1311
|
cell_factor_size=cell_factor_size,
|
|
1323
1312
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1336,7 +1325,7 @@ def main():
|
|
|
1336
1325
|
dtype=dtype,
|
|
1337
1326
|
)
|
|
1338
1327
|
|
|
1339
|
-
|
|
1328
|
+
DensityFlow.fit(xs, us=us,
|
|
1340
1329
|
num_epochs=args.num_epochs,
|
|
1341
1330
|
learning_rate=args.learning_rate,
|
|
1342
1331
|
batch_size=args.batch_size,
|
|
@@ -1348,9 +1337,9 @@ def main():
|
|
|
1348
1337
|
|
|
1349
1338
|
if args.save_model is not None:
|
|
1350
1339
|
if args.save_model.endswith('gz'):
|
|
1351
|
-
|
|
1340
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1352
1341
|
else:
|
|
1353
|
-
|
|
1342
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1354
1343
|
|
|
1355
1344
|
|
|
1356
1345
|
|
|
@@ -99,17 +99,17 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
114
|
dtype = torch.float32, # type: ignore
|
|
115
115
|
):
|
|
@@ -817,7 +817,7 @@ class SURE(nn.Module):
|
|
|
817
817
|
us = None,
|
|
818
818
|
ys = None,
|
|
819
819
|
zs = None,
|
|
820
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
821
821
|
learning_rate: float = 0.0001,
|
|
822
822
|
batch_size: int = 256,
|
|
823
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -826,7 +826,7 @@ class SURE(nn.Module):
|
|
|
826
826
|
decay_rate: float = 0.9,
|
|
827
827
|
config_enum: str = 'parallel',
|
|
828
828
|
threshold: int = 0,
|
|
829
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
830
830
|
"""
|
|
831
831
|
Train the SURE model.
|
|
832
832
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
@@ -41,6 +41,18 @@ class VectorFieldEval:
|
|
|
41
41
|
divergence[np.isnan(divergence)] = 0
|
|
42
42
|
|
|
43
43
|
return divergence
|
|
44
|
+
|
|
45
|
+
def movement_stats(self,vectors):
|
|
46
|
+
return calculate_movement_stats(vectors)
|
|
47
|
+
|
|
48
|
+
def direction_stats(self, vectors):
|
|
49
|
+
return calculate_direction_stats(vectors)
|
|
50
|
+
|
|
51
|
+
def movement_energy(self, vectors, masses=None):
|
|
52
|
+
return calculate_movement_energy(vectors, masses)
|
|
53
|
+
|
|
54
|
+
def movement_divergence(self, positions, vectors):
|
|
55
|
+
return calculate_movement_divergence(positions, vectors)
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def calculate_movement_stats(vectors):
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|