ncut-pytorch 3.0.0.dev4__tar.gz → 3.0.0.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/PKG-INFO +1 -1
  2. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/ncuts/ncut_click.py +0 -1
  3. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/ncuts/ncut_nystrom.py +31 -22
  4. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/grad.py +139 -0
  5. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/math.py +7 -4
  6. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/sample.py +1 -1
  7. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch.egg-info/PKG-INFO +1 -1
  8. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/pyproject.toml +1 -1
  9. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/LICENSE +0 -0
  10. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/README.md +0 -0
  11. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/__init__.py +0 -0
  12. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/color/__init__.py +0 -0
  13. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/color/coloring.py +0 -0
  14. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/color/mspace.py +0 -0
  15. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/color/mspace_nopl.py +0 -0
  16. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/ncut.py +0 -0
  17. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/ncuts/__init__.py +0 -0
  18. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
  19. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/__init__.py +0 -0
  20. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/__init__.py +0 -0
  21. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/api.py +0 -0
  22. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
  23. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
  24. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
  25. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/patch.py +0 -0
  26. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino/transform.py +0 -0
  27. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/dino_predictor.py +0 -0
  28. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
  29. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/predictor.py +0 -0
  30. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/predictor/vision_predictor.py +0 -0
  31. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/__init__.py +0 -0
  32. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/device.py +0 -0
  33. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/sigma.py +0 -0
  34. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch/utils/torch_mod.py +0 -0
  35. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
  36. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  37. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch.egg-info/requires.txt +0 -0
  38. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/ncut_pytorch.egg-info/top_level.txt +0 -0
  39. {ncut_pytorch-3.0.0.dev4 → ncut_pytorch-3.0.0.dev6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev4
3
+ Version: 3.0.0.dev6
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -89,7 +89,6 @@ def ncut_click_prompt(
89
89
  nystrom_X,
90
90
  n_neighbors=config.n_neighbors,
91
91
  n_sample=config.n_sample2,
92
- matmul_chunk_size=config.matmul_chunk_size,
93
92
  device=device,
94
93
  return_indices=True,
95
94
  )
@@ -9,7 +9,10 @@ from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity
9
9
  from ncut_pytorch.utils.math import gram_schmidt, normalize_affinity, grad_safe_eig_solve, correct_rotation, keep_topk_per_row, svd_lowrank
10
10
  from ncut_pytorch.utils.sample import farthest_point_sampling
11
11
  from ncut_pytorch.utils.device import auto_device
12
+ import logging
12
13
 
14
+ MATMUL_CHUNK_SIZE = 65536
15
+ SMALL_SCALE_THRESHOLD = 8192 # if the number of nodes is less than SMALL_SCALE_THRESHOLD, skip nystrom approximation use exact ncut
13
16
 
14
17
  class NystromConfig:
15
18
  """
@@ -21,7 +24,6 @@ class NystromConfig:
21
24
  n_sample2 = 1024 # number of samples for eigenvector propagation, 1024 is large enough for most cases
22
25
  n_neighbors = 32 # number of neighbors for eigenvector propagation, 10 is large enough for most cases
23
26
  n_neighbors_max_ratio = 1/32 # max ratio of n_neighbors to n_sample2, to avoid over smoothing
24
- matmul_chunk_size = 65536 # chunk size for matrix multiplication, larger chunk size is faster but requires more memory
25
27
 
26
28
  def update(self, kwargs: dict):
27
29
  for key, value in kwargs.items():
@@ -78,9 +80,11 @@ def ncut_fn(
78
80
  device = auto_device(X.device, device)
79
81
 
80
82
  # subsample for nystrom approximation
81
- is_enough_data = X.shape[0] > config.n_sample
82
83
  n_sample = min(config.n_sample, int(X.shape[0]*config.n_sample_max_ratio))
83
- nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device) if is_enough_data else np.arange(X.shape[0])
84
+ if X.shape[0] > SMALL_SCALE_THRESHOLD:
85
+ nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device)
86
+ else:
87
+ nystrom_indices = torch.arange(X.shape[0])
84
88
  nystrom_X = X[nystrom_indices].to(device)
85
89
 
86
90
  sigma, repulsion_sigma = find_optimal_sigma(nystrom_X, quantile_sigma, quantile_sigma_repulsion, sigma, repulsion_sigma, affinity_fn)
@@ -95,10 +99,6 @@ def ncut_fn(
95
99
  if no_propagation:
96
100
  return nystrom_eigvec, eigval, nystrom_indices, sigma
97
101
 
98
- if not is_enough_data:
99
- # skip nystrom approximation if not enough data, use exact ncut
100
- return nystrom_eigvec, eigval
101
-
102
102
  # propagate eigenvectors from subgraph to full graph
103
103
  eigvec = nystrom_propagate(
104
104
  nystrom_eigvec,
@@ -107,7 +107,6 @@ def ncut_fn(
107
107
  extrapolation_factor=extrapolation_factor,
108
108
  n_neighbors=config.n_neighbors,
109
109
  n_sample=config.n_sample2,
110
- matmul_chunk_size=config.matmul_chunk_size,
111
110
  device=device,
112
111
  )
113
112
 
@@ -117,6 +116,7 @@ def ncut_fn(
117
116
 
118
117
  return eigvec, eigval
119
118
 
119
+
120
120
  def find_optimal_sigma(
121
121
  X: torch.Tensor,
122
122
  quantile_sigma: float = 0.25,
@@ -137,6 +137,7 @@ def find_optimal_sigma(
137
137
  raise ValueError(f"`sigma` need to be provided for affinity function {affinity_fn}, (sigma=0.5, repulsion_sigma=0.3)")
138
138
  return sigma, repulsion_sigma
139
139
 
140
+
140
141
  def ncut_with_repulsion(
141
142
  X: torch.Tensor,
142
143
  n_eig: int = 100,
@@ -197,11 +198,16 @@ def nystrom_propagate(
197
198
  nystrom_X (torch.Tensor): input features from nystrom sampled nodes, shape (m, D)
198
199
  extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
199
200
  device (str): device to use for computation, if 'auto', will detect GPU automatically
200
- affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
201
+ return_indices (bool): whether to return the indices used for propagation
201
202
 
202
203
  Returns:
203
204
  torch.Tensor: output propagated by nearest neighbors, shape (N, D)
204
205
  """
206
+ if X.shape[0] <= SMALL_SCALE_THRESHOLD and nystrom_out.shape == X.shape and torch.allclose(nystrom_X.to(X.device), X, atol=1e-6):
207
+ # skip propagation if nystrom_out is the same as X, for small scale graph that don't need nystrom approximation
208
+ if return_indices:
209
+ return nystrom_out, np.arange(X.shape[0])
210
+ return nystrom_out
205
211
 
206
212
  config = NystromConfig()
207
213
  config.update(kwargs)
@@ -217,33 +223,36 @@ def nystrom_propagate(
217
223
 
218
224
  D = rbf_affinity(nystrom_X, sigma=sigma).mean(1)
219
225
 
220
- all_outs = []
221
- n_chunk = config.matmul_chunk_size
222
226
  n_neighbors = int(min(config.n_neighbors, len(indices)*config.n_neighbors_max_ratio))
223
227
  n_neighbors = max(n_neighbors, 4)
228
+ n_chunk = _find_max_chunk_size(X, nystrom_X, device)
229
+
230
+ all_outs = torch.empty((X.shape[0], nystrom_out.shape[-1]), device=output_device, dtype=nystrom_out.dtype)
224
231
  for i in range(0, X.shape[0], n_chunk):
225
232
  end = min(i + n_chunk, X.shape[0])
226
233
 
227
234
  _Ai = rbf_affinity(X[i:end].to(device), nystrom_X, sigma=sigma)
228
235
  _Ai, _indices = keep_topk_per_row(_Ai, n_neighbors) # (n, n_neighbors)
236
+
229
237
  _Di = D[_indices].sum(1)
230
238
  _Ai = _Ai / _Di[:, None]
231
239
 
232
- weights = _Ai[..., None] # (n, n_neighbors, 1)
233
- neighbors = nystrom_out[_indices.flatten()]
234
- neighbors = neighbors.reshape(-1, n_neighbors, nystrom_out.shape[-1]) # (n, n_neighbors, d)
235
- out = weights * neighbors # (n, n_neighbors, d)
236
- out = out.sum(dim=1) # (n, d)
237
-
238
- out = out.to(output_device)
239
- all_outs.append(out)
240
+ out = torch.einsum('nk,nkd->nd', _Ai, nystrom_out[_indices])
240
241
 
241
- all_outs = torch.cat(all_outs, dim=0)
242
+ all_outs[i:end] = out.to(output_device)
242
243
 
243
244
  if return_indices:
244
245
  return all_outs, indices
245
-
246
246
  return all_outs
247
247
 
248
248
 
249
-
249
+ def _find_max_chunk_size(X: torch.Tensor, nystrom_X: torch.Tensor, device: str):
250
+ max_chunk_size = MATMUL_CHUNK_SIZE
251
+ while max_chunk_size > 1:
252
+ try:
253
+ _ = rbf_affinity(X[:max_chunk_size].to(device), nystrom_X)
254
+ return max_chunk_size
255
+ except RuntimeError as e:
256
+ max_chunk_size = max_chunk_size // 2
257
+ continue
258
+ raise RuntimeError("failed to find max chunk size")
@@ -113,3 +113,142 @@ def rbf_eigvec_manual_grad(
113
113
 
114
114
  return grad_u
115
115
 
116
+
117
+ class MultiSpectralProjectorFromMasks(torch.autograd.Function):
118
+ """
119
+ A (symmetric) -> {P_b}_b, where P_b = U_{S_b} U_{S_b}^T and S_b is specified by a boolean mask.
120
+
121
+ Computes eigh(A) ONCE, sorts eigenpairs DESCENDING (largest-first), then forms projectors
122
+ for each mask.
123
+
124
+ Inputs:
125
+ A: [N,N] (float), symmetric (or will be symmetrized if symmetrize=True)
126
+ masks: [B,N] (bool), masks[b,i]=True selects eigenvector i in DESCENDING eigen-order.
127
+
128
+ Output:
129
+ P: [B,N,N]
130
+ """
131
+
132
+ @staticmethod
133
+ def forward(
134
+ ctx,
135
+ A: torch.Tensor, # [N,N]
136
+ masks: torch.Tensor, # [B,N] bool (in DESCENDING eigen-order)
137
+ gap_eps: float = 0.0,
138
+ symmetrize: bool = True,
139
+ ):
140
+ if A.ndim != 2 or A.shape[0] != A.shape[1]:
141
+ raise ValueError(f"A must be square [N,N], got {tuple(A.shape)}")
142
+ if masks.ndim != 2:
143
+ raise ValueError(f"masks must be [B,N], got {tuple(masks.shape)}")
144
+ if masks.dtype != torch.bool:
145
+ raise ValueError("masks must be boolean")
146
+ N = A.shape[0]
147
+ B, N2 = masks.shape
148
+ if N2 != N:
149
+ raise ValueError(f"masks second dim must equal N={N}, got {N2}")
150
+ if (masks.sum(dim=1) == 0).any():
151
+ raise ValueError("Each mask row must select at least one eigenvector.")
152
+
153
+ device = A.device
154
+ masks = masks.to(device=device)
155
+
156
+ A_used = 0.5 * (A + A.T) if symmetrize else A
157
+
158
+ # eigh ascending -> flip to descending
159
+ evals_asc, U_asc = torch.linalg.eigh(A_used)
160
+ evals = torch.flip(evals_asc, dims=[0]) # [N] descending
161
+ U = torch.flip(U_asc, dims=[1]) # [N,N] descending columns
162
+
163
+ # Build projectors
164
+ P_out = []
165
+ for b in range(B):
166
+ U_S = U[:, masks[b]] # [N,p_b]
167
+ P_b = U_S @ U_S.T # [N,N]
168
+ P_out.append(P_b)
169
+ P = torch.stack(P_out, dim=0) # [B,N,N]
170
+
171
+ ctx.save_for_backward(U, evals, masks)
172
+ ctx.gap_eps = float(gap_eps)
173
+ ctx.symmetrize = bool(symmetrize)
174
+
175
+ return P
176
+
177
+ @staticmethod
178
+ def backward(ctx, grad_P: torch.Tensor):
179
+ U, evals, masks = ctx.saved_tensors
180
+ gap_eps = ctx.gap_eps
181
+ symmetrize = ctx.symmetrize
182
+
183
+ if grad_P.ndim != 3:
184
+ raise ValueError(f"grad_P must be [B,N,N], got {tuple(grad_P.shape)}")
185
+ B, N, N2 = grad_P.shape
186
+ if N != N2:
187
+ raise ValueError("grad_P must be square per batch")
188
+
189
+ grad_A_used = torch.zeros((N, N), device=grad_P.device, dtype=grad_P.dtype)
190
+
191
+ for b in range(B):
192
+ mask = masks[b] # [N]
193
+ U_S = U[:, mask] # [N,p]
194
+ U_perp = U[:, ~mask] # [N,N-p]
195
+ lam_S = evals[mask] # [p]
196
+ lam_perp = evals[~mask] # [N-p]
197
+
198
+ # symmetric part only matters
199
+ G = grad_P[b]
200
+ Gs = 0.5 * (G + G.T)
201
+
202
+ # H = U_perp^T Gs U_S
203
+ H = U_perp.T @ (Gs @ U_S) # [N-p,p]
204
+
205
+ denom = lam_S[None, :] - lam_perp[:, None] # [N-p,p]
206
+ if gap_eps > 0.0:
207
+ denom = torch.sign(denom) * torch.clamp(denom.abs(), min=gap_eps)
208
+
209
+ Q = H / denom # [N-p,p]
210
+
211
+ Bmat = U_perp @ (Q @ U_S.T) # [N,N]
212
+ grad_A_used = grad_A_used + (Bmat + Bmat.T)
213
+
214
+ grad_A = 0.5 * (grad_A_used + grad_A_used.T) if symmetrize else grad_A_used
215
+ return grad_A, None, None, None
216
+
217
+
218
+ def spectral_projectors_from_masks(
219
+ A: torch.Tensor,
220
+ masks: torch.Tensor,
221
+ gap_eps: float = 0.0,
222
+ symmetrize: bool = True,
223
+ ):
224
+ """
225
+ Convenience wrapper.
226
+
227
+ masks: [B,N] bool in DESCENDING eigen-order (0 = largest eigenvalue).
228
+ returns P: [B,N,N]
229
+ """
230
+ return MultiSpectralProjectorFromMasks.apply(A, masks, gap_eps, symmetrize)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ B = 2
235
+ N = 1000
236
+ masks = torch.zeros(B, N, dtype=torch.bool)
237
+ masks[0, :3] = True # top-3 eigenvectors (largest-first)
238
+ masks[1, 3:6] = True # next-3
239
+
240
+
241
+ A1 = torch.randn(N, N)
242
+ A1 = 0.5 * (A1 + A1.T)
243
+ A1.requires_grad_(True)
244
+ P1 = spectral_projectors_from_masks(A1, masks)
245
+
246
+ A2 = torch.randn(N, N)
247
+ A2 = 0.5 * (A2 + A2.T)
248
+ A2.requires_grad_(True)
249
+ P2 = spectral_projectors_from_masks(A2, masks)
250
+
251
+ loss = torch.norm(P1 - P2, p=2, dim=(0, 1)).sum()
252
+ loss.backward()
253
+ print(A1.grad.shape)
254
+ print(A2.grad.shape)
@@ -45,7 +45,11 @@ def rbf_affinity(
45
45
  sigma = sigma if gamma is None else check_gamma_deprecated(gamma)
46
46
  X2 = X1 if X2 is None else X2
47
47
 
48
- dist2 = torch.cdist(X1, X2, p=2)**2
48
+ try:
49
+ dist2 = torch.cdist(X1, X2, p=2)**2
50
+ except NotImplementedError:
51
+ dist2 = X1.unsqueeze(1) - X2.unsqueeze(0)
52
+ dist2 = dist2.pow(2).sum(dim=-1)
49
53
  W = torch.exp(-dist2 / (2.0 * sigma * sigma)) # [N,M]
50
54
  if zero_diag and X1 is X2:
51
55
  W = W.clone()
@@ -98,11 +102,10 @@ def grad_safe_eig_solve(
98
102
  is_symmetric = mat.shape[0] == mat.shape[1]
99
103
  if is_symmetric:
100
104
  s, u = torch.linalg.eigh(mat)
105
+ s = torch.flip(s, dims=[0])
106
+ u = torch.flip(u, dims=[1])
101
107
  else:
102
108
  s, u = torch.linalg.eig(mat)
103
- sort_idx = torch.argsort(s, dim=0, descending=True)
104
- s = s[sort_idx]
105
- u = u[:, sort_idx]
106
109
  return u.to(dtype), s.to(dtype), None
107
110
 
108
111
  try:
@@ -60,5 +60,5 @@ def _farthest_point_sampling(
60
60
  assert not torch.any(torch.isnan(X)), "X contains NaN"
61
61
  assert not torch.any(torch.isinf(X)), "X contains Inf"
62
62
 
63
- samples_idx = sample_idx(X.cpu(), n_sample).numpy()
63
+ samples_idx = sample_idx(X.cpu(), n_sample)
64
64
  return samples_idx
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev4
3
+ Version: 3.0.0.dev6
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ncut_pytorch"
7
- version = "3.0.0dev4"
7
+ version = "3.0.0dev6"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  ]