ncut-pytorch 3.0.0.dev1__tar.gz → 3.0.0.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/PKG-INFO +1 -1
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/ncut.py +19 -26
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/ncuts/ncut_click.py +4 -3
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/ncuts/ncut_nystrom.py +51 -30
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/math.py +29 -5
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/sigma.py +11 -17
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch.egg-info/PKG-INFO +1 -1
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/pyproject.toml +1 -1
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/LICENSE +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/README.md +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/color/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/color/coloring.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/color/mspace.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/color/mspace_nopl.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/ncuts/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/api.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/patch.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/transform.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/vision_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/device.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/grad.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/sample.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/utils/torch_mod.py +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch.egg-info/requires.txt +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch.egg-info/top_level.txt +0 -0
- {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/setup.cfg +0 -0
|
@@ -14,28 +14,30 @@ class Ncut:
|
|
|
14
14
|
def __init__(
|
|
15
15
|
self,
|
|
16
16
|
n_eig: int = 100,
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
sigma: float = None,
|
|
20
|
-
repulsion_sigma: float = None,
|
|
21
|
-
repulsion_weight: float =
|
|
22
|
-
extrapolation_factor: float = 1.0,
|
|
23
|
-
device: str = None,
|
|
17
|
+
quantile_sigma: float = 0.25,
|
|
18
|
+
quantile_sigma_repulsion: float = 0.20,
|
|
19
|
+
sigma: float | None = None,
|
|
20
|
+
repulsion_sigma: float | None = None,
|
|
21
|
+
repulsion_weight: float | None = None,
|
|
24
22
|
affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
|
|
23
|
+
extrapolation_factor: float = 1.0,
|
|
24
|
+
exact_gradient: bool = False,
|
|
25
|
+
device: str | None = None,
|
|
25
26
|
**kwargs,
|
|
26
27
|
):
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
30
|
Args:
|
|
30
31
|
n_eig (int): number of eigenvectors
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
n_eig (int): number of eigenvectors
|
|
33
|
+
quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
|
|
34
|
+
quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
|
|
33
35
|
sigma (float): affinity parameter, override d_sigma if provided
|
|
34
36
|
repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
|
|
35
37
|
repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
|
|
36
|
-
extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
|
|
37
|
-
device (str): device, default 'auto' (auto detect GPU)
|
|
38
38
|
affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
|
|
39
|
+
extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
|
|
40
|
+
exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False device (str): device, default 'auto' (auto detect GPU)
|
|
39
41
|
|
|
40
42
|
Examples:
|
|
41
43
|
>>> from ncut_pytorch import Ncut
|
|
@@ -52,13 +54,14 @@ class Ncut:
|
|
|
52
54
|
>>> print(new_eigvec.shape) # (500, 20)
|
|
53
55
|
"""
|
|
54
56
|
self.n_eig = n_eig
|
|
55
|
-
self.
|
|
57
|
+
self.quantile_sigma = quantile_sigma
|
|
58
|
+
self.quantile_sigma_repulsion = quantile_sigma_repulsion
|
|
56
59
|
self.sigma = sigma
|
|
57
60
|
self.repulsion_sigma = repulsion_sigma
|
|
58
61
|
self.repulsion_weight = repulsion_weight
|
|
59
62
|
self.extrapolation_factor = extrapolation_factor
|
|
63
|
+
self.exact_gradient = exact_gradient
|
|
60
64
|
self.device = device
|
|
61
|
-
self.track_grad = track_grad
|
|
62
65
|
self.affinity_fn = affinity_fn
|
|
63
66
|
self.kwargs = kwargs
|
|
64
67
|
|
|
@@ -83,12 +86,13 @@ class Ncut:
|
|
|
83
86
|
ncut_fn(
|
|
84
87
|
X,
|
|
85
88
|
n_eig=self.n_eig,
|
|
86
|
-
|
|
89
|
+
quantile_sigma=self.quantile_sigma,
|
|
90
|
+
quantile_sigma_repulsion=self.quantile_sigma_repulsion,
|
|
87
91
|
sigma=self.sigma,
|
|
88
92
|
repulsion_sigma=self.repulsion_sigma,
|
|
89
93
|
repulsion_weight=self.repulsion_weight,
|
|
90
94
|
device=self.device,
|
|
91
|
-
|
|
95
|
+
exact_gradient=self.exact_gradient,
|
|
92
96
|
no_propagation=True,
|
|
93
97
|
affinity_fn=self.affinity_fn,
|
|
94
98
|
**self.kwargs
|
|
@@ -121,7 +125,6 @@ class Ncut:
|
|
|
121
125
|
self._nystrom_x,
|
|
122
126
|
extrapolation_factor=self.extrapolation_factor,
|
|
123
127
|
device=self.device,
|
|
124
|
-
track_grad=self.track_grad,
|
|
125
128
|
**self.kwargs
|
|
126
129
|
)
|
|
127
130
|
return eigvec
|
|
@@ -137,15 +140,5 @@ class Ncut:
|
|
|
137
140
|
"""
|
|
138
141
|
return self.fit(X).transform(X)
|
|
139
142
|
|
|
140
|
-
def __new__(cls, X: torch.Tensor = None, n_eig: int = 100, track_grad: bool = False, d_sigma: float = None,
|
|
141
|
-
device: str = None, affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
|
|
142
|
-
**kwargs) -> Union["Ncut", torch.Tensor]:
|
|
143
|
-
if X is not None:
|
|
144
|
-
# function-like behavior
|
|
145
|
-
eigvec, eigval = ncut_fn(X, n_eig=n_eig, track_grad=track_grad, d_sigma=d_sigma, device=device, affinity_fn=affinity_fn, **kwargs)
|
|
146
|
-
return eigvec
|
|
147
|
-
# normal class instantiation
|
|
148
|
-
return super().__new__(cls)
|
|
149
|
-
|
|
150
143
|
def __call__(self, X: torch.Tensor) -> torch.Tensor:
|
|
151
144
|
return self.fit_transform(X)
|
|
@@ -22,10 +22,11 @@ def ncut_click_prompt(
|
|
|
22
22
|
click_weight: float = 0.5,
|
|
23
23
|
bg_weight: float = 0.1,
|
|
24
24
|
n_eig: int = 2,
|
|
25
|
-
|
|
25
|
+
quantile_sigma: float = 0.25,
|
|
26
26
|
device: str = None,
|
|
27
27
|
sigma: float = None,
|
|
28
28
|
affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
|
|
29
|
+
exact_gradient: bool = False,
|
|
29
30
|
no_propagation: bool = False,
|
|
30
31
|
return_indices_and_sigma: bool = False,
|
|
31
32
|
**kwargs,
|
|
@@ -55,7 +56,7 @@ def ncut_click_prompt(
|
|
|
55
56
|
|
|
56
57
|
# find optimal sigma for affinity matrix
|
|
57
58
|
if sigma is None and affinity_fn == rbf_affinity:
|
|
58
|
-
sigma = find_sigma_by_degree(nystrom_X,
|
|
59
|
+
sigma = find_sigma_by_degree(nystrom_X, quantile_sigma, affinity_fn)
|
|
59
60
|
# TODO: change to std()
|
|
60
61
|
elif sigma is None and affinity_fn == cosine_affinity:
|
|
61
62
|
sigma = 0.5
|
|
@@ -76,7 +77,7 @@ def ncut_click_prompt(
|
|
|
76
77
|
|
|
77
78
|
_A = click_weight * A_click + (1 - click_weight) * A
|
|
78
79
|
|
|
79
|
-
nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
|
|
80
|
+
nystrom_eigvec, eigval = _plain_ncut(_A, n_eig, exact_gradient=exact_gradient)
|
|
80
81
|
|
|
81
82
|
if no_propagation:
|
|
82
83
|
return nystrom_eigvec, eigval, nystrom_indices, sigma
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from ncut_pytorch.utils.sigma import find_sigma_by_degree
|
|
8
8
|
from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity
|
|
9
|
-
from ncut_pytorch.utils.math import gram_schmidt, normalize_affinity, grad_safe_eig_solve, correct_rotation, keep_topk_per_row
|
|
9
|
+
from ncut_pytorch.utils.math import gram_schmidt, normalize_affinity, grad_safe_eig_solve, correct_rotation, keep_topk_per_row, svd_lowrank
|
|
10
10
|
from ncut_pytorch.utils.sample import farthest_point_sampling
|
|
11
11
|
from ncut_pytorch.utils.device import auto_device
|
|
12
12
|
|
|
@@ -34,14 +34,16 @@ class NystromConfig:
|
|
|
34
34
|
def ncut_fn(
|
|
35
35
|
X: torch.Tensor,
|
|
36
36
|
n_eig: int = 100,
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
sigma: float = None,
|
|
40
|
-
repulsion_sigma: float = None,
|
|
41
|
-
repulsion_weight: float =
|
|
37
|
+
quantile_sigma: float = 0.25,
|
|
38
|
+
quantile_sigma_repulsion: float = 0.20,
|
|
39
|
+
sigma: float | None = None,
|
|
40
|
+
repulsion_sigma: float | None = None,
|
|
41
|
+
repulsion_weight: float | None = None,
|
|
42
|
+
affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
|
|
42
43
|
extrapolation_factor: float = 1.0,
|
|
44
|
+
exact_gradient: bool = False,
|
|
45
|
+
device: str | None = None,
|
|
43
46
|
make_orthogonal: bool = False,
|
|
44
|
-
affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
|
|
45
47
|
no_propagation: bool = False,
|
|
46
48
|
**kwargs,
|
|
47
49
|
) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]:
|
|
@@ -50,14 +52,15 @@ def ncut_fn(
|
|
|
50
52
|
Args:
|
|
51
53
|
X (torch.Tensor): input features, shape (N, D)
|
|
52
54
|
n_eig (int): number of eigenvectors
|
|
53
|
-
|
|
54
|
-
|
|
55
|
+
quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
|
|
56
|
+
quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
|
|
55
57
|
sigma (float): affinity parameter, override d_sigma if provided
|
|
56
58
|
repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
|
|
57
59
|
repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
|
|
60
|
+
affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
|
|
58
61
|
extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
|
|
62
|
+
exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False
|
|
59
63
|
make_orthogonal (bool): make eigenvectors orthogonal
|
|
60
|
-
affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
|
|
61
64
|
|
|
62
65
|
Returns:
|
|
63
66
|
eigenvectors (torch.Tensor): shape (N, n_eig)
|
|
@@ -72,37 +75,28 @@ def ncut_fn(
|
|
|
72
75
|
"""
|
|
73
76
|
config = NystromConfig()
|
|
74
77
|
config.update(kwargs)
|
|
75
|
-
|
|
76
|
-
# use GPU if available
|
|
77
78
|
device = auto_device(X.device, device)
|
|
78
79
|
|
|
79
|
-
# check if enough data for nystrom approximation
|
|
80
|
-
is_enough_data = X.shape[0] > config.n_sample
|
|
81
|
-
|
|
82
80
|
# subsample for nystrom approximation
|
|
81
|
+
is_enough_data = X.shape[0] > config.n_sample
|
|
83
82
|
n_sample = min(config.n_sample, int(X.shape[0]*config.n_sample_max_ratio))
|
|
84
83
|
nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device) if is_enough_data else np.arange(X.shape[0])
|
|
85
84
|
nystrom_X = X[nystrom_indices].to(device)
|
|
86
85
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
sigma = 0.5
|
|
93
|
-
else:
|
|
94
|
-
raise ValueError(f"`sigma` needs to be provided for affinity function {affinity_fn}, (sigma=0.5)")
|
|
95
|
-
|
|
96
|
-
if repulsion_sigma is not None:
|
|
97
|
-
nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma_attraction=sigma, sigma_repulsion=repulsion_sigma, repulsion_weight=repulsion_weight, affinity_fn=affinity_fn)
|
|
86
|
+
sigma, repulsion_sigma = find_optimal_sigma(nystrom_X, quantile_sigma, quantile_sigma_repulsion, sigma, repulsion_sigma, affinity_fn)
|
|
87
|
+
|
|
88
|
+
if repulsion_sigma and repulsion_weight:
|
|
89
|
+
nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma,
|
|
90
|
+
repulsion_sigma, repulsion_weight, affinity_fn, exact_gradient)
|
|
98
91
|
else:
|
|
99
92
|
A = affinity_fn(nystrom_X, sigma=sigma)
|
|
100
|
-
nystrom_eigvec, eigval = _plain_ncut(A, n_eig)
|
|
93
|
+
nystrom_eigvec, eigval = _plain_ncut(A, n_eig, exact_gradient)
|
|
101
94
|
|
|
102
95
|
if no_propagation:
|
|
103
96
|
return nystrom_eigvec, eigval, nystrom_indices, sigma
|
|
104
97
|
|
|
105
98
|
if not is_enough_data:
|
|
99
|
+
# skip nystrom approximation if not enough data, use exact ncut
|
|
106
100
|
return nystrom_eigvec, eigval
|
|
107
101
|
|
|
108
102
|
# propagate eigenvectors from subgraph to full graph
|
|
@@ -123,6 +117,25 @@ def ncut_fn(
|
|
|
123
117
|
|
|
124
118
|
return eigvec, eigval
|
|
125
119
|
|
|
120
|
+
def find_optimal_sigma(
|
|
121
|
+
X: torch.Tensor,
|
|
122
|
+
quantile_sigma: float = 0.25,
|
|
123
|
+
quantile_sigma_repulsion: float = 0.20,
|
|
124
|
+
sigma: float | None = None,
|
|
125
|
+
repulsion_sigma: float | None = None,
|
|
126
|
+
affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
|
|
127
|
+
):
|
|
128
|
+
"""Find optimal sigma for affinity matrix and repulsion matrix."""
|
|
129
|
+
if affinity_fn == rbf_affinity:
|
|
130
|
+
sigma = sigma or find_sigma_by_degree(X, quantile_sigma, affinity_fn)
|
|
131
|
+
repulsion_sigma = repulsion_sigma or find_sigma_by_degree(X, quantile_sigma_repulsion, affinity_fn, init_sigma=sigma)
|
|
132
|
+
elif affinity_fn == cosine_affinity:
|
|
133
|
+
sigma = sigma or 0.5
|
|
134
|
+
repulsion_sigma = repulsion_sigma or 0.3
|
|
135
|
+
else:
|
|
136
|
+
if sigma is None:
|
|
137
|
+
raise ValueError(f"`sigma` need to be provided for affinity function {affinity_fn}, (sigma=0.5, repulsion_sigma=0.3)")
|
|
138
|
+
return sigma, repulsion_sigma
|
|
126
139
|
|
|
127
140
|
def ncut_with_repulsion(
|
|
128
141
|
X: torch.Tensor,
|
|
@@ -131,6 +144,7 @@ def ncut_with_repulsion(
|
|
|
131
144
|
sigma_repulsion: float = None,
|
|
132
145
|
repulsion_weight: float = 0.2,
|
|
133
146
|
affinity_fn: Union["rbf_affinity", "cosine_affinity"] = cosine_affinity,
|
|
147
|
+
exact_gradient: bool = False,
|
|
134
148
|
eps: float = 1e-8,
|
|
135
149
|
):
|
|
136
150
|
A = affinity_fn(X, sigma=sigma_attraction)
|
|
@@ -141,7 +155,10 @@ def ncut_with_repulsion(
|
|
|
141
155
|
D = D_A + D_R
|
|
142
156
|
W = A - R + torch.diag(D_R)
|
|
143
157
|
W = W / D[:, None]
|
|
144
|
-
|
|
158
|
+
if exact_gradient:
|
|
159
|
+
eigvec, eigval, _ = grad_safe_eig_solve(W, n_eig)
|
|
160
|
+
else:
|
|
161
|
+
eigvec, eigval, _ = svd_lowrank(W, n_eig)
|
|
145
162
|
eigvec = correct_rotation(eigvec)
|
|
146
163
|
return eigvec, eigval
|
|
147
164
|
|
|
@@ -149,9 +166,13 @@ def ncut_with_repulsion(
|
|
|
149
166
|
def _plain_ncut(
|
|
150
167
|
A: torch.Tensor,
|
|
151
168
|
n_eig: int = 100,
|
|
169
|
+
exact_gradient: bool = False,
|
|
152
170
|
):
|
|
153
171
|
A = normalize_affinity(A)
|
|
154
|
-
|
|
172
|
+
if exact_gradient:
|
|
173
|
+
eigvec, eigval, _ = grad_safe_eig_solve(A, n_eig)
|
|
174
|
+
else:
|
|
175
|
+
eigvec, eigval, _ = svd_lowrank(A, n_eig)
|
|
155
176
|
eigvec = eigvec[:, :n_eig]
|
|
156
177
|
eigval = eigval[:n_eig]
|
|
157
178
|
eigvec = correct_rotation(eigvec)
|
|
@@ -191,7 +212,7 @@ def nystrom_propagate(
|
|
|
191
212
|
nystrom_out = nystrom_out[indices].to(device)
|
|
192
213
|
nystrom_X = nystrom_X[indices].to(device)
|
|
193
214
|
|
|
194
|
-
sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity)
|
|
215
|
+
sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity, quantile_sigma=0.25)
|
|
195
216
|
sigma = sigma * extrapolation_factor
|
|
196
217
|
|
|
197
218
|
D = rbf_affinity(nystrom_X, sigma=sigma).mean(1)
|
|
@@ -18,14 +18,20 @@ import logging
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
|
|
22
|
+
_GAMMA_DEPRECATION_WARNED = False
|
|
23
|
+
|
|
24
|
+
from .torch_mod import svd_lowrank as my_svd_lowrank
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
def check_gamma_deprecated(gamma: float | None) -> float:
|
|
28
|
+
global _GAMMA_DEPRECATION_WARNED
|
|
25
29
|
if gamma is not None:
|
|
26
|
-
|
|
30
|
+
if not _GAMMA_DEPRECATION_WARNED:
|
|
31
|
+
logging.getLogger(__name__).warning("gamma is deprecated, use sigma instead")
|
|
32
|
+
_GAMMA_DEPRECATION_WARNED = True
|
|
27
33
|
sigma = np.sqrt(gamma)
|
|
28
|
-
|
|
34
|
+
return sigma
|
|
29
35
|
|
|
30
36
|
|
|
31
37
|
def rbf_affinity(
|
|
@@ -92,10 +98,11 @@ def grad_safe_eig_solve(
|
|
|
92
98
|
is_symmetric = mat.shape[0] == mat.shape[1]
|
|
93
99
|
if is_symmetric:
|
|
94
100
|
s, u = torch.linalg.eigh(mat)
|
|
95
|
-
s = s.flip(dims=[0])
|
|
96
|
-
u = u.flip(dims=[1])
|
|
97
101
|
else:
|
|
98
102
|
s, u = torch.linalg.eig(mat)
|
|
103
|
+
sort_idx = torch.argsort(s, dim=0, descending=True)
|
|
104
|
+
s = s[sort_idx]
|
|
105
|
+
u = u[:, sort_idx]
|
|
99
106
|
return u.to(dtype), s.to(dtype), None
|
|
100
107
|
|
|
101
108
|
try:
|
|
@@ -122,6 +129,23 @@ def pca_lowrank(
|
|
|
122
129
|
return u @ torch.diag(s)
|
|
123
130
|
|
|
124
131
|
|
|
132
|
+
def svd_lowrank(mat: torch.Tensor, q: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
133
|
+
"""SVD lowrank implementation for float16 and bfloat16."""
|
|
134
|
+
dtype = mat.dtype
|
|
135
|
+
try:
|
|
136
|
+
with torch.autocast(device_type=mat.device.type, enabled=False):
|
|
137
|
+
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
138
|
+
mat = mat.float() # svd_lowrank does not support float16
|
|
139
|
+
u, s, v = my_svd_lowrank(mat, q=q + 10)
|
|
140
|
+
except RuntimeError:
|
|
141
|
+
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
142
|
+
mat = mat.float()
|
|
143
|
+
u, s, v = my_svd_lowrank(mat, q=q + 10)
|
|
144
|
+
|
|
145
|
+
u, s, v = u[:, :q], s[:q], v[:, :q]
|
|
146
|
+
return u.to(dtype), s.to(dtype), v.to(dtype)
|
|
147
|
+
|
|
148
|
+
|
|
125
149
|
def quantile_min_max(
|
|
126
150
|
x: torch.Tensor,
|
|
127
151
|
q1: float = 0.01,
|
|
@@ -9,7 +9,7 @@ from .sample import farthest_point_sampling
|
|
|
9
9
|
@torch.no_grad()
|
|
10
10
|
def _find_sigma_by_degree(
|
|
11
11
|
X: torch.Tensor, # [n_samples, n_features]
|
|
12
|
-
|
|
12
|
+
quantile_sigma: float = 0.25,
|
|
13
13
|
affinity_fn: callable = rbf_affinity,
|
|
14
14
|
X2: torch.Tensor | None = None,
|
|
15
15
|
init_sigma: float = 0.5,
|
|
@@ -17,27 +17,21 @@ def _find_sigma_by_degree(
|
|
|
17
17
|
max_iter: int = 100,
|
|
18
18
|
) -> float:
|
|
19
19
|
"""Binary search for optimal sigma to achieve target mean edge weight."""
|
|
20
|
-
if
|
|
21
|
-
|
|
20
|
+
if quantile_sigma <= 0 or quantile_sigma >= 1:
|
|
21
|
+
raise ValueError(f"quantile_sigma must be between 0 and 1, got {quantile_sigma}")
|
|
22
22
|
sigma = init_sigma
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
current_degrees = affinity_fn(X, X2=X2, sigma=scale_inv_sigma).mean(1)
|
|
28
|
-
for _ in range(2):
|
|
29
|
-
current_degree = current_degrees.mean().item()
|
|
30
|
-
mask = current_degrees < current_degree
|
|
31
|
-
current_degrees = current_degrees[mask]
|
|
32
|
-
d_sigma = current_degrees.mean().item()
|
|
24
|
+
scale_inv_sigma = X.std(0).sum()
|
|
25
|
+
current_degrees = affinity_fn(X, X2=X2, sigma=scale_inv_sigma).mean(1)
|
|
26
|
+
target_degree = current_degrees.float().quantile(quantile_sigma).item()
|
|
33
27
|
|
|
34
28
|
# Binary search for sigma
|
|
35
29
|
current_degree = affinity_fn(X, X2=X2, sigma=sigma).mean().item()
|
|
36
30
|
low, high = 0, float('inf')
|
|
37
|
-
tol = r_tol *
|
|
31
|
+
tol = r_tol * target_degree
|
|
38
32
|
i_iter = 0
|
|
39
|
-
while abs(current_degree -
|
|
40
|
-
if current_degree >
|
|
33
|
+
while abs(current_degree - target_degree) > tol and i_iter < max_iter:
|
|
34
|
+
if current_degree > target_degree:
|
|
41
35
|
high = sigma
|
|
42
36
|
sigma = (low + sigma) / 2
|
|
43
37
|
else:
|
|
@@ -52,7 +46,7 @@ def _find_sigma_by_degree(
|
|
|
52
46
|
@torch.no_grad()
|
|
53
47
|
def find_sigma_by_degree(
|
|
54
48
|
X: torch.Tensor, # [n_samples, n_features]
|
|
55
|
-
|
|
49
|
+
quantile_sigma: float = 0.25,
|
|
56
50
|
affinity_fn: callable = rbf_affinity,
|
|
57
51
|
X2: torch.Tensor | None = None,
|
|
58
52
|
init_sigma: float = 0.5,
|
|
@@ -62,4 +56,4 @@ def find_sigma_by_degree(
|
|
|
62
56
|
) -> float:
|
|
63
57
|
"""Find sigma after FPS-based downsampling for efficiency."""
|
|
64
58
|
indices = farthest_point_sampling(X, n_sample)
|
|
65
|
-
return _find_sigma_by_degree(X[indices],
|
|
59
|
+
return _find_sigma_by_degree(X[indices], quantile_sigma, affinity_fn, X2=X2, init_sigma=init_sigma, r_tol=r_tol, max_iter=max_iter)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/hires_dino.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/lowres_dino.py
RENAMED
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino/transform.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/dino_predictor.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/jafar_predictor.py
RENAMED
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch/predictor/vision_predictor.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev3}/ncut_pytorch.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|