HMAP-tool 1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hmap_tool-1.0/HMAP/HMAP.py +1298 -0
- hmap_tool-1.0/HMAP/__init__.py +8 -0
- hmap_tool-1.0/HMAP/atac/__init__.py +1 -0
- hmap_tool-1.0/HMAP/atac/utils.py +151 -0
- hmap_tool-1.0/HMAP/graph/__init__.py +1 -0
- hmap_tool-1.0/HMAP/graph/graph.py +323 -0
- hmap_tool-1.0/HMAP/utils/__init__.py +17 -0
- hmap_tool-1.0/HMAP/utils/custom_mlp.py +209 -0
- hmap_tool-1.0/HMAP/utils/utils.py +308 -0
- hmap_tool-1.0/HMAP_tool.egg-info/PKG-INFO +70 -0
- hmap_tool-1.0/HMAP_tool.egg-info/SOURCES.txt +17 -0
- hmap_tool-1.0/HMAP_tool.egg-info/dependency_links.txt +1 -0
- hmap_tool-1.0/HMAP_tool.egg-info/entry_points.txt +2 -0
- hmap_tool-1.0/HMAP_tool.egg-info/requires.txt +13 -0
- hmap_tool-1.0/HMAP_tool.egg-info/top_level.txt +1 -0
- hmap_tool-1.0/PKG-INFO +70 -0
- hmap_tool-1.0/README.md +35 -0
- hmap_tool-1.0/setup.cfg +4 -0
- hmap_tool-1.0/setup.py +30 -0
|
@@ -0,0 +1,1298 @@
|
|
|
1
|
+
import pyro
|
|
2
|
+
import pyro.distributions as dist
|
|
3
|
+
from pyro.optim import ExponentialLR
|
|
4
|
+
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
|
|
10
|
+
from torch.distributions import constraints
|
|
11
|
+
from torch.distributions.transforms import SoftmaxTransform
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import os
|
|
15
|
+
import time as tm
|
|
16
|
+
import random
|
|
17
|
+
import itertools
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import numpy as np
|
|
20
|
+
import datatable as dt
|
|
21
|
+
import networkx as nx
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
import warnings
|
|
24
|
+
warnings.filterwarnings("ignore")
|
|
25
|
+
|
|
26
|
+
from .utils.utils import convert_to_tensor, tensor_to_numpy, CustomDataset, CustomDataset2, CustomDataset4
|
|
27
|
+
from .utils.custom_mlp import MLP, Exp
|
|
28
|
+
from .graph.graph import compute_metacell_diffusion_kernel,visualize_metacell_igraph_with_fa2
|
|
29
|
+
from .atac import binarize
|
|
30
|
+
|
|
31
|
+
from tqdm import tqdm
|
|
32
|
+
from typing import Literal
|
|
33
|
+
import dill as pickle
|
|
34
|
+
import gzip
|
|
35
|
+
import scanpy as sc
|
|
36
|
+
from scipy import sparse
|
|
37
|
+
|
|
38
|
+
def set_random_seed(seed):
|
|
39
|
+
# Set seed for PyTorch
|
|
40
|
+
torch.manual_seed(seed)
|
|
41
|
+
|
|
42
|
+
# If using CUDA, set the seed for CUDA
|
|
43
|
+
if torch.cuda.is_available():
|
|
44
|
+
torch.cuda.manual_seed(seed)
|
|
45
|
+
torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
|
|
46
|
+
|
|
47
|
+
# Set seed for NumPy
|
|
48
|
+
np.random.seed(seed)
|
|
49
|
+
|
|
50
|
+
# Set seed for Python's random module
|
|
51
|
+
random.seed(seed)
|
|
52
|
+
|
|
53
|
+
# Set seed for Pyro
|
|
54
|
+
pyro.set_rng_seed(seed)
|
|
55
|
+
|
|
56
|
+
class HMAP(nn.Module):
|
|
57
|
+
def __init__(self,
|
|
58
|
+
input_size: int,
|
|
59
|
+
undesired_size: int = 0,
|
|
60
|
+
codebook_size: int = 30, # size of metacell codebook
|
|
61
|
+
supervised_mode: bool = False,
|
|
62
|
+
use_cell_factor: bool = False,
|
|
63
|
+
d_dim: int = 3, # dimension of a metacell variable
|
|
64
|
+
d_dist: Literal['normal','laplacian','caucy','studentt','vonmises','gumbel'] = 'normal',
|
|
65
|
+
z_dim: int = 10,
|
|
66
|
+
z_dist: Literal['normal','laplacian','cauchy','studentt','gumbel'] = 'laplacian',
|
|
67
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
|
|
68
|
+
hidden_layers: list =[300],
|
|
69
|
+
hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
|
|
70
|
+
inverse_dispersion: float = 10.0,
|
|
71
|
+
nn_dropout: float = 0.1,
|
|
72
|
+
post_layer_fct: list = ['layernorm'],
|
|
73
|
+
post_act_fct: list = None,
|
|
74
|
+
config_enum: str = 'parallel',
|
|
75
|
+
use_cuda: bool = False,
|
|
76
|
+
seed: int = 42,
|
|
77
|
+
dtype = torch.float32, # type: ignore
|
|
78
|
+
):
|
|
79
|
+
super().__init__()
|
|
80
|
+
|
|
81
|
+
# initialize the class with all arguments provided to the constructor
|
|
82
|
+
self.input_size = input_size
|
|
83
|
+
self.undesired_size = undesired_size
|
|
84
|
+
self.inverse_dispersion = inverse_dispersion
|
|
85
|
+
self.z_dim = z_dim
|
|
86
|
+
self.hidden_layers = hidden_layers
|
|
87
|
+
self.use_undesired = True if self.undesired_size>0 else False
|
|
88
|
+
self.allow_broadcast = config_enum == 'parallel'
|
|
89
|
+
self.use_cuda = use_cuda
|
|
90
|
+
self.loss_func = loss_func
|
|
91
|
+
self.options = None
|
|
92
|
+
self.code_size=codebook_size
|
|
93
|
+
self.D_size=d_dim
|
|
94
|
+
self.z_dist = z_dist
|
|
95
|
+
self.d_dist = d_dist
|
|
96
|
+
self.G = None
|
|
97
|
+
self.supervised_mode = supervised_mode
|
|
98
|
+
self.dtype = dtype
|
|
99
|
+
self.use_cell_factor = use_cell_factor
|
|
100
|
+
self.normalize = True
|
|
101
|
+
|
|
102
|
+
self.nn_dropout = nn_dropout
|
|
103
|
+
self.post_layer_fct = post_layer_fct
|
|
104
|
+
self.post_act_fct = post_act_fct
|
|
105
|
+
self.hidden_layer_activation = hidden_layer_activation
|
|
106
|
+
|
|
107
|
+
assert loss_func in ['poisson','multinomial','negbinomial','bernoulli']
|
|
108
|
+
|
|
109
|
+
if seed is not None:
|
|
110
|
+
set_random_seed(seed)
|
|
111
|
+
|
|
112
|
+
# define and instantiate the neural networks representing
|
|
113
|
+
# the parameters of various distributions in the model
|
|
114
|
+
self.setup_networks()
|
|
115
|
+
|
|
116
|
+
def setup_networks(self):
|
|
117
|
+
z_dim = self.z_dim
|
|
118
|
+
hidden_sizes = self.hidden_layers
|
|
119
|
+
|
|
120
|
+
nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
|
|
121
|
+
na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
|
|
122
|
+
|
|
123
|
+
if self.post_layer_fct is not None:
|
|
124
|
+
nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
|
|
125
|
+
nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
|
|
126
|
+
nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
|
|
127
|
+
|
|
128
|
+
if self.post_act_fct is not None:
|
|
129
|
+
na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
|
|
130
|
+
na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
|
|
131
|
+
na_layer_dropout=True if 'dropout' in self.post_act_fct else False
|
|
132
|
+
|
|
133
|
+
if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
|
|
134
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
135
|
+
elif nn_layer_norm and nn_layer_dropout:
|
|
136
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
|
|
137
|
+
elif nn_batch_norm and nn_layer_dropout:
|
|
138
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
|
|
139
|
+
elif nn_layer_norm and nn_batch_norm:
|
|
140
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
141
|
+
elif nn_layer_norm:
|
|
142
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
|
|
143
|
+
elif nn_batch_norm:
|
|
144
|
+
post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
|
|
145
|
+
elif nn_layer_dropout:
|
|
146
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
|
|
147
|
+
else:
|
|
148
|
+
post_layer_fct = lambda layer_ix, total_layers, layer: None
|
|
149
|
+
|
|
150
|
+
if na_layer_norm and na_batch_norm and na_layer_dropout:
|
|
151
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
152
|
+
elif na_layer_norm and na_layer_dropout:
|
|
153
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
|
|
154
|
+
elif na_batch_norm and na_layer_dropout:
|
|
155
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
|
|
156
|
+
elif na_layer_norm and na_batch_norm:
|
|
157
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
|
|
158
|
+
elif na_layer_norm:
|
|
159
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
|
|
160
|
+
elif na_batch_norm:
|
|
161
|
+
post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
|
|
162
|
+
elif na_layer_dropout:
|
|
163
|
+
post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
|
|
164
|
+
else:
|
|
165
|
+
post_act_fct = lambda layer_ix, total_layers, layer: None
|
|
166
|
+
|
|
167
|
+
if self.hidden_layer_activation == 'relu':
|
|
168
|
+
activate_fct = nn.ReLU
|
|
169
|
+
elif self.hidden_layer_activation == 'softplus':
|
|
170
|
+
activate_fct = nn.Softplus
|
|
171
|
+
elif self.hidden_layer_activation == 'leakyrelu':
|
|
172
|
+
activate_fct = nn.LeakyReLU
|
|
173
|
+
elif self.hidden_layer_activation == 'linear':
|
|
174
|
+
activate_fct = nn.Identity
|
|
175
|
+
|
|
176
|
+
# define the neural networks used later in the model and the guide.
|
|
177
|
+
if self.supervised_mode:
|
|
178
|
+
self.encoder_n = MLP(
|
|
179
|
+
[self.input_size] + hidden_sizes + [self.code_size],
|
|
180
|
+
activation=activate_fct,
|
|
181
|
+
output_activation=None,
|
|
182
|
+
post_layer_fct=post_layer_fct,
|
|
183
|
+
post_act_fct=post_act_fct,
|
|
184
|
+
allow_broadcast=self.allow_broadcast,
|
|
185
|
+
use_cuda=self.use_cuda,
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
self.encoder_n = MLP(
|
|
189
|
+
[self.D_size] + hidden_sizes + [self.code_size],
|
|
190
|
+
activation=activate_fct,
|
|
191
|
+
output_activation=None,
|
|
192
|
+
post_layer_fct=post_layer_fct,
|
|
193
|
+
post_act_fct=post_act_fct,
|
|
194
|
+
allow_broadcast=self.allow_broadcast,
|
|
195
|
+
use_cuda=self.use_cuda,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self.encoder_zn = MLP(
|
|
199
|
+
[self.input_size] + hidden_sizes + [[z_dim, z_dim]],
|
|
200
|
+
activation=activate_fct,
|
|
201
|
+
output_activation=[None, Exp],
|
|
202
|
+
post_layer_fct=post_layer_fct,
|
|
203
|
+
post_act_fct=post_act_fct,
|
|
204
|
+
allow_broadcast=self.allow_broadcast,
|
|
205
|
+
use_cuda=self.use_cuda,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
self.encoder_d = MLP(
|
|
209
|
+
[self.z_dim] + hidden_sizes + [[self.D_size, self.D_size]],
|
|
210
|
+
activation=activate_fct,
|
|
211
|
+
output_activation=[None, Exp],
|
|
212
|
+
post_layer_fct=post_layer_fct,
|
|
213
|
+
post_act_fct=post_act_fct,
|
|
214
|
+
allow_broadcast=self.allow_broadcast,
|
|
215
|
+
use_cuda=self.use_cuda,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if self.use_cell_factor:
|
|
219
|
+
self.cell_factor = MLP(
|
|
220
|
+
[self.input_size] + hidden_sizes + [1],
|
|
221
|
+
activation=activate_fct,
|
|
222
|
+
output_activation=None,
|
|
223
|
+
post_layer_fct=post_layer_fct,
|
|
224
|
+
post_act_fct=post_act_fct,
|
|
225
|
+
allow_broadcast=self.allow_broadcast,
|
|
226
|
+
use_cuda=self.use_cuda,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if self.use_undesired:
|
|
230
|
+
self.decoder_concentrate = MLP(
|
|
231
|
+
[self.undesired_size + self.z_dim] + hidden_sizes + [self.input_size],
|
|
232
|
+
activation=activate_fct,
|
|
233
|
+
output_activation=None,
|
|
234
|
+
post_layer_fct=post_layer_fct,
|
|
235
|
+
post_act_fct=post_act_fct,
|
|
236
|
+
allow_broadcast=self.allow_broadcast,
|
|
237
|
+
use_cuda=self.use_cuda,
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
self.decoder_concentrate = MLP(
|
|
241
|
+
[self.z_dim] + hidden_sizes + [self.input_size],
|
|
242
|
+
activation=activate_fct,
|
|
243
|
+
output_activation=None,
|
|
244
|
+
post_layer_fct=post_layer_fct,
|
|
245
|
+
post_act_fct=post_act_fct,
|
|
246
|
+
allow_broadcast=self.allow_broadcast,
|
|
247
|
+
use_cuda=self.use_cuda,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if self.d_dist == 'studentt':
|
|
251
|
+
self.codebook = MLP(
|
|
252
|
+
[self.code_size] + hidden_sizes + [[self.D_size,self.D_size]],
|
|
253
|
+
activation=activate_fct,
|
|
254
|
+
output_activation=[Exp, None],
|
|
255
|
+
post_layer_fct=post_layer_fct,
|
|
256
|
+
post_act_fct=post_act_fct,
|
|
257
|
+
allow_broadcast=self.allow_broadcast,
|
|
258
|
+
use_cuda=self.use_cuda,
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
self.codebook = MLP(
|
|
262
|
+
[self.code_size] + hidden_sizes + [self.D_size],
|
|
263
|
+
activation=activate_fct,
|
|
264
|
+
output_activation=None,
|
|
265
|
+
post_layer_fct=post_layer_fct,
|
|
266
|
+
post_act_fct=post_act_fct,
|
|
267
|
+
allow_broadcast=self.allow_broadcast,
|
|
268
|
+
use_cuda=self.use_cuda,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if self.z_dist == 'studentt':
|
|
272
|
+
self.decoder_zn = MLP(
|
|
273
|
+
[self.D_size] + hidden_sizes + [[self.z_dim,self.z_dim]],
|
|
274
|
+
activation=activate_fct,
|
|
275
|
+
output_activation=[Exp,None],
|
|
276
|
+
post_layer_fct=post_layer_fct,
|
|
277
|
+
post_act_fct=post_act_fct,
|
|
278
|
+
allow_broadcast=self.allow_broadcast,
|
|
279
|
+
use_cuda=self.use_cuda,
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
self.decoder_zn = MLP(
|
|
283
|
+
[self.D_size] + hidden_sizes + [self.z_dim],
|
|
284
|
+
activation=activate_fct,
|
|
285
|
+
output_activation=None,
|
|
286
|
+
post_layer_fct=post_layer_fct,
|
|
287
|
+
post_act_fct=post_act_fct,
|
|
288
|
+
allow_broadcast=self.allow_broadcast,
|
|
289
|
+
use_cuda=self.use_cuda,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# using GPUs for faster training of the networks
|
|
293
|
+
if self.use_cuda:
|
|
294
|
+
self.cuda()
|
|
295
|
+
|
|
296
|
+
def cutoff(self, xs, thresh=None):
|
|
297
|
+
eps = torch.finfo(xs.dtype).eps
|
|
298
|
+
|
|
299
|
+
if not thresh is None:
|
|
300
|
+
if eps < thresh:
|
|
301
|
+
eps = thresh
|
|
302
|
+
|
|
303
|
+
xs = xs.clamp(min=eps)
|
|
304
|
+
|
|
305
|
+
if torch.any(torch.isnan(xs)):
|
|
306
|
+
xs[torch.isnan(xs)] = eps
|
|
307
|
+
|
|
308
|
+
return xs
|
|
309
|
+
|
|
310
|
+
def softmax(self, xs):
|
|
311
|
+
#soft_enc = nn.Softmax(dim=1)
|
|
312
|
+
#xs = soft_enc(xs)
|
|
313
|
+
#xs = clamp_probs(xs)
|
|
314
|
+
#xs = ft.normalize(xs, 1, 1)
|
|
315
|
+
xs = SoftmaxTransform()(xs)
|
|
316
|
+
return xs
|
|
317
|
+
|
|
318
|
+
def sigmoid(self, xs):
|
|
319
|
+
sigm_enc = nn.Sigmoid()
|
|
320
|
+
xs = sigm_enc(xs)
|
|
321
|
+
xs = clamp_probs(xs)
|
|
322
|
+
return xs
|
|
323
|
+
|
|
324
|
+
def softmax_logit(self, xs):
|
|
325
|
+
eps = torch.finfo(xs.dtype).eps
|
|
326
|
+
xs = self.softmax(xs)
|
|
327
|
+
xs = torch.logit(xs, eps=eps)
|
|
328
|
+
return xs
|
|
329
|
+
|
|
330
|
+
def logit(self, xs):
|
|
331
|
+
eps = torch.finfo(xs.dtype).eps
|
|
332
|
+
xs = torch.logit(xs, eps=eps)
|
|
333
|
+
return xs
|
|
334
|
+
|
|
335
|
+
def dirimulti_param(self, xs):
|
|
336
|
+
xs = self.dirimulti_mass * self.sigmoid(xs)
|
|
337
|
+
return xs
|
|
338
|
+
|
|
339
|
+
def multi_param(self, xs):
|
|
340
|
+
xs = self.softmax(xs)
|
|
341
|
+
return xs
|
|
342
|
+
|
|
343
|
+
def get_device(self):
|
|
344
|
+
return next(self.parameters()).device
|
|
345
|
+
|
|
346
|
+
def model(self, xs, embeds=None, ks2=None):
|
|
347
|
+
# register this pytorch module and all of its sub-modules with pyro
|
|
348
|
+
pyro.module('HMAP', self)
|
|
349
|
+
|
|
350
|
+
eps = torch.finfo(xs.dtype).eps
|
|
351
|
+
batch_size = xs.size(0)
|
|
352
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
353
|
+
|
|
354
|
+
if self.loss_func == 'negbinomial':
|
|
355
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion * torch.ones(1, self.input_size, **self.options),
|
|
356
|
+
constraint=constraints.positive)
|
|
357
|
+
|
|
358
|
+
acs_scale = pyro.param("codebook_scale", torch.ones(1, self.D_size, **self.options),
|
|
359
|
+
constraint=constraints.positive)
|
|
360
|
+
zn_scale = pyro.param("z_scale", torch.ones(1, self.z_dim, **self.options),
|
|
361
|
+
constraint=constraints.positive)
|
|
362
|
+
|
|
363
|
+
I = torch.eye(self.code_size)
|
|
364
|
+
if self.d_dist == 'studentt':
|
|
365
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
366
|
+
else:
|
|
367
|
+
acs_loc = self.codebook(I)
|
|
368
|
+
|
|
369
|
+
with pyro.plate('data'):
|
|
370
|
+
###############################################
|
|
371
|
+
# p(zn)
|
|
372
|
+
prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
373
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
|
|
374
|
+
|
|
375
|
+
if self.d_dist == 'studentt':
|
|
376
|
+
d_dof = torch.matmul(ns,acs_dof)
|
|
377
|
+
d_loc = torch.matmul(ns,acs_loc)
|
|
378
|
+
d_scale = acs_scale
|
|
379
|
+
if self.d_dist == 'normal':
|
|
380
|
+
ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
|
|
381
|
+
elif self.d_dist == 'laplacian':
|
|
382
|
+
ds = pyro.sample('d', dist.Laplace(d_loc, d_scale).to_event(1))
|
|
383
|
+
elif self.d_dist == 'cauchy':
|
|
384
|
+
ds = pyro.sample('d', dist.Cauchy(d_loc, d_scale).to_event(1))
|
|
385
|
+
elif self.d_dist == 'vonmises':
|
|
386
|
+
ds = pyro.sample('d', dist.VonMises(d_loc, d_scale).to_event(1))
|
|
387
|
+
elif self.d_dist == 'gumbel':
|
|
388
|
+
ds = pyro.sample('d', dist.Gumbel(d_loc, d_scale).to_event(1))
|
|
389
|
+
elif self.d_dist == 'studentt':
|
|
390
|
+
ds = pyro.sample('d', dist.StudentT(df=d_dof, loc=d_loc, scale=d_scale).to_event(1))
|
|
391
|
+
|
|
392
|
+
if self.z_dist == 'studentt':
|
|
393
|
+
zn_dof,zn_loc = self.decoder_zn(ds)
|
|
394
|
+
else:
|
|
395
|
+
zn_loc = self.decoder_zn(ds)
|
|
396
|
+
if self.z_dist == 'laplacian':
|
|
397
|
+
if embeds is None:
|
|
398
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
399
|
+
else:
|
|
400
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
401
|
+
elif self.z_dist == 'cauchy':
|
|
402
|
+
if embeds is None:
|
|
403
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
404
|
+
else:
|
|
405
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
406
|
+
elif self.z_dist == 'normal':
|
|
407
|
+
if embeds is None:
|
|
408
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
409
|
+
else:
|
|
410
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
411
|
+
elif self.z_dist == 'studentt':
|
|
412
|
+
if embeds is None:
|
|
413
|
+
zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
414
|
+
else:
|
|
415
|
+
zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
|
|
416
|
+
elif self.z_dist == 'gumbel':
|
|
417
|
+
if embeds is None:
|
|
418
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
419
|
+
else:
|
|
420
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
421
|
+
|
|
422
|
+
###############################################
|
|
423
|
+
# p(a | zys, zk2s)
|
|
424
|
+
if self.use_undesired:
|
|
425
|
+
zs = [ks2, zns]
|
|
426
|
+
else:
|
|
427
|
+
zs = zns
|
|
428
|
+
|
|
429
|
+
concentrate = self.decoder_concentrate(zs)
|
|
430
|
+
if self.use_cell_factor:
|
|
431
|
+
cf = self.cell_factor(xs)
|
|
432
|
+
concentrate += cf
|
|
433
|
+
|
|
434
|
+
if self.normalize:
|
|
435
|
+
rate = concentrate.exp()
|
|
436
|
+
if self.loss_func != 'poisson':
|
|
437
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
438
|
+
else:
|
|
439
|
+
if self.loss_func == 'poisson':
|
|
440
|
+
rate = concentrate.exp()
|
|
441
|
+
else:
|
|
442
|
+
logits = concentrate
|
|
443
|
+
|
|
444
|
+
if self.loss_func == 'negbinomial':
|
|
445
|
+
if self.normalize:
|
|
446
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
447
|
+
else:
|
|
448
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=logits).to_event(1), obs=xs)
|
|
449
|
+
elif self.loss_func == 'multinomial':
|
|
450
|
+
if self.normalize:
|
|
451
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
452
|
+
else:
|
|
453
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=logits), obs=xs)
|
|
454
|
+
elif self.loss_func == 'bernoulli':
|
|
455
|
+
if self.normalize:
|
|
456
|
+
pyro.sample('x', dist.Bernoulli(probs=theta).to_event(1), obs=xs)
|
|
457
|
+
else:
|
|
458
|
+
pyro.sample('x', dist.Bernoulli(logits=logits).to_event(1), obs=xs)
|
|
459
|
+
elif self.loss_func == 'poisson':
|
|
460
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
461
|
+
|
|
462
|
+
def model2(self, xs, ys=None, embeds=None, ks2=None):
|
|
463
|
+
# register this pytorch module and all of its sub-modules with pyro
|
|
464
|
+
pyro.module('HMAP', self)
|
|
465
|
+
|
|
466
|
+
eps = torch.finfo(xs.dtype).eps
|
|
467
|
+
batch_size = xs.size(0)
|
|
468
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
469
|
+
|
|
470
|
+
if self.loss_func == 'negbinomial':
|
|
471
|
+
total_count = pyro.param("inverse_dispersion", self.inverse_dispersion * torch.ones(1, self.input_size, **self.options),
|
|
472
|
+
constraint=constraints.positive)
|
|
473
|
+
acs_scale = pyro.param("codebook_scale", torch.ones(1, self.D_size, **self.options),
|
|
474
|
+
constraint=constraints.positive)
|
|
475
|
+
zn_scale = pyro.param("z_scale", torch.ones(1, self.z_dim, **self.options),
|
|
476
|
+
constraint=constraints.positive)
|
|
477
|
+
|
|
478
|
+
I = torch.eye(self.code_size)
|
|
479
|
+
if self.d_dist == 'studentt':
|
|
480
|
+
acs_dof,acs_loc = self.codebook(I)
|
|
481
|
+
else:
|
|
482
|
+
acs_loc = self.codebook(I)
|
|
483
|
+
|
|
484
|
+
with pyro.plate('data'):
|
|
485
|
+
###############################################
|
|
486
|
+
# p(zn)
|
|
487
|
+
if ys is None:
|
|
488
|
+
prior = torch.zeros(batch_size, self.code_size, **self.options)
|
|
489
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
|
|
490
|
+
else:
|
|
491
|
+
prior = self.encoder_n(xs)
|
|
492
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
|
|
493
|
+
|
|
494
|
+
if self.d_dist == 'studentt':
|
|
495
|
+
d_dof = torch.matmul(ns,acs_dof)
|
|
496
|
+
d_loc = torch.matmul(ns,acs_loc)
|
|
497
|
+
d_scale = acs_scale
|
|
498
|
+
if self.d_dist == 'normal':
|
|
499
|
+
ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
|
|
500
|
+
elif self.d_dist == 'laplacian':
|
|
501
|
+
ds = pyro.sample('d', dist.Laplace(d_loc, d_scale).to_event(1))
|
|
502
|
+
elif self.d_dist == 'cauchy':
|
|
503
|
+
ds = pyro.sample('d', dist.Cauchy(d_loc, d_scale).to_event(1))
|
|
504
|
+
elif self.d_dist == 'vonmises':
|
|
505
|
+
ds = pyro.sample('d', dist.VonMises(d_loc, d_scale).to_event(1))
|
|
506
|
+
elif self.d_dist == 'gumbel':
|
|
507
|
+
ds = pyro.sample('d', dist.Gumbel(d_loc, d_scale).to_event(1))
|
|
508
|
+
elif self.d_dist == 'studentt':
|
|
509
|
+
ds = pyro.sample('d', dist.StudentT(df=d_dof, loc=d_loc, scale=d_scale).to_event(1))
|
|
510
|
+
|
|
511
|
+
if self.z_dist == 'studentt':
|
|
512
|
+
zn_dof,zn_loc = self.decoder_zn(ds)
|
|
513
|
+
else:
|
|
514
|
+
zn_loc = self.decoder_zn(ds)
|
|
515
|
+
if self.z_dist == 'laplacian':
|
|
516
|
+
if embeds is None:
|
|
517
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
|
|
518
|
+
else:
|
|
519
|
+
zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
520
|
+
elif self.z_dist == 'cauchy':
|
|
521
|
+
if embeds is None:
|
|
522
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
|
|
523
|
+
else:
|
|
524
|
+
zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
525
|
+
elif self.z_dist == 'normal':
|
|
526
|
+
if embeds is None:
|
|
527
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
528
|
+
else:
|
|
529
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
530
|
+
elif self.z_dist == 'studentt':
|
|
531
|
+
if embeds is None:
|
|
532
|
+
zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1))
|
|
533
|
+
else:
|
|
534
|
+
zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
|
|
535
|
+
elif self.z_dist == 'gumbel':
|
|
536
|
+
if embeds is None:
|
|
537
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
538
|
+
else:
|
|
539
|
+
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
540
|
+
|
|
541
|
+
###############################################
|
|
542
|
+
# p(a | zys, zk2s)
|
|
543
|
+
if self.use_undesired:
|
|
544
|
+
zs = [ks2, zns]
|
|
545
|
+
else:
|
|
546
|
+
zs = zns
|
|
547
|
+
|
|
548
|
+
concentrate = self.decoder_concentrate(zs)
|
|
549
|
+
if self.use_cell_factor:
|
|
550
|
+
cf = self.cell_factor(xs)
|
|
551
|
+
concentrate += cf
|
|
552
|
+
|
|
553
|
+
if self.normalize:
|
|
554
|
+
rate = concentrate.exp()
|
|
555
|
+
if self.loss_func != 'poisson':
|
|
556
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
557
|
+
else:
|
|
558
|
+
if self.loss_func == 'poisson':
|
|
559
|
+
rate = concentrate.exp()
|
|
560
|
+
else:
|
|
561
|
+
logits = concentrate
|
|
562
|
+
|
|
563
|
+
if self.loss_func == 'negbinomial':
|
|
564
|
+
if self.normalize:
|
|
565
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
566
|
+
else:
|
|
567
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=logits).to_event(1), obs=xs)
|
|
568
|
+
elif self.loss_func == 'multinomial':
|
|
569
|
+
if self.normalize:
|
|
570
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
571
|
+
else:
|
|
572
|
+
pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=logits), obs=xs)
|
|
573
|
+
elif self.loss_func == 'bernoulli':
|
|
574
|
+
if self.normalize:
|
|
575
|
+
pyro.sample('x', dist.Bernoulli(probs=theta).to_event(1), obs=xs)
|
|
576
|
+
else:
|
|
577
|
+
pyro.sample('x', dist.Bernoulli(logits=logits).to_event(1), obs=xs)
|
|
578
|
+
elif self.loss_func == 'poisson':
|
|
579
|
+
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
580
|
+
|
|
581
|
+
def guide(self, xs, embeds=None, ks2=None):
|
|
582
|
+
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
|
|
583
|
+
with pyro.plate('data'):
|
|
584
|
+
# q(zn | x)
|
|
585
|
+
if embeds is None:
|
|
586
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
587
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
588
|
+
else:
|
|
589
|
+
zns = embeds
|
|
590
|
+
|
|
591
|
+
d_loc,d_scale = self.encoder_d(zns)
|
|
592
|
+
ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
|
|
593
|
+
|
|
594
|
+
alpha = self.encoder_n(ds)
|
|
595
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
596
|
+
|
|
597
|
+
def guide2(self, xs, ys=None, embeds=None, ks2=None):
|
|
598
|
+
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
|
|
599
|
+
with pyro.plate('data'):
|
|
600
|
+
# q(zn | x)
|
|
601
|
+
if embeds is None:
|
|
602
|
+
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
603
|
+
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
604
|
+
else:
|
|
605
|
+
zns = embeds
|
|
606
|
+
|
|
607
|
+
d_loc,d_scale = self.encoder_d(zns)
|
|
608
|
+
ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
|
|
609
|
+
|
|
610
|
+
if ys is None:
|
|
611
|
+
alpha = self.encoder_n(ds)
|
|
612
|
+
ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
|
|
613
|
+
|
|
614
|
+
def _codebook(self):
|
|
615
|
+
I = torch.eye(self.code_size, **self.options)
|
|
616
|
+
if self.d_dist == 'studentt':
|
|
617
|
+
_,loc = self.codebook(I)
|
|
618
|
+
else:
|
|
619
|
+
loc = self.codebook(I)
|
|
620
|
+
return loc
|
|
621
|
+
|
|
622
|
+
def get_codebook(self):
|
|
623
|
+
I = torch.eye(self.code_size, **self.options)
|
|
624
|
+
if self.d_dist == 'studentt':
|
|
625
|
+
_,cb = self.codebook(I)
|
|
626
|
+
else:
|
|
627
|
+
cb = self.codebook(I)
|
|
628
|
+
|
|
629
|
+
if self.z_dist == 'studentt':
|
|
630
|
+
_,zs = self.decoder_zn(cb)
|
|
631
|
+
else:
|
|
632
|
+
zs = self.decoder_zn(cb)
|
|
633
|
+
return tensor_to_numpy(cb), tensor_to_numpy(zs)
|
|
634
|
+
|
|
635
|
+
def _code(self, xs):
|
|
636
|
+
if self.supervised_mode:
|
|
637
|
+
alpha = self.encoder_n(xs)
|
|
638
|
+
else:
|
|
639
|
+
zns,_ = self.encoder_zn(xs)
|
|
640
|
+
ds,_ = self.encoder_d(zns)
|
|
641
|
+
alpha = self.encoder_n(ds)
|
|
642
|
+
return alpha
|
|
643
|
+
|
|
644
|
+
def _l0_embedding(self,xs):
|
|
645
|
+
ns = self._soft_assignments(xs)
|
|
646
|
+
acs = self._codebook()
|
|
647
|
+
return torch.matmul(ns, acs)
|
|
648
|
+
|
|
649
|
+
def _l1_embedding(self,xs,zs=None):
|
|
650
|
+
if zs is None:
|
|
651
|
+
zs = self._l2_embedding(xs)
|
|
652
|
+
ds,_ = self.encoder_d(zs)
|
|
653
|
+
return ds
|
|
654
|
+
|
|
655
|
+
def _l2_embedding(self, xs):
|
|
656
|
+
zns, _ = self.encoder_zn(xs)
|
|
657
|
+
return zns
|
|
658
|
+
|
|
659
|
+
def get_l0_embedding(self, xs, batch_size=1024):
|
|
660
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
661
|
+
dataset = CustomDataset(xs)
|
|
662
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
663
|
+
|
|
664
|
+
Z = []
|
|
665
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
666
|
+
for X_batch, _ in dataloader:
|
|
667
|
+
zns = self._l0_embedding(X_batch)
|
|
668
|
+
Z.append(tensor_to_numpy(zns))
|
|
669
|
+
pbar.update(1)
|
|
670
|
+
|
|
671
|
+
Z = np.concatenate(Z)
|
|
672
|
+
return Z
|
|
673
|
+
|
|
674
|
+
def get_l1_embedding(self, xs, zs=None, batch_size=1024):
|
|
675
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
676
|
+
if zs is not None:
|
|
677
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
678
|
+
dataset = CustomDataset2(xs, zs)
|
|
679
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
680
|
+
|
|
681
|
+
Z = []
|
|
682
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
683
|
+
for X_batch, Z_batch, _ in dataloader:
|
|
684
|
+
if zs is None:
|
|
685
|
+
Z_batch = None
|
|
686
|
+
zns = self._l1_embedding(X_batch, zs=Z_batch)
|
|
687
|
+
Z.append(tensor_to_numpy(zns))
|
|
688
|
+
pbar.update(1)
|
|
689
|
+
|
|
690
|
+
Z = np.concatenate(Z)
|
|
691
|
+
return Z
|
|
692
|
+
|
|
693
|
+
def get_l2_embedding(self,xs,batch_size=1024):
|
|
694
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
695
|
+
dataset = CustomDataset(xs)
|
|
696
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
697
|
+
|
|
698
|
+
Z = []
|
|
699
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
700
|
+
for X_batch, _ in dataloader:
|
|
701
|
+
zns = self._l2_embedding(X_batch)
|
|
702
|
+
Z.append(tensor_to_numpy(zns))
|
|
703
|
+
pbar.update(1)
|
|
704
|
+
|
|
705
|
+
Z = np.concatenate(Z)
|
|
706
|
+
return Z
|
|
707
|
+
|
|
708
|
+
def _soft_assignments(self, xs):
|
|
709
|
+
alpha = self._code(xs)
|
|
710
|
+
alpha = self.softmax(alpha)
|
|
711
|
+
return alpha
|
|
712
|
+
|
|
713
|
+
def soft_assignments(self, xs, batch_size=1024):
|
|
714
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
715
|
+
dataset = CustomDataset(xs)
|
|
716
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
717
|
+
|
|
718
|
+
A = []
|
|
719
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
720
|
+
for X_batch, _ in dataloader:
|
|
721
|
+
a = self._soft_assignments(X_batch)
|
|
722
|
+
A.append(tensor_to_numpy(a))
|
|
723
|
+
pbar.update(1)
|
|
724
|
+
|
|
725
|
+
A = np.concatenate(A)
|
|
726
|
+
return A
|
|
727
|
+
|
|
728
|
+
def _hard_assignments(self, xs):
|
|
729
|
+
alpha = self._code(xs)
|
|
730
|
+
res, ind = torch.topk(alpha, 1)
|
|
731
|
+
ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
|
|
732
|
+
return ns
|
|
733
|
+
|
|
734
|
+
def hard_assignments(self, xs, batch_size=1024):
|
|
735
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
736
|
+
dataset = CustomDataset(xs)
|
|
737
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
738
|
+
|
|
739
|
+
A = []
|
|
740
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
741
|
+
for X_batch, _ in dataloader:
|
|
742
|
+
a = self._hard_assignments(X_batch)
|
|
743
|
+
A.append(tensor_to_numpy(a))
|
|
744
|
+
pbar.update(1)
|
|
745
|
+
|
|
746
|
+
A = np.concatenate(A)
|
|
747
|
+
return A
|
|
748
|
+
|
|
749
|
+
def mmd_gaussian(self, x, y, sigma=None):
|
|
750
|
+
"""
|
|
751
|
+
Compute MMD with Gaussian kernel between samples x and y.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
x: Tensor of shape (n_samples, n_features)
|
|
755
|
+
y: Tensor of shape (m_samples, n_features)
|
|
756
|
+
sigma: Bandwidth of the Gaussian kernel. If None, uses median heuristic.
|
|
757
|
+
|
|
758
|
+
Returns:
|
|
759
|
+
mmd: Squared MMD value.
|
|
760
|
+
"""
|
|
761
|
+
n, m = x.size(0), y.size(0)
|
|
762
|
+
|
|
763
|
+
if sigma is None:
|
|
764
|
+
# Median heuristic for bandwidth
|
|
765
|
+
xy = torch.cat([x, y], dim=0)
|
|
766
|
+
pairwise_dist = torch.cdist(xy, xy, p=2)
|
|
767
|
+
sigma = torch.median(pairwise_dist[pairwise_dist > 0]).detach()
|
|
768
|
+
|
|
769
|
+
# Kernel matrices
|
|
770
|
+
xx = torch.exp(-torch.cdist(x, x, p=2)**2 / (2 * sigma**2))
|
|
771
|
+
yy = torch.exp(-torch.cdist(y, y, p=2)**2 / (2 * sigma**2))
|
|
772
|
+
xy = torch.exp(-torch.cdist(x, y, p=2)**2 / (2 * sigma**2))
|
|
773
|
+
|
|
774
|
+
# Compute MMD
|
|
775
|
+
mmd = (xx.sum() - xx.diag().sum()) / (n * (n - 1)) + \
|
|
776
|
+
(yy.sum() - yy.diag().sum()) / (m * (m - 1)) - \
|
|
777
|
+
2 * xy.mean()
|
|
778
|
+
return mmd
|
|
779
|
+
|
|
780
|
+
def sinkhorn_distance(self, x, y, epsilon=0.01, max_iters=100):
|
|
781
|
+
"""
|
|
782
|
+
Compute regularized Wasserstein distance using Sinkhorn algorithm.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
x: Tensor of shape (n, d)
|
|
786
|
+
y: Tensor of shape (m, d)
|
|
787
|
+
epsilon: Regularization parameter
|
|
788
|
+
max_iters: Number of Sinkhorn iterations
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
wasserstein: Approximated Wasserstein distance.
|
|
792
|
+
"""
|
|
793
|
+
n, m = x.size(0), y.size(0)
|
|
794
|
+
C = torch.cdist(x, y, p=2)**2 # Cost matrix
|
|
795
|
+
|
|
796
|
+
# Initialize dual variables
|
|
797
|
+
u, v = torch.zeros(n, device=x.device), torch.zeros(m, device=y.device)
|
|
798
|
+
|
|
799
|
+
for _ in range(max_iters):
|
|
800
|
+
u = (torch.logsumexp((C - v[None, :]) / epsilon, dim=1)) / (1 / epsilon)
|
|
801
|
+
v = (torch.logsumexp((C - u[:, None]) / epsilon, dim=0)) / (1 / epsilon)
|
|
802
|
+
|
|
803
|
+
# Compute transport plan and distance
|
|
804
|
+
P = torch.exp((u[:, None] + v[None, :] - C) / epsilon)
|
|
805
|
+
return torch.sum(P * C)
|
|
806
|
+
|
|
807
|
+
def metacell_distance(self, xs, ys=None, metric:Literal['mmd','sinkhorn']='mmd', epsilon:float=0.01, max_iters:int=100):
|
|
808
|
+
n_metacells = ys.shape[1]
|
|
809
|
+
dm2m = torch.zeros(n_metacells, n_metacells)
|
|
810
|
+
mc = np.argmax(ys, axis=1)
|
|
811
|
+
|
|
812
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
813
|
+
combinations = itertools.product(np.arange(n_metacells), np.arange(n_metacells))
|
|
814
|
+
for i,j in combinations:
|
|
815
|
+
i_cells = np.where(mc==i)
|
|
816
|
+
j_cells = np.where(mc==j)
|
|
817
|
+
if metric=='mmd':
|
|
818
|
+
dm2m[i,j] = self.mmd_gaussian(xs[i_cells],xs[j_cells])
|
|
819
|
+
else:
|
|
820
|
+
dm2m[i,j] = self.sinkhorn_distance(xs[i_cells],xs[j_cells],epsilon=epsilon,max_iters=max_iters)
|
|
821
|
+
|
|
822
|
+
return tensor_to_numpy(dm2m)
|
|
823
|
+
|
|
824
|
+
def metacell_similarity(self, xs, ys=None, embed: Literal['l1','l2']='l1', n_neighbors=50, sigma=1, use_diffuse=False, diffusion_time=1):
|
|
825
|
+
if ys is None:
|
|
826
|
+
ys = self.soft_assignments(xs)
|
|
827
|
+
if not use_diffuse:
|
|
828
|
+
ys = convert_to_tensor(ys, device=self.get_device())
|
|
829
|
+
m2m = torch.matmul(ys.T / torch.sum(ys.T, dim=1, keepdim=True), ys)
|
|
830
|
+
m2m = tensor_to_numpy(m2m)
|
|
831
|
+
else:
|
|
832
|
+
if embed=='l1':
|
|
833
|
+
zs = self.get_l1_embedding(xs)
|
|
834
|
+
else:
|
|
835
|
+
zs = self.get_l2_embedding(xs)
|
|
836
|
+
m2m = compute_metacell_diffusion_kernel(zs, ys, n_neighbors=n_neighbors, sigma=sigma, diffusion_time=diffusion_time)
|
|
837
|
+
return m2m
|
|
838
|
+
|
|
839
|
+
def metacell_network(self, affinity_matrix,
|
|
840
|
+
#xs, ys=None,
|
|
841
|
+
#k=10,
|
|
842
|
+
exclude_metacells: list = None):
|
|
843
|
+
#affinity_matrix = self.metacell_similarity(xs, ys, use_diffuse=use_diffuse)
|
|
844
|
+
self.G = nx.Graph()
|
|
845
|
+
self.G.add_nodes_from(np.arange(self.code_size))
|
|
846
|
+
|
|
847
|
+
#if k < affinity_matrix.shape[1]:
|
|
848
|
+
if True:
|
|
849
|
+
for i in np.arange(self.code_size):
|
|
850
|
+
arr = affinity_matrix[i,:]
|
|
851
|
+
#kth_largest_value = np.partition(arr, -k)[-k]
|
|
852
|
+
#arr[arr<kth_largest_value] = 0
|
|
853
|
+
#affinity_matrix[i,:] = arr
|
|
854
|
+
|
|
855
|
+
if exclude_metacells is None:
|
|
856
|
+
for j in np.arange(len(arr)):
|
|
857
|
+
if (arr[j]>0) and (j!=i):
|
|
858
|
+
self.G.add_edge(i,j,weight=1/arr[j])
|
|
859
|
+
elif not (i in exclude_metacells):
|
|
860
|
+
for j in np.arange(len(arr)):
|
|
861
|
+
if (arr[j]>0) and (j!=i) and (not j in exclude_metacells):
|
|
862
|
+
self.G.add_edge(i,j,weight=1/arr[j])
|
|
863
|
+
|
|
864
|
+
return self.G
|
|
865
|
+
|
|
866
|
+
def metacell_fa2(self, G, max_iter=100):
|
|
867
|
+
return visualize_metacell_igraph_with_fa2(G, iterations=max_iter)
|
|
868
|
+
|
|
869
|
+
def metacell_tree(self, G, root_metacell=0):
|
|
870
|
+
T = nx.minimum_spanning_tree(G)
|
|
871
|
+
sorted(T.edges(data=True))
|
|
872
|
+
tree = nx.dfs_tree(T, root_metacell)
|
|
873
|
+
|
|
874
|
+
return tree
|
|
875
|
+
|
|
876
|
+
def fit(self, xs,
|
|
877
|
+
ys = None,
|
|
878
|
+
zs = None,
|
|
879
|
+
us = None,
|
|
880
|
+
num_epochs: int = 200,
|
|
881
|
+
learning_rate: float = 0.0001,
|
|
882
|
+
batch_size: int = 512,
|
|
883
|
+
algo: Literal['adam','rmsprop','adamw'] = 'adam',
|
|
884
|
+
beta_1: float = 0.9,
|
|
885
|
+
weight_decay: float = 0.005,
|
|
886
|
+
decay_rate: float = 0.9,
|
|
887
|
+
threshold: int = 0,
|
|
888
|
+
normalize: bool = True,
|
|
889
|
+
config_enum: str = 'parallel',
|
|
890
|
+
use_jax: bool = False):
|
|
891
|
+
"""
|
|
892
|
+
Train the SURE model.
|
|
893
|
+
|
|
894
|
+
Parameters
|
|
895
|
+
----------
|
|
896
|
+
xs
|
|
897
|
+
Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
|
|
898
|
+
us
|
|
899
|
+
Undesired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are undesired factors.
|
|
900
|
+
num_epochs
|
|
901
|
+
Number of training epochs.
|
|
902
|
+
learning_rate
|
|
903
|
+
Parameter for training.
|
|
904
|
+
batch_size
|
|
905
|
+
Size of batch processing.
|
|
906
|
+
algo
|
|
907
|
+
Optimization algorithm.
|
|
908
|
+
beta_1
|
|
909
|
+
Parameter for optimization.
|
|
910
|
+
weight_decay
|
|
911
|
+
Parameter for optimization.
|
|
912
|
+
decay_rate
|
|
913
|
+
Parameter for optimization.
|
|
914
|
+
use_jax
|
|
915
|
+
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
916
|
+
the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
|
|
917
|
+
"""
|
|
918
|
+
self.normalize = normalize
|
|
919
|
+
|
|
920
|
+
if self.loss_func == 'bernoulli':
|
|
921
|
+
ad = sc.AnnData(xs)
|
|
922
|
+
binarize(ad, threshold=threshold)
|
|
923
|
+
xs = ad.X.copy()
|
|
924
|
+
else:
|
|
925
|
+
xs = np.round(xs)
|
|
926
|
+
|
|
927
|
+
if sparse.issparse(xs):
|
|
928
|
+
xs = xs.toarray()
|
|
929
|
+
|
|
930
|
+
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
931
|
+
if ys is not None:
|
|
932
|
+
ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
|
|
933
|
+
if zs is not None:
|
|
934
|
+
zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
|
|
935
|
+
if us is not None:
|
|
936
|
+
us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
|
|
937
|
+
else:
|
|
938
|
+
self.use_undesired = False
|
|
939
|
+
self.options = dict(dtype=xs.dtype, device=xs.device)
|
|
940
|
+
|
|
941
|
+
dataset = CustomDataset4(xs, ys, zs, us)
|
|
942
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
943
|
+
|
|
944
|
+
# setup the optimizer
|
|
945
|
+
optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
|
|
946
|
+
|
|
947
|
+
if algo.lower()=='rmsprop':
|
|
948
|
+
optimizer = torch.optim.RMSprop
|
|
949
|
+
elif algo.lower()=='adam':
|
|
950
|
+
optimizer = torch.optim.Adam
|
|
951
|
+
elif algo.lower() == 'adamw':
|
|
952
|
+
optimizer = torch.optim.AdamW
|
|
953
|
+
else:
|
|
954
|
+
raise ValueError("An optimization algorithm must be specified.")
|
|
955
|
+
scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
|
|
956
|
+
|
|
957
|
+
pyro.clear_param_store()
|
|
958
|
+
|
|
959
|
+
# set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
|
|
960
|
+
# by enumerating each class label form the sampled discrete categorical distribution in the model
|
|
961
|
+
Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
|
|
962
|
+
elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
|
|
963
|
+
if ys is None:
|
|
964
|
+
guide = config_enumerate(self.guide, config_enum, expand=True)
|
|
965
|
+
loss_basic = SVI(self.model, guide, scheduler, loss=elbo)
|
|
966
|
+
else:
|
|
967
|
+
guide = config_enumerate(self.guide2, config_enum, expand=True)
|
|
968
|
+
loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
|
|
969
|
+
|
|
970
|
+
# build a list of all losses considered
|
|
971
|
+
losses = [loss_basic]
|
|
972
|
+
num_losses = len(losses)
|
|
973
|
+
|
|
974
|
+
with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
|
|
975
|
+
for epoch in range(num_epochs):
|
|
976
|
+
epoch_losses = [0.0] * num_losses
|
|
977
|
+
for batch_x, batch_y, batch_z, batch_u, _ in dataloader:
|
|
978
|
+
if us is None:
|
|
979
|
+
batch_u = None
|
|
980
|
+
if ys is None:
|
|
981
|
+
batch_y = None
|
|
982
|
+
if zs is None:
|
|
983
|
+
batch_z = None
|
|
984
|
+
|
|
985
|
+
for loss_id in range(num_losses):
|
|
986
|
+
if batch_y is None:
|
|
987
|
+
new_loss = losses[loss_id].step(batch_x, batch_z, batch_u)
|
|
988
|
+
epoch_losses[loss_id] += new_loss
|
|
989
|
+
else:
|
|
990
|
+
new_loss = losses[loss_id].step(batch_x, batch_y, batch_z, batch_u)
|
|
991
|
+
epoch_losses[loss_id] += new_loss
|
|
992
|
+
|
|
993
|
+
avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
|
|
994
|
+
avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
|
|
995
|
+
|
|
996
|
+
# store the loss
|
|
997
|
+
str_loss = " ".join(map(str, avg_epoch_losses))
|
|
998
|
+
|
|
999
|
+
# Update progress bar
|
|
1000
|
+
pbar.set_postfix({'loss': str_loss})
|
|
1001
|
+
pbar.update(1)
|
|
1002
|
+
|
|
1003
|
+
assigns = self.soft_assignments(xs)
|
|
1004
|
+
assigns = convert_to_tensor(assigns, dtype=self.dtype, device=self.get_device())
|
|
1005
|
+
self.codebook_weights = torch.sum(assigns, dim=0)
|
|
1006
|
+
self.codebook_weights = self.codebook_weights / torch.sum(self.codebook_weights)
|
|
1007
|
+
|
|
1008
|
+
@classmethod
|
|
1009
|
+
def save_model(cls, model, file_path, compression=False):
|
|
1010
|
+
"""Save the model to the specified file path."""
|
|
1011
|
+
file_path = os.path.abspath(file_path)
|
|
1012
|
+
|
|
1013
|
+
model.eval()
|
|
1014
|
+
if compression:
|
|
1015
|
+
with gzip.open(file_path, 'wb') as pickle_file:
|
|
1016
|
+
pickle.dump(model, pickle_file)
|
|
1017
|
+
else:
|
|
1018
|
+
with open(file_path, 'wb') as pickle_file:
|
|
1019
|
+
pickle.dump(model, pickle_file)
|
|
1020
|
+
|
|
1021
|
+
print(f'Model saved to {file_path}')
|
|
1022
|
+
|
|
1023
|
+
@classmethod
|
|
1024
|
+
def load_model(cls, file_path):
|
|
1025
|
+
"""Load the model from the specified file path and return an instance."""
|
|
1026
|
+
print(f'Model loaded from {file_path}')
|
|
1027
|
+
|
|
1028
|
+
file_path = os.path.abspath(file_path)
|
|
1029
|
+
if file_path.endswith('gz'):
|
|
1030
|
+
with gzip.open(file_path, 'rb') as pickle_file:
|
|
1031
|
+
model = pickle.load(pickle_file)
|
|
1032
|
+
else:
|
|
1033
|
+
with open(file_path, 'rb') as pickle_file:
|
|
1034
|
+
model = pickle.load(pickle_file)
|
|
1035
|
+
|
|
1036
|
+
return model
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
EXAMPLE_RUN = (
|
|
1040
|
+
"example run: HMAP --help"
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
def parse_args():
|
|
1044
|
+
parser = argparse.ArgumentParser(
|
|
1045
|
+
description="HMAP\n{}".format(EXAMPLE_RUN))
|
|
1046
|
+
|
|
1047
|
+
parser.add_argument(
|
|
1048
|
+
"--cuda", action="store_true", help="use GPU(s) to speed up training"
|
|
1049
|
+
)
|
|
1050
|
+
parser.add_argument(
|
|
1051
|
+
"--jit", action="store_true", help="use PyTorch jit to speed up training"
|
|
1052
|
+
)
|
|
1053
|
+
parser.add_argument(
|
|
1054
|
+
"-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
|
|
1055
|
+
)
|
|
1056
|
+
parser.add_argument(
|
|
1057
|
+
"-enum",
|
|
1058
|
+
"--enum-discrete",
|
|
1059
|
+
default="parallel",
|
|
1060
|
+
help="parallel, sequential or none. uses parallel enumeration by default",
|
|
1061
|
+
)
|
|
1062
|
+
parser.add_argument(
|
|
1063
|
+
"-data",
|
|
1064
|
+
"--data-file",
|
|
1065
|
+
default=None,
|
|
1066
|
+
type=str,
|
|
1067
|
+
help="the data file",
|
|
1068
|
+
)
|
|
1069
|
+
parser.add_argument(
|
|
1070
|
+
"-undesired",
|
|
1071
|
+
"--undesired-factor-file",
|
|
1072
|
+
default=None,
|
|
1073
|
+
type=str,
|
|
1074
|
+
help="the file for the record of undesired factors",
|
|
1075
|
+
)
|
|
1076
|
+
parser.add_argument(
|
|
1077
|
+
"-64",
|
|
1078
|
+
"--float64",
|
|
1079
|
+
action="store_true",
|
|
1080
|
+
help="use double float precision",
|
|
1081
|
+
)
|
|
1082
|
+
parser.add_argument(
|
|
1083
|
+
"--z-dist",
|
|
1084
|
+
default='studentt',
|
|
1085
|
+
type=str,
|
|
1086
|
+
choices=['normal','laplacian','cauchy','studentt','gumbel'],
|
|
1087
|
+
help="distribution model for latent representation",
|
|
1088
|
+
)
|
|
1089
|
+
parser.add_argument(
|
|
1090
|
+
"-zd",
|
|
1091
|
+
"--z-dim",
|
|
1092
|
+
default=10,
|
|
1093
|
+
type=int,
|
|
1094
|
+
help="size of the tensor representing the latent variable z",
|
|
1095
|
+
)
|
|
1096
|
+
parser.add_argument(
|
|
1097
|
+
"-cs",
|
|
1098
|
+
"--codebook_size",
|
|
1099
|
+
default=30,
|
|
1100
|
+
type=int,
|
|
1101
|
+
help="size of vector quantization codebook",
|
|
1102
|
+
)
|
|
1103
|
+
parser.add_argument(
|
|
1104
|
+
"-dd",
|
|
1105
|
+
"--d-dim",
|
|
1106
|
+
default=2,
|
|
1107
|
+
type=int,
|
|
1108
|
+
choices=[2,3],
|
|
1109
|
+
help="size of the vector quantization codeword",
|
|
1110
|
+
)
|
|
1111
|
+
parser.add_argument(
|
|
1112
|
+
"--d-dist",
|
|
1113
|
+
default='normal',
|
|
1114
|
+
type=str,
|
|
1115
|
+
choices=['normal','laplacian','cauchy','vonmises','gumbel','studentt'],
|
|
1116
|
+
help="distribution model for visual representation",
|
|
1117
|
+
)
|
|
1118
|
+
parser.add_argument(
|
|
1119
|
+
"-hl",
|
|
1120
|
+
"--hidden-layers",
|
|
1121
|
+
nargs="+",
|
|
1122
|
+
default=[300],
|
|
1123
|
+
type=int,
|
|
1124
|
+
help="a tuple (or list) of MLP layers to be used in the neural networks "
|
|
1125
|
+
"representing the parameters of the distributions in our model",
|
|
1126
|
+
)
|
|
1127
|
+
parser.add_argument(
|
|
1128
|
+
"-hla",
|
|
1129
|
+
"--hidden-layer-activation",
|
|
1130
|
+
default='relu',
|
|
1131
|
+
type=str,
|
|
1132
|
+
choices=['relu','softplus','leakyrelu','linear'],
|
|
1133
|
+
help="activation function for hidden layers",
|
|
1134
|
+
)
|
|
1135
|
+
parser.add_argument(
|
|
1136
|
+
"-plf",
|
|
1137
|
+
"--post-layer-function",
|
|
1138
|
+
nargs="+",
|
|
1139
|
+
default=['layernorm'],
|
|
1140
|
+
type=str,
|
|
1141
|
+
help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
|
|
1142
|
+
)
|
|
1143
|
+
parser.add_argument(
|
|
1144
|
+
"-paf",
|
|
1145
|
+
"--post-activation-function",
|
|
1146
|
+
nargs="+",
|
|
1147
|
+
default=['none'],
|
|
1148
|
+
type=str,
|
|
1149
|
+
help="post functions for activation layers, could be none or dropout, default is 'none'",
|
|
1150
|
+
)
|
|
1151
|
+
parser.add_argument(
|
|
1152
|
+
"-id",
|
|
1153
|
+
"--inverse-dispersion",
|
|
1154
|
+
default=10.0,
|
|
1155
|
+
type=float,
|
|
1156
|
+
help="inverse dispersion prior for negative binomial",
|
|
1157
|
+
)
|
|
1158
|
+
parser.add_argument(
|
|
1159
|
+
"-lr",
|
|
1160
|
+
"--learning-rate",
|
|
1161
|
+
default=0.0001,
|
|
1162
|
+
type=float,
|
|
1163
|
+
help="learning rate for Adam optimizer",
|
|
1164
|
+
)
|
|
1165
|
+
parser.add_argument(
|
|
1166
|
+
"-dr",
|
|
1167
|
+
"--decay-rate",
|
|
1168
|
+
default=0.9,
|
|
1169
|
+
type=float,
|
|
1170
|
+
help="decay rate for Adam optimizer",
|
|
1171
|
+
)
|
|
1172
|
+
parser.add_argument(
|
|
1173
|
+
"--layer-dropout-rate",
|
|
1174
|
+
default=0.1,
|
|
1175
|
+
type=float,
|
|
1176
|
+
help="droput rate for neural networks",
|
|
1177
|
+
)
|
|
1178
|
+
parser.add_argument(
|
|
1179
|
+
"-b1",
|
|
1180
|
+
"--beta-1",
|
|
1181
|
+
default=0.95,
|
|
1182
|
+
type=float,
|
|
1183
|
+
help="beta-1 parameter for Adam optimizer",
|
|
1184
|
+
)
|
|
1185
|
+
parser.add_argument(
|
|
1186
|
+
"-bs",
|
|
1187
|
+
"--batch-size",
|
|
1188
|
+
default=1000,
|
|
1189
|
+
type=int,
|
|
1190
|
+
help="number of cells to be considered in a batch",
|
|
1191
|
+
)
|
|
1192
|
+
parser.add_argument(
|
|
1193
|
+
"-likeli",
|
|
1194
|
+
"--likelihood",
|
|
1195
|
+
default='poisson',
|
|
1196
|
+
type=str,
|
|
1197
|
+
choices=['negbinomial', 'multinomial', 'poisson'],
|
|
1198
|
+
help="specify the distribution likelihood function",
|
|
1199
|
+
)
|
|
1200
|
+
parser.add_argument(
|
|
1201
|
+
"--seed",
|
|
1202
|
+
default=None,
|
|
1203
|
+
type=int,
|
|
1204
|
+
help="seed for controlling randomness in this example",
|
|
1205
|
+
)
|
|
1206
|
+
parser.add_argument(
|
|
1207
|
+
"--save-model",
|
|
1208
|
+
default=None,
|
|
1209
|
+
type=str,
|
|
1210
|
+
help="path to save model for prediction",
|
|
1211
|
+
)
|
|
1212
|
+
args = parser.parse_args()
|
|
1213
|
+
|
|
1214
|
+
return args
|
|
1215
|
+
|
|
1216
|
+
def set_random_seed(seed):
|
|
1217
|
+
# Set seed for PyTorch
|
|
1218
|
+
torch.manual_seed(seed)
|
|
1219
|
+
|
|
1220
|
+
# If using CUDA, set the seed for CUDA
|
|
1221
|
+
if torch.cuda.is_available():
|
|
1222
|
+
torch.cuda.manual_seed(seed)
|
|
1223
|
+
torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
|
|
1224
|
+
|
|
1225
|
+
# Set seed for NumPy
|
|
1226
|
+
np.random.seed(seed)
|
|
1227
|
+
|
|
1228
|
+
# Set seed for Python's random module
|
|
1229
|
+
random.seed(seed)
|
|
1230
|
+
|
|
1231
|
+
# Set seed for Pyro
|
|
1232
|
+
pyro.set_rng_seed(seed)
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
def main():
|
|
1236
|
+
args = parse_args()
|
|
1237
|
+
|
|
1238
|
+
assert (
|
|
1239
|
+
(args.data_file is not None) and (
|
|
1240
|
+
os.path.exists(args.data_file))
|
|
1241
|
+
), "data file must be provided"
|
|
1242
|
+
|
|
1243
|
+
if args.float64:
|
|
1244
|
+
dtype = torch.float64
|
|
1245
|
+
torch.set_default_dtype(torch.float64)
|
|
1246
|
+
else:
|
|
1247
|
+
dtype = torch.float32
|
|
1248
|
+
torch.set_default_dtype(torch.float32)
|
|
1249
|
+
|
|
1250
|
+
xs = dt.fread(file=args.data_file, header=True).to_numpy()
|
|
1251
|
+
us = None
|
|
1252
|
+
if args.undesired_factor_file is not None:
|
|
1253
|
+
us = dt.fread(file=args.undesired_factor_file, header=True).to_numpy()
|
|
1254
|
+
|
|
1255
|
+
input_size = xs.shape[1]
|
|
1256
|
+
undesired_size = 0 if us is None else us.shape[1]
|
|
1257
|
+
|
|
1258
|
+
z_dist = args.z_dist
|
|
1259
|
+
d_dist = args.d_dist
|
|
1260
|
+
|
|
1261
|
+
# batch_size: number of cells (and labels) to be considered in a batch
|
|
1262
|
+
hmap = HMAP(
|
|
1263
|
+
input_size=input_size,
|
|
1264
|
+
undesired_size=undesired_size,
|
|
1265
|
+
codebook_size=args.codebook_size,
|
|
1266
|
+
d_dim=args.d_dim,
|
|
1267
|
+
d_dist=d_dist,
|
|
1268
|
+
z_dim=args.z_dim,
|
|
1269
|
+
z_dist=z_dist,
|
|
1270
|
+
hidden_layers=args.hidden_layers,
|
|
1271
|
+
hidden_layer_activation=args.hidden_layer_activation,
|
|
1272
|
+
loss_func=args.likelihood,
|
|
1273
|
+
inverse_dispersion=args.inverse_dispersion,
|
|
1274
|
+
nn_dropout=args.layer_dropout_rate,
|
|
1275
|
+
use_cuda=args.cuda,
|
|
1276
|
+
config_enum=args.enum_discrete,
|
|
1277
|
+
post_layer_fct=args.post_layer_function,
|
|
1278
|
+
post_act_fct=args.post_activation_function,
|
|
1279
|
+
dtype=dtype,
|
|
1280
|
+
seed=args.seed,
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
hmap.fit(xs, us = us,
|
|
1284
|
+
num_epochs=args.num_epochs,
|
|
1285
|
+
learning_rate=args.learning_rate,
|
|
1286
|
+
batch_size=args.batch_size,
|
|
1287
|
+
beta_1=args.beta_1,
|
|
1288
|
+
decay_rate=args.decay_rate,
|
|
1289
|
+
use_jax=args.jit,
|
|
1290
|
+
config_enum=args.enum_discrete,
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
if args.save_model is not None:
|
|
1294
|
+
HMAP.save_model(hmap, args.save_model)
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
if __name__ == "__main__":
|
|
1298
|
+
main()
|