nystrom-ncut 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/nystrom/normalized_cut.py +4 -4
- nystrom_ncut/transformer/axis_align.py +58 -46
- {nystrom_ncut-0.3.0.dist-info → nystrom_ncut-0.3.2.dist-info}/METADATA +1 -1
- {nystrom_ncut-0.3.0.dist-info → nystrom_ncut-0.3.2.dist-info}/RECORD +7 -7
- {nystrom_ncut-0.3.0.dist-info → nystrom_ncut-0.3.2.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.3.0.dist-info → nystrom_ncut-0.3.2.dist-info}/WHEEL +0 -0
- {nystrom_ncut-0.3.0.dist-info → nystrom_ncut-0.3.2.dist-info}/top_level.txt +0 -0
@@ -54,10 +54,10 @@ class LaplacianKernel(OnlineKernel):
|
|
54
54
|
self.A,
|
55
55
|
num_eig=d + 1, # d * (d + 3) // 2 + 1,
|
56
56
|
eig_solver=self.eig_solver,
|
57
|
-
)
|
58
|
-
self.Ainv = U @ torch.diag_embed(1 / L) @ U.mT
|
59
|
-
self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1))
|
60
|
-
self.b_r = torch.zeros_like(self.a_r)
|
57
|
+
) # [... x n x (d + 1)], [... x (d + 1)]
|
58
|
+
self.Ainv = U @ torch.nan_to_num(torch.diag_embed(1 / L), posinf=0.0, neginf=0.0) @ U.mT # [... x n x n]
|
59
|
+
self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
|
60
|
+
self.b_r = torch.zeros_like(self.a_r) # [... x n]
|
61
61
|
|
62
62
|
def _affinity(self, features: torch.Tensor) -> torch.Tensor:
|
63
63
|
B = torch.where(self.anchor_mask[..., None], 0.0, affinity_from_features(
|
@@ -3,6 +3,9 @@ from typing import Literal
|
|
3
3
|
import torch
|
4
4
|
import torch.nn.functional as Fn
|
5
5
|
|
6
|
+
from ..common import (
|
7
|
+
default_device,
|
8
|
+
)
|
6
9
|
from .transformer_mixin import (
|
7
10
|
TorchTransformerMixin,
|
8
11
|
)
|
@@ -27,50 +30,59 @@ class AxisAlign(TorchTransformerMixin):
|
|
27
30
|
|
28
31
|
def fit(self, X: torch.Tensor) -> "AxisAlign":
|
29
32
|
# Normalize eigenvectors
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
self.R[
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
33
|
+
with default_device(X.device):
|
34
|
+
d = X.shape[-1]
|
35
|
+
normalized_X = Fn.normalize(X, p=2, dim=-1) # float: [... x n x d]
|
36
|
+
|
37
|
+
# Initialize R matrix with the first column from a random row of EigenVectors
|
38
|
+
def get_idx(idx: torch.Tensor) -> torch.Tensor:
|
39
|
+
return torch.gather(normalized_X, -2, idx[..., None, None].expand([-1] * (X.ndim - 2) + [1, d]))[..., 0, :]
|
40
|
+
|
41
|
+
self.R = torch.empty((*X.shape[:-2], d, d)) # float: [... x d x d]
|
42
|
+
mask = torch.all(torch.isfinite(normalized_X), dim=-1) # bool: [... x n]
|
43
|
+
start_idx = torch.argmax(mask.to(torch.float) + torch.rand(mask.shape), dim=-1) # int: [...]
|
44
|
+
self.R[..., 0, :] = get_idx(start_idx)
|
45
|
+
|
46
|
+
# Loop to populate R with k orthogonal directions
|
47
|
+
c = torch.zeros(X.shape[:-1]) # float: [... x n]
|
48
|
+
for i in range(1, d):
|
49
|
+
c += torch.abs(normalized_X @ self.R[..., i - 1, :, None])[..., 0]
|
50
|
+
self.R[..., i, :] = get_idx(torch.argmin(c.nan_to_num(nan=torch.inf), dim=-1))
|
51
|
+
|
52
|
+
# Iterative optimization loop
|
53
|
+
normalized_X = torch.nan_to_num(normalized_X, nan=0.0)
|
54
|
+
idx, prev_objective = None, torch.inf
|
55
|
+
for _ in range(self.max_iter):
|
56
|
+
# Discretize the projected eigenvectors
|
57
|
+
idx = torch.argmax(normalized_X @ self.R.mT, dim=-1) # int: [... x n]
|
58
|
+
M = torch.sum((idx[..., None] == torch.arange(d))[..., None] * normalized_X[..., :, None, :], dim=-3) # float: [... x d x d]
|
59
|
+
|
60
|
+
# Check for convergence
|
61
|
+
objective = torch.norm(M)
|
62
|
+
if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
|
63
|
+
break
|
64
|
+
prev_objective = objective
|
65
|
+
|
66
|
+
# SVD decomposition to compute the next R
|
67
|
+
U, S, Vh = torch.linalg.svd(M, full_matrices=False)
|
68
|
+
self.R = U @ Vh
|
69
|
+
|
70
|
+
# Permute the rotation matrix so the dimensions are sorted in descending cluster significance
|
71
|
+
match self.sort_method:
|
72
|
+
case "count":
|
73
|
+
sort_metric = torch.sum((idx[..., None] == torch.arange(d)), dim=-2)
|
74
|
+
case "norm":
|
75
|
+
rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
|
76
|
+
sort_metric = torch.linalg.norm(rotated_X, dim=-2)
|
77
|
+
case "marginal_norm":
|
78
|
+
rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
|
79
|
+
sort_metric = torch.sum((idx[..., None] == torch.arange(d)) * (torch.gather(rotated_X, -1, idx[..., None]) ** 2), dim=-2)
|
80
|
+
case _:
|
81
|
+
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
82
|
+
|
83
|
+
order = torch.argsort(sort_metric, dim=-1, descending=True)
|
84
|
+
self.R = torch.gather(self.R, -2, order[..., None].expand([-1] * order.ndim + [d]))
|
85
|
+
return self
|
74
86
|
|
75
87
|
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
76
88
|
"""
|
@@ -82,9 +94,9 @@ class AxisAlign(TorchTransformerMixin):
|
|
82
94
|
torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
|
83
95
|
"""
|
84
96
|
if normalize:
|
85
|
-
X = Fn.normalize(X, p=2, dim
|
97
|
+
X = Fn.normalize(X, p=2, dim=-1)
|
86
98
|
rotated_X = X @ self.R.mT
|
87
|
-
return torch.argmax(rotated_X, dim
|
99
|
+
return torch.argmax(rotated_X, dim=-1) if hard else rotated_X
|
88
100
|
|
89
101
|
def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
90
102
|
return self.fit(X).transform(X, normalize=normalize, hard=hard)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.2
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -6,13 +6,13 @@ nystrom_ncut/sampling_utils.py,sha256=6lP8F6gftl4mgkavPsD7Vuk4erj4RtgILPhcj3YqLX
|
|
6
6
|
nystrom_ncut/visualize_utils.py,sha256=Sfi_kKpvFFzBFoJnbo-pQpH2jhs-A6tH64SV_WGoq58,22740
|
7
7
|
nystrom_ncut/nystrom/__init__.py,sha256=1aUXK87g4cXRXqNt6XkZsfyauw1-yv3sv0NmdmkWo-8,42
|
8
8
|
nystrom_ncut/nystrom/distance_realization.py,sha256=RTI1_Q8fCUGAPSbXaVuNA-2B-11CEAfy2CwKWPJj6xQ,5830
|
9
|
-
nystrom_ncut/nystrom/normalized_cut.py,sha256=
|
9
|
+
nystrom_ncut/nystrom/normalized_cut.py,sha256=cjkG8JeDmTPDK8KwfkAIqF9f1dI-D9s1muJ9WWZlUoc,7237
|
10
10
|
nystrom_ncut/nystrom/nystrom_utils.py,sha256=hksDO8uuAb9xKoA1ZafGwXDlQN_gZJn_qHscaSoO8JE,14120
|
11
11
|
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
-
nystrom_ncut/transformer/axis_align.py,sha256=
|
12
|
+
nystrom_ncut/transformer/axis_align.py,sha256=j3LlAPrp8O_jQAlwZz-gu3D7n_wICEJranye-YK5wvA,4880
|
13
13
|
nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
|
14
|
-
nystrom_ncut-0.3.
|
15
|
-
nystrom_ncut-0.3.
|
16
|
-
nystrom_ncut-0.3.
|
17
|
-
nystrom_ncut-0.3.
|
18
|
-
nystrom_ncut-0.3.
|
14
|
+
nystrom_ncut-0.3.2.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
+
nystrom_ncut-0.3.2.dist-info/METADATA,sha256=4E42fHnXLNvbErWrwxE5K_oeOdpi5Bfabed9VF-YkV0,6058
|
16
|
+
nystrom_ncut-0.3.2.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
17
|
+
nystrom_ncut-0.3.2.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
+
nystrom_ncut-0.3.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|