SURE-tools 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -0,0 +1,472 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+
4
+ import pyro
5
+ import pyro.distributions as dist
6
+
7
+ import numpy as np
8
+ import scipy as sp
9
+ import pandas as pd
10
+ from scipy.stats import gaussian_kde
11
+ from sklearn.preprocessing import LabelBinarizer
12
+ from sklearn.neighbors import NearestNeighbors
13
+ import scanpy as sc
14
+
15
+ import multiprocessing as mp
16
+ from tqdm import tqdm
17
+
18
+ from ..utils import convert_to_tensor, tensor_to_numpy
19
+ from ..utils import CustomDataset2
20
+
21
+ from typing import Literal, List, Tuple, Dict
22
+ from functools import partial
23
+
24
+
25
+ def codebook_generate(sure_model, n_samples):
26
+ code_weights = convert_to_tensor(sure_model.codebook_weights, dtype=sure_model.dtype, device=sure_model.get_device())
27
+ ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
28
+
29
+ codebook_loc, codebook_scale = sure_model.get_codebook()
30
+ codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
31
+ codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
32
+
33
+ loc = torch.matmul(ns, codebook_loc)
34
+ scale = torch.matmul(ns, codebook_scale)
35
+ zs = dist.Normal(loc, scale).to_event(1).sample()
36
+ return tensor_to_numpy(zs), tensor_to_numpy(ns)
37
+
38
+
39
+ def codebook_sample(sure_model, xs, n_samples, even_sample=False, filter=True):
40
+ xs = convert_to_tensor(xs, dtype=sure_model.dtype, device=sure_model.get_device())
41
+ assigns = sure_model.soft_assignments(xs)
42
+ code_assigns = np.argmax(assigns, axis=1)
43
+
44
+ if even_sample:
45
+ repeat = n_samples // assigns.shape[1]
46
+ remainder = n_samples % assigns.shape[1]
47
+ ns_id = np.repeat(np.arange(1, assigns.shape[1] + 1), repeat)
48
+ # 补充剩余元素(将前 `remainder` 个数字各多重复1次)
49
+ if remainder > 0:
50
+ ns_id = np.concatenate([ns_id, np.arange(1, remainder + 1)])
51
+ ns_id -= 1
52
+
53
+ ns = LabelBinarizer().fit_transform(ns_id)
54
+ ns = convert_to_tensor(ns, dtype=sure_model.dtype, device=sure_model.get_device())
55
+ else:
56
+ code_weights = codebook_weights(assigns)
57
+ code_weights = convert_to_tensor(code_weights, dtype=sure_model.dtype, device=sure_model.get_device())
58
+ ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
59
+ ns_id = np.argmax(tensor_to_numpy(ns), axis=1)
60
+
61
+ codebook_loc, codebook_scale = sure_model.get_codebook()
62
+ codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
63
+ codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
64
+
65
+ loc = torch.matmul(ns, codebook_loc)
66
+ scale = torch.matmul(ns, codebook_scale)
67
+ zs = dist.Normal(loc, scale).to_event(1).sample()
68
+
69
+ xs_zs = sure_model.get_cell_coordinates(xs)
70
+ #xs_zs = convert_to_tensor(xs_zs, dtype=sure_model.dtype, device=sure_model.get_device())
71
+ #xs_dist = torch.cdist(zs, xs_zs)
72
+ #idx = xs_dist.argmin(dim=1)
73
+
74
+ #nbrs = NearestNeighbors(n_jobs=-1, n_neighbors=1)
75
+ #nbrs.fit(tensor_to_numpy(xs_zs))
76
+ #idx = nbrs.kneighbors(tensor_to_numpy(zs), return_distance=False)
77
+ #idx_ = idx.flatten()
78
+ #idx = [idx_[i] for i in np.arange(n_samples) if np.array_equal(code_assigns[idx_[i]], ns[i])]
79
+ #df = pd.DataFrame({'idx':idx_,
80
+ # 'to':code_assigns[idx_],
81
+ # 'from':ns_id})
82
+ #if filter:
83
+ # filtered_df = df[df['from'] != df['to']]
84
+ #else:
85
+ # filtered_df = df
86
+ #idx = filtered_df.loc[:,'idx'].values
87
+ #ns_id = filtered_df.loc[:,'from'].values
88
+
89
+ nbrs = NearestNeighbors(n_neighbors=50, n_jobs=-1)
90
+ nbrs.fit(tensor_to_numpy(xs_zs))
91
+
92
+ distances, ids = nbrs.kneighbors(tensor_to_numpy(zs), return_distance=True)
93
+
94
+ idx,ns_list = [],[]
95
+ with tqdm(total=n_samples, desc='Sketching', unit='sketch') as pbar:
96
+ for i in np.arange(n_samples):
97
+ distances_i = distances[i]
98
+ weights_i = distance_to_softmax_weights(distances_i)
99
+ cell_i_ = weighted_sample(ids[i], weights_i, sample_size=1, replace=False)
100
+
101
+ df = pd.DataFrame({'idx':[cell_i_],
102
+ 'to': [code_assigns[cell_i_]],
103
+ 'from': [ns_id[i]]})
104
+ if filter:
105
+ filtered_df = df[df['from'] != df['to']]
106
+ else:
107
+ filtered_df = df
108
+ cells_i = filtered_df.loc[:,'idx'].values
109
+ ns_i = filtered_df.loc[:,'from'].unique()
110
+
111
+ idx.extend(cells_i)
112
+ ns_list.extend(ns_i)
113
+
114
+ pbar.update(1)
115
+
116
+ return tensor_to_numpy(xs[idx].squeece()), tensor_to_numpy(idx.squeeze()), ns_list
117
+
118
+
119
+ def codebook_sketch(sure_model, xs, n_samples, even_sample=False):
120
+ return codebook_sample(sure_model, xs, n_samples, even_sample)
121
+
122
+ def codebook_bootstrap_sketch(sure_model, xs, n_samples, n_neighbors=8,
123
+ aggregate_fun: Literal['mean','sum'] = 'mean',
124
+ even_sample=False, replace=True, filter=True):
125
+ xs = convert_to_tensor(xs, dtype=sure_model.dtype, device=sure_model.get_device())
126
+ xs_zs = sure_model.get_cell_coordinates(xs)
127
+ xs_zs = tensor_to_numpy(xs_zs)
128
+
129
+ # generate samples that follow the metacell distribution of the given data
130
+ assigns = sure_model.soft_assignments(xs)
131
+ code_assigns = np.argmax(assigns,axis=1)
132
+ if even_sample:
133
+ repeat = n_samples // assigns.shape[1]
134
+ remainder = n_samples % assigns.shape[1]
135
+ ns_id = np.repeat(np.arange(1, assigns.shape[1] + 1), repeat)
136
+ # 补充剩余元素(将前 `remainder` 个数字各多重复1次)
137
+ if remainder > 0:
138
+ ns_id = np.concatenate([ns_id, np.arange(1, remainder + 1)])
139
+ ns_id -= 1
140
+
141
+ ns = LabelBinarizer().fit_transform(ns_id)
142
+ ns = convert_to_tensor(ns, dtype=sure_model.dtype, device=sure_model.get_device())
143
+ else:
144
+ code_weights = codebook_weights(assigns)
145
+ code_weights = convert_to_tensor(code_weights, dtype=sure_model.dtype, device=sure_model.get_device())
146
+ ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
147
+ ns_id = np.argmax(tensor_to_numpy(ns), axis=1)
148
+
149
+ codebook_loc, codebook_scale = sure_model.get_codebook()
150
+ codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
151
+ codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
152
+
153
+ loc = torch.matmul(ns, codebook_loc)
154
+ scale = torch.matmul(ns, codebook_scale)
155
+ zs = dist.Normal(loc, scale).to_event(1).sample()
156
+ zs = tensor_to_numpy(zs)
157
+
158
+ # find the neighbors of sample data in the real data space
159
+ nbrs = NearestNeighbors(n_neighbors=50, n_jobs=-1)
160
+ nbrs.fit(xs_zs)
161
+
162
+ xs_list = []
163
+ ns_list = []
164
+ distances, ids = nbrs.kneighbors(zs, return_distance=True)
165
+ #dist_pdf = gaussian_kde(distances.flatten())
166
+
167
+ xs = tensor_to_numpy(xs)
168
+ sketch_cells = dict()
169
+ with tqdm(total=n_samples, desc='Sketching', unit='sketch') as pbar:
170
+ for i in np.arange(n_samples):
171
+ #cells_i_ = ids[i, dist_pdf(distances[i]) > pval]
172
+ #cells_i = [c for c in cells_i_ if np.array_equal(code_assigns[c],ns[i])]
173
+ distances_i = distances[i]
174
+ weights_i = distance_to_softmax_weights(distances_i)
175
+ cells_i_ = weighted_sample(ids[i], weights_i, sample_size=n_neighbors, replace=replace)
176
+
177
+ df = pd.DataFrame({'idx':cells_i_,
178
+ 'to': code_assigns[cells_i_],
179
+ 'from': [ns_id[i]] * len(cells_i_)})
180
+ if filter:
181
+ filtered_df = df[df['from'] != df['to']]
182
+ else:
183
+ filtered_df = df
184
+ cells_i = filtered_df.loc[:,'idx'].values
185
+ ns_i = filtered_df.loc[:,'from'].unique()
186
+
187
+ if len(cells_i)>0:
188
+ xs_i = xs[cells_i]
189
+ if aggregate_fun == 'mean':
190
+ xs_i = np.mean(xs_i, axis=0, keepdims=True)
191
+ elif aggregate_fun == 'median':
192
+ xs_i = np.median(xs_i, axis=0, keepdims=True)
193
+ elif aggregate_fun == 'sum':
194
+ xs_i = np.sum(xs_i, axis=0, keepdims=True)
195
+
196
+ xs_list.append(xs_i)
197
+ ns_list.extend(ns_i)
198
+ sketch_cells[i] = cells_i
199
+
200
+ pbar.update(1)
201
+
202
+ return np.vstack(xs_list),sketch_cells,ns_list
203
+
204
+
205
+ def process_chunk(chunk_indices: np.ndarray,
206
+ zs: np.ndarray,
207
+ xs_zs: np.ndarray,
208
+ xs: np.ndarray,
209
+ code_assigns: np.ndarray,
210
+ ns_id: np.ndarray,
211
+ n_neighbors: int,
212
+ replace: bool,
213
+ filter: bool,
214
+ aggregate_fun: str) -> Tuple[List[np.ndarray], Dict[int, np.ndarray], List[int]]:
215
+ """
216
+ 处理一个chunk的样本
217
+
218
+ 参数:
219
+ chunk_indices: 当前chunk包含的样本索引
220
+ 其他参数与主函数相同
221
+
222
+ 返回:
223
+ (xs_chunk, sketch_cells_chunk, ns_list_chunk)
224
+ """
225
+ xs_chunk = []
226
+ sketch_cells_chunk = {}
227
+ ns_list_chunk = []
228
+
229
+ # 每个chunk创建自己的NearestNeighbors实例,避免多进程冲突
230
+ nbrs = NearestNeighbors(n_neighbors=50, n_jobs=1) # 单线程模式
231
+ nbrs.fit(xs_zs)
232
+
233
+ for i in chunk_indices:
234
+ distances, ids = nbrs.kneighbors(zs[i:i+1], return_distance=True)
235
+ distances_i = distances[0]
236
+ weights_i = distance_to_softmax_weights(distances_i)
237
+ cells_i_ = weighted_sample(ids[0], weights_i, sample_size=n_neighbors, replace=replace)
238
+
239
+ df = pd.DataFrame({
240
+ 'idx': cells_i_,
241
+ 'to': code_assigns[cells_i_],
242
+ 'from': [ns_id[i]] * len(cells_i_)
243
+ })
244
+
245
+ if filter:
246
+ filtered_df = df[df['from'] != df['to']]
247
+ else:
248
+ filtered_df = df
249
+
250
+ cells_i = filtered_df.loc[:, 'idx'].values
251
+ ns_i = filtered_df.loc[:, 'from'].unique()
252
+
253
+ if len(cells_i) > 0:
254
+ xs_i = xs[cells_i]
255
+
256
+ if aggregate_fun == 'mean':
257
+ xs_i = np.mean(xs_i, axis=0, keepdims=True)
258
+ elif aggregate_fun == 'median':
259
+ xs_i = np.median(xs_i, axis=0, keepdims=True)
260
+ elif aggregate_fun == 'sum':
261
+ xs_i = np.sum(xs_i, axis=0, keepdims=True)
262
+
263
+ xs_chunk.append(xs_i)
264
+ sketch_cells_chunk[i] = cells_i
265
+ ns_list_chunk.extend(ns_i)
266
+
267
+ return xs_chunk, sketch_cells_chunk, ns_list_chunk
268
+
269
+ def codebook_bootstrap_sketch_parallel(
270
+ sure_model,
271
+ xs,
272
+ n_samples,
273
+ n_neighbors=8,
274
+ aggregate_fun: Literal['mean','sum'] = 'mean',
275
+ even_sample=False,
276
+ replace=True,
277
+ filter=True,
278
+ n_processes: int = None,
279
+ chunk_size: int = 100
280
+ ) -> Tuple[np.ndarray, Dict[int, np.ndarray], List[int]]:
281
+ """
282
+ 基于chunk的并行版本
283
+
284
+ 新增参数:
285
+ n_processes: 并行进程数,None表示使用所有CPU核心
286
+ chunk_size: 每个chunk包含的样本数
287
+ """
288
+ # 转换输入数据 (与原始版本相同)
289
+ xs = convert_to_tensor(xs, dtype=sure_model.dtype, device=sure_model.get_device())
290
+ xs_zs = sure_model.get_cell_coordinates(xs)
291
+ xs_zs = tensor_to_numpy(xs_zs)
292
+
293
+ # 生成样本 (与原始版本相同)
294
+ assigns = sure_model.soft_assignments(xs)
295
+ code_assigns = np.argmax(assigns, axis=1)
296
+
297
+ if even_sample:
298
+ repeat = n_samples // assigns.shape[1]
299
+ remainder = n_samples % assigns.shape[1]
300
+ ns_id = np.repeat(np.arange(1, assigns.shape[1] + 1), repeat)
301
+ if remainder > 0:
302
+ ns_id = np.concatenate([ns_id, np.arange(1, remainder + 1)])
303
+ ns_id -= 1
304
+ ns = LabelBinarizer().fit_transform(ns_id)
305
+ ns = convert_to_tensor(ns, dtype=sure_model.dtype, device=sure_model.get_device())
306
+ else:
307
+ code_weights = codebook_weights(assigns)
308
+ code_weights = convert_to_tensor(code_weights, dtype=sure_model.dtype, device=sure_model.get_device())
309
+ ns = dist.OneHotCategorical(probs=code_weights).sample([n_samples])
310
+ ns_id = np.argmax(tensor_to_numpy(ns), axis=1)
311
+
312
+ # 获取codebook (与原始版本相同)
313
+ codebook_loc, codebook_scale = sure_model.get_codebook()
314
+ codebook_loc = convert_to_tensor(codebook_loc, dtype=sure_model.dtype, device=sure_model.get_device())
315
+ codebook_scale = convert_to_tensor(codebook_scale, dtype=sure_model.dtype, device=sure_model.get_device())
316
+
317
+ # 生成zs (与原始版本相同)
318
+ loc = torch.matmul(ns, codebook_loc)
319
+ scale = torch.matmul(ns, codebook_scale)
320
+ zs = dist.Normal(loc, scale).to_event(1).sample()
321
+ zs = tensor_to_numpy(zs)
322
+
323
+ # 转换为numpy数组 (与原始版本相同)
324
+ xs = tensor_to_numpy(xs)
325
+
326
+ # 准备结果容器
327
+ xs_list = []
328
+ sketch_cells = {}
329
+ ns_list = []
330
+
331
+ # 分割样本为chunks
332
+ chunks = [np.arange(i, min(i + chunk_size, n_samples))
333
+ for i in range(0, n_samples, chunk_size)]
334
+
335
+ # 创建进程池
336
+ with mp.Pool(processes=n_processes) as pool:
337
+ # 使用partial固定参数
338
+ worker = partial(
339
+ process_chunk,
340
+ zs=zs,
341
+ xs_zs=xs_zs,
342
+ xs=xs,
343
+ code_assigns=code_assigns,
344
+ ns_id=ns_id,
345
+ n_neighbors=n_neighbors,
346
+ replace=replace,
347
+ filter=filter,
348
+ aggregate_fun=aggregate_fun
349
+ )
350
+
351
+ # 使用tqdm显示进度
352
+ results = list(tqdm(
353
+ pool.imap(worker, chunks),
354
+ total=len(chunks),
355
+ desc='Processing chunks',
356
+ unit='chunk'
357
+ ))
358
+
359
+ # 合并结果
360
+ for xs_chunk, sketch_cells_chunk, ns_list_chunk in results:
361
+ xs_list.extend(xs_chunk)
362
+ sketch_cells.update(sketch_cells_chunk)
363
+ ns_list.extend(ns_list_chunk)
364
+
365
+ return (
366
+ np.vstack(xs_list) if xs_list else np.empty((0, xs.shape[1])),
367
+ sketch_cells,
368
+ ns_list
369
+ )
370
+
371
+
372
+ def codebook_summarize_(assigns, xs):
373
+ assigns = convert_to_tensor(assigns)
374
+ xs = convert_to_tensor(xs)
375
+ results = torch.matmul(assigns.T, xs)
376
+ results = results / torch.sum(assigns.T, dim=1, keepdim=True)
377
+ return tensor_to_numpy(results)
378
+
379
+
380
+ def codebook_summarize(assigns, xs, batch_size=1024):
381
+ assigns = convert_to_tensor(assigns)
382
+ xs = convert_to_tensor(xs)
383
+
384
+ dataset = CustomDataset2(assigns, xs)
385
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
386
+
387
+ R = None
388
+ W = None
389
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
390
+ for A_batch, X_batch, _ in dataloader:
391
+ r = torch.matmul(A_batch.T, X_batch)
392
+ w = torch.sum(A_batch.T, dim=1, keepdim=True)
393
+ if R is None:
394
+ R = r
395
+ W = w
396
+ else:
397
+ R += r
398
+ W += w
399
+ pbar.update(1)
400
+
401
+ results = R / W
402
+ return tensor_to_numpy(results)
403
+
404
+ def codebook_aggregate(assigns, xs, batch_size=1024):
405
+ assigns = convert_to_tensor(assigns)
406
+ xs = convert_to_tensor(xs)
407
+
408
+ dataset = CustomDataset2(assigns, xs)
409
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
410
+
411
+ R = None
412
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
413
+ for A_batch, X_batch, _ in dataloader:
414
+ r = torch.matmul(A_batch.T, X_batch)
415
+ if R is None:
416
+ R = r
417
+ else:
418
+ R += r
419
+ pbar.update(1)
420
+
421
+ results = R
422
+ return tensor_to_numpy(results)
423
+
424
+
425
+ def codebook_weights(assigns):
426
+ assigns = convert_to_tensor(assigns)
427
+ results = torch.sum(assigns, dim=0)
428
+ results = results / torch.sum(results)
429
+ return tensor_to_numpy(results)
430
+
431
+
432
+ def distance_to_softmax_weights(distances):
433
+ """使用softmax将距离列表转换为概率权重
434
+
435
+ 参数:
436
+ distances: 距离列表,距离越小权重应该越大
437
+
438
+ 返回:
439
+ 概率权重数组,和为1
440
+ """
441
+ distances = np.array(distances)
442
+ # 取负数使得距离越小值越大
443
+ negative_distances = -distances
444
+ # 计算softmax
445
+ exp_dist = np.exp(negative_distances - np.max(negative_distances)) # 数值稳定性处理
446
+ softmax = exp_dist / np.sum(exp_dist)
447
+ return softmax
448
+
449
+ def weighted_sample(items, weights, sample_size=1, replace=True):
450
+ """根据权重进行采样
451
+
452
+ 参数:
453
+ items: 待采样的列表
454
+ weights: 对应的概率权重列表
455
+ sample_size: 采样数量
456
+ replace: 是否允许重复采样
457
+
458
+ 返回:
459
+ 采样结果列表
460
+ """
461
+ return np.random.choice(
462
+ a=items,
463
+ size=sample_size,
464
+ p=weights,
465
+ replace=replace
466
+ )
467
+
468
+ def split_evenly(n, m):
469
+ """使用NumPy将n分成m个尽可能平均的数"""
470
+ arr = np.full(m, n // m)
471
+ arr[:n % m] += 1
472
+ return arr.tolist()
SURE/utils/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ # Importing specific functions from modules
2
+ from .utils import tensor_to_numpy, move_to_device, convert_to_tensor
3
+ from .utils import CustomDataset, CustomDataset2, CustomDataset3
4
+ from .utils import CustomMultiOmicsDataset, CustomMultiOmicsDataset2
5
+ from .utils import pretty_print, Colors
6
+ from .utils import find_partitions_greedy
7
+
8
+ from .queue import PriorityQueue
9
+
10
+ from .custom_mlp import MLP, Exp
11
+
12
+ # Importing modules
13
+ #from . import utils
14
+ #from . import custom_mlp
15
+
16
+ #__all__ = ['tensor_to_numpy', 'move_to_device', 'convert_to_tensor',
17
+ # 'CustomDataset', 'CustomDataset2', 'CustomDataset3',
18
+ # 'MLP','Exp',
19
+ # 'custom_mlp','utils']
@@ -0,0 +1,209 @@
1
+ # Copyright (c) 2017-2019 Uber Technologies, Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from inspect import isclass
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from pyro.distributions.util import broadcast_shape
10
+
11
+
12
+ class Exp(nn.Module):
13
+ """
14
+ a custom module for exponentiation of tensors
15
+ """
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, val):
21
+ return torch.exp(val)
22
+
23
+
24
+ class ConcatModule(nn.Module):
25
+ """
26
+ a custom module for concatenation of tensors
27
+ """
28
+
29
+ def __init__(self, allow_broadcast=False):
30
+ self.allow_broadcast = allow_broadcast
31
+ super().__init__()
32
+
33
+ def forward(self, *input_args):
34
+ # we have a single object
35
+ if len(input_args) == 1:
36
+ # regardless of type,
37
+ # we don't care about single objects
38
+ # we just index into the object
39
+ input_args = input_args[0]
40
+
41
+ # don't concat things that are just single objects
42
+ if torch.is_tensor(input_args):
43
+ return input_args
44
+ else:
45
+ if self.allow_broadcast:
46
+ shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,)
47
+ input_args = [s.expand(shape) for s in input_args]
48
+ return torch.cat(input_args, dim=-1)
49
+
50
+
51
+ class ListOutModule(nn.ModuleList):
52
+ """
53
+ a custom module for outputting a list of tensors from a list of nn modules
54
+ """
55
+
56
+ def __init__(self, modules):
57
+ super().__init__(modules)
58
+
59
+ def forward(self, *args, **kwargs):
60
+ # loop over modules in self, apply same args
61
+ return [mm.forward(*args, **kwargs) for mm in self]
62
+
63
+
64
+ def call_nn_op(op):
65
+ """
66
+ a helper function that adds appropriate parameters when calling
67
+ an nn module representing an operation like Softmax
68
+
69
+ :param op: the nn.Module operation to instantiate
70
+ :return: instantiation of the op module with appropriate parameters
71
+ """
72
+ if op in [nn.Softmax, nn.LogSoftmax]:
73
+ return op(dim=1)
74
+ else:
75
+ return op()
76
+
77
+
78
+ class MLP(nn.Module):
79
+ def __init__(
80
+ self,
81
+ mlp_sizes,
82
+ activation=nn.ReLU,
83
+ output_activation=None,
84
+ post_layer_fct=lambda layer_ix, total_layers, layer: None,
85
+ post_act_fct=lambda layer_ix, total_layers, layer: None,
86
+ allow_broadcast=False,
87
+ use_cuda=False,
88
+ ):
89
+ # init the module object
90
+ super().__init__()
91
+
92
+ assert len(mlp_sizes) >= 2, "Must have input and output layer sizes defined"
93
+
94
+ # get our inputs, outputs, and hidden
95
+ input_size, hidden_sizes, output_size = (
96
+ mlp_sizes[0],
97
+ mlp_sizes[1:-1],
98
+ mlp_sizes[-1],
99
+ )
100
+
101
+ # assume int or list
102
+ assert isinstance(
103
+ input_size, (int, list, tuple)
104
+ ), "input_size must be int, list, tuple"
105
+
106
+ # everything in MLP will be concatted if it's multiple arguments
107
+ last_layer_size = input_size if type(input_size) == int else sum(input_size)
108
+
109
+ # everything sent in will be concatted together by default
110
+ all_modules = [ConcatModule(allow_broadcast)]
111
+
112
+ # loop over l
113
+ for layer_ix, layer_size in enumerate(hidden_sizes):
114
+ assert type(layer_size) == int, "Hidden layer sizes must be ints"
115
+
116
+ # get our nn layer module (in this case nn.Linear by default)
117
+ cur_linear_layer = nn.Linear(last_layer_size, layer_size)
118
+
119
+ # for numerical stability -- initialize the layer properly
120
+ cur_linear_layer.weight.data.normal_(0, 0.001)
121
+ cur_linear_layer.bias.data.normal_(0, 0.001)
122
+
123
+ # use GPUs to share data during training (if available)
124
+ if use_cuda:
125
+ cur_linear_layer = nn.DataParallel(cur_linear_layer)
126
+
127
+ # add our linear layer
128
+ all_modules.append(cur_linear_layer)
129
+
130
+ # handle post_linear
131
+ post_linear = post_layer_fct(
132
+ layer_ix + 1, len(hidden_sizes), all_modules[-1]
133
+ )
134
+
135
+ # if we send something back, add it to sequential
136
+ # here we could return a batch norm for example
137
+ if post_linear is not None:
138
+ all_modules.append(post_linear)
139
+
140
+ # handle activation (assumed no params -- deal with that later)
141
+ all_modules.append(activation())
142
+
143
+ # now handle after activation
144
+ post_activation = post_act_fct(
145
+ layer_ix + 1, len(hidden_sizes), all_modules[-1]
146
+ )
147
+
148
+ # handle post_activation if not null
149
+ # could add batch norm for example
150
+ if post_activation is not None:
151
+ all_modules.append(post_activation)
152
+
153
+ # save the layer size we just created
154
+ last_layer_size = layer_size
155
+
156
+ # now we have all of our hidden layers
157
+ # we handle outputs
158
+ assert isinstance(
159
+ output_size, (int, list, tuple)
160
+ ), "output_size must be int, list, tuple"
161
+
162
+ if type(output_size) == int:
163
+ all_modules.append(nn.Linear(last_layer_size, output_size))
164
+ if output_activation is not None:
165
+ all_modules.append(
166
+ call_nn_op(output_activation)
167
+ if isclass(output_activation)
168
+ else output_activation
169
+ )
170
+ else:
171
+
172
+ # we're going to have a bunch of separate layers we can spit out (a tuple of outputs)
173
+ out_layers = []
174
+
175
+ # multiple outputs? handle separately
176
+ for out_ix, out_size in enumerate(output_size):
177
+
178
+ # for a single output object, we create a linear layer and some weights
179
+ split_layer = []
180
+
181
+ # we have an activation function
182
+ split_layer.append(nn.Linear(last_layer_size, out_size))
183
+
184
+ # then we get our output activation (either we repeat all or we index into a same sized array)
185
+ act_out_fct = (
186
+ output_activation
187
+ if not isinstance(output_activation, (list, tuple))
188
+ else output_activation[out_ix]
189
+ )
190
+
191
+ if act_out_fct:
192
+ # we check if it's a class. if so, instantiate the object
193
+ # otherwise, use the object directly (e.g. pre-instaniated)
194
+ split_layer.append(
195
+ call_nn_op(act_out_fct) if isclass(act_out_fct) else act_out_fct
196
+ )
197
+
198
+ # our outputs is just a sequential of the two
199
+ out_layers.append(nn.Sequential(*split_layer))
200
+
201
+ all_modules.append(ListOutModule(out_layers))
202
+
203
+ # now we have all of our modules, we're ready to build our sequential!
204
+ # process mlps in order, pretty standard here
205
+ self.sequential_mlp = nn.Sequential(*all_modules)
206
+
207
+ # pass through our sequential for the output!
208
+ def forward(self, *args, **kwargs):
209
+ return self.sequential_mlp.forward(*args, **kwargs)