nystrom-ncut 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/common.py +18 -5
- nystrom_ncut/distance_utils.py +54 -32
- nystrom_ncut/nystrom/__init__.py +0 -3
- nystrom_ncut/nystrom/distance_realization.py +8 -9
- nystrom_ncut/nystrom/normalized_cut.py +51 -47
- nystrom_ncut/nystrom/nystrom_utils.py +78 -70
- nystrom_ncut/sampling_utils.py +64 -51
- nystrom_ncut/transformer/axis_align.py +0 -1
- nystrom_ncut/transformer/transformer_mixin.py +0 -2
- nystrom_ncut/visualize_utils.py +31 -43
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.0.dist-info}/METADATA +1 -1
- nystrom_ncut-0.3.0.dist-info/RECORD +18 -0
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.0.dist-info}/WHEEL +1 -1
- nystrom_ncut-0.2.2.dist-info/RECORD +0 -18
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.0.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.0.dist-info}/top_level.txt +0 -0
nystrom_ncut/common.py
CHANGED
@@ -12,8 +12,8 @@ def ceildiv(a: int, b: int) -> int:
|
|
12
12
|
def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> torch.Tensor:
|
13
13
|
numel = np.prod(x.shape[:-1])
|
14
14
|
n = min(n, numel)
|
15
|
-
random_indices = torch.randperm(numel)[:n]
|
16
|
-
_x = x.
|
15
|
+
random_indices = torch.randperm(numel, device=x.device)[:n]
|
16
|
+
_x = x.view((-1, x.shape[-1]))[random_indices]
|
17
17
|
if torch.allclose(torch.norm(_x, **normalize_kwargs), torch.ones(n, device=x.device)):
|
18
18
|
return x
|
19
19
|
else:
|
@@ -21,13 +21,14 @@ def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> t
|
|
21
21
|
|
22
22
|
|
23
23
|
def quantile_min_max(x: torch.Tensor, q1: float, q2: float, n_sample: int = 10000):
|
24
|
-
|
24
|
+
x = x.flatten()
|
25
|
+
if len(x) > n_sample:
|
25
26
|
np.random.seed(0)
|
26
|
-
random_idx = np.random.choice(x
|
27
|
+
random_idx = np.random.choice(len(x), n_sample, replace=False)
|
27
28
|
vmin, vmax = x[random_idx].quantile(q1), x[random_idx].quantile(q2)
|
28
29
|
else:
|
29
30
|
vmin, vmax = x.quantile(q1), x.quantile(q2)
|
30
|
-
return vmin, vmax
|
31
|
+
return vmin.item(), vmax.item()
|
31
32
|
|
32
33
|
|
33
34
|
def quantile_normalize(x: torch.Tensor, q: float = 0.95):
|
@@ -57,5 +58,17 @@ def quantile_normalize(x: torch.Tensor, q: float = 0.95):
|
|
57
58
|
return x
|
58
59
|
|
59
60
|
|
61
|
+
class default_device:
|
62
|
+
def __init__(self, device: torch.device):
|
63
|
+
self._device = device
|
64
|
+
|
65
|
+
def __enter__(self):
|
66
|
+
self._original_device = torch.get_default_device()
|
67
|
+
torch.set_default_device(self._device)
|
68
|
+
|
69
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
70
|
+
torch.set_default_device(self._original_device)
|
71
|
+
|
72
|
+
|
60
73
|
def profile(name: str, t: torch.Tensor) -> None:
|
61
74
|
print(f"{name} --- nan: {t.isnan().any()}, inf: {t.isinf().any()}, max: {t.abs().max()}, min: {t.abs().min()}")
|
nystrom_ncut/distance_utils.py
CHANGED
@@ -1,61 +1,71 @@
|
|
1
|
-
|
1
|
+
import collections
|
2
|
+
from typing import List, Literal, OrderedDict
|
2
3
|
|
3
4
|
import torch
|
4
5
|
|
5
6
|
from .common import lazy_normalize
|
6
7
|
|
7
8
|
|
8
|
-
DistanceOptions = Literal["cosine", "euclidean"
|
9
|
+
DistanceOptions = Literal["cosine", "euclidean"]
|
10
|
+
AffinityOptions = Literal["cosine", "rbf", "laplacian"]
|
9
11
|
|
12
|
+
# noinspection PyTypeChecker
|
13
|
+
DISTANCE_TO_AFFINITY: OrderedDict[DistanceOptions, List[AffinityOptions]] = collections.OrderedDict([
|
14
|
+
("cosine", ["cosine"]),
|
15
|
+
("euclidean", ["rbf", "laplacian"]),
|
16
|
+
])
|
17
|
+
# noinspection PyTypeChecker
|
18
|
+
AFFINITY_TO_DISTANCE: OrderedDict[AffinityOptions, DistanceOptions] = collections.OrderedDict(sum([
|
19
|
+
[(affinity_type, distance_type) for affinity_type in affinity_types]
|
20
|
+
for distance_type, affinity_types in DISTANCE_TO_AFFINITY.items()
|
21
|
+
], start=[]))
|
10
22
|
|
11
|
-
|
12
|
-
|
23
|
+
|
24
|
+
|
25
|
+
def to_euclidean(x: torch.Tensor, distance_type: DistanceOptions) -> torch.Tensor:
|
26
|
+
if distance_type == "cosine":
|
13
27
|
return lazy_normalize(x, p=2, dim=-1)
|
14
|
-
elif
|
28
|
+
elif distance_type == "euclidean":
|
15
29
|
return x
|
16
30
|
else:
|
17
|
-
raise ValueError(f"to_euclidean not implemented for
|
31
|
+
raise ValueError(f"to_euclidean not implemented for distance_type {distance_type}.")
|
18
32
|
|
19
33
|
|
20
34
|
def distance_from_features(
|
21
35
|
features: torch.Tensor,
|
22
36
|
features_B: torch.Tensor,
|
23
|
-
|
37
|
+
distance_type: DistanceOptions,
|
24
38
|
):
|
25
|
-
"""Compute
|
39
|
+
"""Compute distance matrix from input features.
|
26
40
|
Args:
|
27
41
|
features (torch.Tensor): input features, shape (n_samples, n_features)
|
28
42
|
features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
|
29
|
-
|
43
|
+
distance_type (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
|
30
44
|
Returns:
|
31
45
|
(torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
32
46
|
"""
|
33
47
|
# compute distance matrix from input features
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
D = D / (torch.linalg.norm(stds) ** 2)
|
49
|
-
else:
|
50
|
-
raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
|
51
|
-
return D
|
48
|
+
shape: torch.Size = features.shape[:-2]
|
49
|
+
features = features.view((-1, *features.shape[-2:]))
|
50
|
+
features_B = features_B.view((-1, *features_B.shape[-2:]))
|
51
|
+
|
52
|
+
match distance_type:
|
53
|
+
case "cosine":
|
54
|
+
features = lazy_normalize(features, dim=-1)
|
55
|
+
features_B = lazy_normalize(features_B, dim=-1)
|
56
|
+
D = 1 - features @ features_B.mT
|
57
|
+
case "euclidean":
|
58
|
+
D = torch.cdist(features, features_B, p=2)
|
59
|
+
case _:
|
60
|
+
raise ValueError("Distance should be 'cosine' or 'euclidean'")
|
61
|
+
return D.view((*shape, *D.shape[-2:]))
|
52
62
|
|
53
63
|
|
54
64
|
def affinity_from_features(
|
55
65
|
features: torch.Tensor,
|
56
66
|
features_B: torch.Tensor = None,
|
57
67
|
affinity_focal_gamma: float = 1.0,
|
58
|
-
|
68
|
+
affinity_type: AffinityOptions = "cosine",
|
59
69
|
):
|
60
70
|
"""Compute affinity matrix from input features.
|
61
71
|
|
@@ -64,7 +74,7 @@ def affinity_from_features(
|
|
64
74
|
features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
|
65
75
|
affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
|
66
76
|
on weak connections, default 1.0
|
67
|
-
|
77
|
+
affinity_type (str): distance metric, 'cosine' (default) or 'euclidean'.
|
68
78
|
Returns:
|
69
79
|
(torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
70
80
|
"""
|
@@ -75,9 +85,21 @@ def affinity_from_features(
|
|
75
85
|
features_B = features if features_B is None else features_B
|
76
86
|
|
77
87
|
# compute distance matrix from input features
|
78
|
-
D = distance_from_features(features, features_B,
|
88
|
+
D = distance_from_features(features, features_B, AFFINITY_TO_DISTANCE[affinity_type])
|
79
89
|
|
80
|
-
# torch.exp make affinity matrix positive definite,
|
81
90
|
# lower affinity_focal_gamma reduce the weak edge weights
|
82
|
-
|
91
|
+
match affinity_type:
|
92
|
+
case "cosine" | "laplacian":
|
93
|
+
A = torch.exp(-D / affinity_focal_gamma) # [... x n x n]
|
94
|
+
case "rbf":
|
95
|
+
# Outlier-robust scale invariance using quantiles to estimate standard deviation
|
96
|
+
c = 2.0
|
97
|
+
p = torch.erf(torch.tensor((-c, c), device=features.device) * (2 ** -0.5))
|
98
|
+
stds = torch.nanquantile(features, q=(p + 1) / 2, dim=-2) # [2 x ... x d]
|
99
|
+
stds = (stds[1] - stds[0]) / (2 * c) # [... x d]
|
100
|
+
D = 0.5 * (D / torch.norm(stds, dim=-1)[..., None, None]) ** 2
|
101
|
+
A = torch.exp(-D / affinity_focal_gamma)
|
102
|
+
case _:
|
103
|
+
raise ValueError("Affinity should be 'cosine', 'rbf', or 'laplacian'")
|
104
|
+
|
83
105
|
return A
|
nystrom_ncut/nystrom/__init__.py
CHANGED
@@ -18,10 +18,10 @@ from ..sampling_utils import (
|
|
18
18
|
class GramKernel(OnlineKernel):
|
19
19
|
def __init__(
|
20
20
|
self,
|
21
|
-
|
21
|
+
distance_type: DistanceOptions,
|
22
22
|
eig_solver: EigSolverOptions,
|
23
23
|
):
|
24
|
-
self.
|
24
|
+
self.distance_type: DistanceOptions = distance_type
|
25
25
|
self.eig_solver: EigSolverOptions = eig_solver
|
26
26
|
|
27
27
|
# Anchor matrices
|
@@ -40,7 +40,7 @@ class GramKernel(OnlineKernel):
|
|
40
40
|
self.A = -0.5 * distance_from_features(
|
41
41
|
self.anchor_features, # [n x d]
|
42
42
|
self.anchor_features,
|
43
|
-
|
43
|
+
distance_type=self.distance_type,
|
44
44
|
) # [n x n]
|
45
45
|
d = features.shape[-1]
|
46
46
|
U, L = solve_eig(
|
@@ -58,7 +58,7 @@ class GramKernel(OnlineKernel):
|
|
58
58
|
B = -0.5 * distance_from_features(
|
59
59
|
self.anchor_features, # [n x d]
|
60
60
|
features, # [m x d]
|
61
|
-
|
61
|
+
distance_type=self.distance_type,
|
62
62
|
) # [n x m]
|
63
63
|
b_r = torch.sum(B, dim=-1) # [n]
|
64
64
|
b_c = torch.sum(B, dim=-2) # [m]
|
@@ -84,7 +84,7 @@ class GramKernel(OnlineKernel):
|
|
84
84
|
B = -0.5 * distance_from_features(
|
85
85
|
self.anchor_features,
|
86
86
|
features,
|
87
|
-
|
87
|
+
distance_type=self.distance_type,
|
88
88
|
)
|
89
89
|
b_c = torch.sum(B, dim=-2) # [m]
|
90
90
|
col_sum = b_c + B.mT @ self.Ainv @ self.b_r # [m]
|
@@ -98,7 +98,7 @@ class DistanceRealization(OnlineNystromSubsampleFit):
|
|
98
98
|
def __init__(
|
99
99
|
self,
|
100
100
|
n_components: int = 100,
|
101
|
-
|
101
|
+
distance_type: DistanceOptions = "cosine",
|
102
102
|
sample_config: SampleConfig = SampleConfig(),
|
103
103
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
104
104
|
chunk_size: int = 8192,
|
@@ -115,13 +115,12 @@ class DistanceRealization(OnlineNystromSubsampleFit):
|
|
115
115
|
OnlineNystromSubsampleFit.__init__(
|
116
116
|
self,
|
117
117
|
n_components=n_components,
|
118
|
-
kernel=GramKernel(
|
119
|
-
|
118
|
+
kernel=GramKernel(distance_type, eig_solver),
|
119
|
+
distance_type=distance_type,
|
120
120
|
sample_config=sample_config,
|
121
121
|
eig_solver=eig_solver,
|
122
122
|
chunk_size=chunk_size,
|
123
123
|
)
|
124
|
-
self.distance: DistanceOptions = distance
|
125
124
|
|
126
125
|
def fit_transform(
|
127
126
|
self,
|
@@ -8,7 +8,8 @@ from .nystrom_utils import (
|
|
8
8
|
solve_eig,
|
9
9
|
)
|
10
10
|
from ..distance_utils import (
|
11
|
-
|
11
|
+
AffinityOptions,
|
12
|
+
AFFINITY_TO_DISTANCE,
|
12
13
|
affinity_from_features,
|
13
14
|
)
|
14
15
|
from ..sampling_utils import (
|
@@ -20,80 +21,83 @@ class LaplacianKernel(OnlineKernel):
|
|
20
21
|
def __init__(
|
21
22
|
self,
|
22
23
|
affinity_focal_gamma: float,
|
23
|
-
|
24
|
+
affinity_type: AffinityOptions,
|
24
25
|
adaptive_scaling: bool,
|
25
26
|
eig_solver: EigSolverOptions,
|
26
27
|
):
|
27
28
|
self.affinity_focal_gamma = affinity_focal_gamma
|
28
|
-
self.
|
29
|
+
self.affinity_type: AffinityOptions = affinity_type
|
29
30
|
self.adaptive_scaling: bool = adaptive_scaling
|
30
31
|
self.eig_solver: EigSolverOptions = eig_solver
|
31
32
|
|
32
33
|
# Anchor matrices
|
33
|
-
self.anchor_features: torch.Tensor = None
|
34
|
-
self.
|
35
|
-
self.
|
34
|
+
self.anchor_features: torch.Tensor = None # [... x n x d]
|
35
|
+
self.anchor_mask: torch.Tensor = None
|
36
|
+
self.A: torch.Tensor = None # [... x n x n]
|
37
|
+
self.Ainv: torch.Tensor = None # [... x n x n]
|
36
38
|
|
37
39
|
# Updated matrices
|
38
|
-
self.a_r: torch.Tensor = None
|
39
|
-
self.b_r: torch.Tensor = None
|
40
|
+
self.a_r: torch.Tensor = None # [... x n]
|
41
|
+
self.b_r: torch.Tensor = None # [... x n]
|
40
42
|
|
41
43
|
def fit(self, features: torch.Tensor) -> None:
|
42
|
-
self.anchor_features = features
|
43
|
-
self.
|
44
|
-
|
44
|
+
self.anchor_features = features # [... x n x d]
|
45
|
+
self.anchor_mask = torch.all(torch.isnan(self.anchor_features), dim=-1) # [... x n]
|
46
|
+
|
47
|
+
self.A = torch.nan_to_num(affinity_from_features(
|
48
|
+
self.anchor_features, # [... x n x d]
|
45
49
|
affinity_focal_gamma=self.affinity_focal_gamma,
|
46
|
-
|
47
|
-
)
|
50
|
+
affinity_type=self.affinity_type,
|
51
|
+
), nan=0.0) # [... x n x n]
|
48
52
|
d = features.shape[-1]
|
49
53
|
U, L = solve_eig(
|
50
54
|
self.A,
|
51
55
|
num_eig=d + 1, # d * (d + 3) // 2 + 1,
|
52
56
|
eig_solver=self.eig_solver,
|
53
|
-
)
|
54
|
-
self.Ainv = U @ torch.
|
55
|
-
self.a_r = torch.sum(self.A, dim=-1)
|
56
|
-
self.b_r = torch.zeros_like(self.a_r)
|
57
|
+
) # [... x n x (d + 1)], [... x (d + 1)]
|
58
|
+
self.Ainv = U @ torch.diag_embed(1 / L) @ U.mT # [... x n x n]
|
59
|
+
self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
|
60
|
+
self.b_r = torch.zeros_like(self.a_r) # [... x n]
|
57
61
|
|
58
62
|
def _affinity(self, features: torch.Tensor) -> torch.Tensor:
|
59
|
-
B = affinity_from_features(
|
60
|
-
self.anchor_features,
|
61
|
-
features,
|
63
|
+
B = torch.where(self.anchor_mask[..., None], 0.0, affinity_from_features(
|
64
|
+
self.anchor_features, # [... x n x d]
|
65
|
+
features, # [... x m x d]
|
62
66
|
affinity_focal_gamma=self.affinity_focal_gamma,
|
63
|
-
|
64
|
-
)
|
67
|
+
affinity_type=self.affinity_type,
|
68
|
+
)) # [... x n x m]
|
65
69
|
if self.adaptive_scaling:
|
66
70
|
diagonal = (
|
67
|
-
einops.rearrange(B, "n m -> m 1 n")
|
68
|
-
@ self.Ainv
|
69
|
-
@ einops.rearrange(B, "n m -> m n 1")
|
70
|
-
).squeeze(
|
71
|
-
adaptive_scale = diagonal ** -0.5
|
72
|
-
B = B * adaptive_scale
|
73
|
-
return B
|
71
|
+
einops.rearrange(B, "... n m -> ... m 1 n") # [... x m x 1 x n]
|
72
|
+
@ self.Ainv # [... x n x n]
|
73
|
+
@ einops.rearrange(B, "... n m -> ... m n 1") # [... x m x n x 1]
|
74
|
+
).squeeze(-2, -1) # [... x m]
|
75
|
+
adaptive_scale = diagonal ** -0.5 # [... x m]
|
76
|
+
B = B * adaptive_scale[..., None, :]
|
77
|
+
return B # [... x n x m]
|
74
78
|
|
75
79
|
def update(self, features: torch.Tensor) -> torch.Tensor:
|
76
|
-
B = self._affinity(features)
|
77
|
-
b_r = torch.sum(B, dim=-1)
|
78
|
-
b_c = torch.sum(B, dim=-2)
|
79
|
-
self.b_r = self.b_r + b_r
|
80
|
+
B = self._affinity(features) # [... x n x m]
|
81
|
+
b_r = torch.sum(torch.nan_to_num(B, nan=0.0), dim=-1) # [... x n]
|
82
|
+
b_c = torch.sum(B, dim=-2) # [... x m]
|
83
|
+
self.b_r = self.b_r + b_r # [... x n]
|
80
84
|
|
81
|
-
row_sum = self.a_r + self.b_r
|
82
|
-
col_sum = b_c + B.mT @ self.Ainv @ self.b_r
|
83
|
-
scale = (row_sum[:, None] * col_sum) ** -0.5
|
84
|
-
return (B * scale).mT
|
85
|
+
row_sum = self.a_r + self.b_r # [... x n]
|
86
|
+
col_sum = b_c + (B.mT @ (self.Ainv @ self.b_r[..., None]))[..., 0] # [... x m]
|
87
|
+
scale = (row_sum[..., :, None] * col_sum[..., None, :]) ** -0.5 # [... x n x m]
|
88
|
+
return (B * scale).mT # [... x m x n]
|
85
89
|
|
86
90
|
def transform(self, features: torch.Tensor = None) -> torch.Tensor:
|
87
|
-
row_sum = self.a_r + self.b_r
|
91
|
+
row_sum = self.a_r + self.b_r # [... x n]
|
88
92
|
if features is None:
|
89
|
-
B = self.A
|
90
|
-
col_sum = row_sum
|
93
|
+
B = self.A # [... x n x n]
|
94
|
+
col_sum = row_sum # [... x n]
|
91
95
|
else:
|
92
96
|
B = self._affinity(features)
|
93
|
-
b_c = torch.sum(B, dim=-2)
|
94
|
-
col_sum = b_c + B.mT @ self.Ainv @ self.b_r
|
95
|
-
scale = (row_sum[:, None] * col_sum) ** -0.5
|
96
|
-
return (B * scale).mT
|
97
|
+
b_c = torch.sum(B, dim=-2) # [... x m]
|
98
|
+
col_sum = b_c + (B.mT @ (self.Ainv @ self.b_r[..., None]))[..., 0] # [... x m]
|
99
|
+
scale = (row_sum[..., :, None] * col_sum[..., None, :]) ** -0.5 # [... x n x m]
|
100
|
+
return (B * scale).mT # [... x m x n]
|
97
101
|
|
98
102
|
|
99
103
|
class NCut(OnlineNystromSubsampleFit):
|
@@ -103,7 +107,7 @@ class NCut(OnlineNystromSubsampleFit):
|
|
103
107
|
self,
|
104
108
|
n_components: int = 100,
|
105
109
|
affinity_focal_gamma: float = 1.0,
|
106
|
-
|
110
|
+
affinity_type: AffinityOptions = "cosine",
|
107
111
|
adaptive_scaling: bool = False,
|
108
112
|
sample_config: SampleConfig = SampleConfig(),
|
109
113
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
@@ -124,8 +128,8 @@ class NCut(OnlineNystromSubsampleFit):
|
|
124
128
|
OnlineNystromSubsampleFit.__init__(
|
125
129
|
self,
|
126
130
|
n_components=n_components,
|
127
|
-
kernel=LaplacianKernel(affinity_focal_gamma,
|
128
|
-
|
131
|
+
kernel=LaplacianKernel(affinity_focal_gamma, affinity_type, adaptive_scaling, eig_solver),
|
132
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
129
133
|
sample_config=sample_config,
|
130
134
|
eig_solver=eig_solver,
|
131
135
|
chunk_size=chunk_size,
|
@@ -25,15 +25,15 @@ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
|
|
25
25
|
|
26
26
|
class OnlineKernel:
|
27
27
|
@abstractmethod
|
28
|
-
def fit(self, features: torch.Tensor) -> "OnlineKernel": # [n x d]
|
28
|
+
def fit(self, features: torch.Tensor) -> "OnlineKernel": # [... x n x d]
|
29
29
|
""""""
|
30
30
|
|
31
31
|
@abstractmethod
|
32
|
-
def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
|
32
|
+
def update(self, features: torch.Tensor) -> torch.Tensor: # [... x m x d] -> [... x m x n]
|
33
33
|
""""""
|
34
34
|
|
35
35
|
@abstractmethod
|
36
|
-
def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
|
36
|
+
def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [... x m x d] -> [... x m x n]
|
37
37
|
""""""
|
38
38
|
|
39
39
|
|
@@ -54,20 +54,21 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
54
54
|
self.n_components: int = n_components
|
55
55
|
self.kernel: OnlineKernel = kernel
|
56
56
|
self.eig_solver: EigSolverOptions = eig_solver
|
57
|
+
self.shape: torch.Size = None # ...
|
57
58
|
|
58
59
|
self.chunk_size = chunk_size
|
59
60
|
|
60
61
|
# Anchor matrices
|
61
|
-
self.anchor_features: torch.Tensor = None # [n x d]
|
62
|
-
self.A: torch.Tensor = None # [n x n]
|
63
|
-
self.Ahinv: torch.Tensor = None # [n x n]
|
64
|
-
self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
|
65
|
-
self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
|
62
|
+
self.anchor_features: torch.Tensor = None # [... x n x d]
|
63
|
+
self.A: torch.Tensor = None # [... x n x n]
|
64
|
+
self.Ahinv: torch.Tensor = None # [... x n x n]
|
65
|
+
self.Ahinv_UL: torch.Tensor = None # [... x n x indirect_pca_dim]
|
66
|
+
self.Ahinv_VT: torch.Tensor = None # [... x indirect_pca_dim x n]
|
66
67
|
|
67
68
|
# Updated matrices
|
68
|
-
self.S: torch.Tensor = None # [n x n]
|
69
|
-
self.transform_matrix: torch.Tensor = None # [n x n_components]
|
70
|
-
self.eigenvalues_: torch.Tensor = None # [n]
|
69
|
+
self.S: torch.Tensor = None # [... x n x n]
|
70
|
+
self.transform_matrix: torch.Tensor = None # [... x n x n_components]
|
71
|
+
self.eigenvalues_: torch.Tensor = None # [... x n]
|
71
72
|
|
72
73
|
def _update_to_kernel(self, d: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
73
74
|
self.A = self.S = self.kernel.transform()
|
@@ -75,10 +76,10 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
75
76
|
self.A,
|
76
77
|
num_eig=d + 1, # d * (d + 3) // 2 + 1,
|
77
78
|
eig_solver=self.eig_solver,
|
78
|
-
) # [n x (? + 1)], [? + 1]
|
79
|
-
self.Ahinv_UL = U * (L ** -0.5)
|
80
|
-
self.Ahinv_VT = U.mT # [(? + 1) x n]
|
81
|
-
self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
|
79
|
+
) # [... x n x (? + 1)], [... x (? + 1)]
|
80
|
+
self.Ahinv_UL = U * (L[..., None, :] ** -0.5) # [... x n x (? + 1)]
|
81
|
+
self.Ahinv_VT = U.mT # [... x (? + 1) x n]
|
82
|
+
self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [... x n x n]
|
82
83
|
return U, L
|
83
84
|
|
84
85
|
def fit(self, features: torch.Tensor) -> "OnlineNystrom":
|
@@ -89,65 +90,63 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
89
90
|
self.anchor_features = features
|
90
91
|
|
91
92
|
self.kernel.fit(self.anchor_features)
|
92
|
-
U, L = self._update_to_kernel(features.shape[-1]) # [n x (d + 1)], [d + 1]
|
93
|
+
U, L = self._update_to_kernel(features.shape[-1]) # [... x n x (d + 1)], [... x (d + 1)]
|
93
94
|
|
94
|
-
self.transform_matrix = (U / L)[:, :self.n_components]
|
95
|
-
self.eigenvalues_ = L[:self.n_components]
|
96
|
-
self.
|
97
|
-
return U[:, :self.n_components] # [n x n_components]
|
95
|
+
self.transform_matrix = (U / L[..., None, :])[..., :, :self.n_components] # [... x n x n_components]
|
96
|
+
self.eigenvalues_ = L[..., :self.n_components] # [... x n_components]
|
97
|
+
return U[..., :, :self.n_components] # [... x n x n_components]
|
98
98
|
|
99
99
|
def update(self, features: torch.Tensor) -> torch.Tensor:
|
100
100
|
d = features.shape[-1]
|
101
|
-
n_chunks = ceildiv(
|
101
|
+
n_chunks = ceildiv(features.shape[-2], self.chunk_size)
|
102
102
|
if n_chunks > 1:
|
103
103
|
""" Chunked version """
|
104
|
-
chunks = torch.chunk(features, n_chunks, dim
|
104
|
+
chunks = torch.chunk(features, n_chunks, dim=-2)
|
105
105
|
for chunk in chunks:
|
106
106
|
self.kernel.update(chunk)
|
107
107
|
self._update_to_kernel(d)
|
108
108
|
|
109
|
-
compressed_BBT = 0.0 # [(? + 1) x (? + 1))]
|
109
|
+
compressed_BBT = 0.0 # [... x (? + 1) x (? + 1))]
|
110
110
|
for chunk in chunks:
|
111
|
-
_B = self.kernel.transform(chunk).mT # [n x _m]
|
112
|
-
_compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
self.
|
111
|
+
_B = self.kernel.transform(chunk).mT # [... x n x _m]
|
112
|
+
_compressed_B = self.Ahinv_VT @ _B # [... x (? + 1) x _m]
|
113
|
+
_compressed_B = torch.nan_to_num(_compressed_B, nan=0.0)
|
114
|
+
compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [... x (? + 1) x (? + 1)]
|
115
|
+
self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [... x n x n]
|
116
|
+
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
|
117
|
+
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
|
117
118
|
|
118
119
|
VS = []
|
119
120
|
for chunk in chunks:
|
120
|
-
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
|
121
|
-
VS = torch.cat(VS, dim
|
122
|
-
return VS # [m x n_components]
|
121
|
+
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
|
122
|
+
VS = torch.cat(VS, dim=-2)
|
123
|
+
return VS # [... x m x n_components]
|
123
124
|
else:
|
124
125
|
""" Unchunked version """
|
125
|
-
B = self.kernel.update(features).mT # [n x m]
|
126
|
+
B = self.kernel.update(features).mT # [... x n x m]
|
126
127
|
self._update_to_kernel(d)
|
127
|
-
compressed_B = self.Ahinv_VT @ B # [
|
128
|
+
compressed_B = self.Ahinv_VT @ B # [... x (? + 1) x m]
|
129
|
+
compressed_B = torch.nan_to_num(compressed_B, nan=0.0)
|
128
130
|
|
129
|
-
self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
|
130
|
-
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
|
131
|
-
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5)
|
131
|
+
self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [... x n x n]
|
132
|
+
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
|
133
|
+
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
|
132
134
|
|
133
|
-
return B.mT @ self.transform_matrix # [m x n_components]
|
135
|
+
return B.mT @ self.transform_matrix # [... x m x n_components]
|
134
136
|
|
135
|
-
def transform(self, features: torch.Tensor
|
136
|
-
|
137
|
-
|
137
|
+
def transform(self, features: torch.Tensor) -> torch.Tensor:
|
138
|
+
n_chunks = ceildiv(features.shape[-2], self.chunk_size)
|
139
|
+
if n_chunks > 1:
|
140
|
+
""" Chunked version """
|
141
|
+
chunks = torch.chunk(features, n_chunks, dim=-2)
|
142
|
+
VS = []
|
143
|
+
for chunk in chunks:
|
144
|
+
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
|
145
|
+
VS = torch.cat(VS, dim=-2)
|
138
146
|
else:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
chunks = torch.chunk(features, n_chunks, dim=0)
|
143
|
-
VS = []
|
144
|
-
for chunk in chunks:
|
145
|
-
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
|
146
|
-
VS = torch.cat(VS, dim=0)
|
147
|
-
else:
|
148
|
-
""" Unchunked version """
|
149
|
-
VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
|
150
|
-
return VS # [m x n_components]
|
147
|
+
""" Unchunked version """
|
148
|
+
VS = self.kernel.transform(features) @ self.transform_matrix # [... x m x n_components]
|
149
|
+
return VS # [... x m x n_components]
|
151
150
|
|
152
151
|
|
153
152
|
class OnlineNystromSubsampleFit(OnlineNystrom):
|
@@ -155,7 +154,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
155
154
|
self,
|
156
155
|
n_components: int,
|
157
156
|
kernel: OnlineKernel,
|
158
|
-
|
157
|
+
distance_type: DistanceOptions,
|
159
158
|
sample_config: SampleConfig,
|
160
159
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
161
160
|
chunk_size: int = 8192,
|
@@ -167,7 +166,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
167
166
|
eig_solver=eig_solver,
|
168
167
|
chunk_size=chunk_size,
|
169
168
|
)
|
170
|
-
self.
|
169
|
+
self.distance_type: DistanceOptions = distance_type
|
171
170
|
self.sample_config: SampleConfig = sample_config
|
172
171
|
self.sample_config._ncut_obj = copy.deepcopy(self)
|
173
172
|
self.anchor_indices: torch.Tensor = None
|
@@ -177,7 +176,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
177
176
|
features: torch.Tensor,
|
178
177
|
precomputed_sampled_indices: torch.Tensor,
|
179
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
180
|
-
_n = features.shape[
|
179
|
+
_n = features.shape[-2]
|
181
180
|
if self.sample_config.num_sample >= _n:
|
182
181
|
logging.info(
|
183
182
|
f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
|
@@ -189,16 +188,17 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
189
188
|
else:
|
190
189
|
self.anchor_indices = subsample_features(
|
191
190
|
features=features,
|
192
|
-
|
191
|
+
distance_type=self.distance_type,
|
193
192
|
config=self.sample_config,
|
194
193
|
)
|
195
|
-
sampled_features = features[self.anchor_indices]
|
194
|
+
sampled_features = torch.gather(features, -2, self.anchor_indices[..., None].expand([-1] * self.anchor_indices.ndim + [features.shape[-1]]))
|
196
195
|
OnlineNystrom.fit(self, sampled_features)
|
197
196
|
|
198
|
-
_n_not_sampled = _n -
|
197
|
+
_n_not_sampled = _n - self.anchor_indices.shape[-1]
|
199
198
|
if _n_not_sampled > 0:
|
200
|
-
|
201
|
-
|
199
|
+
unsampled_mask = torch.full(features.shape[:-1], True, device=features.device).scatter_(-1, self.anchor_indices, False)
|
200
|
+
unsampled_indices = torch.where(unsampled_mask)[-1].view((*features.shape[:-2], -1))
|
201
|
+
unsampled_features = torch.gather(features, -2, unsampled_indices[..., None].expand([-1] * unsampled_indices.ndim + [features.shape[-1]]))
|
202
202
|
V_unsampled = OnlineNystrom.update(self, unsampled_features)
|
203
203
|
else:
|
204
204
|
unsampled_indices = V_unsampled = None
|
@@ -236,12 +236,12 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
236
236
|
(torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
|
237
237
|
"""
|
238
238
|
unsampled_indices, V_unsampled = OnlineNystromSubsampleFit._fit_helper(self, features, precomputed_sampled_indices)
|
239
|
-
V_sampled = OnlineNystrom.transform(self)
|
239
|
+
V_sampled = OnlineNystrom.transform(self, self.anchor_features)
|
240
240
|
|
241
241
|
if unsampled_indices is not None:
|
242
|
-
V = torch.zeros((
|
243
|
-
|
244
|
-
|
242
|
+
V = torch.zeros((*features.shape[:-1], self.n_components), device=features.device)
|
243
|
+
for (indices, _V) in [(self.anchor_indices, V_sampled), (unsampled_indices, V_unsampled)]:
|
244
|
+
V.scatter_(-2, indices[..., None].expand([-1] * indices.ndim + [self.n_components]), _V)
|
245
245
|
else:
|
246
246
|
V = V_sampled
|
247
247
|
return V
|
@@ -264,12 +264,16 @@ def solve_eig(
|
|
264
264
|
(torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
|
265
265
|
(torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
|
266
266
|
"""
|
267
|
-
|
267
|
+
shape: torch.Size = A.shape[:-2]
|
268
|
+
A = A.view((-1, *A.shape[-2:]))
|
269
|
+
bsz: int = A.shape[0]
|
270
|
+
|
271
|
+
A = A + eig_value_buffer * torch.eye(A.shape[-1], device=A.device)
|
268
272
|
|
269
273
|
# compute eigenvectors
|
270
274
|
if eig_solver == "svd_lowrank": # default
|
271
275
|
# only top q eigenvectors, fastest
|
272
|
-
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
|
276
|
+
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig) # complex: [(...) x N x D], [(...) x D]
|
273
277
|
elif eig_solver == "lobpcg":
|
274
278
|
# only top k eigenvectors, fast
|
275
279
|
eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
|
@@ -286,11 +290,15 @@ def solve_eig(
|
|
286
290
|
eigen_value = eigen_value - eig_value_buffer
|
287
291
|
|
288
292
|
# sort eigenvectors by eigenvalues, take top (descending order)
|
289
|
-
indices = torch.topk(eigen_value.abs(), k=num_eig, dim
|
290
|
-
eigen_value
|
293
|
+
indices = torch.topk(eigen_value.abs(), k=num_eig, dim=-1).indices # int: [(...) x S]
|
294
|
+
eigen_value = eigen_value[torch.arange(bsz)[:, None], indices] # complex: [(...) x S]
|
295
|
+
eigen_vector = eigen_vector[torch.arange(bsz)[:, None], :, indices].mT # complex: [(...) x N x S]
|
291
296
|
|
292
297
|
# correct the random rotation (flipping sign) of eigenvectors
|
293
|
-
sign = torch.sum(eigen_vector.real, dim=
|
298
|
+
sign = torch.sign(torch.sum(eigen_vector.real, dim=-2, keepdim=True)) # float: [(...) x 1 x S]
|
294
299
|
sign[sign == 0] = 1.0
|
295
300
|
eigen_vector = eigen_vector * sign
|
301
|
+
|
302
|
+
eigen_value = eigen_value.view((*shape, *eigen_value.shape[-1:])) # complex: [... x S]
|
303
|
+
eigen_vector = eigen_vector.view((*shape, *eigen_vector.shape[-2:])) # complex: [... x N x S]
|
296
304
|
return eigen_vector, eigen_value
|
nystrom_ncut/sampling_utils.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1
|
-
import logging
|
2
1
|
from dataclasses import dataclass
|
3
2
|
from typing import Literal
|
4
3
|
|
5
4
|
import torch
|
6
5
|
from pytorch3d.ops import sample_farthest_points
|
7
6
|
|
7
|
+
from .common import (
|
8
|
+
default_device,
|
9
|
+
)
|
8
10
|
from .distance_utils import (
|
9
11
|
DistanceOptions,
|
10
12
|
to_euclidean,
|
11
13
|
)
|
14
|
+
from .transformer import (
|
15
|
+
TorchTransformerMixin,
|
16
|
+
)
|
12
17
|
|
13
18
|
|
14
|
-
SampleOptions = Literal["random", "fps", "fps_recursive"]
|
19
|
+
SampleOptions = Literal["full", "random", "fps", "fps_recursive"]
|
15
20
|
|
16
21
|
|
17
22
|
@dataclass
|
@@ -20,69 +25,77 @@ class SampleConfig:
|
|
20
25
|
num_sample: int = 10000
|
21
26
|
fps_dim: int = 12
|
22
27
|
n_iter: int = None
|
23
|
-
_ncut_obj:
|
28
|
+
_ncut_obj: TorchTransformerMixin = None
|
24
29
|
|
25
30
|
|
26
31
|
@torch.no_grad()
|
27
32
|
def subsample_features(
|
28
33
|
features: torch.Tensor,
|
29
|
-
|
34
|
+
distance_type: DistanceOptions,
|
30
35
|
config: SampleConfig,
|
31
|
-
max_draw: int = 1000000,
|
32
36
|
):
|
33
|
-
features = features.detach()
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
"num_sample is larger than total, bypass Nystrom-like approximation"
|
38
|
-
)
|
39
|
-
sampled_indices = torch.arange(features.shape[0])
|
40
|
-
else:
|
41
|
-
# sample subgraph
|
42
|
-
if config.method == "fps": # default
|
43
|
-
features = to_euclidean(features, disttype)
|
44
|
-
if config.num_sample > max_draw:
|
45
|
-
logging.warning(
|
46
|
-
f"num_sample is larger than max_draw, apply farthest point sampling on random sampled {max_draw} samples"
|
47
|
-
)
|
48
|
-
draw_indices = torch.randperm(features.shape[0])[:max_draw]
|
49
|
-
sampled_indices = fpsample(features[draw_indices], config)
|
50
|
-
sampled_indices = draw_indices[sampled_indices]
|
51
|
-
else:
|
52
|
-
sampled_indices = fpsample(features, config)
|
53
|
-
|
54
|
-
elif config.method == "random": # not recommended
|
55
|
-
sampled_indices = torch.randperm(features.shape[0])[:config.num_sample]
|
56
|
-
|
57
|
-
elif config.method == "fps_recursive":
|
58
|
-
features = to_euclidean(features, disttype)
|
59
|
-
sampled_indices = subsample_features(
|
60
|
-
features=features,
|
61
|
-
disttype=disttype,
|
62
|
-
config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
|
63
|
-
)
|
64
|
-
nc = config._ncut_obj
|
65
|
-
for _ in range(config.n_iter):
|
66
|
-
fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
|
67
|
-
|
68
|
-
fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
|
69
|
-
sampled_indices = torch.sort(fpsample(fps_features, config)).values
|
37
|
+
features = features.detach() # float: [... x n x d]
|
38
|
+
with default_device(features.device):
|
39
|
+
if config.method == "full" or config.num_sample >= features.shape[0]:
|
40
|
+
sampled_indices = torch.arange(features.shape[-2]).expand(features.shape[:-1]) # int: [... x n]
|
70
41
|
else:
|
71
|
-
|
72
|
-
|
73
|
-
|
42
|
+
# sample
|
43
|
+
match config.method:
|
44
|
+
case "fps": # default
|
45
|
+
sampled_indices = fpsample(to_euclidean(features, distance_type), config)
|
46
|
+
|
47
|
+
case "random": # not recommended
|
48
|
+
mask = torch.all(torch.isfinite(features), dim=-1) # bool: [... x n]
|
49
|
+
weights = mask.to(torch.float) + torch.rand(mask.shape) # float: [... x n]
|
50
|
+
sampled_indices = torch.topk(weights, k=config.num_sample, dim=-1).indices # int: [... x num_sample]
|
51
|
+
|
52
|
+
case "fps_recursive":
|
53
|
+
features = to_euclidean(features, distance_type) # float: [... x n x d]
|
54
|
+
sampled_indices = subsample_features(
|
55
|
+
features=features,
|
56
|
+
distance_type=distance_type,
|
57
|
+
config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
|
58
|
+
) # int: [... x num_sample]
|
59
|
+
nc = config._ncut_obj
|
60
|
+
for _ in range(config.n_iter):
|
61
|
+
fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
|
62
|
+
|
63
|
+
fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
|
64
|
+
sampled_indices = torch.sort(fpsample(fps_features, config), dim=-1).values
|
65
|
+
|
66
|
+
case _:
|
67
|
+
raise ValueError("sample_method should be 'farthest' or 'random'")
|
68
|
+
sampled_indices = torch.sort(sampled_indices, dim=-1).values
|
69
|
+
return sampled_indices
|
74
70
|
|
75
71
|
|
76
72
|
def fpsample(
|
77
73
|
features: torch.Tensor,
|
78
74
|
config: SampleConfig,
|
79
75
|
):
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
76
|
+
shape = features.shape[:-2] # ...
|
77
|
+
features = features.view((-1, *features.shape[-2:])) # [(...) x n x d]
|
78
|
+
bsz = features.shape[0]
|
79
|
+
|
80
|
+
mask = torch.all(torch.isfinite(features), dim=-1) # bool: [(...) x n]
|
81
|
+
count = torch.sum(mask, dim=-1) # int: [(...)]
|
82
|
+
order = torch.topk(mask.to(torch.int), k=torch.max(count).item(), dim=-1).indices # int: [(...) x max_count]
|
83
|
+
|
84
|
+
features = torch.nan_to_num(features[torch.arange(bsz)[:, None], order], nan=0.0) # float: [(...) x max_count x d]
|
85
|
+
if features.shape[-1] > config.fps_dim:
|
86
|
+
U, S, V = torch.pca_lowrank(features, q=config.fps_dim) # float: [(...) x max_count x fps_dim], [(...) x fps_dim], [(...) x fps_dim x fps_dim]
|
87
|
+
features = U * S[..., None, :] # float: [(...) x max_count x fps_dim]
|
84
88
|
|
85
89
|
try:
|
86
|
-
|
90
|
+
sample_indices = sample_farthest_points(
|
91
|
+
features, lengths=count, K=config.num_sample
|
92
|
+
)[1] # int: [(...) x num_sample]
|
87
93
|
except RuntimeError:
|
88
|
-
|
94
|
+
original_device = features.device
|
95
|
+
alternative_device = "cuda" if original_device == "cpu" else "cpu"
|
96
|
+
sample_indices = sample_farthest_points(
|
97
|
+
features.to(alternative_device), lengths=count.to(alternative_device), K=config.num_sample
|
98
|
+
)[1].to(original_device) # int: [(...) x num_sample]
|
99
|
+
sample_indices = torch.gather(order, 1, sample_indices) # int: [(...) x num_sample]
|
100
|
+
|
101
|
+
return sample_indices.view((*shape, *sample_indices.shape[-1:])) # int: [... x num_sample]
|
@@ -70,7 +70,6 @@ class AxisAlign(TorchTransformerMixin):
|
|
70
70
|
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
71
71
|
|
72
72
|
self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
|
73
|
-
self.is_fitted = True
|
74
73
|
return self
|
75
74
|
|
76
75
|
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
nystrom_ncut/visualize_utils.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
import
|
2
|
-
from typing import Any, Callable, Dict, Literal, Union
|
1
|
+
from typing import Any, Callable, Dict, Union
|
3
2
|
|
4
3
|
import numpy as np
|
5
4
|
import torch
|
@@ -13,7 +12,8 @@ from .common import (
|
|
13
12
|
quantile_normalize,
|
14
13
|
)
|
15
14
|
from .distance_utils import (
|
16
|
-
|
15
|
+
AffinityOptions,
|
16
|
+
AFFINITY_TO_DISTANCE,
|
17
17
|
to_euclidean,
|
18
18
|
affinity_from_features,
|
19
19
|
)
|
@@ -27,7 +27,7 @@ def extrapolate_knn(
|
|
27
27
|
anchor_features: torch.Tensor, # [n x d]
|
28
28
|
anchor_output: torch.Tensor, # [n x d']
|
29
29
|
extrapolation_features: torch.Tensor, # [m x d]
|
30
|
-
|
30
|
+
affinity_type: AffinityOptions,
|
31
31
|
knn: int = 10, # k
|
32
32
|
affinity_focal_gamma: float = 1.0,
|
33
33
|
chunk_size: int = 8192,
|
@@ -41,7 +41,7 @@ def extrapolate_knn(
|
|
41
41
|
anchor_output (torch.Tensor): output from subgraph, shape (num_sample, D)
|
42
42
|
extrapolation_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
|
43
43
|
knn (int): number of KNN to propagate eige nvectors
|
44
|
-
|
44
|
+
affinity_type (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
|
45
45
|
chunk_size (int): chunk size for matrix multiplication
|
46
46
|
device (str): device to use for computation, if None, will not change device
|
47
47
|
Returns:
|
@@ -66,7 +66,7 @@ def extrapolate_knn(
|
|
66
66
|
for _v in torch.chunk(extrapolation_features, n_chunks, dim=0):
|
67
67
|
_v = _v.to(device) # [_m x d]
|
68
68
|
|
69
|
-
_A = affinity_from_features(anchor_features, _v, affinity_focal_gamma,
|
69
|
+
_A = affinity_from_features(anchor_features, _v, affinity_focal_gamma, affinity_type).mT # [_m x n]
|
70
70
|
if knn is not None:
|
71
71
|
_A, indices = _A.topk(k=knn, dim=-1, largest=True) # [_m x k], [_m x k]
|
72
72
|
_anchor_output = anchor_output[indices] # [_m x k x d]
|
@@ -90,7 +90,7 @@ def extrapolate_knn_with_subsampling(
|
|
90
90
|
full_output: torch.Tensor, # [n x d']
|
91
91
|
extrapolation_features: torch.Tensor, # [m x d]
|
92
92
|
sample_config: SampleConfig,
|
93
|
-
|
93
|
+
affinity_type: AffinityOptions,
|
94
94
|
knn: int = 10, # k
|
95
95
|
affinity_focal_gamma: float = 1.0,
|
96
96
|
chunk_size: int = 8192,
|
@@ -122,7 +122,7 @@ def extrapolate_knn_with_subsampling(
|
|
122
122
|
# sample subgraph
|
123
123
|
anchor_indices = subsample_features(
|
124
124
|
features=full_features,
|
125
|
-
|
125
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
126
126
|
config=sample_config,
|
127
127
|
)
|
128
128
|
|
@@ -135,7 +135,7 @@ def extrapolate_knn_with_subsampling(
|
|
135
135
|
anchor_features,
|
136
136
|
anchor_output,
|
137
137
|
extrapolation_features,
|
138
|
-
|
138
|
+
affinity_type,
|
139
139
|
knn=knn,
|
140
140
|
affinity_focal_gamma=affinity_focal_gamma,
|
141
141
|
chunk_size=chunk_size,
|
@@ -148,7 +148,7 @@ def extrapolate_knn_with_subsampling(
|
|
148
148
|
def _rgb_with_dimensionality_reduction(
|
149
149
|
features: torch.Tensor,
|
150
150
|
num_sample: int,
|
151
|
-
|
151
|
+
affinity_type: AffinityOptions,
|
152
152
|
rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
|
153
153
|
q: float,
|
154
154
|
knn: int,
|
@@ -162,26 +162,26 @@ def _rgb_with_dimensionality_reduction(
|
|
162
162
|
if True:
|
163
163
|
_subgraph_indices = subsample_features(
|
164
164
|
features=features,
|
165
|
-
|
165
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
166
166
|
config=SampleConfig(method="fps"),
|
167
167
|
)
|
168
168
|
features = extrapolate_knn(
|
169
169
|
anchor_features=features[_subgraph_indices],
|
170
170
|
anchor_output=features[_subgraph_indices],
|
171
171
|
extrapolation_features=features,
|
172
|
-
|
172
|
+
affinity_type=affinity_type,
|
173
173
|
)
|
174
174
|
|
175
175
|
subgraph_indices = subsample_features(
|
176
176
|
features=features,
|
177
|
-
|
177
|
+
distance_type=AFFINITY_TO_DISTANCE[affinity_type],
|
178
178
|
config=SampleConfig(method="fps", num_sample=num_sample),
|
179
179
|
)
|
180
180
|
|
181
181
|
_inp = features[subgraph_indices].numpy(force=True)
|
182
182
|
_subgraph_embed = torch.tensor(reduction(
|
183
183
|
n_components=reduction_dim,
|
184
|
-
metric=
|
184
|
+
metric=AFFINITY_TO_DISTANCE[affinity_type],
|
185
185
|
random_state=seed,
|
186
186
|
**reduction_kwargs
|
187
187
|
).fit_transform(_inp), device=features.device, dtype=features.dtype)
|
@@ -190,7 +190,7 @@ def _rgb_with_dimensionality_reduction(
|
|
190
190
|
features[subgraph_indices],
|
191
191
|
_subgraph_embed,
|
192
192
|
features,
|
193
|
-
|
193
|
+
affinity_type,
|
194
194
|
knn=knn,
|
195
195
|
device=device,
|
196
196
|
move_output_to_cpu=True
|
@@ -201,7 +201,7 @@ def _rgb_with_dimensionality_reduction(
|
|
201
201
|
def rgb_from_tsne_2d(
|
202
202
|
features: torch.Tensor,
|
203
203
|
num_sample: int = 1000,
|
204
|
-
|
204
|
+
affinity_type: AffinityOptions = "cosine",
|
205
205
|
perplexity: int = 150,
|
206
206
|
q: float = 0.95,
|
207
207
|
knn: int = 10,
|
@@ -220,16 +220,12 @@ def rgb_from_tsne_2d(
|
|
220
220
|
"sklearn import failed, please install `pip install scikit-learn`"
|
221
221
|
)
|
222
222
|
num_sample = min(num_sample, features.shape[0])
|
223
|
-
|
224
|
-
logging.warning(
|
225
|
-
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
226
|
-
)
|
227
|
-
perplexity = num_sample // 2
|
223
|
+
perplexity = min(perplexity, num_sample // 2)
|
228
224
|
|
229
225
|
rgb = _rgb_with_dimensionality_reduction(
|
230
226
|
features=features,
|
231
227
|
num_sample=num_sample,
|
232
|
-
|
228
|
+
affinity_type=affinity_type,
|
233
229
|
rgb_func=rgb_from_2d_colormap,
|
234
230
|
q=q,
|
235
231
|
knn=knn,
|
@@ -245,7 +241,7 @@ def rgb_from_tsne_2d(
|
|
245
241
|
def rgb_from_tsne_3d(
|
246
242
|
features: torch.Tensor,
|
247
243
|
num_sample: int = 1000,
|
248
|
-
|
244
|
+
affinity_type: AffinityOptions = "cosine",
|
249
245
|
perplexity: int = 150,
|
250
246
|
q: float = 0.95,
|
251
247
|
knn: int = 10,
|
@@ -264,16 +260,12 @@ def rgb_from_tsne_3d(
|
|
264
260
|
"sklearn import failed, please install `pip install scikit-learn`"
|
265
261
|
)
|
266
262
|
num_sample = min(num_sample, features.shape[0])
|
267
|
-
|
268
|
-
logging.warning(
|
269
|
-
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
270
|
-
)
|
271
|
-
perplexity = num_sample // 2
|
263
|
+
perplexity = min(perplexity, num_sample // 2)
|
272
264
|
|
273
265
|
rgb = _rgb_with_dimensionality_reduction(
|
274
266
|
features=features,
|
275
267
|
num_sample=num_sample,
|
276
|
-
|
268
|
+
affinity_type=affinity_type,
|
277
269
|
rgb_func=rgb_from_3d_rgb_cube,
|
278
270
|
q=q,
|
279
271
|
knn=knn,
|
@@ -289,7 +281,7 @@ def rgb_from_tsne_3d(
|
|
289
281
|
def rgb_from_euclidean_tsne_3d(
|
290
282
|
features: torch.Tensor,
|
291
283
|
num_sample: int = 1000,
|
292
|
-
|
284
|
+
affinity_type: AffinityOptions = "cosine",
|
293
285
|
perplexity: int = 150,
|
294
286
|
q: float = 0.95,
|
295
287
|
knn: int = 10,
|
@@ -308,19 +300,15 @@ def rgb_from_euclidean_tsne_3d(
|
|
308
300
|
"sklearn import failed, please install `pip install scikit-learn`"
|
309
301
|
)
|
310
302
|
num_sample = min(num_sample, features.shape[0])
|
311
|
-
|
312
|
-
logging.warning(
|
313
|
-
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
314
|
-
)
|
315
|
-
perplexity = num_sample // 2
|
303
|
+
perplexity = min(perplexity, num_sample // 2)
|
316
304
|
|
317
305
|
def rgb_func(X_3d: torch.Tensor, q: float) -> torch.Tensor:
|
318
|
-
return rgb_from_3d_rgb_cube(to_euclidean(X_3d,
|
306
|
+
return rgb_from_3d_rgb_cube(to_euclidean(X_3d, AFFINITY_TO_DISTANCE[affinity_type]), q=q)
|
319
307
|
|
320
308
|
rgb = _rgb_with_dimensionality_reduction(
|
321
309
|
features=features,
|
322
310
|
num_sample=num_sample,
|
323
|
-
|
311
|
+
affinity_type=affinity_type,
|
324
312
|
rgb_func=rgb_func,
|
325
313
|
q=q,
|
326
314
|
knn=knn,
|
@@ -336,7 +324,7 @@ def rgb_from_euclidean_tsne_3d(
|
|
336
324
|
def rgb_from_umap_2d(
|
337
325
|
features: torch.Tensor,
|
338
326
|
num_sample: int = 1000,
|
339
|
-
|
327
|
+
affinity_type: AffinityOptions = "cosine",
|
340
328
|
n_neighbors: int = 150,
|
341
329
|
min_dist: float = 0.1,
|
342
330
|
q: float = 0.95,
|
@@ -357,7 +345,7 @@ def rgb_from_umap_2d(
|
|
357
345
|
rgb = _rgb_with_dimensionality_reduction(
|
358
346
|
features=features,
|
359
347
|
num_sample=num_sample,
|
360
|
-
|
348
|
+
affinity_type=affinity_type,
|
361
349
|
rgb_func=rgb_from_2d_colormap,
|
362
350
|
q=q,
|
363
351
|
knn=knn,
|
@@ -374,7 +362,7 @@ def rgb_from_umap_2d(
|
|
374
362
|
def rgb_from_umap_sphere(
|
375
363
|
features: torch.Tensor,
|
376
364
|
num_sample: int = 1000,
|
377
|
-
|
365
|
+
affinity_type: AffinityOptions = "cosine",
|
378
366
|
n_neighbors: int = 150,
|
379
367
|
min_dist: float = 0.1,
|
380
368
|
q: float = 0.95,
|
@@ -402,7 +390,7 @@ def rgb_from_umap_sphere(
|
|
402
390
|
rgb = _rgb_with_dimensionality_reduction(
|
403
391
|
features=features,
|
404
392
|
num_sample=num_sample,
|
405
|
-
|
393
|
+
affinity_type=affinity_type,
|
406
394
|
rgb_func=rgb_func,
|
407
395
|
q=q,
|
408
396
|
knn=knn,
|
@@ -420,7 +408,7 @@ def rgb_from_umap_sphere(
|
|
420
408
|
def rgb_from_umap_3d(
|
421
409
|
features: torch.Tensor,
|
422
410
|
num_sample: int = 1000,
|
423
|
-
|
411
|
+
affinity_type: AffinityOptions = "cosine",
|
424
412
|
n_neighbors: int = 150,
|
425
413
|
min_dist: float = 0.1,
|
426
414
|
q: float = 0.95,
|
@@ -441,7 +429,7 @@ def rgb_from_umap_3d(
|
|
441
429
|
rgb = _rgb_with_dimensionality_reduction(
|
442
430
|
features=features,
|
443
431
|
num_sample=num_sample,
|
444
|
-
|
432
|
+
affinity_type=affinity_type,
|
445
433
|
rgb_func=rgb_from_3d_rgb_cube,
|
446
434
|
q=q,
|
447
435
|
knn=knn,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.0
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -0,0 +1,18 @@
|
|
1
|
+
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
|
3
|
+
nystrom_ncut/common.py,sha256=eie19AHTMk6AGTxNnYq1UcFkHJVimeywAUYryXwaiHk,2428
|
4
|
+
nystrom_ncut/distance_utils.py,sha256=pJA8NcIKyS7-YDpRGOkc7mwBQQEsYVemdkHiTjyU4n8,4300
|
5
|
+
nystrom_ncut/sampling_utils.py,sha256=6lP8F6gftl4mgkavPsD7Vuk4erj4RtgILPhcj3YqLXk,4840
|
6
|
+
nystrom_ncut/visualize_utils.py,sha256=Sfi_kKpvFFzBFoJnbo-pQpH2jhs-A6tH64SV_WGoq58,22740
|
7
|
+
nystrom_ncut/nystrom/__init__.py,sha256=1aUXK87g4cXRXqNt6XkZsfyauw1-yv3sv0NmdmkWo-8,42
|
8
|
+
nystrom_ncut/nystrom/distance_realization.py,sha256=RTI1_Q8fCUGAPSbXaVuNA-2B-11CEAfy2CwKWPJj6xQ,5830
|
9
|
+
nystrom_ncut/nystrom/normalized_cut.py,sha256=jB_QALMY3l5CFfZPsrOFpEaquTrJP17muTrDZXxzUA8,7177
|
10
|
+
nystrom_ncut/nystrom/nystrom_utils.py,sha256=hksDO8uuAb9xKoA1ZafGwXDlQN_gZJn_qHscaSoO8JE,14120
|
11
|
+
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
+
nystrom_ncut/transformer/axis_align.py,sha256=6LTR-syJ-f4pcbnMexFmFNn1QADDhH5gka6979YBRrI,3549
|
13
|
+
nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
|
14
|
+
nystrom_ncut-0.3.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
+
nystrom_ncut-0.3.0.dist-info/METADATA,sha256=lhxicufu5Eo9HQsUiS_K-CzocemOeNravAaIXeCtriM,6058
|
16
|
+
nystrom_ncut-0.3.0.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
17
|
+
nystrom_ncut-0.3.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
+
nystrom_ncut-0.3.0.dist-info/RECORD,,
|
@@ -1,18 +0,0 @@
|
|
1
|
-
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
|
3
|
-
nystrom_ncut/common.py,sha256=_PGJoImSk_Fb_5Ri-e_IsFoCcSfbGS8CxYUUHVoNM50,2036
|
4
|
-
nystrom_ncut/distance_utils.py,sha256=p-pYdpRrJsIhzxM_IxUqja7N8okngx52WGXD9pu_Aec,3129
|
5
|
-
nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHNU,3105
|
6
|
-
nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
|
7
|
-
nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
|
8
|
-
nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
|
9
|
-
nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
|
10
|
-
nystrom_ncut/nystrom/nystrom_utils.py,sha256=MMRMq7N_JYM2whSHIhCoPA-SQ28wb9hC8u2CZNmRRN8,12836
|
11
|
-
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
-
nystrom_ncut/transformer/axis_align.py,sha256=8PYtSTChHDTrh5SYdhl1ALsUUPJHd9ojQRM1e6KTbHc,3579
|
13
|
-
nystrom_ncut/transformer/transformer_mixin.py,sha256=kVjODpIHB6noC7yGY-QPf67Ep58o53wrEMKSNjhhChI,1714
|
14
|
-
nystrom_ncut-0.2.2.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
-
nystrom_ncut-0.2.2.dist-info/METADATA,sha256=RGX2HMT2uF9bUB_4qecpX87bY6N1gg6-QKXvhJVnzIo,6058
|
16
|
-
nystrom_ncut-0.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
17
|
-
nystrom_ncut-0.2.2.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
-
nystrom_ncut-0.2.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|