pertpy 0.9.5__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +2 -5
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +136 -30
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +221 -39
- pertpy/preprocessing/_guide_rna_mixture.py +177 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +138 -142
- pertpy/tools/_cinemaot.py +75 -117
- pertpy/tools/_coda/_base_coda.py +150 -174
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +60 -56
- pertpy/tools/_differential_gene_expression/_base.py +25 -43
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +86 -92
- pertpy/tools/_enrichment.py +8 -25
- pertpy/tools/_milo.py +23 -27
- pertpy/tools/_mixscape.py +261 -175
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +13 -17
- pertpy/tools/_scgen/_scgen.py +17 -20
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/METADATA +37 -21
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.9.5.dist-info/RECORD +0 -57
- {pertpy-0.9.5.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_cinemaot.py
CHANGED
@@ -2,18 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
|
+
import jax
|
5
6
|
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
9
10
|
import scipy.stats as ss
|
10
|
-
import seaborn as sns
|
11
11
|
import sklearn.metrics
|
12
|
-
from ott.geometry import
|
12
|
+
from ott.geometry.pointcloud import PointCloud
|
13
13
|
from ott.problems.linear import linear_problem
|
14
14
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
15
15
|
from scanpy.plotting import _utils
|
16
16
|
from scipy.sparse import issparse
|
17
|
+
from seaborn import heatmap
|
17
18
|
from sklearn.decomposition import FastICA
|
18
19
|
from sklearn.linear_model import LinearRegression
|
19
20
|
from sklearn.neighbors import NearestNeighbors
|
@@ -49,6 +50,7 @@ class Cinemaot:
|
|
49
50
|
eps: float = 1e-3,
|
50
51
|
solver: str = "Sinkhorn",
|
51
52
|
preweight_label: str | None = None,
|
53
|
+
random_state: int | None = 0,
|
52
54
|
):
|
53
55
|
"""Calculate the confounding variation, optimal transport counterfactual pairs, and single-cell level treatment effects.
|
54
56
|
|
@@ -69,6 +71,7 @@ class Cinemaot:
|
|
69
71
|
solver: Either "Sinkhorn" or "LRSinkhorn". The ott-jax solver used.
|
70
72
|
preweight_label: The annotated label (e.g. cell type) that is used to assign weights for treated
|
71
73
|
and control cells to balance across the label. Helps overcome the differential abundance issue.
|
74
|
+
random_state: The random seed for the shuffling.
|
72
75
|
|
73
76
|
Returns:
|
74
77
|
Returns an AnnData object that contains the single-cell level treatment effect as de.X and the
|
@@ -85,7 +88,7 @@ class Cinemaot:
|
|
85
88
|
"""
|
86
89
|
available_solvers = ["Sinkhorn", "LRSinkhorn"]
|
87
90
|
if solver not in available_solvers:
|
88
|
-
raise ValueError(f"solver = {solver} is not one of the supported solvers:
|
91
|
+
raise ValueError(f"solver = {solver} is not one of the supported solvers: {available_solvers}")
|
89
92
|
|
90
93
|
if dim is None:
|
91
94
|
dim = self.get_dim(adata, use_rep=use_rep)
|
@@ -96,7 +99,7 @@ class Cinemaot:
|
|
96
99
|
xi = np.zeros(dim)
|
97
100
|
j = 0
|
98
101
|
for source_row in X_transformed.T:
|
99
|
-
xi_obj = Xi(source_row, groupvec * 1)
|
102
|
+
xi_obj = Xi(source_row, groupvec * 1, random_state=random_state)
|
100
103
|
xi[j] = xi_obj.correlation
|
101
104
|
j = j + 1
|
102
105
|
|
@@ -111,7 +114,7 @@ class Cinemaot:
|
|
111
114
|
sklearn.metrics.pairwise_distances(cf1, cf2)
|
112
115
|
|
113
116
|
e = smoothness * sum(xi < thres)
|
114
|
-
geom =
|
117
|
+
geom = PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
|
115
118
|
|
116
119
|
if preweight_label is None:
|
117
120
|
ot_prob = linear_problem.LinearProblem(geom, a=None, b=None)
|
@@ -122,17 +125,15 @@ class Cinemaot:
|
|
122
125
|
a = np.zeros(cf1.shape[0])
|
123
126
|
b = np.zeros(cf2.shape[0])
|
124
127
|
|
125
|
-
adata1 = adata[adata.obs[pert_key] == control
|
126
|
-
adata2 = adata[adata.obs[pert_key] != control
|
128
|
+
adata1 = adata[adata.obs[pert_key] == control]
|
129
|
+
adata2 = adata[adata.obs[pert_key] != control]
|
127
130
|
|
128
131
|
for label in adata1.obs[pert_key].unique():
|
132
|
+
mask_label = adata1.obs[pert_key] == label
|
129
133
|
for ct in adata1.obs[preweight_label].unique():
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
a[(adata1.obs[pert_key] == label).values] = a[(adata1.obs[pert_key] == label).values] / np.sum(
|
134
|
-
a[(adata1.obs[pert_key] == label).values]
|
135
|
-
)
|
134
|
+
mask_ct = adata1.obs[preweight_label] == ct
|
135
|
+
a[mask_ct & mask_label] = np.sum(adata2.obs[preweight_label] == ct) / np.sum(mask_ct)
|
136
|
+
a[mask_label] /= np.sum(a[mask_label])
|
136
137
|
|
137
138
|
a = a / np.sum(a)
|
138
139
|
b[:] = 1 / cf2.shape[0]
|
@@ -141,25 +142,22 @@ class Cinemaot:
|
|
141
142
|
if solver == "LRSinkhorn":
|
142
143
|
if rank is None:
|
143
144
|
rank = int(min(cf1.shape[0], cf2.shape[0]) / 2)
|
144
|
-
_solver = sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps)
|
145
|
+
_solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps))
|
145
146
|
ot_sink = _solver(ot_prob)
|
146
147
|
embedding = (
|
147
148
|
X_transformed[adata.obs[pert_key] != control, :]
|
148
149
|
- ot_sink.apply(X_transformed[adata.obs[pert_key] == control, :].T).T
|
149
150
|
/ ot_sink.apply(np.ones_like(X_transformed[adata.obs[pert_key] == control, :].T)).T
|
150
151
|
)
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
- ot_sink.apply(adata.X[adata.obs[pert_key] == control, :].T).T
|
161
|
-
/ ot_sink.apply(np.ones_like(adata.X[adata.obs[pert_key] == control, :].T)).T
|
162
|
-
)
|
152
|
+
|
153
|
+
X = adata.X.toarray() if issparse(adata.X) else adata.X
|
154
|
+
te2 = (
|
155
|
+
X[adata.obs[pert_key] != control, :]
|
156
|
+
- ot_sink.apply(X[adata.obs[pert_key] == control, :].T).T
|
157
|
+
/ ot_sink.apply(np.ones_like(X[adata.obs[pert_key] == control, :].T)).T
|
158
|
+
)
|
159
|
+
if issparse(X):
|
160
|
+
del X
|
163
161
|
|
164
162
|
adata.obsm[cf_rep] = cf
|
165
163
|
adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = (
|
@@ -168,21 +166,20 @@ class Cinemaot:
|
|
168
166
|
)
|
169
167
|
|
170
168
|
else:
|
171
|
-
_solver = sinkhorn.Sinkhorn(threshold=eps)
|
169
|
+
_solver = jax.jit(sinkhorn.Sinkhorn(threshold=eps))
|
172
170
|
ot_sink = _solver(ot_prob)
|
173
171
|
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
|
174
172
|
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
|
175
173
|
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
|
176
174
|
)
|
177
175
|
|
178
|
-
if issparse(adata.X)
|
179
|
-
|
180
|
-
|
181
|
-
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
)
|
176
|
+
X = adata.X.toarray() if issparse(adata.X) else adata.X
|
177
|
+
|
178
|
+
te2 = X[adata.obs[pert_key] != control, :] - np.matmul(
|
179
|
+
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X[adata.obs[pert_key] == control, :]
|
180
|
+
)
|
181
|
+
if issparse(X):
|
182
|
+
del X
|
186
183
|
|
187
184
|
adata.obsm[cf_rep] = cf
|
188
185
|
adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = np.matmul(
|
@@ -251,7 +248,7 @@ class Cinemaot:
|
|
251
248
|
"""
|
252
249
|
available_solvers = ["Sinkhorn", "LRSinkhorn"]
|
253
250
|
assert solver in available_solvers, (
|
254
|
-
f"solver = {solver} is not one of the supported solvers:
|
251
|
+
f"solver = {solver} is not one of the supported solvers: {available_solvers}"
|
255
252
|
)
|
256
253
|
|
257
254
|
if dim is None:
|
@@ -259,7 +256,9 @@ class Cinemaot:
|
|
259
256
|
|
260
257
|
adata.obs_names_make_unique()
|
261
258
|
|
262
|
-
idx = self.
|
259
|
+
idx = self._get_weightidx(
|
260
|
+
adata, pert_key=pert_key, control=control, k=k, use_rep=use_rep, resolution=resolution
|
261
|
+
)
|
263
262
|
adata_ = adata[idx].copy()
|
264
263
|
TE = self.causaleffect(
|
265
264
|
adata_,
|
@@ -329,11 +328,10 @@ class Cinemaot:
|
|
329
328
|
df = pd.DataFrame(adata.raw.X.toarray(), columns=adata.raw.var_names, index=adata.raw.obs_names)
|
330
329
|
else:
|
331
330
|
df = pd.DataFrame(adata.raw.X, columns=adata.raw.var_names, index=adata.raw.obs_names)
|
331
|
+
elif issparse(adata.X):
|
332
|
+
df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
|
332
333
|
else:
|
333
|
-
|
334
|
-
df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
|
335
|
-
else:
|
336
|
-
df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
|
334
|
+
df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
|
337
335
|
|
338
336
|
if label_list is None:
|
339
337
|
label_list = ["ct"]
|
@@ -377,10 +375,7 @@ class Cinemaot:
|
|
377
375
|
>>> dim = model.get_dim(adata)
|
378
376
|
"""
|
379
377
|
sk = SinkhornKnopp()
|
380
|
-
if issparse(adata.raw.X)
|
381
|
-
data = adata.raw.X.toarray()
|
382
|
-
else:
|
383
|
-
data = adata.raw.X
|
378
|
+
data = adata.raw.X.toarray() if issparse(adata.raw.X) else adata.raw.X
|
384
379
|
vm = (1e-3 + data + c * data * data) / (1 + c)
|
385
380
|
sk.fit(vm)
|
386
381
|
wm = np.dot(np.dot(np.sqrt(sk._D1), vm), np.sqrt(sk._D2))
|
@@ -388,9 +383,10 @@ class Cinemaot:
|
|
388
383
|
dim = min(sum(s > (np.sqrt(data.shape[0]) + np.sqrt(data.shape[1]))), adata.obsm[use_rep].shape[1])
|
389
384
|
return dim
|
390
385
|
|
391
|
-
def
|
386
|
+
def _get_weightidx(
|
392
387
|
self,
|
393
388
|
adata: AnnData,
|
389
|
+
*,
|
394
390
|
pert_key: str,
|
395
391
|
control: str,
|
396
392
|
use_rep: str = "X_pca",
|
@@ -401,7 +397,8 @@ class Cinemaot:
|
|
401
397
|
|
402
398
|
Args:
|
403
399
|
adata: The annotated data object.
|
404
|
-
|
400
|
+
pert_key: Key of the perturbation col.
|
401
|
+
control: Key of the control col.
|
405
402
|
use_rep: the embedding used to give a upper bound for the estimated rank.
|
406
403
|
k: the number of neighbors used in the k-NN matching phase.
|
407
404
|
resolution: the clustering resolution used in the sampling phase.
|
@@ -413,7 +410,7 @@ class Cinemaot:
|
|
413
410
|
>>> import pertpy as pt
|
414
411
|
>>> adata = pt.dt.cinemaot_example()
|
415
412
|
>>> model = pt.tl.Cinemaot()
|
416
|
-
>>> idx = model.
|
413
|
+
>>> idx = model._get_weightidx(adata, pert_key="perturbation", control="No stimulation")
|
417
414
|
"""
|
418
415
|
adata_ = adata.copy()
|
419
416
|
X_pca1 = adata_.obsm[use_rep][adata_.obs[pert_key] == control, :]
|
@@ -621,13 +618,12 @@ class Cinemaot:
|
|
621
618
|
else:
|
622
619
|
Y0 = adata.raw.X[adata.obs[pert_key] == control, :]
|
623
620
|
Y1 = adata.raw.X[adata.obs[pert_key] != control, :]
|
621
|
+
elif issparse(adata.X):
|
622
|
+
Y0 = adata.X.toarray()[adata.obs[pert_key] == control, :]
|
623
|
+
Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
|
624
624
|
else:
|
625
|
-
|
626
|
-
|
627
|
-
Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
|
628
|
-
else:
|
629
|
-
Y0 = adata.X[adata.obs[pert_key] == control, :]
|
630
|
-
Y1 = adata.X[adata.obs[pert_key] != control, :]
|
625
|
+
Y0 = adata.X[adata.obs[pert_key] == control, :]
|
626
|
+
Y1 = adata.X[adata.obs[pert_key] != control, :]
|
631
627
|
X0 = cf[adata.obs[pert_key] == control, :]
|
632
628
|
X1 = cf[adata.obs[pert_key] != control, :]
|
633
629
|
ols0 = LinearRegression()
|
@@ -643,7 +639,7 @@ class Cinemaot:
|
|
643
639
|
return c_effect, s_effect
|
644
640
|
|
645
641
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
646
|
-
def plot_vis_matching(
|
642
|
+
def plot_vis_matching( # pragma: no cover # noqa: D417
|
647
643
|
self,
|
648
644
|
adata: AnnData,
|
649
645
|
de: AnnData,
|
@@ -658,7 +654,6 @@ class Cinemaot:
|
|
658
654
|
title: str = "CINEMA-OT matching matrix",
|
659
655
|
min_val: float = 0.01,
|
660
656
|
ax: Axes | None = None,
|
661
|
-
show: bool = True,
|
662
657
|
return_fig: bool = False,
|
663
658
|
**kwargs,
|
664
659
|
) -> Figure | None:
|
@@ -672,9 +667,11 @@ class Cinemaot:
|
|
672
667
|
de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
|
673
668
|
source_label: the confounder / cell type label.
|
674
669
|
matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
|
670
|
+
resolution: Leiden resolution.
|
675
671
|
normalize: normalize the coarse-grained matching matrix by row / column.
|
676
672
|
title: the title for the figure.
|
677
673
|
min_val: The min value to truncate the matching matrix.
|
674
|
+
ax: Matplotlib axes object.
|
678
675
|
{common_plot_args}
|
679
676
|
**kwargs: Other parameters to input for seaborn.heatmap.
|
680
677
|
|
@@ -704,35 +701,26 @@ class Cinemaot:
|
|
704
701
|
df["source_label"] = adata_.obs[source_label].astype(str).values
|
705
702
|
df = df.groupby("source_label").sum()
|
706
703
|
|
707
|
-
if normalize == "col"
|
708
|
-
df = df / df.sum(axis=0)
|
709
|
-
else:
|
710
|
-
df = (df.T / df.sum(axis=1)).T
|
704
|
+
df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
|
711
705
|
df = df.clip(lower=min_val) - min_val
|
712
|
-
if normalize == "col"
|
713
|
-
df = df / df.sum(axis=0)
|
714
|
-
else:
|
715
|
-
df = (df.T / df.sum(axis=1)).T
|
706
|
+
df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
|
716
707
|
|
717
|
-
g =
|
708
|
+
g = heatmap(df, annot=True, ax=ax, **kwargs)
|
718
709
|
plt.title(title)
|
719
710
|
|
720
|
-
if show:
|
721
|
-
plt.show()
|
722
711
|
if return_fig:
|
723
712
|
return g
|
713
|
+
plt.show()
|
724
714
|
return None
|
725
715
|
|
726
716
|
|
727
717
|
class Xi:
|
728
|
-
"""
|
729
|
-
A fast implementation of cross-rank dependence metric used in CINEMA-OT.
|
730
|
-
|
731
|
-
"""
|
718
|
+
"""A fast implementation of cross-rank dependence metric used in CINEMA-OT."""
|
732
719
|
|
733
|
-
def __init__(self, x, y):
|
720
|
+
def __init__(self, x, y, random_state: int | None = 0):
|
734
721
|
self.x = x
|
735
722
|
self.y = y
|
723
|
+
self.random_state = random_state
|
736
724
|
|
737
725
|
@property
|
738
726
|
def sample_size(self):
|
@@ -741,18 +729,15 @@ class Xi:
|
|
741
729
|
@property
|
742
730
|
def x_ordered_rank(self):
|
743
731
|
# PI is the rank vector for x, with ties broken at random
|
744
|
-
# Not mine: source (https://stackoverflow.com/a/47430384/1628971)
|
745
|
-
# random shuffling of the data - reason to use random.choice is that
|
746
|
-
# pd.sample(frac=1) uses the same randomizing algorithm
|
747
732
|
len_x = len(self.x)
|
748
|
-
rng = np.random.default_rng()
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
return
|
733
|
+
rng = np.random.default_rng(self.random_state)
|
734
|
+
perm = rng.permutation(len_x)
|
735
|
+
x_shuffled = self.x[perm]
|
736
|
+
|
737
|
+
ranks = np.empty(len_x, dtype=int)
|
738
|
+
ranks[perm[np.argsort(x_shuffled, stable=True)]] = np.arange(1, len_x + 1)
|
739
|
+
|
740
|
+
return ranks
|
756
741
|
|
757
742
|
@property
|
758
743
|
def y_rank_max(self):
|
@@ -762,7 +747,7 @@ class Xi:
|
|
762
747
|
@property
|
763
748
|
def g(self):
|
764
749
|
# g[i] is number of j s.t. y[j] >= y[i], divided by n.
|
765
|
-
return ss.rankdata(
|
750
|
+
return ss.rankdata(-self.y, method="max") / self.sample_size
|
766
751
|
|
767
752
|
@property
|
768
753
|
def x_ordered(self):
|
@@ -771,32 +756,14 @@ class Xi:
|
|
771
756
|
|
772
757
|
@property
|
773
758
|
def x_rank_max_ordered(self):
|
774
|
-
|
775
|
-
y_rank_max_result = self.y_rank_max
|
776
|
-
# Rearrange f according to ord.
|
777
|
-
return [y_rank_max_result[i] for i in x_ordered_result]
|
759
|
+
return self.y_rank_max[self.x_ordered]
|
778
760
|
|
779
761
|
@property
|
780
762
|
def mean_absolute(self):
|
781
763
|
x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
|
782
764
|
x2 = self.x_rank_max_ordered[1 : self.sample_size]
|
783
765
|
|
784
|
-
return (
|
785
|
-
np.mean(
|
786
|
-
np.abs(
|
787
|
-
[
|
788
|
-
x - y
|
789
|
-
for x, y in zip(
|
790
|
-
x1,
|
791
|
-
x2,
|
792
|
-
strict=False,
|
793
|
-
)
|
794
|
-
]
|
795
|
-
)
|
796
|
-
)
|
797
|
-
* (self.sample_size - 1)
|
798
|
-
/ (2 * self.sample_size)
|
799
|
-
)
|
766
|
+
return np.mean(np.abs(x1 - x2)) * (self.sample_size - 1) / (2 * self.sample_size)
|
800
767
|
|
801
768
|
@property
|
802
769
|
def inverse_g_mean(self):
|
@@ -805,7 +772,7 @@ class Xi:
|
|
805
772
|
|
806
773
|
@property
|
807
774
|
def correlation(self):
|
808
|
-
"""
|
775
|
+
"""Xi correlation."""
|
809
776
|
return 1 - self.mean_absolute / self.inverse_g_mean
|
810
777
|
|
811
778
|
@classmethod
|
@@ -854,10 +821,7 @@ class Xi:
|
|
854
821
|
|
855
822
|
|
856
823
|
class SinkhornKnopp:
|
857
|
-
"""
|
858
|
-
An simple implementation of Sinkhorn iteration used in the biwhitening approach.
|
859
|
-
|
860
|
-
"""
|
824
|
+
"""An simple implementation of Sinkhorn iteration used in the biwhitening approach."""
|
861
825
|
|
862
826
|
def __init__(self, max_iter: float = 1000, setr: int = 0, setc: float = 0, epsilon: float = 1e-3):
|
863
827
|
if max_iter < 0:
|
@@ -891,14 +855,8 @@ class SinkhornKnopp:
|
|
891
855
|
assert P.ndim == 2
|
892
856
|
|
893
857
|
N = P.shape[0]
|
894
|
-
if np.sum(abs(self._setr)) == 0
|
895
|
-
|
896
|
-
else:
|
897
|
-
rsum = self._setr
|
898
|
-
if np.sum(abs(self._setc)) == 0:
|
899
|
-
csum = P.shape[0]
|
900
|
-
else:
|
901
|
-
csum = self._setc
|
858
|
+
rsum = P.shape[1] if np.sum(abs(self._setr)) == 0 else self._setr
|
859
|
+
csum = P.shape[0] if np.sum(abs(self._setc)) == 0 else self._setc
|
902
860
|
max_threshr = rsum + self._epsilon
|
903
861
|
min_threshr = rsum - self._epsilon
|
904
862
|
max_threshc = csum + self._epsilon
|