ncut-pytorch 2.3.3__tar.gz → 3.0.0.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/PKG-INFO +3 -3
  2. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/README.md +1 -1
  3. ncut_pytorch-3.0.0.dev0/ncut_pytorch/__init__.py +4 -0
  4. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/color/coloring.py +2 -2
  5. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/color/mspace.py +104 -96
  6. ncut_pytorch-3.0.0.dev0/ncut_pytorch/color/mspace_nopl.py +550 -0
  7. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/ncut.py +17 -17
  8. ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/__init__.py +0 -0
  9. ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/ncut_click.py +106 -0
  10. ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/ncut_kway.py +163 -0
  11. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/ncuts/ncut_nystrom.py +50 -44
  12. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/dinov3.py +27 -0
  13. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/predictor.py +20 -13
  14. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/vision_predictor.py +4 -3
  15. ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/__init__.py +0 -0
  16. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/utils/device.py +2 -0
  17. ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/grad.py +154 -0
  18. ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/math.py +217 -0
  19. ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/sample.py +64 -0
  20. ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/sigma.py +65 -0
  21. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/utils/torch_mod.py +2 -0
  22. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/PKG-INFO +3 -3
  23. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/SOURCES.txt +2 -1
  24. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/requires.txt +1 -1
  25. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/pyproject.toml +2 -2
  26. ncut_pytorch-2.3.3/ncut_pytorch/__init__.py +0 -15
  27. ncut_pytorch-2.3.3/ncut_pytorch/color/__init__.py +0 -6
  28. ncut_pytorch-2.3.3/ncut_pytorch/ncuts/__init__.py +0 -3
  29. ncut_pytorch-2.3.3/ncut_pytorch/ncuts/ncut_click.py +0 -107
  30. ncut_pytorch-2.3.3/ncut_pytorch/ncuts/ncut_kway.py +0 -110
  31. ncut_pytorch-2.3.3/ncut_pytorch/utils/gamma.py +0 -60
  32. ncut_pytorch-2.3.3/ncut_pytorch/utils/grad.py +0 -41
  33. ncut_pytorch-2.3.3/ncut_pytorch/utils/math.py +0 -339
  34. ncut_pytorch-2.3.3/ncut_pytorch/utils/sample.py +0 -93
  35. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/LICENSE +0 -0
  36. {ncut_pytorch-2.3.3/ncut_pytorch/utils → ncut_pytorch-3.0.0.dev0/ncut_pytorch/color}/__init__.py +0 -0
  37. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/__init__.py +0 -0
  38. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/__init__.py +0 -0
  39. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/api.py +0 -0
  40. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
  41. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
  42. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/patch.py +0 -0
  43. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/transform.py +0 -0
  44. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino_predictor.py +0 -0
  45. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
  46. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  47. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/top_level.txt +0 -0
  48. {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 2.3.3
3
+ Version: 3.0.0.dev0
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -17,7 +17,7 @@ Requires-Python: >=3
17
17
  Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: numpy<2.0
20
- Requires-Dist: fpsample>=0.2.0
20
+ Requires-Dist: torch-quickfps
21
21
  Requires-Dist: tqdm
22
22
  Requires-Dist: pillow
23
23
  Requires-Dist: opencv-python
@@ -41,7 +41,7 @@ Requires-Dist: ncut_pytorch[cmap,dev,torch]; extra == "all"
41
41
  Dynamic: license-file
42
42
 
43
43
 
44
- ### [🌐Documentation (old version)](https://ncut-pytorch.readthedocs.io/) | [🤗HuggingFace Demo](https://huggingface.co/spaces/huzey/ncut-pytorch)
44
+ ### [🌐Documentation](https://ncut-pytorch.readthedocs.io/) | [🤗HuggingFace Demo](https://huggingface.co/spaces/huzey/ncut-pytorch)
45
45
 
46
46
 
47
47
  ## Nyström Normalized Cut
@@ -1,5 +1,5 @@
1
1
 
2
- ### [🌐Documentation (old version)](https://ncut-pytorch.readthedocs.io/) | [🤗HuggingFace Demo](https://huggingface.co/spaces/huzey/ncut-pytorch)
2
+ ### [🌐Documentation](https://ncut-pytorch.readthedocs.io/) | [🤗HuggingFace Demo](https://huggingface.co/spaces/huzey/ncut-pytorch)
3
3
 
4
4
 
5
5
  ## Nyström Normalized Cut
@@ -0,0 +1,4 @@
1
+ from ncut_pytorch.ncut import Ncut
2
+ from ncut_pytorch.ncuts.ncut_nystrom import ncut_fn
3
+ from ncut_pytorch.ncuts.ncut_kway import kway_ncut, axis_align, quick_kway
4
+ from ncut_pytorch.ncuts.ncut_click import ncut_click_prompt
@@ -7,7 +7,6 @@ from numba.core.types import none
7
7
  import numpy as np
8
8
  import torch
9
9
 
10
- from .mspace import mspace_viz_transform
11
10
  from ncut_pytorch.ncuts.ncut_nystrom import nystrom_propagate
12
11
  from ncut_pytorch.utils.math import quantile_normalize
13
12
  from ncut_pytorch.utils.sample import farthest_point_sampling
@@ -20,7 +19,7 @@ def _identity(X: torch.Tensor) -> torch.Tensor:
20
19
  def mspace_color(
21
20
  X: torch.Tensor,
22
21
  q: float = 0.95,
23
- n_eig: Optional[int] = 16,
22
+ n_eig: Optional[int] = 8,
24
23
  n_dim: int = 3,
25
24
  training_steps: int = 1000,
26
25
  progress_bar: bool = False,
@@ -31,6 +30,7 @@ def mspace_color(
31
30
  (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
32
31
  (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
33
32
  """
33
+ from .mspace import mspace_viz_transform
34
34
 
35
35
  low_dim_embedding = mspace_viz_transform(
36
36
  X=X,
@@ -4,8 +4,9 @@ import logging
4
4
  from collections import defaultdict
5
5
  from functools import partial
6
6
  import warnings
7
-
8
7
  import pytorch_lightning as pl
8
+
9
+ import numpy as np
9
10
  import torch
10
11
  import torch.nn as nn
11
12
  import torch.nn.functional as F
@@ -13,23 +14,8 @@ from torch.utils.data import TensorDataset
13
14
  from tqdm import tqdm
14
15
 
15
16
  from ncut_pytorch.utils.device import auto_device
16
-
17
- # disable lightning logs
18
- logging.getLogger('lightning').setLevel(0)
19
- logging.getLogger("pytorch_lightning").setLevel(0)
20
- class IgnorePLFilter(logging.Filter):
21
- def filter(self, record):
22
- keywords = ['available:', 'CUDA', 'LOCAL_RANK:']
23
- return not any(keyword in record.getMessage() for keyword in keywords)
24
- logging.getLogger('pytorch_lightning.utilities.rank_zero').addFilter(IgnorePLFilter())
25
- logging.getLogger('pytorch_lightning.accelerators.cuda').addFilter(IgnorePLFilter())
26
-
27
-
28
- from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity
29
- from ncut_pytorch.ncuts.ncut_nystrom import _plain_ncut
30
- from ncut_pytorch.ncuts.ncut_kway import kway_ncut
31
- from ncut_pytorch.utils.gamma import find_gamma_by_degree_after_fps
32
- from ncut_pytorch.utils.math import compute_riemann_curvature_loss, compute_boundary_loss, compute_repulsion_loss, compute_axis_align_loss, compute_attraction_loss, find_elbow
17
+ from ncut_pytorch import ncut_fn
18
+ from ncut_pytorch.utils.math import grad_safe_eig_solve, normalize_affinity, rbf_affinity, svd_lowrank, cosine_affinity
33
19
 
34
20
 
35
21
  def _flag_ncut_loss(eigvec_gt, eigvec_hat, n_eig, weight):
@@ -65,10 +51,21 @@ def filter_closeby_eigval(eigvec, eigval, threshold=1e-3):
65
51
  keep_idx = torch.where(eigval_diff > threshold)[0]
66
52
  return eigvec[:, keep_idx], eigval[keep_idx]
67
53
 
68
- def ncut_wrapper(features, n_eig, gamma=1.0):
69
- A = rbf_affinity(features, gamma=gamma)
70
- eigvec, eigval = _plain_ncut(A, n_eig)
71
- eigvec, eigval = filter_closeby_eigval(eigvec, eigval)
54
+ def ncut_wrapper(features, n_eig, sigma=None):
55
+ # TODO: gradient is not stable.
56
+
57
+ # features.requires_grad_(True)
58
+ sigma = sigma or features.std(0).sum().item()
59
+ # eigvec, eigval = ncut_fn(features, n_eig, sigma=sigma, track_grad=True)
60
+ W = rbf_affinity(features, sigma=sigma)
61
+ # W = cosine_affinity(features, sigma=1.0)
62
+ A = normalize_affinity(W)
63
+ eigvec, eigval, _ = svd_lowrank(A, n_eig)
64
+ # eigval, eigvec = torch.linalg.eigh(A)
65
+ # eigvec = eigvec.flip(dims=[1])
66
+ # eigval = eigval.flip(dims=[0])
67
+ # eigvec = eigvec[:, :n_eig]
68
+ # eigval = eigval[:n_eig]
72
69
  return eigvec, eigval
73
70
 
74
71
 
@@ -332,11 +329,11 @@ class TrainEncoder(pl.LightningModule):
332
329
  input_feats, output_feats = batch
333
330
 
334
331
  stored_eigvec_gt = {}
335
- with torch.no_grad():
336
- # Compute eigvec_gt only once for each set of iterations
337
- if self.eigvec_loss != 0:
338
- eigvec_gt, eigval_gt = ncut_wrapper(input_feats, self.n_eig)
339
- stored_eigvec_gt = eigvec_gt
332
+ # with torch.no_grad():
333
+ # Compute eigvec_gt only once for each set of iterations
334
+ if self.eigvec_loss != 0:
335
+ eigvec_gt, eigval_gt = ncut_wrapper(input_feats, self.n_eig, sigma=input_feats.std(0).sum().item())
336
+ stored_eigvec_gt = eigvec_gt
340
337
 
341
338
 
342
339
  # Run the same batch 10 times, updating parameters after each iteration
@@ -348,7 +345,9 @@ class TrainEncoder(pl.LightningModule):
348
345
 
349
346
  total_loss = 0
350
347
  if self.eigvec_loss != 0:
351
- eigvec_hat, eigval_hat = ncut_wrapper(feats_compressed, self.n_eig)
348
+ assert feats_compressed.requires_grad, "feats_compressed must be a tensor with requires_grad=True"
349
+ eigvec_hat, eigval_hat = ncut_wrapper(feats_compressed, self.n_eig, sigma=feats_compressed.std(0).sum().item())
350
+ assert eigvec_hat.requires_grad, "eigvec_hat must be a tensor with requires_grad=True"
352
351
  eigvec_loss = flag_space_loss(stored_eigvec_gt, eigvec_hat, n_eig=self.n_eig, weight=None)
353
352
  eigvec_loss = eigvec_loss * self.eigvec_loss
354
353
  total_loss += eigvec_loss
@@ -361,37 +360,6 @@ class TrainEncoder(pl.LightningModule):
361
360
  total_loss += recon_loss
362
361
  self._log_loss(recon_loss, "recon", log_grad_norm=log_grad_norm)
363
362
 
364
-
365
- if self.riemann_curvature_loss > 0:
366
- riemann_curvature_loss = compute_riemann_curvature_loss(feats_compressed)
367
- riemann_curvature_loss = riemann_curvature_loss * self.riemann_curvature_loss
368
- total_loss += riemann_curvature_loss
369
- self._log_loss(riemann_curvature_loss, "riemann_curvature", log_grad_norm=log_grad_norm)
370
-
371
- if self.axis_align_loss > 0:
372
- axis_align_loss = compute_axis_align_loss(feats_compressed)
373
- axis_align_loss = axis_align_loss * self.axis_align_loss
374
- total_loss += axis_align_loss
375
- self._log_loss(axis_align_loss, "axis_align", log_grad_norm=log_grad_norm)
376
-
377
- if self.repulsion_loss > 0:
378
- repulsion_loss = compute_repulsion_loss(feats_compressed)
379
- repulsion_loss = repulsion_loss * self.repulsion_loss
380
- total_loss += repulsion_loss
381
- self._log_loss(repulsion_loss, "repulsion", log_grad_norm=log_grad_norm)
382
-
383
- if self.attraction_loss > 0:
384
- attraction_loss = compute_attraction_loss(feats_compressed)
385
- attraction_loss = attraction_loss * self.attraction_loss
386
- total_loss += attraction_loss
387
- self._log_loss(attraction_loss, "attraction", log_grad_norm=log_grad_norm)
388
-
389
- if self.boundary_loss > 0:
390
- boundary_loss = compute_boundary_loss(feats_compressed)
391
- boundary_loss = boundary_loss * self.boundary_loss
392
- total_loss += boundary_loss
393
- self._log_loss(boundary_loss, "boundary", log_grad_norm=log_grad_norm)
394
-
395
363
  if self.zero_center_loss > 0:
396
364
  # zero_center_loss = feats_compressed.abs().mean()
397
365
  zero_center_loss = (feats_compressed ** 2).mean()
@@ -499,48 +467,87 @@ class BestModelsAvgCallback(pl.Callback):
499
467
  self.best_loss = float('inf')
500
468
 
501
469
 
470
+ def compute_repulsion_loss(
471
+ points: torch.Tensor, # [N, D]
472
+ ) -> torch.Tensor:
473
+ """Computes repulsion loss between points to prevent collapse."""
474
+ dist_matrix = torch.cdist(points, points)
475
+ mask = torch.eye(points.shape[0], device=points.device).bool()
476
+ dist_matrix = dist_matrix + mask * 1e10
477
+ nearest_dists, _ = torch.min(dist_matrix, dim=1)
478
+ repulsion = 1.0 / (nearest_dists + 0.01)
479
+ return torch.mean(repulsion)
480
+
481
+
502
482
  def train_mspace_model(compress_feats, uncompress_feats, training_steps=500, decoder_training_steps=1000,
503
483
  batch_size=1000, return_trainer=False, progress_bar=True,
504
484
  logger=False, use_wandb=False, model_avg_window=3, **model_kwargs):
505
- compress_feats = compress_feats.float().cpu()
506
- uncompress_feats = uncompress_feats.float().cpu()
507
- l, c_in = compress_feats.shape
508
- c_out = uncompress_feats.shape[1]
509
-
510
- model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
485
+ # Disable lightning logs only for this function
486
+ original_lightning_level = logging.getLogger('lightning').level
487
+ original_pytorch_lightning_level = logging.getLogger("pytorch_lightning").level
511
488
 
512
- dataset = TensorDataset(compress_feats, uncompress_feats)
513
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
514
-
515
- trainer_args = {
516
- 'accelerator': auto_device(),
517
- 'devices': 1,
518
- 'enable_checkpointing': False,
519
- 'enable_progress_bar': False,
520
- 'enable_model_summary': False,
521
- 'callbacks': [BestModelsAvgCallback(top_k=model_avg_window)],
522
- }
523
-
524
- if use_wandb and not logger:
525
- logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
526
-
527
- # train the autoencoder jointly
528
- trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
529
- trainer.fit(model, dataloader)
489
+ logging.getLogger('lightning').setLevel(0)
490
+ logging.getLogger("pytorch_lightning").setLevel(0)
530
491
 
531
- mspace_ae = model.mspace_ae
532
-
533
- if decoder_training_steps > 0:
534
- # train the decoder only
535
- trainer = pl.Trainer(max_steps=decoder_training_steps, logger=logger, **trainer_args)
536
- model2 = TrainDecoder(mspace_ae, progress_bar=progress_bar, training_steps=decoder_training_steps, **model_kwargs)
537
- trainer.fit(model2, dataloader)
538
- model.mspace_ae = model2.mspace_ae
539
-
540
- if return_trainer:
541
- return model, trainer
492
+ class IgnorePLFilter(logging.Filter):
493
+ def filter(self, record):
494
+ keywords = ['available:', 'CUDA', 'LOCAL_RANK:']
495
+ return not any(keyword in record.getMessage() for keyword in keywords)
496
+
497
+ pl_filter = IgnorePLFilter()
498
+ logger_rank_zero = logging.getLogger('pytorch_lightning.utilities.rank_zero')
499
+ logger_cuda = logging.getLogger('pytorch_lightning.accelerators.cuda')
500
+ logger_rank_zero.addFilter(pl_filter)
501
+ logger_cuda.addFilter(pl_filter)
502
+
503
+ try:
504
+ compress_feats = compress_feats.float().cpu()
505
+ uncompress_feats = uncompress_feats.float().cpu()
506
+ l, c_in = compress_feats.shape
507
+ c_out = uncompress_feats.shape[1]
542
508
 
543
- return model
509
+ model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
510
+
511
+ dataset = TensorDataset(compress_feats, uncompress_feats)
512
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
513
+
514
+ trainer_args = {
515
+ 'accelerator': auto_device(),
516
+ 'devices': 1,
517
+ 'enable_checkpointing': False,
518
+ 'enable_progress_bar': False,
519
+ 'enable_model_summary': False,
520
+ 'callbacks': [BestModelsAvgCallback(top_k=model_avg_window)],
521
+ }
522
+
523
+ if use_wandb and not logger:
524
+ logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
525
+
526
+ # train the autoencoder jointly
527
+ trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
528
+ trainer.fit(model, dataloader)
529
+
530
+ mspace_ae = model.mspace_ae
531
+
532
+ if decoder_training_steps > 0:
533
+ # train the decoder only
534
+ trainer = pl.Trainer(max_steps=decoder_training_steps, logger=logger, **trainer_args)
535
+ model2 = TrainDecoder(mspace_ae, progress_bar=progress_bar, training_steps=decoder_training_steps, **model_kwargs)
536
+ trainer.fit(model2, dataloader)
537
+ model.mspace_ae = model2.mspace_ae
538
+
539
+ if return_trainer:
540
+ result = model, trainer
541
+ else:
542
+ result = model
543
+ finally:
544
+ # Restore original logging configuration
545
+ logging.getLogger('lightning').setLevel(original_lightning_level)
546
+ logging.getLogger("pytorch_lightning").setLevel(original_pytorch_lightning_level)
547
+ logger_rank_zero.removeFilter(pl_filter)
548
+ logger_cuda.removeFilter(pl_filter)
549
+
550
+ return result
544
551
 
545
552
 
546
553
  def try_train_mspace(*args, **kwargs):
@@ -576,7 +583,8 @@ def try_train_mspace(*args, **kwargs):
576
583
 
577
584
  def mspace_viz_transform(X, return_model=False, **kwargs):
578
585
  X = X.float().cpu()
579
- model, trainer = try_train_mspace(X, X, return_trainer=True, **kwargs)
586
+ # model, trainer = try_train_mspace(X, X, return_trainer=True, **kwargs)
587
+ model, trainer = train_mspace_model(X, X, return_trainer=True, **kwargs)
580
588
 
581
589
  batch_size = kwargs.get('batch_size', 1000)
582
590
  test_loader = torch.utils.data.DataLoader(TensorDataset(X), batch_size=batch_size, shuffle=False, num_workers=0)