HMAP-tool 1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1298 @@
1
+ import pyro
2
+ import pyro.distributions as dist
3
+ from pyro.optim import ExponentialLR
4
+ from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_probs
10
+ from torch.distributions import constraints
11
+ from torch.distributions.transforms import SoftmaxTransform
12
+
13
+ import argparse
14
+ import os
15
+ import time as tm
16
+ import random
17
+ import itertools
18
+ import pandas as pd
19
+ import numpy as np
20
+ import datatable as dt
21
+ import networkx as nx
22
+ from pathlib import Path
23
+ import warnings
24
+ warnings.filterwarnings("ignore")
25
+
26
+ from .utils.utils import convert_to_tensor, tensor_to_numpy, CustomDataset, CustomDataset2, CustomDataset4
27
+ from .utils.custom_mlp import MLP, Exp
28
+ from .graph.graph import compute_metacell_diffusion_kernel,visualize_metacell_igraph_with_fa2
29
+ from .atac import binarize
30
+
31
+ from tqdm import tqdm
32
+ from typing import Literal
33
+ import dill as pickle
34
+ import gzip
35
+ import scanpy as sc
36
+ from scipy import sparse
37
+
38
+ def set_random_seed(seed):
39
+ # Set seed for PyTorch
40
+ torch.manual_seed(seed)
41
+
42
+ # If using CUDA, set the seed for CUDA
43
+ if torch.cuda.is_available():
44
+ torch.cuda.manual_seed(seed)
45
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
46
+
47
+ # Set seed for NumPy
48
+ np.random.seed(seed)
49
+
50
+ # Set seed for Python's random module
51
+ random.seed(seed)
52
+
53
+ # Set seed for Pyro
54
+ pyro.set_rng_seed(seed)
55
+
56
+ class HMAP(nn.Module):
57
+ def __init__(self,
58
+ input_size: int,
59
+ undesired_size: int = 0,
60
+ codebook_size: int = 30, # size of metacell codebook
61
+ supervised_mode: bool = False,
62
+ use_cell_factor: bool = False,
63
+ d_dim: int = 3, # dimension of a metacell variable
64
+ d_dist: Literal['normal','laplacian','caucy','studentt','vonmises','gumbel'] = 'normal',
65
+ z_dim: int = 10,
66
+ z_dist: Literal['normal','laplacian','cauchy','studentt','gumbel'] = 'laplacian',
67
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
68
+ hidden_layers: list =[300],
69
+ hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
+ inverse_dispersion: float = 10.0,
71
+ nn_dropout: float = 0.1,
72
+ post_layer_fct: list = ['layernorm'],
73
+ post_act_fct: list = None,
74
+ config_enum: str = 'parallel',
75
+ use_cuda: bool = False,
76
+ seed: int = 42,
77
+ dtype = torch.float32, # type: ignore
78
+ ):
79
+ super().__init__()
80
+
81
+ # initialize the class with all arguments provided to the constructor
82
+ self.input_size = input_size
83
+ self.undesired_size = undesired_size
84
+ self.inverse_dispersion = inverse_dispersion
85
+ self.z_dim = z_dim
86
+ self.hidden_layers = hidden_layers
87
+ self.use_undesired = True if self.undesired_size>0 else False
88
+ self.allow_broadcast = config_enum == 'parallel'
89
+ self.use_cuda = use_cuda
90
+ self.loss_func = loss_func
91
+ self.options = None
92
+ self.code_size=codebook_size
93
+ self.D_size=d_dim
94
+ self.z_dist = z_dist
95
+ self.d_dist = d_dist
96
+ self.G = None
97
+ self.supervised_mode = supervised_mode
98
+ self.dtype = dtype
99
+ self.use_cell_factor = use_cell_factor
100
+ self.normalize = True
101
+
102
+ self.nn_dropout = nn_dropout
103
+ self.post_layer_fct = post_layer_fct
104
+ self.post_act_fct = post_act_fct
105
+ self.hidden_layer_activation = hidden_layer_activation
106
+
107
+ assert loss_func in ['poisson','multinomial','negbinomial','bernoulli']
108
+
109
+ if seed is not None:
110
+ set_random_seed(seed)
111
+
112
+ # define and instantiate the neural networks representing
113
+ # the parameters of various distributions in the model
114
+ self.setup_networks()
115
+
116
+ def setup_networks(self):
117
+ z_dim = self.z_dim
118
+ hidden_sizes = self.hidden_layers
119
+
120
+ nn_layer_norm, nn_batch_norm, nn_layer_dropout = False, False, False
121
+ na_layer_norm, na_batch_norm, na_layer_dropout = False, False, False
122
+
123
+ if self.post_layer_fct is not None:
124
+ nn_layer_norm=True if ('layernorm' in self.post_layer_fct) or ('layer_norm' in self.post_layer_fct) else False
125
+ nn_batch_norm=True if ('batchnorm' in self.post_layer_fct) or ('batch_norm' in self.post_layer_fct) else False
126
+ nn_layer_dropout=True if 'dropout' in self.post_layer_fct else False
127
+
128
+ if self.post_act_fct is not None:
129
+ na_layer_norm=True if ('layernorm' in self.post_act_fct) or ('layer_norm' in self.post_act_fct) else False
130
+ na_batch_norm=True if ('batchnorm' in self.post_act_fct) or ('batch_norm' in self.post_act_fct) else False
131
+ na_layer_dropout=True if 'dropout' in self.post_act_fct else False
132
+
133
+ if nn_layer_norm and nn_batch_norm and nn_layer_dropout:
134
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
135
+ elif nn_layer_norm and nn_layer_dropout:
136
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
137
+ elif nn_batch_norm and nn_layer_dropout:
138
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
139
+ elif nn_layer_norm and nn_batch_norm:
140
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
141
+ elif nn_layer_norm:
142
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
143
+ elif nn_batch_norm:
144
+ post_layer_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
145
+ elif nn_layer_dropout:
146
+ post_layer_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
147
+ else:
148
+ post_layer_fct = lambda layer_ix, total_layers, layer: None
149
+
150
+ if na_layer_norm and na_batch_norm and na_layer_dropout:
151
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout),nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
152
+ elif na_layer_norm and na_layer_dropout:
153
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.LayerNorm(layer.module.out_features))
154
+ elif na_batch_norm and na_layer_dropout:
155
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.Dropout(self.nn_dropout), nn.BatchNorm1d(layer.module.out_features))
156
+ elif na_layer_norm and na_batch_norm:
157
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Sequential(nn.BatchNorm1d(layer.module.out_features), nn.LayerNorm(layer.module.out_features))
158
+ elif na_layer_norm:
159
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.LayerNorm(layer.module.out_features)
160
+ elif na_batch_norm:
161
+ post_act_fct = lambda layer_ix, total_layers, layer:nn.BatchNorm1d(layer.module.out_features)
162
+ elif na_layer_dropout:
163
+ post_act_fct = lambda layer_ix, total_layers, layer: nn.Dropout(self.nn_dropout)
164
+ else:
165
+ post_act_fct = lambda layer_ix, total_layers, layer: None
166
+
167
+ if self.hidden_layer_activation == 'relu':
168
+ activate_fct = nn.ReLU
169
+ elif self.hidden_layer_activation == 'softplus':
170
+ activate_fct = nn.Softplus
171
+ elif self.hidden_layer_activation == 'leakyrelu':
172
+ activate_fct = nn.LeakyReLU
173
+ elif self.hidden_layer_activation == 'linear':
174
+ activate_fct = nn.Identity
175
+
176
+ # define the neural networks used later in the model and the guide.
177
+ if self.supervised_mode:
178
+ self.encoder_n = MLP(
179
+ [self.input_size] + hidden_sizes + [self.code_size],
180
+ activation=activate_fct,
181
+ output_activation=None,
182
+ post_layer_fct=post_layer_fct,
183
+ post_act_fct=post_act_fct,
184
+ allow_broadcast=self.allow_broadcast,
185
+ use_cuda=self.use_cuda,
186
+ )
187
+ else:
188
+ self.encoder_n = MLP(
189
+ [self.D_size] + hidden_sizes + [self.code_size],
190
+ activation=activate_fct,
191
+ output_activation=None,
192
+ post_layer_fct=post_layer_fct,
193
+ post_act_fct=post_act_fct,
194
+ allow_broadcast=self.allow_broadcast,
195
+ use_cuda=self.use_cuda,
196
+ )
197
+
198
+ self.encoder_zn = MLP(
199
+ [self.input_size] + hidden_sizes + [[z_dim, z_dim]],
200
+ activation=activate_fct,
201
+ output_activation=[None, Exp],
202
+ post_layer_fct=post_layer_fct,
203
+ post_act_fct=post_act_fct,
204
+ allow_broadcast=self.allow_broadcast,
205
+ use_cuda=self.use_cuda,
206
+ )
207
+
208
+ self.encoder_d = MLP(
209
+ [self.z_dim] + hidden_sizes + [[self.D_size, self.D_size]],
210
+ activation=activate_fct,
211
+ output_activation=[None, Exp],
212
+ post_layer_fct=post_layer_fct,
213
+ post_act_fct=post_act_fct,
214
+ allow_broadcast=self.allow_broadcast,
215
+ use_cuda=self.use_cuda,
216
+ )
217
+
218
+ if self.use_cell_factor:
219
+ self.cell_factor = MLP(
220
+ [self.input_size] + hidden_sizes + [1],
221
+ activation=activate_fct,
222
+ output_activation=None,
223
+ post_layer_fct=post_layer_fct,
224
+ post_act_fct=post_act_fct,
225
+ allow_broadcast=self.allow_broadcast,
226
+ use_cuda=self.use_cuda,
227
+ )
228
+
229
+ if self.use_undesired:
230
+ self.decoder_concentrate = MLP(
231
+ [self.undesired_size + self.z_dim] + hidden_sizes + [self.input_size],
232
+ activation=activate_fct,
233
+ output_activation=None,
234
+ post_layer_fct=post_layer_fct,
235
+ post_act_fct=post_act_fct,
236
+ allow_broadcast=self.allow_broadcast,
237
+ use_cuda=self.use_cuda,
238
+ )
239
+ else:
240
+ self.decoder_concentrate = MLP(
241
+ [self.z_dim] + hidden_sizes + [self.input_size],
242
+ activation=activate_fct,
243
+ output_activation=None,
244
+ post_layer_fct=post_layer_fct,
245
+ post_act_fct=post_act_fct,
246
+ allow_broadcast=self.allow_broadcast,
247
+ use_cuda=self.use_cuda,
248
+ )
249
+
250
+ if self.d_dist == 'studentt':
251
+ self.codebook = MLP(
252
+ [self.code_size] + hidden_sizes + [[self.D_size,self.D_size]],
253
+ activation=activate_fct,
254
+ output_activation=[Exp, None],
255
+ post_layer_fct=post_layer_fct,
256
+ post_act_fct=post_act_fct,
257
+ allow_broadcast=self.allow_broadcast,
258
+ use_cuda=self.use_cuda,
259
+ )
260
+ else:
261
+ self.codebook = MLP(
262
+ [self.code_size] + hidden_sizes + [self.D_size],
263
+ activation=activate_fct,
264
+ output_activation=None,
265
+ post_layer_fct=post_layer_fct,
266
+ post_act_fct=post_act_fct,
267
+ allow_broadcast=self.allow_broadcast,
268
+ use_cuda=self.use_cuda,
269
+ )
270
+
271
+ if self.z_dist == 'studentt':
272
+ self.decoder_zn = MLP(
273
+ [self.D_size] + hidden_sizes + [[self.z_dim,self.z_dim]],
274
+ activation=activate_fct,
275
+ output_activation=[Exp,None],
276
+ post_layer_fct=post_layer_fct,
277
+ post_act_fct=post_act_fct,
278
+ allow_broadcast=self.allow_broadcast,
279
+ use_cuda=self.use_cuda,
280
+ )
281
+ else:
282
+ self.decoder_zn = MLP(
283
+ [self.D_size] + hidden_sizes + [self.z_dim],
284
+ activation=activate_fct,
285
+ output_activation=None,
286
+ post_layer_fct=post_layer_fct,
287
+ post_act_fct=post_act_fct,
288
+ allow_broadcast=self.allow_broadcast,
289
+ use_cuda=self.use_cuda,
290
+ )
291
+
292
+ # using GPUs for faster training of the networks
293
+ if self.use_cuda:
294
+ self.cuda()
295
+
296
+ def cutoff(self, xs, thresh=None):
297
+ eps = torch.finfo(xs.dtype).eps
298
+
299
+ if not thresh is None:
300
+ if eps < thresh:
301
+ eps = thresh
302
+
303
+ xs = xs.clamp(min=eps)
304
+
305
+ if torch.any(torch.isnan(xs)):
306
+ xs[torch.isnan(xs)] = eps
307
+
308
+ return xs
309
+
310
+ def softmax(self, xs):
311
+ #soft_enc = nn.Softmax(dim=1)
312
+ #xs = soft_enc(xs)
313
+ #xs = clamp_probs(xs)
314
+ #xs = ft.normalize(xs, 1, 1)
315
+ xs = SoftmaxTransform()(xs)
316
+ return xs
317
+
318
+ def sigmoid(self, xs):
319
+ sigm_enc = nn.Sigmoid()
320
+ xs = sigm_enc(xs)
321
+ xs = clamp_probs(xs)
322
+ return xs
323
+
324
+ def softmax_logit(self, xs):
325
+ eps = torch.finfo(xs.dtype).eps
326
+ xs = self.softmax(xs)
327
+ xs = torch.logit(xs, eps=eps)
328
+ return xs
329
+
330
+ def logit(self, xs):
331
+ eps = torch.finfo(xs.dtype).eps
332
+ xs = torch.logit(xs, eps=eps)
333
+ return xs
334
+
335
+ def dirimulti_param(self, xs):
336
+ xs = self.dirimulti_mass * self.sigmoid(xs)
337
+ return xs
338
+
339
+ def multi_param(self, xs):
340
+ xs = self.softmax(xs)
341
+ return xs
342
+
343
+ def get_device(self):
344
+ return next(self.parameters()).device
345
+
346
+ def model(self, xs, embeds=None, ks2=None):
347
+ # register this pytorch module and all of its sub-modules with pyro
348
+ pyro.module('HMAP', self)
349
+
350
+ eps = torch.finfo(xs.dtype).eps
351
+ batch_size = xs.size(0)
352
+ self.options = dict(dtype=xs.dtype, device=xs.device)
353
+
354
+ if self.loss_func == 'negbinomial':
355
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion * torch.ones(1, self.input_size, **self.options),
356
+ constraint=constraints.positive)
357
+
358
+ acs_scale = pyro.param("codebook_scale", torch.ones(1, self.D_size, **self.options),
359
+ constraint=constraints.positive)
360
+ zn_scale = pyro.param("z_scale", torch.ones(1, self.z_dim, **self.options),
361
+ constraint=constraints.positive)
362
+
363
+ I = torch.eye(self.code_size)
364
+ if self.d_dist == 'studentt':
365
+ acs_dof,acs_loc = self.codebook(I)
366
+ else:
367
+ acs_loc = self.codebook(I)
368
+
369
+ with pyro.plate('data'):
370
+ ###############################################
371
+ # p(zn)
372
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
373
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
374
+
375
+ if self.d_dist == 'studentt':
376
+ d_dof = torch.matmul(ns,acs_dof)
377
+ d_loc = torch.matmul(ns,acs_loc)
378
+ d_scale = acs_scale
379
+ if self.d_dist == 'normal':
380
+ ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
381
+ elif self.d_dist == 'laplacian':
382
+ ds = pyro.sample('d', dist.Laplace(d_loc, d_scale).to_event(1))
383
+ elif self.d_dist == 'cauchy':
384
+ ds = pyro.sample('d', dist.Cauchy(d_loc, d_scale).to_event(1))
385
+ elif self.d_dist == 'vonmises':
386
+ ds = pyro.sample('d', dist.VonMises(d_loc, d_scale).to_event(1))
387
+ elif self.d_dist == 'gumbel':
388
+ ds = pyro.sample('d', dist.Gumbel(d_loc, d_scale).to_event(1))
389
+ elif self.d_dist == 'studentt':
390
+ ds = pyro.sample('d', dist.StudentT(df=d_dof, loc=d_loc, scale=d_scale).to_event(1))
391
+
392
+ if self.z_dist == 'studentt':
393
+ zn_dof,zn_loc = self.decoder_zn(ds)
394
+ else:
395
+ zn_loc = self.decoder_zn(ds)
396
+ if self.z_dist == 'laplacian':
397
+ if embeds is None:
398
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
399
+ else:
400
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
401
+ elif self.z_dist == 'cauchy':
402
+ if embeds is None:
403
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
404
+ else:
405
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
406
+ elif self.z_dist == 'normal':
407
+ if embeds is None:
408
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
409
+ else:
410
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
411
+ elif self.z_dist == 'studentt':
412
+ if embeds is None:
413
+ zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1))
414
+ else:
415
+ zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
416
+ elif self.z_dist == 'gumbel':
417
+ if embeds is None:
418
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
419
+ else:
420
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
421
+
422
+ ###############################################
423
+ # p(a | zys, zk2s)
424
+ if self.use_undesired:
425
+ zs = [ks2, zns]
426
+ else:
427
+ zs = zns
428
+
429
+ concentrate = self.decoder_concentrate(zs)
430
+ if self.use_cell_factor:
431
+ cf = self.cell_factor(xs)
432
+ concentrate += cf
433
+
434
+ if self.normalize:
435
+ rate = concentrate.exp()
436
+ if self.loss_func != 'poisson':
437
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
438
+ else:
439
+ if self.loss_func == 'poisson':
440
+ rate = concentrate.exp()
441
+ else:
442
+ logits = concentrate
443
+
444
+ if self.loss_func == 'negbinomial':
445
+ if self.normalize:
446
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
447
+ else:
448
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=logits).to_event(1), obs=xs)
449
+ elif self.loss_func == 'multinomial':
450
+ if self.normalize:
451
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
452
+ else:
453
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=logits), obs=xs)
454
+ elif self.loss_func == 'bernoulli':
455
+ if self.normalize:
456
+ pyro.sample('x', dist.Bernoulli(probs=theta).to_event(1), obs=xs)
457
+ else:
458
+ pyro.sample('x', dist.Bernoulli(logits=logits).to_event(1), obs=xs)
459
+ elif self.loss_func == 'poisson':
460
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
461
+
462
+ def model2(self, xs, ys=None, embeds=None, ks2=None):
463
+ # register this pytorch module and all of its sub-modules with pyro
464
+ pyro.module('HMAP', self)
465
+
466
+ eps = torch.finfo(xs.dtype).eps
467
+ batch_size = xs.size(0)
468
+ self.options = dict(dtype=xs.dtype, device=xs.device)
469
+
470
+ if self.loss_func == 'negbinomial':
471
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion * torch.ones(1, self.input_size, **self.options),
472
+ constraint=constraints.positive)
473
+ acs_scale = pyro.param("codebook_scale", torch.ones(1, self.D_size, **self.options),
474
+ constraint=constraints.positive)
475
+ zn_scale = pyro.param("z_scale", torch.ones(1, self.z_dim, **self.options),
476
+ constraint=constraints.positive)
477
+
478
+ I = torch.eye(self.code_size)
479
+ if self.d_dist == 'studentt':
480
+ acs_dof,acs_loc = self.codebook(I)
481
+ else:
482
+ acs_loc = self.codebook(I)
483
+
484
+ with pyro.plate('data'):
485
+ ###############################################
486
+ # p(zn)
487
+ if ys is None:
488
+ prior = torch.zeros(batch_size, self.code_size, **self.options)
489
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior))
490
+ else:
491
+ prior = self.encoder_n(xs)
492
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=prior), obs=ys)
493
+
494
+ if self.d_dist == 'studentt':
495
+ d_dof = torch.matmul(ns,acs_dof)
496
+ d_loc = torch.matmul(ns,acs_loc)
497
+ d_scale = acs_scale
498
+ if self.d_dist == 'normal':
499
+ ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
500
+ elif self.d_dist == 'laplacian':
501
+ ds = pyro.sample('d', dist.Laplace(d_loc, d_scale).to_event(1))
502
+ elif self.d_dist == 'cauchy':
503
+ ds = pyro.sample('d', dist.Cauchy(d_loc, d_scale).to_event(1))
504
+ elif self.d_dist == 'vonmises':
505
+ ds = pyro.sample('d', dist.VonMises(d_loc, d_scale).to_event(1))
506
+ elif self.d_dist == 'gumbel':
507
+ ds = pyro.sample('d', dist.Gumbel(d_loc, d_scale).to_event(1))
508
+ elif self.d_dist == 'studentt':
509
+ ds = pyro.sample('d', dist.StudentT(df=d_dof, loc=d_loc, scale=d_scale).to_event(1))
510
+
511
+ if self.z_dist == 'studentt':
512
+ zn_dof,zn_loc = self.decoder_zn(ds)
513
+ else:
514
+ zn_loc = self.decoder_zn(ds)
515
+ if self.z_dist == 'laplacian':
516
+ if embeds is None:
517
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1))
518
+ else:
519
+ zns = pyro.sample('zn', dist.Laplace(zn_loc, zn_scale).to_event(1), obs=embeds)
520
+ elif self.z_dist == 'cauchy':
521
+ if embeds is None:
522
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1))
523
+ else:
524
+ zns = pyro.sample('zn', dist.Cauchy(zn_loc, zn_scale).to_event(1), obs=embeds)
525
+ elif self.z_dist == 'normal':
526
+ if embeds is None:
527
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
528
+ else:
529
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1), obs=embeds)
530
+ elif self.z_dist == 'studentt':
531
+ if embeds is None:
532
+ zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1))
533
+ else:
534
+ zns = pyro.sample('zn', dist.StudentT(df=zn_dof, loc=zn_loc, scale=zn_scale).to_event(1), obs=embeds)
535
+ elif self.z_dist == 'gumbel':
536
+ if embeds is None:
537
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
538
+ else:
539
+ zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
540
+
541
+ ###############################################
542
+ # p(a | zys, zk2s)
543
+ if self.use_undesired:
544
+ zs = [ks2, zns]
545
+ else:
546
+ zs = zns
547
+
548
+ concentrate = self.decoder_concentrate(zs)
549
+ if self.use_cell_factor:
550
+ cf = self.cell_factor(xs)
551
+ concentrate += cf
552
+
553
+ if self.normalize:
554
+ rate = concentrate.exp()
555
+ if self.loss_func != 'poisson':
556
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
557
+ else:
558
+ if self.loss_func == 'poisson':
559
+ rate = concentrate.exp()
560
+ else:
561
+ logits = concentrate
562
+
563
+ if self.loss_func == 'negbinomial':
564
+ if self.normalize:
565
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
566
+ else:
567
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=logits).to_event(1), obs=xs)
568
+ elif self.loss_func == 'multinomial':
569
+ if self.normalize:
570
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
571
+ else:
572
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=logits), obs=xs)
573
+ elif self.loss_func == 'bernoulli':
574
+ if self.normalize:
575
+ pyro.sample('x', dist.Bernoulli(probs=theta).to_event(1), obs=xs)
576
+ else:
577
+ pyro.sample('x', dist.Bernoulli(logits=logits).to_event(1), obs=xs)
578
+ elif self.loss_func == 'poisson':
579
+ pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
580
+
581
+ def guide(self, xs, embeds=None, ks2=None):
582
+ # inform Pyro that the variables in the batch of xs, ys are conditionally independent
583
+ with pyro.plate('data'):
584
+ # q(zn | x)
585
+ if embeds is None:
586
+ zn_loc, zn_scale = self.encoder_zn(xs)
587
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
588
+ else:
589
+ zns = embeds
590
+
591
+ d_loc,d_scale = self.encoder_d(zns)
592
+ ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
593
+
594
+ alpha = self.encoder_n(ds)
595
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
596
+
597
+ def guide2(self, xs, ys=None, embeds=None, ks2=None):
598
+ # inform Pyro that the variables in the batch of xs, ys are conditionally independent
599
+ with pyro.plate('data'):
600
+ # q(zn | x)
601
+ if embeds is None:
602
+ zn_loc, zn_scale = self.encoder_zn(xs)
603
+ zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
604
+ else:
605
+ zns = embeds
606
+
607
+ d_loc,d_scale = self.encoder_d(zns)
608
+ ds = pyro.sample('d', dist.Normal(d_loc, d_scale).to_event(1))
609
+
610
+ if ys is None:
611
+ alpha = self.encoder_n(ds)
612
+ ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
613
+
614
+ def _codebook(self):
615
+ I = torch.eye(self.code_size, **self.options)
616
+ if self.d_dist == 'studentt':
617
+ _,loc = self.codebook(I)
618
+ else:
619
+ loc = self.codebook(I)
620
+ return loc
621
+
622
+ def get_codebook(self):
623
+ I = torch.eye(self.code_size, **self.options)
624
+ if self.d_dist == 'studentt':
625
+ _,cb = self.codebook(I)
626
+ else:
627
+ cb = self.codebook(I)
628
+
629
+ if self.z_dist == 'studentt':
630
+ _,zs = self.decoder_zn(cb)
631
+ else:
632
+ zs = self.decoder_zn(cb)
633
+ return tensor_to_numpy(cb), tensor_to_numpy(zs)
634
+
635
+ def _code(self, xs):
636
+ if self.supervised_mode:
637
+ alpha = self.encoder_n(xs)
638
+ else:
639
+ zns,_ = self.encoder_zn(xs)
640
+ ds,_ = self.encoder_d(zns)
641
+ alpha = self.encoder_n(ds)
642
+ return alpha
643
+
644
+ def _l0_embedding(self,xs):
645
+ ns = self._soft_assignments(xs)
646
+ acs = self._codebook()
647
+ return torch.matmul(ns, acs)
648
+
649
+ def _l1_embedding(self,xs,zs=None):
650
+ if zs is None:
651
+ zs = self._l2_embedding(xs)
652
+ ds,_ = self.encoder_d(zs)
653
+ return ds
654
+
655
+ def _l2_embedding(self, xs):
656
+ zns, _ = self.encoder_zn(xs)
657
+ return zns
658
+
659
+ def get_l0_embedding(self, xs, batch_size=1024):
660
+ xs = convert_to_tensor(xs, device=self.get_device())
661
+ dataset = CustomDataset(xs)
662
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
663
+
664
+ Z = []
665
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
666
+ for X_batch, _ in dataloader:
667
+ zns = self._l0_embedding(X_batch)
668
+ Z.append(tensor_to_numpy(zns))
669
+ pbar.update(1)
670
+
671
+ Z = np.concatenate(Z)
672
+ return Z
673
+
674
+ def get_l1_embedding(self, xs, zs=None, batch_size=1024):
675
+ xs = convert_to_tensor(xs, device=self.get_device())
676
+ if zs is not None:
677
+ zs = convert_to_tensor(zs, device=self.get_device())
678
+ dataset = CustomDataset2(xs, zs)
679
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
680
+
681
+ Z = []
682
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
683
+ for X_batch, Z_batch, _ in dataloader:
684
+ if zs is None:
685
+ Z_batch = None
686
+ zns = self._l1_embedding(X_batch, zs=Z_batch)
687
+ Z.append(tensor_to_numpy(zns))
688
+ pbar.update(1)
689
+
690
+ Z = np.concatenate(Z)
691
+ return Z
692
+
693
+ def get_l2_embedding(self,xs,batch_size=1024):
694
+ xs = convert_to_tensor(xs, device=self.get_device())
695
+ dataset = CustomDataset(xs)
696
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
697
+
698
+ Z = []
699
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
700
+ for X_batch, _ in dataloader:
701
+ zns = self._l2_embedding(X_batch)
702
+ Z.append(tensor_to_numpy(zns))
703
+ pbar.update(1)
704
+
705
+ Z = np.concatenate(Z)
706
+ return Z
707
+
708
+ def _soft_assignments(self, xs):
709
+ alpha = self._code(xs)
710
+ alpha = self.softmax(alpha)
711
+ return alpha
712
+
713
+ def soft_assignments(self, xs, batch_size=1024):
714
+ xs = convert_to_tensor(xs, device=self.get_device())
715
+ dataset = CustomDataset(xs)
716
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
717
+
718
+ A = []
719
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
720
+ for X_batch, _ in dataloader:
721
+ a = self._soft_assignments(X_batch)
722
+ A.append(tensor_to_numpy(a))
723
+ pbar.update(1)
724
+
725
+ A = np.concatenate(A)
726
+ return A
727
+
728
+ def _hard_assignments(self, xs):
729
+ alpha = self._code(xs)
730
+ res, ind = torch.topk(alpha, 1)
731
+ ns = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
732
+ return ns
733
+
734
+ def hard_assignments(self, xs, batch_size=1024):
735
+ xs = convert_to_tensor(xs, device=self.get_device())
736
+ dataset = CustomDataset(xs)
737
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
738
+
739
+ A = []
740
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
741
+ for X_batch, _ in dataloader:
742
+ a = self._hard_assignments(X_batch)
743
+ A.append(tensor_to_numpy(a))
744
+ pbar.update(1)
745
+
746
+ A = np.concatenate(A)
747
+ return A
748
+
749
+ def mmd_gaussian(self, x, y, sigma=None):
750
+ """
751
+ Compute MMD with Gaussian kernel between samples x and y.
752
+
753
+ Args:
754
+ x: Tensor of shape (n_samples, n_features)
755
+ y: Tensor of shape (m_samples, n_features)
756
+ sigma: Bandwidth of the Gaussian kernel. If None, uses median heuristic.
757
+
758
+ Returns:
759
+ mmd: Squared MMD value.
760
+ """
761
+ n, m = x.size(0), y.size(0)
762
+
763
+ if sigma is None:
764
+ # Median heuristic for bandwidth
765
+ xy = torch.cat([x, y], dim=0)
766
+ pairwise_dist = torch.cdist(xy, xy, p=2)
767
+ sigma = torch.median(pairwise_dist[pairwise_dist > 0]).detach()
768
+
769
+ # Kernel matrices
770
+ xx = torch.exp(-torch.cdist(x, x, p=2)**2 / (2 * sigma**2))
771
+ yy = torch.exp(-torch.cdist(y, y, p=2)**2 / (2 * sigma**2))
772
+ xy = torch.exp(-torch.cdist(x, y, p=2)**2 / (2 * sigma**2))
773
+
774
+ # Compute MMD
775
+ mmd = (xx.sum() - xx.diag().sum()) / (n * (n - 1)) + \
776
+ (yy.sum() - yy.diag().sum()) / (m * (m - 1)) - \
777
+ 2 * xy.mean()
778
+ return mmd
779
+
780
+ def sinkhorn_distance(self, x, y, epsilon=0.01, max_iters=100):
781
+ """
782
+ Compute regularized Wasserstein distance using Sinkhorn algorithm.
783
+
784
+ Args:
785
+ x: Tensor of shape (n, d)
786
+ y: Tensor of shape (m, d)
787
+ epsilon: Regularization parameter
788
+ max_iters: Number of Sinkhorn iterations
789
+
790
+ Returns:
791
+ wasserstein: Approximated Wasserstein distance.
792
+ """
793
+ n, m = x.size(0), y.size(0)
794
+ C = torch.cdist(x, y, p=2)**2 # Cost matrix
795
+
796
+ # Initialize dual variables
797
+ u, v = torch.zeros(n, device=x.device), torch.zeros(m, device=y.device)
798
+
799
+ for _ in range(max_iters):
800
+ u = (torch.logsumexp((C - v[None, :]) / epsilon, dim=1)) / (1 / epsilon)
801
+ v = (torch.logsumexp((C - u[:, None]) / epsilon, dim=0)) / (1 / epsilon)
802
+
803
+ # Compute transport plan and distance
804
+ P = torch.exp((u[:, None] + v[None, :] - C) / epsilon)
805
+ return torch.sum(P * C)
806
+
807
+ def metacell_distance(self, xs, ys=None, metric:Literal['mmd','sinkhorn']='mmd', epsilon:float=0.01, max_iters:int=100):
808
+ n_metacells = ys.shape[1]
809
+ dm2m = torch.zeros(n_metacells, n_metacells)
810
+ mc = np.argmax(ys, axis=1)
811
+
812
+ xs = convert_to_tensor(xs, device=self.get_device())
813
+ combinations = itertools.product(np.arange(n_metacells), np.arange(n_metacells))
814
+ for i,j in combinations:
815
+ i_cells = np.where(mc==i)
816
+ j_cells = np.where(mc==j)
817
+ if metric=='mmd':
818
+ dm2m[i,j] = self.mmd_gaussian(xs[i_cells],xs[j_cells])
819
+ else:
820
+ dm2m[i,j] = self.sinkhorn_distance(xs[i_cells],xs[j_cells],epsilon=epsilon,max_iters=max_iters)
821
+
822
+ return tensor_to_numpy(dm2m)
823
+
824
+ def metacell_similarity(self, xs, ys=None, embed: Literal['l1','l2']='l1', n_neighbors=50, sigma=1, use_diffuse=False, diffusion_time=1):
825
+ if ys is None:
826
+ ys = self.soft_assignments(xs)
827
+ if not use_diffuse:
828
+ ys = convert_to_tensor(ys, device=self.get_device())
829
+ m2m = torch.matmul(ys.T / torch.sum(ys.T, dim=1, keepdim=True), ys)
830
+ m2m = tensor_to_numpy(m2m)
831
+ else:
832
+ if embed=='l1':
833
+ zs = self.get_l1_embedding(xs)
834
+ else:
835
+ zs = self.get_l2_embedding(xs)
836
+ m2m = compute_metacell_diffusion_kernel(zs, ys, n_neighbors=n_neighbors, sigma=sigma, diffusion_time=diffusion_time)
837
+ return m2m
838
+
839
+ def metacell_network(self, affinity_matrix,
840
+ #xs, ys=None,
841
+ #k=10,
842
+ exclude_metacells: list = None):
843
+ #affinity_matrix = self.metacell_similarity(xs, ys, use_diffuse=use_diffuse)
844
+ self.G = nx.Graph()
845
+ self.G.add_nodes_from(np.arange(self.code_size))
846
+
847
+ #if k < affinity_matrix.shape[1]:
848
+ if True:
849
+ for i in np.arange(self.code_size):
850
+ arr = affinity_matrix[i,:]
851
+ #kth_largest_value = np.partition(arr, -k)[-k]
852
+ #arr[arr<kth_largest_value] = 0
853
+ #affinity_matrix[i,:] = arr
854
+
855
+ if exclude_metacells is None:
856
+ for j in np.arange(len(arr)):
857
+ if (arr[j]>0) and (j!=i):
858
+ self.G.add_edge(i,j,weight=1/arr[j])
859
+ elif not (i in exclude_metacells):
860
+ for j in np.arange(len(arr)):
861
+ if (arr[j]>0) and (j!=i) and (not j in exclude_metacells):
862
+ self.G.add_edge(i,j,weight=1/arr[j])
863
+
864
+ return self.G
865
+
866
+ def metacell_fa2(self, G, max_iter=100):
867
+ return visualize_metacell_igraph_with_fa2(G, iterations=max_iter)
868
+
869
+ def metacell_tree(self, G, root_metacell=0):
870
+ T = nx.minimum_spanning_tree(G)
871
+ sorted(T.edges(data=True))
872
+ tree = nx.dfs_tree(T, root_metacell)
873
+
874
+ return tree
875
+
876
+ def fit(self, xs,
877
+ ys = None,
878
+ zs = None,
879
+ us = None,
880
+ num_epochs: int = 200,
881
+ learning_rate: float = 0.0001,
882
+ batch_size: int = 512,
883
+ algo: Literal['adam','rmsprop','adamw'] = 'adam',
884
+ beta_1: float = 0.9,
885
+ weight_decay: float = 0.005,
886
+ decay_rate: float = 0.9,
887
+ threshold: int = 0,
888
+ normalize: bool = True,
889
+ config_enum: str = 'parallel',
890
+ use_jax: bool = False):
891
+ """
892
+ Train the SURE model.
893
+
894
+ Parameters
895
+ ----------
896
+ xs
897
+ Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
898
+ us
899
+ Undesired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are undesired factors.
900
+ num_epochs
901
+ Number of training epochs.
902
+ learning_rate
903
+ Parameter for training.
904
+ batch_size
905
+ Size of batch processing.
906
+ algo
907
+ Optimization algorithm.
908
+ beta_1
909
+ Parameter for optimization.
910
+ weight_decay
911
+ Parameter for optimization.
912
+ decay_rate
913
+ Parameter for optimization.
914
+ use_jax
915
+ If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
916
+ the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
917
+ """
918
+ self.normalize = normalize
919
+
920
+ if self.loss_func == 'bernoulli':
921
+ ad = sc.AnnData(xs)
922
+ binarize(ad, threshold=threshold)
923
+ xs = ad.X.copy()
924
+ else:
925
+ xs = np.round(xs)
926
+
927
+ if sparse.issparse(xs):
928
+ xs = xs.toarray()
929
+
930
+ xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
931
+ if ys is not None:
932
+ ys = convert_to_tensor(ys, dtype=self.dtype, device=self.get_device())
933
+ if zs is not None:
934
+ zs = convert_to_tensor(zs, dtype=self.dtype, device=self.get_device())
935
+ if us is not None:
936
+ us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
937
+ else:
938
+ self.use_undesired = False
939
+ self.options = dict(dtype=xs.dtype, device=xs.device)
940
+
941
+ dataset = CustomDataset4(xs, ys, zs, us)
942
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
943
+
944
+ # setup the optimizer
945
+ optim_params = {'lr': learning_rate, 'betas': (beta_1, 0.999), 'weight_decay': weight_decay}
946
+
947
+ if algo.lower()=='rmsprop':
948
+ optimizer = torch.optim.RMSprop
949
+ elif algo.lower()=='adam':
950
+ optimizer = torch.optim.Adam
951
+ elif algo.lower() == 'adamw':
952
+ optimizer = torch.optim.AdamW
953
+ else:
954
+ raise ValueError("An optimization algorithm must be specified.")
955
+ scheduler = ExponentialLR({'optimizer': optimizer, 'optim_args': optim_params, 'gamma': decay_rate})
956
+
957
+ pyro.clear_param_store()
958
+
959
+ # set up the loss(es) for inference, wrapping the guide in config_enumerate builds the loss as a sum
960
+ # by enumerating each class label form the sampled discrete categorical distribution in the model
961
+ Elbo = JitTraceEnum_ELBO if use_jax else TraceEnum_ELBO
962
+ elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
963
+ if ys is None:
964
+ guide = config_enumerate(self.guide, config_enum, expand=True)
965
+ loss_basic = SVI(self.model, guide, scheduler, loss=elbo)
966
+ else:
967
+ guide = config_enumerate(self.guide2, config_enum, expand=True)
968
+ loss_basic = SVI(self.model2, guide, scheduler, loss=elbo)
969
+
970
+ # build a list of all losses considered
971
+ losses = [loss_basic]
972
+ num_losses = len(losses)
973
+
974
+ with tqdm(total=num_epochs, desc='Training', unit='epoch') as pbar:
975
+ for epoch in range(num_epochs):
976
+ epoch_losses = [0.0] * num_losses
977
+ for batch_x, batch_y, batch_z, batch_u, _ in dataloader:
978
+ if us is None:
979
+ batch_u = None
980
+ if ys is None:
981
+ batch_y = None
982
+ if zs is None:
983
+ batch_z = None
984
+
985
+ for loss_id in range(num_losses):
986
+ if batch_y is None:
987
+ new_loss = losses[loss_id].step(batch_x, batch_z, batch_u)
988
+ epoch_losses[loss_id] += new_loss
989
+ else:
990
+ new_loss = losses[loss_id].step(batch_x, batch_y, batch_z, batch_u)
991
+ epoch_losses[loss_id] += new_loss
992
+
993
+ avg_epoch_losses_ = map(lambda v: v / len(dataloader), epoch_losses)
994
+ avg_epoch_losses = map(lambda v: "{:.4f}".format(v), avg_epoch_losses_)
995
+
996
+ # store the loss
997
+ str_loss = " ".join(map(str, avg_epoch_losses))
998
+
999
+ # Update progress bar
1000
+ pbar.set_postfix({'loss': str_loss})
1001
+ pbar.update(1)
1002
+
1003
+ assigns = self.soft_assignments(xs)
1004
+ assigns = convert_to_tensor(assigns, dtype=self.dtype, device=self.get_device())
1005
+ self.codebook_weights = torch.sum(assigns, dim=0)
1006
+ self.codebook_weights = self.codebook_weights / torch.sum(self.codebook_weights)
1007
+
1008
+ @classmethod
1009
+ def save_model(cls, model, file_path, compression=False):
1010
+ """Save the model to the specified file path."""
1011
+ file_path = os.path.abspath(file_path)
1012
+
1013
+ model.eval()
1014
+ if compression:
1015
+ with gzip.open(file_path, 'wb') as pickle_file:
1016
+ pickle.dump(model, pickle_file)
1017
+ else:
1018
+ with open(file_path, 'wb') as pickle_file:
1019
+ pickle.dump(model, pickle_file)
1020
+
1021
+ print(f'Model saved to {file_path}')
1022
+
1023
+ @classmethod
1024
+ def load_model(cls, file_path):
1025
+ """Load the model from the specified file path and return an instance."""
1026
+ print(f'Model loaded from {file_path}')
1027
+
1028
+ file_path = os.path.abspath(file_path)
1029
+ if file_path.endswith('gz'):
1030
+ with gzip.open(file_path, 'rb') as pickle_file:
1031
+ model = pickle.load(pickle_file)
1032
+ else:
1033
+ with open(file_path, 'rb') as pickle_file:
1034
+ model = pickle.load(pickle_file)
1035
+
1036
+ return model
1037
+
1038
+
1039
+ EXAMPLE_RUN = (
1040
+ "example run: HMAP --help"
1041
+ )
1042
+
1043
+ def parse_args():
1044
+ parser = argparse.ArgumentParser(
1045
+ description="HMAP\n{}".format(EXAMPLE_RUN))
1046
+
1047
+ parser.add_argument(
1048
+ "--cuda", action="store_true", help="use GPU(s) to speed up training"
1049
+ )
1050
+ parser.add_argument(
1051
+ "--jit", action="store_true", help="use PyTorch jit to speed up training"
1052
+ )
1053
+ parser.add_argument(
1054
+ "-n", "--num-epochs", default=200, type=int, help="number of epochs to run"
1055
+ )
1056
+ parser.add_argument(
1057
+ "-enum",
1058
+ "--enum-discrete",
1059
+ default="parallel",
1060
+ help="parallel, sequential or none. uses parallel enumeration by default",
1061
+ )
1062
+ parser.add_argument(
1063
+ "-data",
1064
+ "--data-file",
1065
+ default=None,
1066
+ type=str,
1067
+ help="the data file",
1068
+ )
1069
+ parser.add_argument(
1070
+ "-undesired",
1071
+ "--undesired-factor-file",
1072
+ default=None,
1073
+ type=str,
1074
+ help="the file for the record of undesired factors",
1075
+ )
1076
+ parser.add_argument(
1077
+ "-64",
1078
+ "--float64",
1079
+ action="store_true",
1080
+ help="use double float precision",
1081
+ )
1082
+ parser.add_argument(
1083
+ "--z-dist",
1084
+ default='studentt',
1085
+ type=str,
1086
+ choices=['normal','laplacian','cauchy','studentt','gumbel'],
1087
+ help="distribution model for latent representation",
1088
+ )
1089
+ parser.add_argument(
1090
+ "-zd",
1091
+ "--z-dim",
1092
+ default=10,
1093
+ type=int,
1094
+ help="size of the tensor representing the latent variable z",
1095
+ )
1096
+ parser.add_argument(
1097
+ "-cs",
1098
+ "--codebook_size",
1099
+ default=30,
1100
+ type=int,
1101
+ help="size of vector quantization codebook",
1102
+ )
1103
+ parser.add_argument(
1104
+ "-dd",
1105
+ "--d-dim",
1106
+ default=2,
1107
+ type=int,
1108
+ choices=[2,3],
1109
+ help="size of the vector quantization codeword",
1110
+ )
1111
+ parser.add_argument(
1112
+ "--d-dist",
1113
+ default='normal',
1114
+ type=str,
1115
+ choices=['normal','laplacian','cauchy','vonmises','gumbel','studentt'],
1116
+ help="distribution model for visual representation",
1117
+ )
1118
+ parser.add_argument(
1119
+ "-hl",
1120
+ "--hidden-layers",
1121
+ nargs="+",
1122
+ default=[300],
1123
+ type=int,
1124
+ help="a tuple (or list) of MLP layers to be used in the neural networks "
1125
+ "representing the parameters of the distributions in our model",
1126
+ )
1127
+ parser.add_argument(
1128
+ "-hla",
1129
+ "--hidden-layer-activation",
1130
+ default='relu',
1131
+ type=str,
1132
+ choices=['relu','softplus','leakyrelu','linear'],
1133
+ help="activation function for hidden layers",
1134
+ )
1135
+ parser.add_argument(
1136
+ "-plf",
1137
+ "--post-layer-function",
1138
+ nargs="+",
1139
+ default=['layernorm'],
1140
+ type=str,
1141
+ help="post functions for hidden layers, could be none, dropout, layernorm, batchnorm, or combination, default is 'dropout layernorm'",
1142
+ )
1143
+ parser.add_argument(
1144
+ "-paf",
1145
+ "--post-activation-function",
1146
+ nargs="+",
1147
+ default=['none'],
1148
+ type=str,
1149
+ help="post functions for activation layers, could be none or dropout, default is 'none'",
1150
+ )
1151
+ parser.add_argument(
1152
+ "-id",
1153
+ "--inverse-dispersion",
1154
+ default=10.0,
1155
+ type=float,
1156
+ help="inverse dispersion prior for negative binomial",
1157
+ )
1158
+ parser.add_argument(
1159
+ "-lr",
1160
+ "--learning-rate",
1161
+ default=0.0001,
1162
+ type=float,
1163
+ help="learning rate for Adam optimizer",
1164
+ )
1165
+ parser.add_argument(
1166
+ "-dr",
1167
+ "--decay-rate",
1168
+ default=0.9,
1169
+ type=float,
1170
+ help="decay rate for Adam optimizer",
1171
+ )
1172
+ parser.add_argument(
1173
+ "--layer-dropout-rate",
1174
+ default=0.1,
1175
+ type=float,
1176
+ help="droput rate for neural networks",
1177
+ )
1178
+ parser.add_argument(
1179
+ "-b1",
1180
+ "--beta-1",
1181
+ default=0.95,
1182
+ type=float,
1183
+ help="beta-1 parameter for Adam optimizer",
1184
+ )
1185
+ parser.add_argument(
1186
+ "-bs",
1187
+ "--batch-size",
1188
+ default=1000,
1189
+ type=int,
1190
+ help="number of cells to be considered in a batch",
1191
+ )
1192
+ parser.add_argument(
1193
+ "-likeli",
1194
+ "--likelihood",
1195
+ default='poisson',
1196
+ type=str,
1197
+ choices=['negbinomial', 'multinomial', 'poisson'],
1198
+ help="specify the distribution likelihood function",
1199
+ )
1200
+ parser.add_argument(
1201
+ "--seed",
1202
+ default=None,
1203
+ type=int,
1204
+ help="seed for controlling randomness in this example",
1205
+ )
1206
+ parser.add_argument(
1207
+ "--save-model",
1208
+ default=None,
1209
+ type=str,
1210
+ help="path to save model for prediction",
1211
+ )
1212
+ args = parser.parse_args()
1213
+
1214
+ return args
1215
+
1216
+ def set_random_seed(seed):
1217
+ # Set seed for PyTorch
1218
+ torch.manual_seed(seed)
1219
+
1220
+ # If using CUDA, set the seed for CUDA
1221
+ if torch.cuda.is_available():
1222
+ torch.cuda.manual_seed(seed)
1223
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups.
1224
+
1225
+ # Set seed for NumPy
1226
+ np.random.seed(seed)
1227
+
1228
+ # Set seed for Python's random module
1229
+ random.seed(seed)
1230
+
1231
+ # Set seed for Pyro
1232
+ pyro.set_rng_seed(seed)
1233
+
1234
+
1235
+ def main():
1236
+ args = parse_args()
1237
+
1238
+ assert (
1239
+ (args.data_file is not None) and (
1240
+ os.path.exists(args.data_file))
1241
+ ), "data file must be provided"
1242
+
1243
+ if args.float64:
1244
+ dtype = torch.float64
1245
+ torch.set_default_dtype(torch.float64)
1246
+ else:
1247
+ dtype = torch.float32
1248
+ torch.set_default_dtype(torch.float32)
1249
+
1250
+ xs = dt.fread(file=args.data_file, header=True).to_numpy()
1251
+ us = None
1252
+ if args.undesired_factor_file is not None:
1253
+ us = dt.fread(file=args.undesired_factor_file, header=True).to_numpy()
1254
+
1255
+ input_size = xs.shape[1]
1256
+ undesired_size = 0 if us is None else us.shape[1]
1257
+
1258
+ z_dist = args.z_dist
1259
+ d_dist = args.d_dist
1260
+
1261
+ # batch_size: number of cells (and labels) to be considered in a batch
1262
+ hmap = HMAP(
1263
+ input_size=input_size,
1264
+ undesired_size=undesired_size,
1265
+ codebook_size=args.codebook_size,
1266
+ d_dim=args.d_dim,
1267
+ d_dist=d_dist,
1268
+ z_dim=args.z_dim,
1269
+ z_dist=z_dist,
1270
+ hidden_layers=args.hidden_layers,
1271
+ hidden_layer_activation=args.hidden_layer_activation,
1272
+ loss_func=args.likelihood,
1273
+ inverse_dispersion=args.inverse_dispersion,
1274
+ nn_dropout=args.layer_dropout_rate,
1275
+ use_cuda=args.cuda,
1276
+ config_enum=args.enum_discrete,
1277
+ post_layer_fct=args.post_layer_function,
1278
+ post_act_fct=args.post_activation_function,
1279
+ dtype=dtype,
1280
+ seed=args.seed,
1281
+ )
1282
+
1283
+ hmap.fit(xs, us = us,
1284
+ num_epochs=args.num_epochs,
1285
+ learning_rate=args.learning_rate,
1286
+ batch_size=args.batch_size,
1287
+ beta_1=args.beta_1,
1288
+ decay_rate=args.decay_rate,
1289
+ use_jax=args.jit,
1290
+ config_enum=args.enum_discrete,
1291
+ )
1292
+
1293
+ if args.save_model is not None:
1294
+ HMAP.save_model(hmap, args.save_model)
1295
+
1296
+
1297
+ if __name__ == "__main__":
1298
+ main()