nystrom-ncut 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,655 @@
1
+ import logging
2
+ from typing import Any, Callable, Dict, Literal, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from sklearn.base import BaseEstimator
8
+
9
+ from .propagation_utils import (
10
+ run_subgraph_sampling,
11
+ propagate_knn,
12
+ propagate_eigenvectors,
13
+ check_if_normalized,
14
+ quantile_min_max,
15
+ quantile_normalize
16
+ )
17
+
18
+
19
+ def _identity(X: torch.Tensor) -> torch.Tensor:
20
+ return X
21
+
22
+
23
+ def eigenvector_to_rgb(
24
+ eigen_vector: torch.Tensor,
25
+ method: Literal["tsne_2d", "tsne_3d", "umap_sphere", "umap_2d", "umap_3d"] = "tsne_3d",
26
+ num_sample: int = 1000,
27
+ perplexity: int = 150,
28
+ n_neighbors: int = 150,
29
+ min_distance: float = 0.1,
30
+ metric: Literal["cosine", "euclidean"] = "cosine",
31
+ device: str = None,
32
+ q: float = 0.95,
33
+ knn: int = 10,
34
+ seed: int = 0,
35
+ ):
36
+ """Use t-SNE or UMAP to convert eigenvectors (more than 3) to RGB color (3D RGB CUBE).
37
+
38
+ Args:
39
+ eigen_vector (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
40
+ method (str): method to convert eigenvectors to RGB,
41
+ choices are: ['tsne_2d', 'tsne_3d', 'umap_sphere', 'umap_2d', 'umap_3d']
42
+ num_sample (int): number of samples for Nystrom-like approximation, increase for better approximation
43
+ perplexity (int): perplexity for t-SNE, increase for more global structure
44
+ n_neighbors (int): number of neighbors for UMAP, increase for more global structure
45
+ min_distance (float): minimum distance for UMAP
46
+ metric (str): distance metric, default 'cosine'
47
+ device (str): device to use for computation, if None, will not change device
48
+ q (float): quantile for RGB normalization, default 0.95. lower q results in more sharp colors
49
+ knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
50
+ smaller knn result in more sharp colors, default 1. knn>1 will smooth-out the embedding
51
+ in the t-SNE or UMAP space.
52
+ seed (int): random seed for t-SNE or UMAP
53
+
54
+ Examples:
55
+ >>> from ncut_pytorch import eigenvector_to_rgb
56
+ >>> X_3d, rgb = eigenvector_to_rgb(eigenvectors, method='tsne_3d')
57
+ >>> print(X_3d.shape, rgb.shape)
58
+ >>> # (10000, 3) (10000, 3)
59
+
60
+ Returns:
61
+ (torch.Tensor): t-SNE or UMAP embedding, shape (n_samples, 2) or (n_samples, 3)
62
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
63
+ """
64
+ kwargs = {
65
+ "num_sample": num_sample,
66
+ "perplexity": perplexity,
67
+ "n_neighbors": n_neighbors,
68
+ "min_distance": min_distance,
69
+ "metric": metric,
70
+ "device": device,
71
+ "q": q,
72
+ "knn": knn,
73
+ "seed": seed,
74
+ }
75
+
76
+ if method == "tsne_2d":
77
+ embed, rgb = rgb_from_tsne_2d(eigen_vector, **kwargs)
78
+ elif method == "tsne_3d":
79
+ embed, rgb = rgb_from_tsne_3d(eigen_vector, **kwargs)
80
+ elif method == "umap_sphere":
81
+ embed, rgb = rgb_from_umap_sphere(eigen_vector, **kwargs)
82
+ elif method == "umap_2d":
83
+ embed, rgb = rgb_from_umap_2d(eigen_vector, **kwargs)
84
+ elif method == "umap_3d":
85
+ embed, rgb = rgb_from_umap_3d(eigen_vector, **kwargs)
86
+ else:
87
+ raise ValueError("method should be 'tsne_2d', 'tsne_3d' or 'umap_sphere'")
88
+
89
+ return embed, rgb
90
+
91
+
92
+ def _rgb_with_dimensionality_reduction(
93
+ features: torch.Tensor,
94
+ num_sample: int,
95
+ metric: Literal["cosine", "euclidean"],
96
+ rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
97
+ q: float, knn: int,
98
+ seed: int, device: str,
99
+ reduction: Callable[..., BaseEstimator],
100
+ reduction_dim: int,
101
+ reduction_kwargs: Dict[str, Any],
102
+ transform_func: Callable[[torch.Tensor], torch.Tensor] = _identity,
103
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ subgraph_indices = run_subgraph_sampling(
105
+ features,
106
+ num_sample=num_sample,
107
+ sample_method="farthest",
108
+ )
109
+
110
+ _inp = features[subgraph_indices].cpu().numpy()
111
+ _subgraph_embed = reduction(
112
+ n_components=reduction_dim,
113
+ metric=metric,
114
+ random_state=seed,
115
+ **reduction_kwargs
116
+ ).fit_transform(_inp)
117
+
118
+ _subgraph_embed = torch.tensor(_subgraph_embed, dtype=torch.float32)
119
+ X_nd = transform_func(propagate_knn(
120
+ _subgraph_embed,
121
+ features,
122
+ features[subgraph_indices],
123
+ distance=metric,
124
+ knn=knn,
125
+ device=device,
126
+ move_output_to_cpu=True,
127
+ ))
128
+ rgb = rgb_func(X_nd, q)
129
+ return X_nd.numpy(force=True), rgb
130
+
131
+
132
+ def rgb_from_tsne_2d(
133
+ features: torch.Tensor,
134
+ num_sample: int = 1000,
135
+ perplexity: int = 150,
136
+ metric: Literal["cosine", "euclidean"] = "cosine",
137
+ device: str = None,
138
+ seed: int = 0,
139
+ q: float = 0.95,
140
+ knn: int = 10,
141
+ **kwargs: Any,
142
+ ):
143
+ """
144
+ Returns:
145
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
146
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
147
+ """
148
+ try:
149
+ from sklearn.manifold import TSNE
150
+ except ImportError:
151
+ raise ImportError(
152
+ "sklearn import failed, please install `pip install scikit-learn`"
153
+ )
154
+ num_sample = min(num_sample, features.shape[0])
155
+ if perplexity > num_sample // 2:
156
+ logging.warning(
157
+ f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
158
+ )
159
+ perplexity = num_sample // 2
160
+
161
+ x2d, rgb = _rgb_with_dimensionality_reduction(
162
+ features=features,
163
+ num_sample=num_sample,
164
+ metric=metric,
165
+ rgb_func=rgb_from_2d_colormap,
166
+ q=q, knn=knn,
167
+ seed=seed, device=device,
168
+ reduction=TSNE, reduction_dim=2, reduction_kwargs={
169
+ "perplexity": perplexity,
170
+ },
171
+ )
172
+
173
+ return x2d, rgb
174
+
175
+
176
+ def rgb_from_tsne_3d(
177
+ features: torch.Tensor,
178
+ num_sample: int = 1000,
179
+ perplexity: int = 150,
180
+ metric: Literal["cosine", "euclidean"] = "cosine",
181
+ device: str = None,
182
+ seed: int = 0,
183
+ q: float = 0.95,
184
+ knn: int = 10,
185
+ **kwargs: Any,
186
+ ):
187
+ """
188
+ Returns:
189
+ (torch.Tensor): Embedding in 3D, shape (n_samples, 3)
190
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
191
+ """
192
+ try:
193
+ from sklearn.manifold import TSNE
194
+ except ImportError:
195
+ raise ImportError(
196
+ "sklearn import failed, please install `pip install scikit-learn`"
197
+ )
198
+ num_sample = min(num_sample, features.shape[0])
199
+ if perplexity > num_sample // 2:
200
+ logging.warning(
201
+ f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
202
+ )
203
+ perplexity = num_sample // 2
204
+
205
+ x3d, rgb = _rgb_with_dimensionality_reduction(
206
+ features=features,
207
+ num_sample=num_sample,
208
+ metric=metric,
209
+ rgb_func=rgb_from_3d_rgb_cube,
210
+ q=q, knn=knn,
211
+ seed=seed, device=device,
212
+ reduction=TSNE, reduction_dim=3, reduction_kwargs={
213
+ "perplexity": perplexity,
214
+ },
215
+ )
216
+
217
+ return x3d, rgb
218
+
219
+
220
+ def rgb_from_cosine_tsne_3d(
221
+ features: torch.Tensor,
222
+ num_sample: int = 1000,
223
+ perplexity: int = 150,
224
+ device: str = None,
225
+ seed: int = 0,
226
+ q: float = 0.95,
227
+ knn: int = 10,
228
+ **kwargs: Any,
229
+ ):
230
+ """
231
+ Returns:
232
+ (torch.Tensor): Embedding in 3D, shape (n_samples, 3)
233
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
234
+ """
235
+ try:
236
+ from sklearn.manifold import TSNE
237
+ except ImportError:
238
+ raise ImportError(
239
+ "sklearn import failed, please install `pip install scikit-learn`"
240
+ )
241
+ num_sample = min(num_sample, features.shape[0])
242
+ if perplexity > num_sample // 2:
243
+ logging.warning(
244
+ f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
245
+ )
246
+ perplexity = num_sample // 2
247
+
248
+
249
+ def cosine_to_rbf(X: torch.Tensor) -> torch.Tensor: # [B... x N x 3]
250
+ normalized_X = X / torch.norm(X, p=2, dim=-1, keepdim=True) # [B... x N x 3]
251
+ D = 1 - normalized_X @ normalized_X.mT # [B... x N x N]
252
+
253
+ G = (D[..., :1, 1:] ** 2 + D[..., 1:, :1] ** 2 - D[..., 1:, 1:] ** 2) / 2 # [B... x (N - 1) x (N - 1)]
254
+ L, V = torch.linalg.eigh(G) # [B... x (N - 1)], [B... x (N - 1) x (N - 1)]
255
+ sqrtG = V[..., -3:] * (L[..., None, -3:] ** 0.5) # [B... x (N - 1) x 3]
256
+
257
+ Y = torch.cat((torch.zeros_like(sqrtG[..., :1, :]), sqrtG), dim=-2) # [B... x N x 3]
258
+ Y = Y - torch.mean(Y, dim=-2, keepdim=True)
259
+ return Y
260
+
261
+ def rgb_from_cosine(X_3d: torch.Tensor, q: float) -> torch.Tensor:
262
+ return rgb_from_3d_rgb_cube(cosine_to_rbf(X_3d), q=q)
263
+
264
+ x3d, rgb = _rgb_with_dimensionality_reduction(
265
+ features=features,
266
+ num_sample=num_sample,
267
+ metric="cosine",
268
+ rgb_func=rgb_from_cosine,
269
+ q=q, knn=knn,
270
+ seed=seed, device=device,
271
+ reduction=TSNE, reduction_dim=3, reduction_kwargs={
272
+ "perplexity": perplexity,
273
+ },
274
+ )
275
+
276
+ return x3d, rgb
277
+
278
+
279
+ def rgb_from_umap_2d(
280
+ features: torch.Tensor,
281
+ num_sample: int = 1000,
282
+ n_neighbors: int = 150,
283
+ min_dist: float = 0.1,
284
+ metric: Literal["cosine", "euclidean"] = "cosine",
285
+ device: str = None,
286
+ seed: int = 0,
287
+ q: float = 0.95,
288
+ knn: int = 10,
289
+ **kwargs: Any,
290
+ ):
291
+ """
292
+ Returns:
293
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
294
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
295
+ """
296
+ try:
297
+ from umap import UMAP
298
+ except ImportError:
299
+ raise ImportError("umap import failed, please install `pip install umap-learn`")
300
+
301
+ x2d, rgb = _rgb_with_dimensionality_reduction(
302
+ features=features,
303
+ num_sample=num_sample,
304
+ metric=metric,
305
+ rgb_func=rgb_from_2d_colormap,
306
+ q=q, knn=knn,
307
+ seed=seed, device=device,
308
+ reduction=UMAP, reduction_dim=2, reduction_kwargs={
309
+ "n_neighbors": n_neighbors,
310
+ "min_dist": min_dist,
311
+ },
312
+ )
313
+
314
+ return x2d, rgb
315
+
316
+
317
+ def rgb_from_umap_sphere(
318
+ features: torch.Tensor,
319
+ num_sample: int = 1000,
320
+ n_neighbors: int = 150,
321
+ min_dist: float = 0.1,
322
+ metric: Literal["cosine", "euclidean"] = "cosine",
323
+ device: str = None,
324
+ seed: int = 0,
325
+ q: float = 0.95,
326
+ knn: int = 10,
327
+ **kwargs: Any,
328
+ ):
329
+ """
330
+ Returns:
331
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
332
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
333
+ """
334
+ try:
335
+ from umap import UMAP
336
+ except ImportError:
337
+ raise ImportError("umap import failed, please install `pip install umap-learn`")
338
+
339
+ def transform_func(X: torch.Tensor) -> torch.Tensor:
340
+ return torch.stack((
341
+ torch.sin(X[:, 0]) * torch.cos(X[:, 1]),
342
+ torch.sin(X[:, 0]) * torch.sin(X[:, 1]),
343
+ torch.cos(X[:, 0]),
344
+ ), dim=1)
345
+
346
+ x3d, rgb = _rgb_with_dimensionality_reduction(
347
+ features=features,
348
+ num_sample=num_sample,
349
+ metric=metric,
350
+ rgb_func=rgb_from_3d_rgb_cube,
351
+ q=q, knn=knn,
352
+ seed=seed, device=device,
353
+ reduction=UMAP, reduction_dim=2, reduction_kwargs={
354
+ "n_neighbors": n_neighbors,
355
+ "min_dist": min_dist,
356
+ "output_metric": "haversine",
357
+ },
358
+ transform_func=transform_func
359
+ )
360
+
361
+ return x3d, rgb
362
+
363
+
364
+ def rgb_from_umap_3d(
365
+ features: torch.Tensor,
366
+ num_sample: int = 1000,
367
+ n_neighbors: int = 150,
368
+ min_dist: float = 0.1,
369
+ metric: Literal["cosine", "euclidean"] = "cosine",
370
+ device: str = None,
371
+ seed: int = 0,
372
+ q: float = 0.95,
373
+ knn: int = 10,
374
+ **kwargs: Any,
375
+ ):
376
+ """
377
+ Returns:
378
+ (torch.Tensor): Embedding in 2D, shape (n_samples, 2)
379
+ (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
380
+ """
381
+ try:
382
+ from umap import UMAP
383
+ except ImportError:
384
+ raise ImportError("umap import failed, please install `pip install umap-learn`")
385
+
386
+ x3d, rgb = _rgb_with_dimensionality_reduction(
387
+ features=features,
388
+ num_sample=num_sample,
389
+ metric=metric,
390
+ rgb_func=rgb_from_3d_rgb_cube,
391
+ q=q, knn=knn,
392
+ seed=seed, device=device,
393
+ reduction=UMAP, reduction_dim=3, reduction_kwargs={
394
+ "n_neighbors": n_neighbors,
395
+ "min_dist": min_dist,
396
+ },
397
+ )
398
+
399
+ return x3d, rgb
400
+
401
+
402
+ def flatten_sphere(X_3d):
403
+ x = np.arctan2(X_3d[:, 0], X_3d[:, 1])
404
+ y = -np.arccos(X_3d[:, 2])
405
+ X_2d = np.stack([x, y], axis=1)
406
+ return X_2d
407
+
408
+
409
+ def rotate_rgb_cube(rgb, position=1):
410
+ """rotate RGB cube to different position
411
+
412
+ Args:
413
+ rgb (torch.Tensor): RGB color space [0, 1], shape (*, 3)
414
+ position (int): position to rotate, 0, 1, 2, 3, 4, 5, 6
415
+
416
+ Returns:
417
+ torch.Tensor: RGB color space, shape (n_samples, 3)
418
+ """
419
+ assert position in range(0, 7), "position should be 0, 1, 2, 3, 4, 5, 6"
420
+ rotation_matrix = torch.tensor(
421
+ [
422
+ [0, 1, 0],
423
+ [0, 0, 1],
424
+ [1, 0, 0],
425
+ ]
426
+ ).float()
427
+ n_mul = position % 3
428
+ rotation_matrix = torch.matrix_power(rotation_matrix, n_mul)
429
+ rgb = rgb @ rotation_matrix
430
+ if position > 3:
431
+ rgb = 1 - rgb
432
+ return rgb
433
+
434
+
435
+ def rgb_from_3d_rgb_cube(X_3d, q=0.95):
436
+ """convert 3D t-SNE to RGB color space
437
+
438
+ Args:
439
+ X_3d (torch.Tensor): 3D t-SNE embedding, shape (n_samples, 3)
440
+ q (float): quantile, default 0.95
441
+
442
+ Returns:
443
+ torch.Tensor: RGB color space, shape (n_samples, 3)
444
+ """
445
+ assert X_3d.shape[1] == 3, "input should be (n_samples, 3)"
446
+ assert len(X_3d.shape) == 2, "input should be (n_samples, 3)"
447
+ rgb = []
448
+ for i in range(3):
449
+ rgb.append(quantile_normalize(X_3d[:, i], q=q))
450
+ rgb = torch.stack(rgb, dim=-1)
451
+ return rgb
452
+
453
+
454
+ def convert_to_lab_color(rgb, full_range=True):
455
+ from skimage import color
456
+ import copy
457
+
458
+ if isinstance(rgb, torch.Tensor):
459
+ rgb = rgb.cpu().numpy()
460
+ _rgb = copy.deepcopy(rgb)
461
+ _rgb[..., 0] = _rgb[..., 0] * 100
462
+ if full_range:
463
+ _rgb[..., 1] = _rgb[..., 1] * 255 - 128
464
+ _rgb[..., 2] = _rgb[..., 2] * 255 - 128
465
+ else:
466
+ _rgb[..., 1] = _rgb[..., 1] * 100 - 50
467
+ _rgb[..., 2] = _rgb[..., 2] * 100 - 50
468
+ lab_rgb = color.lab2rgb(_rgb)
469
+ return lab_rgb
470
+
471
+
472
+ def rgb_from_2d_colormap(X_2d, q=0.95):
473
+ xy = X_2d.clone()
474
+ for i in range(2):
475
+ xy[:, i] = quantile_normalize(xy[:, i], q=q)
476
+
477
+ try:
478
+ from pycolormap_2d import (
479
+ ColorMap2DBremm,
480
+ ColorMap2DZiegler,
481
+ ColorMap2DCubeDiagonal,
482
+ ColorMap2DSchumann,
483
+ )
484
+ except ImportError:
485
+ raise ImportError(
486
+ "pycolormap_2d import failed, please install `pip install pycolormap-2d`"
487
+ )
488
+
489
+ cmap = ColorMap2DCubeDiagonal()
490
+ xy = xy.cpu().numpy()
491
+ len_x, len_y = cmap._cmap_data.shape[:2]
492
+ x = (xy[:, 0] * (len_x - 1)).astype(int)
493
+ y = (xy[:, 1] * (len_y - 1)).astype(int)
494
+ rgb = cmap._cmap_data[x, y]
495
+ rgb = torch.tensor(rgb, dtype=torch.float32) / 255
496
+ return rgb
497
+
498
+
499
+ def propagate_rgb_color(
500
+ rgb: torch.Tensor,
501
+ eigenvectors: torch.Tensor,
502
+ new_eigenvectors: torch.Tensor,
503
+ knn: int = 10,
504
+ num_sample: int = 1000,
505
+ sample_method: Literal["farthest", "random"] = "farthest",
506
+ chunk_size: int = 8096,
507
+ device: str = None,
508
+ use_tqdm: bool = False,
509
+ ):
510
+ """Propagate RGB color to new nodes using KNN.
511
+ Args:
512
+ rgb (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
513
+ features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
514
+ new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
515
+ knn (int): number of KNN to propagate RGB color, default 1
516
+ num_sample (int): number of samples for subgraph sampling, default 50000
517
+ sample_method (str): sample method, 'farthest' (default) or 'random'
518
+ chunk_size (int): chunk size for matrix multiplication, default 8096
519
+ device (str): device to use for computation, if None, will not change device
520
+ use_tqdm (bool): show progress bar when propagating RGB color from subgraph to full graph
521
+
522
+ Returns:
523
+ torch.Tensor: propagated RGB color for each data sample, shape (n_new_samples, 3)
524
+
525
+ Examples:
526
+ >>> old_rgb = torch.randn(3000, 3)
527
+ >>> old_eigenvectors = torch.randn(3000, 20)
528
+ >>> new_eigenvectors = torch.randn(200, 20)
529
+ >>> new_rgb = propagate_rgb_color(old_rgb, new_eigenvectors, old_eigenvectors)
530
+ >>> # new_eigenvectors.shape = (200, 3)
531
+ """
532
+ return propagate_eigenvectors(
533
+ eigenvectors=rgb,
534
+ features=eigenvectors,
535
+ new_features=new_eigenvectors,
536
+ knn=knn,
537
+ num_sample=num_sample,
538
+ sample_method=sample_method,
539
+ chunk_size=chunk_size,
540
+ device=device,
541
+ use_tqdm=use_tqdm,
542
+ )
543
+
544
+
545
+ # application: get segmentation mask fron a reference eigenvector (point prompt)
546
+ def _transform_heatmap(heatmap, gamma=1.0):
547
+ """Transform the heatmap using gamma, normalize and min-max normalization.
548
+
549
+ Args:
550
+ heatmap (torch.Tensor): distance heatmap, shape (B, H, W)
551
+ gamma (float, optional): scaling factor, higher means smaller mask. Defaults to 1.0.
552
+
553
+ Returns:
554
+ torch.Tensor: transformed heatmap, shape (B, H, W)
555
+ """
556
+ # normalize the heatmap
557
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
558
+ heatmap = torch.exp(heatmap)
559
+ # transform the heatmap using gamma
560
+ # large gamma means more focus on the high values, hence smaller mask
561
+ heatmap = 1 / heatmap ** gamma
562
+ # min-max normalization [0, 1]
563
+ vmin, vmax = quantile_min_max(heatmap.flatten())
564
+ heatmap = (heatmap - vmin) / (vmax - vmin)
565
+ return heatmap
566
+
567
+
568
+ def _clean_mask(mask, min_area=500):
569
+ """clean the binary mask by removing small connected components.
570
+
571
+ Args:
572
+ - mask: A numpy image of a binary mask with 255 for the object and 0 for the background.
573
+ - min_area: Minimum area for a connected component to be considered valid (default 500).
574
+
575
+ Returns:
576
+ - bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
577
+ - cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
578
+ """
579
+
580
+ import cv2
581
+ # Find connected components in the cleaned mask
582
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
583
+
584
+ # Initialize an empty mask to store the final cleaned mask
585
+ final_cleaned_mask = np.zeros_like(mask)
586
+
587
+ # Collect bounding boxes for components that are larger than the threshold and update the cleaned mask
588
+ bounding_boxes = []
589
+ for i in range(1, num_labels): # Skip label 0 (background)
590
+ x, y, w, h, area = stats[i]
591
+ if area >= min_area:
592
+ # Add the bounding box of the valid component
593
+ bounding_boxes.append((x, y, w, h))
594
+ # Keep the valid components in the final cleaned mask
595
+ final_cleaned_mask[labels == i] = 255
596
+
597
+ return final_cleaned_mask, bounding_boxes
598
+
599
+
600
+ def get_mask(
601
+ all_eigvecs: torch.Tensor, prompt_eigvec: torch.Tensor,
602
+ threshold: float = 0.5, gamma: float = 1.0,
603
+ denoise: bool = True, denoise_area_th: int = 3):
604
+ """Segmentation mask from one prompt eigenvector (at a clicked latent pixel).
605
+ </br> The mask is computed by measuring the cosine similarity between the clicked eigenvector and all the eigenvectors in the latent space.
606
+ </br> 1. Compute the cosine similarity between the clicked eigenvector and all the eigenvectors in the latent space.
607
+ </br> 2. Transform the heatmap, normalize and apply scaling (gamma).
608
+ </br> 3. Threshold the heatmap to get the mask.
609
+ </br> 4. Optionally denoise the mask by removing small connected components
610
+
611
+ Args:
612
+ all_eigvecs (torch.Tensor): (B, H, W, num_eig)
613
+ prompt_eigvec (torch.Tensor): (num_eig,)
614
+ threshold (float, optional): mask threshold, higher means smaller mask. Defaults to 0.5.
615
+ gamma (float, optional): mask scaling factor, higher means smaller mask. Defaults to 1.0.
616
+ denoise (bool, optional): mask denoising flag. Defaults to True.
617
+ denoise_area_th (int, optional): mask denoising area threshold. higher means more aggressive denoising. Defaults to 3.
618
+
619
+ Returns:
620
+ np.ndarray: masks (B, H, W), 1 for object, 0 for background
621
+
622
+ Examples:
623
+ >>> all_eigvecs = torch.randn(10, 64, 64, 20)
624
+ >>> prompt_eigvec = all_eigvecs[0, 32, 32] # center pixel
625
+ >>> masks = get_mask(all_eigvecs, prompt_eigvec, threshold=0.5, gamma=1.0, denoise=True, denoise_area_th=3)
626
+ >>> # masks.shape = (10, 64, 64)
627
+ """
628
+
629
+ # normalize the eigenvectors to unit norm, to compute cosine similarity
630
+ if not check_if_normalized(all_eigvecs.reshape(-1, all_eigvecs.shape[-1])):
631
+ all_eigvecs = F.normalize(all_eigvecs, p=2, dim=-1)
632
+
633
+ prompt_eigvec = F.normalize(prompt_eigvec, p=2, dim=-1)
634
+
635
+ # compute the cosine similarity
636
+ cos_sim = all_eigvecs @ prompt_eigvec.unsqueeze(-1) # (B, H, W, 1)
637
+ cos_sim = cos_sim.squeeze(-1) # (B, H, W)
638
+
639
+ heatmap = 1 - cos_sim
640
+
641
+ # transform the heatmap, normalize and apply scaling (gamma)
642
+ heatmap = _transform_heatmap(heatmap, gamma=gamma)
643
+
644
+ masks = heatmap > threshold
645
+ masks = masks.cpu().numpy().astype(np.uint8)
646
+
647
+ if denoise:
648
+ cleaned_masks = []
649
+ for mask in masks:
650
+ cleaned_mask, _ = _clean_mask(mask, min_area=denoise_area_th)
651
+ cleaned_masks.append(cleaned_mask)
652
+ cleaned_masks = np.stack(cleaned_masks)
653
+ return cleaned_masks
654
+
655
+ return masks
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2018 The Python Packaging Authority
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.