SURE-tools 2.1.91__tar.gz → 2.2.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.91 → sure_tools-2.2.17}/PKG-INFO +1 -1
- sure_tools-2.1.91/SURE/PerturbFlow.py → sure_tools-2.2.17/SURE/DensityFlow.py +44 -69
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/SURE.py +6 -6
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/__init__.py +3 -3
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/flow/flow_stats.py +12 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/perturb/perturb.py +27 -1
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE_tools.egg-info/SOURCES.txt +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.17}/setup.py +1 -1
- {sure_tools-2.1.91 → sure_tools-2.2.17}/LICENSE +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/README.md +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.91 → sure_tools-2.2.17}/setup.cfg +0 -0
|
@@ -54,19 +54,18 @@ def set_random_seed(seed):
|
|
|
54
54
|
# Set seed for Pyro
|
|
55
55
|
pyro.set_rng_seed(seed)
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class DensityFlow(nn.Module):
|
|
58
58
|
def __init__(self,
|
|
59
59
|
input_size: int,
|
|
60
60
|
codebook_size: int = 200,
|
|
61
61
|
cell_factor_size: int = 0,
|
|
62
|
-
cell_factor_effect_discrete: bool = False,
|
|
63
62
|
supervised_mode: bool = False,
|
|
64
63
|
z_dim: int = 10,
|
|
65
64
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
66
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
65
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
67
66
|
inverse_dispersion: float = 10.0,
|
|
68
67
|
use_zeroinflate: bool = False,
|
|
69
|
-
hidden_layers: list = [
|
|
68
|
+
hidden_layers: list = [500],
|
|
70
69
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
71
70
|
nn_dropout: float = 0.1,
|
|
72
71
|
post_layer_fct: list = ['layernorm'],
|
|
@@ -103,7 +102,6 @@ class PerturbFlow(nn.Module):
|
|
|
103
102
|
else:
|
|
104
103
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
105
104
|
#self.use_bias = not zero_bias
|
|
106
|
-
self.enumrate = cell_factor_effect_discrete
|
|
107
105
|
|
|
108
106
|
self.codebook_weights = None
|
|
109
107
|
|
|
@@ -310,7 +308,7 @@ class PerturbFlow(nn.Module):
|
|
|
310
308
|
return xs
|
|
311
309
|
|
|
312
310
|
def model1(self, xs):
|
|
313
|
-
pyro.module('
|
|
311
|
+
pyro.module('DensityFlow', self)
|
|
314
312
|
|
|
315
313
|
eps = torch.finfo(xs.dtype).eps
|
|
316
314
|
batch_size = xs.size(0)
|
|
@@ -389,7 +387,7 @@ class PerturbFlow(nn.Module):
|
|
|
389
387
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
390
388
|
|
|
391
389
|
def model2(self, xs, us=None):
|
|
392
|
-
pyro.module('
|
|
390
|
+
pyro.module('DensityFlow', self)
|
|
393
391
|
|
|
394
392
|
eps = torch.finfo(xs.dtype).eps
|
|
395
393
|
batch_size = xs.size(0)
|
|
@@ -431,10 +429,7 @@ class PerturbFlow(nn.Module):
|
|
|
431
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
432
430
|
|
|
433
431
|
if self.cell_factor_size>0:
|
|
434
|
-
|
|
435
|
-
zus = self._total_effects(zn_loc, us)
|
|
436
|
-
else:
|
|
437
|
-
zus = self._total_effects(zns, us)
|
|
432
|
+
zus = self._total_effects(zns, us)
|
|
438
433
|
zs = zns+zus
|
|
439
434
|
else:
|
|
440
435
|
zs = zns
|
|
@@ -476,7 +471,7 @@ class PerturbFlow(nn.Module):
|
|
|
476
471
|
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
477
472
|
|
|
478
473
|
def model3(self, xs, ys, embeds=None):
|
|
479
|
-
pyro.module('
|
|
474
|
+
pyro.module('DensityFlow', self)
|
|
480
475
|
|
|
481
476
|
eps = torch.finfo(xs.dtype).eps
|
|
482
477
|
batch_size = xs.size(0)
|
|
@@ -572,7 +567,7 @@ class PerturbFlow(nn.Module):
|
|
|
572
567
|
zns = embeds
|
|
573
568
|
|
|
574
569
|
def model4(self, xs, us, ys, embeds=None):
|
|
575
|
-
pyro.module('
|
|
570
|
+
pyro.module('DensityFlow', self)
|
|
576
571
|
|
|
577
572
|
eps = torch.finfo(xs.dtype).eps
|
|
578
573
|
batch_size = xs.size(0)
|
|
@@ -636,10 +631,7 @@ class PerturbFlow(nn.Module):
|
|
|
636
631
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
637
632
|
# else:
|
|
638
633
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
639
|
-
|
|
640
|
-
zus = self._total_effects(zn_loc, us)
|
|
641
|
-
else:
|
|
642
|
-
zus = self._total_effects(zns, us)
|
|
634
|
+
zus = self._total_effects(zns, us)
|
|
643
635
|
zs = zns+zus
|
|
644
636
|
else:
|
|
645
637
|
zs = zns
|
|
@@ -704,7 +696,7 @@ class PerturbFlow(nn.Module):
|
|
|
704
696
|
"""
|
|
705
697
|
Return the mean part of metacell codebook
|
|
706
698
|
"""
|
|
707
|
-
cb = self.
|
|
699
|
+
cb = self._get_codebook()
|
|
708
700
|
cb = tensor_to_numpy(cb)
|
|
709
701
|
return cb
|
|
710
702
|
|
|
@@ -828,13 +820,15 @@ class PerturbFlow(nn.Module):
|
|
|
828
820
|
us_i = us[:,pert_idx].reshape(-1,1)
|
|
829
821
|
|
|
830
822
|
# factor effect of xs
|
|
831
|
-
dzs0 = self.get_cell_response(
|
|
823
|
+
dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
|
|
832
824
|
|
|
833
825
|
# perturbation effect
|
|
834
826
|
ps = np.ones_like(us_i)
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
827
|
+
if np.sum(np.abs(ps-us_i))>=1:
|
|
828
|
+
dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
|
|
829
|
+
zs = zs + dzs0 + dzs
|
|
830
|
+
else:
|
|
831
|
+
zs = zs + dzs0
|
|
838
832
|
|
|
839
833
|
if library_sizes is None:
|
|
840
834
|
library_sizes = np.sum(xs, axis=1, keepdims=True)
|
|
@@ -848,47 +842,42 @@ class PerturbFlow(nn.Module):
|
|
|
848
842
|
|
|
849
843
|
return counts, zs
|
|
850
844
|
|
|
851
|
-
def _cell_response(self,
|
|
845
|
+
def _cell_response(self, zs, perturb_idx, perturb):
|
|
852
846
|
#zns,_ = self.encoder_zn(xs)
|
|
853
|
-
zns,_ = self._get_basal_embedding(xs)
|
|
847
|
+
#zns,_ = self._get_basal_embedding(xs)
|
|
848
|
+
zns = zs
|
|
854
849
|
if perturb.ndim==2:
|
|
855
|
-
ms = self.cell_factor_effect[
|
|
850
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
|
|
856
851
|
else:
|
|
857
|
-
ms = self.cell_factor_effect[
|
|
852
|
+
ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
|
|
858
853
|
|
|
859
854
|
return ms
|
|
860
855
|
|
|
861
856
|
def get_cell_response(self,
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
857
|
+
zs,
|
|
858
|
+
perturb_idx,
|
|
859
|
+
perturb_us,
|
|
865
860
|
batch_size: int = 1024):
|
|
866
861
|
"""
|
|
867
862
|
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
868
863
|
|
|
869
864
|
"""
|
|
870
|
-
xs = self.preprocess(xs)
|
|
871
|
-
|
|
872
|
-
ps = convert_to_tensor(
|
|
873
|
-
dataset = CustomDataset2(
|
|
865
|
+
#xs = self.preprocess(xs)
|
|
866
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
867
|
+
ps = convert_to_tensor(perturb_us, device=self.get_device())
|
|
868
|
+
dataset = CustomDataset2(zs,ps)
|
|
874
869
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
875
870
|
|
|
876
871
|
Z = []
|
|
877
872
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
878
|
-
for
|
|
879
|
-
zns = self._cell_response(
|
|
873
|
+
for Z_batch, P_batch, _ in dataloader:
|
|
874
|
+
zns = self._cell_response(Z_batch, perturb_idx, P_batch)
|
|
880
875
|
Z.append(tensor_to_numpy(zns))
|
|
881
876
|
pbar.update(1)
|
|
882
877
|
|
|
883
878
|
Z = np.concatenate(Z)
|
|
884
879
|
return Z
|
|
885
880
|
|
|
886
|
-
def get_metacell_response(self, factor_idx, perturb):
|
|
887
|
-
zs = self._get_codebook()
|
|
888
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
889
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
890
|
-
return tensor_to_numpy(ms)
|
|
891
|
-
|
|
892
881
|
def _get_expression_response(self, delta_zs):
|
|
893
882
|
return self.decoder_concentrate(delta_zs)
|
|
894
883
|
|
|
@@ -913,7 +902,7 @@ class PerturbFlow(nn.Module):
|
|
|
913
902
|
R = np.concatenate(R)
|
|
914
903
|
return R
|
|
915
904
|
|
|
916
|
-
def _count(self,concentrate, library_size=None):
|
|
905
|
+
def _count(self, concentrate, library_size=None):
|
|
917
906
|
if self.loss_func == 'bernoulli':
|
|
918
907
|
#counts = self.sigmoid(concentrate)
|
|
919
908
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -921,28 +910,17 @@ class PerturbFlow(nn.Module):
|
|
|
921
910
|
rate = concentrate.exp()
|
|
922
911
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
923
912
|
counts = theta * library_size
|
|
924
|
-
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
925
|
-
return counts
|
|
926
|
-
|
|
927
|
-
def _count_sample(self,concentrate):
|
|
928
|
-
if self.loss_func == 'bernoulli':
|
|
929
|
-
logits = concentrate
|
|
930
|
-
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
931
|
-
else:
|
|
932
|
-
counts = self._count(concentrate=concentrate)
|
|
933
|
-
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
934
913
|
return counts
|
|
935
914
|
|
|
936
915
|
def get_counts(self, zs, library_sizes,
|
|
937
|
-
batch_size: int = 1024
|
|
938
|
-
use_sampler: bool = False):
|
|
916
|
+
batch_size: int = 1024):
|
|
939
917
|
|
|
940
918
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
941
919
|
|
|
942
920
|
if type(library_sizes) == list:
|
|
943
|
-
library_sizes = np.array(library_sizes).
|
|
921
|
+
library_sizes = np.array(library_sizes).reshape(-1,1)
|
|
944
922
|
elif len(library_sizes.shape)==1:
|
|
945
|
-
library_sizes = library_sizes.
|
|
923
|
+
library_sizes = library_sizes.reshape(-1,1)
|
|
946
924
|
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
947
925
|
|
|
948
926
|
dataset = CustomDataset2(zs,ls)
|
|
@@ -952,10 +930,7 @@ class PerturbFlow(nn.Module):
|
|
|
952
930
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
953
931
|
for Z_batch, L_batch, _ in dataloader:
|
|
954
932
|
concentrate = self._get_expression_response(Z_batch)
|
|
955
|
-
|
|
956
|
-
counts = self._count_sample(concentrate)
|
|
957
|
-
else:
|
|
958
|
-
counts = self._count(concentrate, L_batch)
|
|
933
|
+
counts = self._count(concentrate, L_batch)
|
|
959
934
|
E.append(tensor_to_numpy(counts))
|
|
960
935
|
pbar.update(1)
|
|
961
936
|
|
|
@@ -978,7 +953,7 @@ class PerturbFlow(nn.Module):
|
|
|
978
953
|
us = None,
|
|
979
954
|
ys = None,
|
|
980
955
|
zs = None,
|
|
981
|
-
num_epochs: int =
|
|
956
|
+
num_epochs: int = 500,
|
|
982
957
|
learning_rate: float = 0.0001,
|
|
983
958
|
batch_size: int = 256,
|
|
984
959
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -989,7 +964,7 @@ class PerturbFlow(nn.Module):
|
|
|
989
964
|
threshold: int = 0,
|
|
990
965
|
use_jax: bool = True):
|
|
991
966
|
"""
|
|
992
|
-
Train the
|
|
967
|
+
Train the DensityFlow model.
|
|
993
968
|
|
|
994
969
|
Parameters
|
|
995
970
|
----------
|
|
@@ -1015,7 +990,7 @@ class PerturbFlow(nn.Module):
|
|
|
1015
990
|
Parameter for optimization.
|
|
1016
991
|
use_jax
|
|
1017
992
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
1018
|
-
the Python script or Jupyter notebook. It is OK if it is used when runing
|
|
993
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
|
|
1019
994
|
"""
|
|
1020
995
|
xs = self.preprocess(xs, threshold=threshold)
|
|
1021
996
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
@@ -1133,12 +1108,12 @@ class PerturbFlow(nn.Module):
|
|
|
1133
1108
|
|
|
1134
1109
|
|
|
1135
1110
|
EXAMPLE_RUN = (
|
|
1136
|
-
"example run:
|
|
1111
|
+
"example run: DensityFlow --help"
|
|
1137
1112
|
)
|
|
1138
1113
|
|
|
1139
1114
|
def parse_args():
|
|
1140
1115
|
parser = argparse.ArgumentParser(
|
|
1141
|
-
description="
|
|
1116
|
+
description="DensityFlow\n{}".format(EXAMPLE_RUN))
|
|
1142
1117
|
|
|
1143
1118
|
parser.add_argument(
|
|
1144
1119
|
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
@@ -1325,7 +1300,7 @@ def main():
|
|
|
1325
1300
|
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1326
1301
|
|
|
1327
1302
|
###########################################
|
|
1328
|
-
|
|
1303
|
+
DensityFlow = DensityFlow(
|
|
1329
1304
|
input_size=input_size,
|
|
1330
1305
|
cell_factor_size=cell_factor_size,
|
|
1331
1306
|
inverse_dispersion=args.inverse_dispersion,
|
|
@@ -1344,7 +1319,7 @@ def main():
|
|
|
1344
1319
|
dtype=dtype,
|
|
1345
1320
|
)
|
|
1346
1321
|
|
|
1347
|
-
|
|
1322
|
+
DensityFlow.fit(xs, us=us,
|
|
1348
1323
|
num_epochs=args.num_epochs,
|
|
1349
1324
|
learning_rate=args.learning_rate,
|
|
1350
1325
|
batch_size=args.batch_size,
|
|
@@ -1356,9 +1331,9 @@ def main():
|
|
|
1356
1331
|
|
|
1357
1332
|
if args.save_model is not None:
|
|
1358
1333
|
if args.save_model.endswith('gz'):
|
|
1359
|
-
|
|
1334
|
+
DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
|
|
1360
1335
|
else:
|
|
1361
|
-
|
|
1336
|
+
DensityFlow.save_model(DensityFlow, args.save_model)
|
|
1362
1337
|
|
|
1363
1338
|
|
|
1364
1339
|
|
|
@@ -99,17 +99,17 @@ class SURE(nn.Module):
|
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
100
|
supervised_mode: bool = False,
|
|
101
101
|
z_dim: int = 10,
|
|
102
|
-
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = '
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = '
|
|
102
|
+
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
|
|
103
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
|
|
104
104
|
inverse_dispersion: float = 10.0,
|
|
105
105
|
use_zeroinflate: bool = True,
|
|
106
|
-
hidden_layers: list = [
|
|
106
|
+
hidden_layers: list = [500],
|
|
107
107
|
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
108
108
|
nn_dropout: float = 0.1,
|
|
109
109
|
post_layer_fct: list = ['layernorm'],
|
|
110
110
|
post_act_fct: list = None,
|
|
111
111
|
config_enum: str = 'parallel',
|
|
112
|
-
use_cuda: bool =
|
|
112
|
+
use_cuda: bool = True,
|
|
113
113
|
seed: int = 42,
|
|
114
114
|
dtype = torch.float32, # type: ignore
|
|
115
115
|
):
|
|
@@ -817,7 +817,7 @@ class SURE(nn.Module):
|
|
|
817
817
|
us = None,
|
|
818
818
|
ys = None,
|
|
819
819
|
zs = None,
|
|
820
|
-
num_epochs: int =
|
|
820
|
+
num_epochs: int = 500,
|
|
821
821
|
learning_rate: float = 0.0001,
|
|
822
822
|
batch_size: int = 256,
|
|
823
823
|
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
@@ -826,7 +826,7 @@ class SURE(nn.Module):
|
|
|
826
826
|
decay_rate: float = 0.9,
|
|
827
827
|
config_enum: str = 'parallel',
|
|
828
828
|
threshold: int = 0,
|
|
829
|
-
use_jax: bool =
|
|
829
|
+
use_jax: bool = True):
|
|
830
830
|
"""
|
|
831
831
|
Train the SURE model.
|
|
832
832
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .DensityFlow import DensityFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
-
from . import
|
|
7
|
+
from . import DensityFlow
|
|
8
8
|
from . import atac
|
|
9
9
|
from . import flow
|
|
10
10
|
from . import perturb
|
|
11
11
|
|
|
12
|
-
__all__ = ['SURE', '
|
|
12
|
+
__all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
|
|
@@ -41,6 +41,18 @@ class VectorFieldEval:
|
|
|
41
41
|
divergence[np.isnan(divergence)] = 0
|
|
42
42
|
|
|
43
43
|
return divergence
|
|
44
|
+
|
|
45
|
+
def movement_stats(self,vectors):
|
|
46
|
+
return calculate_movement_stats(vectors)
|
|
47
|
+
|
|
48
|
+
def direction_stats(self, vectors):
|
|
49
|
+
return calculate_direction_stats(vectors)
|
|
50
|
+
|
|
51
|
+
def movement_energy(self, vectors, masses=None):
|
|
52
|
+
return calculate_movement_energy(vectors, masses)
|
|
53
|
+
|
|
54
|
+
def movement_divergence(self, positions, vectors):
|
|
55
|
+
return calculate_movement_divergence(positions, vectors)
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def calculate_movement_stats(vectors):
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
3
4
|
from numba import njit
|
|
4
5
|
from itertools import chain
|
|
5
6
|
from joblib import Parallel, delayed
|
|
@@ -8,6 +9,8 @@ from typing import Literal
|
|
|
8
9
|
class LabelMatrix:
|
|
9
10
|
def __init__(self):
|
|
10
11
|
self.labels_ = None
|
|
12
|
+
self.control_label = None
|
|
13
|
+
self.sep_pattern = None
|
|
11
14
|
|
|
12
15
|
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
16
|
if speedup=='none':
|
|
@@ -24,8 +27,31 @@ class LabelMatrix:
|
|
|
24
27
|
mat = np.delete(mat, idx, axis=1)
|
|
25
28
|
self.labels_ = np.delete(self.labels_, idx)
|
|
26
29
|
|
|
30
|
+
self.control_label = control_label
|
|
31
|
+
self.sep_pattern=sep_pattern
|
|
32
|
+
|
|
27
33
|
return mat
|
|
28
|
-
|
|
34
|
+
|
|
35
|
+
def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
|
|
36
|
+
sep_pattern = self.sep_pattern
|
|
37
|
+
if speedup=='none':
|
|
38
|
+
mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
39
|
+
elif speedup=='vectorize':
|
|
40
|
+
mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
41
|
+
elif speedup=='parallel':
|
|
42
|
+
mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
43
|
+
|
|
44
|
+
mat_df = pd.DataFrame(mat, columns=labels_)
|
|
45
|
+
|
|
46
|
+
labels_valid = [x for x in labels_ if x in self.labels_]
|
|
47
|
+
mat_df = mat_df[labels_valid]
|
|
48
|
+
|
|
49
|
+
mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
|
|
50
|
+
mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
|
|
51
|
+
mat_valid_df[labels_valid] = mat_df
|
|
52
|
+
|
|
53
|
+
return mat_valid_df.values
|
|
54
|
+
|
|
29
55
|
def inverse_transform(self, matrix):
|
|
30
56
|
return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
|
|
31
57
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|