nystrom-ncut 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/__init__.py +22 -0
- nystrom_ncut/ncut_pytorch.py +561 -0
- nystrom_ncut/new_ncut_pytorch.py +241 -0
- nystrom_ncut/nystrom.py +170 -0
- nystrom_ncut/propagation_utils.py +371 -0
- nystrom_ncut/visualize_utils.py +655 -0
- nystrom_ncut-0.0.1.dist-info/LICENSE +19 -0
- nystrom_ncut-0.0.1.dist-info/METADATA +164 -0
- nystrom_ncut-0.0.1.dist-info/RECORD +11 -0
- nystrom_ncut-0.0.1.dist-info/WHEEL +5 -0
- nystrom_ncut-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,655 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Callable, Dict, Literal, Tuple
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
import torch.nn.functional as F
|
7
|
+
from sklearn.base import BaseEstimator
|
8
|
+
|
9
|
+
from .propagation_utils import (
|
10
|
+
run_subgraph_sampling,
|
11
|
+
propagate_knn,
|
12
|
+
propagate_eigenvectors,
|
13
|
+
check_if_normalized,
|
14
|
+
quantile_min_max,
|
15
|
+
quantile_normalize
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
def _identity(X: torch.Tensor) -> torch.Tensor:
|
20
|
+
return X
|
21
|
+
|
22
|
+
|
23
|
+
def eigenvector_to_rgb(
|
24
|
+
eigen_vector: torch.Tensor,
|
25
|
+
method: Literal["tsne_2d", "tsne_3d", "umap_sphere", "umap_2d", "umap_3d"] = "tsne_3d",
|
26
|
+
num_sample: int = 1000,
|
27
|
+
perplexity: int = 150,
|
28
|
+
n_neighbors: int = 150,
|
29
|
+
min_distance: float = 0.1,
|
30
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
31
|
+
device: str = None,
|
32
|
+
q: float = 0.95,
|
33
|
+
knn: int = 10,
|
34
|
+
seed: int = 0,
|
35
|
+
):
|
36
|
+
"""Use t-SNE or UMAP to convert eigenvectors (more than 3) to RGB color (3D RGB CUBE).
|
37
|
+
|
38
|
+
Args:
|
39
|
+
eigen_vector (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
|
40
|
+
method (str): method to convert eigenvectors to RGB,
|
41
|
+
choices are: ['tsne_2d', 'tsne_3d', 'umap_sphere', 'umap_2d', 'umap_3d']
|
42
|
+
num_sample (int): number of samples for Nystrom-like approximation, increase for better approximation
|
43
|
+
perplexity (int): perplexity for t-SNE, increase for more global structure
|
44
|
+
n_neighbors (int): number of neighbors for UMAP, increase for more global structure
|
45
|
+
min_distance (float): minimum distance for UMAP
|
46
|
+
metric (str): distance metric, default 'cosine'
|
47
|
+
device (str): device to use for computation, if None, will not change device
|
48
|
+
q (float): quantile for RGB normalization, default 0.95. lower q results in more sharp colors
|
49
|
+
knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
|
50
|
+
smaller knn result in more sharp colors, default 1. knn>1 will smooth-out the embedding
|
51
|
+
in the t-SNE or UMAP space.
|
52
|
+
seed (int): random seed for t-SNE or UMAP
|
53
|
+
|
54
|
+
Examples:
|
55
|
+
>>> from ncut_pytorch import eigenvector_to_rgb
|
56
|
+
>>> X_3d, rgb = eigenvector_to_rgb(eigenvectors, method='tsne_3d')
|
57
|
+
>>> print(X_3d.shape, rgb.shape)
|
58
|
+
>>> # (10000, 3) (10000, 3)
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
(torch.Tensor): t-SNE or UMAP embedding, shape (n_samples, 2) or (n_samples, 3)
|
62
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
63
|
+
"""
|
64
|
+
kwargs = {
|
65
|
+
"num_sample": num_sample,
|
66
|
+
"perplexity": perplexity,
|
67
|
+
"n_neighbors": n_neighbors,
|
68
|
+
"min_distance": min_distance,
|
69
|
+
"metric": metric,
|
70
|
+
"device": device,
|
71
|
+
"q": q,
|
72
|
+
"knn": knn,
|
73
|
+
"seed": seed,
|
74
|
+
}
|
75
|
+
|
76
|
+
if method == "tsne_2d":
|
77
|
+
embed, rgb = rgb_from_tsne_2d(eigen_vector, **kwargs)
|
78
|
+
elif method == "tsne_3d":
|
79
|
+
embed, rgb = rgb_from_tsne_3d(eigen_vector, **kwargs)
|
80
|
+
elif method == "umap_sphere":
|
81
|
+
embed, rgb = rgb_from_umap_sphere(eigen_vector, **kwargs)
|
82
|
+
elif method == "umap_2d":
|
83
|
+
embed, rgb = rgb_from_umap_2d(eigen_vector, **kwargs)
|
84
|
+
elif method == "umap_3d":
|
85
|
+
embed, rgb = rgb_from_umap_3d(eigen_vector, **kwargs)
|
86
|
+
else:
|
87
|
+
raise ValueError("method should be 'tsne_2d', 'tsne_3d' or 'umap_sphere'")
|
88
|
+
|
89
|
+
return embed, rgb
|
90
|
+
|
91
|
+
|
92
|
+
def _rgb_with_dimensionality_reduction(
|
93
|
+
features: torch.Tensor,
|
94
|
+
num_sample: int,
|
95
|
+
metric: Literal["cosine", "euclidean"],
|
96
|
+
rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
|
97
|
+
q: float, knn: int,
|
98
|
+
seed: int, device: str,
|
99
|
+
reduction: Callable[..., BaseEstimator],
|
100
|
+
reduction_dim: int,
|
101
|
+
reduction_kwargs: Dict[str, Any],
|
102
|
+
transform_func: Callable[[torch.Tensor], torch.Tensor] = _identity,
|
103
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
104
|
+
subgraph_indices = run_subgraph_sampling(
|
105
|
+
features,
|
106
|
+
num_sample=num_sample,
|
107
|
+
sample_method="farthest",
|
108
|
+
)
|
109
|
+
|
110
|
+
_inp = features[subgraph_indices].cpu().numpy()
|
111
|
+
_subgraph_embed = reduction(
|
112
|
+
n_components=reduction_dim,
|
113
|
+
metric=metric,
|
114
|
+
random_state=seed,
|
115
|
+
**reduction_kwargs
|
116
|
+
).fit_transform(_inp)
|
117
|
+
|
118
|
+
_subgraph_embed = torch.tensor(_subgraph_embed, dtype=torch.float32)
|
119
|
+
X_nd = transform_func(propagate_knn(
|
120
|
+
_subgraph_embed,
|
121
|
+
features,
|
122
|
+
features[subgraph_indices],
|
123
|
+
distance=metric,
|
124
|
+
knn=knn,
|
125
|
+
device=device,
|
126
|
+
move_output_to_cpu=True,
|
127
|
+
))
|
128
|
+
rgb = rgb_func(X_nd, q)
|
129
|
+
return X_nd.numpy(force=True), rgb
|
130
|
+
|
131
|
+
|
132
|
+
def rgb_from_tsne_2d(
|
133
|
+
features: torch.Tensor,
|
134
|
+
num_sample: int = 1000,
|
135
|
+
perplexity: int = 150,
|
136
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
137
|
+
device: str = None,
|
138
|
+
seed: int = 0,
|
139
|
+
q: float = 0.95,
|
140
|
+
knn: int = 10,
|
141
|
+
**kwargs: Any,
|
142
|
+
):
|
143
|
+
"""
|
144
|
+
Returns:
|
145
|
+
(torch.Tensor): Embedding in 2D, shape (n_samples, 2)
|
146
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
147
|
+
"""
|
148
|
+
try:
|
149
|
+
from sklearn.manifold import TSNE
|
150
|
+
except ImportError:
|
151
|
+
raise ImportError(
|
152
|
+
"sklearn import failed, please install `pip install scikit-learn`"
|
153
|
+
)
|
154
|
+
num_sample = min(num_sample, features.shape[0])
|
155
|
+
if perplexity > num_sample // 2:
|
156
|
+
logging.warning(
|
157
|
+
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
158
|
+
)
|
159
|
+
perplexity = num_sample // 2
|
160
|
+
|
161
|
+
x2d, rgb = _rgb_with_dimensionality_reduction(
|
162
|
+
features=features,
|
163
|
+
num_sample=num_sample,
|
164
|
+
metric=metric,
|
165
|
+
rgb_func=rgb_from_2d_colormap,
|
166
|
+
q=q, knn=knn,
|
167
|
+
seed=seed, device=device,
|
168
|
+
reduction=TSNE, reduction_dim=2, reduction_kwargs={
|
169
|
+
"perplexity": perplexity,
|
170
|
+
},
|
171
|
+
)
|
172
|
+
|
173
|
+
return x2d, rgb
|
174
|
+
|
175
|
+
|
176
|
+
def rgb_from_tsne_3d(
|
177
|
+
features: torch.Tensor,
|
178
|
+
num_sample: int = 1000,
|
179
|
+
perplexity: int = 150,
|
180
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
181
|
+
device: str = None,
|
182
|
+
seed: int = 0,
|
183
|
+
q: float = 0.95,
|
184
|
+
knn: int = 10,
|
185
|
+
**kwargs: Any,
|
186
|
+
):
|
187
|
+
"""
|
188
|
+
Returns:
|
189
|
+
(torch.Tensor): Embedding in 3D, shape (n_samples, 3)
|
190
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
191
|
+
"""
|
192
|
+
try:
|
193
|
+
from sklearn.manifold import TSNE
|
194
|
+
except ImportError:
|
195
|
+
raise ImportError(
|
196
|
+
"sklearn import failed, please install `pip install scikit-learn`"
|
197
|
+
)
|
198
|
+
num_sample = min(num_sample, features.shape[0])
|
199
|
+
if perplexity > num_sample // 2:
|
200
|
+
logging.warning(
|
201
|
+
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
202
|
+
)
|
203
|
+
perplexity = num_sample // 2
|
204
|
+
|
205
|
+
x3d, rgb = _rgb_with_dimensionality_reduction(
|
206
|
+
features=features,
|
207
|
+
num_sample=num_sample,
|
208
|
+
metric=metric,
|
209
|
+
rgb_func=rgb_from_3d_rgb_cube,
|
210
|
+
q=q, knn=knn,
|
211
|
+
seed=seed, device=device,
|
212
|
+
reduction=TSNE, reduction_dim=3, reduction_kwargs={
|
213
|
+
"perplexity": perplexity,
|
214
|
+
},
|
215
|
+
)
|
216
|
+
|
217
|
+
return x3d, rgb
|
218
|
+
|
219
|
+
|
220
|
+
def rgb_from_cosine_tsne_3d(
|
221
|
+
features: torch.Tensor,
|
222
|
+
num_sample: int = 1000,
|
223
|
+
perplexity: int = 150,
|
224
|
+
device: str = None,
|
225
|
+
seed: int = 0,
|
226
|
+
q: float = 0.95,
|
227
|
+
knn: int = 10,
|
228
|
+
**kwargs: Any,
|
229
|
+
):
|
230
|
+
"""
|
231
|
+
Returns:
|
232
|
+
(torch.Tensor): Embedding in 3D, shape (n_samples, 3)
|
233
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
234
|
+
"""
|
235
|
+
try:
|
236
|
+
from sklearn.manifold import TSNE
|
237
|
+
except ImportError:
|
238
|
+
raise ImportError(
|
239
|
+
"sklearn import failed, please install `pip install scikit-learn`"
|
240
|
+
)
|
241
|
+
num_sample = min(num_sample, features.shape[0])
|
242
|
+
if perplexity > num_sample // 2:
|
243
|
+
logging.warning(
|
244
|
+
f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
|
245
|
+
)
|
246
|
+
perplexity = num_sample // 2
|
247
|
+
|
248
|
+
|
249
|
+
def cosine_to_rbf(X: torch.Tensor) -> torch.Tensor: # [B... x N x 3]
|
250
|
+
normalized_X = X / torch.norm(X, p=2, dim=-1, keepdim=True) # [B... x N x 3]
|
251
|
+
D = 1 - normalized_X @ normalized_X.mT # [B... x N x N]
|
252
|
+
|
253
|
+
G = (D[..., :1, 1:] ** 2 + D[..., 1:, :1] ** 2 - D[..., 1:, 1:] ** 2) / 2 # [B... x (N - 1) x (N - 1)]
|
254
|
+
L, V = torch.linalg.eigh(G) # [B... x (N - 1)], [B... x (N - 1) x (N - 1)]
|
255
|
+
sqrtG = V[..., -3:] * (L[..., None, -3:] ** 0.5) # [B... x (N - 1) x 3]
|
256
|
+
|
257
|
+
Y = torch.cat((torch.zeros_like(sqrtG[..., :1, :]), sqrtG), dim=-2) # [B... x N x 3]
|
258
|
+
Y = Y - torch.mean(Y, dim=-2, keepdim=True)
|
259
|
+
return Y
|
260
|
+
|
261
|
+
def rgb_from_cosine(X_3d: torch.Tensor, q: float) -> torch.Tensor:
|
262
|
+
return rgb_from_3d_rgb_cube(cosine_to_rbf(X_3d), q=q)
|
263
|
+
|
264
|
+
x3d, rgb = _rgb_with_dimensionality_reduction(
|
265
|
+
features=features,
|
266
|
+
num_sample=num_sample,
|
267
|
+
metric="cosine",
|
268
|
+
rgb_func=rgb_from_cosine,
|
269
|
+
q=q, knn=knn,
|
270
|
+
seed=seed, device=device,
|
271
|
+
reduction=TSNE, reduction_dim=3, reduction_kwargs={
|
272
|
+
"perplexity": perplexity,
|
273
|
+
},
|
274
|
+
)
|
275
|
+
|
276
|
+
return x3d, rgb
|
277
|
+
|
278
|
+
|
279
|
+
def rgb_from_umap_2d(
|
280
|
+
features: torch.Tensor,
|
281
|
+
num_sample: int = 1000,
|
282
|
+
n_neighbors: int = 150,
|
283
|
+
min_dist: float = 0.1,
|
284
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
285
|
+
device: str = None,
|
286
|
+
seed: int = 0,
|
287
|
+
q: float = 0.95,
|
288
|
+
knn: int = 10,
|
289
|
+
**kwargs: Any,
|
290
|
+
):
|
291
|
+
"""
|
292
|
+
Returns:
|
293
|
+
(torch.Tensor): Embedding in 2D, shape (n_samples, 2)
|
294
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
295
|
+
"""
|
296
|
+
try:
|
297
|
+
from umap import UMAP
|
298
|
+
except ImportError:
|
299
|
+
raise ImportError("umap import failed, please install `pip install umap-learn`")
|
300
|
+
|
301
|
+
x2d, rgb = _rgb_with_dimensionality_reduction(
|
302
|
+
features=features,
|
303
|
+
num_sample=num_sample,
|
304
|
+
metric=metric,
|
305
|
+
rgb_func=rgb_from_2d_colormap,
|
306
|
+
q=q, knn=knn,
|
307
|
+
seed=seed, device=device,
|
308
|
+
reduction=UMAP, reduction_dim=2, reduction_kwargs={
|
309
|
+
"n_neighbors": n_neighbors,
|
310
|
+
"min_dist": min_dist,
|
311
|
+
},
|
312
|
+
)
|
313
|
+
|
314
|
+
return x2d, rgb
|
315
|
+
|
316
|
+
|
317
|
+
def rgb_from_umap_sphere(
|
318
|
+
features: torch.Tensor,
|
319
|
+
num_sample: int = 1000,
|
320
|
+
n_neighbors: int = 150,
|
321
|
+
min_dist: float = 0.1,
|
322
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
323
|
+
device: str = None,
|
324
|
+
seed: int = 0,
|
325
|
+
q: float = 0.95,
|
326
|
+
knn: int = 10,
|
327
|
+
**kwargs: Any,
|
328
|
+
):
|
329
|
+
"""
|
330
|
+
Returns:
|
331
|
+
(torch.Tensor): Embedding in 2D, shape (n_samples, 2)
|
332
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
333
|
+
"""
|
334
|
+
try:
|
335
|
+
from umap import UMAP
|
336
|
+
except ImportError:
|
337
|
+
raise ImportError("umap import failed, please install `pip install umap-learn`")
|
338
|
+
|
339
|
+
def transform_func(X: torch.Tensor) -> torch.Tensor:
|
340
|
+
return torch.stack((
|
341
|
+
torch.sin(X[:, 0]) * torch.cos(X[:, 1]),
|
342
|
+
torch.sin(X[:, 0]) * torch.sin(X[:, 1]),
|
343
|
+
torch.cos(X[:, 0]),
|
344
|
+
), dim=1)
|
345
|
+
|
346
|
+
x3d, rgb = _rgb_with_dimensionality_reduction(
|
347
|
+
features=features,
|
348
|
+
num_sample=num_sample,
|
349
|
+
metric=metric,
|
350
|
+
rgb_func=rgb_from_3d_rgb_cube,
|
351
|
+
q=q, knn=knn,
|
352
|
+
seed=seed, device=device,
|
353
|
+
reduction=UMAP, reduction_dim=2, reduction_kwargs={
|
354
|
+
"n_neighbors": n_neighbors,
|
355
|
+
"min_dist": min_dist,
|
356
|
+
"output_metric": "haversine",
|
357
|
+
},
|
358
|
+
transform_func=transform_func
|
359
|
+
)
|
360
|
+
|
361
|
+
return x3d, rgb
|
362
|
+
|
363
|
+
|
364
|
+
def rgb_from_umap_3d(
|
365
|
+
features: torch.Tensor,
|
366
|
+
num_sample: int = 1000,
|
367
|
+
n_neighbors: int = 150,
|
368
|
+
min_dist: float = 0.1,
|
369
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
370
|
+
device: str = None,
|
371
|
+
seed: int = 0,
|
372
|
+
q: float = 0.95,
|
373
|
+
knn: int = 10,
|
374
|
+
**kwargs: Any,
|
375
|
+
):
|
376
|
+
"""
|
377
|
+
Returns:
|
378
|
+
(torch.Tensor): Embedding in 2D, shape (n_samples, 2)
|
379
|
+
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
380
|
+
"""
|
381
|
+
try:
|
382
|
+
from umap import UMAP
|
383
|
+
except ImportError:
|
384
|
+
raise ImportError("umap import failed, please install `pip install umap-learn`")
|
385
|
+
|
386
|
+
x3d, rgb = _rgb_with_dimensionality_reduction(
|
387
|
+
features=features,
|
388
|
+
num_sample=num_sample,
|
389
|
+
metric=metric,
|
390
|
+
rgb_func=rgb_from_3d_rgb_cube,
|
391
|
+
q=q, knn=knn,
|
392
|
+
seed=seed, device=device,
|
393
|
+
reduction=UMAP, reduction_dim=3, reduction_kwargs={
|
394
|
+
"n_neighbors": n_neighbors,
|
395
|
+
"min_dist": min_dist,
|
396
|
+
},
|
397
|
+
)
|
398
|
+
|
399
|
+
return x3d, rgb
|
400
|
+
|
401
|
+
|
402
|
+
def flatten_sphere(X_3d):
|
403
|
+
x = np.arctan2(X_3d[:, 0], X_3d[:, 1])
|
404
|
+
y = -np.arccos(X_3d[:, 2])
|
405
|
+
X_2d = np.stack([x, y], axis=1)
|
406
|
+
return X_2d
|
407
|
+
|
408
|
+
|
409
|
+
def rotate_rgb_cube(rgb, position=1):
|
410
|
+
"""rotate RGB cube to different position
|
411
|
+
|
412
|
+
Args:
|
413
|
+
rgb (torch.Tensor): RGB color space [0, 1], shape (*, 3)
|
414
|
+
position (int): position to rotate, 0, 1, 2, 3, 4, 5, 6
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
torch.Tensor: RGB color space, shape (n_samples, 3)
|
418
|
+
"""
|
419
|
+
assert position in range(0, 7), "position should be 0, 1, 2, 3, 4, 5, 6"
|
420
|
+
rotation_matrix = torch.tensor(
|
421
|
+
[
|
422
|
+
[0, 1, 0],
|
423
|
+
[0, 0, 1],
|
424
|
+
[1, 0, 0],
|
425
|
+
]
|
426
|
+
).float()
|
427
|
+
n_mul = position % 3
|
428
|
+
rotation_matrix = torch.matrix_power(rotation_matrix, n_mul)
|
429
|
+
rgb = rgb @ rotation_matrix
|
430
|
+
if position > 3:
|
431
|
+
rgb = 1 - rgb
|
432
|
+
return rgb
|
433
|
+
|
434
|
+
|
435
|
+
def rgb_from_3d_rgb_cube(X_3d, q=0.95):
|
436
|
+
"""convert 3D t-SNE to RGB color space
|
437
|
+
|
438
|
+
Args:
|
439
|
+
X_3d (torch.Tensor): 3D t-SNE embedding, shape (n_samples, 3)
|
440
|
+
q (float): quantile, default 0.95
|
441
|
+
|
442
|
+
Returns:
|
443
|
+
torch.Tensor: RGB color space, shape (n_samples, 3)
|
444
|
+
"""
|
445
|
+
assert X_3d.shape[1] == 3, "input should be (n_samples, 3)"
|
446
|
+
assert len(X_3d.shape) == 2, "input should be (n_samples, 3)"
|
447
|
+
rgb = []
|
448
|
+
for i in range(3):
|
449
|
+
rgb.append(quantile_normalize(X_3d[:, i], q=q))
|
450
|
+
rgb = torch.stack(rgb, dim=-1)
|
451
|
+
return rgb
|
452
|
+
|
453
|
+
|
454
|
+
def convert_to_lab_color(rgb, full_range=True):
|
455
|
+
from skimage import color
|
456
|
+
import copy
|
457
|
+
|
458
|
+
if isinstance(rgb, torch.Tensor):
|
459
|
+
rgb = rgb.cpu().numpy()
|
460
|
+
_rgb = copy.deepcopy(rgb)
|
461
|
+
_rgb[..., 0] = _rgb[..., 0] * 100
|
462
|
+
if full_range:
|
463
|
+
_rgb[..., 1] = _rgb[..., 1] * 255 - 128
|
464
|
+
_rgb[..., 2] = _rgb[..., 2] * 255 - 128
|
465
|
+
else:
|
466
|
+
_rgb[..., 1] = _rgb[..., 1] * 100 - 50
|
467
|
+
_rgb[..., 2] = _rgb[..., 2] * 100 - 50
|
468
|
+
lab_rgb = color.lab2rgb(_rgb)
|
469
|
+
return lab_rgb
|
470
|
+
|
471
|
+
|
472
|
+
def rgb_from_2d_colormap(X_2d, q=0.95):
|
473
|
+
xy = X_2d.clone()
|
474
|
+
for i in range(2):
|
475
|
+
xy[:, i] = quantile_normalize(xy[:, i], q=q)
|
476
|
+
|
477
|
+
try:
|
478
|
+
from pycolormap_2d import (
|
479
|
+
ColorMap2DBremm,
|
480
|
+
ColorMap2DZiegler,
|
481
|
+
ColorMap2DCubeDiagonal,
|
482
|
+
ColorMap2DSchumann,
|
483
|
+
)
|
484
|
+
except ImportError:
|
485
|
+
raise ImportError(
|
486
|
+
"pycolormap_2d import failed, please install `pip install pycolormap-2d`"
|
487
|
+
)
|
488
|
+
|
489
|
+
cmap = ColorMap2DCubeDiagonal()
|
490
|
+
xy = xy.cpu().numpy()
|
491
|
+
len_x, len_y = cmap._cmap_data.shape[:2]
|
492
|
+
x = (xy[:, 0] * (len_x - 1)).astype(int)
|
493
|
+
y = (xy[:, 1] * (len_y - 1)).astype(int)
|
494
|
+
rgb = cmap._cmap_data[x, y]
|
495
|
+
rgb = torch.tensor(rgb, dtype=torch.float32) / 255
|
496
|
+
return rgb
|
497
|
+
|
498
|
+
|
499
|
+
def propagate_rgb_color(
|
500
|
+
rgb: torch.Tensor,
|
501
|
+
eigenvectors: torch.Tensor,
|
502
|
+
new_eigenvectors: torch.Tensor,
|
503
|
+
knn: int = 10,
|
504
|
+
num_sample: int = 1000,
|
505
|
+
sample_method: Literal["farthest", "random"] = "farthest",
|
506
|
+
chunk_size: int = 8096,
|
507
|
+
device: str = None,
|
508
|
+
use_tqdm: bool = False,
|
509
|
+
):
|
510
|
+
"""Propagate RGB color to new nodes using KNN.
|
511
|
+
Args:
|
512
|
+
rgb (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
513
|
+
features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
|
514
|
+
new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
|
515
|
+
knn (int): number of KNN to propagate RGB color, default 1
|
516
|
+
num_sample (int): number of samples for subgraph sampling, default 50000
|
517
|
+
sample_method (str): sample method, 'farthest' (default) or 'random'
|
518
|
+
chunk_size (int): chunk size for matrix multiplication, default 8096
|
519
|
+
device (str): device to use for computation, if None, will not change device
|
520
|
+
use_tqdm (bool): show progress bar when propagating RGB color from subgraph to full graph
|
521
|
+
|
522
|
+
Returns:
|
523
|
+
torch.Tensor: propagated RGB color for each data sample, shape (n_new_samples, 3)
|
524
|
+
|
525
|
+
Examples:
|
526
|
+
>>> old_rgb = torch.randn(3000, 3)
|
527
|
+
>>> old_eigenvectors = torch.randn(3000, 20)
|
528
|
+
>>> new_eigenvectors = torch.randn(200, 20)
|
529
|
+
>>> new_rgb = propagate_rgb_color(old_rgb, new_eigenvectors, old_eigenvectors)
|
530
|
+
>>> # new_eigenvectors.shape = (200, 3)
|
531
|
+
"""
|
532
|
+
return propagate_eigenvectors(
|
533
|
+
eigenvectors=rgb,
|
534
|
+
features=eigenvectors,
|
535
|
+
new_features=new_eigenvectors,
|
536
|
+
knn=knn,
|
537
|
+
num_sample=num_sample,
|
538
|
+
sample_method=sample_method,
|
539
|
+
chunk_size=chunk_size,
|
540
|
+
device=device,
|
541
|
+
use_tqdm=use_tqdm,
|
542
|
+
)
|
543
|
+
|
544
|
+
|
545
|
+
# application: get segmentation mask fron a reference eigenvector (point prompt)
|
546
|
+
def _transform_heatmap(heatmap, gamma=1.0):
|
547
|
+
"""Transform the heatmap using gamma, normalize and min-max normalization.
|
548
|
+
|
549
|
+
Args:
|
550
|
+
heatmap (torch.Tensor): distance heatmap, shape (B, H, W)
|
551
|
+
gamma (float, optional): scaling factor, higher means smaller mask. Defaults to 1.0.
|
552
|
+
|
553
|
+
Returns:
|
554
|
+
torch.Tensor: transformed heatmap, shape (B, H, W)
|
555
|
+
"""
|
556
|
+
# normalize the heatmap
|
557
|
+
heatmap = (heatmap - heatmap.mean()) / heatmap.std()
|
558
|
+
heatmap = torch.exp(heatmap)
|
559
|
+
# transform the heatmap using gamma
|
560
|
+
# large gamma means more focus on the high values, hence smaller mask
|
561
|
+
heatmap = 1 / heatmap ** gamma
|
562
|
+
# min-max normalization [0, 1]
|
563
|
+
vmin, vmax = quantile_min_max(heatmap.flatten())
|
564
|
+
heatmap = (heatmap - vmin) / (vmax - vmin)
|
565
|
+
return heatmap
|
566
|
+
|
567
|
+
|
568
|
+
def _clean_mask(mask, min_area=500):
|
569
|
+
"""clean the binary mask by removing small connected components.
|
570
|
+
|
571
|
+
Args:
|
572
|
+
- mask: A numpy image of a binary mask with 255 for the object and 0 for the background.
|
573
|
+
- min_area: Minimum area for a connected component to be considered valid (default 500).
|
574
|
+
|
575
|
+
Returns:
|
576
|
+
- bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
|
577
|
+
- cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
|
578
|
+
"""
|
579
|
+
|
580
|
+
import cv2
|
581
|
+
# Find connected components in the cleaned mask
|
582
|
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
583
|
+
|
584
|
+
# Initialize an empty mask to store the final cleaned mask
|
585
|
+
final_cleaned_mask = np.zeros_like(mask)
|
586
|
+
|
587
|
+
# Collect bounding boxes for components that are larger than the threshold and update the cleaned mask
|
588
|
+
bounding_boxes = []
|
589
|
+
for i in range(1, num_labels): # Skip label 0 (background)
|
590
|
+
x, y, w, h, area = stats[i]
|
591
|
+
if area >= min_area:
|
592
|
+
# Add the bounding box of the valid component
|
593
|
+
bounding_boxes.append((x, y, w, h))
|
594
|
+
# Keep the valid components in the final cleaned mask
|
595
|
+
final_cleaned_mask[labels == i] = 255
|
596
|
+
|
597
|
+
return final_cleaned_mask, bounding_boxes
|
598
|
+
|
599
|
+
|
600
|
+
def get_mask(
|
601
|
+
all_eigvecs: torch.Tensor, prompt_eigvec: torch.Tensor,
|
602
|
+
threshold: float = 0.5, gamma: float = 1.0,
|
603
|
+
denoise: bool = True, denoise_area_th: int = 3):
|
604
|
+
"""Segmentation mask from one prompt eigenvector (at a clicked latent pixel).
|
605
|
+
</br> The mask is computed by measuring the cosine similarity between the clicked eigenvector and all the eigenvectors in the latent space.
|
606
|
+
</br> 1. Compute the cosine similarity between the clicked eigenvector and all the eigenvectors in the latent space.
|
607
|
+
</br> 2. Transform the heatmap, normalize and apply scaling (gamma).
|
608
|
+
</br> 3. Threshold the heatmap to get the mask.
|
609
|
+
</br> 4. Optionally denoise the mask by removing small connected components
|
610
|
+
|
611
|
+
Args:
|
612
|
+
all_eigvecs (torch.Tensor): (B, H, W, num_eig)
|
613
|
+
prompt_eigvec (torch.Tensor): (num_eig,)
|
614
|
+
threshold (float, optional): mask threshold, higher means smaller mask. Defaults to 0.5.
|
615
|
+
gamma (float, optional): mask scaling factor, higher means smaller mask. Defaults to 1.0.
|
616
|
+
denoise (bool, optional): mask denoising flag. Defaults to True.
|
617
|
+
denoise_area_th (int, optional): mask denoising area threshold. higher means more aggressive denoising. Defaults to 3.
|
618
|
+
|
619
|
+
Returns:
|
620
|
+
np.ndarray: masks (B, H, W), 1 for object, 0 for background
|
621
|
+
|
622
|
+
Examples:
|
623
|
+
>>> all_eigvecs = torch.randn(10, 64, 64, 20)
|
624
|
+
>>> prompt_eigvec = all_eigvecs[0, 32, 32] # center pixel
|
625
|
+
>>> masks = get_mask(all_eigvecs, prompt_eigvec, threshold=0.5, gamma=1.0, denoise=True, denoise_area_th=3)
|
626
|
+
>>> # masks.shape = (10, 64, 64)
|
627
|
+
"""
|
628
|
+
|
629
|
+
# normalize the eigenvectors to unit norm, to compute cosine similarity
|
630
|
+
if not check_if_normalized(all_eigvecs.reshape(-1, all_eigvecs.shape[-1])):
|
631
|
+
all_eigvecs = F.normalize(all_eigvecs, p=2, dim=-1)
|
632
|
+
|
633
|
+
prompt_eigvec = F.normalize(prompt_eigvec, p=2, dim=-1)
|
634
|
+
|
635
|
+
# compute the cosine similarity
|
636
|
+
cos_sim = all_eigvecs @ prompt_eigvec.unsqueeze(-1) # (B, H, W, 1)
|
637
|
+
cos_sim = cos_sim.squeeze(-1) # (B, H, W)
|
638
|
+
|
639
|
+
heatmap = 1 - cos_sim
|
640
|
+
|
641
|
+
# transform the heatmap, normalize and apply scaling (gamma)
|
642
|
+
heatmap = _transform_heatmap(heatmap, gamma=gamma)
|
643
|
+
|
644
|
+
masks = heatmap > threshold
|
645
|
+
masks = masks.cpu().numpy().astype(np.uint8)
|
646
|
+
|
647
|
+
if denoise:
|
648
|
+
cleaned_masks = []
|
649
|
+
for mask in masks:
|
650
|
+
cleaned_mask, _ = _clean_mask(mask, min_area=denoise_area_th)
|
651
|
+
cleaned_masks.append(cleaned_mask)
|
652
|
+
cleaned_masks = np.stack(cleaned_masks)
|
653
|
+
return cleaned_masks
|
654
|
+
|
655
|
+
return masks
|
@@ -0,0 +1,19 @@
|
|
1
|
+
Copyright (c) 2018 The Python Packaging Authority
|
2
|
+
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4
|
+
of this software and associated documentation files (the "Software"), to deal
|
5
|
+
in the Software without restriction, including without limitation the rights
|
6
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7
|
+
copies of the Software, and to permit persons to whom the Software is
|
8
|
+
furnished to do so, subject to the following conditions:
|
9
|
+
|
10
|
+
The above copyright notice and this permission notice shall be included in all
|
11
|
+
copies or substantial portions of the Software.
|
12
|
+
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19
|
+
SOFTWARE.
|