celldl 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
celldl-0.1.1/PKG-INFO ADDED
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.2
2
+ Name: celldl
3
+ Version: 0.1.1
4
+ Summary: CellDL: Defining Cell Identity by Learning Transcriptome Distributions from Single-Cell Data
5
+ Author: Yin yusong
6
+ Author-email: yyusong526@gmail.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.10
11
+ Dynamic: author
12
+ Dynamic: author-email
13
+ Dynamic: classifier
14
+ Dynamic: requires-python
15
+ Dynamic: summary
celldl-0.1.1/README.md ADDED
@@ -0,0 +1,140 @@
1
+ # CellDL: Defining Cell Identity by Learning Transcriptome Distributions
2
+
3
+ [![PyPI version](https://badge.fury.io/py/CellDL.svg)](https://badge.fury.io/py/CellDL)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
5
+
6
+ **CellDL** is a deep probabilistic representation learning framework designed to redefine how cell identity is modeled in single-cell RNA-seq (scRNA-seq) data.
7
+
8
+ ## 📖 Introduction & Motivation
9
+
10
+ Cell identity defines what a cell is, how it functions, and what it can become. Currently, most computational approaches adopt a deterministic paradigm, compressing the cellular transcriptional state into a single, fixed vector. This approach treats a dynamic, stochastic entity as a static point, discarding the variability and uncertainty essential to biological nature.
11
+
12
+ **CellDL moves from point estimates to probabilistic representations.**
13
+
14
+ It represents each cell through a set of gene-wise probability distributions. By leveraging a decoupled deep learning architecture, CellDL captures the full distribution of transcriptional states, preserving biological heterogeneity and variability that traditional methods miss.
15
+
16
+ ## 🚀 Key Features
17
+
18
+ * **Probabilistic Representation**: Models gene expression using parametric distributions (e.g., IZIP, ZINB) rather than fixed values.
19
+ * **Decoupled Architecture**: Uses a shared encoder for global cell state and decoupled heads for inferring gene-specific distribution parameters ($\lambda$, $\phi$, etc.).
20
+ * **Biologically-Informed Denoising**: Reconstructs expression profiles based on the expected values ($\mathbb{E}$) of learned distributions, effectively removing technical noise while keeping biological signals.
21
+ * **Generative Data Augmentation**: Generates realistic synthetic cells via controlled perturbation of learned parameters, facilitating the analysis of rare cell populations.
22
+
23
+ ## 🛠 Model Architecture
24
+
25
+ <div align="center">
26
+ <img src="docs/fig1.png" alt="CellDL Architecture" width="800"/>
27
+ <p><em>Figure 1: Schematic of the CellDL model architecture. The model maps cells to a latent embedding and decodes them into gene-specific distributional parameters.</em></p>
28
+ </div>
29
+
30
+ CellDL employs a **decoupled autoencoder architecture**:
31
+ 1. **Encoder**: Maps the raw count matrix to a non-linear latent embedding.
32
+ 2. **Decoupled Decoders**: Independently infer the parameters of the underlying distribution (e.g., Mean expression rate $\lambda$ and Dropout probability $\phi$).
33
+ 3. **Objective**: Minimizes the difference between the expected value of the predicted distribution and the observed data using a self-supervised expectation-based loss.
34
+
35
+ ## 📦 Installation
36
+
37
+ ### Install from PyPI
38
+ ```bash
39
+ pip install CellDL
40
+ ```
41
+
42
+ ### Install from Source
43
+ ```bash
44
+ git clone https://github.com/yys-arch/CellDL.git
45
+ cd CellDL
46
+ pip install .
47
+ ```
48
+
49
+ **Requirements:** Python >= 3.10, TensorFlow, Scanpy, AnnData, etc.
50
+
51
+ ## 💻 Usage Tutorial
52
+
53
+ ### 1. Data Preprocessing
54
+ CellDL provides a robust preprocessing pipeline including HVG selection and normalization.
55
+
56
+ ```python
57
+ import scanpy as sc
58
+ from CellDL import data_preprocessing
59
+
60
+ # Load data
61
+ adata = sc.read_h5ad("your_data.h5ad")
62
+
63
+ # Preprocess: Filter, Log-normalize, and Select HVGs
64
+ adata = data_preprocessing(
65
+ adata,
66
+ assay="10x 3' v3", # Optional filtering
67
+ gene_mean_min=0.0125,
68
+ gene_mean_max=3,
69
+ gene_disp_min=0.5
70
+ )
71
+ ```
72
+
73
+ ### 2. Model Training
74
+ Initialize and train the model using one of the supported distribution modes. The paper highlights the **IZIP (Independent Zero-Inflated Poisson)** mode.
75
+
76
+ ```python
77
+ from CellDL import build_model, train_model, save_trained_model
78
+
79
+ # Build model with IZIP distribution (Recommended)
80
+ model = build_model(adata, mode='IZIP_mode', bottle_dim=512)
81
+
82
+ # Train
83
+ history = train_model(model, adata, epochs=1000, batch_size=32)
84
+
85
+ # Save
86
+ save_trained_model(model, 'models/celldl_model.keras')
87
+ ```
88
+
89
+ ### 3. Denoising (Signal Reconstruction)
90
+ Reconstruct gene expression using the expected value of the inferred distribution.
91
+
92
+ ```python
93
+ from CellDL import load_trained_model, denoise_data
94
+
95
+ model = load_trained_model('models/celldl_model.keras')
96
+ adata_denoised = denoise_data(model, adata)
97
+
98
+ # Result is stored in .obsm
99
+ print(adata_denoised.obsm['rna_denoised'])
100
+ ```
101
+
102
+ ### 4. Synthetic Data Generation (Sample Expansion)
103
+ Generate synthetic cells to augment rare populations by perturbing the learned parameters.
104
+
105
+ ```python
106
+ from CellDL import generate_sc_synthetic_data
107
+
108
+ # Generate 5 synthetic cells for every original cell
109
+ adata_synthetic = generate_sc_synthetic_data(model, adata, num_samples=5, deviation_scale=0.1)
110
+ ```
111
+
112
+ ## 📊 Supported Distributions
113
+
114
+ While the manuscript focuses on IZIP, the package supports multiple distribution families to fit different data characteristics:
115
+
116
+ * `IZIP_mode`: Independent Zero-Inflated Poisson (**Default**)
117
+ * `ZINB_mode`: Zero-Inflated Negative Binomial
118
+ * `NB_mode`: Negative Binomial
119
+ * `Mix_P_NB_mode`: Mixture of Poisson and NB
120
+ * (See documentation for full list of mixture models)
121
+
122
+ ## 📂 Data Availability
123
+
124
+ The datasets used in our manuscript to benchmark and validate CellDL are publicly available through the [CZ CELLxGENE Discover](https://cellxgene.cziscience.com/) platform.
125
+
126
+ | Dataset / Tissue | File Name / Description | Source Link |
127
+ | :--- | :--- | :--- |
128
+ | **Heart** | Tabula Sapiens - Heart | [Collection Link](https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5) |
129
+ | **Bladder** | Tabula Sapiens - Bladder | [Collection Link](https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5) |
130
+ | **Breast** | scRNA-seq data - all cells | [Collection Link](https://cellxgene.cziscience.com/collections/4195ab4c-20bd-4cd3-8b3d-65601277e731) |
131
+ | **Bone Marrow** | Fetal Bone Marrow (10x) | [Blood and immune development...](https://cellxgene.cziscience.com/) |
132
+ | **Large Intestine**| Tabula Sapiens - Large_Intestine | [Collection Link](https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5) |
133
+ | **Lung** | Tabula Sapiens - Lung | [Collection Link](https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5) |
134
+ | **Skin** | Skin | [Collection Link](https://cellxgene.cziscience.com/collections/43d4bb39-21af-4d05-b973-4c1fed7b916c) |
135
+ | **Spleen** | Tabula Sapiens - Spleen | [Collection Link](https://cellxgene.cziscience.com/collections/e5f58829-1a66-40b5-a624-9046778e74f5) |
136
+ | **iPSC-Derived EBs**<br>(Wellington et al. 2024) | Developmental Regulation of Endothelium | [Collection Link](https://cellxgene.cziscience.com/collections/4a2c25af-558a-45fc-bc9a-54ec44a1d63f) |
137
+
138
+ ## 📧 Contact
139
+
140
+ Email: yyusong526@gmail.com
@@ -0,0 +1,3 @@
1
+ import scanpy as sc
2
+ from functions import (data_preprocessing, build_model, train_model, save_trained_model,
3
+ load_trained_model, denoise_data, generate_sc_synthetic_data)
@@ -0,0 +1,33 @@
1
+ import scanpy as sc
2
+ from functions import (data_preprocessing, build_model, train_model, save_trained_model,
3
+ load_trained_model, denoise_data, generate_sc_synthetic_data)
4
+
5
+
6
+ def main_train():
7
+ scobj = sc.read_h5ad("your dataset")
8
+ scobj = data_preprocessing(scobj)
9
+
10
+ model = build_model(scobj, mode='IZIP_mode')
11
+ train_model(model, scobj, epochs=1000)
12
+
13
+ save_trained_model(model, 'CellDL_model.keras')
14
+
15
+ def main_denoise():
16
+ model = load_trained_model('CellDL_model.keras')
17
+ scobj = sc.read_h5ad("your dataset")
18
+ scobj = data_preprocessing(scobj)
19
+ scobj_denoised = denoise_data(model, scobj)
20
+ return scobj_denoised
21
+
22
+ def main_synthetic():
23
+ model = load_trained_model('CellDL_model.keras')
24
+ scobj = sc.read_h5ad("your dataset")
25
+ scobj = data_preprocessing(scobj)
26
+ scobj_synthetic = generate_sc_synthetic_data(model, scobj)
27
+ return scobj_synthetic
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main_train()
32
+ scobj_denoised = main_denoise()
33
+ scobj_synthetic = main_synthetic()
@@ -0,0 +1,480 @@
1
+ from tqdm import tqdm
2
+ import numpy as np
3
+ import pandas as pd
4
+ import scipy
5
+ import anndata
6
+ import tensorflow as tf
7
+ from tensorflow_probability import distributions as tfd
8
+ import tf_keras.optimizers as opt
9
+ from sklearn.preprocessing import StandardScaler
10
+ from sklearn.metrics import adjusted_mutual_info_score
11
+ from scipy.stats import spearmanr
12
+ import scanpy as sc
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+ from tf_keras import layers, models, losses, callbacks, initializers
16
+ Model = models.Model
17
+ Input = layers.Input
18
+ Dense = layers.Dense
19
+ Activation = layers.Activation
20
+ BatchNormalization = layers.BatchNormalization
21
+ Lambda = layers.Lambda
22
+ PReLU = layers.PReLU
23
+ EarlyStopping = callbacks.EarlyStopping
24
+ MeanSquaredError = losses.MeanSquaredError
25
+ load_model = models.load_model
26
+
27
+
28
+ # ==============================================================================
29
+ # Distribution Mean Functions
30
+ # ==============================================================================
31
+
32
+ @tf.function
33
+ def rna_Negbinom_pmf(inputs):
34
+ """Mean of Negative Binomial (r=dispersion, theta=prob)."""
35
+ r, theta = inputs
36
+ nb = tfd.NegativeBinomial(total_count=r, probs=theta)
37
+ return nb.mean()
38
+
39
+
40
+ def rna_Inflatednegbinom_pmf(inputs):
41
+ """Mean of Zero-Inflated Negative Binomial."""
42
+ r, theta, inflated_loc_prob = inputs
43
+ zinb = tfd.ZeroInflatedNegativeBinomial(total_count=r, probs=theta, inflated_loc_probs=inflated_loc_prob)
44
+ return zinb.mean()
45
+
46
+
47
+ def rna_Inflatedpoisson(inputs):
48
+ """Mean of Zero-Inflated Poisson."""
49
+ lambda_, inflated_loc_prob = inputs
50
+ poissonb = tfd.Poisson(lambda_)
51
+ zip_dist = tfd.Inflated(distribution=poissonb, inflated_loc_probs=inflated_loc_prob)
52
+ return zip_dist.mean()
53
+
54
+
55
+ def rna_Indinflatedpoisson(inputs):
56
+ """Mean of Independent Zero-Inflated Poisson."""
57
+ lambda_, inflated_loc_prob = inputs
58
+ poissonb = tfd.Poisson(lambda_)
59
+ ind_zip = tfd.Independent(
60
+ distribution=tfd.Inflated(distribution=poissonb, inflated_loc_probs=inflated_loc_prob),
61
+ reinterpreted_batch_ndims=0
62
+ )
63
+ return ind_zip.mean()
64
+
65
+
66
+ def rna_Mixpoissonnb(inputs):
67
+ """Mean of Mixture (Poisson + Negative Binomial)."""
68
+ lambda_, r, theta, cat = inputs
69
+ poisson = tfd.Poisson(lambda_)
70
+ nb = tfd.NegativeBinomial(total_count=r, probs=theta)
71
+ mixpoissonnb = tfd.Mixture(
72
+ cat=tfd.Categorical(tf.stack([cat, 1 - cat], axis=-1)),
73
+ components=[poisson, nb]
74
+ )
75
+ return mixpoissonnb.mean()
76
+
77
+
78
+ def rna_zindmixpoissonnb(inputs):
79
+ """Mean of Zero-Inflated Mixture (Poisson + NB)."""
80
+ lambda_, r, theta, cat, inflated_loc_prob = inputs
81
+ poisson = tfd.Poisson(lambda_)
82
+ nb = tfd.NegativeBinomial(total_count=r, probs=theta)
83
+ mixpoissonnb = tfd.Mixture(
84
+ cat=tfd.Categorical(tf.stack([cat, 1 - cat], axis=-1)),
85
+ components=[poisson, nb]
86
+ )
87
+ zindmixpoissonnb = tfd.Inflated(distribution=mixpoissonnb, inflated_loc_probs=inflated_loc_prob)
88
+ return zindmixpoissonnb.mean()
89
+
90
+
91
+ def rna_Mixpoissonlognormal(inputs):
92
+ """Mean of Mixture (Poisson + LogNormal)."""
93
+ lambda_, loc, scale, cat = inputs
94
+ poisson = tfd.Poisson(lambda_)
95
+ lognormal = tfd.LogNormal(loc=loc, scale=scale)
96
+ mixpoissonlognormal = tfd.Mixture(
97
+ cat=tfd.Categorical(tf.stack([cat, 1 - cat], axis=-1)),
98
+ components=[poisson, lognormal]
99
+ )
100
+ return mixpoissonlognormal.mean()
101
+
102
+
103
+ def rna_zindmixpoissonlognormal(inputs):
104
+ """Mean of Zero-Inflated Mixture (Poisson + LogNormal)."""
105
+ lambda_, loc, scale, cat, inflated_loc_prob = inputs
106
+ poisson = tfd.Poisson(lambda_)
107
+ lognormal = tfd.LogNormal(loc=loc, scale=scale)
108
+ mixpoissonlognormal = tfd.Mixture(
109
+ cat=tfd.Categorical(tf.stack([cat, 1 - cat], axis=-1)),
110
+ components=[poisson, lognormal]
111
+ )
112
+ zindmixpoissonlognormal = tfd.Inflated(distribution=mixpoissonlognormal, inflated_loc_probs=inflated_loc_prob)
113
+ return zindmixpoissonlognormal.mean()
114
+
115
+
116
+ def rna_indzindmixpoissonlognormal(inputs):
117
+ """Mean of Independent Zero-Inflated Mixture (Poisson + LogNormal)."""
118
+ lambda_, loc, scale, cat, inflated_loc_prob = inputs
119
+ poisson = tfd.Poisson(lambda_)
120
+ lognormal = tfd.LogNormal(loc=loc, scale=scale)
121
+ mixpoissonlognormal = tfd.Mixture(
122
+ cat=tfd.Categorical(tf.stack([cat, 1 - cat], axis=-1)),
123
+ components=[poisson, lognormal]
124
+ )
125
+ zindmixpoissonlognormal = tfd.Inflated(distribution=mixpoissonlognormal, inflated_loc_probs=inflated_loc_prob)
126
+ ind_zind = tfd.Independent(distribution=zindmixpoissonlognormal, reinterpreted_batch_ndims=0)
127
+ return ind_zind.mean()
128
+
129
+
130
+ def rna_indzindmixnblognormal(inputs):
131
+ """Mean of Independent Zero-Inflated Mixture (NB + LogNormal)."""
132
+ r, theta, loc, scale, cat, inflated_loc_prob = inputs
133
+ nb = tfd.NegativeBinomial(total_count=r, probs=theta)
134
+ lognormal = tfd.LogNormal(loc=loc, scale=scale)
135
+ mixnblognormal = tfd.Mixture(
136
+ cat=tfd.Categorical(tf.stack([cat, 1 - cat], axis=-1)),
137
+ components=[nb, lognormal]
138
+ )
139
+ zindmixnblognormal = tfd.Inflated(distribution=mixnblognormal, inflated_loc_probs=inflated_loc_prob)
140
+ ind_zind = tfd.Independent(distribution=zindmixnblognormal, reinterpreted_batch_ndims=0)
141
+ return ind_zind.mean()
142
+
143
+
144
+ # ==============================================================================
145
+ # Reconstruction Layers
146
+ # ==============================================================================
147
+
148
+ def NB_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
149
+ """Output layer for NB distribution."""
150
+ NorAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 3, 1e10)
151
+ rna_r = Dense(input_dim_rna, kernel_initializer=inikernel, activation=NorAct, name="rna_r")(h_rna_decoder_z)
152
+ rna_theta = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_theta")(
153
+ h_rna_decoder_z)
154
+ rna_mean = Lambda(rna_Negbinom_pmf, output_shape=(input_dim_rna,), name="rna_denoised")([rna_r, rna_theta])
155
+ return rna_mean
156
+
157
+
158
+ def ZINB_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
159
+ """Output layer for ZINB distribution."""
160
+ NorAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 3, 1e10)
161
+ rna_r = Dense(input_dim_rna, kernel_initializer=inikernel, activation=NorAct, name="rna_r")(h_rna_decoder_z)
162
+ rna_theta = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_theta")(
163
+ h_rna_decoder_z)
164
+ rna_zerorate = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
165
+ h_rna_decoder_z)
166
+ rna_mean = Lambda(rna_Inflatednegbinom_pmf, output_shape=(input_dim_rna,), name="rna_denoised")(
167
+ [rna_r, rna_theta, rna_zerorate])
168
+ return rna_mean
169
+
170
+
171
+ def ZIP_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
172
+ """Output layer for ZIP distribution."""
173
+ rna_lambda_ = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(
174
+ h_rna_decoder_z)
175
+ rna_zerorate = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
176
+ h_rna_decoder_z)
177
+ rna_mean = Lambda(rna_Inflatedpoisson, output_shape=(input_dim_rna,), name="rna_denoised")(
178
+ [rna_lambda_, rna_zerorate])
179
+ return rna_mean
180
+
181
+
182
+ def IZIP_reconstruct(input_dim, h_rna_decoder_z, inikernel):
183
+ """Output layer for IZIP distribution."""
184
+ rna_lambda_ = Dense(input_dim, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(h_rna_decoder_z)
185
+ rna_zerorate = Dense(input_dim, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
186
+ h_rna_decoder_z)
187
+ rna_mean = Lambda(rna_Indinflatedpoisson, output_shape=(input_dim,), name="rna_denoised")(
188
+ [rna_lambda_, rna_zerorate])
189
+ return rna_mean
190
+
191
+
192
+ def Mix_P_NB_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
193
+ """Output layer for Mixture (Poisson + NB)."""
194
+ NorAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 3, 1e10)
195
+ rna_lambda_ = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(
196
+ h_rna_decoder_z)
197
+ rna_r = Dense(input_dim_rna, kernel_initializer=inikernel, activation=NorAct, name="rna_r")(h_rna_decoder_z)
198
+ rna_theta = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_theta")(
199
+ h_rna_decoder_z)
200
+ rna_cat = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_cat")(h_rna_decoder_z)
201
+ rna_mean = Lambda(rna_Mixpoissonnb, output_shape=(input_dim_rna,), name="rna_denoised")(
202
+ [rna_lambda_, rna_r, rna_theta, rna_cat])
203
+ return rna_mean
204
+
205
+
206
+ def ZIMix_P_NB_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
207
+ """Output layer for Zero-Inflated Mixture (Poisson + NB)."""
208
+ NorAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 3, 1e10)
209
+ rna_lambda_ = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(
210
+ h_rna_decoder_z)
211
+ rna_r = Dense(input_dim_rna, kernel_initializer=inikernel, activation=NorAct, name="rna_r")(h_rna_decoder_z)
212
+ rna_theta = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_theta")(
213
+ h_rna_decoder_z)
214
+ rna_cat = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_cat")(h_rna_decoder_z)
215
+ rna_zerorate = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
216
+ h_rna_decoder_z)
217
+ rna_mean = Lambda(rna_zindmixpoissonnb, output_shape=(input_dim_rna,), name="rna_denoised")(
218
+ [rna_lambda_, rna_r, rna_theta, rna_cat, rna_zerorate])
219
+ return rna_mean
220
+
221
+
222
+ def Mix_P_logNormal_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
223
+ """Output layer for Mixture (Poisson + LogNormal)."""
224
+ rna_lambda_ = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(
225
+ h_rna_decoder_z)
226
+ rna_loc = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_loc")(h_rna_decoder_z)
227
+ rna_scale = Dense(input_dim_rna, kernel_initializer=inikernel, activation="linear", name="rna_scale")(
228
+ h_rna_decoder_z)
229
+ rna_cat = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_cat")(h_rna_decoder_z)
230
+ rna_mean = Lambda(rna_Mixpoissonlognormal, output_shape=(input_dim_rna,), name="rna_denoised")(
231
+ [rna_lambda_, rna_loc, rna_scale, rna_cat])
232
+ return rna_mean
233
+
234
+
235
+ def ZIMix_P_logNormal_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
236
+ """Output layer for Zero-Inflated Mixture (Poisson + LogNormal)."""
237
+ rna_lambda_ = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(
238
+ h_rna_decoder_z)
239
+ rna_loc = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_loc")(h_rna_decoder_z)
240
+ rna_scale = Dense(input_dim_rna, kernel_initializer=inikernel, activation="linear", name="rna_scale")(
241
+ h_rna_decoder_z)
242
+ rna_cat = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_cat")(h_rna_decoder_z)
243
+ rna_zerorate = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
244
+ h_rna_decoder_z)
245
+ rna_mean = Lambda(rna_zindmixpoissonlognormal, output_shape=(input_dim_rna,), name="rna_denoised")(
246
+ [rna_lambda_, rna_loc, rna_scale, rna_cat, rna_zerorate])
247
+ return rna_mean
248
+
249
+
250
+ def IZIMix_P_logNormal_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
251
+ """Output layer for Independent Zero-Inflated Mixture (Poisson + LogNormal)."""
252
+ rna_lambda_ = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_lambda_")(
253
+ h_rna_decoder_z)
254
+ rna_loc = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_loc")(h_rna_decoder_z)
255
+ rna_scale = Dense(input_dim_rna, kernel_initializer=inikernel, activation="linear", name="rna_scale")(
256
+ h_rna_decoder_z)
257
+ rna_cat = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_cat")(h_rna_decoder_z)
258
+ rna_zerorate = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
259
+ h_rna_decoder_z)
260
+ rna_mean = Lambda(rna_indzindmixpoissonlognormal, output_shape=(input_dim_rna,), name="rna_denoised")(
261
+ [rna_lambda_, rna_loc, rna_scale, rna_cat, rna_zerorate])
262
+ return rna_mean
263
+
264
+
265
+ def IZIMix_NB_logNormal_reconstruct(input_dim_rna, h_rna_decoder_z, inikernel):
266
+ """Output layer for Independent Zero-Inflated Mixture (NB + LogNormal)."""
267
+ NorAct = lambda x: tf.clip_by_value(tf.nn.softplus(x), 3, 1e10)
268
+ rna_r = Dense(input_dim_rna, kernel_initializer=inikernel, activation=NorAct, name="rna_r")(h_rna_decoder_z)
269
+ rna_theta = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_theta")(
270
+ h_rna_decoder_z)
271
+ rna_loc = Dense(input_dim_rna, kernel_initializer=inikernel, activation="relu", name="rna_loc")(h_rna_decoder_z)
272
+ rna_scale = Dense(input_dim_rna, kernel_initializer=inikernel, activation="linear", name="rna_scale")(
273
+ h_rna_decoder_z)
274
+ rna_cat = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_cat")(h_rna_decoder_z)
275
+ rna_zerorate = Dense(input_dim_rna, kernel_initializer=inikernel, activation="sigmoid", name="rna_zerorate")(
276
+ h_rna_decoder_z)
277
+ rna_mean = Lambda(rna_indzindmixnblognormal, output_shape=(input_dim_rna,), name="rna_denoised")(
278
+ [rna_r, rna_theta, rna_loc, rna_scale, rna_cat, rna_zerorate])
279
+ return rna_mean
280
+
281
+
282
+ # ==============================================================================
283
+ # Core Functions
284
+ # ==============================================================================
285
+
286
+ def load_data(filepath, donor_id, assay, gene_mean_min, gene_mean_max, gene_disp_min):
287
+ """Loads and preprocesses data (Legacy wrapper). See data_preprocessing."""
288
+ scobj = sc.read_h5ad(filepath)
289
+ scobj = scobj[scobj.obs['donor_id'] == donor_id, :]
290
+ scobj = scobj[scobj.obs['assay'] == assay, :]
291
+ if scobj.raw.X is not None:
292
+ scobj.X = scobj.raw.X
293
+ scobj.var_names_make_unique()
294
+ scobj.var.index = pd.Index(scobj.var['feature_name'].values)
295
+ sc.pp.log1p(scobj)
296
+ sc.pp.highly_variable_genes(scobj, min_mean=gene_mean_min, max_mean=gene_mean_max, min_disp=gene_disp_min)
297
+ scobj = scobj[:, scobj.var["highly_variable"]]
298
+ scobj.obsm["rna_nor"] = scobj.X.toarray()
299
+ scobj.obsm["X_input"] = 1 + scobj.obsm["rna_nor"]
300
+ scaler = StandardScaler()
301
+ scobj.obsm["X_input"] = scaler.fit_transform(scobj.obsm["X_input"])
302
+ return scobj
303
+
304
+
305
+ def build_model(scobj, seed=100, bottle_dim=512, mode='IZIP_mode'):
306
+ """
307
+ Builds the CellDL model.
308
+ Modes: 'IZIP_mode' (default), 'NB_mode', 'ZINB_mode', 'ZIP_mode', 'Mix_P_NB_mode', etc.
309
+ """
310
+ inikernel = initializers.glorot_uniform(seed=seed)
311
+ if "X_input" not in scobj.obsm:
312
+ raise ValueError("scobj.obsm['X_input'] missing. Run data_preprocessing first.")
313
+ input_dim = scobj.obsm["X_input"].shape[1]
314
+ input_data = Input(shape=(input_dim,), name='X_input')
315
+
316
+ # Encoder
317
+ h = input_data
318
+ for units in [2048, 1024]:
319
+ h = Dense(units, kernel_initializer=inikernel)(h)
320
+ h = BatchNormalization()(h)
321
+ h = PReLU()(h)
322
+
323
+ # Bottleneck
324
+ h = Dense(bottle_dim, kernel_initializer=inikernel, name="rna_features")(h)
325
+ h = Activation("relu")(h)
326
+
327
+ # Decoder
328
+ h = Dense(input_dim, kernel_initializer=inikernel, name="rec_dim")(h)
329
+ h = Activation("relu")(h)
330
+
331
+ # Distribution Heads
332
+ if mode == 'IZIP_mode':
333
+ rna_mean = IZIP_reconstruct(input_dim, h, inikernel)
334
+ elif mode == 'NB_mode':
335
+ rna_mean = NB_reconstruct(input_dim, h, inikernel)
336
+ elif mode == 'ZINB_mode':
337
+ rna_mean = ZINB_reconstruct(input_dim, h, inikernel)
338
+ elif mode == 'ZIP_mode':
339
+ rna_mean = ZIP_reconstruct(input_dim, h, inikernel)
340
+ elif mode == 'Mix_P_NB_mode':
341
+ rna_mean = Mix_P_NB_reconstruct(input_dim, h, inikernel)
342
+ elif mode == 'ZIMix_P_NB_mode':
343
+ rna_mean = ZIMix_P_NB_reconstruct(input_dim, h, inikernel)
344
+ elif mode == 'Mix_P_logNormal_mode':
345
+ rna_mean = Mix_P_logNormal_reconstruct(input_dim, h, inikernel)
346
+ elif mode == 'ZIMix_P_logNormal_mode':
347
+ rna_mean = ZIMix_P_logNormal_reconstruct(input_dim, h, inikernel)
348
+ elif mode == 'IZIMix_P_logNormal_mode':
349
+ rna_mean = IZIMix_P_logNormal_reconstruct(input_dim, h, inikernel)
350
+ elif mode == 'IZIMix_NB_logNormal_mode':
351
+ rna_mean = IZIMix_NB_logNormal_reconstruct(input_dim, h, inikernel)
352
+ else:
353
+ raise ValueError(f"Unknown mode: {mode}")
354
+
355
+ model = Model(inputs=input_data, outputs=rna_mean)
356
+ return model
357
+
358
+
359
+ def train_model(model, scobj, lr=0.001, batch_size=32, epochs=3000):
360
+ """Trains the model using RMSprop and EarlyStopping."""
361
+ optimizer = opt.RMSprop(learning_rate=lr, clipvalue=5)
362
+ model.compile(optimizer=optimizer, loss=MeanSquaredError())
363
+ callbacks_list = [EarlyStopping(monitor="loss", patience=15, verbose=2)]
364
+
365
+ history = model.fit(
366
+ x=scobj.obsm["X_input"], y=scobj.obsm["rna_nor"],
367
+ epochs=epochs, callbacks=callbacks_list,
368
+ batch_size=batch_size, shuffle=True, verbose=1
369
+ )
370
+ return history
371
+
372
+
373
+ def denoise_data(model, scobj):
374
+ """Denoises data by calculating the expected value of the learned distribution."""
375
+ temp_denoised_rna = Model(inputs=model.inputs, outputs=model.get_layer("rna_denoised").output).predict(
376
+ [scobj.obsm["X_input"]])
377
+ scobj.obsm["rna_denoised"] = temp_denoised_rna
378
+ scobj_denoised = sc.AnnData(
379
+ X=temp_denoised_rna, obs=scobj.obs, var=scobj.var,
380
+ obsm=scobj.obsm, layers=scobj.layers, uns=scobj.uns, varm=scobj.varm
381
+ )
382
+ return scobj_denoised
383
+
384
+
385
+ def calculate_spearman_correlation(scobj):
386
+ """Calculates mean Spearman correlation between denoised and raw data."""
387
+ temp_denoised_rna = scobj.obsm["rna_denoised"]
388
+ corr_list = [spearmanr(x, y)[0] for x, y in zip(temp_denoised_rna, scobj.obsm["rna_nor"])]
389
+ return np.mean(corr_list)
390
+
391
+
392
+ def save_trained_model(model, filepath):
393
+ """Saves the trained model to a file."""
394
+ model.save(filepath)
395
+
396
+
397
+ def load_trained_model(filepath):
398
+ """Loads a trained model with custom distribution layers."""
399
+ custom_objects = {
400
+ 'rna_Negbinom_pmf': rna_Negbinom_pmf,
401
+ 'rna_Inflatednegbinom_pmf': rna_Inflatednegbinom_pmf,
402
+ 'rna_Inflatedpoisson': rna_Inflatedpoisson,
403
+ 'rna_Indinflatedpoisson': rna_Indinflatedpoisson,
404
+ 'rna_Mixpoissonnb': rna_Mixpoissonnb,
405
+ 'rna_zindmixpoissonnb': rna_zindmixpoissonnb,
406
+ 'rna_Mixpoissonlognormal': rna_Mixpoissonlognormal,
407
+ 'rna_zindmixpoissonlognormal': rna_zindmixpoissonlognormal,
408
+ 'rna_indzindmixpoissonlognormal': rna_indzindmixpoissonlognormal,
409
+ 'rna_indzindmixnblognormal': rna_indzindmixnblognormal,
410
+ }
411
+ return load_model(filepath, custom_objects=custom_objects, safe_mode=False)
412
+
413
+
414
+ def generate_sc_synthetic_data(model, scobj, num_samples=1, deviation_scale=0.1):
415
+ """Generates synthetic cells by perturbing learned distribution parameters."""
416
+ X_input = scobj.obsm["X_input"]
417
+ num_cells, num_genes = X_input.shape
418
+
419
+ # Extract parameters and predict
420
+ lambda_vals = Model(inputs=model.inputs, outputs=model.get_layer("rna_lambda_").output).predict(X_input)
421
+ zero_vals = Model(inputs=model.inputs, outputs=model.get_layer("rna_zerorate").output).predict(X_input)
422
+
423
+ # Calculate perturbed mean
424
+ mean_vals = np.repeat((1 - zero_vals) * lambda_vals, num_samples, axis=0)
425
+ noise = np.random.uniform(-deviation_scale, deviation_scale, size=mean_vals.shape) * mean_vals
426
+ synthetic_data = np.clip(mean_vals + noise, a_min=0, a_max=None)
427
+
428
+ # Construct metadata
429
+ synthetic_obs = pd.DataFrame(np.repeat(scobj.obs.values, num_samples, axis=0), columns=scobj.obs.columns)
430
+ synthetic_obs['original_cell_index'] = np.repeat(np.arange(num_cells), num_samples)
431
+
432
+
433
+ def data_preprocessing(scobj, assay=None, ID=None, gene_mean_min=0.0125, gene_mean_max=3, gene_disp_min=0.5):
434
+ """
435
+ Preprocess single-cell data for CellDL: filter, normalize, and select HVGs.
436
+
437
+ Args:
438
+ scobj: AnnData object.
439
+ assay: (Optional) Filter by 'assay' column.
440
+ ID: (Optional) Filter by 'donor_id' column.
441
+ gene_mean_min/max, gene_disp_min: Thresholds for Highly Variable Genes.
442
+
443
+ Returns:
444
+ AnnData object with prepared input in `.obsm['X_input']`.
445
+ """
446
+ # Use raw counts if available
447
+ if scobj.raw is not None:
448
+ scobj.X = scobj.raw.X
449
+ scobj.var_names_make_unique()
450
+
451
+ # Filter by assay or ID if specified and columns exist
452
+ if assay is not None:
453
+ if 'assay' in scobj.obs.columns:
454
+ scobj = scobj[scobj.obs['assay'] == assay].copy()
455
+ else:
456
+ warnings.warn(f"'assay' column missing; skipping filter assay='{assay}'.")
457
+
458
+ if ID is not None:
459
+ if 'donor_id' in scobj.obs.columns:
460
+ scobj = scobj[scobj.obs['donor_id'] == ID].copy()
461
+ else:
462
+ warnings.warn(f"'donor_id' column missing; skipping filter ID='{ID}'.")
463
+
464
+ # Use feature names if available
465
+ if 'feature_name' in scobj.var.columns:
466
+ scobj.var.index = pd.Index(scobj.var['feature_name'].values)
467
+
468
+ # Standard preprocessing
469
+ sc.pp.log1p(scobj)
470
+ sc.pp.highly_variable_genes(scobj, min_mean=gene_mean_min, max_mean=gene_mean_max, min_disp=gene_disp_min)
471
+ scobj = scobj[:, scobj.var["highly_variable"]].copy()
472
+
473
+ # Prepare dense input for model (handle sparse matrices)
474
+ scobj.obsm["rna_nor"] = scobj.X.toarray() if scipy.sparse.issparse(scobj.X) else scobj.X
475
+
476
+ # Scale data (StandardScaler)
477
+ scaler = StandardScaler()
478
+ scobj.obsm["X_input"] = scaler.fit_transform(1 + scobj.obsm["rna_nor"])
479
+
480
+ return scobj
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.2
2
+ Name: celldl
3
+ Version: 0.1.1
4
+ Summary: CellDL: Defining Cell Identity by Learning Transcriptome Distributions from Single-Cell Data
5
+ Author: Yin yusong
6
+ Author-email: yyusong526@gmail.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.10
11
+ Dynamic: author
12
+ Dynamic: author-email
13
+ Dynamic: classifier
14
+ Dynamic: requires-python
15
+ Dynamic: summary
@@ -0,0 +1,9 @@
1
+ README.md
2
+ setup.py
3
+ celldl/__init__.py
4
+ celldl/__main__.py
5
+ celldl/functions.py
6
+ celldl.egg-info/PKG-INFO
7
+ celldl.egg-info/SOURCES.txt
8
+ celldl.egg-info/dependency_links.txt
9
+ celldl.egg-info/top_level.txt
@@ -0,0 +1 @@
1
+ celldl
celldl-0.1.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
celldl-0.1.1/setup.py ADDED
@@ -0,0 +1,16 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='celldl',
5
+ version='0.1.1',
6
+ author='Yin yusong',
7
+ author_email='yyusong526@gmail.com',
8
+ description='CellDL: Defining Cell Identity by Learning Transcriptome Distributions from Single-Cell Data',
9
+ packages=find_packages(),
10
+ classifiers=[
11
+ 'Programming Language :: Python :: 3',
12
+ 'License :: OSI Approved :: MIT License',
13
+ 'Operating System :: OS Independent',
14
+ ],
15
+ python_requires='>=3.10',
16
+ )