nystrom-ncut 0.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,655 @@
1
+ import logging
2
+ from typing import Any, Callable, Dict, Literal, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from sklearn.base import BaseEstimator
8
+
9
+ from .propagation_utils import (
10
+ run_subgraph_sampling,
11
+ propagate_knn,
12
+ propagate_eigenvectors,
13
+ check_if_normalized,
14
+ quantile_min_max,
15
+ quantile_normalize
16
+ )
17
+
18
+
19
+ def _identity(X: torch.Tensor) -> torch.Tensor:
20
+ return X
21
+
22
+
23
+ def eigenvector_to_rgb(
24
+ eigen_vector: torch.Tensor,
25
+ method: Literal["tsne_2d", "tsne_3d", "umap_sphere", "umap_2d", "umap_3d"] = "tsne_3d",
26
+ num_sample: int = 1000,
27
+ perplexity: int = 150,
28
+ n_neighbors: int = 150,
29
+ min_distance: float = 0.1,
30
+ metric: Literal["cosine", "euclidean"] = "cosine",
31
+ device: str = None,
32
+ q: float = 0.95,
33
+ knn: int = 10,
34
+ seed: int = 0,
35
+ ):
36
+ """Use t-SNE or UMAP to convert eigenvectors (more than 3) to RGB color (3D RGB CUBE).
37
+
38
+ Args:
39
+ eigen_vector (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
40
+ method (str): method to convert eigenvectors to RGB,
41
+ choices are: ['tsne_2d', 'tsne_3d', 'umap_sphere', 'umap_2d', 'umap_3d']
42
+ num_sample (int): number of samples for Nystrom-like approximation, increase for better approximation
43
+ perplexity (int): perplexity for t-SNE, increase for more global structure
44
+ n_neighbors (int): number of neighbors for UMAP, increase for more global structure
45
+ min_distance (float): minimum distance for UMAP
46
+ metric (str): distance metric, default 'cosine'
47
+ device (str): device to use for computation, if None, will not change device
48
+ q (float): quantile for RGB normalization, default 0.95. lower q results in more sharp colors
49
+ knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
50
+ smaller knn result in more sharp colors, default 1. knn>1 will smooth-out the embedding
51
+ in the t-SNE or UMAP space.
52
+ seed (int): random seed for t-SNE or UMAP
53
+
54
+ Examples:
55
+ >>> from ncut_pytorch import eigenvector_to_rgb
56
+ >>> X_3d, rgb = eigenvector_to_rgb(eigenvectors, method='tsne_3d')
57
+ >>> print(X_3d.shape, rgb.shape)
58
+ >>> # (10000, 3) (10000, 3)
59
+
60
+ Returns:
61
+ (torch.Tensor): t-SNE or UMAP embedding, shape (n_samples, 2) or (n_samples, 3)
62
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
63
+ """
64
+ kwargs = {
65
+ "num_sample": num_sample,
66
+ "perplexity": perplexity,
67
+ "n_neighbors": n_neighbors,
68
+ "min_distance": min_distance,
69
+ "metric": metric,
70
+ "device": device,
71
+ "q": q,
72
+ "knn": knn,
73
+ "seed": seed,
74
+ }
75
+
76
+ if method == "tsne_2d":
77
+ embed, rgb = rgb_from_tsne_2d(eigen_vector, **kwargs)
78
+ elif method == "tsne_3d":
79
+ embed, rgb = rgb_from_tsne_3d(eigen_vector, **kwargs)
80
+ elif method == "umap_sphere":
81
+ embed, rgb = rgb_from_umap_sphere(eigen_vector, **kwargs)
82
+ elif method == "umap_2d":
83
+ embed, rgb = rgb_from_umap_2d(eigen_vector, **kwargs)
84
+ elif method == "umap_3d":
85
+ embed, rgb = rgb_from_umap_3d(eigen_vector, **kwargs)
86
+ else:
87
+ raise ValueError("method should be 'tsne_2d', 'tsne_3d' or 'umap_sphere'")
88
+
89
+ return embed, rgb
90
+
91
+
92
+ def _rgb_with_dimensionality_reduction(
93
+ features: torch.Tensor,
94
+ num_sample: int,
95
+ metric: Literal["cosine", "euclidean"],
96
+ rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
97
+ q: float, knn: int,
98
+ seed: int, device: str,
99
+ reduction: Callable[..., BaseEstimator],
100
+ reduction_dim: int,
101
+ reduction_kwargs: Dict[str, Any],
102
+ transform_func: Callable[[torch.Tensor], torch.Tensor] = _identity,
103
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ subgraph_indices = run_subgraph_sampling(
105
+ features,
106
+ num_sample=num_sample,
107
+ sample_method="farthest",
108
+ )
109
+
110
+ _inp = features[subgraph_indices].cpu().numpy()
111
+ _subgraph_embed = reduction(
112
+ n_components=reduction_dim,
113
+ metric=metric,
114
+ random_state=seed,
115
+ **reduction_kwargs
116
+ ).fit_transform(_inp)
117
+
118
+ _subgraph_embed = torch.tensor(_subgraph_embed, dtype=torch.float32)
119
+ X_nd = transform_func(propagate_knn(
120
+ _subgraph_embed,
121
+ features,
122
+ features[subgraph_indices],
123
+ distance=metric,
124
+ knn=knn,
125
+ device=device,
126
+ move_output_to_cpu=True,
127
+ ))
128
+ rgb = rgb_func(X_nd, q)
129
+ return X_nd.numpy(force=True), rgb
130
+
131
+
132
+ def rgb_from_tsne_2d(
133
+ features: torch.Tensor,
134
+ num_sample: int = 1000,
135
+ perplexity: int = 150,
136
+ metric: Literal["cosine", "euclidean"] = "cosine",
137
+ device: str = None,
138
+ seed: int = 0,
139
+ q: float = 0.95,
140
+ knn: int = 10,
141
+ **kwargs: Any,
142
+ ):
143
+ """
144
+ Returns:
145
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
146
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
147
+ """
148
+ try:
149
+ from sklearn.manifold import TSNE
150
+ except ImportError:
151
+ raise ImportError(
152
+ "sklearn import failed, please install `pip install scikit-learn`"
153
+ )
154
+ num_sample = min(num_sample, features.shape[0])
155
+ if perplexity > num_sample // 2:
156
+ logging.warning(
157
+ f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
158
+ )
159
+ perplexity = num_sample // 2
160
+
161
+ x2d, rgb = _rgb_with_dimensionality_reduction(
162
+ features=features,
163
+ num_sample=num_sample,
164
+ metric=metric,
165
+ rgb_func=rgb_from_2d_colormap,
166
+ q=q, knn=knn,
167
+ seed=seed, device=device,
168
+ reduction=TSNE, reduction_dim=2, reduction_kwargs={
169
+ "perplexity": perplexity,
170
+ },
171
+ )
172
+
173
+ return x2d, rgb
174
+
175
+
176
+ def rgb_from_tsne_3d(
177
+ features: torch.Tensor,
178
+ num_sample: int = 1000,
179
+ perplexity: int = 150,
180
+ metric: Literal["cosine", "euclidean"] = "cosine",
181
+ device: str = None,
182
+ seed: int = 0,
183
+ q: float = 0.95,
184
+ knn: int = 10,
185
+ **kwargs: Any,
186
+ ):
187
+ """
188
+ Returns:
189
+ (torch.Tensor): Embedding in 3D, shape (n_samples, 3)
190
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
191
+ """
192
+ try:
193
+ from sklearn.manifold import TSNE
194
+ except ImportError:
195
+ raise ImportError(
196
+ "sklearn import failed, please install `pip install scikit-learn`"
197
+ )
198
+ num_sample = min(num_sample, features.shape[0])
199
+ if perplexity > num_sample // 2:
200
+ logging.warning(
201
+ f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
202
+ )
203
+ perplexity = num_sample // 2
204
+
205
+ x3d, rgb = _rgb_with_dimensionality_reduction(
206
+ features=features,
207
+ num_sample=num_sample,
208
+ metric=metric,
209
+ rgb_func=rgb_from_3d_rgb_cube,
210
+ q=q, knn=knn,
211
+ seed=seed, device=device,
212
+ reduction=TSNE, reduction_dim=3, reduction_kwargs={
213
+ "perplexity": perplexity,
214
+ },
215
+ )
216
+
217
+ return x3d, rgb
218
+
219
+
220
+ def rgb_from_cosine_tsne_3d(
221
+ features: torch.Tensor,
222
+ num_sample: int = 1000,
223
+ perplexity: int = 150,
224
+ device: str = None,
225
+ seed: int = 0,
226
+ q: float = 0.95,
227
+ knn: int = 10,
228
+ **kwargs: Any,
229
+ ):
230
+ """
231
+ Returns:
232
+ (torch.Tensor): Embedding in 3D, shape (n_samples, 3)
233
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
234
+ """
235
+ try:
236
+ from sklearn.manifold import TSNE
237
+ except ImportError:
238
+ raise ImportError(
239
+ "sklearn import failed, please install `pip install scikit-learn`"
240
+ )
241
+ num_sample = min(num_sample, features.shape[0])
242
+ if perplexity > num_sample // 2:
243
+ logging.warning(
244
+ f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
245
+ )
246
+ perplexity = num_sample // 2
247
+
248
+
249
+ def cosine_to_rbf(X: torch.Tensor) -> torch.Tensor: # [B... x N x 3]
250
+ normalized_X = X / torch.norm(X, p=2, dim=-1, keepdim=True) # [B... x N x 3]
251
+ D = 1 - normalized_X @ normalized_X.mT # [B... x N x N]
252
+
253
+ G = (D[..., :1, 1:] ** 2 + D[..., 1:, :1] ** 2 - D[..., 1:, 1:] ** 2) / 2 # [B... x (N - 1) x (N - 1)]
254
+ L, V = torch.linalg.eigh(G) # [B... x (N - 1)], [B... x (N - 1) x (N - 1)]
255
+ sqrtG = V[..., -3:] * (L[..., None, -3:] ** 0.5) # [B... x (N - 1) x 3]
256
+
257
+ Y = torch.cat((torch.zeros_like(sqrtG[..., :1, :]), sqrtG), dim=-2) # [B... x N x 3]
258
+ Y = Y - torch.mean(Y, dim=-2, keepdim=True)
259
+ return Y
260
+
261
+ def rgb_from_cosine(X_3d: torch.Tensor, q: float) -> torch.Tensor:
262
+ return rgb_from_3d_rgb_cube(cosine_to_rbf(X_3d), q=q)
263
+
264
+ x3d, rgb = _rgb_with_dimensionality_reduction(
265
+ features=features,
266
+ num_sample=num_sample,
267
+ metric="cosine",
268
+ rgb_func=rgb_from_cosine,
269
+ q=q, knn=knn,
270
+ seed=seed, device=device,
271
+ reduction=TSNE, reduction_dim=3, reduction_kwargs={
272
+ "perplexity": perplexity,
273
+ },
274
+ )
275
+
276
+ return x3d, rgb
277
+
278
+
279
+ def rgb_from_umap_2d(
280
+ features: torch.Tensor,
281
+ num_sample: int = 1000,
282
+ n_neighbors: int = 150,
283
+ min_dist: float = 0.1,
284
+ metric: Literal["cosine", "euclidean"] = "cosine",
285
+ device: str = None,
286
+ seed: int = 0,
287
+ q: float = 0.95,
288
+ knn: int = 10,
289
+ **kwargs: Any,
290
+ ):
291
+ """
292
+ Returns:
293
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
294
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
295
+ """
296
+ try:
297
+ from umap import UMAP
298
+ except ImportError:
299
+ raise ImportError("umap import failed, please install `pip install umap-learn`")
300
+
301
+ x2d, rgb = _rgb_with_dimensionality_reduction(
302
+ features=features,
303
+ num_sample=num_sample,
304
+ metric=metric,
305
+ rgb_func=rgb_from_2d_colormap,
306
+ q=q, knn=knn,
307
+ seed=seed, device=device,
308
+ reduction=UMAP, reduction_dim=2, reduction_kwargs={
309
+ "n_neighbors": n_neighbors,
310
+ "min_dist": min_dist,
311
+ },
312
+ )
313
+
314
+ return x2d, rgb
315
+
316
+
317
+ def rgb_from_umap_sphere(
318
+ features: torch.Tensor,
319
+ num_sample: int = 1000,
320
+ n_neighbors: int = 150,
321
+ min_dist: float = 0.1,
322
+ metric: Literal["cosine", "euclidean"] = "cosine",
323
+ device: str = None,
324
+ seed: int = 0,
325
+ q: float = 0.95,
326
+ knn: int = 10,
327
+ **kwargs: Any,
328
+ ):
329
+ """
330
+ Returns:
331
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
332
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
333
+ """
334
+ try:
335
+ from umap import UMAP
336
+ except ImportError:
337
+ raise ImportError("umap import failed, please install `pip install umap-learn`")
338
+
339
+ def transform_func(X: torch.Tensor) -> torch.Tensor:
340
+ return torch.stack((
341
+ torch.sin(X[:, 0]) * torch.cos(X[:, 1]),
342
+ torch.sin(X[:, 0]) * torch.sin(X[:, 1]),
343
+ torch.cos(X[:, 0]),
344
+ ), dim=1)
345
+
346
+ x3d, rgb = _rgb_with_dimensionality_reduction(
347
+ features=features,
348
+ num_sample=num_sample,
349
+ metric=metric,
350
+ rgb_func=rgb_from_3d_rgb_cube,
351
+ q=q, knn=knn,
352
+ seed=seed, device=device,
353
+ reduction=UMAP, reduction_dim=2, reduction_kwargs={
354
+ "n_neighbors": n_neighbors,
355
+ "min_dist": min_dist,
356
+ "output_metric": "haversine",
357
+ },
358
+ transform_func=transform_func
359
+ )
360
+
361
+ return x3d, rgb
362
+
363
+
364
+ def rgb_from_umap_3d(
365
+ features: torch.Tensor,
366
+ num_sample: int = 1000,
367
+ n_neighbors: int = 150,
368
+ min_dist: float = 0.1,
369
+ metric: Literal["cosine", "euclidean"] = "cosine",
370
+ device: str = None,
371
+ seed: int = 0,
372
+ q: float = 0.95,
373
+ knn: int = 10,
374
+ **kwargs: Any,
375
+ ):
376
+ """
377
+ Returns:
378
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
379
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
380
+ """
381
+ try:
382
+ from umap import UMAP
383
+ except ImportError:
384
+ raise ImportError("umap import failed, please install `pip install umap-learn`")
385
+
386
+ x3d, rgb = _rgb_with_dimensionality_reduction(
387
+ features=features,
388
+ num_sample=num_sample,
389
+ metric=metric,
390
+ rgb_func=rgb_from_3d_rgb_cube,
391
+ q=q, knn=knn,
392
+ seed=seed, device=device,
393
+ reduction=UMAP, reduction_dim=3, reduction_kwargs={
394
+ "n_neighbors": n_neighbors,
395
+ "min_dist": min_dist,
396
+ },
397
+ )
398
+
399
+ return x3d, rgb
400
+
401
+
402
+ def flatten_sphere(X_3d):
403
+ x = np.arctan2(X_3d[:, 0], X_3d[:, 1])
404
+ y = -np.arccos(X_3d[:, 2])
405
+ X_2d = np.stack([x, y], axis=1)
406
+ return X_2d
407
+
408
+
409
+ def rotate_rgb_cube(rgb, position=1):
410
+ """rotate RGB cube to different position
411
+
412
+ Args:
413
+ rgb (torch.Tensor): RGB color space [0, 1], shape (*, 3)
414
+ position (int): position to rotate, 0, 1, 2, 3, 4, 5, 6
415
+
416
+ Returns:
417
+ torch.Tensor: RGB color space, shape (n_samples, 3)
418
+ """
419
+ assert position in range(0, 7), "position should be 0, 1, 2, 3, 4, 5, 6"
420
+ rotation_matrix = torch.tensor(
421
+ [
422
+ [0, 1, 0],
423
+ [0, 0, 1],
424
+ [1, 0, 0],
425
+ ]
426
+ ).float()
427
+ n_mul = position % 3
428
+ rotation_matrix = torch.matrix_power(rotation_matrix, n_mul)
429
+ rgb = rgb @ rotation_matrix
430
+ if position > 3:
431
+ rgb = 1 - rgb
432
+ return rgb
433
+
434
+
435
+ def rgb_from_3d_rgb_cube(X_3d, q=0.95):
436
+ """convert 3D t-SNE to RGB color space
437
+
438
+ Args:
439
+ X_3d (torch.Tensor): 3D t-SNE embedding, shape (n_samples, 3)
440
+ q (float): quantile, default 0.95
441
+
442
+ Returns:
443
+ torch.Tensor: RGB color space, shape (n_samples, 3)
444
+ """
445
+ assert X_3d.shape[1] == 3, "input should be (n_samples, 3)"
446
+ assert len(X_3d.shape) == 2, "input should be (n_samples, 3)"
447
+ rgb = []
448
+ for i in range(3):
449
+ rgb.append(quantile_normalize(X_3d[:, i], q=q))
450
+ rgb = torch.stack(rgb, dim=-1)
451
+ return rgb
452
+
453
+
454
+ def convert_to_lab_color(rgb, full_range=True):
455
+ from skimage import color
456
+ import copy
457
+
458
+ if isinstance(rgb, torch.Tensor):
459
+ rgb = rgb.cpu().numpy()
460
+ _rgb = copy.deepcopy(rgb)
461
+ _rgb[..., 0] = _rgb[..., 0] * 100
462
+ if full_range:
463
+ _rgb[..., 1] = _rgb[..., 1] * 255 - 128
464
+ _rgb[..., 2] = _rgb[..., 2] * 255 - 128
465
+ else:
466
+ _rgb[..., 1] = _rgb[..., 1] * 100 - 50
467
+ _rgb[..., 2] = _rgb[..., 2] * 100 - 50
468
+ lab_rgb = color.lab2rgb(_rgb)
469
+ return lab_rgb
470
+
471
+
472
+ def rgb_from_2d_colormap(X_2d, q=0.95):
473
+ xy = X_2d.clone()
474
+ for i in range(2):
475
+ xy[:, i] = quantile_normalize(xy[:, i], q=q)
476
+
477
+ try:
478
+ from pycolormap_2d import (
479
+ ColorMap2DBremm,
480
+ ColorMap2DZiegler,
481
+ ColorMap2DCubeDiagonal,
482
+ ColorMap2DSchumann,
483
+ )
484
+ except ImportError:
485
+ raise ImportError(
486
+ "pycolormap_2d import failed, please install `pip install pycolormap-2d`"
487
+ )
488
+
489
+ cmap = ColorMap2DCubeDiagonal()
490
+ xy = xy.cpu().numpy()
491
+ len_x, len_y = cmap._cmap_data.shape[:2]
492
+ x = (xy[:, 0] * (len_x - 1)).astype(int)
493
+ y = (xy[:, 1] * (len_y - 1)).astype(int)
494
+ rgb = cmap._cmap_data[x, y]
495
+ rgb = torch.tensor(rgb, dtype=torch.float32) / 255
496
+ return rgb
497
+
498
+
499
+ def propagate_rgb_color(
500
+ rgb: torch.Tensor,
501
+ eigenvectors: torch.Tensor,
502
+ new_eigenvectors: torch.Tensor,
503
+ knn: int = 10,
504
+ num_sample: int = 1000,
505
+ sample_method: Literal["farthest", "random"] = "farthest",
506
+ chunk_size: int = 8096,
507
+ device: str = None,
508
+ use_tqdm: bool = False,
509
+ ):
510
+ """Propagate RGB color to new nodes using KNN.
511
+ Args:
512
+ rgb (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
513
+ features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
514
+ new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
515
+ knn (int): number of KNN to propagate RGB color, default 1
516
+ num_sample (int): number of samples for subgraph sampling, default 50000
517
+ sample_method (str): sample method, 'farthest' (default) or 'random'
518
+ chunk_size (int): chunk size for matrix multiplication, default 8096
519
+ device (str): device to use for computation, if None, will not change device
520
+ use_tqdm (bool): show progress bar when propagating RGB color from subgraph to full graph
521
+
522
+ Returns:
523
+ torch.Tensor: propagated RGB color for each data sample, shape (n_new_samples, 3)
524
+
525
+ Examples:
526
+ >>> old_rgb = torch.randn(3000, 3)
527
+ >>> old_eigenvectors = torch.randn(3000, 20)
528
+ >>> new_eigenvectors = torch.randn(200, 20)
529
+ >>> new_rgb = propagate_rgb_color(old_rgb, new_eigenvectors, old_eigenvectors)
530
+ >>> # new_eigenvectors.shape = (200, 3)
531
+ """
532
+ return propagate_eigenvectors(
533
+ eigenvectors=rgb,
534
+ features=eigenvectors,
535
+ new_features=new_eigenvectors,
536
+ knn=knn,
537
+ num_sample=num_sample,
538
+ sample_method=sample_method,
539
+ chunk_size=chunk_size,
540
+ device=device,
541
+ use_tqdm=use_tqdm,
542
+ )
543
+
544
+
545
+ # application: get segmentation mask fron a reference eigenvector (point prompt)
546
+ def _transform_heatmap(heatmap, gamma=1.0):
547
+ """Transform the heatmap using gamma, normalize and min-max normalization.
548
+
549
+ Args:
550
+ heatmap (torch.Tensor): distance heatmap, shape (B, H, W)
551
+ gamma (float, optional): scaling factor, higher means smaller mask. Defaults to 1.0.
552
+
553
+ Returns:
554
+ torch.Tensor: transformed heatmap, shape (B, H, W)
555
+ """
556
+ # normalize the heatmap
557
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
558
+ heatmap = torch.exp(heatmap)
559
+ # transform the heatmap using gamma
560
+ # large gamma means more focus on the high values, hence smaller mask
561
+ heatmap = 1 / heatmap ** gamma
562
+ # min-max normalization [0, 1]
563
+ vmin, vmax = quantile_min_max(heatmap.flatten())
564
+ heatmap = (heatmap - vmin) / (vmax - vmin)
565
+ return heatmap
566
+
567
+
568
+ def _clean_mask(mask, min_area=500):
569
+ """clean the binary mask by removing small connected components.
570
+
571
+ Args:
572
+ - mask: A numpy image of a binary mask with 255 for the object and 0 for the background.
573
+ - min_area: Minimum area for a connected component to be considered valid (default 500).
574
+
575
+ Returns:
576
+ - bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
577
+ - cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
578
+ """
579
+
580
+ import cv2
581
+ # Find connected components in the cleaned mask
582
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
583
+
584
+ # Initialize an empty mask to store the final cleaned mask
585
+ final_cleaned_mask = np.zeros_like(mask)
586
+
587
+ # Collect bounding boxes for components that are larger than the threshold and update the cleaned mask
588
+ bounding_boxes = []
589
+ for i in range(1, num_labels): # Skip label 0 (background)
590
+ x, y, w, h, area = stats[i]
591
+ if area >= min_area:
592
+ # Add the bounding box of the valid component
593
+ bounding_boxes.append((x, y, w, h))
594
+ # Keep the valid components in the final cleaned mask
595
+ final_cleaned_mask[labels == i] = 255
596
+
597
+ return final_cleaned_mask, bounding_boxes
598
+
599
+
600
+ def get_mask(
601
+ all_eigvecs: torch.Tensor, prompt_eigvec: torch.Tensor,
602
+ threshold: float = 0.5, gamma: float = 1.0,
603
+ denoise: bool = True, denoise_area_th: int = 3):
604
+ """Segmentation mask from one prompt eigenvector (at a clicked latent pixel).
605
+ </br> The mask is computed by measuring the cosine similarity between the clicked eigenvector and all the eigenvectors in the latent space.
606
+ </br> 1. Compute the cosine similarity between the clicked eigenvector and all the eigenvectors in the latent space.
607
+ </br> 2. Transform the heatmap, normalize and apply scaling (gamma).
608
+ </br> 3. Threshold the heatmap to get the mask.
609
+ </br> 4. Optionally denoise the mask by removing small connected components
610
+
611
+ Args:
612
+ all_eigvecs (torch.Tensor): (B, H, W, num_eig)
613
+ prompt_eigvec (torch.Tensor): (num_eig,)
614
+ threshold (float, optional): mask threshold, higher means smaller mask. Defaults to 0.5.
615
+ gamma (float, optional): mask scaling factor, higher means smaller mask. Defaults to 1.0.
616
+ denoise (bool, optional): mask denoising flag. Defaults to True.
617
+ denoise_area_th (int, optional): mask denoising area threshold. higher means more aggressive denoising. Defaults to 3.
618
+
619
+ Returns:
620
+ np.ndarray: masks (B, H, W), 1 for object, 0 for background
621
+
622
+ Examples:
623
+ >>> all_eigvecs = torch.randn(10, 64, 64, 20)
624
+ >>> prompt_eigvec = all_eigvecs[0, 32, 32] # center pixel
625
+ >>> masks = get_mask(all_eigvecs, prompt_eigvec, threshold=0.5, gamma=1.0, denoise=True, denoise_area_th=3)
626
+ >>> # masks.shape = (10, 64, 64)
627
+ """
628
+
629
+ # normalize the eigenvectors to unit norm, to compute cosine similarity
630
+ if not check_if_normalized(all_eigvecs.reshape(-1, all_eigvecs.shape[-1])):
631
+ all_eigvecs = F.normalize(all_eigvecs, p=2, dim=-1)
632
+
633
+ prompt_eigvec = F.normalize(prompt_eigvec, p=2, dim=-1)
634
+
635
+ # compute the cosine similarity
636
+ cos_sim = all_eigvecs @ prompt_eigvec.unsqueeze(-1) # (B, H, W, 1)
637
+ cos_sim = cos_sim.squeeze(-1) # (B, H, W)
638
+
639
+ heatmap = 1 - cos_sim
640
+
641
+ # transform the heatmap, normalize and apply scaling (gamma)
642
+ heatmap = _transform_heatmap(heatmap, gamma=gamma)
643
+
644
+ masks = heatmap > threshold
645
+ masks = masks.cpu().numpy().astype(np.uint8)
646
+
647
+ if denoise:
648
+ cleaned_masks = []
649
+ for mask in masks:
650
+ cleaned_mask, _ = _clean_mask(mask, min_area=denoise_area_th)
651
+ cleaned_masks.append(cleaned_mask)
652
+ cleaned_masks = np.stack(cleaned_masks)
653
+ return cleaned_masks
654
+
655
+ return masks
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2018 The Python Packaging Authority
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.