SURE-tools 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/SURE.py +1203 -0
- SURE/__init__.py +7 -0
- SURE/assembly/__init__.py +3 -0
- SURE/assembly/assembly.py +511 -0
- SURE/assembly/atlas.py +575 -0
- SURE/codebook/__init__.py +4 -0
- SURE/codebook/codebook.py +472 -0
- SURE/utils/__init__.py +19 -0
- SURE/utils/custom_mlp.py +209 -0
- SURE/utils/queue.py +50 -0
- SURE/utils/utils.py +308 -0
- SURE_tools-1.0.1.dist-info/LICENSE +21 -0
- SURE_tools-1.0.1.dist-info/METADATA +68 -0
- SURE_tools-1.0.1.dist-info/RECORD +17 -0
- SURE_tools-1.0.1.dist-info/WHEEL +5 -0
- SURE_tools-1.0.1.dist-info/entry_points.txt +2 -0
- SURE_tools-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import DataLoader
|
|
3
|
+
|
|
4
|
+
import pyro
|
|
5
|
+
import pyro.distributions as dist
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import scipy as sp
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from scipy.stats import gaussian_kde
|
|
11
|
+
from sklearn.preprocessing import LabelBinarizer
|
|
12
|
+
from sklearn.neighbors import NearestNeighbors
|
|
13
|
+
import scanpy as sc
|
|
14
|
+
|
|
15
|
+
import multiprocessing as mp
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from ..utils import convert_to_tensor, tensor_to_numpy
|
|
19
|
+
from ..utils import CustomDataset2
|
|
20
|
+
|
|
21
|
+
from typing import Literal, List, Tuple, Dict
|
|
22
|
+
from functools import partial
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def codebook_generate(sure_model, n_samples):
|
|
26
|
+
code_weights = convert_to_tensor(sure_model.codebook_weights, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
27
|
+
ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
|
|
28
|
+
|
|
29
|
+
codebook_loc, codebook_scale = sure_model.get_codebook()
|
|
30
|
+
codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
31
|
+
codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
32
|
+
|
|
33
|
+
loc = torch.matmul(ns, codebook_loc)
|
|
34
|
+
scale = torch.matmul(ns, codebook_scale)
|
|
35
|
+
zs = dist.Normal(loc, scale).to_event(1).sample()
|
|
36
|
+
return tensor_to_numpy(zs), tensor_to_numpy(ns)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def codebook_sample(sure_model, xs, n_samples, even_sample=False, filter=True):
|
|
40
|
+
xs = convert_to_tensor(xs, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
41
|
+
assigns = sure_model.soft_assignments(xs)
|
|
42
|
+
code_assigns = np.argmax(assigns, axis=1)
|
|
43
|
+
|
|
44
|
+
if even_sample:
|
|
45
|
+
repeat = n_samples // assigns.shape[1]
|
|
46
|
+
remainder = n_samples % assigns.shape[1]
|
|
47
|
+
ns_id = np.repeat(np.arange(1, assigns.shape[1] + 1), repeat)
|
|
48
|
+
# 补充剩余元素(将前 `remainder` 个数字各多重复1次)
|
|
49
|
+
if remainder > 0:
|
|
50
|
+
ns_id = np.concatenate([ns_id, np.arange(1, remainder + 1)])
|
|
51
|
+
ns_id -= 1
|
|
52
|
+
|
|
53
|
+
ns = LabelBinarizer().fit_transform(ns_id)
|
|
54
|
+
ns = convert_to_tensor(ns, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
55
|
+
else:
|
|
56
|
+
code_weights = codebook_weights(assigns)
|
|
57
|
+
code_weights = convert_to_tensor(code_weights, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
58
|
+
ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
|
|
59
|
+
ns_id = np.argmax(tensor_to_numpy(ns), axis=1)
|
|
60
|
+
|
|
61
|
+
codebook_loc, codebook_scale = sure_model.get_codebook()
|
|
62
|
+
codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
63
|
+
codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
64
|
+
|
|
65
|
+
loc = torch.matmul(ns, codebook_loc)
|
|
66
|
+
scale = torch.matmul(ns, codebook_scale)
|
|
67
|
+
zs = dist.Normal(loc, scale).to_event(1).sample()
|
|
68
|
+
|
|
69
|
+
xs_zs = sure_model.get_cell_coordinates(xs)
|
|
70
|
+
#xs_zs = convert_to_tensor(xs_zs, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
71
|
+
#xs_dist = torch.cdist(zs, xs_zs)
|
|
72
|
+
#idx = xs_dist.argmin(dim=1)
|
|
73
|
+
|
|
74
|
+
#nbrs = NearestNeighbors(n_jobs=-1, n_neighbors=1)
|
|
75
|
+
#nbrs.fit(tensor_to_numpy(xs_zs))
|
|
76
|
+
#idx = nbrs.kneighbors(tensor_to_numpy(zs), return_distance=False)
|
|
77
|
+
#idx_ = idx.flatten()
|
|
78
|
+
#idx = [idx_[i] for i in np.arange(n_samples) if np.array_equal(code_assigns[idx_[i]], ns[i])]
|
|
79
|
+
#df = pd.DataFrame({'idx':idx_,
|
|
80
|
+
# 'to':code_assigns[idx_],
|
|
81
|
+
# 'from':ns_id})
|
|
82
|
+
#if filter:
|
|
83
|
+
# filtered_df = df[df['from'] != df['to']]
|
|
84
|
+
#else:
|
|
85
|
+
# filtered_df = df
|
|
86
|
+
#idx = filtered_df.loc[:,'idx'].values
|
|
87
|
+
#ns_id = filtered_df.loc[:,'from'].values
|
|
88
|
+
|
|
89
|
+
nbrs = NearestNeighbors(n_neighbors=50, n_jobs=-1)
|
|
90
|
+
nbrs.fit(tensor_to_numpy(xs_zs))
|
|
91
|
+
|
|
92
|
+
distances, ids = nbrs.kneighbors(tensor_to_numpy(zs), return_distance=True)
|
|
93
|
+
|
|
94
|
+
idx,ns_list = [],[]
|
|
95
|
+
with tqdm(total=n_samples, desc='Sketching', unit='sketch') as pbar:
|
|
96
|
+
for i in np.arange(n_samples):
|
|
97
|
+
distances_i = distances[i]
|
|
98
|
+
weights_i = distance_to_softmax_weights(distances_i)
|
|
99
|
+
cell_i_ = weighted_sample(ids[i], weights_i, sample_size=1, replace=False)
|
|
100
|
+
|
|
101
|
+
df = pd.DataFrame({'idx':[cell_i_],
|
|
102
|
+
'to': [code_assigns[cell_i_]],
|
|
103
|
+
'from': [ns_id[i]]})
|
|
104
|
+
if filter:
|
|
105
|
+
filtered_df = df[df['from'] != df['to']]
|
|
106
|
+
else:
|
|
107
|
+
filtered_df = df
|
|
108
|
+
cells_i = filtered_df.loc[:,'idx'].values
|
|
109
|
+
ns_i = filtered_df.loc[:,'from'].unique()
|
|
110
|
+
|
|
111
|
+
idx.extend(cells_i)
|
|
112
|
+
ns_list.extend(ns_i)
|
|
113
|
+
|
|
114
|
+
pbar.update(1)
|
|
115
|
+
|
|
116
|
+
return tensor_to_numpy(xs[idx].squeece()), tensor_to_numpy(idx.squeeze()), ns_list
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def codebook_sketch(sure_model, xs, n_samples, even_sample=False):
|
|
120
|
+
return codebook_sample(sure_model, xs, n_samples, even_sample)
|
|
121
|
+
|
|
122
|
+
def codebook_bootstrap_sketch(sure_model, xs, n_samples, n_neighbors=8,
|
|
123
|
+
aggregate_fun: Literal['mean','sum'] = 'mean',
|
|
124
|
+
even_sample=False, replace=True, filter=True):
|
|
125
|
+
xs = convert_to_tensor(xs, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
126
|
+
xs_zs = sure_model.get_cell_coordinates(xs)
|
|
127
|
+
xs_zs = tensor_to_numpy(xs_zs)
|
|
128
|
+
|
|
129
|
+
# generate samples that follow the metacell distribution of the given data
|
|
130
|
+
assigns = sure_model.soft_assignments(xs)
|
|
131
|
+
code_assigns = np.argmax(assigns,axis=1)
|
|
132
|
+
if even_sample:
|
|
133
|
+
repeat = n_samples // assigns.shape[1]
|
|
134
|
+
remainder = n_samples % assigns.shape[1]
|
|
135
|
+
ns_id = np.repeat(np.arange(1, assigns.shape[1] + 1), repeat)
|
|
136
|
+
# 补充剩余元素(将前 `remainder` 个数字各多重复1次)
|
|
137
|
+
if remainder > 0:
|
|
138
|
+
ns_id = np.concatenate([ns_id, np.arange(1, remainder + 1)])
|
|
139
|
+
ns_id -= 1
|
|
140
|
+
|
|
141
|
+
ns = LabelBinarizer().fit_transform(ns_id)
|
|
142
|
+
ns = convert_to_tensor(ns, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
143
|
+
else:
|
|
144
|
+
code_weights = codebook_weights(assigns)
|
|
145
|
+
code_weights = convert_to_tensor(code_weights, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
146
|
+
ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
|
|
147
|
+
ns_id = np.argmax(tensor_to_numpy(ns), axis=1)
|
|
148
|
+
|
|
149
|
+
codebook_loc, codebook_scale = sure_model.get_codebook()
|
|
150
|
+
codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
151
|
+
codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
152
|
+
|
|
153
|
+
loc = torch.matmul(ns, codebook_loc)
|
|
154
|
+
scale = torch.matmul(ns, codebook_scale)
|
|
155
|
+
zs = dist.Normal(loc, scale).to_event(1).sample()
|
|
156
|
+
zs = tensor_to_numpy(zs)
|
|
157
|
+
|
|
158
|
+
# find the neighbors of sample data in the real data space
|
|
159
|
+
nbrs = NearestNeighbors(n_neighbors=50, n_jobs=-1)
|
|
160
|
+
nbrs.fit(xs_zs)
|
|
161
|
+
|
|
162
|
+
xs_list = []
|
|
163
|
+
ns_list = []
|
|
164
|
+
distances, ids = nbrs.kneighbors(zs, return_distance=True)
|
|
165
|
+
#dist_pdf = gaussian_kde(distances.flatten())
|
|
166
|
+
|
|
167
|
+
xs = tensor_to_numpy(xs)
|
|
168
|
+
sketch_cells = dict()
|
|
169
|
+
with tqdm(total=n_samples, desc='Sketching', unit='sketch') as pbar:
|
|
170
|
+
for i in np.arange(n_samples):
|
|
171
|
+
#cells_i_ = ids[i, dist_pdf(distances[i]) > pval]
|
|
172
|
+
#cells_i = [c for c in cells_i_ if np.array_equal(code_assigns[c],ns[i])]
|
|
173
|
+
distances_i = distances[i]
|
|
174
|
+
weights_i = distance_to_softmax_weights(distances_i)
|
|
175
|
+
cells_i_ = weighted_sample(ids[i], weights_i, sample_size=n_neighbors, replace=replace)
|
|
176
|
+
|
|
177
|
+
df = pd.DataFrame({'idx':cells_i_,
|
|
178
|
+
'to': code_assigns[cells_i_],
|
|
179
|
+
'from': [ns_id[i]] * len(cells_i_)})
|
|
180
|
+
if filter:
|
|
181
|
+
filtered_df = df[df['from'] != df['to']]
|
|
182
|
+
else:
|
|
183
|
+
filtered_df = df
|
|
184
|
+
cells_i = filtered_df.loc[:,'idx'].values
|
|
185
|
+
ns_i = filtered_df.loc[:,'from'].unique()
|
|
186
|
+
|
|
187
|
+
if len(cells_i)>0:
|
|
188
|
+
xs_i = xs[cells_i]
|
|
189
|
+
if aggregate_fun == 'mean':
|
|
190
|
+
xs_i = np.mean(xs_i, axis=0, keepdims=True)
|
|
191
|
+
elif aggregate_fun == 'median':
|
|
192
|
+
xs_i = np.median(xs_i, axis=0, keepdims=True)
|
|
193
|
+
elif aggregate_fun == 'sum':
|
|
194
|
+
xs_i = np.sum(xs_i, axis=0, keepdims=True)
|
|
195
|
+
|
|
196
|
+
xs_list.append(xs_i)
|
|
197
|
+
ns_list.extend(ns_i)
|
|
198
|
+
sketch_cells[i] = cells_i
|
|
199
|
+
|
|
200
|
+
pbar.update(1)
|
|
201
|
+
|
|
202
|
+
return np.vstack(xs_list),sketch_cells,ns_list
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def process_chunk(chunk_indices: np.ndarray,
|
|
206
|
+
zs: np.ndarray,
|
|
207
|
+
xs_zs: np.ndarray,
|
|
208
|
+
xs: np.ndarray,
|
|
209
|
+
code_assigns: np.ndarray,
|
|
210
|
+
ns_id: np.ndarray,
|
|
211
|
+
n_neighbors: int,
|
|
212
|
+
replace: bool,
|
|
213
|
+
filter: bool,
|
|
214
|
+
aggregate_fun: str) -> Tuple[List[np.ndarray], Dict[int, np.ndarray], List[int]]:
|
|
215
|
+
"""
|
|
216
|
+
处理一个chunk的样本
|
|
217
|
+
|
|
218
|
+
参数:
|
|
219
|
+
chunk_indices: 当前chunk包含的样本索引
|
|
220
|
+
其他参数与主函数相同
|
|
221
|
+
|
|
222
|
+
返回:
|
|
223
|
+
(xs_chunk, sketch_cells_chunk, ns_list_chunk)
|
|
224
|
+
"""
|
|
225
|
+
xs_chunk = []
|
|
226
|
+
sketch_cells_chunk = {}
|
|
227
|
+
ns_list_chunk = []
|
|
228
|
+
|
|
229
|
+
# 每个chunk创建自己的NearestNeighbors实例,避免多进程冲突
|
|
230
|
+
nbrs = NearestNeighbors(n_neighbors=50, n_jobs=1) # 单线程模式
|
|
231
|
+
nbrs.fit(xs_zs)
|
|
232
|
+
|
|
233
|
+
for i in chunk_indices:
|
|
234
|
+
distances, ids = nbrs.kneighbors(zs[i:i+1], return_distance=True)
|
|
235
|
+
distances_i = distances[0]
|
|
236
|
+
weights_i = distance_to_softmax_weights(distances_i)
|
|
237
|
+
cells_i_ = weighted_sample(ids[0], weights_i, sample_size=n_neighbors, replace=replace)
|
|
238
|
+
|
|
239
|
+
df = pd.DataFrame({
|
|
240
|
+
'idx': cells_i_,
|
|
241
|
+
'to': code_assigns[cells_i_],
|
|
242
|
+
'from': [ns_id[i]] * len(cells_i_)
|
|
243
|
+
})
|
|
244
|
+
|
|
245
|
+
if filter:
|
|
246
|
+
filtered_df = df[df['from'] != df['to']]
|
|
247
|
+
else:
|
|
248
|
+
filtered_df = df
|
|
249
|
+
|
|
250
|
+
cells_i = filtered_df.loc[:, 'idx'].values
|
|
251
|
+
ns_i = filtered_df.loc[:, 'from'].unique()
|
|
252
|
+
|
|
253
|
+
if len(cells_i) > 0:
|
|
254
|
+
xs_i = xs[cells_i]
|
|
255
|
+
|
|
256
|
+
if aggregate_fun == 'mean':
|
|
257
|
+
xs_i = np.mean(xs_i, axis=0, keepdims=True)
|
|
258
|
+
elif aggregate_fun == 'median':
|
|
259
|
+
xs_i = np.median(xs_i, axis=0, keepdims=True)
|
|
260
|
+
elif aggregate_fun == 'sum':
|
|
261
|
+
xs_i = np.sum(xs_i, axis=0, keepdims=True)
|
|
262
|
+
|
|
263
|
+
xs_chunk.append(xs_i)
|
|
264
|
+
sketch_cells_chunk[i] = cells_i
|
|
265
|
+
ns_list_chunk.extend(ns_i)
|
|
266
|
+
|
|
267
|
+
return xs_chunk, sketch_cells_chunk, ns_list_chunk
|
|
268
|
+
|
|
269
|
+
def codebook_bootstrap_sketch_parallel(
|
|
270
|
+
sure_model,
|
|
271
|
+
xs,
|
|
272
|
+
n_samples,
|
|
273
|
+
n_neighbors=8,
|
|
274
|
+
aggregate_fun: Literal['mean','sum'] = 'mean',
|
|
275
|
+
even_sample=False,
|
|
276
|
+
replace=True,
|
|
277
|
+
filter=True,
|
|
278
|
+
n_processes: int = None,
|
|
279
|
+
chunk_size: int = 100
|
|
280
|
+
) -> Tuple[np.ndarray, Dict[int, np.ndarray], List[int]]:
|
|
281
|
+
"""
|
|
282
|
+
基于chunk的并行版本
|
|
283
|
+
|
|
284
|
+
新增参数:
|
|
285
|
+
n_processes: 并行进程数,None表示使用所有CPU核心
|
|
286
|
+
chunk_size: 每个chunk包含的样本数
|
|
287
|
+
"""
|
|
288
|
+
# 转换输入数据 (与原始版本相同)
|
|
289
|
+
xs = convert_to_tensor(xs, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
290
|
+
xs_zs = sure_model.get_cell_coordinates(xs)
|
|
291
|
+
xs_zs = tensor_to_numpy(xs_zs)
|
|
292
|
+
|
|
293
|
+
# 生成样本 (与原始版本相同)
|
|
294
|
+
assigns = sure_model.soft_assignments(xs)
|
|
295
|
+
code_assigns = np.argmax(assigns, axis=1)
|
|
296
|
+
|
|
297
|
+
if even_sample:
|
|
298
|
+
repeat = n_samples // assigns.shape[1]
|
|
299
|
+
remainder = n_samples % assigns.shape[1]
|
|
300
|
+
ns_id = np.repeat(np.arange(1, assigns.shape[1] + 1), repeat)
|
|
301
|
+
if remainder > 0:
|
|
302
|
+
ns_id = np.concatenate([ns_id, np.arange(1, remainder + 1)])
|
|
303
|
+
ns_id -= 1
|
|
304
|
+
ns = LabelBinarizer().fit_transform(ns_id)
|
|
305
|
+
ns = convert_to_tensor(ns, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
306
|
+
else:
|
|
307
|
+
code_weights = codebook_weights(assigns)
|
|
308
|
+
code_weights = convert_to_tensor(code_weights, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
309
|
+
ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
|
|
310
|
+
ns_id = np.argmax(tensor_to_numpy(ns), axis=1)
|
|
311
|
+
|
|
312
|
+
# 获取codebook (与原始版本相同)
|
|
313
|
+
codebook_loc, codebook_scale = sure_model.get_codebook()
|
|
314
|
+
codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
315
|
+
codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
|
|
316
|
+
|
|
317
|
+
# 生成zs (与原始版本相同)
|
|
318
|
+
loc = torch.matmul(ns, codebook_loc)
|
|
319
|
+
scale = torch.matmul(ns, codebook_scale)
|
|
320
|
+
zs = dist.Normal(loc, scale).to_event(1).sample()
|
|
321
|
+
zs = tensor_to_numpy(zs)
|
|
322
|
+
|
|
323
|
+
# 转换为numpy数组 (与原始版本相同)
|
|
324
|
+
xs = tensor_to_numpy(xs)
|
|
325
|
+
|
|
326
|
+
# 准备结果容器
|
|
327
|
+
xs_list = []
|
|
328
|
+
sketch_cells = {}
|
|
329
|
+
ns_list = []
|
|
330
|
+
|
|
331
|
+
# 分割样本为chunks
|
|
332
|
+
chunks = [np.arange(i, min(i + chunk_size, n_samples))
|
|
333
|
+
for i in range(0, n_samples, chunk_size)]
|
|
334
|
+
|
|
335
|
+
# 创建进程池
|
|
336
|
+
with mp.Pool(processes=n_processes) as pool:
|
|
337
|
+
# 使用partial固定参数
|
|
338
|
+
worker = partial(
|
|
339
|
+
process_chunk,
|
|
340
|
+
zs=zs,
|
|
341
|
+
xs_zs=xs_zs,
|
|
342
|
+
xs=xs,
|
|
343
|
+
code_assigns=code_assigns,
|
|
344
|
+
ns_id=ns_id,
|
|
345
|
+
n_neighbors=n_neighbors,
|
|
346
|
+
replace=replace,
|
|
347
|
+
filter=filter,
|
|
348
|
+
aggregate_fun=aggregate_fun
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# 使用tqdm显示进度
|
|
352
|
+
results = list(tqdm(
|
|
353
|
+
pool.imap(worker, chunks),
|
|
354
|
+
total=len(chunks),
|
|
355
|
+
desc='Processing chunks',
|
|
356
|
+
unit='chunk'
|
|
357
|
+
))
|
|
358
|
+
|
|
359
|
+
# 合并结果
|
|
360
|
+
for xs_chunk, sketch_cells_chunk, ns_list_chunk in results:
|
|
361
|
+
xs_list.extend(xs_chunk)
|
|
362
|
+
sketch_cells.update(sketch_cells_chunk)
|
|
363
|
+
ns_list.extend(ns_list_chunk)
|
|
364
|
+
|
|
365
|
+
return (
|
|
366
|
+
np.vstack(xs_list) if xs_list else np.empty((0, xs.shape[1])),
|
|
367
|
+
sketch_cells,
|
|
368
|
+
ns_list
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def codebook_summarize_(assigns, xs):
|
|
373
|
+
assigns = convert_to_tensor(assigns)
|
|
374
|
+
xs = convert_to_tensor(xs)
|
|
375
|
+
results = torch.matmul(assigns.T, xs)
|
|
376
|
+
results = results / torch.sum(assigns.T, dim=1, keepdim=True)
|
|
377
|
+
return tensor_to_numpy(results)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def codebook_summarize(assigns, xs, batch_size=1024):
|
|
381
|
+
assigns = convert_to_tensor(assigns)
|
|
382
|
+
xs = convert_to_tensor(xs)
|
|
383
|
+
|
|
384
|
+
dataset = CustomDataset2(assigns, xs)
|
|
385
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
386
|
+
|
|
387
|
+
R = None
|
|
388
|
+
W = None
|
|
389
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
390
|
+
for A_batch, X_batch, _ in dataloader:
|
|
391
|
+
r = torch.matmul(A_batch.T, X_batch)
|
|
392
|
+
w = torch.sum(A_batch.T, dim=1, keepdim=True)
|
|
393
|
+
if R is None:
|
|
394
|
+
R = r
|
|
395
|
+
W = w
|
|
396
|
+
else:
|
|
397
|
+
R += r
|
|
398
|
+
W += w
|
|
399
|
+
pbar.update(1)
|
|
400
|
+
|
|
401
|
+
results = R / W
|
|
402
|
+
return tensor_to_numpy(results)
|
|
403
|
+
|
|
404
|
+
def codebook_aggregate(assigns, xs, batch_size=1024):
|
|
405
|
+
assigns = convert_to_tensor(assigns)
|
|
406
|
+
xs = convert_to_tensor(xs)
|
|
407
|
+
|
|
408
|
+
dataset = CustomDataset2(assigns, xs)
|
|
409
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
410
|
+
|
|
411
|
+
R = None
|
|
412
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
413
|
+
for A_batch, X_batch, _ in dataloader:
|
|
414
|
+
r = torch.matmul(A_batch.T, X_batch)
|
|
415
|
+
if R is None:
|
|
416
|
+
R = r
|
|
417
|
+
else:
|
|
418
|
+
R += r
|
|
419
|
+
pbar.update(1)
|
|
420
|
+
|
|
421
|
+
results = R
|
|
422
|
+
return tensor_to_numpy(results)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def codebook_weights(assigns):
|
|
426
|
+
assigns = convert_to_tensor(assigns)
|
|
427
|
+
results = torch.sum(assigns, dim=0)
|
|
428
|
+
results = results / torch.sum(results)
|
|
429
|
+
return tensor_to_numpy(results)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def distance_to_softmax_weights(distances):
|
|
433
|
+
"""使用softmax将距离列表转换为概率权重
|
|
434
|
+
|
|
435
|
+
参数:
|
|
436
|
+
distances: 距离列表,距离越小权重应该越大
|
|
437
|
+
|
|
438
|
+
返回:
|
|
439
|
+
概率权重数组,和为1
|
|
440
|
+
"""
|
|
441
|
+
distances = np.array(distances)
|
|
442
|
+
# 取负数使得距离越小值越大
|
|
443
|
+
negative_distances = -distances
|
|
444
|
+
# 计算softmax
|
|
445
|
+
exp_dist = np.exp(negative_distances - np.max(negative_distances)) # 数值稳定性处理
|
|
446
|
+
softmax = exp_dist / np.sum(exp_dist)
|
|
447
|
+
return softmax
|
|
448
|
+
|
|
449
|
+
def weighted_sample(items, weights, sample_size=1, replace=True):
|
|
450
|
+
"""根据权重进行采样
|
|
451
|
+
|
|
452
|
+
参数:
|
|
453
|
+
items: 待采样的列表
|
|
454
|
+
weights: 对应的概率权重列表
|
|
455
|
+
sample_size: 采样数量
|
|
456
|
+
replace: 是否允许重复采样
|
|
457
|
+
|
|
458
|
+
返回:
|
|
459
|
+
采样结果列表
|
|
460
|
+
"""
|
|
461
|
+
return np.random.choice(
|
|
462
|
+
a=items,
|
|
463
|
+
size=sample_size,
|
|
464
|
+
p=weights,
|
|
465
|
+
replace=replace
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def split_evenly(n, m):
|
|
469
|
+
"""使用NumPy将n分成m个尽可能平均的数"""
|
|
470
|
+
arr = np.full(m, n // m)
|
|
471
|
+
arr[:n % m] += 1
|
|
472
|
+
return arr.tolist()
|
SURE/utils/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Importing specific functions from modules
|
|
2
|
+
from .utils import tensor_to_numpy, move_to_device, convert_to_tensor
|
|
3
|
+
from .utils import CustomDataset, CustomDataset2, CustomDataset3
|
|
4
|
+
from .utils import CustomMultiOmicsDataset, CustomMultiOmicsDataset2
|
|
5
|
+
from .utils import pretty_print, Colors
|
|
6
|
+
from .utils import find_partitions_greedy
|
|
7
|
+
|
|
8
|
+
from .queue import PriorityQueue
|
|
9
|
+
|
|
10
|
+
from .custom_mlp import MLP, Exp
|
|
11
|
+
|
|
12
|
+
# Importing modules
|
|
13
|
+
#from . import utils
|
|
14
|
+
#from . import custom_mlp
|
|
15
|
+
|
|
16
|
+
#__all__ = ['tensor_to_numpy', 'move_to_device', 'convert_to_tensor',
|
|
17
|
+
# 'CustomDataset', 'CustomDataset2', 'CustomDataset3',
|
|
18
|
+
# 'MLP','Exp',
|
|
19
|
+
# 'custom_mlp','utils']
|
SURE/utils/custom_mlp.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# Copyright (c) 2017-2019 Uber Technologies, Inc.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from inspect import isclass
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from pyro.distributions.util import broadcast_shape
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Exp(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
a custom module for exponentiation of tensors
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
def forward(self, val):
|
|
21
|
+
return torch.exp(val)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ConcatModule(nn.Module):
|
|
25
|
+
"""
|
|
26
|
+
a custom module for concatenation of tensors
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, allow_broadcast=False):
|
|
30
|
+
self.allow_broadcast = allow_broadcast
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
def forward(self, *input_args):
|
|
34
|
+
# we have a single object
|
|
35
|
+
if len(input_args) == 1:
|
|
36
|
+
# regardless of type,
|
|
37
|
+
# we don't care about single objects
|
|
38
|
+
# we just index into the object
|
|
39
|
+
input_args = input_args[0]
|
|
40
|
+
|
|
41
|
+
# don't concat things that are just single objects
|
|
42
|
+
if torch.is_tensor(input_args):
|
|
43
|
+
return input_args
|
|
44
|
+
else:
|
|
45
|
+
if self.allow_broadcast:
|
|
46
|
+
shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,)
|
|
47
|
+
input_args = [s.expand(shape) for s in input_args]
|
|
48
|
+
return torch.cat(input_args, dim=-1)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ListOutModule(nn.ModuleList):
|
|
52
|
+
"""
|
|
53
|
+
a custom module for outputting a list of tensors from a list of nn modules
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, modules):
|
|
57
|
+
super().__init__(modules)
|
|
58
|
+
|
|
59
|
+
def forward(self, *args, **kwargs):
|
|
60
|
+
# loop over modules in self, apply same args
|
|
61
|
+
return [mm.forward(*args, **kwargs) for mm in self]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def call_nn_op(op):
|
|
65
|
+
"""
|
|
66
|
+
a helper function that adds appropriate parameters when calling
|
|
67
|
+
an nn module representing an operation like Softmax
|
|
68
|
+
|
|
69
|
+
:param op: the nn.Module operation to instantiate
|
|
70
|
+
:return: instantiation of the op module with appropriate parameters
|
|
71
|
+
"""
|
|
72
|
+
if op in [nn.Softmax, nn.LogSoftmax]:
|
|
73
|
+
return op(dim=1)
|
|
74
|
+
else:
|
|
75
|
+
return op()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class MLP(nn.Module):
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
mlp_sizes,
|
|
82
|
+
activation=nn.ReLU,
|
|
83
|
+
output_activation=None,
|
|
84
|
+
post_layer_fct=lambda layer_ix, total_layers, layer: None,
|
|
85
|
+
post_act_fct=lambda layer_ix, total_layers, layer: None,
|
|
86
|
+
allow_broadcast=False,
|
|
87
|
+
use_cuda=False,
|
|
88
|
+
):
|
|
89
|
+
# init the module object
|
|
90
|
+
super().__init__()
|
|
91
|
+
|
|
92
|
+
assert len(mlp_sizes) >= 2, "Must have input and output layer sizes defined"
|
|
93
|
+
|
|
94
|
+
# get our inputs, outputs, and hidden
|
|
95
|
+
input_size, hidden_sizes, output_size = (
|
|
96
|
+
mlp_sizes[0],
|
|
97
|
+
mlp_sizes[1:-1],
|
|
98
|
+
mlp_sizes[-1],
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# assume int or list
|
|
102
|
+
assert isinstance(
|
|
103
|
+
input_size, (int, list, tuple)
|
|
104
|
+
), "input_size must be int, list, tuple"
|
|
105
|
+
|
|
106
|
+
# everything in MLP will be concatted if it's multiple arguments
|
|
107
|
+
last_layer_size = input_size if type(input_size) == int else sum(input_size)
|
|
108
|
+
|
|
109
|
+
# everything sent in will be concatted together by default
|
|
110
|
+
all_modules = [ConcatModule(allow_broadcast)]
|
|
111
|
+
|
|
112
|
+
# loop over l
|
|
113
|
+
for layer_ix, layer_size in enumerate(hidden_sizes):
|
|
114
|
+
assert type(layer_size) == int, "Hidden layer sizes must be ints"
|
|
115
|
+
|
|
116
|
+
# get our nn layer module (in this case nn.Linear by default)
|
|
117
|
+
cur_linear_layer = nn.Linear(last_layer_size, layer_size)
|
|
118
|
+
|
|
119
|
+
# for numerical stability -- initialize the layer properly
|
|
120
|
+
cur_linear_layer.weight.data.normal_(0, 0.001)
|
|
121
|
+
cur_linear_layer.bias.data.normal_(0, 0.001)
|
|
122
|
+
|
|
123
|
+
# use GPUs to share data during training (if available)
|
|
124
|
+
if use_cuda:
|
|
125
|
+
cur_linear_layer = nn.DataParallel(cur_linear_layer)
|
|
126
|
+
|
|
127
|
+
# add our linear layer
|
|
128
|
+
all_modules.append(cur_linear_layer)
|
|
129
|
+
|
|
130
|
+
# handle post_linear
|
|
131
|
+
post_linear = post_layer_fct(
|
|
132
|
+
layer_ix + 1, len(hidden_sizes), all_modules[-1]
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# if we send something back, add it to sequential
|
|
136
|
+
# here we could return a batch norm for example
|
|
137
|
+
if post_linear is not None:
|
|
138
|
+
all_modules.append(post_linear)
|
|
139
|
+
|
|
140
|
+
# handle activation (assumed no params -- deal with that later)
|
|
141
|
+
all_modules.append(activation())
|
|
142
|
+
|
|
143
|
+
# now handle after activation
|
|
144
|
+
post_activation = post_act_fct(
|
|
145
|
+
layer_ix + 1, len(hidden_sizes), all_modules[-1]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# handle post_activation if not null
|
|
149
|
+
# could add batch norm for example
|
|
150
|
+
if post_activation is not None:
|
|
151
|
+
all_modules.append(post_activation)
|
|
152
|
+
|
|
153
|
+
# save the layer size we just created
|
|
154
|
+
last_layer_size = layer_size
|
|
155
|
+
|
|
156
|
+
# now we have all of our hidden layers
|
|
157
|
+
# we handle outputs
|
|
158
|
+
assert isinstance(
|
|
159
|
+
output_size, (int, list, tuple)
|
|
160
|
+
), "output_size must be int, list, tuple"
|
|
161
|
+
|
|
162
|
+
if type(output_size) == int:
|
|
163
|
+
all_modules.append(nn.Linear(last_layer_size, output_size))
|
|
164
|
+
if output_activation is not None:
|
|
165
|
+
all_modules.append(
|
|
166
|
+
call_nn_op(output_activation)
|
|
167
|
+
if isclass(output_activation)
|
|
168
|
+
else output_activation
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
|
|
172
|
+
# we're going to have a bunch of separate layers we can spit out (a tuple of outputs)
|
|
173
|
+
out_layers = []
|
|
174
|
+
|
|
175
|
+
# multiple outputs? handle separately
|
|
176
|
+
for out_ix, out_size in enumerate(output_size):
|
|
177
|
+
|
|
178
|
+
# for a single output object, we create a linear layer and some weights
|
|
179
|
+
split_layer = []
|
|
180
|
+
|
|
181
|
+
# we have an activation function
|
|
182
|
+
split_layer.append(nn.Linear(last_layer_size, out_size))
|
|
183
|
+
|
|
184
|
+
# then we get our output activation (either we repeat all or we index into a same sized array)
|
|
185
|
+
act_out_fct = (
|
|
186
|
+
output_activation
|
|
187
|
+
if not isinstance(output_activation, (list, tuple))
|
|
188
|
+
else output_activation[out_ix]
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if act_out_fct:
|
|
192
|
+
# we check if it's a class. if so, instantiate the object
|
|
193
|
+
# otherwise, use the object directly (e.g. pre-instaniated)
|
|
194
|
+
split_layer.append(
|
|
195
|
+
call_nn_op(act_out_fct) if isclass(act_out_fct) else act_out_fct
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# our outputs is just a sequential of the two
|
|
199
|
+
out_layers.append(nn.Sequential(*split_layer))
|
|
200
|
+
|
|
201
|
+
all_modules.append(ListOutModule(out_layers))
|
|
202
|
+
|
|
203
|
+
# now we have all of our modules, we're ready to build our sequential!
|
|
204
|
+
# process mlps in order, pretty standard here
|
|
205
|
+
self.sequential_mlp = nn.Sequential(*all_modules)
|
|
206
|
+
|
|
207
|
+
# pass through our sequential for the output!
|
|
208
|
+
def forward(self, *args, **kwargs):
|
|
209
|
+
return self.sequential_mlp.forward(*args, **kwargs)
|