nystrom-ncut 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- nystrom_ncut/__init__.py +4 -4
- nystrom_ncut/common.py +20 -0
- nystrom_ncut/ncut_pytorch.py +192 -467
- nystrom_ncut/nystrom.py +4 -2
- nystrom_ncut/propagation_utils.py +15 -57
- nystrom_ncut/visualize_utils.py +9 -98
- {nystrom_ncut-0.0.1.dist-info → nystrom_ncut-0.0.2.dist-info}/METADATA +1 -1
- nystrom_ncut-0.0.2.dist-info/RECORD +11 -0
- nystrom_ncut/new_ncut_pytorch.py +0 -241
- nystrom_ncut-0.0.1.dist-info/RECORD +0 -11
- {nystrom_ncut-0.0.1.dist-info → nystrom_ncut-0.0.2.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.0.1.dist-info → nystrom_ncut-0.0.2.dist-info}/WHEEL +0 -0
- {nystrom_ncut-0.0.1.dist-info → nystrom_ncut-0.0.2.dist-info}/top_level.txt +0 -0
nystrom_ncut/__init__.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
from .ncut_pytorch import
|
1
|
+
from .ncut_pytorch import (
|
2
|
+
NCUT,
|
3
|
+
axis_align,
|
4
|
+
)
|
2
5
|
from .propagation_utils import (
|
3
6
|
affinity_from_features,
|
4
7
|
propagate_eigenvectors,
|
@@ -6,7 +9,6 @@ from .propagation_utils import (
|
|
6
9
|
quantile_normalize,
|
7
10
|
)
|
8
11
|
from .visualize_utils import (
|
9
|
-
eigenvector_to_rgb,
|
10
12
|
rgb_from_tsne_3d,
|
11
13
|
rgb_from_umap_sphere,
|
12
14
|
rgb_from_tsne_2d,
|
@@ -18,5 +20,3 @@ from .visualize_utils import (
|
|
18
20
|
propagate_rgb_color,
|
19
21
|
get_mask,
|
20
22
|
)
|
21
|
-
from .ncut_pytorch import nystrom_ncut, ncut
|
22
|
-
from .ncut_pytorch import kway_ncut, axis_align
|
nystrom_ncut/common.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as Fn
|
6
|
+
|
7
|
+
|
8
|
+
def ceildiv(a: int, b: int) -> int:
|
9
|
+
return -(-a // b)
|
10
|
+
|
11
|
+
|
12
|
+
def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> torch.Tensor:
|
13
|
+
numel = np.prod(x.shape[:-1])
|
14
|
+
n = min(n, numel)
|
15
|
+
random_indices = torch.randperm(numel)[:n]
|
16
|
+
_x = x.flatten(0, -2)[random_indices]
|
17
|
+
if torch.allclose(torch.norm(_x, **normalize_kwargs), torch.ones(n, device=x.device)):
|
18
|
+
return x
|
19
|
+
else:
|
20
|
+
return Fn.normalize(x, **normalize_kwargs)
|