nystrom-ncut 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,3 +1,5 @@
1
+ from typing import Tuple
2
+
1
3
  import torch
2
4
 
3
5
  from .nystrom import (
@@ -125,3 +127,15 @@ class DistanceRealization(OnlineNystromSubsampleFit):
125
127
  chunk_size=chunk_size,
126
128
  )
127
129
  self.distance: DistanceOptions = distance
130
+
131
+ def fit_transform(
132
+ self,
133
+ features: torch.Tensor,
134
+ precomputed_sampled_indices: torch.Tensor = None,
135
+ ) -> torch.Tensor:
136
+ V, L = OnlineNystromSubsampleFit.fit_transform(self, features, precomputed_sampled_indices)
137
+ return V * (L ** 0.5)
138
+
139
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor:
140
+ V, L = OnlineNystromSubsampleFit.transform(features)
141
+ return V * (L ** 0.5)
@@ -276,5 +276,7 @@ def solve_eig(
276
276
  eigen_vector = eigen_vector[:, indices]
277
277
 
278
278
  # correct the random rotation (flipping sign) of eigenvectors
279
- eigen_vector = eigen_vector * torch.sum(eigen_vector, dim=0).sign()
279
+ sign = torch.sum(eigen_vector, dim=0).sign()
280
+ sign[sign == 0] = 1.0
281
+ eigen_vector = eigen_vector * sign
280
282
  return eigen_vector, eigen_value
@@ -11,6 +11,9 @@ from .common import (
11
11
  quantile_min_max,
12
12
  quantile_normalize,
13
13
  )
14
+ from .nystrom import (
15
+ DistanceRealization,
16
+ )
14
17
  from .propagation_utils import (
15
18
  run_subgraph_sampling,
16
19
  extrapolate_knn,
@@ -192,18 +195,9 @@ def rgb_from_cosine_tsne_3d(
192
195
  )
193
196
  perplexity = num_sample // 2
194
197
 
195
-
196
- def cosine_to_rbf(X: torch.Tensor) -> torch.Tensor: # [B... x N x 3]
197
- normalized_X = X / torch.norm(X, p=2, dim=-1, keepdim=True) # [B... x N x 3]
198
- D = 1 - normalized_X @ normalized_X.mT # [B... x N x N]
199
-
200
- G = (D[..., :1, 1:] ** 2 + D[..., 1:, :1] ** 2 - D[..., 1:, 1:] ** 2) / 2 # [B... x (N - 1) x (N - 1)]
201
- L, V = torch.linalg.eigh(G) # [B... x (N - 1)], [B... x (N - 1) x (N - 1)]
202
- sqrtG = V[..., -3:] * (L[..., None, -3:] ** 0.5) # [B... x (N - 1) x 3]
203
-
204
- Y = torch.cat((torch.zeros_like(sqrtG[..., :1, :]), sqrtG), dim=-2) # [B... x N x 3]
205
- Y = Y - torch.mean(Y, dim=-2, keepdim=True)
206
- return Y
198
+ def cosine_to_rbf(X: torch.Tensor) -> torch.Tensor:
199
+ dr = DistanceRealization(n_components=3, num_sample=20000, distance="cosine", eig_solver="svd_lowrank")
200
+ return dr.fit_transform(X)
207
201
 
208
202
  def rgb_from_cosine(X_3d: torch.Tensor, q: float) -> torch.Tensor:
209
203
  return rgb_from_3d_rgb_cube(cosine_to_rbf(X_3d), q=q)
@@ -379,7 +373,6 @@ def rotate_rgb_cube(rgb, position=1):
379
373
 
380
374
  def rgb_from_3d_rgb_cube(X_3d, q=0.95):
381
375
  """convert 3D t-SNE to RGB color space
382
-
383
376
  Args:
384
377
  X_3d (torch.Tensor): 3D t-SNE embedding, shape (n_samples, 3)
385
378
  q (float): quantile, default 0.95
@@ -389,10 +382,10 @@ def rgb_from_3d_rgb_cube(X_3d, q=0.95):
389
382
  """
390
383
  assert X_3d.shape[1] == 3, "input should be (n_samples, 3)"
391
384
  assert len(X_3d.shape) == 2, "input should be (n_samples, 3)"
392
- rgb = []
393
- for i in range(3):
394
- rgb.append(quantile_normalize(X_3d[:, i], q=q))
395
- rgb = torch.stack(rgb, dim=-1)
385
+ rgb = torch.stack([
386
+ quantile_normalize(x, q=q)
387
+ for x in torch.unbind(X_3d, dim=1)
388
+ ], dim=-1)
396
389
  return rgb
397
390
 
398
391
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.10
3
+ Version: 0.0.12
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -0,0 +1,14 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ nystrom_ncut/__init__.py,sha256=JKfF6atok5T9V692RhlhgeRO5a2cN-bfAVa9irmTLfs,463
3
+ nystrom_ncut/common.py,sha256=RMPQvg9R2s7V-q7zAStN9YCZt7gpc5Ut-KSKtvELBQ4,1934
4
+ nystrom_ncut/propagation_utils.py,sha256=WeWKxRBm01ITILMgjsit5_fCe9oW1kJOPmAjjcmliMo,10340
5
+ nystrom_ncut/visualize_utils.py,sha256=JkDyWML6k7k6S2Z7xnpbUvMWssEXcXqXu7gBy8wnids,16809
6
+ nystrom_ncut/nystrom/__init__.py,sha256=4EpxD3Cmc8Fif4vo8DG-6FpTfCnNanD5zCZxK3WrMwQ,121
7
+ nystrom_ncut/nystrom/distance_realization.py,sha256=MWSdfPfUnr7BdiKFkogjQvcGagvj7OzLQklnVp0fPx8,6000
8
+ nystrom_ncut/nystrom/normalized_cut.py,sha256=_U3zrbe6V-5TQ4uWmqckxs2JTIhygQlnRDTFBI1ghD4,7194
9
+ nystrom_ncut/nystrom/nystrom.py,sha256=nL-zxbEE_ygJEZEPmeNrUpVeffvxdrsTcbxFanFuXQY,12613
10
+ nystrom_ncut-0.0.12.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
11
+ nystrom_ncut-0.0.12.dist-info/METADATA,sha256=pM-WT6Ly-IKYJ3DV2d-oOyc--K4VOyArB0sT5gHfHL4,6059
12
+ nystrom_ncut-0.0.12.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
13
+ nystrom_ncut-0.0.12.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
14
+ nystrom_ncut-0.0.12.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- nystrom_ncut/__init__.py,sha256=JKfF6atok5T9V692RhlhgeRO5a2cN-bfAVa9irmTLfs,463
3
- nystrom_ncut/common.py,sha256=RMPQvg9R2s7V-q7zAStN9YCZt7gpc5Ut-KSKtvELBQ4,1934
4
- nystrom_ncut/propagation_utils.py,sha256=WeWKxRBm01ITILMgjsit5_fCe9oW1kJOPmAjjcmliMo,10340
5
- nystrom_ncut/visualize_utils.py,sha256=Z_bcoxwmWpTxhQ_yoAXqTnYDf269IuT0b0Sm2EVQpRw,17422
6
- nystrom_ncut/nystrom/__init__.py,sha256=4EpxD3Cmc8Fif4vo8DG-6FpTfCnNanD5zCZxK3WrMwQ,121
7
- nystrom_ncut/nystrom/distance_realization.py,sha256=8AWUlZKZEPfhQHxYTZt0uzKedVp8ZB1wb__7M2Fy-Eo,5529
8
- nystrom_ncut/nystrom/normalized_cut.py,sha256=_U3zrbe6V-5TQ4uWmqckxs2JTIhygQlnRDTFBI1ghD4,7194
9
- nystrom_ncut/nystrom/nystrom.py,sha256=VJPA17I8cVvjILUABJjkVA5kkXbTmHDyrtcWvu5xs-0,12571
10
- nystrom_ncut-0.0.10.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
11
- nystrom_ncut-0.0.10.dist-info/METADATA,sha256=sqs2WHdNbJeT5zvlq_WWHHRvHTz1mHVbDL3PsE1NMBI,6059
12
- nystrom_ncut-0.0.10.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
13
- nystrom_ncut-0.0.10.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
14
- nystrom_ncut-0.0.10.dist-info/RECORD,,