pertpy 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +5 -1
- pertpy/_doc.py +1 -3
- pertpy/_types.py +6 -0
- pertpy/data/_dataloader.py +68 -24
- pertpy/data/_datasets.py +9 -9
- pertpy/metadata/__init__.py +2 -1
- pertpy/metadata/_cell_line.py +133 -25
- pertpy/metadata/_look_up.py +13 -19
- pertpy/metadata/_moa.py +1 -1
- pertpy/preprocessing/_guide_rna.py +138 -44
- pertpy/preprocessing/_guide_rna_mixture.py +17 -19
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +106 -98
- pertpy/tools/_cinemaot.py +74 -114
- pertpy/tools/_coda/_base_coda.py +129 -145
- pertpy/tools/_coda/_sccoda.py +66 -69
- pertpy/tools/_coda/_tasccoda.py +71 -79
- pertpy/tools/_dialogue.py +48 -40
- pertpy/tools/_differential_gene_expression/_base.py +21 -31
- pertpy/tools/_differential_gene_expression/_checks.py +4 -6
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +5 -6
- pertpy/tools/_differential_gene_expression/_edger.py +6 -10
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +1 -1
- pertpy/tools/_differential_gene_expression/_simple_tests.py +3 -3
- pertpy/tools/_differential_gene_expression/_statsmodels.py +8 -5
- pertpy/tools/_distances/_distance_tests.py +1 -2
- pertpy/tools/_distances/_distances.py +31 -45
- pertpy/tools/_enrichment.py +7 -22
- pertpy/tools/_milo.py +19 -15
- pertpy/tools/_mixscape.py +73 -75
- pertpy/tools/_perturbation_space/_clustering.py +4 -4
- pertpy/tools/_perturbation_space/_comparison.py +4 -4
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +83 -32
- pertpy/tools/_perturbation_space/_perturbation_space.py +10 -10
- pertpy/tools/_perturbation_space/_simple.py +12 -14
- pertpy/tools/_scgen/_scgen.py +16 -17
- pertpy/tools/_scgen/_scgenvae.py +2 -2
- pertpy/tools/_scgen/_utils.py +3 -1
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/METADATA +36 -20
- pertpy-0.11.0.dist-info/RECORD +58 -0
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/licenses/LICENSE +1 -0
- pertpy/tools/_kernel_pca.py +0 -50
- pertpy-0.10.0.dist-info/RECORD +0 -58
- {pertpy-0.10.0.dist-info → pertpy-0.11.0.dist-info}/WHEEL +0 -0
pertpy/tools/_cinemaot.py
CHANGED
@@ -2,18 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
|
+
import jax
|
5
6
|
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
import scanpy as sc
|
9
10
|
import scipy.stats as ss
|
10
|
-
import seaborn as sns
|
11
11
|
import sklearn.metrics
|
12
|
-
from ott.geometry import
|
12
|
+
from ott.geometry.pointcloud import PointCloud
|
13
13
|
from ott.problems.linear import linear_problem
|
14
14
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
15
15
|
from scanpy.plotting import _utils
|
16
16
|
from scipy.sparse import issparse
|
17
|
+
from seaborn import heatmap
|
17
18
|
from sklearn.decomposition import FastICA
|
18
19
|
from sklearn.linear_model import LinearRegression
|
19
20
|
from sklearn.neighbors import NearestNeighbors
|
@@ -49,6 +50,7 @@ class Cinemaot:
|
|
49
50
|
eps: float = 1e-3,
|
50
51
|
solver: str = "Sinkhorn",
|
51
52
|
preweight_label: str | None = None,
|
53
|
+
random_state: int | None = 0,
|
52
54
|
):
|
53
55
|
"""Calculate the confounding variation, optimal transport counterfactual pairs, and single-cell level treatment effects.
|
54
56
|
|
@@ -69,6 +71,7 @@ class Cinemaot:
|
|
69
71
|
solver: Either "Sinkhorn" or "LRSinkhorn". The ott-jax solver used.
|
70
72
|
preweight_label: The annotated label (e.g. cell type) that is used to assign weights for treated
|
71
73
|
and control cells to balance across the label. Helps overcome the differential abundance issue.
|
74
|
+
random_state: The random seed for the shuffling.
|
72
75
|
|
73
76
|
Returns:
|
74
77
|
Returns an AnnData object that contains the single-cell level treatment effect as de.X and the
|
@@ -85,7 +88,7 @@ class Cinemaot:
|
|
85
88
|
"""
|
86
89
|
available_solvers = ["Sinkhorn", "LRSinkhorn"]
|
87
90
|
if solver not in available_solvers:
|
88
|
-
raise ValueError(f"solver = {solver} is not one of the supported solvers:
|
91
|
+
raise ValueError(f"solver = {solver} is not one of the supported solvers: {available_solvers}")
|
89
92
|
|
90
93
|
if dim is None:
|
91
94
|
dim = self.get_dim(adata, use_rep=use_rep)
|
@@ -96,7 +99,7 @@ class Cinemaot:
|
|
96
99
|
xi = np.zeros(dim)
|
97
100
|
j = 0
|
98
101
|
for source_row in X_transformed.T:
|
99
|
-
xi_obj = Xi(source_row, groupvec * 1)
|
102
|
+
xi_obj = Xi(source_row, groupvec * 1, random_state=random_state)
|
100
103
|
xi[j] = xi_obj.correlation
|
101
104
|
j = j + 1
|
102
105
|
|
@@ -111,7 +114,7 @@ class Cinemaot:
|
|
111
114
|
sklearn.metrics.pairwise_distances(cf1, cf2)
|
112
115
|
|
113
116
|
e = smoothness * sum(xi < thres)
|
114
|
-
geom =
|
117
|
+
geom = PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
|
115
118
|
|
116
119
|
if preweight_label is None:
|
117
120
|
ot_prob = linear_problem.LinearProblem(geom, a=None, b=None)
|
@@ -122,17 +125,15 @@ class Cinemaot:
|
|
122
125
|
a = np.zeros(cf1.shape[0])
|
123
126
|
b = np.zeros(cf2.shape[0])
|
124
127
|
|
125
|
-
adata1 = adata[adata.obs[pert_key] == control
|
126
|
-
adata2 = adata[adata.obs[pert_key] != control
|
128
|
+
adata1 = adata[adata.obs[pert_key] == control]
|
129
|
+
adata2 = adata[adata.obs[pert_key] != control]
|
127
130
|
|
128
131
|
for label in adata1.obs[pert_key].unique():
|
132
|
+
mask_label = adata1.obs[pert_key] == label
|
129
133
|
for ct in adata1.obs[preweight_label].unique():
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
a[(adata1.obs[pert_key] == label).values] = a[(adata1.obs[pert_key] == label).values] / np.sum(
|
134
|
-
a[(adata1.obs[pert_key] == label).values]
|
135
|
-
)
|
134
|
+
mask_ct = adata1.obs[preweight_label] == ct
|
135
|
+
a[mask_ct & mask_label] = np.sum(adata2.obs[preweight_label] == ct) / np.sum(mask_ct)
|
136
|
+
a[mask_label] /= np.sum(a[mask_label])
|
136
137
|
|
137
138
|
a = a / np.sum(a)
|
138
139
|
b[:] = 1 / cf2.shape[0]
|
@@ -141,25 +142,22 @@ class Cinemaot:
|
|
141
142
|
if solver == "LRSinkhorn":
|
142
143
|
if rank is None:
|
143
144
|
rank = int(min(cf1.shape[0], cf2.shape[0]) / 2)
|
144
|
-
_solver = sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps)
|
145
|
+
_solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=rank, threshold=eps))
|
145
146
|
ot_sink = _solver(ot_prob)
|
146
147
|
embedding = (
|
147
148
|
X_transformed[adata.obs[pert_key] != control, :]
|
148
149
|
- ot_sink.apply(X_transformed[adata.obs[pert_key] == control, :].T).T
|
149
150
|
/ ot_sink.apply(np.ones_like(X_transformed[adata.obs[pert_key] == control, :].T)).T
|
150
151
|
)
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
- ot_sink.apply(adata.X[adata.obs[pert_key] == control, :].T).T
|
161
|
-
/ ot_sink.apply(np.ones_like(adata.X[adata.obs[pert_key] == control, :].T)).T
|
162
|
-
)
|
152
|
+
|
153
|
+
X = adata.X.toarray() if issparse(adata.X) else adata.X
|
154
|
+
te2 = (
|
155
|
+
X[adata.obs[pert_key] != control, :]
|
156
|
+
- ot_sink.apply(X[adata.obs[pert_key] == control, :].T).T
|
157
|
+
/ ot_sink.apply(np.ones_like(X[adata.obs[pert_key] == control, :].T)).T
|
158
|
+
)
|
159
|
+
if issparse(X):
|
160
|
+
del X
|
163
161
|
|
164
162
|
adata.obsm[cf_rep] = cf
|
165
163
|
adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = (
|
@@ -168,21 +166,20 @@ class Cinemaot:
|
|
168
166
|
)
|
169
167
|
|
170
168
|
else:
|
171
|
-
_solver = sinkhorn.Sinkhorn(threshold=eps)
|
169
|
+
_solver = jax.jit(sinkhorn.Sinkhorn(threshold=eps))
|
172
170
|
ot_sink = _solver(ot_prob)
|
173
171
|
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
|
174
172
|
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
|
175
173
|
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
|
176
174
|
)
|
177
175
|
|
178
|
-
if issparse(adata.X)
|
179
|
-
|
180
|
-
|
181
|
-
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
)
|
176
|
+
X = adata.X.toarray() if issparse(adata.X) else adata.X
|
177
|
+
|
178
|
+
te2 = X[adata.obs[pert_key] != control, :] - np.matmul(
|
179
|
+
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X[adata.obs[pert_key] == control, :]
|
180
|
+
)
|
181
|
+
if issparse(X):
|
182
|
+
del X
|
186
183
|
|
187
184
|
adata.obsm[cf_rep] = cf
|
188
185
|
adata.obsm[cf_rep][adata.obs[pert_key] != control, :] = np.matmul(
|
@@ -251,7 +248,7 @@ class Cinemaot:
|
|
251
248
|
"""
|
252
249
|
available_solvers = ["Sinkhorn", "LRSinkhorn"]
|
253
250
|
assert solver in available_solvers, (
|
254
|
-
f"solver = {solver} is not one of the supported solvers:
|
251
|
+
f"solver = {solver} is not one of the supported solvers: {available_solvers}"
|
255
252
|
)
|
256
253
|
|
257
254
|
if dim is None:
|
@@ -259,7 +256,9 @@ class Cinemaot:
|
|
259
256
|
|
260
257
|
adata.obs_names_make_unique()
|
261
258
|
|
262
|
-
idx = self.
|
259
|
+
idx = self._get_weightidx(
|
260
|
+
adata, pert_key=pert_key, control=control, k=k, use_rep=use_rep, resolution=resolution
|
261
|
+
)
|
263
262
|
adata_ = adata[idx].copy()
|
264
263
|
TE = self.causaleffect(
|
265
264
|
adata_,
|
@@ -329,11 +328,10 @@ class Cinemaot:
|
|
329
328
|
df = pd.DataFrame(adata.raw.X.toarray(), columns=adata.raw.var_names, index=adata.raw.obs_names)
|
330
329
|
else:
|
331
330
|
df = pd.DataFrame(adata.raw.X, columns=adata.raw.var_names, index=adata.raw.obs_names)
|
331
|
+
elif issparse(adata.X):
|
332
|
+
df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
|
332
333
|
else:
|
333
|
-
|
334
|
-
df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
|
335
|
-
else:
|
336
|
-
df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
|
334
|
+
df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
|
337
335
|
|
338
336
|
if label_list is None:
|
339
337
|
label_list = ["ct"]
|
@@ -377,10 +375,7 @@ class Cinemaot:
|
|
377
375
|
>>> dim = model.get_dim(adata)
|
378
376
|
"""
|
379
377
|
sk = SinkhornKnopp()
|
380
|
-
if issparse(adata.raw.X)
|
381
|
-
data = adata.raw.X.toarray()
|
382
|
-
else:
|
383
|
-
data = adata.raw.X
|
378
|
+
data = adata.raw.X.toarray() if issparse(adata.raw.X) else adata.raw.X
|
384
379
|
vm = (1e-3 + data + c * data * data) / (1 + c)
|
385
380
|
sk.fit(vm)
|
386
381
|
wm = np.dot(np.dot(np.sqrt(sk._D1), vm), np.sqrt(sk._D2))
|
@@ -388,9 +383,10 @@ class Cinemaot:
|
|
388
383
|
dim = min(sum(s > (np.sqrt(data.shape[0]) + np.sqrt(data.shape[1]))), adata.obsm[use_rep].shape[1])
|
389
384
|
return dim
|
390
385
|
|
391
|
-
def
|
386
|
+
def _get_weightidx(
|
392
387
|
self,
|
393
388
|
adata: AnnData,
|
389
|
+
*,
|
394
390
|
pert_key: str,
|
395
391
|
control: str,
|
396
392
|
use_rep: str = "X_pca",
|
@@ -401,7 +397,8 @@ class Cinemaot:
|
|
401
397
|
|
402
398
|
Args:
|
403
399
|
adata: The annotated data object.
|
404
|
-
|
400
|
+
pert_key: Key of the perturbation col.
|
401
|
+
control: Key of the control col.
|
405
402
|
use_rep: the embedding used to give a upper bound for the estimated rank.
|
406
403
|
k: the number of neighbors used in the k-NN matching phase.
|
407
404
|
resolution: the clustering resolution used in the sampling phase.
|
@@ -413,7 +410,7 @@ class Cinemaot:
|
|
413
410
|
>>> import pertpy as pt
|
414
411
|
>>> adata = pt.dt.cinemaot_example()
|
415
412
|
>>> model = pt.tl.Cinemaot()
|
416
|
-
>>> idx = model.
|
413
|
+
>>> idx = model._get_weightidx(adata, pert_key="perturbation", control="No stimulation")
|
417
414
|
"""
|
418
415
|
adata_ = adata.copy()
|
419
416
|
X_pca1 = adata_.obsm[use_rep][adata_.obs[pert_key] == control, :]
|
@@ -621,13 +618,12 @@ class Cinemaot:
|
|
621
618
|
else:
|
622
619
|
Y0 = adata.raw.X[adata.obs[pert_key] == control, :]
|
623
620
|
Y1 = adata.raw.X[adata.obs[pert_key] != control, :]
|
621
|
+
elif issparse(adata.X):
|
622
|
+
Y0 = adata.X.toarray()[adata.obs[pert_key] == control, :]
|
623
|
+
Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
|
624
624
|
else:
|
625
|
-
|
626
|
-
|
627
|
-
Y1 = adata.X.toarray()[adata.obs[pert_key] != control, :]
|
628
|
-
else:
|
629
|
-
Y0 = adata.X[adata.obs[pert_key] == control, :]
|
630
|
-
Y1 = adata.X[adata.obs[pert_key] != control, :]
|
625
|
+
Y0 = adata.X[adata.obs[pert_key] == control, :]
|
626
|
+
Y1 = adata.X[adata.obs[pert_key] != control, :]
|
631
627
|
X0 = cf[adata.obs[pert_key] == control, :]
|
632
628
|
X1 = cf[adata.obs[pert_key] != control, :]
|
633
629
|
ols0 = LinearRegression()
|
@@ -643,7 +639,7 @@ class Cinemaot:
|
|
643
639
|
return c_effect, s_effect
|
644
640
|
|
645
641
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
646
|
-
def plot_vis_matching(
|
642
|
+
def plot_vis_matching( # pragma: no cover # noqa: D417
|
647
643
|
self,
|
648
644
|
adata: AnnData,
|
649
645
|
de: AnnData,
|
@@ -671,9 +667,11 @@ class Cinemaot:
|
|
671
667
|
de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
|
672
668
|
source_label: the confounder / cell type label.
|
673
669
|
matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
|
670
|
+
resolution: Leiden resolution.
|
674
671
|
normalize: normalize the coarse-grained matching matrix by row / column.
|
675
672
|
title: the title for the figure.
|
676
673
|
min_val: The min value to truncate the matching matrix.
|
674
|
+
ax: Matplotlib axes object.
|
677
675
|
{common_plot_args}
|
678
676
|
**kwargs: Other parameters to input for seaborn.heatmap.
|
679
677
|
|
@@ -703,17 +701,11 @@ class Cinemaot:
|
|
703
701
|
df["source_label"] = adata_.obs[source_label].astype(str).values
|
704
702
|
df = df.groupby("source_label").sum()
|
705
703
|
|
706
|
-
if normalize == "col"
|
707
|
-
df = df / df.sum(axis=0)
|
708
|
-
else:
|
709
|
-
df = (df.T / df.sum(axis=1)).T
|
704
|
+
df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
|
710
705
|
df = df.clip(lower=min_val) - min_val
|
711
|
-
if normalize == "col"
|
712
|
-
df = df / df.sum(axis=0)
|
713
|
-
else:
|
714
|
-
df = (df.T / df.sum(axis=1)).T
|
706
|
+
df = df / df.sum(axis=0) if normalize == "col" else (df.T / df.sum(axis=1)).T
|
715
707
|
|
716
|
-
g =
|
708
|
+
g = heatmap(df, annot=True, ax=ax, **kwargs)
|
717
709
|
plt.title(title)
|
718
710
|
|
719
711
|
if return_fig:
|
@@ -723,14 +715,12 @@ class Cinemaot:
|
|
723
715
|
|
724
716
|
|
725
717
|
class Xi:
|
726
|
-
"""
|
727
|
-
A fast implementation of cross-rank dependence metric used in CINEMA-OT.
|
728
|
-
|
729
|
-
"""
|
718
|
+
"""A fast implementation of cross-rank dependence metric used in CINEMA-OT."""
|
730
719
|
|
731
|
-
def __init__(self, x, y):
|
720
|
+
def __init__(self, x, y, random_state: int | None = 0):
|
732
721
|
self.x = x
|
733
722
|
self.y = y
|
723
|
+
self.random_state = random_state
|
734
724
|
|
735
725
|
@property
|
736
726
|
def sample_size(self):
|
@@ -739,18 +729,15 @@ class Xi:
|
|
739
729
|
@property
|
740
730
|
def x_ordered_rank(self):
|
741
731
|
# PI is the rank vector for x, with ties broken at random
|
742
|
-
# Not mine: source (https://stackoverflow.com/a/47430384/1628971)
|
743
|
-
# random shuffling of the data - reason to use random.choice is that
|
744
|
-
# pd.sample(frac=1) uses the same randomizing algorithm
|
745
732
|
len_x = len(self.x)
|
746
|
-
rng = np.random.default_rng()
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
return
|
733
|
+
rng = np.random.default_rng(self.random_state)
|
734
|
+
perm = rng.permutation(len_x)
|
735
|
+
x_shuffled = self.x[perm]
|
736
|
+
|
737
|
+
ranks = np.empty(len_x, dtype=int)
|
738
|
+
ranks[perm[np.argsort(x_shuffled, stable=True)]] = np.arange(1, len_x + 1)
|
739
|
+
|
740
|
+
return ranks
|
754
741
|
|
755
742
|
@property
|
756
743
|
def y_rank_max(self):
|
@@ -760,7 +747,7 @@ class Xi:
|
|
760
747
|
@property
|
761
748
|
def g(self):
|
762
749
|
# g[i] is number of j s.t. y[j] >= y[i], divided by n.
|
763
|
-
return ss.rankdata(
|
750
|
+
return ss.rankdata(-self.y, method="max") / self.sample_size
|
764
751
|
|
765
752
|
@property
|
766
753
|
def x_ordered(self):
|
@@ -769,32 +756,14 @@ class Xi:
|
|
769
756
|
|
770
757
|
@property
|
771
758
|
def x_rank_max_ordered(self):
|
772
|
-
|
773
|
-
y_rank_max_result = self.y_rank_max
|
774
|
-
# Rearrange f according to ord.
|
775
|
-
return [y_rank_max_result[i] for i in x_ordered_result]
|
759
|
+
return self.y_rank_max[self.x_ordered]
|
776
760
|
|
777
761
|
@property
|
778
762
|
def mean_absolute(self):
|
779
763
|
x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
|
780
764
|
x2 = self.x_rank_max_ordered[1 : self.sample_size]
|
781
765
|
|
782
|
-
return (
|
783
|
-
np.mean(
|
784
|
-
np.abs(
|
785
|
-
[
|
786
|
-
x - y
|
787
|
-
for x, y in zip(
|
788
|
-
x1,
|
789
|
-
x2,
|
790
|
-
strict=False,
|
791
|
-
)
|
792
|
-
]
|
793
|
-
)
|
794
|
-
)
|
795
|
-
* (self.sample_size - 1)
|
796
|
-
/ (2 * self.sample_size)
|
797
|
-
)
|
766
|
+
return np.mean(np.abs(x1 - x2)) * (self.sample_size - 1) / (2 * self.sample_size)
|
798
767
|
|
799
768
|
@property
|
800
769
|
def inverse_g_mean(self):
|
@@ -803,7 +772,7 @@ class Xi:
|
|
803
772
|
|
804
773
|
@property
|
805
774
|
def correlation(self):
|
806
|
-
"""
|
775
|
+
"""Xi correlation."""
|
807
776
|
return 1 - self.mean_absolute / self.inverse_g_mean
|
808
777
|
|
809
778
|
@classmethod
|
@@ -852,10 +821,7 @@ class Xi:
|
|
852
821
|
|
853
822
|
|
854
823
|
class SinkhornKnopp:
|
855
|
-
"""
|
856
|
-
An simple implementation of Sinkhorn iteration used in the biwhitening approach.
|
857
|
-
|
858
|
-
"""
|
824
|
+
"""An simple implementation of Sinkhorn iteration used in the biwhitening approach."""
|
859
825
|
|
860
826
|
def __init__(self, max_iter: float = 1000, setr: int = 0, setc: float = 0, epsilon: float = 1e-3):
|
861
827
|
if max_iter < 0:
|
@@ -889,14 +855,8 @@ class SinkhornKnopp:
|
|
889
855
|
assert P.ndim == 2
|
890
856
|
|
891
857
|
N = P.shape[0]
|
892
|
-
if np.sum(abs(self._setr)) == 0
|
893
|
-
|
894
|
-
else:
|
895
|
-
rsum = self._setr
|
896
|
-
if np.sum(abs(self._setc)) == 0:
|
897
|
-
csum = P.shape[0]
|
898
|
-
else:
|
899
|
-
csum = self._setc
|
858
|
+
rsum = P.shape[1] if np.sum(abs(self._setr)) == 0 else self._setr
|
859
|
+
csum = P.shape[0] if np.sum(abs(self._setc)) == 0 else self._setc
|
900
860
|
max_threshr = rsum + self._epsilon
|
901
861
|
min_threshr = rsum - self._epsilon
|
902
862
|
max_threshc = csum + self._epsilon
|