ncut-pytorch 2.0.0.dev1__tar.gz → 2.0.0.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ncut_pytorch-2.0.0.dev1/ncut_pytorch.egg-info → ncut_pytorch-2.0.0.dev2}/PKG-INFO +3 -2
- ncut_pytorch-2.0.0.dev2/ncut_pytorch/dino/__init__.py +4 -0
- ncut_pytorch-2.0.0.dev2/ncut_pytorch/dino/api.py +50 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/dino/hires_dino.py +49 -54
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/dino/patch.py +2 -2
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncut.py +2 -2
- ncut_pytorch-2.0.0.dev2/ncut_pytorch/ncuts/ncut_click.py +232 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncuts/ncut_kway.py +18 -8
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncuts/ncut_nystrom.py +5 -12
- ncut_pytorch-2.0.0.dev2/ncut_pytorch/predictor.py +136 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/math_utils.py +2 -1
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/sample_utils.py +14 -4
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2/ncut_pytorch.egg-info}/PKG-INFO +3 -2
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/SOURCES.txt +4 -6
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/requires.txt +3 -1
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/pyproject.toml +6 -2
- ncut_pytorch-2.0.0.dev2/requirements.txt +7 -0
- ncut_pytorch-2.0.0.dev1/ncut_pytorch/dino/__init__.py +0 -8
- ncut_pytorch-2.0.0.dev1/ncut_pytorch/ncuts/ncut_biased.py +0 -88
- ncut_pytorch-2.0.0.dev1/requirements.txt +0 -7
- ncut_pytorch-2.0.0.dev1/tests/test_real_images.py +0 -284
- ncut_pytorch-2.0.0.dev1/tests/test_real_images_densesparse.py +0 -243
- ncut_pytorch-2.0.0.dev1/tests/test_sample_imb copy.py +0 -100
- ncut_pytorch-2.0.0.dev1/tests/test_sample_imb.py +0 -106
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/LICENSE +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/MANIFEST.in +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/README.md +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/__init__.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/color/__init__.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/color/coloring.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/color/mspace.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/dino/transform.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/ncuts/__init__.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/__init__.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch/utils/gamma.py +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/ncut_pytorch.egg-info/top_level.txt +0 -0
- {ncut_pytorch-2.0.0.dev1 → ncut_pytorch-2.0.0.dev2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ncut_pytorch
|
|
3
|
-
Version: 2.0.0.
|
|
3
|
+
Version: 2.0.0.dev2
|
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -18,11 +18,12 @@ Description-Content-Type: text/markdown
|
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: torch~=2.0
|
|
20
20
|
Requires-Dist: fpsample>=0.2.0
|
|
21
|
-
Requires-Dist: torchvision>=0.15.0
|
|
22
21
|
Requires-Dist: pytorch-lightning~=2.0
|
|
23
22
|
Requires-Dist: pillow
|
|
24
23
|
Requires-Dist: numpy
|
|
25
24
|
Requires-Dist: tqdm
|
|
25
|
+
Provides-Extra: vision
|
|
26
|
+
Requires-Dist: torchvision>=0.15.0; extra == "vision"
|
|
26
27
|
Dynamic: license-file
|
|
27
28
|
|
|
28
29
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from torchvision import transforms
|
|
3
|
+
from .hires_dino import hires_dino
|
|
4
|
+
from .hires_dino import HighResDINO
|
|
5
|
+
from .transform import get_input_transform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def hires_dino_256() -> Tuple[HighResDINO, transforms.Compose]:
|
|
10
|
+
model = hires_dino(dino_name="dino_vitb8",
|
|
11
|
+
stride=6,
|
|
12
|
+
shift_dists=[1, 2, 3],
|
|
13
|
+
flip_transforms=True,
|
|
14
|
+
chunk_size=6,
|
|
15
|
+
feature_resolution=256)
|
|
16
|
+
transform = get_input_transform(resize=256)
|
|
17
|
+
return model, transform
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def hires_dino_512() -> Tuple[HighResDINO, transforms.Compose]:
|
|
21
|
+
model = hires_dino(dino_name="dino_vitb8",
|
|
22
|
+
stride=6,
|
|
23
|
+
shift_dists=[1, 2, 3],
|
|
24
|
+
flip_transforms=True,
|
|
25
|
+
chunk_size=4,
|
|
26
|
+
feature_resolution=512)
|
|
27
|
+
transform = get_input_transform(resize=512)
|
|
28
|
+
return model, transform
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def hires_dino_1024() -> Tuple[HighResDINO, transforms.Compose]:
|
|
32
|
+
model = hires_dino(dino_name="dino_vitb8",
|
|
33
|
+
stride=6,
|
|
34
|
+
shift_dists=[1, 2, 3],
|
|
35
|
+
flip_transforms=True,
|
|
36
|
+
chunk_size=1,
|
|
37
|
+
feature_resolution=1024)
|
|
38
|
+
transform = get_input_transform(resize=1024)
|
|
39
|
+
return model, transform
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def hires_dinov2() -> Tuple[HighResDINO, transforms.Compose]:
|
|
43
|
+
model = hires_dino(dino_name="dinov2_vitb14_reg",
|
|
44
|
+
stride=6,
|
|
45
|
+
shift_dists=[1, 2, 3],
|
|
46
|
+
flip_transforms=True,
|
|
47
|
+
chunk_size=1,
|
|
48
|
+
feature_resolution=1008)
|
|
49
|
+
transform = get_input_transform(resize=1008)
|
|
50
|
+
return model, transform
|
|
@@ -28,10 +28,12 @@ class HighResDINO(nn.Module):
|
|
|
28
28
|
def __init__(
|
|
29
29
|
self,
|
|
30
30
|
dino_name: DINONameOptions,
|
|
31
|
-
stride: int =
|
|
32
|
-
dtype: torch.dtype | int = torch.
|
|
31
|
+
stride: int = 5,
|
|
32
|
+
dtype: torch.dtype | int = torch.float16,
|
|
33
33
|
track_grad: bool = False,
|
|
34
|
-
attention_mask_ratio: float = 0.
|
|
34
|
+
attention_mask_ratio: float = 0.1,
|
|
35
|
+
chunk_size: int = 4,
|
|
36
|
+
feature_resolution: int = 1024,
|
|
35
37
|
) -> None:
|
|
36
38
|
super().__init__()
|
|
37
39
|
|
|
@@ -60,7 +62,9 @@ class HighResDINO(nn.Module):
|
|
|
60
62
|
self.feat_dim: int = feat
|
|
61
63
|
self.n_heads: int = 6
|
|
62
64
|
self.n_register_tokens = 4
|
|
63
|
-
|
|
65
|
+
self.chunk_size = chunk_size
|
|
66
|
+
self.feature_resolution = feature_resolution
|
|
67
|
+
|
|
64
68
|
# mask out some of the attention Keys (K) to save compute
|
|
65
69
|
self.attention_mask_ratio = attention_mask_ratio
|
|
66
70
|
|
|
@@ -142,7 +146,7 @@ class HighResDINO(nn.Module):
|
|
|
142
146
|
attn_block = final_block.attn # type: ignore
|
|
143
147
|
# hilariously this also works for dino i.e we can patch dino's attn block forward to
|
|
144
148
|
# use the memeory efficienty attn like in dinov2
|
|
145
|
-
attn_block.forward = MethodType(Patch.
|
|
149
|
+
attn_block.forward = MethodType(Patch._fix_attn_masking(self.attention_mask_ratio), attn_block)
|
|
146
150
|
if "dinov2" in dino_name:
|
|
147
151
|
final_block.forward = MethodType(Patch._fix_block_forward_dv2(), final_block) # type: ignore
|
|
148
152
|
dino_model.forward_feats_attn = MethodType( # type: ignore
|
|
@@ -152,7 +156,7 @@ class HighResDINO(nn.Module):
|
|
|
152
156
|
for i, blk in enumerate(dino_model.blocks):
|
|
153
157
|
blk.forward = MethodType(Patch._fix_block_forward_dino(), blk)
|
|
154
158
|
attn_block = blk.attn
|
|
155
|
-
attn_block.forward = MethodType(Patch.
|
|
159
|
+
attn_block.forward = MethodType(Patch._fix_attn_masking(self.attention_mask_ratio), attn_block)
|
|
156
160
|
final_block.forward = MethodType(Patch._fix_block_forward_dino(), final_block) # type: ignore
|
|
157
161
|
dino_model.forward_feats_attn = MethodType( # type: ignore
|
|
158
162
|
Patch._add_new_forward_features_dino(), dino_model
|
|
@@ -161,7 +165,7 @@ class HighResDINO(nn.Module):
|
|
|
161
165
|
for i, blk in enumerate(dino_model.blocks):
|
|
162
166
|
blk.forward = MethodType(Patch._fix_block_forward_dino(), blk)
|
|
163
167
|
attn_block = blk.attn
|
|
164
|
-
attn_block.forward = MethodType(Patch.
|
|
168
|
+
attn_block.forward = MethodType(Patch._fix_attn_masking(self.attention_mask_ratio), attn_block)
|
|
165
169
|
final_block.forward = MethodType(Patch._fix_block_forward_dino(), final_block) # type: ignore
|
|
166
170
|
dino_model.forward_feats_attn = MethodType( # type: ignore
|
|
167
171
|
Patch._add_new_forward_features_vit(), dino_model
|
|
@@ -244,8 +248,8 @@ class HighResDINO(nn.Module):
|
|
|
244
248
|
out_feature_img: torch.Tensor = torch.zeros(
|
|
245
249
|
1,
|
|
246
250
|
c,
|
|
247
|
-
|
|
248
|
-
|
|
251
|
+
self.feature_resolution,
|
|
252
|
+
self.feature_resolution,
|
|
249
253
|
device=x.device,
|
|
250
254
|
dtype=self.dtype,
|
|
251
255
|
requires_grad=self.track_grad,
|
|
@@ -264,7 +268,15 @@ class HighResDINO(nn.Module):
|
|
|
264
268
|
mode=self.interpolation_mode,
|
|
265
269
|
)
|
|
266
270
|
inverted: torch.Tensor = inv_transform(full_size)
|
|
267
|
-
|
|
271
|
+
|
|
272
|
+
# resize the inverted feature map to the output resolution
|
|
273
|
+
out = F.interpolate(
|
|
274
|
+
inverted,
|
|
275
|
+
(self.feature_resolution, self.feature_resolution),
|
|
276
|
+
mode=self.interpolation_mode
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
out_feature_img += out
|
|
268
280
|
|
|
269
281
|
n_imgs: int = feature_batch.shape[0]
|
|
270
282
|
mean = out_feature_img / n_imgs
|
|
@@ -275,13 +287,11 @@ class HighResDINO(nn.Module):
|
|
|
275
287
|
self,
|
|
276
288
|
x: torch.Tensor,
|
|
277
289
|
attn_choice: AttentionOptions = "none",
|
|
278
|
-
chunk_size: int = 6,
|
|
279
290
|
) -> torch.Tensor:
|
|
280
291
|
"""Feed input img $x through network and get high-res features.
|
|
281
292
|
|
|
282
293
|
:param x: unbatched image tensor, (c, h, w)
|
|
283
294
|
:param attn_choice: choice of attention, "none" or "q", "k", "v", "o"
|
|
284
|
-
:param chunk_size: number of images to process in one chunk, default 6, in case of OOM
|
|
285
295
|
:return: upsampled features, (c, h, w)
|
|
286
296
|
"""
|
|
287
297
|
if self.dtype != torch.float32: # cast (i.e to f16)
|
|
@@ -291,8 +301,8 @@ class HighResDINO(nn.Module):
|
|
|
291
301
|
N_imgs = img_batch.shape[0]
|
|
292
302
|
|
|
293
303
|
all_features = []
|
|
294
|
-
for i in range(0, N_imgs, chunk_size):
|
|
295
|
-
_img_batch = img_batch[i:i+chunk_size]
|
|
304
|
+
for i in range(0, N_imgs, self.chunk_size):
|
|
305
|
+
_img_batch = img_batch[i:i+self.chunk_size]
|
|
296
306
|
out_dict = self.dinov2.forward_feats_attn(_img_batch, None, attn_choice) # type: ignore
|
|
297
307
|
if attn_choice != "none":
|
|
298
308
|
feats, attn = out_dict["x_norm_patchtokens"], out_dict["x_patchattn"]
|
|
@@ -314,23 +324,24 @@ class HighResDINO(nn.Module):
|
|
|
314
324
|
self,
|
|
315
325
|
x: torch.Tensor,
|
|
316
326
|
attn_choice: AttentionOptions = "none",
|
|
317
|
-
|
|
327
|
+
move_to_cpu: bool = True,
|
|
318
328
|
) -> torch.Tensor:
|
|
319
329
|
"""Feed input img $x through network and get low and high res features.
|
|
320
330
|
|
|
321
331
|
:param x: batched image tensor, (b, c, h, w)
|
|
322
332
|
:param attn_choice: choice of attention, "none" or "q", "k", "v", "o"
|
|
323
|
-
:param chunk_size: number of images to process in one chunk, default 6, in case of OOM
|
|
324
333
|
:return: upsampled features, (b, c, h, w)
|
|
325
334
|
:rtype: torch.Tensor
|
|
326
335
|
"""
|
|
327
336
|
upsampled_features = []
|
|
328
337
|
for i in range(x.shape[0]):
|
|
329
338
|
if self.track_grad:
|
|
330
|
-
out = self._forward_one_image(x[i], attn_choice
|
|
339
|
+
out = self._forward_one_image(x[i], attn_choice)
|
|
331
340
|
else:
|
|
332
341
|
with torch.no_grad():
|
|
333
|
-
out = self._forward_one_image(x[i], attn_choice
|
|
342
|
+
out = self._forward_one_image(x[i], attn_choice)
|
|
343
|
+
if move_to_cpu:
|
|
344
|
+
out = out.cpu()
|
|
334
345
|
upsampled_features.append(out)
|
|
335
346
|
upsampled_features = torch.stack(upsampled_features, dim=0)
|
|
336
347
|
return upsampled_features
|
|
@@ -338,13 +349,28 @@ class HighResDINO(nn.Module):
|
|
|
338
349
|
|
|
339
350
|
def hires_dino(dino_name: DINONameOptions = "dino_vitb8",
|
|
340
351
|
stride: int = 6,
|
|
341
|
-
|
|
342
|
-
shift_dists: List[int] = [1, 2],
|
|
352
|
+
shift_dists: List[int] = [1, 2, 3],
|
|
343
353
|
flip_transforms: bool = True,
|
|
354
|
+
attention_mask_ratio: float = 0.1,
|
|
344
355
|
dtype: torch.dtype | int = torch.float16,
|
|
345
|
-
track_grad: bool = False
|
|
346
|
-
|
|
347
|
-
|
|
356
|
+
track_grad: bool = False,
|
|
357
|
+
chunk_size: int = 4,
|
|
358
|
+
feature_resolution: int = 512
|
|
359
|
+
) -> HighResDINO:
|
|
360
|
+
"""
|
|
361
|
+
Args:
|
|
362
|
+
dino_name: name of the DINO model to use
|
|
363
|
+
stride: stride size of the tokenization, smaller is better but slower
|
|
364
|
+
shift_dists: pixel shifts for multiple image transformations, more shifts means more crispy features
|
|
365
|
+
flip_transforms: whether to use flip transforms, remove positional features
|
|
366
|
+
attention_mask_ratio: ratio of attention keys to mask out
|
|
367
|
+
dtype: data type of the model
|
|
368
|
+
track_grad: whether to track gradients
|
|
369
|
+
chunk_size: number of images to process in one batch, in case of OOM
|
|
370
|
+
feature_resolution: resolution of the output features
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
model = HighResDINO(dino_name, stride, dtype, track_grad, attention_mask_ratio, chunk_size, feature_resolution)
|
|
348
374
|
|
|
349
375
|
fwd_shift, inv_shift = get_shift_transforms(shift_dists)
|
|
350
376
|
if flip_transforms: # add flip transforms
|
|
@@ -359,34 +385,3 @@ def hires_dino(dino_name: DINONameOptions = "dino_vitb8",
|
|
|
359
385
|
return model
|
|
360
386
|
|
|
361
387
|
|
|
362
|
-
# ==================== MODEL PRESETS ====================
|
|
363
|
-
|
|
364
|
-
# stride: stride size of the tokenization, the original model use patch size as stride size
|
|
365
|
-
# reduce stride size to get more tokens and crispy features
|
|
366
|
-
# shift_dists: pixel shifts for multiple image transformations, more shifts means more crispy features
|
|
367
|
-
# flip_transforms: whether to use flip transforms, remove positional features
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
def hires_dino_small() -> HighResDINO:
|
|
371
|
-
return hires_dino(dino_name="dino_vitb8",
|
|
372
|
-
stride=6,
|
|
373
|
-
shift_dists=[1, 2, 3],
|
|
374
|
-
flip_transforms=True)
|
|
375
|
-
|
|
376
|
-
def hires_dino_base() -> HighResDINO:
|
|
377
|
-
return hires_dino(dino_name="dino_vitb8",
|
|
378
|
-
stride=4,
|
|
379
|
-
shift_dists=[1, 2, 3],
|
|
380
|
-
flip_transforms=True)
|
|
381
|
-
|
|
382
|
-
def hires_dino_large() -> HighResDINO:
|
|
383
|
-
return hires_dino(dino_name="dino_vitb8",
|
|
384
|
-
stride=3,
|
|
385
|
-
shift_dists=[1, 2, 3],
|
|
386
|
-
flip_transforms=True)
|
|
387
|
-
|
|
388
|
-
def hires_dinov2() -> HighResDINO:
|
|
389
|
-
return hires_dino(dino_name="dinov2_vitb14_reg",
|
|
390
|
-
stride=4,
|
|
391
|
-
shift_dists=[1, 2, 3],
|
|
392
|
-
flip_transforms=True)
|
|
@@ -161,7 +161,7 @@ class Patch:
|
|
|
161
161
|
return forward
|
|
162
162
|
|
|
163
163
|
@staticmethod
|
|
164
|
-
def
|
|
164
|
+
def _fix_attn_masking(attention_mask_ratio: float) -> Callable:
|
|
165
165
|
"""Replaces normal 'forward()' method of the memory efficient attention layer (block.attn)
|
|
166
166
|
in the Dv2 model with an optional early return with attention. Used if xformers used.
|
|
167
167
|
|
|
@@ -196,8 +196,8 @@ class Patch:
|
|
|
196
196
|
k.transpose(1, 2),
|
|
197
197
|
v.transpose(1, 2),
|
|
198
198
|
attn_mask=attn_bias)
|
|
199
|
-
|
|
200
199
|
x = x.transpose(1, 2)
|
|
200
|
+
|
|
201
201
|
to_append: torch.Tensor
|
|
202
202
|
if attn_choice != "none":
|
|
203
203
|
to_append = get_qkvo_per_head(q, k, v, x, attn_choice, self.attn_drop)
|
|
@@ -109,9 +109,9 @@ class Ncut:
|
|
|
109
109
|
def fit_transform(self, X: torch.Tensor) -> torch.Tensor:
|
|
110
110
|
return self.fit(X).transform(X)
|
|
111
111
|
|
|
112
|
-
def __new__(cls, X: torch.Tensor = None, **kwargs):
|
|
112
|
+
def __new__(cls, X: torch.Tensor = None, n_eig: int = 100, track_grad: bool = False, d_gamma: float = 0.1, device: str = 'auto', **kwargs):
|
|
113
113
|
if X is not None:
|
|
114
|
-
eigvec, eigval = ncut_fn(X, **kwargs) # function-like behavior
|
|
114
|
+
eigvec, eigval = ncut_fn(X, n_eig=n_eig, track_grad=track_grad, d_gamma=d_gamma, device=device, **kwargs) # function-like behavior
|
|
115
115
|
return eigvec
|
|
116
116
|
return super().__new__(cls) # normal class instantiation
|
|
117
117
|
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# %%
|
|
2
|
+
from sympy import Q
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ncut_pytorch.utils.gamma import find_gamma_by_degree_after_fps, find_gamma_by_degree
|
|
6
|
+
from ncut_pytorch.utils.math_utils import get_affinity, normalize_affinity, svd_lowrank, correct_rotation
|
|
7
|
+
from ncut_pytorch.utils.sample_utils import farthest_point_sampling, auto_divice
|
|
8
|
+
from .ncut_kway import kway_ncut
|
|
9
|
+
from .ncut_nystrom import _nystrom_propagate
|
|
10
|
+
from .ncut_nystrom import _plain_ncut
|
|
11
|
+
from .ncut_nystrom import _NYSTROM_CONFIG
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def ncut_click_prompt(
|
|
15
|
+
X: torch.Tensor,
|
|
16
|
+
fg_indices: torch.Tensor,
|
|
17
|
+
bg_indices: torch.Tensor = None,
|
|
18
|
+
click_weight: float = 0.5,
|
|
19
|
+
bg_weight: float = 0.1,
|
|
20
|
+
n_eig: int = 2,
|
|
21
|
+
track_grad: bool = False,
|
|
22
|
+
d_gamma: float = 0.1,
|
|
23
|
+
device: str = 'auto',
|
|
24
|
+
gamma: float = None,
|
|
25
|
+
no_propagation: bool = False,
|
|
26
|
+
**kwargs,
|
|
27
|
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
|
|
28
|
+
|
|
29
|
+
_config = _NYSTROM_CONFIG.copy()
|
|
30
|
+
_config.update(kwargs)
|
|
31
|
+
|
|
32
|
+
# use GPU if available
|
|
33
|
+
device = auto_divice(X.device, device)
|
|
34
|
+
|
|
35
|
+
# skip pytorch gradient computation if track_grad is False
|
|
36
|
+
prev_grad_state = torch.is_grad_enabled()
|
|
37
|
+
torch.set_grad_enabled(track_grad)
|
|
38
|
+
|
|
39
|
+
if bg_indices is None:
|
|
40
|
+
bg_indices = torch.tensor([], dtype=torch.long)
|
|
41
|
+
|
|
42
|
+
# subsample for nystrom approximation
|
|
43
|
+
nystrom_indices = farthest_point_sampling(X, n_sample=_config['n_sample'], device=device)
|
|
44
|
+
nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
|
|
45
|
+
# remove fg and bg from fps_idx
|
|
46
|
+
nystrom_indices = nystrom_indices[~torch.isin(nystrom_indices, torch.cat([fg_indices, bg_indices]))]
|
|
47
|
+
# add fg and bg to fps_idx
|
|
48
|
+
nystrom_indices = torch.cat([fg_indices, bg_indices, nystrom_indices])
|
|
49
|
+
fg_indices = torch.arange(len(fg_indices))
|
|
50
|
+
bg_indices = torch.arange(len(bg_indices)) + len(fg_indices)
|
|
51
|
+
n_fgbg = len(fg_indices) + len(bg_indices)
|
|
52
|
+
|
|
53
|
+
nystrom_X = X[nystrom_indices].to(device)
|
|
54
|
+
|
|
55
|
+
# find optimal gamma for affinity matrix
|
|
56
|
+
if gamma is None:
|
|
57
|
+
gamma = find_gamma_by_degree_after_fps(nystrom_X, d_gamma)
|
|
58
|
+
|
|
59
|
+
# compute Ncut on the nystrom sampled subgraph
|
|
60
|
+
A = get_affinity(nystrom_X, gamma=gamma)
|
|
61
|
+
A = normalize_affinity(A)
|
|
62
|
+
|
|
63
|
+
# modify the affinity from the clicks
|
|
64
|
+
X_click = 1 * A[fg_indices].mean(0)
|
|
65
|
+
if len(bg_indices) > 0:
|
|
66
|
+
X_click = X_click - bg_weight * A[bg_indices].mean(0)
|
|
67
|
+
|
|
68
|
+
X_click = X_click * A.shape[0]
|
|
69
|
+
|
|
70
|
+
# gamma2 = find_gamma_by_degree(X_click.unsqueeze(1), d_gamma)
|
|
71
|
+
# A_click = get_affinity(X_click.unsqueeze(1), gamma=gamma2)
|
|
72
|
+
A_click = get_affinity(X_click.unsqueeze(1), gamma=0.5)
|
|
73
|
+
# A_click = - torch.cdist(X_click.unsqueeze(1), X_click.unsqueeze(1))
|
|
74
|
+
A_click = normalize_affinity(A_click)
|
|
75
|
+
|
|
76
|
+
_A = click_weight * A_click + (1 - click_weight) * A
|
|
77
|
+
# _A = _A[n_fgbg:, n_fgbg:]
|
|
78
|
+
# nystrom_indices = nystrom_indices[n_fgbg:]
|
|
79
|
+
# nystrom_X = nystrom_X[n_fgbg:]
|
|
80
|
+
|
|
81
|
+
nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
|
|
82
|
+
|
|
83
|
+
if no_propagation:
|
|
84
|
+
torch.set_grad_enabled(prev_grad_state)
|
|
85
|
+
return nystrom_eigvec, eigval, nystrom_indices, gamma
|
|
86
|
+
|
|
87
|
+
# propagate eigenvectors from subgraph to full graph
|
|
88
|
+
eigvec = _nystrom_propagate(
|
|
89
|
+
nystrom_eigvec,
|
|
90
|
+
X,
|
|
91
|
+
nystrom_X,
|
|
92
|
+
n_neighbors=_config['n_neighbors'],
|
|
93
|
+
n_sample=_config['n_sample2'],
|
|
94
|
+
gamma=gamma,
|
|
95
|
+
chunk_size=_config['matmul_chunk_size'],
|
|
96
|
+
device=device,
|
|
97
|
+
move_output_to_cpu=_config['move_output_to_cpu'],
|
|
98
|
+
track_grad=track_grad,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
torch.set_grad_enabled(prev_grad_state)
|
|
102
|
+
|
|
103
|
+
return eigvec, eigval
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_mask_and_heatmap(eigvecs, fg_indices, n_cluster=2, device='auto'):
|
|
107
|
+
device = auto_divice(eigvecs.device, device)
|
|
108
|
+
eigvecs = eigvecs[:, :n_cluster]
|
|
109
|
+
|
|
110
|
+
eigvecs = kway_ncut(eigvecs, device=device)
|
|
111
|
+
# find which cluster is the foreground
|
|
112
|
+
fg_eigvecs = eigvecs[fg_indices]
|
|
113
|
+
fg_idx = fg_eigvecs.mean(0).argmax().item()
|
|
114
|
+
bg_idx = 1 if fg_idx == 0 else 0
|
|
115
|
+
|
|
116
|
+
# discretize the eigvecs
|
|
117
|
+
mask = eigvecs.argmax(dim=-1) == fg_idx
|
|
118
|
+
|
|
119
|
+
heatmap = eigvecs[:, fg_idx] - eigvecs[:, bg_idx]
|
|
120
|
+
|
|
121
|
+
return mask, heatmap
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
from ncut_pytorch.utils.math_utils import keep_topk_per_row
|
|
126
|
+
|
|
127
|
+
def ncut_click_prompt_cached(
|
|
128
|
+
nystrom_indices: torch.Tensor,
|
|
129
|
+
gamma: float,
|
|
130
|
+
X: torch.Tensor,
|
|
131
|
+
fg_indices: torch.Tensor,
|
|
132
|
+
bg_indices: torch.Tensor = None,
|
|
133
|
+
click_weight: float = 0.5,
|
|
134
|
+
bg_weight: float = 0.1,
|
|
135
|
+
n_eig: int = 2,
|
|
136
|
+
track_grad: bool = False,
|
|
137
|
+
device: str = 'auto',
|
|
138
|
+
**kwargs,
|
|
139
|
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
|
|
140
|
+
|
|
141
|
+
_config = _NYSTROM_CONFIG.copy()
|
|
142
|
+
_config.update(kwargs)
|
|
143
|
+
|
|
144
|
+
# use GPU if available
|
|
145
|
+
device = auto_divice(X.device, device)
|
|
146
|
+
|
|
147
|
+
# skip pytorch gradient computation if track_grad is False
|
|
148
|
+
prev_grad_state = torch.is_grad_enabled()
|
|
149
|
+
torch.set_grad_enabled(track_grad)
|
|
150
|
+
|
|
151
|
+
if bg_indices is None:
|
|
152
|
+
bg_indices = torch.tensor([], dtype=torch.long)
|
|
153
|
+
|
|
154
|
+
# subsample for nystrom approximation
|
|
155
|
+
nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
|
|
156
|
+
# add fg and bg to fps_idx
|
|
157
|
+
nystrom_indices = torch.cat([fg_indices, bg_indices, nystrom_indices])
|
|
158
|
+
fg_indices = torch.arange(len(fg_indices))
|
|
159
|
+
bg_indices = torch.arange(len(bg_indices)) + len(fg_indices)
|
|
160
|
+
n_fgbg = len(fg_indices) + len(bg_indices)
|
|
161
|
+
|
|
162
|
+
nystrom_X = X[nystrom_indices].to(device)
|
|
163
|
+
|
|
164
|
+
# compute Ncut on the nystrom sampled subgraph
|
|
165
|
+
A = get_affinity(nystrom_X, gamma=gamma)
|
|
166
|
+
A = normalize_affinity(A)
|
|
167
|
+
|
|
168
|
+
# modify the affinity from the clicks
|
|
169
|
+
X_click = 1 * A[fg_indices].mean(0)
|
|
170
|
+
if len(bg_indices) > 0:
|
|
171
|
+
X_click = X_click - bg_weight * A[bg_indices].mean(0)
|
|
172
|
+
|
|
173
|
+
X_click = X_click * A.shape[0]
|
|
174
|
+
|
|
175
|
+
A_click = get_affinity(X_click.unsqueeze(1), gamma=0.5)
|
|
176
|
+
A_click = normalize_affinity(A_click)
|
|
177
|
+
|
|
178
|
+
_A = click_weight * A_click + (1 - click_weight) * A
|
|
179
|
+
_A = _A[n_fgbg:, n_fgbg:]
|
|
180
|
+
nystrom_indices = nystrom_indices[n_fgbg:]
|
|
181
|
+
nystrom_X = nystrom_X[n_fgbg:]
|
|
182
|
+
|
|
183
|
+
nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
|
|
184
|
+
|
|
185
|
+
torch.set_grad_enabled(prev_grad_state)
|
|
186
|
+
return nystrom_eigvec, eigval
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _build_nystrom_graph(
|
|
190
|
+
X: torch.Tensor,
|
|
191
|
+
nystrom_X: torch.Tensor,
|
|
192
|
+
gamma: float = 1.0,
|
|
193
|
+
device: str = 'auto',
|
|
194
|
+
**kwargs,
|
|
195
|
+
):
|
|
196
|
+
"""propagate output from nystrom sampled nodes to all nodes,
|
|
197
|
+
use a weighted sum of the nearest neighbors to propagate the output.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
nystrom_out (torch.Tensor): output from nystrom sampled nodes, shape (m, D)
|
|
201
|
+
X (torch.Tensor): input features for all nodes, shape (N, D)
|
|
202
|
+
nystrom_X (torch.Tensor): input features from nystrom sampled nodes, shape (m, D)
|
|
203
|
+
gamma (float): affinity parameter, default 1.0
|
|
204
|
+
track_grad (bool): keep track of pytorch gradients, default False
|
|
205
|
+
device (str): device to use for computation, if 'auto', will detect GPU automatically
|
|
206
|
+
_config (dict): configuration for nystrom approximation, default _NYSTROM_CONFIG
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
torch.Tensor: output propagated by nearest neighbors, shape (N, D)
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
_config = _NYSTROM_CONFIG.copy()
|
|
213
|
+
_config.update(kwargs)
|
|
214
|
+
|
|
215
|
+
device = auto_divice(X.device, device)
|
|
216
|
+
nystrom_X = nystrom_X.to(device)
|
|
217
|
+
|
|
218
|
+
all_outs = []
|
|
219
|
+
n_chunk = _config['matmul_chunk_size']
|
|
220
|
+
n_neighbors = _config['n_neighbors']
|
|
221
|
+
cached_weights = torch.zeros((X.shape[0], nystrom_X.shape[0]),
|
|
222
|
+
device=device, dtype=X.dtype)
|
|
223
|
+
for i in range(0, X.shape[0], n_chunk):
|
|
224
|
+
end = min(i + n_chunk, X.shape[0])
|
|
225
|
+
|
|
226
|
+
_Ai = get_affinity(X[i:end].to(device), nystrom_X, gamma=gamma)
|
|
227
|
+
_Ai, _indices = keep_topk_per_row(_Ai, n_neighbors) # (n, n_neighbors)
|
|
228
|
+
row_indices = torch.arange(i, end).unsqueeze(1).expand(-1, n_neighbors) # shape (N, 10)
|
|
229
|
+
cached_weights[row_indices, _indices] = _Ai
|
|
230
|
+
print((cached_weights[i] > 0).sum())
|
|
231
|
+
|
|
232
|
+
return cached_weights
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
from re import U
|
|
1
2
|
import torch
|
|
2
3
|
import torch.nn.functional as F
|
|
3
4
|
|
|
4
5
|
from ncut_pytorch.utils.sample_utils import farthest_point_sampling
|
|
6
|
+
from ncut_pytorch.utils.sample_utils import auto_divice
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
def kway_ncut(eigvec: torch.Tensor, **kwargs):
|
|
9
|
+
def kway_ncut(eigvec: torch.Tensor, device: str = 'auto', **kwargs):
|
|
8
10
|
"""
|
|
9
11
|
Args:
|
|
10
12
|
eigvec (torch.Tensor): eigenvectors from Ncut output, shape (n, k)
|
|
@@ -13,14 +15,16 @@ def kway_ncut(eigvec: torch.Tensor, **kwargs):
|
|
|
13
15
|
eigvec.argmax(dim=1) is the cluster assignment.
|
|
14
16
|
eigvec.argmax(dim=0) is the cluster centroids.
|
|
15
17
|
"""
|
|
16
|
-
|
|
18
|
+
# __check_input_tensor(eigvec)
|
|
19
|
+
|
|
20
|
+
R = axis_align(eigvec, device=device, **kwargs)
|
|
17
21
|
eigvec = F.normalize(eigvec, dim=1)
|
|
18
22
|
eigvec = eigvec @ R
|
|
19
23
|
return eigvec
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
@torch.no_grad()
|
|
23
|
-
def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
|
|
27
|
+
def axis_align(eigvec: torch.Tensor, device: str = 'auto', max_iter=1000, n_sample=10240):
|
|
24
28
|
"""Multiclass Spectral Clustering, SX Yu, J Shi, 2003
|
|
25
29
|
|
|
26
30
|
Args:
|
|
@@ -33,16 +37,20 @@ def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
|
|
|
33
37
|
|
|
34
38
|
# subsample the eigenvectors, to speed up the computation
|
|
35
39
|
n, k = eigvec.shape
|
|
36
|
-
|
|
37
|
-
sample_idx = farthest_point_sampling(eigvec, n_sample)
|
|
40
|
+
sample_idx = farthest_point_sampling(eigvec, n_sample, device=device)
|
|
38
41
|
eigvec = eigvec[sample_idx]
|
|
39
42
|
|
|
40
43
|
eigvec = F.normalize(eigvec, dim=1)
|
|
41
44
|
|
|
42
45
|
# Initialize R matrix with the first column from Farthest Point Sampling
|
|
43
|
-
_sample_idx = farthest_point_sampling(eigvec, k)
|
|
46
|
+
_sample_idx = farthest_point_sampling(eigvec, k, device=device)
|
|
44
47
|
R = eigvec[_sample_idx].T
|
|
45
48
|
|
|
49
|
+
original_device = eigvec.device
|
|
50
|
+
device = auto_divice(original_device, device)
|
|
51
|
+
eigvec = eigvec.to(device=device)
|
|
52
|
+
R = R.to(device=device)
|
|
53
|
+
|
|
46
54
|
# Iterative optimization loop
|
|
47
55
|
last_objective_value = 0
|
|
48
56
|
exit_loop = False
|
|
@@ -54,12 +62,13 @@ def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
|
|
|
54
62
|
# Discretize the projected eigenvectors
|
|
55
63
|
_eigenvectors_continuous = eigvec @ R
|
|
56
64
|
_eigenvectors_discrete = _onehot_discretize(_eigenvectors_continuous)
|
|
57
|
-
_eigenvectors_discrete = _eigenvectors_discrete.to(device=
|
|
65
|
+
_eigenvectors_discrete = _eigenvectors_discrete.to(device=device, dtype=eigvec.dtype)
|
|
58
66
|
|
|
59
67
|
# SVD decomposition
|
|
60
68
|
_out = _eigenvectors_discrete.T @ eigvec
|
|
61
69
|
U, S, Vh = torch.linalg.svd(_out, full_matrices=False)
|
|
62
70
|
V = Vh.T
|
|
71
|
+
# U, S, V = svd_lowrank(_out, 100)
|
|
63
72
|
|
|
64
73
|
# Compute the Ncut value
|
|
65
74
|
ncut_value = 2 * (n - torch.sum(S))
|
|
@@ -71,7 +80,8 @@ def axis_align(eigvec: torch.Tensor, max_iter=1000, n_sample=10240):
|
|
|
71
80
|
else:
|
|
72
81
|
last_objective_value = ncut_value
|
|
73
82
|
R = V @ U.T
|
|
74
|
-
|
|
83
|
+
|
|
84
|
+
R = R.to(device=original_device)
|
|
75
85
|
return R
|
|
76
86
|
|
|
77
87
|
|
|
@@ -12,7 +12,6 @@ _NYSTROM_CONFIG = {
|
|
|
12
12
|
'n_sample2': 1024, # number of samples for eigenvector propagation, 1024 is large enough for most cases
|
|
13
13
|
'n_neighbors': 10, # number of neighbors for eigenvector propagation, 10 is large enough for most cases
|
|
14
14
|
'matmul_chunk_size': 16384, # chunk size for matrix multiplication, larger chunk size is faster but requires more memory
|
|
15
|
-
'sample_method': "farthest", # sample method for nystrom approximation, 'farthest' is FPS(Farthest Point Sampling)
|
|
16
15
|
'move_output_to_cpu': True, # if True, will move output to cpu, which saves memory but loses gradients
|
|
17
16
|
}
|
|
18
17
|
|
|
@@ -102,17 +101,6 @@ def _plain_ncut(
|
|
|
102
101
|
A: torch.Tensor,
|
|
103
102
|
n_eig: int = 100,
|
|
104
103
|
):
|
|
105
|
-
"""Normalized Cut.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
A (torch.Tensor): affinity matrix, shape (N, N)
|
|
109
|
-
n_eig (int): number of eigenvectors to return
|
|
110
|
-
|
|
111
|
-
Returns:
|
|
112
|
-
(torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (N, n_eig)
|
|
113
|
-
(torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
|
|
114
|
-
"""
|
|
115
|
-
|
|
116
104
|
# normalization; A = D^(-1/2) A D^(-1/2)
|
|
117
105
|
A = normalize_affinity(A)
|
|
118
106
|
|
|
@@ -120,6 +108,11 @@ def _plain_ncut(
|
|
|
120
108
|
|
|
121
109
|
# correct the random rotation (flipping sign) of eigenvectors
|
|
122
110
|
eigvec = correct_rotation(eigvec)
|
|
111
|
+
|
|
112
|
+
assert not torch.any(torch.isnan(eigvec)), "eigvec contains NaN"
|
|
113
|
+
assert not torch.any(torch.isinf(eigvec)), "eigvec contains Inf"
|
|
114
|
+
assert not torch.any(torch.isnan(eigval)), "eigval contains NaN"
|
|
115
|
+
assert not torch.any(torch.isinf(eigval)), "eigval contains Inf"
|
|
123
116
|
|
|
124
117
|
return eigvec, eigval
|
|
125
118
|
|