nystrom-ncut 0.0.1__tar.gz → 0.0.3__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nystrom_ncut
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "nystrom_ncut"
7
- version = "0.0.1"
7
+ version = "0.0.3"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  { name = "Wentinn Liao", email = "wentinn.liao@gmail.com" },
@@ -3,4 +3,5 @@ scikit-learn
3
3
  umap-learn
4
4
  fpsample>=0.3.2
5
5
  pycolormap-2d
6
- tqdm
6
+ tqdm
7
+ torch
@@ -1,4 +1,7 @@
1
- from .ncut_pytorch import NCUT
1
+ from .ncut_pytorch import (
2
+ NCUT,
3
+ axis_align,
4
+ )
2
5
  from .propagation_utils import (
3
6
  affinity_from_features,
4
7
  propagate_eigenvectors,
@@ -6,7 +9,6 @@ from .propagation_utils import (
6
9
  quantile_normalize,
7
10
  )
8
11
  from .visualize_utils import (
9
- eigenvector_to_rgb,
10
12
  rgb_from_tsne_3d,
11
13
  rgb_from_umap_sphere,
12
14
  rgb_from_tsne_2d,
@@ -18,5 +20,3 @@ from .visualize_utils import (
18
20
  propagate_rgb_color,
19
21
  get_mask,
20
22
  )
21
- from .ncut_pytorch import nystrom_ncut, ncut
22
- from .ncut_pytorch import kway_ncut, axis_align
@@ -0,0 +1,20 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as Fn
6
+
7
+
8
+ def ceildiv(a: int, b: int) -> int:
9
+ return -(-a // b)
10
+
11
+
12
+ def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> torch.Tensor:
13
+ numel = np.prod(x.shape[:-1])
14
+ n = min(n, numel)
15
+ random_indices = torch.randperm(numel)[:n]
16
+ _x = x.flatten(0, -2)[random_indices]
17
+ if torch.allclose(torch.norm(_x, **normalize_kwargs), torch.ones(n, device=x.device)):
18
+ return x
19
+ else:
20
+ return Fn.normalize(x, **normalize_kwargs)
@@ -2,6 +2,7 @@ import logging
2
2
  from typing import Literal, Tuple
3
3
 
4
4
  import torch
5
+ import torch.nn.functional as Fn
5
6
 
6
7
  from .nystrom import (
7
8
  EigSolverOptions,
@@ -44,7 +45,6 @@ class LaplacianKernel(OnlineKernel):
44
45
  self.anchor_features, # [n x d]
45
46
  affinity_focal_gamma=self.affinity_focal_gamma,
46
47
  distance=self.distance,
47
- fill_diagonal=False,
48
48
  ) # [n x n]
49
49
  U, L = solve_eig(
50
50
  self.A,
@@ -61,7 +61,6 @@ class LaplacianKernel(OnlineKernel):
61
61
  features, # [m x d]
62
62
  affinity_focal_gamma=self.affinity_focal_gamma,
63
63
  distance=self.distance,
64
- fill_diagonal=False,
65
64
  ) # [n x m]
66
65
  b_r = torch.sum(B, dim=-1) # [n]
67
66
  b_c = torch.sum(B, dim=-2) # [m]
@@ -83,7 +82,6 @@ class LaplacianKernel(OnlineKernel):
83
82
  features, # [m x d]
84
83
  affinity_focal_gamma=self.affinity_focal_gamma,
85
84
  distance=self.distance,
86
- fill_diagonal=False,
87
85
  ) # [n x m]
88
86
  b_c = torch.sum(B, dim=-2) # [m]
89
87
  colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
@@ -91,25 +89,24 @@ class LaplacianKernel(OnlineKernel):
91
89
  return (B * scale).mT # [m x n]
92
90
 
93
91
 
94
- class NewNCUT(OnlineNystrom):
92
+ class NCUT(OnlineNystrom):
95
93
  """Nystrom Normalized Cut for large scale graph."""
96
94
 
97
95
  def __init__(
98
96
  self,
99
- num_eig: int = 100,
97
+ n_components: int = 100,
100
98
  affinity_focal_gamma: float = 1.0,
101
99
  num_sample: int = 10000,
102
100
  sample_method: Literal["farthest", "random"] = "farthest",
103
101
  distance: DistanceOptions = "cosine",
104
102
  eig_solver: EigSolverOptions = "svd_lowrank",
105
103
  normalize_features: bool = None,
106
- device: str = None,
107
104
  move_output_to_cpu: bool = False,
108
- matmul_chunk_size: int = 8096,
105
+ chunk_size: int = 8192,
109
106
  ):
110
107
  """
111
108
  Args:
112
- num_eig (int): number of top eigenvectors to return
109
+ n_components (int): number of top eigenvectors to return
113
110
  affinity_focal_gamma (float): affinity matrix temperature, lower t reduce the not-so-connected edge weights,
114
111
  smaller t result in more sharp eigenvectors.
115
112
  num_sample (int): number of samples for Nystrom-like approximation,
@@ -120,17 +117,15 @@ class NewNCUT(OnlineNystrom):
120
117
  eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
121
118
  normalize_features (bool): normalize input features before computing affinity matrix,
122
119
  default 'None' is True for cosine distance, False for euclidean distance and rbf
123
- device (str): device to use for eigen computation,
124
- move to GPU to speeds up a bit (~5x faster)
125
120
  move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
126
- matmul_chunk_size (int): chunk size for large-scale matrix multiplication
121
+ chunk_size (int): chunk size for large-scale matrix multiplication
127
122
  """
128
123
  OnlineNystrom.__init__(
129
124
  self,
130
- n_components=num_eig,
125
+ n_components=n_components,
131
126
  kernel=LaplacianKernel(affinity_focal_gamma, distance, eig_solver),
132
127
  eig_solver=eig_solver,
133
- chunk_size=matmul_chunk_size,
128
+ chunk_size=chunk_size,
134
129
  )
135
130
  self.num_sample = num_sample
136
131
  self.sample_method = sample_method
@@ -142,19 +137,14 @@ class NewNCUT(OnlineNystrom):
142
137
  if distance in ["euclidean", "rbf"]:
143
138
  self.normalize_features = False
144
139
 
145
- self.device = device
146
140
  self.move_output_to_cpu = move_output_to_cpu
147
- self.matmul_chunk_size = matmul_chunk_size
141
+ self.chunk_size = chunk_size
148
142
 
149
143
  def _fit_helper(
150
144
  self,
151
145
  features: torch.Tensor,
152
146
  precomputed_sampled_indices: torch.Tensor,
153
147
  ) -> Tuple[torch.Tensor, torch.Tensor]:
154
- # move subgraph gpu to speed up
155
- original_device = features.device
156
- device = original_device if self.device is None else self.device
157
-
158
148
  _n = features.shape[0]
159
149
  if self.num_sample >= _n:
160
150
  logging.info(
@@ -186,13 +176,13 @@ class NewNCUT(OnlineNystrom):
186
176
  num_sample=self.num_sample,
187
177
  sample_method=self.sample_method,
188
178
  )
189
- sampled_features = features[sampled_indices].to(device)
179
+ sampled_features = features[sampled_indices]
190
180
  OnlineNystrom.fit(self, sampled_features)
191
181
 
192
182
  _n_not_sampled = _n - len(sampled_features)
193
183
  if _n_not_sampled > 0:
194
- unsampled_indices = torch.full((_n,), True).scatter(0, sampled_indices, False)
195
- unsampled_features = features[unsampled_indices].to(device)
184
+ unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, sampled_indices, False)
185
+ unsampled_features = features[unsampled_indices]
196
186
  V_unsampled, _ = OnlineNystrom.update(self, unsampled_features)
197
187
  else:
198
188
  unsampled_indices = V_unsampled = None
@@ -211,7 +201,7 @@ class NewNCUT(OnlineNystrom):
211
201
  Returns:
212
202
  (NCUT): self
213
203
  """
214
- NewNCUT._fit_helper(self, features, precomputed_sampled_indices)
204
+ NCUT._fit_helper(self, features, precomputed_sampled_indices)
215
205
  return self
216
206
 
217
207
  def fit_transform(
@@ -229,13 +219,60 @@ class NewNCUT(OnlineNystrom):
229
219
  (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
230
220
  (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
231
221
  """
232
- unsampled_indices, V_unsampled = NewNCUT._fit_helper(self, features, precomputed_sampled_indices)
222
+ unsampled_indices, V_unsampled = NCUT._fit_helper(self, features, precomputed_sampled_indices)
233
223
  V_sampled, L = OnlineNystrom.transform(self)
234
224
 
235
225
  if unsampled_indices is not None:
236
- V = torch.zeros((len(unsampled_indices), self.n_components))
226
+ V = torch.zeros((len(unsampled_indices), self.n_components), device=features.device)
237
227
  V[~unsampled_indices] = V_sampled
238
228
  V[unsampled_indices] = V_unsampled
239
229
  else:
240
230
  V = V_sampled
241
231
  return V, L
232
+
233
+
234
+ def axis_align(eigen_vectors: torch.Tensor, max_iter=300):
235
+ """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
236
+
237
+ Args:
238
+ eigen_vectors (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
239
+ max_iter (int, optional): Maximum number of iterations.
240
+
241
+ Returns:
242
+ torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
243
+ """
244
+ # Normalize eigenvectors
245
+ n, k = eigen_vectors.shape
246
+ eigen_vectors = Fn.normalize(eigen_vectors, p=2, dim=-1)
247
+
248
+ # Initialize R matrix with the first column from a random row of EigenVectors
249
+ R = torch.empty((k, k), device=eigen_vectors.device)
250
+ R[0] = eigen_vectors[torch.randint(0, n, (1,))].squeeze()
251
+
252
+ # Loop to populate R with k orthogonal directions
253
+ c = torch.zeros(n, device=eigen_vectors.device)
254
+ for i in range(1, k):
255
+ c += torch.abs(eigen_vectors @ R[i - 1])
256
+ R[i] = eigen_vectors[torch.argmin(c, dim=0)]
257
+
258
+ # Iterative optimization loop
259
+ eps = torch.finfo(torch.float32).eps
260
+ prev_objective = torch.inf
261
+ for _ in range(max_iter):
262
+ # Discretize the projected eigenvectors
263
+ idx = torch.argmax(eigen_vectors @ R.mT, dim=-1)
264
+ M = torch.zeros((k, k)).index_add_(0, idx, eigen_vectors)
265
+
266
+ # Compute the NCut value
267
+ objective = torch.norm(M)
268
+
269
+ # Check for convergence
270
+ if torch.abs(objective - prev_objective) < eps:
271
+ break
272
+ prev_objective = objective
273
+
274
+ # SVD decomposition
275
+ U, S, Vh = torch.linalg.svd(M, full_matrices=False)
276
+ R = U @ Vh
277
+
278
+ return Fn.one_hot(idx, num_classes=k).to(torch.float), R
@@ -2,6 +2,8 @@ from typing import Literal, Tuple
2
2
 
3
3
  import torch
4
4
 
5
+ from .common import ceildiv
6
+
5
7
 
6
8
  EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
7
9
 
@@ -75,7 +77,7 @@ class OnlineNystrom:
75
77
  return U[:, :self.n_components], L[:self.n_components] # [n x n_components], [n_components]
76
78
 
77
79
  def update(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78
- n_chunks = -(-len(features) // self.chunk_size)
80
+ n_chunks = ceildiv(len(features), self.chunk_size)
79
81
  if n_chunks > 1:
80
82
  """ Chunked version """
81
83
  chunks = torch.chunk(features, n_chunks, dim=0)
@@ -111,7 +113,7 @@ class OnlineNystrom:
111
113
  if features is None:
112
114
  VS = self.A @ self.transform_matrix # [n x n_components]
113
115
  else:
114
- n_chunks = -(-len(features) // self.chunk_size)
116
+ n_chunks = ceildiv(len(features), self.chunk_size)
115
117
  if n_chunks > 1:
116
118
  """ Chunked version """
117
119
  chunks = torch.chunk(features, n_chunks, dim=0)
@@ -1,11 +1,12 @@
1
1
  import logging
2
- import math
3
2
  from typing import Literal
4
3
 
5
4
  import numpy as np
6
5
  import torch
7
6
  import torch.nn.functional as F
8
7
 
8
+ from .common import ceildiv, lazy_normalize
9
+
9
10
 
10
11
  @torch.no_grad()
11
12
  def run_subgraph_sampling(
@@ -42,7 +43,7 @@ def run_subgraph_sampling(
42
43
  sampled_indices = torch.randperm(features.shape[0])[:num_sample]
43
44
  else:
44
45
  raise ValueError("sample_method should be 'farthest' or 'random'")
45
- return sampled_indices
46
+ return sampled_indices.to(features.device)
46
47
 
47
48
 
48
49
  def farthest_point_sampling(
@@ -60,14 +61,12 @@ def farthest_point_sampling(
60
61
  # PCA to reduce the dimension
61
62
  if features.shape[1] > 8:
62
63
  u, s, v = torch.pca_lowrank(features, q=8)
63
- _n = features.shape[0]
64
- s /= math.sqrt(_n)
65
64
  features = u @ torch.diag(s)
66
65
 
67
66
  h = min(h, int(np.log2(features.shape[0])))
68
67
 
69
68
  kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
70
- features.cpu().numpy(), num_sample, h
69
+ features.numpy(force=True), num_sample, h
71
70
  ).astype(np.int64)
72
71
  return torch.from_numpy(kdline_fps_samples_idx)
73
72
 
@@ -76,26 +75,19 @@ def distance_from_features(
76
75
  features: torch.Tensor,
77
76
  features_B: torch.Tensor,
78
77
  distance: Literal["cosine", "euclidean", "rbf"],
79
- fill_diagonal: bool,
80
78
  ):
81
79
  """Compute affinity matrix from input features.
82
80
  Args:
83
81
  features (torch.Tensor): input features, shape (n_samples, n_features)
84
82
  features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
85
- affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
86
- on weak connections, default 1.0
87
83
  distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
88
- normalize_features (bool): normalize input features before computing affinity matrix
89
-
90
84
  Returns:
91
85
  (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
92
86
  """
93
87
  # compute distance matrix from input features
94
88
  if distance == "cosine":
95
- if not check_if_normalized(features):
96
- features = F.normalize(features, dim=-1)
97
- if not check_if_normalized(features_B):
98
- features_B = F.normalize(features_B, dim=-1)
89
+ features = lazy_normalize(features, dim=-1)
90
+ features_B = lazy_normalize(features_B, dim=-1)
99
91
  D = 1 - features @ features_B.T
100
92
  elif distance == "euclidean":
101
93
  D = torch.cdist(features, features_B, p=2)
@@ -105,8 +97,6 @@ def distance_from_features(
105
97
  else:
106
98
  raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
107
99
 
108
- if fill_diagonal:
109
- D[torch.arange(D.shape[0]), torch.arange(D.shape[0])] = 0
110
100
  return D
111
101
 
112
102
 
@@ -115,7 +105,6 @@ def affinity_from_features(
115
105
  features_B: torch.Tensor = None,
116
106
  affinity_focal_gamma: float = 1.0,
117
107
  distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
118
- fill_diagonal: bool = True,
119
108
  ):
120
109
  """Compute affinity matrix from input features.
121
110
 
@@ -125,8 +114,6 @@ def affinity_from_features(
125
114
  affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
126
115
  on weak connections, default 1.0
127
116
  distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
128
- normalize_features (bool): normalize input features before computing affinity matrix
129
-
130
117
  Returns:
131
118
  (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
132
119
  """
@@ -134,12 +121,10 @@ def affinity_from_features(
134
121
 
135
122
  # if feature_B is not provided, compute affinity matrix on features x features
136
123
  # if feature_B is provided, compute affinity matrix on features x feature_B
137
- if features_B is not None:
138
- assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
139
124
  features_B = features if features_B is None else features_B
140
125
 
141
126
  # compute distance matrix from input features
142
- D = distance_from_features(features, features_B, distance, fill_diagonal)
127
+ D = distance_from_features(features, features_B, distance)
143
128
 
144
129
  # torch.exp make affinity matrix positive definite,
145
130
  # lower affinity_focal_gamma reduce the weak edge weights
@@ -154,9 +139,8 @@ def propagate_knn(
154
139
  knn: int = 10,
155
140
  distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
156
141
  affinity_focal_gamma: float = 1.0,
157
- chunk_size: int = 8096,
142
+ chunk_size: int = 8192,
158
143
  device: str = None,
159
- use_tqdm: bool = False,
160
144
  move_output_to_cpu: bool = False,
161
145
  ):
162
146
  """A generic function to propagate new nodes using KNN.
@@ -169,8 +153,6 @@ def propagate_knn(
169
153
  distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
170
154
  chunk_size (int): chunk size for matrix multiplication
171
155
  device (str): device to use for computation, if None, will not change device
172
- use_tqdm (bool): show progress bar when propagating eigenvectors from subgraph to full graph
173
-
174
156
  Returns:
175
157
  torch.Tensor: propagated eigenvectors, shape (new_num_samples, D)
176
158
 
@@ -197,24 +179,16 @@ def propagate_knn(
197
179
  # used in nystrom_ncut
198
180
  # propagate eigen_vector from subgraph to full graph
199
181
  subgraph_output = subgraph_output.to(device)
200
- V_list = []
201
- iterator = range(0, inp_features.shape[0], chunk_size)
202
- try:
203
- assert use_tqdm
204
- from tqdm import tqdm
205
- iterator = tqdm(iterator, "propagate by KNN")
206
- except (AssertionError, ImportError):
207
- pass
208
182
 
209
- subgraph_features = subgraph_features.to(device)
210
- for i in iterator:
211
- end = min(i + chunk_size, inp_features.shape[0])
212
- _v = inp_features[i:end].to(device)
213
- _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance, False).mT
183
+ n_chunks = ceildiv(inp_features.shape[0], chunk_size)
184
+ V_list = []
185
+ for _v in torch.chunk(inp_features, n_chunks, dim=0):
186
+ _v = _v.to(device)
187
+ _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
214
188
 
215
189
  if knn is not None:
216
190
  mask = torch.full_like(_A, True, dtype=torch.bool)
217
- mask[torch.arange(end - i)[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
191
+ mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
218
192
  _A[mask] = 0.0
219
193
  _A = F.normalize(_A, p=1, dim=-1)
220
194
 
@@ -232,16 +206,14 @@ def propagate_nearest(
232
206
  inp_features: torch.Tensor,
233
207
  subgraph_features: torch.Tensor,
234
208
  distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
235
- chunk_size: int = 8096,
209
+ chunk_size: int = 8192,
236
210
  device: str = None,
237
211
  move_output_to_cpu: bool = False,
238
212
  ):
239
213
  device = subgraph_output.device if device is None else device
240
214
  if distance == 'cosine':
241
- if not check_if_normalized(inp_features):
242
- inp_features = F.normalize(inp_features, dim=-1)
243
- if not check_if_normalized(subgraph_features):
244
- subgraph_features = F.normalize(subgraph_features, dim=-1)
215
+ inp_features = lazy_normalize(inp_features, dim=-1)
216
+ subgraph_features = lazy_normalize(subgraph_features, dim=-1)
245
217
 
246
218
  # used in nystrom_tsne, equivalent to propagate_by_knn with knn=1
247
219
  # propagate tSNE from subgraph to full graph
@@ -250,7 +222,7 @@ def propagate_nearest(
250
222
  for i in range(0, inp_features.shape[0], chunk_size):
251
223
  end = min(i + chunk_size, inp_features.shape[0])
252
224
  _v = inp_features[i:end].to(device)
253
- _A = -distance_from_features(subgraph_features, _v, distance, False).mT
225
+ _A = -distance_from_features(subgraph_features, _v, distance).mT
254
226
 
255
227
  # keep top1 for each row
256
228
  top_idx = _A.argmax(dim=-1).cpu()
@@ -273,7 +245,6 @@ def propagate_eigenvectors(
273
245
  sample_method: Literal["farthest", "random"],
274
246
  chunk_size: int,
275
247
  device: str,
276
- use_tqdm: bool,
277
248
  ):
278
249
  """Propagate eigenvectors to new nodes using KNN. Note: this is equivalent to the class API `NCUT.tranform(new_features)`, expect for the sampling is re-done in this function.
279
250
  Args:
@@ -283,10 +254,8 @@ def propagate_eigenvectors(
283
254
  knn (int): number of KNN to propagate eigenvectors, default 3
284
255
  num_sample (int): number of samples for subgraph sampling, default 50000
285
256
  sample_method (str): sample method, 'farthest' (default) or 'random'
286
- chunk_size (int): chunk size for matrix multiplication, default 8096
257
+ chunk_size (int): chunk size for matrix multiplication, default 8192
287
258
  device (str): device to use for computation, if None, will not change device
288
- use_tqdm (bool): show progress bar when propagating eigenvectors from subgraph to full graph
289
-
290
259
  Returns:
291
260
  torch.Tensor: propagated eigenvectors, shape (n_new_samples, num_eig)
292
261
 
@@ -319,21 +288,10 @@ def propagate_eigenvectors(
319
288
  knn=knn,
320
289
  chunk_size=chunk_size,
321
290
  device=device,
322
- use_tqdm=use_tqdm,
323
291
  )
324
-
325
292
  return new_eigenvectors
326
293
 
327
294
 
328
- def check_if_normalized(x, n=1000):
329
- """check if the input tensor is normalized (unit norm)"""
330
- n = min(n, x.shape[0])
331
- random_indices = torch.randperm(x.shape[0])[:n]
332
- _x = x[random_indices]
333
- flag = torch.allclose(torch.norm(_x, dim=-1), torch.ones(n, device=x.device))
334
- return flag
335
-
336
-
337
295
  def quantile_min_max(x, q1=0.01, q2=0.99, n_sample=10000):
338
296
  if x.shape[0] > n_sample:
339
297
  np.random.seed(0)