ncut-pytorch 2.0.0.dev1__tar.gz → 2.0.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {ncut_pytorch-2.0.0.dev1/ncut_pytorch.egg-info → ncut_pytorch-2.0.0.dev2}/PKG-INFO +3 -2
  2. ncut_pytorch-2.0.0.dev2/ncut_pytorch/dino/__init__.py +4 -0
  3. ncut_pytorch-2.0.0.dev2/ncut_pytorch/dino/api.py +50 -0
  4. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/dino/hires_dino.py +49 -54
  5. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/dino/patch.py +2 -2
  6. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncut.py +2 -2
  7. ncut_pytorch-2.0.0.dev2/ncut_pytorch/ncuts/ncut_click.py +232 -0
  8. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncuts/ncut_kway.py +18 -8
  9. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncuts/ncut_nystrom.py +5 -12
  10. ncut_pytorch-2.0.0.dev2/ncut_pytorch/predictor.py +136 -0
  11. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/math_utils.py +2 -1
  12. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/sample_utils.py +14 -4
  13. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2/ncut_pytorch.egg-info}/PKG-INFO +3 -2
  14. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/SOURCES.txt +4 -6
  15. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/requires.txt +3 -1
  16. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/pyproject.toml +6 -2
  17. ncut_pytorch-2.0.0.dev2/requirements.txt +7 -0
  18. ncut_pytorch-2.0.0.dev1/ncut_pytorch/dino/__init__.py +0 -8
  19. ncut_pytorch-2.0.0.dev1/ncut_pytorch/ncuts/ncut_biased.py +0 -88
  20. ncut_pytorch-2.0.0.dev1/requirements.txt +0 -7
  21. ncut_pytorch-2.0.0.dev1/tests/test_real_images.py +0 -284
  22. ncut_pytorch-2.0.0.dev1/tests/test_real_images_densesparse.py +0 -243
  23. ncut_pytorch-2.0.0.dev1/tests/test_sample_imb copy.py +0 -100
  24. ncut_pytorch-2.0.0.dev1/tests/test_sample_imb.py +0 -106
  25. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/LICENSE +0 -0
  26. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/MANIFEST.in +0 -0
  27. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/README.md +0 -0
  28. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/__init__.py +0 -0
  29. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/color/__init__.py +0 -0
  30. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/color/coloring.py +0 -0
  31. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/color/mspace.py +0 -0
  32. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/dino/transform.py +0 -0
  33. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncuts/__init__.py +0 -0
  34. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/__init__.py +0 -0
  35. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/gamma.py +0 -0
  36. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  37. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/top_level.txt +0 -0
  38. {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 2.0.0.dev1
3
+ Version: 2.0.0.dev2
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -18,11 +18,12 @@ Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: torch~=2.0
20
20
  Requires-Dist: fpsample>=0.2.0
21
- Requires-Dist: torchvision>=0.15.0
22
21
  Requires-Dist: pytorch-lightning~=2.0
23
22
  Requires-Dist: pillow
24
23
  Requires-Dist: numpy
25
24
  Requires-Dist: tqdm
25
+ Provides-Extra: vision
26
+ Requires-Dist: torchvision>=0.15.0; extra == "vision"
26
27
  Dynamic: license-file
27
28
 
28
29
 
@@ -0,0 +1,4 @@
1
+ from .api import hires_dino_256
2
+ from .api import hires_dino_512
3
+ from .api import hires_dino_1024
4
+ from .api import hires_dinov2
@@ -0,0 +1,50 @@
1
+ from typing import Tuple
2
+ from torchvision import transforms
3
+ from .hires_dino import hires_dino
4
+ from .hires_dino import HighResDINO
5
+ from .transform import get_input_transform
6
+
7
+
8
+
9
+ def hires_dino_256() -> Tuple[HighResDINO, transforms.Compose]:
10
+ model = hires_dino(dino_name="dino_vitb8",
11
+ stride=6,
12
+ shift_dists=[1, 2, 3],
13
+ flip_transforms=True,
14
+ chunk_size=6,
15
+ feature_resolution=256)
16
+ transform = get_input_transform(resize=256)
17
+ return model, transform
18
+
19
+
20
+ def hires_dino_512() -> Tuple[HighResDINO, transforms.Compose]:
21
+ model = hires_dino(dino_name="dino_vitb8",
22
+ stride=6,
23
+ shift_dists=[1, 2, 3],
24
+ flip_transforms=True,
25
+ chunk_size=4,
26
+ feature_resolution=512)
27
+ transform = get_input_transform(resize=512)
28
+ return model, transform
29
+
30
+
31
+ def hires_dino_1024() -> Tuple[HighResDINO, transforms.Compose]:
32
+ model = hires_dino(dino_name="dino_vitb8",
33
+ stride=6,
34
+ shift_dists=[1, 2, 3],
35
+ flip_transforms=True,
36
+ chunk_size=1,
37
+ feature_resolution=1024)
38
+ transform = get_input_transform(resize=1024)
39
+ return model, transform
40
+
41
+
42
+ def hires_dinov2() -> Tuple[HighResDINO, transforms.Compose]:
43
+ model = hires_dino(dino_name="dinov2_vitb14_reg",
44
+ stride=6,
45
+ shift_dists=[1, 2, 3],
46
+ flip_transforms=True,
47
+ chunk_size=1,
48
+ feature_resolution=1008)
49
+ transform = get_input_transform(resize=1008)
50
+ return model, transform
@@ -28,10 +28,12 @@ class HighResDINO(nn.Module):
28
28
  def __init__(
29
29
  self,
30
30
  dino_name: DINONameOptions,
31
- stride: int = 6,
32
- dtype: torch.dtype | int = torch.float32,
31
+ stride: int = 5,
32
+ dtype: torch.dtype | int = torch.float16,
33
33
  track_grad: bool = False,
34
- attention_mask_ratio: float = 0.25,
34
+ attention_mask_ratio: float = 0.1,
35
+ chunk_size: int = 4,
36
+ feature_resolution: int = 1024,
35
37
  ) -> None:
36
38
  super().__init__()
37
39
 
@@ -60,7 +62,9 @@ class HighResDINO(nn.Module):
60
62
  self.feat_dim: int = feat
61
63
  self.n_heads: int = 6
62
64
  self.n_register_tokens = 4
63
-
65
+ self.chunk_size = chunk_size
66
+ self.feature_resolution = feature_resolution
67
+
64
68
  # mask out some of the attention Keys (K) to save compute
65
69
  self.attention_mask_ratio = attention_mask_ratio
66
70
 
@@ -142,7 +146,7 @@ class HighResDINO(nn.Module):
142
146
  attn_block = final_block.attn # type: ignore
143
147
  # hilariously this also works for dino i.e we can patch dino's attn block forward to
144
148
  # use the memeory efficienty attn like in dinov2
145
- attn_block.forward = MethodType(Patch._fix_mem_eff_attn(self.attention_mask_ratio), attn_block)
149
+ attn_block.forward = MethodType(Patch._fix_attn_masking(self.attention_mask_ratio), attn_block)
146
150
  if "dinov2" in dino_name:
147
151
  final_block.forward = MethodType(Patch._fix_block_forward_dv2(), final_block) # type: ignore
148
152
  dino_model.forward_feats_attn = MethodType( # type: ignore
@@ -152,7 +156,7 @@ class HighResDINO(nn.Module):
152
156
  for i, blk in enumerate(dino_model.blocks):
153
157
  blk.forward = MethodType(Patch._fix_block_forward_dino(), blk)
154
158
  attn_block = blk.attn
155
- attn_block.forward = MethodType(Patch._fix_mem_eff_attn(self.attention_mask_ratio), attn_block)
159
+ attn_block.forward = MethodType(Patch._fix_attn_masking(self.attention_mask_ratio), attn_block)
156
160
  final_block.forward = MethodType(Patch._fix_block_forward_dino(), final_block) # type: ignore
157
161
  dino_model.forward_feats_attn = MethodType( # type: ignore
158
162
  Patch._add_new_forward_features_dino(), dino_model
@@ -161,7 +165,7 @@ class HighResDINO(nn.Module):
161
165
  for i, blk in enumerate(dino_model.blocks):
162
166
  blk.forward = MethodType(Patch._fix_block_forward_dino(), blk)
163
167
  attn_block = blk.attn
164
- attn_block.forward = MethodType(Patch._fix_mem_eff_attn(self.attention_mask_ratio), attn_block)
168
+ attn_block.forward = MethodType(Patch._fix_attn_masking(self.attention_mask_ratio), attn_block)
165
169
  final_block.forward = MethodType(Patch._fix_block_forward_dino(), final_block) # type: ignore
166
170
  dino_model.forward_feats_attn = MethodType( # type: ignore
167
171
  Patch._add_new_forward_features_vit(), dino_model
@@ -244,8 +248,8 @@ class HighResDINO(nn.Module):
244
248
  out_feature_img: torch.Tensor = torch.zeros(
245
249
  1,
246
250
  c,
247
- img_h,
248
- img_w,
251
+ self.feature_resolution,
252
+ self.feature_resolution,
249
253
  device=x.device,
250
254
  dtype=self.dtype,
251
255
  requires_grad=self.track_grad,
@@ -264,7 +268,15 @@ class HighResDINO(nn.Module):
264
268
  mode=self.interpolation_mode,
265
269
  )
266
270
  inverted: torch.Tensor = inv_transform(full_size)
267
- out_feature_img += inverted
271
+
272
+ # resize the inverted feature map to the output resolution
273
+ out = F.interpolate(
274
+ inverted,
275
+ (self.feature_resolution, self.feature_resolution),
276
+ mode=self.interpolation_mode
277
+ )
278
+
279
+ out_feature_img += out
268
280
 
269
281
  n_imgs: int = feature_batch.shape[0]
270
282
  mean = out_feature_img / n_imgs
@@ -275,13 +287,11 @@ class HighResDINO(nn.Module):
275
287
  self,
276
288
  x: torch.Tensor,
277
289
  attn_choice: AttentionOptions = "none",
278
- chunk_size: int = 6,
279
290
  ) -> torch.Tensor:
280
291
  """Feed input img $x through network and get high-res features.
281
292
 
282
293
  :param x: unbatched image tensor, (c, h, w)
283
294
  :param attn_choice: choice of attention, "none" or "q", "k", "v", "o"
284
- :param chunk_size: number of images to process in one chunk, default 6, in case of OOM
285
295
  :return: upsampled features, (c, h, w)
286
296
  """
287
297
  if self.dtype != torch.float32: # cast (i.e to f16)
@@ -291,8 +301,8 @@ class HighResDINO(nn.Module):
291
301
  N_imgs = img_batch.shape[0]
292
302
 
293
303
  all_features = []
294
- for i in range(0, N_imgs, chunk_size):
295
- _img_batch = img_batch[i:i+chunk_size]
304
+ for i in range(0, N_imgs, self.chunk_size):
305
+ _img_batch = img_batch[i:i+self.chunk_size]
296
306
  out_dict = self.dinov2.forward_feats_attn(_img_batch, None, attn_choice) # type: ignore
297
307
  if attn_choice != "none":
298
308
  feats, attn = out_dict["x_norm_patchtokens"], out_dict["x_patchattn"]
@@ -314,23 +324,24 @@ class HighResDINO(nn.Module):
314
324
  self,
315
325
  x: torch.Tensor,
316
326
  attn_choice: AttentionOptions = "none",
317
- chunk_size: int = 6,
327
+ move_to_cpu: bool = True,
318
328
  ) -> torch.Tensor:
319
329
  """Feed input img $x through network and get low and high res features.
320
330
 
321
331
  :param x: batched image tensor, (b, c, h, w)
322
332
  :param attn_choice: choice of attention, "none" or "q", "k", "v", "o"
323
- :param chunk_size: number of images to process in one chunk, default 6, in case of OOM
324
333
  :return: upsampled features, (b, c, h, w)
325
334
  :rtype: torch.Tensor
326
335
  """
327
336
  upsampled_features = []
328
337
  for i in range(x.shape[0]):
329
338
  if self.track_grad:
330
- out = self._forward_one_image(x[i], attn_choice, chunk_size)
339
+ out = self._forward_one_image(x[i], attn_choice)
331
340
  else:
332
341
  with torch.no_grad():
333
- out = self._forward_one_image(x[i], attn_choice, chunk_size)
342
+ out = self._forward_one_image(x[i], attn_choice)
343
+ if move_to_cpu:
344
+ out = out.cpu()
334
345
  upsampled_features.append(out)
335
346
  upsampled_features = torch.stack(upsampled_features, dim=0)
336
347
  return upsampled_features
@@ -338,13 +349,28 @@ class HighResDINO(nn.Module):
338
349
 
339
350
  def hires_dino(dino_name: DINONameOptions = "dino_vitb8",
340
351
  stride: int = 6,
341
- attention_mask_ratio: float = 0.1,
342
- shift_dists: List[int] = [1, 2],
352
+ shift_dists: List[int] = [1, 2, 3],
343
353
  flip_transforms: bool = True,
354
+ attention_mask_ratio: float = 0.1,
344
355
  dtype: torch.dtype | int = torch.float16,
345
- track_grad: bool = False) -> HighResDINO:
346
-
347
- model = HighResDINO(dino_name, stride, dtype, track_grad, attention_mask_ratio)
356
+ track_grad: bool = False,
357
+ chunk_size: int = 4,
358
+ feature_resolution: int = 512
359
+ ) -> HighResDINO:
360
+ """
361
+ Args:
362
+ dino_name: name of the DINO model to use
363
+ stride: stride size of the tokenization, smaller is better but slower
364
+ shift_dists: pixel shifts for multiple image transformations, more shifts means more crispy features
365
+ flip_transforms: whether to use flip transforms, remove positional features
366
+ attention_mask_ratio: ratio of attention keys to mask out
367
+ dtype: data type of the model
368
+ track_grad: whether to track gradients
369
+ chunk_size: number of images to process in one batch, in case of OOM
370
+ feature_resolution: resolution of the output features
371
+ """
372
+
373
+ model = HighResDINO(dino_name, stride, dtype, track_grad, attention_mask_ratio, chunk_size, feature_resolution)
348
374
 
349
375
  fwd_shift, inv_shift = get_shift_transforms(shift_dists)
350
376
  if flip_transforms: # add flip transforms
@@ -359,34 +385,3 @@ def hires_dino(dino_name: DINONameOptions = "dino_vitb8",
359
385
  return model
360
386
 
361
387
 
362
- # ==================== MODEL PRESETS ====================
363
-
364
- # stride: stride size of the tokenization, the original model use patch size as stride size
365
- # reduce stride size to get more tokens and crispy features
366
- # shift_dists: pixel shifts for multiple image transformations, more shifts means more crispy features
367
- # flip_transforms: whether to use flip transforms, remove positional features
368
-
369
-
370
- def hires_dino_small() -> HighResDINO:
371
- return hires_dino(dino_name="dino_vitb8",
372
- stride=6,
373
- shift_dists=[1, 2, 3],
374
- flip_transforms=True)
375
-
376
- def hires_dino_base() -> HighResDINO:
377
- return hires_dino(dino_name="dino_vitb8",
378
- stride=4,
379
- shift_dists=[1, 2, 3],
380
- flip_transforms=True)
381
-
382
- def hires_dino_large() -> HighResDINO:
383
- return hires_dino(dino_name="dino_vitb8",
384
- stride=3,
385
- shift_dists=[1, 2, 3],
386
- flip_transforms=True)
387
-
388
- def hires_dinov2() -> HighResDINO:
389
- return hires_dino(dino_name="dinov2_vitb14_reg",
390
- stride=4,
391
- shift_dists=[1, 2, 3],
392
- flip_transforms=True)
@@ -161,7 +161,7 @@ class Patch:
161
161
  return forward
162
162
 
163
163
  @staticmethod
164
- def _fix_mem_eff_attn(attention_mask_ratio: float = 0.25) -> Callable:
164
+ def _fix_attn_masking(attention_mask_ratio: float) -> Callable:
165
165
  """Replaces normal 'forward()' method of the memory efficient attention layer (block.attn)
166
166
  in the Dv2 model with an optional early return with attention. Used if xformers used.
167
167
 
@@ -196,8 +196,8 @@ class Patch:
196
196
  k.transpose(1, 2),
197
197
  v.transpose(1, 2),
198
198
  attn_mask=attn_bias)
199
-
200
199
  x = x.transpose(1, 2)
200
+
201
201
  to_append: torch.Tensor
202
202
  if attn_choice != "none":
203
203
  to_append = get_qkvo_per_head(q, k, v, x, attn_choice, self.attn_drop)
@@ -109,9 +109,9 @@ class Ncut:
109
109
  def fit_transform(self, X: torch.Tensor) -> torch.Tensor:
110
110
  return self.fit(X).transform(X)
111
111
 
112
- def __new__(cls, X: torch.Tensor = None, **kwargs):
112
+ def __new__(cls, X: torch.Tensor = None, n_eig: int = 100, track_grad: bool = False, d_gamma: float = 0.1, device: str = 'auto', **kwargs):
113
113
  if X is not None:
114
- eigvec, eigval = ncut_fn(X, **kwargs) # function-like behavior
114
+ eigvec, eigval = ncut_fn(X, n_eig=n_eig, track_grad=track_grad, d_gamma=d_gamma, device=device, **kwargs) # function-like behavior
115
115
  return eigvec
116
116
  return super().__new__(cls) # normal class instantiation
117
117
 
@@ -0,0 +1,232 @@
1
+ # %%
2
+ from sympy import Q
3
+ import torch
4
+
5
+ from ncut_pytorch.utils.gamma import find_gamma_by_degree_after_fps, find_gamma_by_degree
6
+ from ncut_pytorch.utils.math_utils import get_affinity, normalize_affinity, svd_lowrank, correct_rotation
7
+ from ncut_pytorch.utils.sample_utils import farthest_point_sampling, auto_divice
8
+ from .ncut_kway import kway_ncut
9
+ from .ncut_nystrom import _nystrom_propagate
10
+ from .ncut_nystrom import _plain_ncut
11
+ from .ncut_nystrom import _NYSTROM_CONFIG
12
+
13
+
14
+ def ncut_click_prompt(
15
+ X: torch.Tensor,
16
+ fg_indices: torch.Tensor,
17
+ bg_indices: torch.Tensor = None,
18
+ click_weight: float = 0.5,
19
+ bg_weight: float = 0.1,
20
+ n_eig: int = 2,
21
+ track_grad: bool = False,
22
+ d_gamma: float = 0.1,
23
+ device: str = 'auto',
24
+ gamma: float = None,
25
+ no_propagation: bool = False,
26
+ **kwargs,
27
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
28
+
29
+ _config = _NYSTROM_CONFIG.copy()
30
+ _config.update(kwargs)
31
+
32
+ # use GPU if available
33
+ device = auto_divice(X.device, device)
34
+
35
+ # skip pytorch gradient computation if track_grad is False
36
+ prev_grad_state = torch.is_grad_enabled()
37
+ torch.set_grad_enabled(track_grad)
38
+
39
+ if bg_indices is None:
40
+ bg_indices = torch.tensor([], dtype=torch.long)
41
+
42
+ # subsample for nystrom approximation
43
+ nystrom_indices = farthest_point_sampling(X, n_sample=_config['n_sample'], device=device)
44
+ nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
45
+ # remove fg and bg from fps_idx
46
+ nystrom_indices = nystrom_indices[~torch.isin(nystrom_indices, torch.cat([fg_indices, bg_indices]))]
47
+ # add fg and bg to fps_idx
48
+ nystrom_indices = torch.cat([fg_indices, bg_indices, nystrom_indices])
49
+ fg_indices = torch.arange(len(fg_indices))
50
+ bg_indices = torch.arange(len(bg_indices)) + len(fg_indices)
51
+ n_fgbg = len(fg_indices) + len(bg_indices)
52
+
53
+ nystrom_X = X[nystrom_indices].to(device)
54
+
55
+ # find optimal gamma for affinity matrix
56
+ if gamma is None:
57
+ gamma = find_gamma_by_degree_after_fps(nystrom_X, d_gamma)
58
+
59
+ # compute Ncut on the nystrom sampled subgraph
60
+ A = get_affinity(nystrom_X, gamma=gamma)
61
+ A = normalize_affinity(A)
62
+
63
+ # modify the affinity from the clicks
64
+ X_click = 1 * A[fg_indices].mean(0)
65
+ if len(bg_indices) > 0:
66
+ X_click = X_click - bg_weight * A[bg_indices].mean(0)
67
+
68
+ X_click = X_click * A.shape[0]
69
+
70
+ # gamma2 = find_gamma_by_degree(X_click.unsqueeze(1), d_gamma)
71
+ # A_click = get_affinity(X_click.unsqueeze(1), gamma=gamma2)
72
+ A_click = get_affinity(X_click.unsqueeze(1), gamma=0.5)
73
+ # A_click = - torch.cdist(X_click.unsqueeze(1), X_click.unsqueeze(1))
74
+ A_click = normalize_affinity(A_click)
75
+
76
+ _A = click_weight * A_click + (1 - click_weight) * A
77
+ # _A = _A[n_fgbg:, n_fgbg:]
78
+ # nystrom_indices = nystrom_indices[n_fgbg:]
79
+ # nystrom_X = nystrom_X[n_fgbg:]
80
+
81
+ nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
82
+
83
+ if no_propagation:
84
+ torch.set_grad_enabled(prev_grad_state)
85
+ return nystrom_eigvec, eigval, nystrom_indices, gamma
86
+
87
+ # propagate eigenvectors from subgraph to full graph
88
+ eigvec = _nystrom_propagate(
89
+ nystrom_eigvec,
90
+ X,
91
+ nystrom_X,
92
+ n_neighbors=_config['n_neighbors'],
93
+ n_sample=_config['n_sample2'],
94
+ gamma=gamma,
95
+ chunk_size=_config['matmul_chunk_size'],
96
+ device=device,
97
+ move_output_to_cpu=_config['move_output_to_cpu'],
98
+ track_grad=track_grad,
99
+ )
100
+
101
+ torch.set_grad_enabled(prev_grad_state)
102
+
103
+ return eigvec, eigval
104
+
105
+
106
+ def get_mask_and_heatmap(eigvecs, fg_indices, n_cluster=2, device='auto'):
107
+ device = auto_divice(eigvecs.device, device)
108
+ eigvecs = eigvecs[:, :n_cluster]
109
+
110
+ eigvecs = kway_ncut(eigvecs, device=device)
111
+ # find which cluster is the foreground
112
+ fg_eigvecs = eigvecs[fg_indices]
113
+ fg_idx = fg_eigvecs.mean(0).argmax().item()
114
+ bg_idx = 1 if fg_idx == 0 else 0
115
+
116
+ # discretize the eigvecs
117
+ mask = eigvecs.argmax(dim=-1) == fg_idx
118
+
119
+ heatmap = eigvecs[:, fg_idx] - eigvecs[:, bg_idx]
120
+
121
+ return mask, heatmap
122
+
123
+
124
+
125
+ from ncut_pytorch.utils.math_utils import keep_topk_per_row
126
+
127
+ def ncut_click_prompt_cached(
128
+ nystrom_indices: torch.Tensor,
129
+ gamma: float,
130
+ X: torch.Tensor,
131
+ fg_indices: torch.Tensor,
132
+ bg_indices: torch.Tensor = None,
133
+ click_weight: float = 0.5,
134
+ bg_weight: float = 0.1,
135
+ n_eig: int = 2,
136
+ track_grad: bool = False,
137
+ device: str = 'auto',
138
+ **kwargs,
139
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
140
+
141
+ _config = _NYSTROM_CONFIG.copy()
142
+ _config.update(kwargs)
143
+
144
+ # use GPU if available
145
+ device = auto_divice(X.device, device)
146
+
147
+ # skip pytorch gradient computation if track_grad is False
148
+ prev_grad_state = torch.is_grad_enabled()
149
+ torch.set_grad_enabled(track_grad)
150
+
151
+ if bg_indices is None:
152
+ bg_indices = torch.tensor([], dtype=torch.long)
153
+
154
+ # subsample for nystrom approximation
155
+ nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
156
+ # add fg and bg to fps_idx
157
+ nystrom_indices = torch.cat([fg_indices, bg_indices, nystrom_indices])
158
+ fg_indices = torch.arange(len(fg_indices))
159
+ bg_indices = torch.arange(len(bg_indices)) + len(fg_indices)
160
+ n_fgbg = len(fg_indices) + len(bg_indices)
161
+
162
+ nystrom_X = X[nystrom_indices].to(device)
163
+
164
+ # compute Ncut on the nystrom sampled subgraph
165
+ A = get_affinity(nystrom_X, gamma=gamma)
166
+ A = normalize_affinity(A)
167
+
168
+ # modify the affinity from the clicks
169
+ X_click = 1 * A[fg_indices].mean(0)
170
+ if len(bg_indices) > 0:
171
+ X_click = X_click - bg_weight * A[bg_indices].mean(0)
172
+
173
+ X_click = X_click * A.shape[0]
174
+
175
+ A_click = get_affinity(X_click.unsqueeze(1), gamma=0.5)
176
+ A_click = normalize_affinity(A_click)
177
+
178
+ _A = click_weight * A_click + (1 - click_weight) * A
179
+ _A = _A[n_fgbg:, n_fgbg:]
180
+ nystrom_indices = nystrom_indices[n_fgbg:]
181
+ nystrom_X = nystrom_X[n_fgbg:]
182
+
183
+ nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
184
+
185
+ torch.set_grad_enabled(prev_grad_state)
186
+ return nystrom_eigvec, eigval
187
+
188
+
189
+ def _build_nystrom_graph(
190
+ X: torch.Tensor,
191
+ nystrom_X: torch.Tensor,
192
+ gamma: float = 1.0,
193
+ device: str = 'auto',
194
+ **kwargs,
195
+ ):
196
+ """propagate output from nystrom sampled nodes to all nodes,
197
+ use a weighted sum of the nearest neighbors to propagate the output.
198
+
199
+ Args:
200
+ nystrom_out (torch.Tensor): output from nystrom sampled nodes, shape (m, D)
201
+ X (torch.Tensor): input features for all nodes, shape (N, D)
202
+ nystrom_X (torch.Tensor): input features from nystrom sampled nodes, shape (m, D)
203
+ gamma (float): affinity parameter, default 1.0
204
+ track_grad (bool): keep track of pytorch gradients, default False
205
+ device (str): device to use for computation, if 'auto', will detect GPU automatically
206
+ _config (dict): configuration for nystrom approximation, default _NYSTROM_CONFIG
207
+
208
+ Returns:
209
+ torch.Tensor: output propagated by nearest neighbors, shape (N, D)
210
+ """
211
+
212
+ _config = _NYSTROM_CONFIG.copy()
213
+ _config.update(kwargs)
214
+
215
+ device = auto_divice(X.device, device)
216
+ nystrom_X = nystrom_X.to(device)
217
+
218
+ all_outs = []
219
+ n_chunk = _config['matmul_chunk_size']
220
+ n_neighbors = _config['n_neighbors']
221
+ cached_weights = torch.zeros((X.shape[0], nystrom_X.shape[0]),
222
+ device=device, dtype=X.dtype)
223
+ for i in range(0, X.shape[0], n_chunk):
224
+ end = min(i + n_chunk, X.shape[0])
225
+
226
+ _Ai = get_affinity(X[i:end].to(device), nystrom_X, gamma=gamma)
227
+ _Ai, _indices = keep_topk_per_row(_Ai, n_neighbors) # (n, n_neighbors)
228
+ row_indices = torch.arange(i, end).unsqueeze(1).expand(-1, n_neighbors) # shape (N, 10)
229
+ cached_weights[row_indices, _indices] = _Ai
230
+ print((cached_weights[i] > 0).sum())
231
+
232
+ return cached_weights
@@ -1,10 +1,12 @@
1
+ from re import U
1
2
  import torch
2
3
  import torch.nn.functional as F
3
4
 
4
5
  from ncut_pytorch.utils.sample_utils import farthest_point_sampling
6
+ from ncut_pytorch.utils.sample_utils import auto_divice
5
7
 
6
8
 
7
- def kway_ncut(eigvec: torch.Tensor, **kwargs):
9
+ def kway_ncut(eigvec: torch.Tensor, device: str = 'auto', **kwargs):
8
10
  """
9
11
  Args:
10
12
  eigvec (torch.Tensor): eigenvectors from Ncut output, shape (n, k)
@@ -13,14 +15,16 @@ def kway_ncut(eigvec: torch.Tensor, **kwargs):
13
15
  eigvec.argmax(dim=1) is the cluster assignment.
14
16
  eigvec.argmax(dim=0) is the cluster centroids.
15
17
  """
16
- R = axis_align(eigvec, **kwargs)
18
+ # __check_input_tensor(eigvec)
19
+
20
+ R = axis_align(eigvec, device=device, **kwargs)
17
21
  eigvec = F.normalize(eigvec, dim=1)
18
22
  eigvec = eigvec @ R
19
23
  return eigvec
20
24
 
21
25
 
22
26
  @torch.no_grad()
23
- def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
27
+ def axis_align(eigvec: torch.Tensor, device: str = 'auto', max_iter=1000, n_sample=10240):
24
28
  """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
25
29
 
26
30
  Args:
@@ -33,16 +37,20 @@ def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
33
37
 
34
38
  # subsample the eigenvectors, to speed up the computation
35
39
  n, k = eigvec.shape
36
- n_sample = max(n_sample, k)
37
- sample_idx = farthest_point_sampling(eigvec, n_sample)
40
+ sample_idx = farthest_point_sampling(eigvec, n_sample, device=device)
38
41
  eigvec = eigvec[sample_idx]
39
42
 
40
43
  eigvec = F.normalize(eigvec, dim=1)
41
44
 
42
45
  # Initialize R matrix with the first column from Farthest Point Sampling
43
- _sample_idx = farthest_point_sampling(eigvec, k)
46
+ _sample_idx = farthest_point_sampling(eigvec, k, device=device)
44
47
  R = eigvec[_sample_idx].T
45
48
 
49
+ original_device = eigvec.device
50
+ device = auto_divice(original_device, device)
51
+ eigvec = eigvec.to(device=device)
52
+ R = R.to(device=device)
53
+
46
54
  # Iterative optimization loop
47
55
  last_objective_value = 0
48
56
  exit_loop = False
@@ -54,12 +62,13 @@ def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
54
62
  # Discretize the projected eigenvectors
55
63
  _eigenvectors_continuous = eigvec @ R
56
64
  _eigenvectors_discrete = _onehot_discretize(_eigenvectors_continuous)
57
- _eigenvectors_discrete = _eigenvectors_discrete.to(device=eigvec.device, dtype=eigvec.dtype)
65
+ _eigenvectors_discrete = _eigenvectors_discrete.to(device=device, dtype=eigvec.dtype)
58
66
 
59
67
  # SVD decomposition
60
68
  _out = _eigenvectors_discrete.T @ eigvec
61
69
  U, S, Vh = torch.linalg.svd(_out, full_matrices=False)
62
70
  V = Vh.T
71
+ # U, S, V = svd_lowrank(_out, 100)
63
72
 
64
73
  # Compute the Ncut value
65
74
  ncut_value = 2 * (n - torch.sum(S))
@@ -71,7 +80,8 @@ def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
71
80
  else:
72
81
  last_objective_value = ncut_value
73
82
  R = V @ U.T
74
-
83
+
84
+ R = R.to(device=original_device)
75
85
  return R
76
86
 
77
87
 
@@ -12,7 +12,6 @@ _NYSTROM_CONFIG = {
12
12
  'n_sample2': 1024, # number of samples for eigenvector propagation, 1024 is large enough for most cases
13
13
  'n_neighbors': 10, # number of neighbors for eigenvector propagation, 10 is large enough for most cases
14
14
  'matmul_chunk_size': 16384, # chunk size for matrix multiplication, larger chunk size is faster but requires more memory
15
- 'sample_method': "farthest", # sample method for nystrom approximation, 'farthest' is FPS(Farthest Point Sampling)
16
15
  'move_output_to_cpu': True, # if True, will move output to cpu, which saves memory but loses gradients
17
16
  }
18
17
 
@@ -102,17 +101,6 @@ def _plain_ncut(
102
101
  A: torch.Tensor,
103
102
  n_eig: int = 100,
104
103
  ):
105
- """Normalized Cut.
106
-
107
- Args:
108
- A (torch.Tensor): affinity matrix, shape (N, N)
109
- n_eig (int): number of eigenvectors to return
110
-
111
- Returns:
112
- (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (N, n_eig)
113
- (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
114
- """
115
-
116
104
  # normalization; A = D^(-1/2) A D^(-1/2)
117
105
  A = normalize_affinity(A)
118
106
 
@@ -120,6 +108,11 @@ def _plain_ncut(
120
108
 
121
109
  # correct the random rotation (flipping sign) of eigenvectors
122
110
  eigvec = correct_rotation(eigvec)
111
+
112
+ assert not torch.any(torch.isnan(eigvec)), "eigvec contains NaN"
113
+ assert not torch.any(torch.isinf(eigvec)), "eigvec contains Inf"
114
+ assert not torch.any(torch.isnan(eigval)), "eigval contains NaN"
115
+ assert not torch.any(torch.isinf(eigval)), "eigval contains Inf"
123
116
 
124
117
  return eigvec, eigval
125
118