nystrom-ncut 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/common.py +18 -5
- nystrom_ncut/distance_utils.py +54 -32
- nystrom_ncut/nystrom/__init__.py +0 -3
- nystrom_ncut/nystrom/distance_realization.py +8 -9
- nystrom_ncut/nystrom/normalized_cut.py +51 -47
- nystrom_ncut/nystrom/nystrom_utils.py +78 -70
- nystrom_ncut/sampling_utils.py +64 -51
- nystrom_ncut/transformer/axis_align.py +58 -47
- nystrom_ncut/transformer/transformer_mixin.py +0 -2
- nystrom_ncut/visualize_utils.py +31 -43
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.1.dist-info}/METADATA +1 -1
- nystrom_ncut-0.3.1.dist-info/RECORD +18 -0
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.1.dist-info}/WHEEL +1 -1
- nystrom_ncut-0.2.2.dist-info/RECORD +0 -18
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.1.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.2.2.dist-info → nystrom_ncut-0.3.1.dist-info}/top_level.txt +0 -0
@@ -25,15 +25,15 @@ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
|
|
25
25
|
|
26
26
|
class OnlineKernel:
|
27
27
|
@abstractmethod
|
28
|
-
def fit(self, features: torch.Tensor) -> "OnlineKernel": # [n x d]
|
28
|
+
def fit(self, features: torch.Tensor) -> "OnlineKernel": # [... x n x d]
|
29
29
|
""""""
|
30
30
|
|
31
31
|
@abstractmethod
|
32
|
-
def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
|
32
|
+
def update(self, features: torch.Tensor) -> torch.Tensor: # [... x m x d] -> [... x m x n]
|
33
33
|
""""""
|
34
34
|
|
35
35
|
@abstractmethod
|
36
|
-
def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
|
36
|
+
def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [... x m x d] -> [... x m x n]
|
37
37
|
""""""
|
38
38
|
|
39
39
|
|
@@ -54,20 +54,21 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
54
54
|
self.n_components: int = n_components
|
55
55
|
self.kernel: OnlineKernel = kernel
|
56
56
|
self.eig_solver: EigSolverOptions = eig_solver
|
57
|
+
self.shape: torch.Size = None # ...
|
57
58
|
|
58
59
|
self.chunk_size = chunk_size
|
59
60
|
|
60
61
|
# Anchor matrices
|
61
|
-
self.anchor_features: torch.Tensor = None # [n x d]
|
62
|
-
self.A: torch.Tensor = None # [n x n]
|
63
|
-
self.Ahinv: torch.Tensor = None # [n x n]
|
64
|
-
self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
|
65
|
-
self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
|
62
|
+
self.anchor_features: torch.Tensor = None # [... x n x d]
|
63
|
+
self.A: torch.Tensor = None # [... x n x n]
|
64
|
+
self.Ahinv: torch.Tensor = None # [... x n x n]
|
65
|
+
self.Ahinv_UL: torch.Tensor = None # [... x n x indirect_pca_dim]
|
66
|
+
self.Ahinv_VT: torch.Tensor = None # [... x indirect_pca_dim x n]
|
66
67
|
|
67
68
|
# Updated matrices
|
68
|
-
self.S: torch.Tensor = None # [n x n]
|
69
|
-
self.transform_matrix: torch.Tensor = None # [n x n_components]
|
70
|
-
self.eigenvalues_: torch.Tensor = None # [n]
|
69
|
+
self.S: torch.Tensor = None # [... x n x n]
|
70
|
+
self.transform_matrix: torch.Tensor = None # [... x n x n_components]
|
71
|
+
self.eigenvalues_: torch.Tensor = None # [... x n]
|
71
72
|
|
72
73
|
def _update_to_kernel(self, d: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
73
74
|
self.A = self.S = self.kernel.transform()
|
@@ -75,10 +76,10 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
75
76
|
self.A,
|
76
77
|
num_eig=d + 1, # d * (d + 3) // 2 + 1,
|
77
78
|
eig_solver=self.eig_solver,
|
78
|
-
) # [n x (? + 1)], [? + 1]
|
79
|
-
self.Ahinv_UL = U * (L ** -0.5)
|
80
|
-
self.Ahinv_VT = U.mT # [(? + 1) x n]
|
81
|
-
self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
|
79
|
+
) # [... x n x (? + 1)], [... x (? + 1)]
|
80
|
+
self.Ahinv_UL = U * (L[..., None, :] ** -0.5) # [... x n x (? + 1)]
|
81
|
+
self.Ahinv_VT = U.mT # [... x (? + 1) x n]
|
82
|
+
self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [... x n x n]
|
82
83
|
return U, L
|
83
84
|
|
84
85
|
def fit(self, features: torch.Tensor) -> "OnlineNystrom":
|
@@ -89,65 +90,63 @@ class OnlineNystrom(TorchTransformerMixin):
|
|
89
90
|
self.anchor_features = features
|
90
91
|
|
91
92
|
self.kernel.fit(self.anchor_features)
|
92
|
-
U, L = self._update_to_kernel(features.shape[-1]) # [n x (d + 1)], [d + 1]
|
93
|
+
U, L = self._update_to_kernel(features.shape[-1]) # [... x n x (d + 1)], [... x (d + 1)]
|
93
94
|
|
94
|
-
self.transform_matrix = (U / L)[:, :self.n_components]
|
95
|
-
self.eigenvalues_ = L[:self.n_components]
|
96
|
-
self.
|
97
|
-
return U[:, :self.n_components] # [n x n_components]
|
95
|
+
self.transform_matrix = (U / L[..., None, :])[..., :, :self.n_components] # [... x n x n_components]
|
96
|
+
self.eigenvalues_ = L[..., :self.n_components] # [... x n_components]
|
97
|
+
return U[..., :, :self.n_components] # [... x n x n_components]
|
98
98
|
|
99
99
|
def update(self, features: torch.Tensor) -> torch.Tensor:
|
100
100
|
d = features.shape[-1]
|
101
|
-
n_chunks = ceildiv(
|
101
|
+
n_chunks = ceildiv(features.shape[-2], self.chunk_size)
|
102
102
|
if n_chunks > 1:
|
103
103
|
""" Chunked version """
|
104
|
-
chunks = torch.chunk(features, n_chunks, dim
|
104
|
+
chunks = torch.chunk(features, n_chunks, dim=-2)
|
105
105
|
for chunk in chunks:
|
106
106
|
self.kernel.update(chunk)
|
107
107
|
self._update_to_kernel(d)
|
108
108
|
|
109
|
-
compressed_BBT = 0.0 # [(? + 1) x (? + 1))]
|
109
|
+
compressed_BBT = 0.0 # [... x (? + 1) x (? + 1))]
|
110
110
|
for chunk in chunks:
|
111
|
-
_B = self.kernel.transform(chunk).mT # [n x _m]
|
112
|
-
_compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
self.
|
111
|
+
_B = self.kernel.transform(chunk).mT # [... x n x _m]
|
112
|
+
_compressed_B = self.Ahinv_VT @ _B # [... x (? + 1) x _m]
|
113
|
+
_compressed_B = torch.nan_to_num(_compressed_B, nan=0.0)
|
114
|
+
compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [... x (? + 1) x (? + 1)]
|
115
|
+
self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [... x n x n]
|
116
|
+
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
|
117
|
+
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
|
117
118
|
|
118
119
|
VS = []
|
119
120
|
for chunk in chunks:
|
120
|
-
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
|
121
|
-
VS = torch.cat(VS, dim
|
122
|
-
return VS # [m x n_components]
|
121
|
+
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
|
122
|
+
VS = torch.cat(VS, dim=-2)
|
123
|
+
return VS # [... x m x n_components]
|
123
124
|
else:
|
124
125
|
""" Unchunked version """
|
125
|
-
B = self.kernel.update(features).mT # [n x m]
|
126
|
+
B = self.kernel.update(features).mT # [... x n x m]
|
126
127
|
self._update_to_kernel(d)
|
127
|
-
compressed_B = self.Ahinv_VT @ B # [
|
128
|
+
compressed_B = self.Ahinv_VT @ B # [... x (? + 1) x m]
|
129
|
+
compressed_B = torch.nan_to_num(compressed_B, nan=0.0)
|
128
130
|
|
129
|
-
self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
|
130
|
-
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
|
131
|
-
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5)
|
131
|
+
self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [... x n x n]
|
132
|
+
US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
|
133
|
+
self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
|
132
134
|
|
133
|
-
return B.mT @ self.transform_matrix # [m x n_components]
|
135
|
+
return B.mT @ self.transform_matrix # [... x m x n_components]
|
134
136
|
|
135
|
-
def transform(self, features: torch.Tensor
|
136
|
-
|
137
|
-
|
137
|
+
def transform(self, features: torch.Tensor) -> torch.Tensor:
|
138
|
+
n_chunks = ceildiv(features.shape[-2], self.chunk_size)
|
139
|
+
if n_chunks > 1:
|
140
|
+
""" Chunked version """
|
141
|
+
chunks = torch.chunk(features, n_chunks, dim=-2)
|
142
|
+
VS = []
|
143
|
+
for chunk in chunks:
|
144
|
+
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
|
145
|
+
VS = torch.cat(VS, dim=-2)
|
138
146
|
else:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
chunks = torch.chunk(features, n_chunks, dim=0)
|
143
|
-
VS = []
|
144
|
-
for chunk in chunks:
|
145
|
-
VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
|
146
|
-
VS = torch.cat(VS, dim=0)
|
147
|
-
else:
|
148
|
-
""" Unchunked version """
|
149
|
-
VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
|
150
|
-
return VS # [m x n_components]
|
147
|
+
""" Unchunked version """
|
148
|
+
VS = self.kernel.transform(features) @ self.transform_matrix # [... x m x n_components]
|
149
|
+
return VS # [... x m x n_components]
|
151
150
|
|
152
151
|
|
153
152
|
class OnlineNystromSubsampleFit(OnlineNystrom):
|
@@ -155,7 +154,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
155
154
|
self,
|
156
155
|
n_components: int,
|
157
156
|
kernel: OnlineKernel,
|
158
|
-
|
157
|
+
distance_type: DistanceOptions,
|
159
158
|
sample_config: SampleConfig,
|
160
159
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
161
160
|
chunk_size: int = 8192,
|
@@ -167,7 +166,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
167
166
|
eig_solver=eig_solver,
|
168
167
|
chunk_size=chunk_size,
|
169
168
|
)
|
170
|
-
self.
|
169
|
+
self.distance_type: DistanceOptions = distance_type
|
171
170
|
self.sample_config: SampleConfig = sample_config
|
172
171
|
self.sample_config._ncut_obj = copy.deepcopy(self)
|
173
172
|
self.anchor_indices: torch.Tensor = None
|
@@ -177,7 +176,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
177
176
|
features: torch.Tensor,
|
178
177
|
precomputed_sampled_indices: torch.Tensor,
|
179
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
180
|
-
_n = features.shape[
|
179
|
+
_n = features.shape[-2]
|
181
180
|
if self.sample_config.num_sample >= _n:
|
182
181
|
logging.info(
|
183
182
|
f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
|
@@ -189,16 +188,17 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
189
188
|
else:
|
190
189
|
self.anchor_indices = subsample_features(
|
191
190
|
features=features,
|
192
|
-
|
191
|
+
distance_type=self.distance_type,
|
193
192
|
config=self.sample_config,
|
194
193
|
)
|
195
|
-
sampled_features = features[self.anchor_indices]
|
194
|
+
sampled_features = torch.gather(features, -2, self.anchor_indices[..., None].expand([-1] * self.anchor_indices.ndim + [features.shape[-1]]))
|
196
195
|
OnlineNystrom.fit(self, sampled_features)
|
197
196
|
|
198
|
-
_n_not_sampled = _n -
|
197
|
+
_n_not_sampled = _n - self.anchor_indices.shape[-1]
|
199
198
|
if _n_not_sampled > 0:
|
200
|
-
|
201
|
-
|
199
|
+
unsampled_mask = torch.full(features.shape[:-1], True, device=features.device).scatter_(-1, self.anchor_indices, False)
|
200
|
+
unsampled_indices = torch.where(unsampled_mask)[-1].view((*features.shape[:-2], -1))
|
201
|
+
unsampled_features = torch.gather(features, -2, unsampled_indices[..., None].expand([-1] * unsampled_indices.ndim + [features.shape[-1]]))
|
202
202
|
V_unsampled = OnlineNystrom.update(self, unsampled_features)
|
203
203
|
else:
|
204
204
|
unsampled_indices = V_unsampled = None
|
@@ -236,12 +236,12 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
|
|
236
236
|
(torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
|
237
237
|
"""
|
238
238
|
unsampled_indices, V_unsampled = OnlineNystromSubsampleFit._fit_helper(self, features, precomputed_sampled_indices)
|
239
|
-
V_sampled = OnlineNystrom.transform(self)
|
239
|
+
V_sampled = OnlineNystrom.transform(self, self.anchor_features)
|
240
240
|
|
241
241
|
if unsampled_indices is not None:
|
242
|
-
V = torch.zeros((
|
243
|
-
|
244
|
-
|
242
|
+
V = torch.zeros((*features.shape[:-1], self.n_components), device=features.device)
|
243
|
+
for (indices, _V) in [(self.anchor_indices, V_sampled), (unsampled_indices, V_unsampled)]:
|
244
|
+
V.scatter_(-2, indices[..., None].expand([-1] * indices.ndim + [self.n_components]), _V)
|
245
245
|
else:
|
246
246
|
V = V_sampled
|
247
247
|
return V
|
@@ -264,12 +264,16 @@ def solve_eig(
|
|
264
264
|
(torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
|
265
265
|
(torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
|
266
266
|
"""
|
267
|
-
|
267
|
+
shape: torch.Size = A.shape[:-2]
|
268
|
+
A = A.view((-1, *A.shape[-2:]))
|
269
|
+
bsz: int = A.shape[0]
|
270
|
+
|
271
|
+
A = A + eig_value_buffer * torch.eye(A.shape[-1], device=A.device)
|
268
272
|
|
269
273
|
# compute eigenvectors
|
270
274
|
if eig_solver == "svd_lowrank": # default
|
271
275
|
# only top q eigenvectors, fastest
|
272
|
-
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
|
276
|
+
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig) # complex: [(...) x N x D], [(...) x D]
|
273
277
|
elif eig_solver == "lobpcg":
|
274
278
|
# only top k eigenvectors, fast
|
275
279
|
eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
|
@@ -286,11 +290,15 @@ def solve_eig(
|
|
286
290
|
eigen_value = eigen_value - eig_value_buffer
|
287
291
|
|
288
292
|
# sort eigenvectors by eigenvalues, take top (descending order)
|
289
|
-
indices = torch.topk(eigen_value.abs(), k=num_eig, dim
|
290
|
-
eigen_value
|
293
|
+
indices = torch.topk(eigen_value.abs(), k=num_eig, dim=-1).indices # int: [(...) x S]
|
294
|
+
eigen_value = eigen_value[torch.arange(bsz)[:, None], indices] # complex: [(...) x S]
|
295
|
+
eigen_vector = eigen_vector[torch.arange(bsz)[:, None], :, indices].mT # complex: [(...) x N x S]
|
291
296
|
|
292
297
|
# correct the random rotation (flipping sign) of eigenvectors
|
293
|
-
sign = torch.sum(eigen_vector.real, dim=
|
298
|
+
sign = torch.sign(torch.sum(eigen_vector.real, dim=-2, keepdim=True)) # float: [(...) x 1 x S]
|
294
299
|
sign[sign == 0] = 1.0
|
295
300
|
eigen_vector = eigen_vector * sign
|
301
|
+
|
302
|
+
eigen_value = eigen_value.view((*shape, *eigen_value.shape[-1:])) # complex: [... x S]
|
303
|
+
eigen_vector = eigen_vector.view((*shape, *eigen_vector.shape[-2:])) # complex: [... x N x S]
|
296
304
|
return eigen_vector, eigen_value
|
nystrom_ncut/sampling_utils.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1
|
-
import logging
|
2
1
|
from dataclasses import dataclass
|
3
2
|
from typing import Literal
|
4
3
|
|
5
4
|
import torch
|
6
5
|
from pytorch3d.ops import sample_farthest_points
|
7
6
|
|
7
|
+
from .common import (
|
8
|
+
default_device,
|
9
|
+
)
|
8
10
|
from .distance_utils import (
|
9
11
|
DistanceOptions,
|
10
12
|
to_euclidean,
|
11
13
|
)
|
14
|
+
from .transformer import (
|
15
|
+
TorchTransformerMixin,
|
16
|
+
)
|
12
17
|
|
13
18
|
|
14
|
-
SampleOptions = Literal["random", "fps", "fps_recursive"]
|
19
|
+
SampleOptions = Literal["full", "random", "fps", "fps_recursive"]
|
15
20
|
|
16
21
|
|
17
22
|
@dataclass
|
@@ -20,69 +25,77 @@ class SampleConfig:
|
|
20
25
|
num_sample: int = 10000
|
21
26
|
fps_dim: int = 12
|
22
27
|
n_iter: int = None
|
23
|
-
_ncut_obj:
|
28
|
+
_ncut_obj: TorchTransformerMixin = None
|
24
29
|
|
25
30
|
|
26
31
|
@torch.no_grad()
|
27
32
|
def subsample_features(
|
28
33
|
features: torch.Tensor,
|
29
|
-
|
34
|
+
distance_type: DistanceOptions,
|
30
35
|
config: SampleConfig,
|
31
|
-
max_draw: int = 1000000,
|
32
36
|
):
|
33
|
-
features = features.detach()
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
"num_sample is larger than total, bypass Nystrom-like approximation"
|
38
|
-
)
|
39
|
-
sampled_indices = torch.arange(features.shape[0])
|
40
|
-
else:
|
41
|
-
# sample subgraph
|
42
|
-
if config.method == "fps": # default
|
43
|
-
features = to_euclidean(features, disttype)
|
44
|
-
if config.num_sample > max_draw:
|
45
|
-
logging.warning(
|
46
|
-
f"num_sample is larger than max_draw, apply farthest point sampling on random sampled {max_draw} samples"
|
47
|
-
)
|
48
|
-
draw_indices = torch.randperm(features.shape[0])[:max_draw]
|
49
|
-
sampled_indices = fpsample(features[draw_indices], config)
|
50
|
-
sampled_indices = draw_indices[sampled_indices]
|
51
|
-
else:
|
52
|
-
sampled_indices = fpsample(features, config)
|
53
|
-
|
54
|
-
elif config.method == "random": # not recommended
|
55
|
-
sampled_indices = torch.randperm(features.shape[0])[:config.num_sample]
|
56
|
-
|
57
|
-
elif config.method == "fps_recursive":
|
58
|
-
features = to_euclidean(features, disttype)
|
59
|
-
sampled_indices = subsample_features(
|
60
|
-
features=features,
|
61
|
-
disttype=disttype,
|
62
|
-
config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
|
63
|
-
)
|
64
|
-
nc = config._ncut_obj
|
65
|
-
for _ in range(config.n_iter):
|
66
|
-
fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
|
67
|
-
|
68
|
-
fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
|
69
|
-
sampled_indices = torch.sort(fpsample(fps_features, config)).values
|
37
|
+
features = features.detach() # float: [... x n x d]
|
38
|
+
with default_device(features.device):
|
39
|
+
if config.method == "full" or config.num_sample >= features.shape[0]:
|
40
|
+
sampled_indices = torch.arange(features.shape[-2]).expand(features.shape[:-1]) # int: [... x n]
|
70
41
|
else:
|
71
|
-
|
72
|
-
|
73
|
-
|
42
|
+
# sample
|
43
|
+
match config.method:
|
44
|
+
case "fps": # default
|
45
|
+
sampled_indices = fpsample(to_euclidean(features, distance_type), config)
|
46
|
+
|
47
|
+
case "random": # not recommended
|
48
|
+
mask = torch.all(torch.isfinite(features), dim=-1) # bool: [... x n]
|
49
|
+
weights = mask.to(torch.float) + torch.rand(mask.shape) # float: [... x n]
|
50
|
+
sampled_indices = torch.topk(weights, k=config.num_sample, dim=-1).indices # int: [... x num_sample]
|
51
|
+
|
52
|
+
case "fps_recursive":
|
53
|
+
features = to_euclidean(features, distance_type) # float: [... x n x d]
|
54
|
+
sampled_indices = subsample_features(
|
55
|
+
features=features,
|
56
|
+
distance_type=distance_type,
|
57
|
+
config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
|
58
|
+
) # int: [... x num_sample]
|
59
|
+
nc = config._ncut_obj
|
60
|
+
for _ in range(config.n_iter):
|
61
|
+
fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
|
62
|
+
|
63
|
+
fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
|
64
|
+
sampled_indices = torch.sort(fpsample(fps_features, config), dim=-1).values
|
65
|
+
|
66
|
+
case _:
|
67
|
+
raise ValueError("sample_method should be 'farthest' or 'random'")
|
68
|
+
sampled_indices = torch.sort(sampled_indices, dim=-1).values
|
69
|
+
return sampled_indices
|
74
70
|
|
75
71
|
|
76
72
|
def fpsample(
|
77
73
|
features: torch.Tensor,
|
78
74
|
config: SampleConfig,
|
79
75
|
):
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
76
|
+
shape = features.shape[:-2] # ...
|
77
|
+
features = features.view((-1, *features.shape[-2:])) # [(...) x n x d]
|
78
|
+
bsz = features.shape[0]
|
79
|
+
|
80
|
+
mask = torch.all(torch.isfinite(features), dim=-1) # bool: [(...) x n]
|
81
|
+
count = torch.sum(mask, dim=-1) # int: [(...)]
|
82
|
+
order = torch.topk(mask.to(torch.int), k=torch.max(count).item(), dim=-1).indices # int: [(...) x max_count]
|
83
|
+
|
84
|
+
features = torch.nan_to_num(features[torch.arange(bsz)[:, None], order], nan=0.0) # float: [(...) x max_count x d]
|
85
|
+
if features.shape[-1] > config.fps_dim:
|
86
|
+
U, S, V = torch.pca_lowrank(features, q=config.fps_dim) # float: [(...) x max_count x fps_dim], [(...) x fps_dim], [(...) x fps_dim x fps_dim]
|
87
|
+
features = U * S[..., None, :] # float: [(...) x max_count x fps_dim]
|
84
88
|
|
85
89
|
try:
|
86
|
-
|
90
|
+
sample_indices = sample_farthest_points(
|
91
|
+
features, lengths=count, K=config.num_sample
|
92
|
+
)[1] # int: [(...) x num_sample]
|
87
93
|
except RuntimeError:
|
88
|
-
|
94
|
+
original_device = features.device
|
95
|
+
alternative_device = "cuda" if original_device == "cpu" else "cpu"
|
96
|
+
sample_indices = sample_farthest_points(
|
97
|
+
features.to(alternative_device), lengths=count.to(alternative_device), K=config.num_sample
|
98
|
+
)[1].to(original_device) # int: [(...) x num_sample]
|
99
|
+
sample_indices = torch.gather(order, 1, sample_indices) # int: [(...) x num_sample]
|
100
|
+
|
101
|
+
return sample_indices.view((*shape, *sample_indices.shape[-1:])) # int: [... x num_sample]
|
@@ -3,6 +3,9 @@ from typing import Literal
|
|
3
3
|
import torch
|
4
4
|
import torch.nn.functional as Fn
|
5
5
|
|
6
|
+
from ..common import (
|
7
|
+
default_device,
|
8
|
+
)
|
6
9
|
from .transformer_mixin import (
|
7
10
|
TorchTransformerMixin,
|
8
11
|
)
|
@@ -27,51 +30,59 @@ class AxisAlign(TorchTransformerMixin):
|
|
27
30
|
|
28
31
|
def fit(self, X: torch.Tensor) -> "AxisAlign":
|
29
32
|
# Normalize eigenvectors
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
self.R[
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
33
|
+
with default_device(X.device):
|
34
|
+
d = X.shape[-1]
|
35
|
+
normalized_X = Fn.normalize(X, p=2, dim=-1) # float: [... x n x d]
|
36
|
+
|
37
|
+
# Initialize R matrix with the first column from a random row of EigenVectors
|
38
|
+
def get_idx(idx: torch.Tensor) -> torch.Tensor:
|
39
|
+
return torch.gather(normalized_X, -2, idx[..., None, None].expand([-1] * (X.ndim - 2) + [1, d]))[..., 0, :]
|
40
|
+
|
41
|
+
self.R = torch.empty((*X.shape[:-2], d, d)) # float: [... x d x d]
|
42
|
+
mask = torch.all(torch.isfinite(normalized_X), dim=-1) # bool: [... x n]
|
43
|
+
start_idx = torch.argmax(mask.to(torch.float) + torch.rand(mask.shape), dim=-1) # int: [...]
|
44
|
+
self.R[..., 0, :] = get_idx(start_idx)
|
45
|
+
|
46
|
+
# Loop to populate R with k orthogonal directions
|
47
|
+
c = torch.zeros(X.shape[:-1]) # float: [... x n]
|
48
|
+
for i in range(1, d):
|
49
|
+
c += torch.abs(normalized_X @ self.R[..., i - 1, :, None])[..., 0]
|
50
|
+
self.R[..., i, :] = get_idx(torch.argmin(c.nan_to_num(nan=torch.inf), dim=-1))
|
51
|
+
|
52
|
+
# Iterative optimization loop
|
53
|
+
normalized_X = torch.nan_to_num(normalized_X, nan=0.0)
|
54
|
+
idx, prev_objective = None, torch.inf
|
55
|
+
for _ in range(self.max_iter):
|
56
|
+
# Discretize the projected eigenvectors
|
57
|
+
idx = torch.argmax(normalized_X @ self.R.mT, dim=-1) # int: [... x n]
|
58
|
+
M = torch.sum((idx[..., None] == torch.arange(d))[..., None] * normalized_X[..., :, None, :], dim=-3) # float: [... x d x d]
|
59
|
+
|
60
|
+
# Check for convergence
|
61
|
+
objective = torch.norm(M)
|
62
|
+
if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
|
63
|
+
break
|
64
|
+
prev_objective = objective
|
65
|
+
|
66
|
+
# SVD decomposition to compute the next R
|
67
|
+
U, S, Vh = torch.linalg.svd(M, full_matrices=False)
|
68
|
+
self.R = U @ Vh
|
69
|
+
|
70
|
+
# Permute the rotation matrix so the dimensions are sorted in descending cluster significance
|
71
|
+
match self.sort_method:
|
72
|
+
case "count":
|
73
|
+
sort_metric = torch.sum((idx[..., None] == torch.arange(d)), dim=-2)
|
74
|
+
case "norm":
|
75
|
+
rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
|
76
|
+
sort_metric = torch.linalg.norm(rotated_X, dim=-2)
|
77
|
+
case "marginal_norm":
|
78
|
+
rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
|
79
|
+
sort_metric = torch.sum((idx[..., None] == torch.arange(d)) * (torch.gather(rotated_X, -1, idx[..., None]) ** 2), dim=-2)
|
80
|
+
case _:
|
81
|
+
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
82
|
+
|
83
|
+
order = torch.argsort(sort_metric, dim=-1, descending=True)
|
84
|
+
self.R = torch.gather(self.R, -2, order[..., None].expand([-1] * order.ndim + [d]))
|
85
|
+
return self
|
75
86
|
|
76
87
|
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
77
88
|
"""
|
@@ -83,9 +94,9 @@ class AxisAlign(TorchTransformerMixin):
|
|
83
94
|
torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
|
84
95
|
"""
|
85
96
|
if normalize:
|
86
|
-
X = Fn.normalize(X, p=2, dim
|
97
|
+
X = Fn.normalize(X, p=2, dim=-1)
|
87
98
|
rotated_X = X @ self.R.mT
|
88
|
-
return torch.argmax(rotated_X, dim
|
99
|
+
return torch.argmax(rotated_X, dim=-1) if hard else rotated_X
|
89
100
|
|
90
101
|
def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
91
102
|
return self.fit(X).transform(X, normalize=normalize, hard=hard)
|