nystrom-ncut 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/common.py +18 -5
- nystrom_ncut/distance_utils.py +54 -32
- nystrom_ncut/nystrom/__init__.py +0 -3
- nystrom_ncut/nystrom/distance_realization.py +8 -9
- nystrom_ncut/nystrom/normalized_cut.py +51 -47
- nystrom_ncut/nystrom/nystrom_utils.py +78 -69
- nystrom_ncut/sampling_utils.py +64 -51
- nystrom_ncut/visualize_utils.py +31 -43
- {nystrom_ncut-0.2.1.dist-info → nystrom_ncut-0.3.0.dist-info}/METADATA +1 -1
- nystrom_ncut-0.3.0.dist-info/RECORD +18 -0
- {nystrom_ncut-0.2.1.dist-info → nystrom_ncut-0.3.0.dist-info}/WHEEL +1 -1
- nystrom_ncut-0.2.1.dist-info/RECORD +0 -18
- {nystrom_ncut-0.2.1.dist-info → nystrom_ncut-0.3.0.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.2.1.dist-info → nystrom_ncut-0.3.0.dist-info}/top_level.txt +0 -0
nystrom_ncut/common.py
CHANGED
@@ -12,8 +12,8 @@ def ceildiv(a: int, b: int) -> int:
|
|
12
12
|
def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> torch.Tensor:
|
13
13
|
numel = np.prod(x.shape[:-1])
|
14
14
|
n = min(n, numel)
|
15
|
-
random_indices = torch.randperm(numel)[:n]
|
16
|
-
_x = x.
|
15
|
+
random_indices = torch.randperm(numel, device=x.device)[:n]
|
16
|
+
_x = x.view((-1, x.shape[-1]))[random_indices]
|
17
17
|
if torch.allclose(torch.norm(_x, **normalize_kwargs), torch.ones(n, device=x.device)):
|
18
18
|
return x
|
19
19
|
else:
|
@@ -21,13 +21,14 @@ def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> t
|
|
21
21
|
|
22
22
|
|
23
23
|
def quantile_min_max(x: torch.Tensor, q1: float, q2: float, n_sample: int = 10000):
|
24
|
-
|
24
|
+
x = x.flatten()
|
25
|
+
if len(x) > n_sample:
|
25
26
|
np.random.seed(0)
|
26
|
-
random_idx = np.random.choice(x
|
27
|
+
random_idx = np.random.choice(len(x), n_sample, replace=False)
|
27
28
|
vmin, vmax = x[random_idx].quantile(q1), x[random_idx].quantile(q2)
|
28
29
|
else:
|
29
30
|
vmin, vmax = x.quantile(q1), x.quantile(q2)
|
30
|
-
return vmin, vmax
|
31
|
+
return vmin.item(), vmax.item()
|
31
32
|
|
32
33
|
|
33
34
|
def quantile_normalize(x: torch.Tensor, q: float = 0.95):
|
@@ -57,5 +58,17 @@ def quantile_normalize(x: torch.Tensor, q: float = 0.95):
|
|
57
58
|
return x
|
58
59
|
|
59
60
|
|
61
|
+
class default_device:
|
62
|
+
def __init__(self, device: torch.device):
|
63
|
+
self._device = device
|
64
|
+
|
65
|
+
def __enter__(self):
|
66
|
+
self._original_device = torch.get_default_device()
|
67
|
+
torch.set_default_device(self._device)
|
68
|
+
|
69
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
70
|
+
torch.set_default_device(self._original_device)
|
71
|
+
|
72
|
+
|
60
73
|
def profile(name: str, t: torch.Tensor) -> None:
|
61
74
|
print(f"{name} --- nan: {t.isnan().any()}, inf: {t.isinf().any()}, max: {t.abs().max()}, min: {t.abs().min()}")
|
nystrom_ncut/distance_utils.py
CHANGED
@@ -1,61 +1,71 @@
|
|
1
|
-
|
1
|
+
import collections
|
2
|
+
from typing import List, Literal, OrderedDict
|
2
3
|
|
3
4
|
import torch
|
4
5
|
|
5
6
|
from .common import lazy_normalize
|
6
7
|
|
7
8
|
|
8
|
-
DistanceOptions = Literal["cosine", "euclidean"
|
9
|
+
DistanceOptions = Literal["cosine", "euclidean"]
|
10
|
+
AffinityOptions = Literal["cosine", "rbf", "laplacian"]
|
9
11
|
|
12
|
+
# noinspection PyTypeChecker
|
13
|
+
DISTANCE_TO_AFFINITY: OrderedDict[DistanceOptions, List[AffinityOptions]] = collections.OrderedDict([
|
14
|
+
("cosine", ["cosine"]),
|
15
|
+
("euclidean", ["rbf", "laplacian"]),
|
16
|
+
])
|
17
|
+
# noinspection PyTypeChecker
|
18
|
+
AFFINITY_TO_DISTANCE: OrderedDict[AffinityOptions, DistanceOptions] = collections.OrderedDict(sum([
|
19
|
+
[(affinity_type, distance_type) for affinity_type in affinity_types]
|
20
|
+
for distance_type, affinity_types in DISTANCE_TO_AFFINITY.items()
|
21
|
+
], start=[]))
|
10
22
|
|
11
|
-
|
12
|
-
|
23
|
+
|
24
|
+
|
25
|
+
def to_euclidean(x: torch.Tensor, distance_type: DistanceOptions) -> torch.Tensor:
|
26
|
+
if distance_type == "cosine":
|
13
27
|
return lazy_normalize(x, p=2, dim=-1)
|
14
|
-
elif
|
28
|
+
elif distance_type == "euclidean":
|
15
29
|
return x
|
16
30
|
else:
|
17
|
-
raise ValueError(f"to_euclidean not implemented for
|
31
|
+
raise ValueError(f"to_euclidean not implemented for distance_type {distance_type}.")
|
18
32
|
|
19
33
|
|
20
34
|
def distance_from_features(
|
21
35
|
features: torch.Tensor,
|
22
36
|
features_B: torch.Tensor,
|
23
|
-
|
37
|
+
distance_type: DistanceOptions,
|
24
38
|
):
|
25
|
-
"""Compute
|
39
|
+
"""Compute distance matrix from input features.
|
26
40
|
Args:
|
27
41
|
features (torch.Tensor): input features, shape (n_samples, n_features)
|
28
42
|
features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
|
29
|
-
|
43
|
+
distance_type (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
|
30
44
|
Returns:
|
31
45
|
(torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
32
46
|
"""
|
33
47
|
# compute distance matrix from input features
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
D = D / (torch.linalg.norm(stds) ** 2)
|
49
|
-
else:
|
50
|
-
raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
|
51
|
-
return D
|
48
|
+
shape: torch.Size = features.shape[:-2]
|
49
|
+
features = features.view((-1, *features.shape[-2:]))
|
50
|
+
features_B = features_B.view((-1, *features_B.shape[-2:]))
|
51
|
+
|
52
|
+
match distance_type:
|
53
|
+
case "cosine":
|
54
|
+
features = lazy_normalize(features, dim=-1)
|
55
|
+
features_B = lazy_normalize(features_B, dim=-1)
|
56
|
+
D = 1 - features @ features_B.mT
|
57
|
+
case "euclidean":
|
58
|
+
D = torch.cdist(features, features_B, p=2)
|
59
|
+
case _:
|
60
|
+
raise ValueError("Distance should be 'cosine' or 'euclidean'")
|
61
|
+
return D.view((*shape, *D.shape[-2:]))
|
52
62
|
|
53
63
|
|
54
64
|
def affinity_from_features(
|
55
65
|
features: torch.Tensor,
|
56
66
|
features_B: torch.Tensor = None,
|
57
67
|
affinity_focal_gamma: float = 1.0,
|
58
|
-
|
68
|
+
affinity_type: AffinityOptions = "cosine",
|
59
69
|
):
|
60
70
|
"""Compute affinity matrix from input features.
|
61
71
|
|
@@ -64,7 +74,7 @@ def affinity_from_features(
|
|
64
74
|
features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
|
65
75
|
affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
|
66
76
|
on weak connections, default 1.0
|
67
|
-
|
77
|
+
affinity_type (str): distance metric, 'cosine' (default) or 'euclidean'.
|
68
78
|
Returns:
|
69
79
|
(torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
70
80
|
"""
|
@@ -75,9 +85,21 @@ def affinity_from_features(
|
|
75
85
|
features_B = features if features_B is None else features_B
|
76
86
|
|
77
87
|
# compute distance matrix from input features
|
78
|
-
D = distance_from_features(features, features_B,
|
88
|
+
D = distance_from_features(features, features_B, AFFINITY_TO_DISTANCE[affinity_type])
|
79
89
|
|
80
|
-
# torch.exp make affinity matrix positive definite,
|
81
90
|
# lower affinity_focal_gamma reduce the weak edge weights
|
82
|
-
|
91
|
+
match affinity_type:
|
92
|
+
case "cosine" | "laplacian":
|
93
|
+
A = torch.exp(-D / affinity_focal_gamma) # [... x n x n]
|
94
|
+
case "rbf":
|
95
|
+
# Outlier-robust scale invariance using quantiles to estimate standard deviation
|
96
|
+
c = 2.0
|
97
|
+
p = torch.erf(torch.tensor((-c, c), device=features.device) * (2 ** -0.5))
|
98
|
+
stds = torch.nanquantile(features, q=(p + 1) / 2, dim=-2) # [2 x ... x d]
|
99
|
+
stds = (stds[1] - stds[0]) / (2 * c) # [... x d]
|
100
|
+
D = 0.5 * (D / torch.norm(stds, dim=-1)[..., None, None]) ** 2
|
101
|
+
A = torch.exp(-D / affinity_focal_gamma)
|
102
|
+
case _:
|
103
|
+
raise ValueError("Affinity should be 'cosine', 'rbf', or 'laplacian'")
|
104
|
+
|
83
105
|
return A
|
nystrom_ncut/nystrom/__init__.py
CHANGED
@@ -18,10 +18,10 @@ from ..sampling_utils import (
|
|
18
18
|
class GramKernel(OnlineKernel):
|
19
19
|
def __init__(
|
20
20
|
self,
|
21
|
-
|
21
|
+
distance_type: DistanceOptions,
|
22
22
|
eig_solver: EigSolverOptions,
|
23
23
|
):
|
24
|
-
self.
|
24
|
+
self.distance_type: DistanceOptions = distance_type
|
25
25
|
self.eig_solver: EigSolverOptions = eig_solver
|
26
26
|
|
27
27
|
# Anchor matrices
|
@@ -40,7 +40,7 @@ class GramKernel(OnlineKernel):
|
|
40
40
|
self.A = -0.5 * distance_from_features(
|
41
41
|
self.anchor_features, # [n x d]
|
42
42
|
self.anchor_features,
|
43
|
-
|
43
|
+
distance_type=self.distance_type,
|
44
44
|
) # [n x n]
|
45
45
|
d = features.shape[-1]
|
46
46
|
U, L = solve_eig(
|
@@ -58,7 +58,7 @@ class GramKernel(OnlineKernel):
|
|
58
58
|
B = -0.5 * distance_from_features(
|
59
59
|
self.anchor_features, # [n x d]
|
60
60
|
features, # [m x d]
|
61
|
-
|
61
|
+
distance_type=self.distance_type,
|
62
62
|
) # [n x m]
|
63
63
|
b_r = torch.sum(B, dim=-1) # [n]
|
64
64
|
b_c = torch.sum(B, dim=-2) # [m]
|
@@ -84,7 +84,7 @@ class GramKernel(OnlineKernel):
|
|
84
84
|
B = -0.5 * distance_from_features(
|
85
85
|
self.anchor_features,
|
86
86
|
features,
|
87
|
-
|
87
|
+
distance_type=self.distance_type,
|
88
88
|
)
|
89
89
|
b_c = torch.sum(B, dim=-2) # [m]
|
90
90
|
col_sum = b_c + B.mT @ self.Ainv @ self.b_r # [m]
|
@@ -98,7 +98,7 @@ class DistanceRealization(OnlineNystromSubsampleFit):
|
|
98
98
|
def __init__(
|
99
99
|
self,
|
100
100
|
n_components: int = 100,
|
101
|
-
|
101
|
+
distance_type: DistanceOptions = "cosine",
|
102
102
|
sample_config: SampleConfig = SampleConfig(),
|
103
103
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
104
104
|
chunk_size: int = 8192,
|
@@ -115,13 +115,12 @@ class DistanceRealization(OnlineNystromSubsampleFit):
|
|
115
115
|
OnlineNystromSubsampleFit.__init__(
|
116
116
|
self,
|
117
117
|
n_components=n_components,
|
118
|
-
kernel=GramKernel(
|
119
|
-
|
118
|
+
kernel=GramKernel(distance_type, eig_solver),
|
119
|
+
distance_type=distance_type,
|
120
120
|
sample_config=sample_config,
|
121
121
|
eig_solver=eig_solver,
|
122
122
|
chunk_size=chunk_size,
|
123
123
|
)
|
124
|
-
self.distance: DistanceOptions = distance
|
125
124
|
|
126
125
|
def fit_transform(
|
127
126
|
self,
|
@@ -8,7 +8,8 @@ from .nystrom_utils import (
|
|
8
8
|
solve_eig,
|
9
9
|
)
|
10
10
|
from ..distance_utils import (
|
11
|
-
|
11
|
+
AffinityOptions,
|
12
|
+
AFFINITY_TO_DISTANCE,
|
12
13
|
affinity_from_features,
|
13
14
|
)
|
14
15
|
from ..sampling_utils import (
|
@@ -20,80 +21,83 @@ class LaplacianKernel(OnlineKernel):
|
|
20
21
|
def __init__(
|
21
22
|
self,
|
22
23
|
affinity_focal_gamma: float,
|
23
|
-
|
24
|
+
affinity_type: AffinityOptions,
|
24
25
|
adaptive_scaling: bool,
|
25
26
|
eig_solver: EigSolverOptions,
|
26
27
|
):
|
27
28
|
self.affinity_focal_gamma = affinity_focal_gamma
|
28
|
-
self.
|
29
|
+
self.affinity_type: AffinityOptions = affinity_type
|
29
30
|
self.adaptive_scaling: bool = adaptive_scaling
|
30
31
|
self.eig_solver: EigSolverOptions = eig_solver
|
31
32
|
|
32
33
|
# Anchor matrices
|
33
|
-
self.anchor_features: torch.Tensor = None
|
34
|
-
self.
|
35
|
-
self.
|
34
|
+
self.anchor_features: torch.Tensor = None # [... x n x d]
|
35
|
+
self.anchor_mask: torch.Tensor = None
|
36
|
+
self.A: torch.Tensor = None # [... x n x n]
|
37
|
+
self.Ainv: torch.Tensor = None # [... x n x n]
|
36
38
|
|
37
39
|
# Updated matrices
|
38
|
-
self.a_r: torch.Tensor = None
|
39
|
-
self.b_r: torch.Tensor = None
|
40
|
+
self.a_r: torch.Tensor = None # [... x n]
|
41
|
+
self.b_r: torch.Tensor = None # [... x n]
|
40
42
|
|
41
43
|
def fit(self, features: torch.Tensor) -> None:
|
42
|
-
self.anchor_features = features
|
43
|
-
self.
|
44
|
-
|
44
|
+
self.anchor_features = features # [... x n x d]
|
45
|
+
self.anchor_mask = torch.all(torch.isnan(self.anchor_features), dim=-1) # [... x n]
|
46
|
+
|
47
|
+
self.A = torch.nan_to_num(affinity_from_features(
|
48
|
+
self.anchor_features, # [... x n x d]
|
45
49
|
affinity_focal_gamma=self.affinity_focal_gamma,
|
46
|
-
|
47
|
-
)
|
50
|
+
affinity_type=self.affinity_type,
|
51
|
+
), nan=0.0) # [... x n x n]
|
48
52
|
d = features.shape[-1]
|
49
53
|
U, L = solve_eig(
|
50
54
|
self.A,
|
51
55
|
num_eig=d + 1, # d * (d + 3) // 2 + 1,
|
52
56
|
eig_solver=self.eig_solver,
|
53
|
-
)
|
54
|
-
self.Ainv = U @ torch.
|
55
|
-
self.a_r = torch.sum(self.A, dim=-1)
|
56
|
-
self.b_r = torch.zeros_like(self.a_r)
|
57
|
+
) # [... x n x (d + 1)], [... x (d + 1)]
|
58
|
+
self.Ainv = U @ torch.diag_embed(1 / L) @ U.mT # [... x n x n]
|
59
|
+
self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
|
60
|
+
self.b_r = torch.zeros_like(self.a_r) # [... x n]
|
57
61
|
|
58
62
|
def _affinity(self, features: torch.Tensor) -> torch.Tensor:
|
59
|
-
B = affinity_from_features(
|
60
|
-
self.anchor_features,
|
61
|
-
features,
|
63
|
+
B = torch.where(self.anchor_mask[..., None], 0.0, affinity_from_features(
|
64
|
+
self.anchor_features, # [... x n x d]
|
65
|
+
features, # [... x m x d]
|
62
66
|
affinity_focal_gamma=self.affinity_focal_gamma,
|
63
|
-
|
64
|
-
)
|
67
|
+
affinity_type=self.affinity_type,
|
68
|
+
)) # [... x n x m]
|
65
69
|
if self.adaptive_scaling:
|
66
70
|
diagonal = (
|
67
|
-
einops.rearrange(B, "n m -> m 1 n")
|
68
|
-
@ self.Ainv
|
69
|
-
@ einops.rearrange(B, "n m -> m n 1")
|
70
|
-
).squeeze(
|
71
|
-
adaptive_scale = diagonal ** -0.5
|
72
|
-
B = B * adaptive_scale
|
73
|
-
return B
|
71
|
+
einops.rearrange(B, "... n m -> ... m 1 n") # [... x m x 1 x n]
|
72
|
+
@ self.Ainv # [... x n x n]
|
73
|
+
@ einops.rearrange(B, "... n m -> ... m n 1") # [... x m x n x 1]
|
74
|
+
).squeeze(-2, -1) # [... x m]
|
75
|
+
adaptive_scale = diagonal ** -0.5 # [... x m]
|
76
|
+
B = B * adaptive_scale[..., None, :]
|
77
|
+
return B # [... x n x m]
|
74
78
|
|
75
79
|
def update(self, features: torch.Tensor) -> torch.Tensor:
|
76
|
-
B = self._affinity(features)
|
77
|
-
b_r = torch.sum(B, dim=-1)
|
78
|
-
b_c = torch.sum(B, dim=-2)
|
79
|
-
self.b_r = self.b_r + b_r
|
80
|
+
B = self._affinity(features) # [... x n x m]
|
81
|
+
b_r = torch.sum(torch.nan_to_num(B, nan=0.0), dim=-1) # [... x n]
|
82
|
+
b_c = torch.sum(B, dim=-2) # [... x m]
|
83
|
+
self.b_r = self.b_r + b_r # [... x n]
|
80
84
|
|
81
|
-
row_sum = self.a_r + self.b_r
|
82
|
-
col_sum = b_c + B.mT @ self.Ainv @ self.b_r
|
83
|
-
scale = (row_sum[:, None] * col_sum) ** -0.5
|
84
|
-
return (B * scale).mT
|
85
|
+
row_sum = self.a_r + self.b_r # [... x n]
|
86
|
+
col_sum = b_c + (B.mT @ (self.Ainv @ self.b_r[..., None]))[..., 0] # [... x m]
|
87
|
+
scale = (row_sum[..., :, None] * col_sum[..., None, :]) ** -0.5 # [... x n x m]
|
88
|
+
return (B * scale).mT # [... x m x n]
|
85
89
|
|
86
90
|
def transform(self, features: torch.Tensor = None) -> torch.Tensor:
|
87
|
-
row_sum = self.a_r + self.b_r
|
91
|
+
row_sum = self.a_r + self.b_r # [... x n]
|
88
92
|
if features is None:
|
89
|
-
B = self.A
|
90
|
-
col_sum = row_sum
|
93
|
+
B = self.A # [... x n x n]
|
94
|
+
col_sum = row_sum # [... x n]
|
91
95
|
else:
|
92
96
|
B = self._affinity(features)
|
93
|
-
b_c = torch.sum(B, dim=-2)
|
94
|
-
col_sum = b_c + B.mT @ self.Ainv @ self.b_r
|
95
|
-
scale = (row_sum[:, None] * col_sum) ** -0.5
|
96
|
-
return (B * scale).mT
|
97
|
+
b_c = torch.sum(B, dim=-2) # [... x m]
|
98
|
+
col_sum = b_c + (B.mT @ (self.Ainv @ self.b_r[..., None]))[..., 0] # [... x m]
|
99
|
+
scale = (row_sum[..., :, None] * col_sum[..., None, :]) ** -0.5 # [... x n x m]
|
100
|
+
return (B * scale).mT # [... x m x n]
|
97
101
|
|
98
102
|
|
99
103
|
class NCut(OnlineNystromSubsampleFit):
|
@@ -103,7 +107,7 @@ class NCut(OnlineNystromSubsampleFit):
|
|
103
107
|
self,
|
104
108
|
n_components: int = 100,
|
105
109
|
affinity_focal_gamma: float = 1.0,
|
106
|
-
|
110
|
+
affinity_type: AffinityOptions = "cosine",
|
107
111
|
adaptive_scaling: bool = False,
|
108
112
|
sample_config: SampleConfig = SampleConfig(),
|
109
113
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
@@ -124,8 +128,8 @@ class NCut(OnlineNystromSubsampleFit):
|
|
124
128
|
OnlineNystromSubsampleFit.__init__(
|
125
129
|
self,
|
126
130
|
n_components=n_components,
|
127
|
-
kernel=LaplacianKernel(affinity_focal_gamma,
|
128
|
-
|
131
|
+
kernel=LaplacianKernel(affinity_focal_gamma, affinity_type, adaptive_scaling, eig_solver),
|
132
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
129
133
|
sample_config=sample_config,
|
130
134
|
eig_solver=eig_solver,
|
131
135
|
chunk_size=chunk_size,
|
@@ -25,15 +25,15 @@ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
|
|
25
25
|
|
26
26
|
class OnlineKernel:
|
27
27
|
@abstractmethod
|
28
|
-
def fit(self, features: torch.Tensor) -> "OnlineKernel": # [n x d]
|
28
|
+
def fit(self, features: torch.Tensor) -> "OnlineKernel": # [... x n x d]
|
29
29
|
""""""
|
30
30
|
|
31
31
|
@abstractmethod
|
32
|
-
def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
|
32
|
+
def update(self, features: torch.Tensor) -> torch.Tensor: # [... x m x d] -> [... x m x n]
|
33
33
|
""""""
|
34
34
|
|
35
35
|
@abstractmethod
|
36
|
-
def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
|
36
|
+
def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [... x m x d] -> [... x m x n]
|
37
37
|
""""""
|
38
38
|
|
39
39
|
|
@@ -54,20 +54,21 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
54
54
|
self.n_components: int = n_components
|
55
55
|
self.kernel: OnlineKernel = kernel
|
56
56
|
self.eig_solver: EigSolverOptions = eig_solver
|
57
|
+
self.shape: torch.Size = None # ...
|
57
58
|
|
58
59
|
self.chunk_size = chunk_size
|
59
60
|
|
60
61
|
# Anchor matrices
|
61
|
-
self.anchor_features: torch.Tensor = None # [n x d]
|
62
|
-
self.A: torch.Tensor = None # [n x n]
|
63
|
-
self.Ahinv: torch.Tensor = None # [n x n]
|
64
|
-
self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
|
65
|
-
self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
|
62
|
+
self.anchor_features: torch.Tensor = None # [... x n x d]
|
63
|
+
self.A: torch.Tensor = None # [... x n x n]
|
64
|
+
self.Ahinv: torch.Tensor = None # [... x n x n]
|
65
|
+
self.Ahinv_UL: torch.Tensor = None # [... x n x indirect_pca_dim]
|
66
|
+
self.Ahinv_VT: torch.Tensor = None # [... x indirect_pca_dim x n]
|
66
67
|
|
67
68
|
# Updated matrices
|
68
|
-
self.S: torch.Tensor = None # [n x n]
|
69
|
-
self.transform_matrix: torch.Tensor = None # [n x n_components]
|
70
|
-
self.eigenvalues_: torch.Tensor = None # [n]
|
69
|
+
self.S: torch.Tensor = None # [... x n x n]
|
70
|
+
self.transform_matrix: torch.Tensor = None # [... x n x n_components]
|
71
|
+
self.eigenvalues_: torch.Tensor = None # [... x n]
|
71
72
|
|
72
73
|
def _update_to_kernel(self, d: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
73
74
|
self.A = self.S = self.kernel.transform()
|
@@ -75,10 +76,10 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
75
76
|
self.A,
|
76
77
|
num_eig=d + 1, # d * (d + 3) // 2 + 1,
|
77
78
|
eig_solver=self.eig_solver,
|
78
|
-
) # [n x (? + 1)], [? + 1]
|
79
|
-
self.Ahinv_UL = U * (L ** -0.5)
|
80
|
-
self.Ahinv_VT = U.mT # [(? + 1) x n]
|
81
|
-
self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
|
79
|
+
) # [... x n x (? + 1)], [... x (? + 1)]
|
80
|
+
self.Ahinv_UL = U * (L[..., None, :] ** -0.5) # [... x n x (? + 1)]
|
81
|
+
self.Ahinv_VT = U.mT # [... x (? + 1) x n]
|
82
|
+
self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [... x n x n]
|
82
83
|
return U, L
|
83
84
|
|
84
85
|
def fit(self, features: torch.Tensor) -> "OnlineNystrom":
|
@@ -89,64 +90,63 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
89
90
|
self.anchor_features = features
|
90
91
|
|
91
92
|
self.kernel.fit(self.anchor_features)
|
92
|
-
U, L = self._update_to_kernel(features.shape[-1]) # [n x (d + 1)], [d + 1]
|
93
|
+
U, L = self._update_to_kernel(features.shape[-1]) # [... x n x (d + 1)], [... x (d + 1)]
|
93
94
|
|
94
|
-
self.transform_matrix = (U / L)[:, :self.n_components]
|
95
|
-
self.eigenvalues_ = L[:self.n_components]
|
96
|
-
return U[:, :self.n_components]
|
95
|
+
self.transform_matrix = (U / L[..., None, :])[..., :, :self.n_components] # [... x n x n_components]
|
96
|
+
self.eigenvalues_ = L[..., :self.n_components] # [... x n_components]
|
97
|
+
return U[..., :, :self.n_components] # [... x n x n_components]
|
97
98
|
|
98
99
|
def update(self, features: torch.Tensor) -> torch.Tensor:
|
99
100
|
d = features.shape[-1]
|
100
|
-
n_chunks = ceildiv(
|
101
|
+
n_chunks = ceildiv(features.shape[-2], self.chunk_size)
|
101
102
|
if n_chunks > 1:
|
102
103
|
""" Chunked version """
|
103
|
-
chunks = torch.chunk(features, n_chunks, dim
|
104
|
+
chunks = torch.chunk(features, n_chunks, dim=-2)
|
104
105
|
for chunk in chunks:
|
105
106
|
self.kernel.update(chunk)
|
106
107
|
self._update_to_kernel(d)
|
107
108
|
|
108
|
-
compressed_BBT = 0.0 # [(? + 1) x (? + 1))]
|
109
|
+
compressed_BBT = 0.0 # [... x (? + 1) x (? + 1))]
|
109
110
|
for chunk in chunks:
|
110
|
-
_B = self.kernel.transform(chunk).mT # [n x _m]
|
111
|
-
_compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
self.
|
111
|
+
_B = self.kernel.transform(chunk).mT # [... x n x _m]
|
112
|
+
_compressed_B = self.Ahinv_VT @ _B # [... x (? + 1) x _m]
|
113
|
+
_compressed_B = torch.nan_to_num(_compressed_B, nan=0.0)
|
114
|
+
compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [... x (? + 1) x (? + 1)]
|
115
|
+
self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [... x n x n]
|
116
|
+
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
|
117
|
+
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
|
116
118
|
|
117
119
|
VS = []
|
118
120
|
for chunk in chunks:
|
119
|
-
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
|
120
|
-
VS = torch.cat(VS, dim
|
121
|
-
return VS # [m x n_components]
|
121
|
+
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
|
122
|
+
VS = torch.cat(VS, dim=-2)
|
123
|
+
return VS # [... x m x n_components]
|
122
124
|
else:
|
123
125
|
""" Unchunked version """
|
124
|
-
B = self.kernel.update(features).mT # [n x m]
|
126
|
+
B = self.kernel.update(features).mT # [... x n x m]
|
125
127
|
self._update_to_kernel(d)
|
126
|
-
compressed_B = self.Ahinv_VT @ B # [
|
128
|
+
compressed_B = self.Ahinv_VT @ B # [... x (? + 1) x m]
|
129
|
+
compressed_B = torch.nan_to_num(compressed_B, nan=0.0)
|
127
130
|
|
128
|
-
self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
|
129
|
-
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
|
130
|
-
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5)
|
131
|
+
self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [... x n x n]
|
132
|
+
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
|
133
|
+
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
|
131
134
|
|
132
|
-
return B.mT @ self.transform_matrix # [m x n_components]
|
135
|
+
return B.mT @ self.transform_matrix # [... x m x n_components]
|
133
136
|
|
134
|
-
def transform(self, features: torch.Tensor
|
135
|
-
|
136
|
-
|
137
|
+
def transform(self, features: torch.Tensor) -> torch.Tensor:
|
138
|
+
n_chunks = ceildiv(features.shape[-2], self.chunk_size)
|
139
|
+
if n_chunks > 1:
|
140
|
+
""" Chunked version """
|
141
|
+
chunks = torch.chunk(features, n_chunks, dim=-2)
|
142
|
+
VS = []
|
143
|
+
for chunk in chunks:
|
144
|
+
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
|
145
|
+
VS = torch.cat(VS, dim=-2)
|
137
146
|
else:
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
chunks = torch.chunk(features, n_chunks, dim=0)
|
142
|
-
VS = []
|
143
|
-
for chunk in chunks:
|
144
|
-
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
|
145
|
-
VS = torch.cat(VS, dim=0)
|
146
|
-
else:
|
147
|
-
""" Unchunked version """
|
148
|
-
VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
|
149
|
-
return VS # [m x n_components]
|
147
|
+
""" Unchunked version """
|
148
|
+
VS = self.kernel.transform(features) @ self.transform_matrix # [... x m x n_components]
|
149
|
+
return VS # [... x m x n_components]
|
150
150
|
|
151
151
|
|
152
152
|
class OnlineNystromSubsampleFit(OnlineNystrom):
|
@@ -154,7 +154,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
154
154
|
self,
|
155
155
|
n_components: int,
|
156
156
|
kernel: OnlineKernel,
|
157
|
-
|
157
|
+
distance_type: DistanceOptions,
|
158
158
|
sample_config: SampleConfig,
|
159
159
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
160
160
|
chunk_size: int = 8192,
|
@@ -166,7 +166,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
166
166
|
eig_solver=eig_solver,
|
167
167
|
chunk_size=chunk_size,
|
168
168
|
)
|
169
|
-
self.
|
169
|
+
self.distance_type: DistanceOptions = distance_type
|
170
170
|
self.sample_config: SampleConfig = sample_config
|
171
171
|
self.sample_config._ncut_obj = copy.deepcopy(self)
|
172
172
|
self.anchor_indices: torch.Tensor = None
|
@@ -176,7 +176,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
176
176
|
features: torch.Tensor,
|
177
177
|
precomputed_sampled_indices: torch.Tensor,
|
178
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
179
|
-
_n = features.shape[
|
179
|
+
_n = features.shape[-2]
|
180
180
|
if self.sample_config.num_sample >= _n:
|
181
181
|
logging.info(
|
182
182
|
f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
|
@@ -188,16 +188,17 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
188
188
|
else:
|
189
189
|
self.anchor_indices = subsample_features(
|
190
190
|
features=features,
|
191
|
-
|
191
|
+
distance_type=self.distance_type,
|
192
192
|
config=self.sample_config,
|
193
193
|
)
|
194
|
-
sampled_features = features[self.anchor_indices]
|
194
|
+
sampled_features = torch.gather(features, -2, self.anchor_indices[..., None].expand([-1] * self.anchor_indices.ndim + [features.shape[-1]]))
|
195
195
|
OnlineNystrom.fit(self, sampled_features)
|
196
196
|
|
197
|
-
_n_not_sampled = _n -
|
197
|
+
_n_not_sampled = _n - self.anchor_indices.shape[-1]
|
198
198
|
if _n_not_sampled > 0:
|
199
|
-
|
200
|
-
|
199
|
+
unsampled_mask = torch.full(features.shape[:-1], True, device=features.device).scatter_(-1, self.anchor_indices, False)
|
200
|
+
unsampled_indices = torch.where(unsampled_mask)[-1].view((*features.shape[:-2], -1))
|
201
|
+
unsampled_features = torch.gather(features, -2, unsampled_indices[..., None].expand([-1] * unsampled_indices.ndim + [features.shape[-1]]))
|
201
202
|
V_unsampled = OnlineNystrom.update(self, unsampled_features)
|
202
203
|
else:
|
203
204
|
unsampled_indices = V_unsampled = None
|
@@ -235,12 +236,12 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
235
236
|
(torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
|
236
237
|
"""
|
237
238
|
unsampled_indices, V_unsampled = OnlineNystromSubsampleFit._fit_helper(self, features, precomputed_sampled_indices)
|
238
|
-
V_sampled = OnlineNystrom.transform(self)
|
239
|
+
V_sampled = OnlineNystrom.transform(self, self.anchor_features)
|
239
240
|
|
240
241
|
if unsampled_indices is not None:
|
241
|
-
V = torch.zeros((
|
242
|
-
|
243
|
-
|
242
|
+
V = torch.zeros((*features.shape[:-1], self.n_components), device=features.device)
|
243
|
+
for (indices, _V) in [(self.anchor_indices, V_sampled), (unsampled_indices, V_unsampled)]:
|
244
|
+
V.scatter_(-2, indices[..., None].expand([-1] * indices.ndim + [self.n_components]), _V)
|
244
245
|
else:
|
245
246
|
V = V_sampled
|
246
247
|
return V
|
@@ -263,12 +264,16 @@ def solve_eig(
|
|
263
264
|
(torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
|
264
265
|
(torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
|
265
266
|
"""
|
266
|
-
|
267
|
+
shape: torch.Size = A.shape[:-2]
|
268
|
+
A = A.view((-1, *A.shape[-2:]))
|
269
|
+
bsz: int = A.shape[0]
|
270
|
+
|
271
|
+
A = A + eig_value_buffer * torch.eye(A.shape[-1], device=A.device)
|
267
272
|
|
268
273
|
# compute eigenvectors
|
269
274
|
if eig_solver == "svd_lowrank": # default
|
270
275
|
# only top q eigenvectors, fastest
|
271
|
-
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
|
276
|
+
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig) # complex: [(...) x N x D], [(...) x D]
|
272
277
|
elif eig_solver == "lobpcg":
|
273
278
|
# only top k eigenvectors, fast
|
274
279
|
eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
|
@@ -285,11 +290,15 @@ def solve_eig(
|
|
285
290
|
eigen_value = eigen_value - eig_value_buffer
|
286
291
|
|
287
292
|
# sort eigenvectors by eigenvalues, take top (descending order)
|
288
|
-
indices = torch.topk(eigen_value.abs(), k=num_eig, dim
|
289
|
-
eigen_value
|
293
|
+
indices = torch.topk(eigen_value.abs(), k=num_eig, dim=-1).indices # int: [(...) x S]
|
294
|
+
eigen_value = eigen_value[torch.arange(bsz)[:, None], indices] # complex: [(...) x S]
|
295
|
+
eigen_vector = eigen_vector[torch.arange(bsz)[:, None], :, indices].mT # complex: [(...) x N x S]
|
290
296
|
|
291
297
|
# correct the random rotation (flipping sign) of eigenvectors
|
292
|
-
sign = torch.sum(eigen_vector.real, dim=
|
298
|
+
sign = torch.sign(torch.sum(eigen_vector.real, dim=-2, keepdim=True)) # float: [(...) x 1 x S]
|
293
299
|
sign[sign == 0] = 1.0
|
294
300
|
eigen_vector = eigen_vector * sign
|
301
|
+
|
302
|
+
eigen_value = eigen_value.view((*shape, *eigen_value.shape[-1:])) # complex: [... x S]
|
303
|
+
eigen_vector = eigen_vector.view((*shape, *eigen_vector.shape[-2:])) # complex: [... x N x S]
|
295
304
|
return eigen_vector, eigen_value
|
nystrom_ncut/sampling_utils.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1
|
-
import logging
|
2
1
|
from dataclasses import dataclass
|
3
2
|
from typing import Literal
|
4
3
|
|
5
4
|
import torch
|
6
5
|
from pytorch3d.ops import sample_farthest_points
|
7
6
|
|
7
|
+
from .common import (
|
8
|
+
default_device,
|
9
|
+
)
|
8
10
|
from .distance_utils import (
|
9
11
|
DistanceOptions,
|
10
12
|
to_euclidean,
|
11
13
|
)
|
14
|
+
from .transformer import (
|
15
|
+
TorchTransformerMixin,
|
16
|
+
)
|
12
17
|
|
13
18
|
|
14
|
-
SampleOptions = Literal["random", "fps", "fps_recursive"]
|
19
|
+
SampleOptions = Literal["full", "random", "fps", "fps_recursive"]
|
15
20
|
|
16
21
|
|
17
22
|
@dataclass
|
@@ -20,69 +25,77 @@ class SampleConfig:
|
|
20
25
|
num_sample: int = 10000
|
21
26
|
fps_dim: int = 12
|
22
27
|
n_iter: int = None
|
23
|
-
_ncut_obj:
|
28
|
+
_ncut_obj: TorchTransformerMixin = None
|
24
29
|
|
25
30
|
|
26
31
|
@torch.no_grad()
|
27
32
|
def subsample_features(
|
28
33
|
features: torch.Tensor,
|
29
|
-
|
34
|
+
distance_type: DistanceOptions,
|
30
35
|
config: SampleConfig,
|
31
|
-
max_draw: int = 1000000,
|
32
36
|
):
|
33
|
-
features = features.detach()
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
"num_sample is larger than total, bypass Nystrom-like approximation"
|
38
|
-
)
|
39
|
-
sampled_indices = torch.arange(features.shape[0])
|
40
|
-
else:
|
41
|
-
# sample subgraph
|
42
|
-
if config.method == "fps": # default
|
43
|
-
features = to_euclidean(features, disttype)
|
44
|
-
if config.num_sample > max_draw:
|
45
|
-
logging.warning(
|
46
|
-
f"num_sample is larger than max_draw, apply farthest point sampling on random sampled {max_draw} samples"
|
47
|
-
)
|
48
|
-
draw_indices = torch.randperm(features.shape[0])[:max_draw]
|
49
|
-
sampled_indices = fpsample(features[draw_indices], config)
|
50
|
-
sampled_indices = draw_indices[sampled_indices]
|
51
|
-
else:
|
52
|
-
sampled_indices = fpsample(features, config)
|
53
|
-
|
54
|
-
elif config.method == "random": # not recommended
|
55
|
-
sampled_indices = torch.randperm(features.shape[0])[:config.num_sample]
|
56
|
-
|
57
|
-
elif config.method == "fps_recursive":
|
58
|
-
features = to_euclidean(features, disttype)
|
59
|
-
sampled_indices = subsample_features(
|
60
|
-
features=features,
|
61
|
-
disttype=disttype,
|
62
|
-
config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
|
63
|
-
)
|
64
|
-
nc = config._ncut_obj
|
65
|
-
for _ in range(config.n_iter):
|
66
|
-
fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
|
67
|
-
|
68
|
-
fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
|
69
|
-
sampled_indices = torch.sort(fpsample(fps_features, config)).values
|
37
|
+
features = features.detach() # float: [... x n x d]
|
38
|
+
with default_device(features.device):
|
39
|
+
if config.method == "full" or config.num_sample >= features.shape[0]:
|
40
|
+
sampled_indices = torch.arange(features.shape[-2]).expand(features.shape[:-1]) # int: [... x n]
|
70
41
|
else:
|
71
|
-
|
72
|
-
|
73
|
-
|
42
|
+
# sample
|
43
|
+
match config.method:
|
44
|
+
case "fps": # default
|
45
|
+
sampled_indices = fpsample(to_euclidean(features, distance_type), config)
|
46
|
+
|
47
|
+
case "random": # not recommended
|
48
|
+
mask = torch.all(torch.isfinite(features), dim=-1) # bool: [... x n]
|
49
|
+
weights = mask.to(torch.float) + torch.rand(mask.shape) # float: [... x n]
|
50
|
+
sampled_indices = torch.topk(weights, k=config.num_sample, dim=-1).indices # int: [... x num_sample]
|
51
|
+
|
52
|
+
case "fps_recursive":
|
53
|
+
features = to_euclidean(features, distance_type) # float: [... x n x d]
|
54
|
+
sampled_indices = subsample_features(
|
55
|
+
features=features,
|
56
|
+
distance_type=distance_type,
|
57
|
+
config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
|
58
|
+
) # int: [... x num_sample]
|
59
|
+
nc = config._ncut_obj
|
60
|
+
for _ in range(config.n_iter):
|
61
|
+
fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
|
62
|
+
|
63
|
+
fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
|
64
|
+
sampled_indices = torch.sort(fpsample(fps_features, config), dim=-1).values
|
65
|
+
|
66
|
+
case _:
|
67
|
+
raise ValueError("sample_method should be 'farthest' or 'random'")
|
68
|
+
sampled_indices = torch.sort(sampled_indices, dim=-1).values
|
69
|
+
return sampled_indices
|
74
70
|
|
75
71
|
|
76
72
|
def fpsample(
|
77
73
|
features: torch.Tensor,
|
78
74
|
config: SampleConfig,
|
79
75
|
):
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
76
|
+
shape = features.shape[:-2] # ...
|
77
|
+
features = features.view((-1, *features.shape[-2:])) # [(...) x n x d]
|
78
|
+
bsz = features.shape[0]
|
79
|
+
|
80
|
+
mask = torch.all(torch.isfinite(features), dim=-1) # bool: [(...) x n]
|
81
|
+
count = torch.sum(mask, dim=-1) # int: [(...)]
|
82
|
+
order = torch.topk(mask.to(torch.int), k=torch.max(count).item(), dim=-1).indices # int: [(...) x max_count]
|
83
|
+
|
84
|
+
features = torch.nan_to_num(features[torch.arange(bsz)[:, None], order], nan=0.0) # float: [(...) x max_count x d]
|
85
|
+
if features.shape[-1] > config.fps_dim:
|
86
|
+
U, S, V = torch.pca_lowrank(features, q=config.fps_dim) # float: [(...) x max_count x fps_dim], [(...) x fps_dim], [(...) x fps_dim x fps_dim]
|
87
|
+
features = U * S[..., None, :] # float: [(...) x max_count x fps_dim]
|
84
88
|
|
85
89
|
try:
|
86
|
-
|
90
|
+
sample_indices = sample_farthest_points(
|
91
|
+
features, lengths=count, K=config.num_sample
|
92
|
+
)[1] # int: [(...) x num_sample]
|
87
93
|
except RuntimeError:
|
88
|
-
|
94
|
+
original_device = features.device
|
95
|
+
alternative_device = "cuda" if original_device == "cpu" else "cpu"
|
96
|
+
sample_indices = sample_farthest_points(
|
97
|
+
features.to(alternative_device), lengths=count.to(alternative_device), K=config.num_sample
|
98
|
+
)[1].to(original_device) # int: [(...) x num_sample]
|
99
|
+
sample_indices = torch.gather(order, 1, sample_indices) # int: [(...) x num_sample]
|
100
|
+
|
101
|
+
return sample_indices.view((*shape, *sample_indices.shape[-1:])) # int: [... x num_sample]
|
nystrom_ncut/visualize_utils.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
import
|
2
|
-
from typing import Any, Callable, Dict, Literal, Union
|
1
|
+
from typing import Any, Callable, Dict, Union
|
3
2
|
|
4
3
|
import numpy as np
|
5
4
|
import torch
|
@@ -13,7 +12,8 @@ from .common import (
|
|
13
12
|
quantile_normalize,
|
14
13
|
)
|
15
14
|
from .distance_utils import (
|
16
|
-
|
15
|
+
AffinityOptions,
|
16
|
+
AFFINITY_TO_DISTANCE,
|
17
17
|
to_euclidean,
|
18
18
|
affinity_from_features,
|
19
19
|
)
|
@@ -27,7 +27,7 @@ def extrapolate_knn(
|
|
27
27
|
anchor_features: torch.Tensor, # [n x d]
|
28
28
|
anchor_output: torch.Tensor, # [n x d']
|
29
29
|
extrapolation_features: torch.Tensor, # [m x d]
|
30
|
-
|
30
|
+
affinity_type: AffinityOptions,
|
31
31
|
knn: int = 10, # k
|
32
32
|
affinity_focal_gamma: float = 1.0,
|
33
33
|
chunk_size: int = 8192,
|
@@ -41,7 +41,7 @@ def extrapolate_knn(
|
|
41
41
|
anchor_output (torch.Tensor): output from subgraph, shape (num_sample, D)
|
42
42
|
extrapolation_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
|
43
43
|
knn (int): number of KNN to propagate eige nvectors
|
44
|
-
|
44
|
+
affinity_type (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
|
45
45
|
chunk_size (int): chunk size for matrix multiplication
|
46
46
|
device (str): device to use for computation, if None, will not change device
|
47
47
|
Returns:
|
@@ -66,7 +66,7 @@ def extrapolate_knn(
|
|
66
66
|
for _v in torch.chunk(extrapolation_features, n_chunks, dim=0):
|
67
67
|
_v = _v.to(device) # [_m x d]
|
68
68
|
|
69
|
-
_A = affinity_from_features(anchor_features, _v, affinity_focal_gamma,
|
69
|
+
_A = affinity_from_features(anchor_features, _v, affinity_focal_gamma, affinity_type).mT # [_m x n]
|
70
70
|
if knn is not None:
|
71
71
|
_A, indices = _A.topk(k=knn, dim=-1, largest=True) # [_m x k], [_m x k]
|
72
72
|
_anchor_output = anchor_output[indices] # [_m x k x d]
|
@@ -90,7 +90,7 @@ def extrapolate_knn_with_subsampling(
|
|
90
90
|
full_output: torch.Tensor, # [n x d']
|
91
91
|
extrapolation_features: torch.Tensor, # [m x d]
|
92
92
|
sample_config: SampleConfig,
|
93
|
-
|
93
|
+
affinity_type: AffinityOptions,
|
94
94
|
knn: int = 10, # k
|
95
95
|
affinity_focal_gamma: float = 1.0,
|
96
96
|
chunk_size: int = 8192,
|
@@ -122,7 +122,7 @@ def extrapolate_knn_with_subsampling(
|
|
122
122
|
# sample subgraph
|
123
123
|
anchor_indices = subsample_features(
|
124
124
|
features=full_features,
|
125
|
-
|
125
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
126
126
|
config=sample_config,
|
127
127
|
)
|
128
128
|
|
@@ -135,7 +135,7 @@ def extrapolate_knn_with_subsampling(
|
|
135
135
|
anchor_features,
|
136
136
|
anchor_output,
|
137
137
|
extrapolation_features,
|
138
|
-
|
138
|
+
affinity_type,
|
139
139
|
knn=knn,
|
140
140
|
affinity_focal_gamma=affinity_focal_gamma,
|
141
141
|
chunk_size=chunk_size,
|
@@ -148,7 +148,7 @@ def extrapolate_knn_with_subsampling(
|
|
148
148
|
def _rgb_with_dimensionality_reduction(
|
149
149
|
features: torch.Tensor,
|
150
150
|
num_sample: int,
|
151
|
-
|
151
|
+
affinity_type: AffinityOptions,
|
152
152
|
rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
|
153
153
|
q: float,
|
154
154
|
knn: int,
|
@@ -162,26 +162,26 @@ def _rgb_with_dimensionality_reduction(
|
|
162
162
|
if True:
|
163
163
|
_subgraph_indices = subsample_features(
|
164
164
|
features=features,
|
165
|
-
|
165
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
166
166
|
config=SampleConfig(method="fps"),
|
167
167
|
)
|
168
168
|
features = extrapolate_knn(
|
169
169
|
anchor_features=features[_subgraph_indices],
|
170
170
|
anchor_output=features[_subgraph_indices],
|
171
171
|
extrapolation_features=features,
|
172
|
-
|
172
|
+
affinity_type=affinity_type,
|
173
173
|
)
|
174
174
|
|
175
175
|
subgraph_indices = subsample_features(
|
176
176
|
features=features,
|
177
|
-
|
177
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
178
178
|
config=SampleConfig(method="fps", num_sample=num_sample),
|
179
179
|
)
|
180
180
|
|
181
181
|
_inp = features[subgraph_indices].numpy(force=True)
|
182
182
|
_subgraph_embed = torch.tensor(reduction(
|
183
183
|
n_components=reduction_dim,
|
184
|
-
metric=
|
184
|
+
metric=AFFINITY_TO_DISTANCE[affinity_type],
|
185
185
|
random_state=seed,
|
186
186
|
**reduction_kwargs
|
187
187
|
).fit_transform(_inp), device=features.device, dtype=features.dtype)
|
@@ -190,7 +190,7 @@ def _rgb_with_dimensionality_reduction(
|
|
190
190
|
features[subgraph_indices],
|
191
191
|
_subgraph_embed,
|
192
192
|
features,
|
193
|
-
|
193
|
+
affinity_type,
|
194
194
|
knn=knn,
|
195
195
|
device=device,
|
196
196
|
move_output_to_cpu=True
|
@@ -201,7 +201,7 @@ def _rgb_with_dimensionality_reduction(
|
|
201
201
|
def rgb_from_tsne_2d(
|
202
202
|
features: torch.Tensor,
|
203
203
|
num_sample: int = 1000,
|
204
|
-
|
204
|
+
affinity_type: AffinityOptions = "cosine",
|
205
205
|
perplexity: int = 150,
|
206
206
|
q: float = 0.95,
|
207
207
|
knn: int = 10,
|
@@ -220,16 +220,12 @@ def rgb_from_tsne_2d(
|
|
220
220
|
"sklearn import failed, please install `pip install scikit-learn`"
|
221
221
|
)
|
222
222
|
num_sample = min(num_sample, features.shape[0])
|
223
|
-
|
224
|
-
logging.warning(
|
225
|
-
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
226
|
-
)
|
227
|
-
perplexity = num_sample // 2
|
223
|
+
perplexity = min(perplexity, num_sample // 2)
|
228
224
|
|
229
225
|
rgb = _rgb_with_dimensionality_reduction(
|
230
226
|
features=features,
|
231
227
|
num_sample=num_sample,
|
232
|
-
|
228
|
+
affinity_type=affinity_type,
|
233
229
|
rgb_func=rgb_from_2d_colormap,
|
234
230
|
q=q,
|
235
231
|
knn=knn,
|
@@ -245,7 +241,7 @@ def rgb_from_tsne_2d(
|
|
245
241
|
def rgb_from_tsne_3d(
|
246
242
|
features: torch.Tensor,
|
247
243
|
num_sample: int = 1000,
|
248
|
-
|
244
|
+
affinity_type: AffinityOptions = "cosine",
|
249
245
|
perplexity: int = 150,
|
250
246
|
q: float = 0.95,
|
251
247
|
knn: int = 10,
|
@@ -264,16 +260,12 @@ def rgb_from_tsne_3d(
|
|
264
260
|
"sklearn import failed, please install `pip install scikit-learn`"
|
265
261
|
)
|
266
262
|
num_sample = min(num_sample, features.shape[0])
|
267
|
-
|
268
|
-
logging.warning(
|
269
|
-
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
270
|
-
)
|
271
|
-
perplexity = num_sample // 2
|
263
|
+
perplexity = min(perplexity, num_sample // 2)
|
272
264
|
|
273
265
|
rgb = _rgb_with_dimensionality_reduction(
|
274
266
|
features=features,
|
275
267
|
num_sample=num_sample,
|
276
|
-
|
268
|
+
affinity_type=affinity_type,
|
277
269
|
rgb_func=rgb_from_3d_rgb_cube,
|
278
270
|
q=q,
|
279
271
|
knn=knn,
|
@@ -289,7 +281,7 @@ def rgb_from_tsne_3d(
|
|
289
281
|
def rgb_from_euclidean_tsne_3d(
|
290
282
|
features: torch.Tensor,
|
291
283
|
num_sample: int = 1000,
|
292
|
-
|
284
|
+
affinity_type: AffinityOptions = "cosine",
|
293
285
|
perplexity: int = 150,
|
294
286
|
q: float = 0.95,
|
295
287
|
knn: int = 10,
|
@@ -308,19 +300,15 @@ def rgb_from_euclidean_tsne_3d(
|
|
308
300
|
"sklearn import failed, please install `pip install scikit-learn`"
|
309
301
|
)
|
310
302
|
num_sample = min(num_sample, features.shape[0])
|
311
|
-
|
312
|
-
logging.warning(
|
313
|
-
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
314
|
-
)
|
315
|
-
perplexity = num_sample // 2
|
303
|
+
perplexity = min(perplexity, num_sample // 2)
|
316
304
|
|
317
305
|
def rgb_func(X_3d: torch.Tensor, q: float) -> torch.Tensor:
|
318
|
-
return rgb_from_3d_rgb_cube(to_euclidean(X_3d,
|
306
|
+
return rgb_from_3d_rgb_cube(to_euclidean(X_3d, AFFINITY_TO_DISTANCE[affinity_type]), q=q)
|
319
307
|
|
320
308
|
rgb = _rgb_with_dimensionality_reduction(
|
321
309
|
features=features,
|
322
310
|
num_sample=num_sample,
|
323
|
-
|
311
|
+
affinity_type=affinity_type,
|
324
312
|
rgb_func=rgb_func,
|
325
313
|
q=q,
|
326
314
|
knn=knn,
|
@@ -336,7 +324,7 @@ def rgb_from_euclidean_tsne_3d(
|
|
336
324
|
def rgb_from_umap_2d(
|
337
325
|
features: torch.Tensor,
|
338
326
|
num_sample: int = 1000,
|
339
|
-
|
327
|
+
affinity_type: AffinityOptions = "cosine",
|
340
328
|
n_neighbors: int = 150,
|
341
329
|
min_dist: float = 0.1,
|
342
330
|
q: float = 0.95,
|
@@ -357,7 +345,7 @@ def rgb_from_umap_2d(
|
|
357
345
|
rgb = _rgb_with_dimensionality_reduction(
|
358
346
|
features=features,
|
359
347
|
num_sample=num_sample,
|
360
|
-
|
348
|
+
affinity_type=affinity_type,
|
361
349
|
rgb_func=rgb_from_2d_colormap,
|
362
350
|
q=q,
|
363
351
|
knn=knn,
|
@@ -374,7 +362,7 @@ def rgb_from_umap_2d(
|
|
374
362
|
def rgb_from_umap_sphere(
|
375
363
|
features: torch.Tensor,
|
376
364
|
num_sample: int = 1000,
|
377
|
-
|
365
|
+
affinity_type: AffinityOptions = "cosine",
|
378
366
|
n_neighbors: int = 150,
|
379
367
|
min_dist: float = 0.1,
|
380
368
|
q: float = 0.95,
|
@@ -402,7 +390,7 @@ def rgb_from_umap_sphere(
|
|
402
390
|
rgb = _rgb_with_dimensionality_reduction(
|
403
391
|
features=features,
|
404
392
|
num_sample=num_sample,
|
405
|
-
|
393
|
+
affinity_type=affinity_type,
|
406
394
|
rgb_func=rgb_func,
|
407
395
|
q=q,
|
408
396
|
knn=knn,
|
@@ -420,7 +408,7 @@ def rgb_from_umap_sphere(
|
|
420
408
|
def rgb_from_umap_3d(
|
421
409
|
features: torch.Tensor,
|
422
410
|
num_sample: int = 1000,
|
423
|
-
|
411
|
+
affinity_type: AffinityOptions = "cosine",
|
424
412
|
n_neighbors: int = 150,
|
425
413
|
min_dist: float = 0.1,
|
426
414
|
q: float = 0.95,
|
@@ -441,7 +429,7 @@ def rgb_from_umap_3d(
|
|
441
429
|
rgb = _rgb_with_dimensionality_reduction(
|
442
430
|
features=features,
|
443
431
|
num_sample=num_sample,
|
444
|
-
|
432
|
+
affinity_type=affinity_type,
|
445
433
|
rgb_func=rgb_from_3d_rgb_cube,
|
446
434
|
q=q,
|
447
435
|
knn=knn,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.0
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -0,0 +1,18 @@
|
|
1
|
+
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
|
3
|
+
nystrom_ncut/common.py,sha256=eie19AHTMk6AGTxNnYq1UcFkHJVimeywAUYryXwaiHk,2428
|
4
|
+
nystrom_ncut/distance_utils.py,sha256=pJA8NcIKyS7-YDpRGOkc7mwBQQEsYVemdkHiTjyU4n8,4300
|
5
|
+
nystrom_ncut/sampling_utils.py,sha256=6lP8F6gftl4mgkavPsD7Vuk4erj4RtgILPhcj3YqLXk,4840
|
6
|
+
nystrom_ncut/visualize_utils.py,sha256=Sfi_kKpvFFzBFoJnbo-pQpH2jhs-A6tH64SV_WGoq58,22740
|
7
|
+
nystrom_ncut/nystrom/__init__.py,sha256=1aUXK87g4cXRXqNt6XkZsfyauw1-yv3sv0NmdmkWo-8,42
|
8
|
+
nystrom_ncut/nystrom/distance_realization.py,sha256=RTI1_Q8fCUGAPSbXaVuNA-2B-11CEAfy2CwKWPJj6xQ,5830
|
9
|
+
nystrom_ncut/nystrom/normalized_cut.py,sha256=jB_QALMY3l5CFfZPsrOFpEaquTrJP17muTrDZXxzUA8,7177
|
10
|
+
nystrom_ncut/nystrom/nystrom_utils.py,sha256=hksDO8uuAb9xKoA1ZafGwXDlQN_gZJn_qHscaSoO8JE,14120
|
11
|
+
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
+
nystrom_ncut/transformer/axis_align.py,sha256=6LTR-syJ-f4pcbnMexFmFNn1QADDhH5gka6979YBRrI,3549
|
13
|
+
nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
|
14
|
+
nystrom_ncut-0.3.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
+
nystrom_ncut-0.3.0.dist-info/METADATA,sha256=lhxicufu5Eo9HQsUiS_K-CzocemOeNravAaIXeCtriM,6058
|
16
|
+
nystrom_ncut-0.3.0.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
17
|
+
nystrom_ncut-0.3.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
+
nystrom_ncut-0.3.0.dist-info/RECORD,,
|
@@ -1,18 +0,0 @@
|
|
1
|
-
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
|
3
|
-
nystrom_ncut/common.py,sha256=_PGJoImSk_Fb_5Ri-e_IsFoCcSfbGS8CxYUUHVoNM50,2036
|
4
|
-
nystrom_ncut/distance_utils.py,sha256=p-pYdpRrJsIhzxM_IxUqja7N8okngx52WGXD9pu_Aec,3129
|
5
|
-
nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHNU,3105
|
6
|
-
nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
|
7
|
-
nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
|
8
|
-
nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
|
9
|
-
nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
|
10
|
-
nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
|
11
|
-
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
-
nystrom_ncut/transformer/axis_align.py,sha256=6LTR-syJ-f4pcbnMexFmFNn1QADDhH5gka6979YBRrI,3549
|
13
|
-
nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
|
14
|
-
nystrom_ncut-0.2.1.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
-
nystrom_ncut-0.2.1.dist-info/METADATA,sha256=l5t4vEFtPANsQY8PK0YHDJ1tw6dZUulU5daxX9T8QC0,6058
|
16
|
-
nystrom_ncut-0.2.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
17
|
-
nystrom_ncut-0.2.1.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
-
nystrom_ncut-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|