scmtg 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scmtg-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.1
2
+ Name: scmtg
3
+ Version: 0.0.1
4
+ Summary: scMTG learns generative Markov transitions for single-cell temporal dynamics
5
+ Home-page: https://github.com/liuq-lab/scMTG
6
+ Author: Xuejian Cui
7
+ Author-email: cuixj@hit.edu.cn
8
+ License: MIT
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Operating System :: MacOS :: MacOS X
14
+ Classifier: Operating System :: Microsoft :: Windows
15
+ Classifier: Operating System :: POSIX :: Linux
16
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
17
+ Requires-Python: >=3.8.0
18
+ Requires-Dist: tensorflow-gpu==2.6.0
19
+ Requires-Dist: keras==2.6.0
@@ -0,0 +1,3 @@
1
+ __version__ = '0.0.1'
2
+ from .scMTG import *
3
+ from . util import *
@@ -0,0 +1,132 @@
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import tensorflow.keras.backend as K
4
+
5
+
6
+ def _nan2zero(x):
7
+ return tf.where(tf.math.is_nan(x), tf.zeros_like(x), x)
8
+
9
+ def _nan2inf(x):
10
+ return tf.where(tf.math.is_nan(x), tf.zeros_like(x)+np.inf, x)
11
+
12
+ def _nelem(x):
13
+ nelem = tf.reduce_sum(tf.cast(~tf.math.is_nan(x), tf.float32))
14
+ return tf.cast(tf.where(tf.equal(nelem, 0.), 1., nelem), x.dtype)
15
+
16
+
17
+ def _reduce_mean(x):
18
+ nelem = _nelem(x)
19
+ x = _nan2zero(x)
20
+ return tf.divide(tf.reduce_sum(x), nelem)
21
+
22
+
23
+ def mse_loss(y_true, y_pred):
24
+ ret = tf.square(y_pred - y_true)
25
+
26
+ return _reduce_mean(ret)
27
+
28
+
29
+ def poisson_loss(y_true, y_pred):
30
+ y_pred = tf.cast(y_pred, tf.float32)
31
+ y_true = tf.cast(y_true, tf.float32)
32
+ nelem = _nelem(y_true)
33
+ y_true = _nan2zero(y_true)
34
+
35
+ # last term can be avoided since it doesn't depend on y_pred
36
+ # however keeping it gives a nice lower bound to zero
37
+ ret = y_pred - y_true*tf.math.log(y_pred+1e-10) + tf.math.lgamma(y_true+1.0)
38
+
39
+ return tf.divide(tf.reduce_sum(ret), nelem)
40
+
41
+ class NB(object):
42
+ def __init__(self, theta=None, masking=False, scope='nbinom_loss/',
43
+ scale_factor=1.0, debug=False):
44
+
45
+ self.eps = 1e-10
46
+ self.scale_factor = scale_factor
47
+ self.debug = debug
48
+ self.scope = scope
49
+ self.masking = masking
50
+ self.theta = theta
51
+
52
+ def loss(self, y_true, y_pred, mean=True):
53
+ scale_factor = self.scale_factor
54
+ eps = self.eps
55
+
56
+ with tf.name_scope(self.scope):
57
+ y_true = tf.cast(y_true, tf.float32)
58
+ y_pred = tf.cast(y_pred, tf.float32) * scale_factor
59
+
60
+ if self.masking:
61
+ nelem = _nelem(y_true)
62
+ y_true = _nan2zero(y_true)
63
+
64
+ # Clip theta
65
+ theta = tf.minimum(self.theta, 1e6)
66
+
67
+ t1 = tf.math.lgamma(1/(theta+eps)+eps) + tf.math.lgamma(y_true+1.0) - tf.math.lgamma(y_true+1/(theta+eps)+eps) #lgamma ln(gamma(x))
68
+ t2 = (1/(theta+eps)+y_true) * tf.math.log(1.0 + y_pred*theta) - (y_true * tf.math.log(theta*y_pred+eps))
69
+
70
+ if self.debug:
71
+ assert_ops = [
72
+ tf.verify_tensor_all_finite(y_pred, 'y_pred has inf/nans'),
73
+ tf.verify_tensor_all_finite(t1, 't1 has inf/nans'),
74
+ tf.verify_tensor_all_finite(t2, 't2 has inf/nans')]
75
+
76
+ tf.summary.histogram('t1', t1)
77
+ tf.summary.histogram('t2', t2)
78
+
79
+ with tf.control_dependencies(assert_ops):
80
+ final = t1 + t2
81
+
82
+ else:
83
+ final = t1 + t2
84
+
85
+ final = _nan2inf(final)
86
+
87
+ if mean:
88
+ if self.masking:
89
+ final = tf.divide(tf.reduce_sum(final), nelem)
90
+ else:
91
+ final = tf.reduce_mean(final)
92
+
93
+ return final
94
+
95
+ class ZINB(NB):
96
+ def __init__(self, pi, ridge_lambda=0.0, scope='zinb_loss/', **kwargs):
97
+ super().__init__(scope=scope, **kwargs)
98
+ self.pi = pi
99
+ self.ridge_lambda = ridge_lambda
100
+
101
+ def loss(self, y_true, y_pred, mean=True):
102
+ scale_factor = self.scale_factor
103
+ eps = self.eps
104
+
105
+ with tf.name_scope(self.scope):
106
+ nb_case = super().loss(y_true, y_pred, mean=False) - tf.math.log(1.0-self.pi+eps)
107
+
108
+ y_true = tf.cast(y_true, tf.float32)
109
+ y_pred = tf.cast(y_pred, tf.float32) * scale_factor
110
+ theta = tf.minimum(self.theta, 1e6)
111
+
112
+ zero_nb = tf.pow((1/theta)/(y_pred+1/theta), 1/theta)
113
+ zero_case = -tf.math.log(self.pi + ((1.0-self.pi)*zero_nb)+eps)
114
+ result = tf.where(tf.less(y_true, 1e-8), zero_case, nb_case)
115
+ ridge = self.ridge_lambda*tf.square(self.pi)
116
+ result += ridge
117
+
118
+ if mean:
119
+ if self.masking:
120
+ result = _reduce_mean(result)
121
+ else:
122
+ result = tf.reduce_mean(result)
123
+
124
+ result = _nan2inf(result)
125
+
126
+ if self.debug:
127
+ tf.summary.histogram('nb_case', nb_case)
128
+ tf.summary.histogram('zero_nb', zero_nb)
129
+ tf.summary.histogram('zero_case', zero_case)
130
+ tf.summary.histogram('ridge', ridge)
131
+
132
+ return result
@@ -0,0 +1,278 @@
1
+ import os
2
+ import tensorflow as tf
3
+ import tensorflow.keras.backend as K
4
+ from keras.layers import Lambda, Layer
5
+
6
+ MeanAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 0, 1e4)
7
+ DispAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 1e-4, 1e4)
8
+
9
+
10
+ class BaseFullyConnectedNet(tf.keras.Model):
11
+ """ Encoder network.
12
+ """
13
+ def __init__(self, input_dim, z_dim, output_dim, model_name, nb_units=[256], concat_every_fcl=False, batchnorm=False, dropout=True, last_relu=False):
14
+ super(BaseFullyConnectedNet, self).__init__()
15
+ self.input_layer = tf.keras.layers.Input((input_dim,))
16
+ self.input_dim = input_dim
17
+ self.z_dim = z_dim
18
+ self.output_dim = output_dim
19
+ self.model_name = model_name
20
+ self.nb_units = nb_units
21
+ self.concat_every_fcl = concat_every_fcl
22
+ self.batchnorm = batchnorm
23
+ self.dropout = dropout
24
+ self.last_relu = last_relu
25
+ self.all_layers = []
26
+ """ Builds the FC stacks. """
27
+ for i in range(len(nb_units)):
28
+ fc_layer = tf.keras.layers.Dense(
29
+ units = self.nb_units[i],
30
+ activation = None,
31
+ # kernel_regularizer = tf.keras.regularizers.L2(2.5e-5)
32
+ )
33
+ norm_layer = tf.keras.layers.BatchNormalization()
34
+ dropout_layer = tf.keras.layers.Dropout(0.1)
35
+ act_layer = tf.keras.layers.LeakyReLU(alpha=0.2)
36
+ # act_layer = tf.keras.layers.ReLU()
37
+ self.all_layers.append([fc_layer, norm_layer, dropout_layer, act_layer])
38
+ fc_layer = tf.keras.layers.Dense(
39
+ units = self.output_dim,
40
+ activation = None,
41
+ # kernel_regularizer = tf.keras.regularizers.L2(2.5e-5),
42
+ # activity_regularizer = tf.keras.regularizers.L1(2.5e-3)
43
+ )
44
+ self.all_layers.append([fc_layer, None, None, None])
45
+
46
+ self.out = self.call(self.input_layer)
47
+
48
+ def call(self, inputs, training=True):
49
+ """ Return the output of the Generator.
50
+ Args:
51
+ inputs: tensor with shape [batch_size, input_dim]
52
+ Returns:
53
+ Output of Generator.
54
+ float32 tensor with shape [batch_size, output_dim]
55
+ """
56
+ y = inputs[:,self.z_dim:]
57
+ for i, layers in enumerate(self.all_layers[:-1]):
58
+ # Run inputs through the sublayers.
59
+ fc_layer, norm_layer, dropout_layer, act_layer = layers
60
+ with tf.name_scope("%s_layer_%d" % (self.model_name, i+1)):
61
+ x = fc_layer(inputs) if i==0 else fc_layer(x)
62
+ if self.batchnorm:
63
+ x = norm_layer(x)
64
+ if self.dropout:
65
+ x = dropout_layer(x)
66
+ x = act_layer(x)
67
+ if self.concat_every_fcl:
68
+ x = tf.keras.layers.concatenate([x,y],axis=1)
69
+ fc_layer, _, _, _ = self.all_layers[-1]
70
+ with tf.name_scope("%s_layer_ouput" % self.model_name):
71
+ output = fc_layer(x)
72
+ # No activation func at last layer
73
+ if self.last_relu:
74
+ # output = tf.keras.layers.ReLU()(output)
75
+ output = tf.keras.activations.tanh(output)
76
+ return output
77
+
78
+ class Decoder2(tf.keras.Model):
79
+ """ Decoder network.
80
+ """
81
+ def __init__(self, input_dim, z_dim, output_dim, model_name, nb_units=[256], batchnorm=False, dropout=True, last_relu=False):
82
+ super(Decoder2, self).__init__()
83
+ self.input_layer = tf.keras.layers.Input((input_dim,))
84
+ self.input_dim = input_dim
85
+ self.z_dim = z_dim
86
+ self.output_dim = output_dim
87
+ self.model_name = model_name
88
+ self.nb_units = nb_units
89
+ self.batchnorm = batchnorm
90
+ self.dropout = dropout
91
+ self.last_relu = last_relu
92
+ self.all_layers = []
93
+ """ Builds the FC stacks. """
94
+ for i in range(len(nb_units)):
95
+ fc_layer = tf.keras.layers.Dense(
96
+ units = self.nb_units[i],
97
+ activation = None,
98
+ # kernel_regularizer = tf.keras.regularizers.L2(2.5e-5)
99
+ )
100
+ norm_layer = tf.keras.layers.BatchNormalization()
101
+ dropout_layer = tf.keras.layers.Dropout(0.1)
102
+ act_layer = tf.keras.layers.LeakyReLU(alpha=0.2)
103
+ # act_layer = tf.keras.layers.ReLU()
104
+ self.all_layers.append([fc_layer, norm_layer, dropout_layer, act_layer])
105
+
106
+ disp_layer = tf.keras.layers.Dense(
107
+ units=self.output_dim,
108
+ activation=DispAct,
109
+ # kernel_initializer='glorot_uniform',
110
+ name='dispersion'
111
+ )
112
+ mean_layer = tf.keras.layers.Dense(
113
+ units=self.output_dim,
114
+ activation=MeanAct,
115
+ # kernel_initializer='glorot_uniform',
116
+ name='mean'
117
+ )
118
+ self.all_layers.append([disp_layer, mean_layer])
119
+
120
+ self.disp, self.mean = self.call(self.input_layer)
121
+
122
+ def call(self, inputs, training=True):
123
+ """ Return the output of the Generator.
124
+ Args:
125
+ inputs: tensor with shape [batch_size, input_dim]
126
+ Returns:
127
+ Output of Generator.
128
+ float32 tensor with shape [batch_size, output_dim]
129
+ """
130
+ for i, layers in enumerate(self.all_layers[:-1]):
131
+ # Run inputs through the sublayers.
132
+ fc_layer, norm_layer, dropout_layer, act_layer = layers
133
+ with tf.name_scope("%s_layer_%d" % (self.model_name, i+1)):
134
+ x = fc_layer(inputs) if i==0 else fc_layer(x)
135
+ if self.batchnorm:
136
+ x = norm_layer(x)
137
+ if self.dropout:
138
+ x = dropout_layer(x)
139
+ x = act_layer(x)
140
+ disp_layer, mean_layer = self.all_layers[-1]
141
+ with tf.name_scope("%s_layer_output" % self.model_name):
142
+ disp = disp_layer(x)
143
+ mean = mean_layer(x)
144
+ # No activation func at last layer
145
+ if self.last_relu:
146
+ mean = tf.keras.layers.ReLU()(mean)
147
+ return disp, mean
148
+
149
+ class Generator(tf.keras.Model):
150
+ """Generator network.
151
+ """
152
+ def __init__(self, input_dim, z_dim, output_dim, model_name, nb_units=[256], concat_every_fcl=False, batchnorm=False, dropout=True, last_relu=False):
153
+ super(Generator, self).__init__()
154
+ self.input_layer = tf.keras.layers.Input((input_dim,))
155
+ self.input_dim = input_dim
156
+ self.z_dim = z_dim
157
+ self.output_dim = output_dim
158
+ self.model_name = model_name
159
+ self.nb_units = nb_units
160
+ self.concat_every_fcl = concat_every_fcl
161
+ self.batchnorm = batchnorm
162
+ self.dropout = dropout
163
+ self.last_relu = last_relu
164
+ self.scale_factor = tf.Variable(initial_value=3.0, trainable=True)
165
+ self.all_layers = []
166
+
167
+ """Builds the FC stacks."""
168
+ for i in range(len(self.nb_units) + 1):
169
+ units = self.output_dim if i == len(nb_units) else self.nb_units[i]
170
+ fc_layer = tf.keras.layers.Dense(
171
+ units = units,
172
+ activation = None,
173
+ # kernel_regularizer = tf.keras.regularizers.L2(2.5e-5)
174
+ )
175
+ norm_layer = tf.keras.layers.BatchNormalization()
176
+ dropout_layer = tf.keras.layers.Dropout(0.1)
177
+ self.all_layers.append([fc_layer, norm_layer, dropout_layer])
178
+ self.out = self.call(self.input_layer)
179
+
180
+ def call(self, inputs, training=True):
181
+ """Return the output of the Generator.
182
+ Args:
183
+ inputs: tensor with shape [batch_size, z_dim + nb_classes]
184
+ Returns:
185
+ Output of Generator.
186
+ float32 tensor with shape [batch_size, output_dim]
187
+ """
188
+ y = inputs[:,self.z_dim:]
189
+ for i, layers in enumerate(self.all_layers[:-1]):
190
+ # Run inputs through the sublayers.
191
+ fc_layer, norm_layer, dropout_layer = layers
192
+ with tf.name_scope("%s_g_layer_%d" % (self.model_name,i)):
193
+ x = fc_layer(inputs) if i==0 else fc_layer(x)
194
+ # if i==len(self.nb_units):
195
+ # x = norm_layer(x)
196
+ if self.batchnorm:
197
+ x = norm_layer(x)
198
+ if self.dropout:
199
+ x = dropout_layer(x)
200
+ x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
201
+ # x = tf.keras.layers.ReLU()(x)
202
+ if self.concat_every_fcl:
203
+ x = tf.keras.layers.concatenate([x,y],axis=1)
204
+ fc_layer, norm_layer, dropout_layer = self.all_layers[-1]
205
+ with tf.name_scope("%s_g_layer_output"%self.model_name):
206
+ output = fc_layer(x)
207
+ # No activation func at last layer
208
+ if self.last_relu:
209
+ # output = tf.keras.layers.ReLU()(output)
210
+ # output = tf.keras.activations.tanh(output)
211
+ output = self.scale_factor * tf.keras.activations.tanh(output)
212
+ return output
213
+
214
+ class Discriminator(tf.keras.Model):
215
+ """Discriminator network.
216
+ """
217
+ def __init__(self, input_dim, z_dim, output_dim, model_name, nb_units=[256], batchnorm=False, dropout=True, last_relu=False):
218
+ super(Discriminator, self).__init__()
219
+ self.input_layer = tf.keras.layers.Input((input_dim,))
220
+ self.input_dim = input_dim
221
+ self.z_dim = z_dim
222
+ self.output_dim = output_dim
223
+ self.model_name = model_name
224
+ self.nb_units = nb_units
225
+ self.batchnorm = batchnorm
226
+ self.dropout = dropout
227
+ self.last_relu = last_relu
228
+ self.all_layers = []
229
+ """Builds the FC stacks."""
230
+ for i in range(len(nb_units) + 1):
231
+ units = self.output_dim if i == len(nb_units) else self.nb_units[i]
232
+ fc_layer = tf.keras.layers.Dense(
233
+ units = units,
234
+ activation = None,
235
+ # kernel_regularizer = tf.keras.regularizers.L2(2.5e-5)
236
+ )
237
+ norm_layer = tf.keras.layers.BatchNormalization()
238
+ dropout_layer = tf.keras.layers.Dropout(0.1)
239
+ self.all_layers.append([fc_layer, norm_layer, dropout_layer])
240
+
241
+ self.out = self.call(self.input_layer)
242
+
243
+ def call(self, inputs, training=True):
244
+ """Return the output of the Discriminator network.
245
+ Args:
246
+ inputs: tensor with shape [batch_size, input_dim]
247
+ Returns:
248
+ Output of Discriminator.
249
+ float32 tensor with shape [batch_size, 1]
250
+ """
251
+ fc_layer, norm_layer, dropout_layer = self.all_layers[0]
252
+ with tf.name_scope("%s_d_layer_0" % self.model_name):
253
+ x = fc_layer(inputs)
254
+ if self.dropout:
255
+ x = dropout_layer(x)
256
+ x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
257
+ # x = tf.keras.layers.ReLU()(x)
258
+
259
+ for i, layers in enumerate(self.all_layers[1:-1]):
260
+ # Run inputs through the sublayers.
261
+ fc_layer, norm_layer, dropout_layer = layers
262
+ with tf.name_scope("%s_d_layer_%d" % (self.model_name,i+1)):
263
+ x = fc_layer(x)
264
+ # if i==len(self.nb_units):
265
+ # x = norm_layer(x)
266
+ if self.batchnorm:
267
+ x = norm_layer(x)
268
+ if self.dropout:
269
+ x = dropout_layer(x)
270
+ # x = tf.keras.activations.tanh(x)
271
+ x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
272
+ # x = tf.keras.layers.ReLU()(x)
273
+ fc_layer, norm_layer, dropout_layer = self.all_layers[-1]
274
+ with tf.name_scope("%s_d_layer_output" % self.model_name):
275
+ output = fc_layer(x)
276
+ if self.last_relu:
277
+ output = tf.keras.layers.ReLU()(output)
278
+ return output
@@ -0,0 +1,295 @@
1
+ import tensorflow as tf
2
+ from .model import Decoder2, Generator, Discriminator, BaseFullyConnectedNet
3
+ import numpy as np
4
+ import pandas as pd
5
+ from .util import Sequential_sampler
6
+ from .loss import NB
7
+ import dateutil.tz
8
+ import datetime
9
+ import sys
10
+ import copy
11
+ import os
12
+ import json
13
+
14
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
16
+
17
+ class scMTG(object):
18
+ """Markov Transition Generative model for time-series single cell analysis.
19
+ """
20
+ def __init__(self, params, timestamp=None, random_seed=None):
21
+ super(scMTG, self).__init__()
22
+ self.params = params
23
+ self.timestamp = timestamp
24
+ if random_seed is not None:
25
+ tf.keras.utils.set_random_seed(random_seed)
26
+
27
+ #initilize the shared autoencoder (encoder + decoder)
28
+ self.encoder = BaseFullyConnectedNet(input_dim=params['e_dim'],z_dim=params['z_dim'],output_dim = params['z_dim'],
29
+ model_name='e_net', nb_units=params['e_units'], last_relu=False)
30
+ self.decoder = Decoder2(input_dim=params['z_dim'],z_dim=params['z_dim'],output_dim = params['e_dim'],
31
+ model_name='d_net', nb_units=params['d_units'], last_relu=False)
32
+
33
+ #initilize the T-1 Markov generators and T-1 discriminators
34
+ self.generators = [Generator(input_dim=params['noise_dim']+params['z_dim'],z_dim=params['z_dim'],
35
+ output_dim=params['z_dim'],model_name='g_net_%d'%i,
36
+ nb_units=params['gen_units'], concat_every_fcl=False, last_relu=True)
37
+ for i in range(params['nb_time']-1)]
38
+ self.discriminators = [Discriminator(input_dim=params['z_dim'],z_dim=params['z_dim'],output_dim = 1,
39
+ model_name='d_net_%d'%i, nb_units=params['dis_units'], last_relu=False)
40
+ for i in range(params['nb_time']-1)]
41
+ lr_schedule1 = tf.keras.optimizers.schedules.ExponentialDecay(params['lr'], decay_steps=100000, decay_rate=0.9, staircase=True)
42
+ lr_schedule2 = tf.keras.optimizers.schedules.ExponentialDecay(params['lr']/10.0, decay_steps=100000, decay_rate=0.9, staircase=True)
43
+ self.ae_optimizer = tf.keras.optimizers.Adam(lr_schedule1, beta_1=0.5, beta_2=0.9)
44
+ self.e_optimizer = tf.keras.optimizers.Adam(lr_schedule2, beta_1=0.5, beta_2=0.9)
45
+ self.g_optimizer = tf.keras.optimizers.Adam(lr_schedule1, beta_1=0.5, beta_2=0.9)
46
+ self.d_optimizer = tf.keras.optimizers.Adam(lr_schedule1, beta_1=0.5, beta_2=0.9)
47
+
48
+ self.initialize_nets()
49
+
50
+ if self.timestamp is None:
51
+ now = datetime.datetime.now(dateutil.tz.tzlocal())
52
+ self.timestamp = now.strftime('%Y%m%d_%H%M%S')
53
+
54
+ self.best_path = "{}/{}/best_model/".format(
55
+ params['output_dir'], self.timestamp)
56
+
57
+ if self.params['save_model'] and not os.path.exists(self.best_path):
58
+ os.makedirs(self.best_path)
59
+
60
+ self.save_dir = "{}/{}".format(
61
+ params['output_dir'], self.timestamp)
62
+
63
+ if self.params['save_res'] and not os.path.exists(self.save_dir):
64
+ os.makedirs(self.save_dir)
65
+
66
+ self.ckpt = tf.train.Checkpoint(encoder = self.encoder,
67
+ decoder = self.decoder,
68
+ generators = self.generators,
69
+ discriminators = self.discriminators,
70
+ ae_optimizer = self.ae_optimizer,
71
+ e_optimizer = self.e_optimizer,
72
+ g_optimizer = self.g_optimizer,
73
+ d_optimizer = self.d_optimizer)
74
+
75
+ def get_config(self):
76
+ return {
77
+ "params": self.params,
78
+ }
79
+
80
+ def initialize_nets(self, print_summary = True):
81
+ """Initialize all the networks in CausalEGM."""
82
+
83
+ self.encoder(np.zeros((1, self.params['e_dim'])))
84
+ self.decoder(np.zeros((1, self.params['z_dim'])), 1.0)
85
+ [self.generators[i](np.zeros((1, self.params['z_dim']+self.params['noise_dim'])))
86
+ for i in range(self.params['nb_time']-1)]
87
+ [self.discriminators[i](np.zeros((1, self.params['z_dim'])))
88
+ for i in range(self.params['nb_time']-1)]
89
+ if print_summary:
90
+ print(self.encoder.summary())
91
+ print(self.decoder.summary())
92
+ print([self.generators[i].summary() for i in range(self.params['nb_time']-1)])
93
+ print([self.discriminators[i].summary() for i in range(self.params['nb_time']-1)])
94
+
95
+ @tf.function
96
+ def train_ae_step(self, data_series):
97
+ """train shared AE.
98
+ """
99
+ with tf.GradientTape(persistent=True) as tape:
100
+ embed_series = tf.map_fn(lambda item:self.encoder(item), data_series)
101
+
102
+ disps, means = [], []
103
+ for i in range(len(embed_series)):
104
+ disp, mean = self.decoder(embed_series[i])
105
+ disps.append(disp)
106
+ means.append(mean)
107
+ zinb = NB(theta=tf.concat(disps, axis=0), debug=False)
108
+ loss_rec = zinb.loss(tf.reshape(data_series, [-1, data_series.shape[-1]]), tf.concat(means, axis=0), mean=True)
109
+
110
+ # Calculate the gradients
111
+ gradients = tape.gradient(loss_rec, self.encoder.trainable_variables+self.decoder.trainable_variables)
112
+ # Apply the gradients to the optimizer
113
+ self.ae_optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables+self.decoder.trainable_variables))
114
+ return loss_rec
115
+
116
+ @tf.function
117
+ def train_gen_step(self, data_series):
118
+ """train generators.
119
+ """
120
+ noise = tf.random.normal(shape=(self.params['nb_time']-1,data_series.shape[1],self.params['noise_dim']),mean=0.,stddev=self.params['sd'])
121
+ with tf.GradientTape(persistent=True) as gen_tape:
122
+ embed_series = tf.map_fn(lambda item:self.encoder(item) ,data_series)
123
+
124
+ #data from time point 1,2,...,T-1
125
+ data_previous = tf.concat([embed_series[:-1],noise],axis=-1)
126
+
127
+ #generated data for time point 2,...,T
128
+ data_gen = tf.TensorArray(tf.float32, size=data_previous.shape[0])
129
+ for i in range(data_previous.shape[0]):
130
+ data_gen = data_gen.write(i, self.generators[i](data_previous[i]))
131
+ data_gen=data_gen.stack()
132
+
133
+ dz_gen = tf.TensorArray(tf.float32, size=data_gen.shape[0])
134
+ for i in range(data_gen.shape[0]):
135
+ dz_gen = dz_gen.write(i, self.discriminators[i](data_gen[i]))
136
+ dz_gen=dz_gen.stack()
137
+
138
+ loss_g = -tf.reduce_mean(dz_gen)
139
+ loss_td = tf.reduce_mean((data_gen-embed_series[:-1])**2)
140
+ loss_g_total = loss_g + self.params['beta']*loss_td
141
+
142
+ # Calculate the gradients
143
+ g_gradients = gen_tape.gradient(loss_g_total, sum([item.trainable_variables for item in self.generators], []))
144
+ self.g_optimizer.apply_gradients(zip(g_gradients, sum([item.trainable_variables for item in self.generators], [])))
145
+
146
+ e_gradients = gen_tape.gradient(loss_g_total, self.encoder.trainable_variables)
147
+ self.e_optimizer.apply_gradients(zip(e_gradients, self.encoder.trainable_variables))
148
+ return loss_g, loss_td, loss_g_total
149
+
150
+ @tf.function
151
+ def train_disc_step(self, data_series):
152
+ """train discriminators.
153
+ """
154
+ epsilon_z = tf.random.uniform(shape=(self.params['nb_time']-1,1,1),minval=0., maxval=1.)
155
+ noise = tf.random.normal(shape=(self.params['nb_time']-1,data_series.shape[1],self.params['noise_dim']),mean=0.,stddev=self.params['sd'])
156
+ with tf.GradientTape(persistent=True) as disc_tape:
157
+ embed_series = tf.map_fn(lambda item:self.encoder(item) ,data_series)
158
+ data_previous = tf.concat([embed_series[:-1],noise],axis=-1)
159
+
160
+ data_gen = tf.TensorArray(tf.float32, size=self.params['nb_time']-1)
161
+ for i in range(self.params['nb_time']-1):
162
+ data_gen = data_gen.write(i, self.generators[i](data_previous[i]))
163
+ data_gen=data_gen.stack()
164
+
165
+ data_true = embed_series[1:]
166
+
167
+ dz_gen = tf.TensorArray(tf.float32, size=self.params['nb_time']-1)
168
+ for i in range(self.params['nb_time']-1):
169
+ dz_gen = dz_gen.write(i, self.discriminators[i](data_gen[i]))
170
+ dz_gen=dz_gen.stack()
171
+
172
+ dz_true = tf.TensorArray(tf.float32, size=self.params['nb_time']-1)
173
+ for i in range(self.params['nb_time']-1):
174
+ dz_true = dz_true.write(i, self.discriminators[i](data_true[i]))
175
+ dz_true = dz_true.stack()
176
+ loss_d = tf.reduce_mean(dz_gen)-tf.reduce_mean(dz_true)
177
+
178
+ #gradient penalty for z
179
+ data_hat = epsilon_z*data_gen+(1-epsilon_z)*data_true
180
+
181
+ dz_hat = tf.TensorArray(tf.float32, size=self.params['nb_time']-1)
182
+ for i in range(self.params['nb_time']-1):
183
+ dz_hat = dz_hat.write(i, self.discriminators[i](data_hat[i]))
184
+ dz_hat=dz_hat.stack()
185
+
186
+ grads = tf.gradients(dz_hat,data_hat)[0]
187
+ grad_norms = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2]))
188
+ loss_gp = tf.reduce_mean(tf.square(grad_norms - 1))
189
+
190
+ loss_d_total = loss_d + self.params['alpha']*loss_gp
191
+
192
+ # Calculate the gradients
193
+ d_gradients = disc_tape.gradient(loss_d_total, sum([item.trainable_variables for item in self.discriminators], []))
194
+ self.d_optimizer.apply_gradients(zip(d_gradients, sum([item.trainable_variables for item in self.discriminators], [])))
195
+ return loss_d, loss_gp, loss_d_total
196
+
197
+ def train(self, data, weights, batch_size=32, n_iter=100000, batches_per_verbose=500, verbose=1):
198
+ if self.params['save_res']:
199
+ f_params = open('{}/params.txt'.format(self.save_dir),'w')
200
+ f_params.write(str(self.params))
201
+ f_params.close()
202
+ self.data_sampler = Sequential_sampler(data=data, weights=weights, batch_size=batch_size)
203
+
204
+ best_loss = float('inf')
205
+ best_ckpt = None
206
+ best_batch_idx = 0
207
+ loss_pd = []
208
+ for batch_idx in range(n_iter+1):
209
+ batch_data_series = self.data_sampler.next_batch()
210
+
211
+ #train autoencoders
212
+ loss_rec = self.train_ae_step(batch_data_series)
213
+
214
+ #control the update frequency of autoencoder vs GAN
215
+ if batch_idx % self.params['ae_gan_freq'] == 0:
216
+ #update G once and update D multiple times for WGAN-GP
217
+ for _ in range(self.params['g_d_freq']):
218
+ batch_data_series = self.data_sampler.next_batch()
219
+ loss_d, loss_gp, loss_d_total = self.train_disc_step(batch_data_series)
220
+ batch_data_series = self.data_sampler.next_batch()
221
+ loss_g, loss_td, loss_g_total = self.train_gen_step(batch_data_series)
222
+
223
+ loss_total = loss_rec + loss_g_total + loss_d_total
224
+ loss_pd.append([batch_idx, loss_total.numpy(), loss_rec.numpy(), loss_gp.numpy(), loss_d.numpy(), loss_g.numpy(), loss_td.numpy()])
225
+ if loss_total < best_loss:
226
+ best_loss = loss_total
227
+ best_ckpt = self.ckpt
228
+ best_batch_idx = batch_idx
229
+ self.ckpt.save('{}/{}/best_model'.format(self.best_path, batch_idx))
230
+ if batch_idx - best_batch_idx >= 10000 and batch_idx >= 50000:
231
+ print("Early stop at {} and best at {}!".format(batch_idx, best_batch_idx))
232
+ break
233
+
234
+ if batch_idx % batches_per_verbose == 0:
235
+ loss_contents = '''Iteration [%d] : loss_t [%.4f], loss_rec [%.4f], loss_gp [%.4f], loss_d [%.4f], loss_g [%.4f] loss_td [%.4f]''' \
236
+ %(batch_idx, loss_total, loss_rec, loss_gp, loss_d, loss_g, loss_td)
237
+ if verbose:
238
+ print(loss_contents)
239
+ loss_pd = pd.DataFrame(loss_pd, columns=['batch_idx', 'loss_total', 'loss_rec', 'loss_gp', 'loss_d', 'loss_g', 'loss_td'])
240
+ loss_pd.to_csv(self.save_dir+'/loss_pd.csv', sep = '\t', index = False)
241
+ # self.ckpt.save('{}/best_model_at_{}'.format(self.best_path, batch_idx))
242
+ self.ckpt.restore(tf.train.latest_checkpoint('{}/best_model_at_{}'.format(self.best_path, best_batch_idx)))
243
+ self.evaluate(self.data_sampler.load_all(), 'best')
244
+
245
+ def evaluate(self, data_series, batch_idx):
246
+ embed_series = [self.encoder.predict(item) for item in data_series]
247
+ data_previous = [np.concatenate([data, np.random.normal(0.,self.params['sd'],size=(data.shape[0],self.params['noise_dim']))],axis=1) for data in embed_series[:-1]] #contain T-1 time points
248
+ data_gen_z = [self.generators[i].predict(data) for i,data in enumerate(data_previous)]
249
+ data_gen_org = [self.decoder.predict(item)[-1] for item in data_gen_z]
250
+
251
+ np.savez('{}/data_embed_at_{}.npz'.format(self.save_dir, batch_idx),embed_series)
252
+ np.savez('{}/data_gen_at_{}.npz'.format(self.save_dir, batch_idx),data_gen_z)
253
+ np.savez('{}/data_gen_org_at_{}.npz'.format(self.save_dir, batch_idx),data_gen_org)
254
+
255
+ def thresholding_(self, trans_mtx):
256
+ trans_mtx2 = trans_mtx / np.sum(trans_mtx, axis=1, keepdims=True)
257
+ trans_mtx2[np.isnan(trans_mtx2)] = 0.0
258
+
259
+ trans_mtx0 = trans_mtx2[np.where(np.sum(trans_mtx2, axis=1)>0)[0]]
260
+ thresh = min(trans_mtx0[i].max() for i in range(trans_mtx0.shape[0]))
261
+ # print(thresh)
262
+
263
+ trans_mtx3 = trans_mtx2.copy()
264
+ trans_mtx3[trans_mtx3<thresh] = 0.0
265
+ trans_mtx3 = trans_mtx3 / np.sum(trans_mtx3, axis=1, keepdims=True)
266
+ trans_mtx3[np.isnan(trans_mtx3)] = 0.0
267
+ return trans_mtx3
268
+
269
+ def compute_trans_mat(self, times=1, n_noise=1000, n_chunk=1000, random_seed=1, thresholding=True, save_mtx=False):
270
+ data_series = self.data_sampler.load_all()
271
+ embed_series = [self.encoder.predict(item) for item in data_series]
272
+
273
+ tf.random.set_seed(random_seed)
274
+ noises = tf.random.normal(shape=(n_noise, self.params['noise_dim']), mean=0.0, stddev=self.params['sd'])
275
+
276
+ embed_data = embed_series[times]
277
+ trans_mtx = []
278
+ const = -0.5 / (self.params['sd'] ** 2)
279
+ for embed_data0 in embed_data:
280
+ embed_gen1 = self.generators[times].predict(tf.concat([tf.tile(tf.reshape(embed_data0,(1,-1)),[n_noise,1]),noises], axis=-1))
281
+ trans2 = []
282
+ for i in range(n_noise//n_chunk):
283
+ trans2.append(tf.reduce_mean(tf.math.exp(const*tf.reduce_sum((tf.reshape(tf.tile(embed_series[times+1], [1,n_chunk]), [embed_series[times+1].shape[0],-1,embed_series[times+1].shape[1]])-embed_gen1[i*n_chunk:(i+1)*n_chunk])**2,axis=-1)),axis=-1))
284
+ trans_mtx.append(tf.reduce_mean(trans2, axis=0))
285
+ trans_mtx = np.array(trans_mtx)
286
+
287
+ if thresholding:
288
+ trans_mtx = self.thresholding_(trans_mtx)
289
+
290
+ if save_mtx:
291
+ np.savez('{}/trans_mtx_{}.npz'.format(self.save_dir, times),t1=trans_mtx)
292
+ return trans_mtx
293
+
294
+
295
+
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ import os
3
+ import math
4
+ import pandas as pd
5
+ import scipy
6
+ import anndata as ad
7
+ import scib
8
+
9
+
10
+ def _logistic(x, L, k, center=0):
11
+ return L / (1 + np.exp(-k * (x - center)))
12
+
13
+ def _gen_logistic(p, sup, inf, center, width):
14
+ return inf + _logistic(p, L=sup - inf, k=4.0 / width, center=center)
15
+
16
+ def beta(p, beta_max=1.7, beta_min=0.3, beta_center=0.25, beta_width=0.5):
17
+ return _gen_logistic(p, beta_max, beta_min, beta_center, beta_width)
18
+
19
+ def delta(a, delta_max=1.7, delta_min=0.3, delta_center=0.1, delta_width=0.2):
20
+ return _gen_logistic(a, delta_max, delta_min, delta_center, delta_width)
21
+
22
+ def growth_rate(adata, proliferation_key="proliferation", apoptosis_key="apoptosis", delta_t=1.0,
23
+ beta_max=1.7, beta_min=0.3, beta_center=0.25, beta_width=0.5,
24
+ delta_max=1.7, delta_min=0.3, delta_center=0.1, delta_width=0.2):
25
+ birth = beta(adata.obs[proliferation_key].values,
26
+ beta_max=beta_max, beta_min=beta_min, beta_center=beta_center, beta_width=beta_width)
27
+ death = delta(adata.obs[apoptosis_key].values,
28
+ delta_max=delta_max, delta_min=delta_min, delta_center=delta_center, delta_width=delta_width)
29
+ gr = np.exp((birth - death) * delta_t)
30
+ return gr
31
+
32
+ class Sequential_sampler(object):
33
+ def __init__(self, data, weights=None, batch_size=32, random_seed=123):
34
+ np.random.seed(random_seed)
35
+ self.data = [np.array(item, dtype='float32') for item in data]
36
+ self.nb_time = len(self.data)
37
+ self.batch_size = batch_size
38
+ self.sample_sizes = [item.shape[0] for item in self.data]
39
+ if weights is None:
40
+ self.weights = [np.ones(item, dtype='float32') / item for item in self.sample_sizes]
41
+ else:
42
+ self.weights = [np.array(item, dtype='float32') / np.sum(item) for item in weights]
43
+ self.idx_gens = [self.create_idx_generator(sample_size=item, time_idx=i) for i,item in enumerate(self.sample_sizes)]
44
+
45
+ def create_idx_generator(self, sample_size, time_idx, random_seed=123):
46
+ while True:
47
+ # indices = np.random.choice(sample_size, size=self.batch_size, replace=True, p=self.weights[time_idx])
48
+ indices = np.random.choice(sample_size, size=self.batch_size, replace=False, p=self.weights[time_idx])
49
+ yield indices
50
+ # indices = np.random.choice(sample_size, size=3*self.batch_size, replace=False, p=self.weights[time_idx])
51
+ # np.random.shuffle(indices)
52
+ # yield indices[:self.batch_size]
53
+
54
+ def next_batch(self):
55
+ indexes = [next(item) for item in self.idx_gens]
56
+ return np.stack([item[indexes[i],:] for i,item in enumerate(self.data)])
57
+
58
+ def load_all(self):
59
+ return self.data
60
+
61
+ def interpolate(p0, p1, tmap, interp_frac, size, seed=1):
62
+ p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0
63
+ p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1
64
+ p0 = np.asarray(p0, dtype=np.float64)
65
+ p1 = np.asarray(p1, dtype=np.float64)
66
+ tmap = np.asarray(tmap, dtype=np.float64)
67
+ if p0.shape[1] != p1.shape[1]:
68
+ raise ValueError("Unable to interpolate. Number of genes do not match")
69
+ if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]:
70
+ raise ValueError("Unable to interpolate. Tmap size is {}, expected {}"
71
+ .format(tmap.shape, (len(p0), len(p1))))
72
+ I = len(p0);
73
+ J = len(p1)
74
+ a = np.power(tmap.sum(axis=0), 1. - interp_frac)
75
+ a[a==0] = np.finfo(float).eps
76
+ p = tmap / a
77
+ p = p.flatten(order='C')
78
+ p = p / p.sum()
79
+ np.random.seed(seed)
80
+ choices = np.random.choice(I * J, p=p, size=size)
81
+ return np.asarray([p0[i // J] * (1 - interp_frac) + p1[i % J] * interp_frac for i in choices], dtype=np.float64)
82
+
83
+ def cal_metrics(gen_data, real_data):
84
+ mmd_value = cal_mmd(gen_data, real_data)
85
+
86
+ x, y = np.mean(real_data, axis=0), np.mean(gen_data, axis=0)
87
+ pearson_corr, _ = scipy.stats.pearsonr(x, y)
88
+ spearman_corr, _ = scipy.stats.spearmanr(x, y)
89
+
90
+ lisi_value = cal_lisi(gen_data, real_data)
91
+
92
+ return mmd_value, pearson_corr, spearman_corr, lisi_value
93
+
94
+ def cal_mmd(f_of_X, f_of_Y):
95
+ loss = 0.0
96
+ delta = f_of_X - f_of_Y
97
+ mmd_value = np.mean((delta[:-1] * delta[1:]).sum(1))
98
+ return mmd_value
99
+
100
+ def cal_lisi(gen_data, real_data):
101
+ adata_scib = ad.AnnData(np.concatenate((gen_data, real_data), axis=0))
102
+ adata_scib.obs['batch'] = pd.Categorical(['gen']*len(gen_data) + ['true']*len(real_data))
103
+ lisi_value = scib.me.ilisi_graph(adata_scib, batch_key="batch", type_="full")
104
+ return lisi_value
105
+
106
+
107
+
108
+
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.1
2
+ Name: scmtg
3
+ Version: 0.0.1
4
+ Summary: scMTG learns generative Markov transitions for single-cell temporal dynamics
5
+ Home-page: https://github.com/liuq-lab/scMTG
6
+ Author: Xuejian Cui
7
+ Author-email: cuixj@hit.edu.cn
8
+ License: MIT
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Operating System :: MacOS :: MacOS X
14
+ Classifier: Operating System :: Microsoft :: Windows
15
+ Classifier: Operating System :: POSIX :: Linux
16
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
17
+ Requires-Python: >=3.8.0
18
+ Requires-Dist: tensorflow-gpu==2.6.0
19
+ Requires-Dist: keras==2.6.0
@@ -0,0 +1,11 @@
1
+ setup.py
2
+ scMTG/__init__.py
3
+ scMTG/loss.py
4
+ scMTG/model.py
5
+ scMTG/scMTG.py
6
+ scMTG/util.py
7
+ scmtg.egg-info/PKG-INFO
8
+ scmtg.egg-info/SOURCES.txt
9
+ scmtg.egg-info/dependency_links.txt
10
+ scmtg.egg-info/requires.txt
11
+ scmtg.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ tensorflow-gpu==2.6.0
2
+ keras==2.6.0
@@ -0,0 +1 @@
1
+ scMTG
scmtg-0.0.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
scmtg-0.0.1/setup.py ADDED
@@ -0,0 +1,32 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ from setuptools import setup, find_packages
4
+
5
+ setup(name='scmtg',
6
+ version='0.0.1',
7
+ packages=find_packages(),
8
+ description='scMTG learns generative Markov transitions for single-cell temporal dynamics',
9
+ long_description='',
10
+
11
+ author='Xuejian Cui',
12
+ author_email='cuixj@hit.edu.cn',
13
+ url="https://github.com/liuq-lab/scMTG",
14
+ python_requires='>=3.8.0',
15
+ license='MIT',
16
+
17
+ classifiers=[
18
+ 'Development Status :: 4 - Beta',
19
+ 'Intended Audience :: Science/Research',
20
+ 'License :: OSI Approved :: MIT License',
21
+ 'Programming Language :: Python :: 3.8',
22
+ 'Operating System :: MacOS :: MacOS X',
23
+ 'Operating System :: Microsoft :: Windows',
24
+ 'Operating System :: POSIX :: Linux',
25
+ 'Topic :: Scientific/Engineering :: Bio-Informatics',
26
+ ],
27
+
28
+ install_requires=[
29
+ 'tensorflow-gpu==2.6.0',
30
+ 'keras==2.6.0',
31
+ ]
32
+ )