ncut-pytorch 3.0.0.dev3__tar.gz → 3.0.0.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/PKG-INFO +1 -1
  2. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/grad.py +139 -0
  3. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/math.py +4 -5
  4. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/PKG-INFO +1 -1
  5. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/pyproject.toml +1 -1
  6. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/LICENSE +0 -0
  7. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/README.md +0 -0
  8. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/__init__.py +0 -0
  9. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/__init__.py +0 -0
  10. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/coloring.py +0 -0
  11. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/mspace.py +0 -0
  12. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/color/mspace_nopl.py +0 -0
  13. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncut.py +0 -0
  14. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/__init__.py +0 -0
  15. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/ncut_click.py +0 -0
  16. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
  17. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/ncuts/ncut_nystrom.py +0 -0
  18. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/__init__.py +0 -0
  19. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/__init__.py +0 -0
  20. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/api.py +0 -0
  21. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
  22. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
  23. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
  24. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/patch.py +0 -0
  25. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino/transform.py +0 -0
  26. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/dino_predictor.py +0 -0
  27. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
  28. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/predictor.py +0 -0
  29. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/predictor/vision_predictor.py +0 -0
  30. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/__init__.py +0 -0
  31. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/device.py +0 -0
  32. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/sample.py +0 -0
  33. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/sigma.py +0 -0
  34. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch/utils/torch_mod.py +0 -0
  35. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
  36. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  37. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/requires.txt +0 -0
  38. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/ncut_pytorch.egg-info/top_level.txt +0 -0
  39. {ncut_pytorch-3.0.0.dev3 → ncut_pytorch-3.0.0.dev5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev3
3
+ Version: 3.0.0.dev5
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -113,3 +113,142 @@ def rbf_eigvec_manual_grad(
113
113
 
114
114
  return grad_u
115
115
 
116
+
117
+ class MultiSpectralProjectorFromMasks(torch.autograd.Function):
118
+ """
119
+ A (symmetric) -> {P_b}_b, where P_b = U_{S_b} U_{S_b}^T and S_b is specified by a boolean mask.
120
+
121
+ Computes eigh(A) ONCE, sorts eigenpairs DESCENDING (largest-first), then forms projectors
122
+ for each mask.
123
+
124
+ Inputs:
125
+ A: [N,N] (float), symmetric (or will be symmetrized if symmetrize=True)
126
+ masks: [B,N] (bool), masks[b,i]=True selects eigenvector i in DESCENDING eigen-order.
127
+
128
+ Output:
129
+ P: [B,N,N]
130
+ """
131
+
132
+ @staticmethod
133
+ def forward(
134
+ ctx,
135
+ A: torch.Tensor, # [N,N]
136
+ masks: torch.Tensor, # [B,N] bool (in DESCENDING eigen-order)
137
+ gap_eps: float = 0.0,
138
+ symmetrize: bool = True,
139
+ ):
140
+ if A.ndim != 2 or A.shape[0] != A.shape[1]:
141
+ raise ValueError(f"A must be square [N,N], got {tuple(A.shape)}")
142
+ if masks.ndim != 2:
143
+ raise ValueError(f"masks must be [B,N], got {tuple(masks.shape)}")
144
+ if masks.dtype != torch.bool:
145
+ raise ValueError("masks must be boolean")
146
+ N = A.shape[0]
147
+ B, N2 = masks.shape
148
+ if N2 != N:
149
+ raise ValueError(f"masks second dim must equal N={N}, got {N2}")
150
+ if (masks.sum(dim=1) == 0).any():
151
+ raise ValueError("Each mask row must select at least one eigenvector.")
152
+
153
+ device = A.device
154
+ masks = masks.to(device=device)
155
+
156
+ A_used = 0.5 * (A + A.T) if symmetrize else A
157
+
158
+ # eigh ascending -> flip to descending
159
+ evals_asc, U_asc = torch.linalg.eigh(A_used)
160
+ evals = torch.flip(evals_asc, dims=[0]) # [N] descending
161
+ U = torch.flip(U_asc, dims=[1]) # [N,N] descending columns
162
+
163
+ # Build projectors
164
+ P_out = []
165
+ for b in range(B):
166
+ U_S = U[:, masks[b]] # [N,p_b]
167
+ P_b = U_S @ U_S.T # [N,N]
168
+ P_out.append(P_b)
169
+ P = torch.stack(P_out, dim=0) # [B,N,N]
170
+
171
+ ctx.save_for_backward(U, evals, masks)
172
+ ctx.gap_eps = float(gap_eps)
173
+ ctx.symmetrize = bool(symmetrize)
174
+
175
+ return P
176
+
177
+ @staticmethod
178
+ def backward(ctx, grad_P: torch.Tensor):
179
+ U, evals, masks = ctx.saved_tensors
180
+ gap_eps = ctx.gap_eps
181
+ symmetrize = ctx.symmetrize
182
+
183
+ if grad_P.ndim != 3:
184
+ raise ValueError(f"grad_P must be [B,N,N], got {tuple(grad_P.shape)}")
185
+ B, N, N2 = grad_P.shape
186
+ if N != N2:
187
+ raise ValueError("grad_P must be square per batch")
188
+
189
+ grad_A_used = torch.zeros((N, N), device=grad_P.device, dtype=grad_P.dtype)
190
+
191
+ for b in range(B):
192
+ mask = masks[b] # [N]
193
+ U_S = U[:, mask] # [N,p]
194
+ U_perp = U[:, ~mask] # [N,N-p]
195
+ lam_S = evals[mask] # [p]
196
+ lam_perp = evals[~mask] # [N-p]
197
+
198
+ # symmetric part only matters
199
+ G = grad_P[b]
200
+ Gs = 0.5 * (G + G.T)
201
+
202
+ # H = U_perp^T Gs U_S
203
+ H = U_perp.T @ (Gs @ U_S) # [N-p,p]
204
+
205
+ denom = lam_S[None, :] - lam_perp[:, None] # [N-p,p]
206
+ if gap_eps > 0.0:
207
+ denom = torch.sign(denom) * torch.clamp(denom.abs(), min=gap_eps)
208
+
209
+ Q = H / denom # [N-p,p]
210
+
211
+ Bmat = U_perp @ (Q @ U_S.T) # [N,N]
212
+ grad_A_used = grad_A_used + (Bmat + Bmat.T)
213
+
214
+ grad_A = 0.5 * (grad_A_used + grad_A_used.T) if symmetrize else grad_A_used
215
+ return grad_A, None, None, None
216
+
217
+
218
+ def spectral_projectors_from_masks(
219
+ A: torch.Tensor,
220
+ masks: torch.Tensor,
221
+ gap_eps: float = 0.0,
222
+ symmetrize: bool = True,
223
+ ):
224
+ """
225
+ Convenience wrapper.
226
+
227
+ masks: [B,N] bool in DESCENDING eigen-order (0 = largest eigenvalue).
228
+ returns P: [B,N,N]
229
+ """
230
+ return MultiSpectralProjectorFromMasks.apply(A, masks, gap_eps, symmetrize)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ B = 2
235
+ N = 1000
236
+ masks = torch.zeros(B, N, dtype=torch.bool)
237
+ masks[0, :3] = True # top-3 eigenvectors (largest-first)
238
+ masks[1, 3:6] = True # next-3
239
+
240
+
241
+ A1 = torch.randn(N, N)
242
+ A1 = 0.5 * (A1 + A1.T)
243
+ A1.requires_grad_(True)
244
+ P1 = spectral_projectors_from_masks(A1, masks)
245
+
246
+ A2 = torch.randn(N, N)
247
+ A2 = 0.5 * (A2 + A2.T)
248
+ A2.requires_grad_(True)
249
+ P2 = spectral_projectors_from_masks(A2, masks)
250
+
251
+ loss = torch.norm(P1 - P2, p=2, dim=(0, 1)).sum()
252
+ loss.backward()
253
+ print(A1.grad.shape)
254
+ print(A2.grad.shape)
@@ -38,7 +38,7 @@ def rbf_affinity(
38
38
  X1: torch.Tensor, # [N,D]
39
39
  X2: torch.Tensor | None = None, # [M,D]
40
40
  sigma: float = 1.0,
41
- zero_diag: bool = True,
41
+ zero_diag: bool = False,
42
42
  gamma: float | None = None, # deprecated
43
43
  ) -> torch.Tensor: # [N,M]
44
44
  """Computes RBF affinity matrix: W_ij = exp(-||x_i - x_j||^2 / (2 * sigma^2))."""
@@ -58,7 +58,7 @@ def cosine_affinity(
58
58
  X2: torch.Tensor | None = None, # [M,D]
59
59
  sigma: float = 1.0,
60
60
  repulse: bool = False,
61
- zero_diag: bool = True,
61
+ zero_diag: bool = False,
62
62
  gamma: float | None = None, # deprecated
63
63
  ) -> torch.Tensor: # [N,M]
64
64
  """Computes cosine-based affinity matrix."""
@@ -98,11 +98,10 @@ def grad_safe_eig_solve(
98
98
  is_symmetric = mat.shape[0] == mat.shape[1]
99
99
  if is_symmetric:
100
100
  s, u = torch.linalg.eigh(mat)
101
+ s = torch.flip(s, dims=[0])
102
+ u = torch.flip(u, dims=[1])
101
103
  else:
102
104
  s, u = torch.linalg.eig(mat)
103
- sort_idx = torch.argsort(s, dim=0, descending=True)
104
- s = s[sort_idx]
105
- u = u[:, sort_idx]
106
105
  return u.to(dtype), s.to(dtype), None
107
106
 
108
107
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev3
3
+ Version: 3.0.0.dev5
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ncut_pytorch"
7
- version = "3.0.0dev3"
7
+ version = "3.0.0dev5"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  ]