ncut-pytorch 3.0.0.dev3__tar.gz → 3.0.0.dev5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/PKG-INFO +1 -1
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/grad.py +139 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/math.py +4 -5
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/PKG-INFO +1 -1
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/pyproject.toml +1 -1
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/LICENSE +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/README.md +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/coloring.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/mspace.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/mspace_nopl.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncut.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/ncut_click.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/ncut_nystrom.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/api.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/patch.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/transform.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/vision_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/device.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/sample.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/sigma.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/torch_mod.py +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/requires.txt +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/top_level.txt +0 -0
- {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/setup.cfg +0 -0
|
@@ -113,3 +113,142 @@ def rbf_eigvec_manual_grad(
|
|
|
113
113
|
|
|
114
114
|
return grad_u
|
|
115
115
|
|
|
116
|
+
|
|
117
|
+
class MultiSpectralProjectorFromMasks(torch.autograd.Function):
|
|
118
|
+
"""
|
|
119
|
+
A (symmetric) -> {P_b}_b, where P_b = U_{S_b} U_{S_b}^T and S_b is specified by a boolean mask.
|
|
120
|
+
|
|
121
|
+
Computes eigh(A) ONCE, sorts eigenpairs DESCENDING (largest-first), then forms projectors
|
|
122
|
+
for each mask.
|
|
123
|
+
|
|
124
|
+
Inputs:
|
|
125
|
+
A: [N,N] (float), symmetric (or will be symmetrized if symmetrize=True)
|
|
126
|
+
masks: [B,N] (bool), masks[b,i]=True selects eigenvector i in DESCENDING eigen-order.
|
|
127
|
+
|
|
128
|
+
Output:
|
|
129
|
+
P: [B,N,N]
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def forward(
|
|
134
|
+
ctx,
|
|
135
|
+
A: torch.Tensor, # [N,N]
|
|
136
|
+
masks: torch.Tensor, # [B,N] bool (in DESCENDING eigen-order)
|
|
137
|
+
gap_eps: float = 0.0,
|
|
138
|
+
symmetrize: bool = True,
|
|
139
|
+
):
|
|
140
|
+
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
|
141
|
+
raise ValueError(f"A must be square [N,N], got {tuple(A.shape)}")
|
|
142
|
+
if masks.ndim != 2:
|
|
143
|
+
raise ValueError(f"masks must be [B,N], got {tuple(masks.shape)}")
|
|
144
|
+
if masks.dtype != torch.bool:
|
|
145
|
+
raise ValueError("masks must be boolean")
|
|
146
|
+
N = A.shape[0]
|
|
147
|
+
B, N2 = masks.shape
|
|
148
|
+
if N2 != N:
|
|
149
|
+
raise ValueError(f"masks second dim must equal N={N}, got {N2}")
|
|
150
|
+
if (masks.sum(dim=1) == 0).any():
|
|
151
|
+
raise ValueError("Each mask row must select at least one eigenvector.")
|
|
152
|
+
|
|
153
|
+
device = A.device
|
|
154
|
+
masks = masks.to(device=device)
|
|
155
|
+
|
|
156
|
+
A_used = 0.5 * (A + A.T) if symmetrize else A
|
|
157
|
+
|
|
158
|
+
# eigh ascending -> flip to descending
|
|
159
|
+
evals_asc, U_asc = torch.linalg.eigh(A_used)
|
|
160
|
+
evals = torch.flip(evals_asc, dims=[0]) # [N] descending
|
|
161
|
+
U = torch.flip(U_asc, dims=[1]) # [N,N] descending columns
|
|
162
|
+
|
|
163
|
+
# Build projectors
|
|
164
|
+
P_out = []
|
|
165
|
+
for b in range(B):
|
|
166
|
+
U_S = U[:, masks[b]] # [N,p_b]
|
|
167
|
+
P_b = U_S @ U_S.T # [N,N]
|
|
168
|
+
P_out.append(P_b)
|
|
169
|
+
P = torch.stack(P_out, dim=0) # [B,N,N]
|
|
170
|
+
|
|
171
|
+
ctx.save_for_backward(U, evals, masks)
|
|
172
|
+
ctx.gap_eps = float(gap_eps)
|
|
173
|
+
ctx.symmetrize = bool(symmetrize)
|
|
174
|
+
|
|
175
|
+
return P
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def backward(ctx, grad_P: torch.Tensor):
|
|
179
|
+
U, evals, masks = ctx.saved_tensors
|
|
180
|
+
gap_eps = ctx.gap_eps
|
|
181
|
+
symmetrize = ctx.symmetrize
|
|
182
|
+
|
|
183
|
+
if grad_P.ndim != 3:
|
|
184
|
+
raise ValueError(f"grad_P must be [B,N,N], got {tuple(grad_P.shape)}")
|
|
185
|
+
B, N, N2 = grad_P.shape
|
|
186
|
+
if N != N2:
|
|
187
|
+
raise ValueError("grad_P must be square per batch")
|
|
188
|
+
|
|
189
|
+
grad_A_used = torch.zeros((N, N), device=grad_P.device, dtype=grad_P.dtype)
|
|
190
|
+
|
|
191
|
+
for b in range(B):
|
|
192
|
+
mask = masks[b] # [N]
|
|
193
|
+
U_S = U[:, mask] # [N,p]
|
|
194
|
+
U_perp = U[:, ~mask] # [N,N-p]
|
|
195
|
+
lam_S = evals[mask] # [p]
|
|
196
|
+
lam_perp = evals[~mask] # [N-p]
|
|
197
|
+
|
|
198
|
+
# symmetric part only matters
|
|
199
|
+
G = grad_P[b]
|
|
200
|
+
Gs = 0.5 * (G + G.T)
|
|
201
|
+
|
|
202
|
+
# H = U_perp^T Gs U_S
|
|
203
|
+
H = U_perp.T @ (Gs @ U_S) # [N-p,p]
|
|
204
|
+
|
|
205
|
+
denom = lam_S[None, :] - lam_perp[:, None] # [N-p,p]
|
|
206
|
+
if gap_eps > 0.0:
|
|
207
|
+
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=gap_eps)
|
|
208
|
+
|
|
209
|
+
Q = H / denom # [N-p,p]
|
|
210
|
+
|
|
211
|
+
Bmat = U_perp @ (Q @ U_S.T) # [N,N]
|
|
212
|
+
grad_A_used = grad_A_used + (Bmat + Bmat.T)
|
|
213
|
+
|
|
214
|
+
grad_A = 0.5 * (grad_A_used + grad_A_used.T) if symmetrize else grad_A_used
|
|
215
|
+
return grad_A, None, None, None
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def spectral_projectors_from_masks(
|
|
219
|
+
A: torch.Tensor,
|
|
220
|
+
masks: torch.Tensor,
|
|
221
|
+
gap_eps: float = 0.0,
|
|
222
|
+
symmetrize: bool = True,
|
|
223
|
+
):
|
|
224
|
+
"""
|
|
225
|
+
Convenience wrapper.
|
|
226
|
+
|
|
227
|
+
masks: [B,N] bool in DESCENDING eigen-order (0 = largest eigenvalue).
|
|
228
|
+
returns P: [B,N,N]
|
|
229
|
+
"""
|
|
230
|
+
return MultiSpectralProjectorFromMasks.apply(A, masks, gap_eps, symmetrize)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
if __name__ == "__main__":
|
|
234
|
+
B = 2
|
|
235
|
+
N = 1000
|
|
236
|
+
masks = torch.zeros(B, N, dtype=torch.bool)
|
|
237
|
+
masks[0, :3] = True # top-3 eigenvectors (largest-first)
|
|
238
|
+
masks[1, 3:6] = True # next-3
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
A1 = torch.randn(N, N)
|
|
242
|
+
A1 = 0.5 * (A1 + A1.T)
|
|
243
|
+
A1.requires_grad_(True)
|
|
244
|
+
P1 = spectral_projectors_from_masks(A1, masks)
|
|
245
|
+
|
|
246
|
+
A2 = torch.randn(N, N)
|
|
247
|
+
A2 = 0.5 * (A2 + A2.T)
|
|
248
|
+
A2.requires_grad_(True)
|
|
249
|
+
P2 = spectral_projectors_from_masks(A2, masks)
|
|
250
|
+
|
|
251
|
+
loss = torch.norm(P1 - P2, p=2, dim=(0, 1)).sum()
|
|
252
|
+
loss.backward()
|
|
253
|
+
print(A1.grad.shape)
|
|
254
|
+
print(A2.grad.shape)
|
|
@@ -38,7 +38,7 @@ def rbf_affinity(
|
|
|
38
38
|
X1: torch.Tensor, # [N,D]
|
|
39
39
|
X2: torch.Tensor | None = None, # [M,D]
|
|
40
40
|
sigma: float = 1.0,
|
|
41
|
-
zero_diag: bool =
|
|
41
|
+
zero_diag: bool = False,
|
|
42
42
|
gamma: float | None = None, # deprecated
|
|
43
43
|
) -> torch.Tensor: # [N,M]
|
|
44
44
|
"""Computes RBF affinity matrix: W_ij = exp(-||x_i - x_j||^2 / (2 * sigma^2))."""
|
|
@@ -58,7 +58,7 @@ def cosine_affinity(
|
|
|
58
58
|
X2: torch.Tensor | None = None, # [M,D]
|
|
59
59
|
sigma: float = 1.0,
|
|
60
60
|
repulse: bool = False,
|
|
61
|
-
zero_diag: bool =
|
|
61
|
+
zero_diag: bool = False,
|
|
62
62
|
gamma: float | None = None, # deprecated
|
|
63
63
|
) -> torch.Tensor: # [N,M]
|
|
64
64
|
"""Computes cosine-based affinity matrix."""
|
|
@@ -98,11 +98,10 @@ def grad_safe_eig_solve(
|
|
|
98
98
|
is_symmetric = mat.shape[0] == mat.shape[1]
|
|
99
99
|
if is_symmetric:
|
|
100
100
|
s, u = torch.linalg.eigh(mat)
|
|
101
|
+
s = torch.flip(s, dims=[0])
|
|
102
|
+
u = torch.flip(u, dims=[1])
|
|
101
103
|
else:
|
|
102
104
|
s, u = torch.linalg.eig(mat)
|
|
103
|
-
sort_idx = torch.argsort(s, dim=0, descending=True)
|
|
104
|
-
s = s[sort_idx]
|
|
105
|
-
u = u[:, sort_idx]
|
|
106
105
|
return u.to(dtype), s.to(dtype), None
|
|
107
106
|
|
|
108
107
|
try:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/hires_dino.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/lowres_dino.py
RENAMED
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/transform.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino_predictor.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/jafar_predictor.py
RENAMED
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/vision_predictor.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|