ncut-pytorch 3.0.0.dev0__tar.gz → 3.0.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/PKG-INFO +1 -1
  2. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/color/coloring.py +0 -1
  3. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/color/mspace.py +1 -1
  4. ncut_pytorch-3.0.0.dev1/ncut_pytorch/ncuts/ncut_click.py +101 -0
  5. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/ncuts/ncut_nystrom.py +80 -90
  6. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/grad.py +1 -40
  7. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch.egg-info/PKG-INFO +1 -1
  8. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/pyproject.toml +1 -1
  9. ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/ncut_click.py +0 -106
  10. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/LICENSE +0 -0
  11. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/README.md +0 -0
  12. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/__init__.py +0 -0
  13. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/color/__init__.py +0 -0
  14. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/color/mspace_nopl.py +0 -0
  15. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/ncut.py +0 -0
  16. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/ncuts/__init__.py +0 -0
  17. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
  18. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/__init__.py +0 -0
  19. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/__init__.py +0 -0
  20. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/api.py +0 -0
  21. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
  22. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
  23. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
  24. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/patch.py +0 -0
  25. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino/transform.py +0 -0
  26. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/dino_predictor.py +0 -0
  27. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
  28. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/predictor.py +0 -0
  29. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/predictor/vision_predictor.py +0 -0
  30. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/__init__.py +0 -0
  31. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/device.py +0 -0
  32. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/math.py +0 -0
  33. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/sample.py +0 -0
  34. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/sigma.py +0 -0
  35. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch/utils/torch_mod.py +0 -0
  36. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
  37. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  38. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch.egg-info/requires.txt +0 -0
  39. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/ncut_pytorch.egg-info/top_level.txt +0 -0
  40. {ncut_pytorch-3.0.0.dev0 → ncut_pytorch-3.0.0.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev0
3
+ Version: 3.0.0.dev1
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -218,7 +218,6 @@ def _nystrom_dimension_reduction(
218
218
  X[subgraph_indices],
219
219
  n_neighbors=knn,
220
220
  device=device,
221
- move_output_to_cpu=True,
222
221
  ))
223
222
  rgb = rgb_func(X_nd, q)
224
223
  return X_nd, rgb
@@ -56,7 +56,7 @@ def ncut_wrapper(features, n_eig, sigma=None):
56
56
 
57
57
  # features.requires_grad_(True)
58
58
  sigma = sigma or features.std(0).sum().item()
59
- # eigvec, eigval = ncut_fn(features, n_eig, sigma=sigma, track_grad=True)
59
+ eigvec, eigval = ncut_fn(features, n_eig, sigma=sigma)
60
60
  W = rbf_affinity(features, sigma=sigma)
61
61
  # W = cosine_affinity(features, sigma=1.0)
62
62
  A = normalize_affinity(W)
@@ -0,0 +1,101 @@
1
+ __all__ = ['ncut_click_prompt']
2
+
3
+ from typing import Callable, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ncut_pytorch.utils.sigma import find_sigma_by_degree
9
+ from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity, normalize_affinity
10
+ from ncut_pytorch.utils.sample import farthest_point_sampling
11
+ from ncut_pytorch.utils.device import auto_device
12
+ from .ncut_nystrom import NystromConfig
13
+ from .ncut_nystrom import nystrom_propagate
14
+ from .ncut_nystrom import _plain_ncut
15
+
16
+
17
+ #TODO: automatically optimize click_weight based on the iou of fg and bg
18
+ def ncut_click_prompt(
19
+ X: torch.Tensor,
20
+ fg_indices: np.ndarray,
21
+ bg_indices: np.ndarray = None,
22
+ click_weight: float = 0.5,
23
+ bg_weight: float = 0.1,
24
+ n_eig: int = 2,
25
+ d_sigma: float = None,
26
+ device: str = None,
27
+ sigma: float = None,
28
+ affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
29
+ no_propagation: bool = False,
30
+ return_indices_and_sigma: bool = False,
31
+ **kwargs,
32
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]:
33
+
34
+ config = NystromConfig()
35
+ config.update(kwargs)
36
+
37
+ # use GPU if available
38
+ device = auto_device(X.device, device)
39
+
40
+ if bg_indices is None:
41
+ bg_indices = np.array([], dtype=np.int64)
42
+
43
+ # subsample for nystrom approximation
44
+ nystrom_indices = farthest_point_sampling(X, n_sample=config.n_sample, device=device)
45
+ nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
46
+ # remove fg and bg from fps_idx
47
+ nystrom_indices = nystrom_indices[~np.isin(nystrom_indices, np.concatenate([fg_indices, bg_indices]))]
48
+ # add fg and bg to fps_idx
49
+ nystrom_indices = np.concatenate([fg_indices, bg_indices, nystrom_indices])
50
+ fg_indices = np.arange(len(fg_indices))
51
+ bg_indices = np.arange(len(bg_indices)) + len(fg_indices)
52
+ n_fgbg = len(fg_indices) + len(bg_indices)
53
+
54
+ nystrom_X = X[nystrom_indices].to(device)
55
+
56
+ # find optimal sigma for affinity matrix
57
+ if sigma is None and affinity_fn == rbf_affinity:
58
+ sigma = find_sigma_by_degree(nystrom_X, d_sigma, affinity_fn)
59
+ # TODO: change to std()
60
+ elif sigma is None and affinity_fn == cosine_affinity:
61
+ sigma = 0.5
62
+
63
+ # compute Ncut on the nystrom sampled subgraph
64
+ A = affinity_fn(nystrom_X, sigma=sigma)
65
+ A = normalize_affinity(A)
66
+
67
+ # modify the affinity from the clicks
68
+ X_click = 1 * A[fg_indices].mean(0)
69
+ if len(bg_indices) > 0:
70
+ X_click = X_click - bg_weight * A[bg_indices].mean(0)
71
+
72
+ X_click = X_click * A.shape[0]
73
+
74
+ A_click = affinity_fn(X_click.unsqueeze(1), sigma=0.5)
75
+ A_click = normalize_affinity(A_click)
76
+
77
+ _A = click_weight * A_click + (1 - click_weight) * A
78
+
79
+ nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
80
+
81
+ if no_propagation:
82
+ return nystrom_eigvec, eigval, nystrom_indices, sigma
83
+
84
+ # propagate eigenvectors from subgraph to full graph
85
+ eigvec, nystrom_indices2 = nystrom_propagate(
86
+ nystrom_eigvec,
87
+ X,
88
+ nystrom_X,
89
+ n_neighbors=config.n_neighbors,
90
+ n_sample=config.n_sample2,
91
+ matmul_chunk_size=config.matmul_chunk_size,
92
+ device=device,
93
+ return_indices=True,
94
+ )
95
+
96
+
97
+ if return_indices_and_sigma:
98
+ indices = nystrom_indices[nystrom_indices2]
99
+ return eigvec, eigval, indices, sigma
100
+
101
+ return eigvec, eigval
@@ -9,7 +9,6 @@ from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity
9
9
  from ncut_pytorch.utils.math import gram_schmidt, normalize_affinity, grad_safe_eig_solve, correct_rotation, keep_topk_per_row
10
10
  from ncut_pytorch.utils.sample import farthest_point_sampling
11
11
  from ncut_pytorch.utils.device import auto_device
12
- from ncut_pytorch.utils.grad import grad_manager
13
12
 
14
13
 
15
14
  class NystromConfig:
@@ -23,7 +22,6 @@ class NystromConfig:
23
22
  n_neighbors = 32 # number of neighbors for eigenvector propagation, 10 is large enough for most cases
24
23
  n_neighbors_max_ratio = 1/32 # max ratio of n_neighbors to n_sample2, to avoid over smoothing
25
24
  matmul_chunk_size = 65536 # chunk size for matrix multiplication, larger chunk size is faster but requires more memory
26
- move_output_to_cpu = True # if True, will move output to cpu, saves VRAM
27
25
 
28
26
  def update(self, kwargs: dict):
29
27
  for key, value in kwargs.items():
@@ -36,7 +34,6 @@ class NystromConfig:
36
34
  def ncut_fn(
37
35
  X: torch.Tensor,
38
36
  n_eig: int = 100,
39
- track_grad: bool = False,
40
37
  d_sigma: float = None,
41
38
  device: str = None,
42
39
  sigma: float = None,
@@ -53,7 +50,6 @@ def ncut_fn(
53
50
  Args:
54
51
  X (torch.Tensor): input features, shape (N, D)
55
52
  n_eig (int): number of eigenvectors
56
- track_grad (bool): keep track of pytorch gradients
57
53
  d_sigma (float): affinity sigma parameter, lower d_sigma results in sharper eigenvectors
58
54
  device (str): device, default 'auto' (auto detect GPU)
59
55
  sigma (float): affinity parameter, override d_sigma if provided
@@ -83,52 +79,49 @@ def ncut_fn(
83
79
  # check if enough data for nystrom approximation
84
80
  is_enough_data = X.shape[0] > config.n_sample
85
81
 
86
- with grad_manager(track_grad):
87
- # subsample for nystrom approximation
88
- n_sample = min(config.n_sample, int(X.shape[0]*config.n_sample_max_ratio))
89
- nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device) if is_enough_data else np.arange(X.shape[0])
90
- nystrom_X = X[nystrom_indices].to(device)
91
-
92
- # find optimal sigma for affinity matrix
93
- if sigma is None:
94
- if affinity_fn == rbf_affinity:
95
- sigma = find_sigma_by_degree(nystrom_X, d_sigma, affinity_fn)
96
- elif affinity_fn == cosine_affinity:
97
- sigma = 0.5
98
- else:
99
- raise ValueError(f"`sigma` needs to be provided for affinity function {affinity_fn}, (sigma=0.5)")
100
-
101
- if repulsion_sigma is not None:
102
- nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma_attraction=sigma, sigma_repulsion=repulsion_sigma, repulsion_weight=repulsion_weight, affinity_fn=affinity_fn)
82
+ # subsample for nystrom approximation
83
+ n_sample = min(config.n_sample, int(X.shape[0]*config.n_sample_max_ratio))
84
+ nystrom_indices = farthest_point_sampling(X, n_sample=n_sample, device=device) if is_enough_data else np.arange(X.shape[0])
85
+ nystrom_X = X[nystrom_indices].to(device)
86
+
87
+ # find optimal sigma for affinity matrix
88
+ if sigma is None:
89
+ if affinity_fn == rbf_affinity:
90
+ sigma = find_sigma_by_degree(nystrom_X, d_sigma, affinity_fn)
91
+ elif affinity_fn == cosine_affinity:
92
+ sigma = 0.5
103
93
  else:
104
- A = affinity_fn(nystrom_X, sigma=sigma)
105
- nystrom_eigvec, eigval = _plain_ncut(A, n_eig)
106
-
107
- if no_propagation:
108
- return nystrom_eigvec, eigval, nystrom_indices, sigma
109
-
110
- if not is_enough_data:
111
- return nystrom_eigvec, eigval
112
-
113
- # propagate eigenvectors from subgraph to full graph
114
- eigvec = nystrom_propagate(
115
- nystrom_eigvec,
116
- X,
117
- nystrom_X,
118
- extrapolation_factor=extrapolation_factor,
119
- n_neighbors=config.n_neighbors,
120
- n_sample=config.n_sample2,
121
- matmul_chunk_size=config.matmul_chunk_size,
122
- device=device,
123
- move_output_to_cpu=config.move_output_to_cpu,
124
- track_grad=track_grad,
125
- )
126
-
127
- # post-hoc orthogonalization
128
- if make_orthogonal:
129
- eigvec = gram_schmidt(eigvec)
130
-
131
- return eigvec, eigval
94
+ raise ValueError(f"`sigma` needs to be provided for affinity function {affinity_fn}, (sigma=0.5)")
95
+
96
+ if repulsion_sigma is not None:
97
+ nystrom_eigvec, eigval = ncut_with_repulsion(nystrom_X, n_eig, sigma_attraction=sigma, sigma_repulsion=repulsion_sigma, repulsion_weight=repulsion_weight, affinity_fn=affinity_fn)
98
+ else:
99
+ A = affinity_fn(nystrom_X, sigma=sigma)
100
+ nystrom_eigvec, eigval = _plain_ncut(A, n_eig)
101
+
102
+ if no_propagation:
103
+ return nystrom_eigvec, eigval, nystrom_indices, sigma
104
+
105
+ if not is_enough_data:
106
+ return nystrom_eigvec, eigval
107
+
108
+ # propagate eigenvectors from subgraph to full graph
109
+ eigvec = nystrom_propagate(
110
+ nystrom_eigvec,
111
+ X,
112
+ nystrom_X,
113
+ extrapolation_factor=extrapolation_factor,
114
+ n_neighbors=config.n_neighbors,
115
+ n_sample=config.n_sample2,
116
+ matmul_chunk_size=config.matmul_chunk_size,
117
+ device=device,
118
+ )
119
+
120
+ # post-hoc orthogonalization
121
+ if make_orthogonal:
122
+ eigvec = gram_schmidt(eigvec)
123
+
124
+ return eigvec, eigval
132
125
 
133
126
 
134
127
  def ncut_with_repulsion(
@@ -170,7 +163,6 @@ def nystrom_propagate(
170
163
  X: torch.Tensor,
171
164
  nystrom_X: torch.Tensor,
172
165
  extrapolation_factor: float = 1.0,
173
- track_grad: bool = False,
174
166
  device: str = None,
175
167
  return_indices: bool = False,
176
168
  **kwargs,
@@ -183,7 +175,6 @@ def nystrom_propagate(
183
175
  X (torch.Tensor): input features for all nodes, shape (N, D)
184
176
  nystrom_X (torch.Tensor): input features from nystrom sampled nodes, shape (m, D)
185
177
  extrapolation_factor (float): control how far can we extrapolate, larger extrapolation_factor means we can extrapolate further, default 1.0
186
- track_grad (bool): keep track of pytorch gradients, default False
187
178
  device (str): device to use for computation, if 'auto', will detect GPU automatically
188
179
  affinity_fn (callable): affinity function, default rbf_affinity. Should accept (X1, X2=None, sigma=float) and return affinity matrix
189
180
 
@@ -194,45 +185,44 @@ def nystrom_propagate(
194
185
  config = NystromConfig()
195
186
  config.update(kwargs)
196
187
 
197
- with grad_manager(track_grad):
198
- device = auto_device(nystrom_out.device, device)
199
- indices = farthest_point_sampling(nystrom_out, config.n_sample2, device=device)
200
- nystrom_out = nystrom_out[indices].to(device)
201
- nystrom_X = nystrom_X[indices].to(device)
202
-
203
- sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity)
204
- sigma = sigma * extrapolation_factor
205
-
206
- D = rbf_affinity(nystrom_X, sigma=sigma).mean(1)
207
-
208
- all_outs = []
209
- n_chunk = config.matmul_chunk_size
210
- n_neighbors = int(min(config.n_neighbors, len(indices)*config.n_neighbors_max_ratio))
211
- n_neighbors = max(n_neighbors, 4)
212
- for i in range(0, X.shape[0], n_chunk):
213
- end = min(i + n_chunk, X.shape[0])
214
-
215
- _Ai = rbf_affinity(X[i:end].to(device), nystrom_X, sigma=sigma)
216
- _Ai, _indices = keep_topk_per_row(_Ai, n_neighbors) # (n, n_neighbors)
217
- _Di = D[_indices].sum(1)
218
- _Ai = _Ai / _Di[:, None]
219
-
220
- weights = _Ai[..., None] # (n, n_neighbors, 1)
221
- neighbors = nystrom_out[_indices.flatten()]
222
- neighbors = neighbors.reshape(-1, n_neighbors, nystrom_out.shape[-1]) # (n, n_neighbors, d)
223
- out = weights * neighbors # (n, n_neighbors, d)
224
- out = out.sum(dim=1) # (n, d)
225
-
226
- if config.move_output_to_cpu and not track_grad:
227
- out = out.to("cpu")
228
- all_outs.append(out)
229
-
230
- all_outs = torch.cat(all_outs, dim=0)
231
-
232
- if return_indices:
233
- return all_outs, indices
234
-
235
- return all_outs
188
+ device = auto_device(nystrom_out.device, device)
189
+ output_device = X.device
190
+ indices = farthest_point_sampling(nystrom_out, config.n_sample2, device=device)
191
+ nystrom_out = nystrom_out[indices].to(device)
192
+ nystrom_X = nystrom_X[indices].to(device)
193
+
194
+ sigma = find_sigma_by_degree(nystrom_X, affinity_fn=rbf_affinity)
195
+ sigma = sigma * extrapolation_factor
196
+
197
+ D = rbf_affinity(nystrom_X, sigma=sigma).mean(1)
198
+
199
+ all_outs = []
200
+ n_chunk = config.matmul_chunk_size
201
+ n_neighbors = int(min(config.n_neighbors, len(indices)*config.n_neighbors_max_ratio))
202
+ n_neighbors = max(n_neighbors, 4)
203
+ for i in range(0, X.shape[0], n_chunk):
204
+ end = min(i + n_chunk, X.shape[0])
205
+
206
+ _Ai = rbf_affinity(X[i:end].to(device), nystrom_X, sigma=sigma)
207
+ _Ai, _indices = keep_topk_per_row(_Ai, n_neighbors) # (n, n_neighbors)
208
+ _Di = D[_indices].sum(1)
209
+ _Ai = _Ai / _Di[:, None]
210
+
211
+ weights = _Ai[..., None] # (n, n_neighbors, 1)
212
+ neighbors = nystrom_out[_indices.flatten()]
213
+ neighbors = neighbors.reshape(-1, n_neighbors, nystrom_out.shape[-1]) # (n, n_neighbors, d)
214
+ out = weights * neighbors # (n, n_neighbors, d)
215
+ out = out.sum(dim=1) # (n, d)
216
+
217
+ out = out.to(output_device)
218
+ all_outs.append(out)
219
+
220
+ all_outs = torch.cat(all_outs, dim=0)
221
+
222
+ if return_indices:
223
+ return all_outs, indices
224
+
225
+ return all_outs
236
226
 
237
227
 
238
228
 
@@ -1,8 +1,6 @@
1
- __all__ = ["rbf_eigvec_manual_grad", "grad_manager"]
1
+ __all__ = ["rbf_eigvec_manual_grad"]
2
2
 
3
3
  import torch
4
- from contextlib import contextmanager
5
-
6
4
 
7
5
  @torch.no_grad()
8
6
  def rbf_eigvec_manual_grad(
@@ -115,40 +113,3 @@ def rbf_eigvec_manual_grad(
115
113
 
116
114
  return grad_u
117
115
 
118
-
119
-
120
- @contextmanager
121
- def grad_manager(enabled: bool):
122
- """Context manager to temporarily set gradient computation mode.
123
-
124
- This context manager allows you to control gradient computation for a block
125
- of code, and automatically restores the previous gradient state when exiting
126
- the context.
127
-
128
- Args:
129
- enabled (bool): If True, enables gradient tracking within the context.
130
- If False, disables gradient tracking within the context.
131
-
132
- Yields:
133
- None
134
-
135
- Examples:
136
- >>> import torch
137
- >>> from ncut_pytorch.utils.grad import set_grad_enabled
138
- >>>
139
- >>> # Disable gradients for inference
140
- >>> with set_grad_enabled(False):
141
- ... result = model(input_tensor)
142
- >>>
143
- >>> # Enable gradients for training
144
- >>> with set_grad_enabled(True):
145
- ... loss = criterion(model(input_tensor), target)
146
- ... loss.backward()
147
- """
148
- prev_grad_state = torch.is_grad_enabled()
149
- torch.set_grad_enabled(enabled)
150
- try:
151
- yield
152
- finally:
153
- torch.set_grad_enabled(prev_grad_state)
154
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev0
3
+ Version: 3.0.0.dev1
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ncut_pytorch"
7
- version = "3.0.0dev0"
7
+ version = "3.0.0dev1"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  ]
@@ -1,106 +0,0 @@
1
- __all__ = ['ncut_click_prompt']
2
-
3
- from typing import Callable, Union
4
-
5
- import numpy as np
6
- import torch
7
-
8
- from ncut_pytorch.utils.sigma import find_sigma_by_degree
9
- from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity, normalize_affinity
10
- from ncut_pytorch.utils.sample import farthest_point_sampling
11
- from ncut_pytorch.utils.device import auto_device
12
- from ncut_pytorch.utils.grad import grad_manager
13
- from .ncut_nystrom import NystromConfig
14
- from .ncut_nystrom import nystrom_propagate
15
- from .ncut_nystrom import _plain_ncut
16
-
17
-
18
- #TODO: automatically optimize click_weight based on the iou of fg and bg
19
- def ncut_click_prompt(
20
- X: torch.Tensor,
21
- fg_indices: np.ndarray,
22
- bg_indices: np.ndarray = None,
23
- click_weight: float = 0.5,
24
- bg_weight: float = 0.1,
25
- n_eig: int = 2,
26
- track_grad: bool = False,
27
- d_sigma: float = None,
28
- device: str = None,
29
- sigma: float = None,
30
- affinity_fn: Callable[[torch.Tensor, torch.Tensor, float], torch.Tensor] = rbf_affinity,
31
- no_propagation: bool = False,
32
- return_indices_and_sigma: bool = False,
33
- **kwargs,
34
- ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]:
35
-
36
- config = NystromConfig()
37
- config.update(kwargs)
38
-
39
- # use GPU if available
40
- device = auto_device(X.device, device)
41
-
42
- with grad_manager(track_grad):
43
- if bg_indices is None:
44
- bg_indices = np.array([], dtype=np.int64)
45
-
46
- # subsample for nystrom approximation
47
- nystrom_indices = farthest_point_sampling(X, n_sample=config.n_sample, device=device)
48
- nystrom_indices = torch.tensor(nystrom_indices, dtype=torch.long)
49
- # remove fg and bg from fps_idx
50
- nystrom_indices = nystrom_indices[~np.isin(nystrom_indices, np.concatenate([fg_indices, bg_indices]))]
51
- # add fg and bg to fps_idx
52
- nystrom_indices = np.concatenate([fg_indices, bg_indices, nystrom_indices])
53
- fg_indices = np.arange(len(fg_indices))
54
- bg_indices = np.arange(len(bg_indices)) + len(fg_indices)
55
- n_fgbg = len(fg_indices) + len(bg_indices)
56
-
57
- nystrom_X = X[nystrom_indices].to(device)
58
-
59
- # find optimal sigma for affinity matrix
60
- if sigma is None and affinity_fn == rbf_affinity:
61
- sigma = find_sigma_by_degree(nystrom_X, d_sigma, affinity_fn)
62
- # TODO: change to std()
63
- elif sigma is None and affinity_fn == cosine_affinity:
64
- sigma = 0.5
65
-
66
- # compute Ncut on the nystrom sampled subgraph
67
- A = affinity_fn(nystrom_X, sigma=sigma)
68
- A = normalize_affinity(A)
69
-
70
- # modify the affinity from the clicks
71
- X_click = 1 * A[fg_indices].mean(0)
72
- if len(bg_indices) > 0:
73
- X_click = X_click - bg_weight * A[bg_indices].mean(0)
74
-
75
- X_click = X_click * A.shape[0]
76
-
77
- A_click = affinity_fn(X_click.unsqueeze(1), sigma=0.5)
78
- A_click = normalize_affinity(A_click)
79
-
80
- _A = click_weight * A_click + (1 - click_weight) * A
81
-
82
- nystrom_eigvec, eigval = _plain_ncut(_A, n_eig)
83
-
84
- if no_propagation:
85
- return nystrom_eigvec, eigval, nystrom_indices, sigma
86
-
87
- # propagate eigenvectors from subgraph to full graph
88
- eigvec, nystrom_indices2 = nystrom_propagate(
89
- nystrom_eigvec,
90
- X,
91
- nystrom_X,
92
- n_neighbors=config.n_neighbors,
93
- n_sample=config.n_sample2,
94
- matmul_chunk_size=config.matmul_chunk_size,
95
- device=device,
96
- move_output_to_cpu=config.move_output_to_cpu,
97
- track_grad=track_grad,
98
- return_indices=True,
99
- )
100
-
101
-
102
- if return_indices_and_sigma:
103
- indices = nystrom_indices[nystrom_indices2]
104
- return eigvec, eigval, indices, sigma
105
-
106
- return eigvec, eigval