ncut-pytorch 3.0.0.dev1__tar.gz → 3.0.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/PKG-INFO +1 -1
  2. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/ncut.py +19 -26
  3. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/ncuts/ncut_click.py +4 -3
  4. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/ncuts/ncut_nystrom.py +51 -30
  5. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/math.py +18 -1
  6. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/sigma.py +11 -17
  7. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch.egg-info/PKG-INFO +1 -1
  8. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/pyproject.toml +1 -1
  9. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/LICENSE +0 -0
  10. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/README.md +0 -0
  11. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/__init__.py +0 -0
  12. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/color/__init__.py +0 -0
  13. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/color/coloring.py +0 -0
  14. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/color/mspace.py +0 -0
  15. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/color/mspace_nopl.py +0 -0
  16. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/ncuts/__init__.py +0 -0
  17. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
  18. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/__init__.py +0 -0
  19. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/__init__.py +0 -0
  20. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/api.py +0 -0
  21. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
  22. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
  23. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
  24. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/patch.py +0 -0
  25. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino/transform.py +0 -0
  26. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/dino_predictor.py +0 -0
  27. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
  28. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/predictor.py +0 -0
  29. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/predictor/vision_predictor.py +0 -0
  30. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/__init__.py +0 -0
  31. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/device.py +0 -0
  32. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/grad.py +0 -0
  33. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/sample.py +0 -0
  34. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch/utils/torch_mod.py +0 -0
  35. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
  36. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  37. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch.egg-info/requires.txt +0 -0
  38. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/ncut_pytorch.egg-info/top_level.txt +0 -0
  39. {ncut_pytorch-3.0.0.dev1 → ncut_pytorch-3.0.0.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev1
3
+ Version: 3.0.0.dev2
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -14,28 +14,30 @@ class Ncut:
14
14
  def __init__(
15
15
  self,
16
16
  n_eig: int = 100,
17
- track_grad: bool = False,
18
- d_sigma: float = None,
19
- sigma: float = None,
20
- repulsion_sigma: float = None,
21
- repulsion_weight: float = 0.2,
22
- extrapolation_factor: float = 1.0,
23
- device: str = None,
17
+ quantile_sigma: float = 0.25,
18
+ quantile_sigma_repulsion: float = 0.20,
19
+ sigma: float | None = None,
20
+ repulsion_sigma: float | None = None,
21
+ repulsion_weight: float | None = None,
24
22
  affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
23
+ extrapolation_factor: float = 1.0,
24
+ exact_gradient: bool = False,
25
+ device: str | None = None,
25
26
  **kwargs,
26
27
  ):
27
28
  """
28
29
 
29
30
  Args:
30
31
  n_eig (int): number of eigenvectors
31
- track_grad (bool): keep track of pytorch gradients
32
- d_sigma (float): affinity sigma parameter, lower d_sigma results in a sharper eigenvectors
32
+ n_eig (int): number of eigenvectors
33
+ quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
34
+ quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
33
35
  sigma (float): affinity parameter, override d_sigma if provided
34
36
  repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
35
37
  repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
36
- extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
37
- device (str): device, default 'auto' (auto detect GPU)
38
38
  affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
39
+ extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
40
+ exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False device (str): device, default 'auto' (auto detect GPU)
39
41
 
40
42
  Examples:
41
43
  >>> from ncut_pytorch import Ncut
@@ -52,13 +54,14 @@ class Ncut:
52
54
  >>> print(new_eigvec.shape) # (500, 20)
53
55
  """
54
56
  self.n_eig = n_eig
55
- self.d_sigma = d_sigma
57
+ self.quantile_sigma = quantile_sigma
58
+ self.quantile_sigma_repulsion = quantile_sigma_repulsion
56
59
  self.sigma = sigma
57
60
  self.repulsion_sigma = repulsion_sigma
58
61
  self.repulsion_weight = repulsion_weight
59
62
  self.extrapolation_factor = extrapolation_factor
63
+ self.exact_gradient = exact_gradient
60
64
  self.device = device
61
- self.track_grad = track_grad
62
65
  self.affinity_fn = affinity_fn
63
66
  self.kwargs = kwargs
64
67
 
@@ -83,12 +86,13 @@ class Ncut:
83
86
  ncut_fn(
84
87
  X,
85
88
  n_eig=self.n_eig,
86
- d_sigma=self.d_sigma,
89
+ quantile_sigma=self.quantile_sigma,
90
+ quantile_sigma_repulsion=self.quantile_sigma_repulsion,
87
91
  sigma=self.sigma,
88
92
  repulsion_sigma=self.repulsion_sigma,
89
93
  repulsion_weight=self.repulsion_weight,
90
94
  device=self.device,
91
- track_grad=self.track_grad,
95
+ exact_gradient=self.exact_gradient,
92
96
  no_propagation=True,
93
97
  affinity_fn=self.affinity_fn,
94
98
  **self.kwargs
@@ -121,7 +125,6 @@ class Ncut:
121
125
  self._nystrom_x,
122
126
  extrapolation_factor=self.extrapolation_factor,
123
127
  device=self.device,
124
- track_grad=self.track_grad,
125
128
  **self.kwargs
126
129
  )
127
130
  return eigvec
@@ -137,15 +140,5 @@ class Ncut:
137
140
  """
138
141
  return self.fit(X).transform(X)
139
142
 
140
- def __new__(cls, X: torch.Tensor = None, n_eig: int = 100, track_grad: bool = False, d_sigma: float = None,
141
- device: str = None, affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
142
- **kwargs) -> Union["Ncut", torch.Tensor]:
143
- if X is not None:
144
- # function-like behavior
145
- eigvec, eigval = ncut_fn(X, n_eig=n_eig, track_grad=track_grad, d_sigma=d_sigma, device=device, affinity_fn=affinity_fn, **kwargs)
146
- return eigvec
147
- # normal class instantiation
148
- return super().__new__(cls)
149
-
150
143
  def __call__(self, X: torch.Tensor) -> torch.Tensor:
151
144
  return self.fit_transform(X)
@@ -22,10 +22,11 @@ def ncut_click_prompt(
22
22
  click_weight: float = 0.5,
23
23
  bg_weight: float = 0.1,
24
24
  n_eig: int = 2,
25
- d_sigma: float = None,
25
+ quantile_sigma: float = 0.25,
26
26
  device: str = None,
27
27
  sigma: float = None,
28
28
  affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
29
+ exact_gradient: bool = False,
29
30
  no_propagation: bool = False,
30
31
  return_indices_and_sigma: bool = False,
31
32
  **kwargs,
@@ -55,7 +56,7 @@ def ncut_click_prompt(
55
56
 
56
57
  # find optimal sigma for affinity matrix
57
58
  if sigma is None and affinity_fn == rbf_affinity:
58
- sigma = find_sigma_by_degree(nystrom_X, d_sigma, affinity_fn)
59
+ sigma = find_sigma_by_degree(nystrom_X, quantile_sigma, affinity_fn)
59
60
  # TODO: change to std()
60
61
  elif sigma is None and affinity_fn == cosine_affinity:
61
62
  sigma = 0.5
@@ -76,7 +77,7 @@ def ncut_click_prompt(
76
77
 
77
78
  _A = click_weight * A_click + (1 - click_weight) * A
78
79
 
79
- nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
80
+ nystrom_eigvec, eigval = _plain_ncut(_A, n_eig, exact_gradient=exact_gradient)
80
81
 
81
82
  if no_propagation:
82
83
  return nystrom_eigvec, eigval, nystrom_indices, sigma
@@ -6,7 +6,7 @@ import torch
6
6
  import numpy as np
7
7
  from ncut_pytorch.utils.sigma import find_sigma_by_degree
8
8
  from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity
9
- from ncut_pytorch.utils.math import gram_schmidt, normalize_affinity, grad_safe_eig_solve, correct_rotation, keep_topk_per_row
9
+ from ncut_pytorch.utils.math import gram_schmidt, normalize_affinity, grad_safe_eig_solve, correct_rotation, keep_topk_per_row, svd_lowrank
10
10
  from ncut_pytorch.utils.sample import farthest_point_sampling
11
11
  from ncut_pytorch.utils.device import auto_device
12
12
 
@@ -34,14 +34,16 @@ class NystromConfig:
34
34
  def ncut_fn(
35
35
  X: torch.Tensor,
36
36
  n_eig: int = 100,
37
- d_sigma: float = None,
38
- device: str = None,
39
- sigma: float = None,
40
- repulsion_sigma: float = None,
41
- repulsion_weight: float = 0.2,
37
+ quantile_sigma: float = 0.25,
38
+ quantile_sigma_repulsion: float = 0.20,
39
+ sigma: float | None = None,
40
+ repulsion_sigma: float | None = None,
41
+ repulsion_weight: float | None = None,
42
+ affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
42
43
  extrapolation_factor: float = 1.0,
44
+ exact_gradient: bool = False,
45
+ device: str | None = None,
43
46
  make_orthogonal: bool = False,
44
- affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
45
47
  no_propagation: bool = False,
46
48
  **kwargs,
47
49
  ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]:
@@ -50,14 +52,15 @@ def ncut_fn(
50
52
  Args:
51
53
  X (torch.Tensor): input features, shape (N, D)
52
54
  n_eig (int): number of eigenvectors
53
- d_sigma (float): affinity sigma parameter, lower d_sigma results in sharper eigenvectors
54
- device (str): device, default 'auto' (auto detect GPU)
55
+ quantile_sigma (float): quantile of affinity sigma parameter, lower quantile_sigma results in sharper eigenvectors
56
+ quantile_sigma_repulsion (float): quantile of repulsion sigma parameter, lower quantile_sigma_repulsion results in sharper eigenvectors
55
57
  sigma (float): affinity parameter, override d_sigma if provided
56
58
  repulsion_sigma (float): (if use repulsion) repulsion sigma parameter, default None (no repulsion)
57
59
  repulsion_weight (float): (if use repulsion) repulsion weight, default 0.2
60
+ affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
58
61
  extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
62
+ exact_gradient (bool): use full spectrum and exact gradient, can be slower and unstable, default False
59
63
  make_orthogonal (bool): make eigenvectors orthogonal
60
- affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
61
64
 
62
65
  Returns:
63
66
  eigenvectors (torch.Tensor): shape (N, n_eig)
@@ -72,37 +75,28 @@ def ncut_fn(
72
75
  """
73
76
  config = NystromConfig()
74
77
  config.update(kwargs)
75
-
76
- # use GPU if available
77
78
  device = auto_device(X.device, device)
78
79
 
79
- # check if enough data for nystrom approximation
80
- is_enough_data = X.shape[0] > config.n_sample
81
-
82
80
  # subsample for nystrom approximation
81
+ is_enough_data = X.shape[0] > config.n_sample
83
82
  n_sample = min(config.n_sample, int(X.shape[0]*config.n_sample_max_ratio))
84
83
  nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device) if is_enough_data else np.arange(X.shape[0])
85
84
  nystrom_X = X[nystrom_indices].to(device)
86
85
 
87
- # find optimal sigma for affinity matrix
88
- if sigma is None:
89
- if affinity_fn == rbf_affinity:
90
- sigma = find_sigma_by_degree(nystrom_X, d_sigma, affinity_fn)
91
- elif affinity_fn == cosine_affinity:
92
- sigma = 0.5
93
- else:
94
- raise ValueError(f"`sigma` needs to be provided for affinity function {affinity_fn}, (sigma=0.5)")
95
-
96
- if repulsion_sigma is not None:
97
- nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma_attraction=sigma, sigma_repulsion=repulsion_sigma, repulsion_weight=repulsion_weight, affinity_fn=affinity_fn)
86
+ sigma, repulsion_sigma = find_optimal_sigma(nystrom_X, quantile_sigma, quantile_sigma_repulsion, sigma, repulsion_sigma, affinity_fn)
87
+
88
+ if repulsion_sigma and repulsion_weight:
89
+ nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma,
90
+ repulsion_sigma, repulsion_weight, affinity_fn, exact_gradient)
98
91
  else:
99
92
  A = affinity_fn(nystrom_X, sigma=sigma)
100
- nystrom_eigvec, eigval = _plain_ncut(A, n_eig)
93
+ nystrom_eigvec, eigval = _plain_ncut(A, n_eig, exact_gradient)
101
94
 
102
95
  if no_propagation:
103
96
  return nystrom_eigvec, eigval, nystrom_indices, sigma
104
97
 
105
98
  if not is_enough_data:
99
+ # skip nystrom approximation if not enough data, use exact ncut
106
100
  return nystrom_eigvec, eigval
107
101
 
108
102
  # propagate eigenvectors from subgraph to full graph
@@ -123,6 +117,25 @@ def ncut_fn(
123
117
 
124
118
  return eigvec, eigval
125
119
 
120
+ def find_optimal_sigma(
121
+ X: torch.Tensor,
122
+ quantile_sigma: float = 0.25,
123
+ quantile_sigma_repulsion: float = 0.20,
124
+ sigma: float | None = None,
125
+ repulsion_sigma: float | None = None,
126
+ affinity_fn: Union["rbf_affinity", "cosine_affinity"] = rbf_affinity,
127
+ ):
128
+ """Find optimal sigma for affinity matrix and repulsion matrix."""
129
+ if affinity_fn == rbf_affinity:
130
+ sigma = sigma or find_sigma_by_degree(X, quantile_sigma, affinity_fn)
131
+ repulsion_sigma = repulsion_sigma or find_sigma_by_degree(X, quantile_sigma_repulsion, affinity_fn, init_sigma=sigma)
132
+ elif affinity_fn == cosine_affinity:
133
+ sigma = sigma or 0.5
134
+ repulsion_sigma = repulsion_sigma or 0.3
135
+ else:
136
+ if sigma is None:
137
+ raise ValueError(f"`sigma` need to be provided for affinity function {affinity_fn}, (sigma=0.5, repulsion_sigma=0.3)")
138
+ return sigma, repulsion_sigma
126
139
 
127
140
  def ncut_with_repulsion(
128
141
  X: torch.Tensor,
@@ -131,6 +144,7 @@ def ncut_with_repulsion(
131
144
  sigma_repulsion: float = None,
132
145
  repulsion_weight: float = 0.2,
133
146
  affinity_fn: Union["rbf_affinity", "cosine_affinity"] = cosine_affinity,
147
+ exact_gradient: bool = False,
134
148
  eps: float = 1e-8,
135
149
  ):
136
150
  A = affinity_fn(X, sigma=sigma_attraction)
@@ -141,7 +155,10 @@ def ncut_with_repulsion(
141
155
  D = D_A + D_R
142
156
  W = A - R + torch.diag(D_R)
143
157
  W = W / D[:, None]
144
- eigvec, eigval, _ = grad_safe_eig_solve(W, n_eig)
158
+ if exact_gradient:
159
+ eigvec, eigval, _ = grad_safe_eig_solve(W, n_eig)
160
+ else:
161
+ eigvec, eigval, _ = svd_lowrank(W, n_eig)
145
162
  eigvec = correct_rotation(eigvec)
146
163
  return eigvec, eigval
147
164
 
@@ -149,9 +166,13 @@ def ncut_with_repulsion(
149
166
  def _plain_ncut(
150
167
  A: torch.Tensor,
151
168
  n_eig: int = 100,
169
+ exact_gradient: bool = False,
152
170
  ):
153
171
  A = normalize_affinity(A)
154
- eigvec, eigval, _ = grad_safe_eig_solve(A, n_eig)
172
+ if exact_gradient:
173
+ eigvec, eigval, _ = grad_safe_eig_solve(A, n_eig)
174
+ else:
175
+ eigvec, eigval, _ = svd_lowrank(A, n_eig)
155
176
  eigvec = eigvec[:, :n_eig]
156
177
  eigval = eigval[:n_eig]
157
178
  eigvec = correct_rotation(eigvec)
@@ -191,7 +212,7 @@ def nystrom_propagate(
191
212
  nystrom_out = nystrom_out[indices].to(device)
192
213
  nystrom_X = nystrom_X[indices].to(device)
193
214
 
194
- sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity)
215
+ sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity, quantile_sigma=0.25)
195
216
  sigma = sigma * extrapolation_factor
196
217
 
197
218
  D = rbf_affinity(nystrom_X, sigma=sigma).mean(1)
@@ -18,7 +18,7 @@ import logging
18
18
  import numpy as np
19
19
  import torch
20
20
 
21
- from .torch_mod import svd_lowrank
21
+ from .torch_mod import svd_lowrank as my_svd_lowrank
22
22
 
23
23
 
24
24
  def check_gamma_deprecated(gamma: float | None) -> float:
@@ -122,6 +122,23 @@ def pca_lowrank(
122
122
  return u @ torch.diag(s)
123
123
 
124
124
 
125
+ def svd_lowrank(mat: torch.Tensor, q: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ """SVD lowrank implementation for float16 and bfloat16."""
127
+ dtype = mat.dtype
128
+ try:
129
+ with torch.autocast(device_type=mat.device.type, enabled=False):
130
+ if dtype == torch.float16 or dtype == torch.bfloat16:
131
+ mat = mat.float() # svd_lowrank does not support float16
132
+ u, s, v = my_svd_lowrank(mat, q=q + 10)
133
+ except RuntimeError:
134
+ if dtype == torch.float16 or dtype == torch.bfloat16:
135
+ mat = mat.float()
136
+ u, s, v = my_svd_lowrank(mat, q=q + 10)
137
+
138
+ u, s, v = u[:, :q], s[:q], v[:, :q]
139
+ return u.to(dtype), s.to(dtype), v.to(dtype)
140
+
141
+
125
142
  def quantile_min_max(
126
143
  x: torch.Tensor,
127
144
  q1: float = 0.01,
@@ -9,7 +9,7 @@ from .sample import farthest_point_sampling
9
9
  @torch.no_grad()
10
10
  def _find_sigma_by_degree(
11
11
  X: torch.Tensor, # [n_samples, n_features]
12
- d_sigma: float | str | None = 'auto',
12
+ quantile_sigma: float = 0.25,
13
13
  affinity_fn: callable = rbf_affinity,
14
14
  X2: torch.Tensor | None = None,
15
15
  init_sigma: float = 0.5,
@@ -17,27 +17,21 @@ def _find_sigma_by_degree(
17
17
  max_iter: int = 100,
18
18
  ) -> float:
19
19
  """Binary search for optimal sigma to achieve target mean edge weight."""
20
- if isinstance(d_sigma, float):
21
- assert d_sigma > 0, "d_sigma must be positive"
20
+ if quantile_sigma <= 0 or quantile_sigma >= 1:
21
+ raise ValueError(f"quantile_sigma must be between 0 and 1, got {quantile_sigma}")
22
22
  sigma = init_sigma
23
23
 
24
- # Find d_sigma if 'auto'
25
- if d_sigma in ('auto', None):
26
- scale_inv_sigma = sigma * X.std(0).sum()
27
- current_degrees = affinity_fn(X, X2=X2, sigma=scale_inv_sigma).mean(1)
28
- for _ in range(2):
29
- current_degree = current_degrees.mean().item()
30
- mask = current_degrees < current_degree
31
- current_degrees = current_degrees[mask]
32
- d_sigma = current_degrees.mean().item()
24
+ scale_inv_sigma = X.std(0).sum()
25
+ current_degrees = affinity_fn(X, X2=X2, sigma=scale_inv_sigma).mean(1)
26
+ target_degree = current_degrees.float().quantile(quantile_sigma).item()
33
27
 
34
28
  # Binary search for sigma
35
29
  current_degree = affinity_fn(X, X2=X2, sigma=sigma).mean().item()
36
30
  low, high = 0, float('inf')
37
- tol = r_tol * d_sigma
31
+ tol = r_tol * target_degree
38
32
  i_iter = 0
39
- while abs(current_degree - d_sigma) > tol and i_iter < max_iter:
40
- if current_degree > d_sigma:
33
+ while abs(current_degree - target_degree) > tol and i_iter < max_iter:
34
+ if current_degree > target_degree:
41
35
  high = sigma
42
36
  sigma = (low + sigma) / 2
43
37
  else:
@@ -52,7 +46,7 @@ def _find_sigma_by_degree(
52
46
  @torch.no_grad()
53
47
  def find_sigma_by_degree(
54
48
  X: torch.Tensor, # [n_samples, n_features]
55
- d_sigma: float | str | None = 'auto',
49
+ quantile_sigma: float = 0.25,
56
50
  affinity_fn: callable = rbf_affinity,
57
51
  X2: torch.Tensor | None = None,
58
52
  init_sigma: float = 0.5,
@@ -62,4 +56,4 @@ def find_sigma_by_degree(
62
56
  ) -> float:
63
57
  """Find sigma after FPS-based downsampling for efficiency."""
64
58
  indices = farthest_point_sampling(X, n_sample)
65
- return _find_sigma_by_degree(X[indices], d_sigma, affinity_fn, X2=X2, init_sigma=init_sigma, r_tol=r_tol, max_iter=max_iter)
59
+ return _find_sigma_by_degree(X[indices], quantile_sigma, affinity_fn, X2=X2, init_sigma=init_sigma, r_tol=r_tol, max_iter=max_iter)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev1
3
+ Version: 3.0.0.dev2
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ncut_pytorch"
7
- version = "3.0.0dev1"
7
+ version = "3.0.0dev2"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  ]