SURE-tools 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/assembly/atlas.py ADDED
@@ -0,0 +1,575 @@
1
+ import os
2
+ import tempfile
3
+ import subprocess
4
+ import scanpy as sc
5
+ import numpy as np
6
+ import scipy as sp
7
+ from scipy import sparse
8
+ from scipy.stats import gaussian_kde
9
+
10
+ if sp.__version__ < '1.14.0':
11
+ from scipy.integrate import cumtrapz
12
+ else:
13
+ from scipy.integrate import cumulative_trapezoid as cumtrapz
14
+
15
+ import pandas as pd
16
+ from functools import reduce
17
+
18
+ import datatable as dt
19
+ from tqdm import tqdm
20
+ import umap
21
+ import faiss
22
+ from sklearn.neighbors import NearestNeighbors
23
+ #from cuml.linear_model import LogisticRegression
24
+ from sklearn.preprocessing import LabelEncoder,LabelBinarizer
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.utils.data import DataLoader
29
+ import pyro.distributions as dist
30
+
31
+ from ..SURE import SURE
32
+ from .assembly import assembly,get_data,get_subdata,batch_encoding,get_uns
33
+ from ..codebook import codebook_summarize_,codebook_generate,codebook_sketch
34
+ from ..utils import convert_to_tensor, tensor_to_numpy
35
+ from ..utils import CustomDataset
36
+ from ..utils import pretty_print, Colors
37
+ from ..utils import PriorityQueue
38
+
39
+ import networkx as nx
40
+ import matplotlib.pyplot as plt
41
+ import matplotlib.patches as mpatches
42
+ from matplotlib.colors import ListedColormap
43
+
44
+ import dill as pickle
45
+ import gzip
46
+ from packaging.version import Version
47
+ torch_version = torch.__version__
48
+
49
+ from typing import Literal
50
+
51
+ import warnings
52
+ warnings.filterwarnings("ignore")
53
+
54
+ class SingleOmicsAtlas(nn.Module):
55
+ """
56
+ Compressed Cell Atlas
57
+
58
+ Parameters
59
+ ----------
60
+ atlas_name
61
+ Name of the built atlas.
62
+ hvgs
63
+ Highly variable genes.
64
+ eps
65
+ Low bound.
66
+ """
67
+ def __init__(self,
68
+ atlas_name: str = 'Atlas',
69
+ hvgs: list = None,
70
+ eps: float = 1e-12):
71
+ super().__init__()
72
+ self.atlas_name = atlas_name
73
+ self.model = None
74
+ self.sure_models_list = None
75
+ self.hvgs = hvgs
76
+ self.adata = None
77
+ self.sample_adata = None
78
+ self.layer = None
79
+ self.n_sure_models = None
80
+ self.umap_metric='euclidean'
81
+ self.umap = None
82
+ self.adj = None
83
+ self.subatlas_list = None
84
+ self.n_subatlas = 0
85
+ self.pheno_keys = None
86
+ self.nearest_neighbor_engine = None
87
+ self.knn_k = 5
88
+ self.network = None
89
+ self.network_pos = None
90
+ self.sample_network = None
91
+ self.sample_network_pos = None
92
+ self.eps=eps
93
+
94
+ def fit(self, adata_list_,
95
+ batch_key: str = None,
96
+ pheno_keys: list = None,
97
+ preprocessing: bool = True,
98
+ hvgs: list = None,
99
+ n_top_genes: int = 5000,
100
+ hvg_method: Literal['seurat','seurat_v3','cell_ranger'] ='seurat',
101
+ layer: str = 'counts',
102
+ cuda_id: int = 0,
103
+ use_jax: bool = False,
104
+ codebook_size: int = 800,
105
+ codebook_size_per_adata: int = 500,
106
+ min_gamma=20,
107
+ min_cell=1000,
108
+ learning_rate: float = 0.0001,
109
+ batch_size: int = 200,
110
+ batch_size_per_adata: int = 100,
111
+ n_epochs: int = 200,
112
+ latent_dist: Literal['normal','laplacian','studentt','cauchy'] = 'studentt',
113
+ use_dirichlet: bool = False,
114
+ use_dirichlet_per_adata: bool = False,
115
+ zero_inflation: bool = False,
116
+ zero_inflation_per_adata: bool = False,
117
+ likelihood: Literal['negbinomial','poisson','multinomial','gaussian'] = 'multinomial',
118
+ likelihood_per_adata: Literal['negbinomial','poisson','multinomial','gaussian'] = 'multinomial',
119
+ #n_samples_per_adata: int = 10000,
120
+ #total_samples: int = 500000,
121
+ #summarize: bool = True,
122
+ sketch_func: Literal['mean','sum','simulate','sample','bootstrap_sample'] = 'mean',
123
+ n_bootstrap_neighbors: int = 8,
124
+ #sketching: bool = True,
125
+ #even_sketch: bool = True,
126
+ #bootstrap_sketch: bool = False,
127
+ #n_sketch_neighbors: int = 10,
128
+ n_workers: int = 1,
129
+ mute: bool = True,
130
+ edge_thresh: float = 0.001,
131
+ metric: Literal['euclidean','correlation','cosine'] ='euclidean',
132
+ knn_k: int = 5):
133
+ """
134
+ Fit the input list of AnnData datasets.
135
+
136
+ Parameters
137
+ ----------
138
+ adata_list
139
+ A list of AnnData datasets.
140
+ batch_key
141
+ Undesired factor.
142
+ pheno_keys
143
+ A list of phenotype factors, of which the information should be retained in the built atlas.
144
+ preprocessing
145
+ If toggled on, the input datasets will go through the standard Scanpy proprocessing steps including normalization and log1p transformation.
146
+ hvgs
147
+ If a list of highly variable genes is given, the subsequent steps will rely on these genes.
148
+ n_top_genes
149
+ Parameter for Scanpy's highly_variable_genes
150
+ hvg_method
151
+ Parameter for Scanpy's highly_variable_genes
152
+ layer
153
+ Data used for building the atlas.
154
+ cuda_id
155
+ Cuda device.
156
+ use_jax
157
+ If toggled on, Jax will be used for speeding.
158
+ codebook_size
159
+ Size of metacells in the built atlas.
160
+ codebook_size_per_adata
161
+ Size of metacells for each adata.
162
+ learning_rate
163
+ Parameter for optimization.
164
+ batch_size
165
+ Parameter for building the atlas.
166
+ batch_size_per_adata
167
+ Parameter for calling metacells within each adata.
168
+ n_epochs
169
+ Number of epochs.
170
+ latent_dist
171
+ Distribution for latent representations.
172
+ use_dirichlet
173
+ Use Dirichlet model for building the atlas.
174
+ use_dirichlet_per_adata
175
+ Use Dirichlet model for calling metacells within each adata.
176
+ zero_inflation
177
+ Use zero-inflated model for building the atlas.
178
+ zero_inflation_per_adata
179
+ Use zero-inflated model for calling metacells within each adata.
180
+ likelihood
181
+ Data generation model for building the atlas.
182
+ likelihood_per_adata
183
+ Data generation model for calling metacells within each adata.
184
+ n_samples_per_adata
185
+ Number of samples drawn from each adata for building the atlas.
186
+ total_samples
187
+ Total number of samples for building the atlas.
188
+ sketching
189
+ If toggled on, sketched cells will be used for building the atlas.
190
+ bootstrap_sketch
191
+ If toggled on, bootstraped sketching will be used instead of simple sketching.
192
+ n_sketch_neighbors
193
+ Parameter for bootstraped sketching.
194
+ edge_thresh
195
+ Parameter for building network.
196
+ metric
197
+ Parameter for UMAP.
198
+ knn_k
199
+ Parameter for K-nearest-neighbor machine.
200
+ """
201
+
202
+ print(Colors.YELLOW + 'Create A Distribution-Preserved Single-Cell Omics Atlas' + Colors.RESET)
203
+ adata_list = [ad for ad in adata_list_ if ad.shape[0]>min_cell]
204
+ n_adatas = len(adata_list)
205
+ self.layer = layer
206
+ self.n_sure_models = n_adatas
207
+ self.umap_metric = metric
208
+ self.pheno_keys = pheno_keys
209
+ #zero_inflation = True if sketching else zero_inflation
210
+
211
+ # assembly
212
+ print(f'{n_adatas} adata datasets are given')
213
+ self.model,self.submodels,self.hvgs = assembly(adata_list, batch_key,
214
+ preprocessing, hvgs, n_top_genes, hvg_method, layer, cuda_id, use_jax,
215
+ codebook_size, codebook_size_per_adata, min_gamma, learning_rate,
216
+ batch_size, batch_size_per_adata, n_epochs, latent_dist,
217
+ use_dirichlet, use_dirichlet_per_adata,
218
+ zero_inflation, zero_inflation_per_adata,
219
+ likelihood, likelihood_per_adata,
220
+ #n_samples_per_adata, total_samples, summarize,
221
+ sketch_func,n_bootstrap_neighbors,
222
+ #sketching, even_sketch, bootstrap_sketch, n_sketch_neighbors,
223
+ n_workers, mute)
224
+
225
+ # summarize expression
226
+ X,W,adj = None,None,None
227
+ with tqdm(total=n_adatas, desc=f'Summarize data in {layer}', unit='adata') as pbar:
228
+ for i in np.arange(n_adatas):
229
+ #print(f'Adata {i+1} / {n_adatas}: Summarize data in {layer}')
230
+ adata_i = adata_list[i][:,self.hvgs].copy()
231
+ adata_i_ = adata_list[i].copy()
232
+
233
+ xs_i = get_data(adata_i, layer).values
234
+ xs_i_ = get_data(adata_i_, layer).values
235
+ ws_i_sup = self.model.soft_assignments(xs_i)
236
+ xs_i_sup = codebook_summarize_(ws_i_sup, xs_i_)
237
+
238
+ if X is None:
239
+ X = xs_i_sup
240
+ W = np.sum(ws_i_sup.T, axis=1, keepdims=True)
241
+
242
+ a = convert_to_tensor(ws_i_sup)
243
+ a_t = a.T / torch.sum(a.T, dim=1, keepdim=True)
244
+ adj = torch.matmul(a_t, a)
245
+ else:
246
+ X += xs_i_sup
247
+ W += np.sum(ws_i_sup.T, axis=1, keepdims=True)
248
+
249
+ a = convert_to_tensor(ws_i_sup)
250
+ a_t = a.T / torch.sum(a.T, dim=1, keepdim=True)
251
+ adj += torch.matmul(a_t, a)
252
+
253
+ pbar.update(1)
254
+ X = X / W
255
+ self.adata = sc.AnnData(X)
256
+ self.adata.obs_names = [f'MC{x}' for x in self.adata.obs_names]
257
+ self.adata.var_names = adata_i_.var_names
258
+
259
+ adj = tensor_to_numpy(adj) / self.n_sure_models
260
+ self.adj = (adj + adj.T) / 2
261
+ n_nodes = adj.shape[0]
262
+ self.adj[np.arange(n_nodes), np.arange(n_nodes)] = 0
263
+
264
+ # summarize phenotypes
265
+ if pheno_keys is not None:
266
+ self._summarize_phenotypes_from_adatas(adata_list, pheno_keys)
267
+
268
+ # COMMENT OUT 2025.7.4
269
+ # compute visualization position for the atlas
270
+ #print('Compute the reference position of the atlas')
271
+ #n_samples = np.max([n_samples_per_adata * self.n_sure_models, 50000])
272
+ #n_samples = np.min([n_samples, total_samples])
273
+ #self.instantiation(n_samples)
274
+ #
275
+ # create nearest neighbor indexing
276
+ #self.build_nearest_neighbor_engine(knn_k)
277
+ #self.knn_k = knn_k
278
+ #
279
+ #self.build_network(edge_thresh=edge_thresh)
280
+ # END OF COMMENT OUT 2025.7.4
281
+
282
+ # the distribution of cell-to-metacell distances
283
+ metacells = self.model.get_metacell_coordinates()
284
+ self.nearest_metacell_engine = NearestNeighbors(n_neighbors=knn_k,n_jobs=-1)
285
+ self.nearest_metacell_engine.fit(metacells)
286
+
287
+ cell2metacell_distances = []
288
+ with tqdm(total=n_adatas, desc='Build cell-to-metacell distance distribution', unit='adata') as pbar:
289
+ for i in np.arange(n_adatas):
290
+ #print(f'Adata {i+1} / {n_adatas}: Build cell-to-metacell distance distribution')
291
+ adata_i = adata_list[i][:,self.hvgs].copy()
292
+
293
+ xs_i = get_data(adata_i, layer).values
294
+ zs_i = self.model.get_cell_coordinates(xs_i)
295
+
296
+ dd_i,_ = self.nearest_metacell_engine.kneighbors(zs_i, n_neighbors=1)
297
+ cell2metacell_distances.extend(dd_i.flatten())
298
+
299
+ pbar.update(1)
300
+
301
+ self.cell2metacell_dist = gaussian_kde(cell2metacell_distances)
302
+
303
+ print(Colors.YELLOW + f'A distribution-preserved atlas has been built from {n_adatas} adata datasets.' + Colors.RESET)
304
+
305
+ def detect_outlier(self, adata_query, thresh:float = 1e-2, batch_size:int = 1024):
306
+
307
+ def batch_p_values_greater(kde, x_vector):
308
+ """
309
+ 批量计算向量x中每个元素的P(X > x)
310
+
311
+ 参数:
312
+ kde: 已拟合的gaussian_kde对象
313
+ x_vector: 包含多个x值的数组
314
+
315
+ 返回:
316
+ 与x_vector形状相同的P(X > x)数组
317
+ """
318
+ # 创建高分辨率的CDF查找表
319
+ x_min = min(kde.dataset.min(), x_vector.min()) - 3 * np.std(kde.dataset)
320
+ x_max = max(kde.dataset.max(), x_vector.max()) + 3 * np.std(kde.dataset)
321
+ grid_points = max(20000, len(kde.dataset)) # 确保足够密集
322
+ x_grid = np.linspace(x_min, x_max, grid_points)
323
+
324
+ # 计算PDF和CDF
325
+ pdf = kde(x_grid)
326
+ cdf = np.cumsum(pdf)
327
+ cdf /= cdf[-1] # 归一化
328
+
329
+ # 使用插值查找每个x对应的P(X > x) = 1 - CDF(x)
330
+ p_values = 1 - np.interp(x_vector, x_grid, cdf)
331
+
332
+ # 处理边界外的值
333
+ p_values[x_vector < x_min] = 1.0
334
+ p_values[x_vector > x_max] = 0.0
335
+
336
+ return p_values
337
+
338
+ adata_query = adata_query.copy()
339
+ X_query = get_subdata(adata_query, self.hvgs, self.layer).values
340
+ Z_map = self.model.get_cell_coordinates(X_query, batch_size=batch_size)
341
+
342
+ dd,_ = self.nearest_metacell_engine.kneighbors(Z_map, n_neighbors=1)
343
+ #pp = self.cell2metacell_dist(dd.flatten())
344
+ pp = batch_p_values_greater(self.cell2metacell_dist,dd)
345
+ outliers = pp < thresh
346
+
347
+ print(f'{np.sum(outliers)} outliers found.')
348
+
349
+ return outliers
350
+
351
+ def map(self, adata_query,
352
+ batch_size: int = 1024):
353
+ """
354
+ Map query data to the atlas.
355
+
356
+ Parameters
357
+ ----------
358
+ adata_query
359
+ Query data. It should be an AnnData object.
360
+ batch_size
361
+ Size of batch processing.
362
+ """
363
+ adata_query = adata_query.copy()
364
+ X_query = get_subdata(adata_query, self.hvgs, self.layer).values
365
+
366
+ Z_map = self.model.get_cell_coordinates(X_query, batch_size=batch_size)
367
+ A_map = self.model.soft_assignments(X_query)
368
+
369
+ return Z_map, A_map
370
+
371
+ def phenotype_density_estimation(self, adata_list, pheno_key):
372
+ n_adatas = len(adata_list)
373
+ PH_MC_list = []
374
+ with tqdm(total=n_adatas, desc=f'Estimate {pheno_key} density', unit='adata') as pbar:
375
+ for i in np.arange(n_adatas):
376
+ ad = adata_list[i][:,self.hvgs].copy()
377
+ xs = get_data(ad, layer=self.layer).values
378
+ X_MC_mat = self.model.soft_assignments(xs)
379
+ lb = LabelBinarizer().fit(ad.obs[pheno_key].astype(str))
380
+ X_PH_mat = lb.transform(ad.obs[pheno_key].astype(str))
381
+ PH_MC_mat = np.matmul(X_PH_mat.T, X_MC_mat)
382
+
383
+ p_PH = np.sum(PH_MC_mat, axis=1, keepdims=True)
384
+ p_PH[p_PH==0] = 1
385
+ p_MC_PH = PH_MC_mat / p_PH
386
+ PH_MC_mat = p_MC_PH * p_PH
387
+ p_MC = np.sum(PH_MC_mat, axis=0, keepdims=True)
388
+ p_MC[p_MC==0] = 1
389
+ p_PH_MC = PH_MC_mat / p_MC
390
+
391
+ PH_MC_df = pd.DataFrame(p_PH_MC,
392
+ columns = [f'MC{x}' for x in np.arange(X_MC_mat.shape[1])],
393
+ index = lb.classes_)
394
+ PH_MC_list.append(PH_MC_df)
395
+ pbar.update(1)
396
+
397
+ if n_adatas>1:
398
+ PH_MC_DF = aggregate_dataframes(PH_MC_list)
399
+ PH_MC_DF = PH_MC_DF.div(n_adatas)
400
+ else:
401
+ PH_MC_DF = PH_MC_list[0]
402
+
403
+ self.adata.uns[f'{pheno_key}_density'] = PH_MC_DF
404
+
405
+ def summarize_phenotypes(self, adata_list=None, pheno_keys=None):
406
+ self._summarize_phenotypes_from_adatas(adata_list, pheno_keys)
407
+
408
+ def _summarize_phenotypes_from_adatas(self, adata_list, pheno_keys):
409
+ n_adatas = len(adata_list)
410
+ for pheno in pheno_keys:
411
+ Y = list()
412
+ with tqdm(total=n_adatas, desc=f'Summarize data in {pheno}', unit='adata') as pbar:
413
+ for i in np.arange(n_adatas):
414
+ if pheno in adata_list[i].obs.columns:
415
+ #print(f'Adata {i+1} / {n_adatas}: Summarize data in {pheno}')
416
+ adata_i = adata_list[i][:,self.hvgs].copy()
417
+
418
+ xs_i = get_data(adata_i, self.layer).values
419
+ ws_i_sup = self.model.soft_assignments(xs_i)
420
+ #ys_i = batch_encoding(adata_i, pheno)
421
+ #columns_i = ys_i.columns.tolist()
422
+ #ys_i = codebook_summarize_(ws_i_sup, ys_i.values)
423
+
424
+ ws_i = np.argmax(ws_i_sup, axis=1)
425
+ adata_i.obs['metacell'] = [f'MC{x}' for x in ws_i]
426
+ df = adata_i.obs[['metacell',pheno]].value_counts().unstack(fill_value=0)
427
+
428
+ Y.append(df)
429
+ pbar.update(1)
430
+
431
+ Y_df_ = aggregate_dataframes(Y)
432
+ Y_df = pd.DataFrame(0, index=[f'MC{x}' for x in np.arange(self.adata.shape[0])], columns=Y_df_.columns)
433
+ Y_df = Y_df.add(Y_df_, fill_value=0)
434
+
435
+ #Y = Y_df.values
436
+ #Y[Y<self.eps] = 0
437
+ #Y = Y / Y.sum(axis=1, keepdims=True)
438
+
439
+ self.adata.uns[pheno] = Y_df
440
+ #self.adata.uns[f'{pheno}_columns'] = Y_df.columns
441
+ Y_hat_ = Y_df.idxmax(axis=1)
442
+ Y_hat = pd.DataFrame(pd.NA, index=[f'MC{x}' for x in np.arange(self.adata.shape[0])], columns=['id'])
443
+ Y_hat.loc[Y_hat_.index.tolist(),'id'] = Y_hat_.tolist()
444
+ self.adata.obs[pheno] = Y_hat['id'].tolist()
445
+
446
+ def phenotype_predict(self, adata_query, pheno_key, batch_size=1024):
447
+ _,ws = self.map(adata_query, batch_size)
448
+ A = matrix_dotprod(ws, self.adata.uns[f'{pheno_key}_density'].values)
449
+ A = pd.DataFrame(A, columns=self.adata.uns[f'{pheno_key}_density'].columns)
450
+ return A.idxmax(axis=1).tolist()
451
+
452
+ @classmethod
453
+ def save_model(cls, atlas, file_path, compression=False):
454
+ """Save the model to the specified file path."""
455
+ file_path = os.path.abspath(file_path)
456
+
457
+ atlas.sample_adata = None
458
+ atlas.eval()
459
+
460
+ if compression:
461
+ with gzip.open(file_path, 'wb') as pickle_file:
462
+ pickle.dump(atlas, pickle_file)
463
+ else:
464
+ with open(file_path, 'wb') as pickle_file:
465
+ pickle.dump(atlas, pickle_file)
466
+
467
+ print(f'Model saved to {file_path}')
468
+
469
+ @classmethod
470
+ def load_model(cls, file_path, n_samples=10000):
471
+ """Load the model from the specified file path and return an instance."""
472
+ print(f'Model loaded from {file_path}')
473
+
474
+ file_path = os.path.abspath(file_path)
475
+ if file_path.endswith('gz'):
476
+ with gzip.open(file_path, 'rb') as pickle_file:
477
+ atlas = pickle.load(pickle_file)
478
+ else:
479
+ with open(file_path, 'rb') as pickle_file:
480
+ atlas = pickle.load(pickle_file)
481
+
482
+ #xs = atlas.sample(n_samples)
483
+ #atlas.sample_adata = sc.AnnData(xs)
484
+ #atlas.sample_adata.var_names = atlas.hvgs
485
+ #
486
+ #zs = atlas.model.get_cell_coordinates(xs)
487
+ #ws = atlas.model.soft_assignments(xs)
488
+ #atlas.sample_adata.obsm['X_umap'] = atlas.umap.transform(zs)
489
+ #atlas.sample_adata.obsm['X_sure'] = zs
490
+ #atlas.sample_adata.obsm['weight'] = ws
491
+ #
492
+ return atlas
493
+
494
+
495
+
496
+
497
+ #def aggregate_dataframes(df_list):
498
+ # n_dfs = len(df_list)
499
+ # all_columns = set(df_list[0].columns)
500
+ # for i in np.arange(n_dfs-1):
501
+ # all_columns = all_columns.union(set(df_list[i+1].columns))
502
+ #
503
+ # all_indexs = set(df_list[0].index.tolist())
504
+ # for i in np.arange(n_dfs-1):
505
+ # all_indexs = all_indexs.union(set(df_list[i+1].index.tolist()))
506
+ #
507
+ # for col in all_columns:
508
+ # for i in np.arange(n_dfs):
509
+ # if col not in df_list[i]:
510
+ # df_list[i][col] = 0
511
+ #
512
+ # df = pd.DataFrame(0, index=all_indexs, columns=all_columns)
513
+ # df = df_list[0]
514
+ # for i in np.arange(n_dfs-1):
515
+ # df += df_list[i+1]
516
+ #
517
+ # #df /= n_dfs
518
+ # return df
519
+
520
+ def aggregate_dataframes(df_list):
521
+ all_index = reduce(lambda x, y: x.union(y), [df.index for df in df_list], pd.Index([]))
522
+ all_index_sorted = all_index.sort_values()
523
+ all_columns = reduce(lambda x, y: x.union(y), [df.columns for df in df_list], pd.Index([]))
524
+ result = pd.DataFrame(0, index=all_index_sorted, columns=all_columns)
525
+
526
+ for df in df_list:
527
+ result = result.add(df, fill_value=0)
528
+
529
+ return result
530
+
531
+ def smooth_y_over_x(xs, ys, knn_k):
532
+ n = xs.shape[0]
533
+ nbrs = NearestNeighbors(n_neighbors=knn_k, n_jobs=-1)
534
+ nbrs.fit(xs)
535
+ ids = nbrs.kneighbors(xs, return_distance=False)
536
+ ys_smooth = np.zeros_like(ys)
537
+ for i in np.arange(knn_k):
538
+ ys_smooth += ys[ids[:,i]]
539
+ ys_smooth -= ys
540
+ ys_smooth /= knn_k-1
541
+ return ys_smooth
542
+
543
+ def matrix_dotprod(A, B, dtype=torch.float32):
544
+ A = convert_to_tensor(A, dtype=dtype)
545
+ B = convert_to_tensor(B, dtype=dtype)
546
+ AB = torch.matmul(A, B)
547
+ return tensor_to_numpy(AB)
548
+
549
+ def matrix_elemprod(A, B):
550
+ A = convert_to_tensor(A)
551
+ B = convert_to_tensor(B)
552
+ AB = A * B
553
+ return tensor_to_numpy(AB)
554
+
555
+ def cdf(density, xs, initial=0):
556
+ CDF = cumtrapz(density, xs, initial=initial)
557
+ CDF /= CDF[-1]
558
+ return CDF
559
+
560
+
561
+
562
+ class FaissKNeighbors:
563
+ def __init__(self, n_neighbors=5):
564
+ self.index = None
565
+ self.k = n_neighbors
566
+
567
+ def fit(self, X):
568
+ self.index = faiss.IndexFlatL2(X.shape[1])
569
+ self.index.add(X.astype(np.float32))
570
+
571
+ def kneighbors(self, X):
572
+ distances, indices = self.index.search(X.astype(np.float32), k=self.k)
573
+ return distances, indices
574
+
575
+
@@ -0,0 +1,4 @@
1
+ # Importing specific functions from modules
2
+ from .codebook import codebook_summarize_, codebook_summarize, codebook_aggregate, \
3
+ codebook_generate, codebook_weights, codebook_sample, codebook_sketch, \
4
+ codebook_bootstrap_sketch