CNNreg 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CNNreg/__init__.py +8 -0
- CNNreg/cli.py +67 -0
- CNNreg/data.py +43 -0
- CNNreg/layers.py +91 -0
- CNNreg/losses.py +107 -0
- CNNreg/train.py +134 -0
- cnnreg-0.1.0.dist-info/METADATA +218 -0
- cnnreg-0.1.0.dist-info/RECORD +12 -0
- cnnreg-0.1.0.dist-info/WHEEL +5 -0
- cnnreg-0.1.0.dist-info/entry_points.txt +2 -0
- cnnreg-0.1.0.dist-info/licenses/LICENSE +21 -0
- cnnreg-0.1.0.dist-info/top_level.txt +1 -0
CNNreg/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""
|
|
2
|
+
cnn_deconv: A CNN-based regression package for cell type deconvolution of bulk RNA-seq using scRNA-seq reference.
|
|
3
|
+
"""
|
|
4
|
+
from .data import data_CSE, flatten_list, divide_by_row_sum, reformat_ref
|
|
5
|
+
from .layers import RefCombLayer, SliceSumLayer, CelltypeScaleLayer, StretchLayer, DeconvProp_S1
|
|
6
|
+
from .losses import (loss_spearmanr, loss_prop, loss_ref, loss_scale,
|
|
7
|
+
loss_stretch, loss_epsilon_insensitive)
|
|
8
|
+
from .train import trainProp
|
CNNreg/cli.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import torch
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from .train import trainProp
|
|
5
|
+
from .data import data_CSE, reformat_ref, divide_by_row_sum
|
|
6
|
+
|
|
7
|
+
def main():
|
|
8
|
+
parser = argparse.ArgumentParser(
|
|
9
|
+
description="CNN-based bulk RNA-seq deconvolution",
|
|
10
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
11
|
+
)
|
|
12
|
+
parser.add_argument("-M", dest="runMode", type=str, required=True, choices=['train', 'evaluate', 'predict', 'explain'])
|
|
13
|
+
parser.add_argument("-bulk", dest="bulk", type=str, required=False)
|
|
14
|
+
parser.add_argument("-ref", dest="reference", type=str, required=False)
|
|
15
|
+
parser.add_argument("-pre", dest="prefix", type=str, default="Project")
|
|
16
|
+
parser.add_argument("-o", dest="outDIR", type=str, required=True)
|
|
17
|
+
parser.add_argument("-mP", dest="modelFile", type=str)
|
|
18
|
+
parser.add_argument("-C", dest="kernelSize", type=int, default=10)
|
|
19
|
+
parser.add_argument("-EP", dest="maxEpochCellProp", type=int, default=1000)
|
|
20
|
+
|
|
21
|
+
args = parser.parse_args()
|
|
22
|
+
|
|
23
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
24
|
+
|
|
25
|
+
# Set default model file path if not provided
|
|
26
|
+
model_file = args.modelFile if args.modelFile else f"{args.outDIR}/cellprop_model.pt"
|
|
27
|
+
|
|
28
|
+
paraHash = {
|
|
29
|
+
"runMode": args.runMode,
|
|
30
|
+
"bulk": args.bulk,
|
|
31
|
+
"reference": args.reference,
|
|
32
|
+
"data_out_dir": args.outDIR,
|
|
33
|
+
"max_epoch_cellprop": args.maxEpochCellProp,
|
|
34
|
+
"model_file": model_file,
|
|
35
|
+
"prefix": args.prefix,
|
|
36
|
+
"device": device,
|
|
37
|
+
"n_kernel": args.kernelSize
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
if paraHash["runMode"] == "train":
|
|
41
|
+
df_bulk = pd.read_csv(paraHash["bulk"])
|
|
42
|
+
dataHash = {
|
|
43
|
+
"bulk": torch.tensor(df_bulk.iloc[:,1:].values.transpose(), dtype=torch.float32).to(device),
|
|
44
|
+
"sample": df_bulk.columns.values[range(1, df_bulk.shape[1])],
|
|
45
|
+
"celltype": [f"celltype_{i+1}" for i in range(paraHash["n_kernel"])],
|
|
46
|
+
"CSE": data_CSE(paraHash["reference"], device=device)
|
|
47
|
+
}
|
|
48
|
+
paraHash["n_gene"] = df_bulk.shape[0]
|
|
49
|
+
paraHash["n_sample"] = df_bulk.shape[1]-1
|
|
50
|
+
paraHash["n_celltype"] = paraHash["n_kernel"]
|
|
51
|
+
paraHash["n_ref"] = dataHash["CSE"].expr_cse.shape[0]
|
|
52
|
+
|
|
53
|
+
# Initialize index arrays for celltype and reference features
|
|
54
|
+
k = paraHash["n_gene"] * paraHash["n_celltype"]
|
|
55
|
+
dataHash["idx_feature_celltype"] = []
|
|
56
|
+
for x in range(paraHash["n_celltype"]):
|
|
57
|
+
dataHash["idx_feature_celltype"].append([int(y+x) for y in range(0, k, paraHash["n_celltype"])])
|
|
58
|
+
|
|
59
|
+
dataHash["idx_feature_ref"] = []
|
|
60
|
+
for x in range(paraHash["n_ref"]):
|
|
61
|
+
dataHash["idx_feature_ref"].append([y for y in range(x*k, (x+1)*k)])
|
|
62
|
+
|
|
63
|
+
dataHash["CSE_reformat"] = reformat_ref(dataHash["CSE"].expr_cse)
|
|
64
|
+
trainProp(dataHash, paraHash)
|
|
65
|
+
|
|
66
|
+
if __name__ == "__main__":
|
|
67
|
+
main()
|
CNNreg/data.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
def flatten_list(nested_list):
|
|
6
|
+
"""Flatten a nested list recursively."""
|
|
7
|
+
flattened_list = []
|
|
8
|
+
for item in nested_list:
|
|
9
|
+
if isinstance(item, list):
|
|
10
|
+
flattened_list.extend(flatten_list(item))
|
|
11
|
+
else:
|
|
12
|
+
flattened_list.append(item)
|
|
13
|
+
return flattened_list
|
|
14
|
+
|
|
15
|
+
def divide_by_row_sum(array):
|
|
16
|
+
"""Normalize rows to sum to 1."""
|
|
17
|
+
row_sums = np.sum(array, axis=1, keepdims=True)
|
|
18
|
+
return array / row_sums
|
|
19
|
+
|
|
20
|
+
def reformat_ref(CSE_data):
|
|
21
|
+
"""Flatten the reference CSE tensor."""
|
|
22
|
+
return torch.flatten(torch.t(CSE_data), start_dim=0)
|
|
23
|
+
|
|
24
|
+
class data_CSE(torch.utils.data.Dataset):
|
|
25
|
+
"""Dataset wrapper for Cell Type Specific Expression (CSE)."""
|
|
26
|
+
def __init__(self, f_CSE, device):
|
|
27
|
+
super().__init__()
|
|
28
|
+
df_cse = pd.read_csv(f_CSE)
|
|
29
|
+
sample = df_cse.iloc[:,0]
|
|
30
|
+
expr_cse = df_cse.iloc[:,1:]
|
|
31
|
+
self.expr_cse = torch.tensor(expr_cse.values, dtype=torch.float32).to(device)
|
|
32
|
+
self.sample = sample
|
|
33
|
+
|
|
34
|
+
def __len__(self):
|
|
35
|
+
return self.expr_cse.shape[0]
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, idx):
|
|
38
|
+
if torch.is_tensor(idx):
|
|
39
|
+
idx = idx.tolist()
|
|
40
|
+
return {'expr_cse': self.expr_cse[idx, :], 'sample': self.sample[idx]}
|
|
41
|
+
|
|
42
|
+
def update_data(self, new_CSE):
|
|
43
|
+
self.expr_cse = new_CSE
|
CNNreg/layers.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torchsort
|
|
5
|
+
|
|
6
|
+
class RefCombLayer(nn.Module):
|
|
7
|
+
"""Reference combination layer for CNN deconvolution."""
|
|
8
|
+
def __init__(self, pHash, dHash, device=None):
|
|
9
|
+
super().__init__()
|
|
10
|
+
num = 1.0 / pHash["n_ref"]
|
|
11
|
+
w = torch.empty(pHash["n_ref"] * pHash["n_celltype"])
|
|
12
|
+
self.weight = nn.Parameter(torch.nn.init.uniform_(w, a=num-num/5, b=num+num/5).to(device))
|
|
13
|
+
self.w_expand = torch.flatten(self.weight.expand([pHash["n_gene"], len(w)]))
|
|
14
|
+
|
|
15
|
+
def forward(self, x):
|
|
16
|
+
return torch.mul(x, self.w_expand)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SliceSumLayer(nn.Module):
|
|
20
|
+
"""Sum slices of the reference layer output."""
|
|
21
|
+
def __init__(self, pHash, dHash, device=None):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.slice_column = pHash["n_ref"]
|
|
24
|
+
self.slice_row = pHash["n_celltype"] * pHash["n_gene"]
|
|
25
|
+
|
|
26
|
+
def forward(self, x):
|
|
27
|
+
x = x.reshape(self.slice_row, self.slice_column)
|
|
28
|
+
return torch.sum(x, dim=1)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CelltypeScaleLayer(nn.Module):
|
|
32
|
+
"""Scale expression for each cell type."""
|
|
33
|
+
def __init__(self, pHash, dHash, device=None):
|
|
34
|
+
super().__init__()
|
|
35
|
+
w = torch.empty(pHash["n_celltype"])
|
|
36
|
+
self.weight = nn.Parameter(torch.nn.init.normal_(w, 1.0, 0.1).to(device))
|
|
37
|
+
self.idx = dHash["idx_feature_celltype"]
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
z = x[self.idx[0]] * self.weight[0]
|
|
41
|
+
for ii in range(1, len(self.idx)):
|
|
42
|
+
y = x[self.idx[ii]] * self.weight[ii]
|
|
43
|
+
z = torch.vstack((z,y))
|
|
44
|
+
return torch.flatten(torch.t(z))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class StretchLayer(nn.Module):
|
|
48
|
+
"""Stretch layer to scale features per gene."""
|
|
49
|
+
def __init__(self, pHash, dHash, device=None):
|
|
50
|
+
super().__init__()
|
|
51
|
+
w = torch.empty(pHash["n_gene"])
|
|
52
|
+
self.weight = nn.Parameter(torch.nn.init.normal_(w, 1.0, 0.1).to(device))
|
|
53
|
+
self.w_expand = torch.flatten(torch.t(self.weight.expand((pHash["n_celltype"], -1))))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
return torch.mul(x, self.w_expand)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DeconvProp_S1(nn.Module):
|
|
61
|
+
"""Stage I CNN model for cell proportion estimation."""
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def ini_kernel_weight(pHash):
|
|
65
|
+
"""Initialize kernel weights for conv1 layer."""
|
|
66
|
+
## (out_channels, in_channels, kernel_size)
|
|
67
|
+
arr = np.ones(pHash["n_celltype"])/pHash["n_celltype"]
|
|
68
|
+
arr = np.tile(arr, (pHash["n_sample"], 1, 1))
|
|
69
|
+
w = torch.empty((arr.shape[0], arr.shape[1], arr.shape[2]))
|
|
70
|
+
weights = torch.nn.init.normal_(w, 1.0/pHash["n_celltype"], 0.25*1.0/pHash["n_celltype"]).to(pHash["device"])
|
|
71
|
+
return nn.Parameter(weights)
|
|
72
|
+
|
|
73
|
+
def __init__(self, dHash, pHash):
|
|
74
|
+
super().__init__()
|
|
75
|
+
k = pHash["n_gene"]*pHash["n_celltype"]
|
|
76
|
+
self.refLayer = RefCombLayer(pHash, dHash, device=pHash["device"])
|
|
77
|
+
self.sum = SliceSumLayer(pHash, dHash, device=pHash["device"])
|
|
78
|
+
self.celltypeScaleLayer = CelltypeScaleLayer(pHash, dHash, device=pHash["device"])
|
|
79
|
+
self.stretchLayer = StretchLayer(pHash, dHash, device=pHash["device"])
|
|
80
|
+
self.conv1 = nn.Conv1d(1, pHash["n_sample"], kernel_size=pHash["n_kernel"],
|
|
81
|
+
stride=pHash["n_kernel"], device=pHash["device"], bias=False)
|
|
82
|
+
self.conv1.weight = self.ini_kernel_weight(pHash)
|
|
83
|
+
|
|
84
|
+
def forward(self, x, N_batch, N_feature):
|
|
85
|
+
y = self.refLayer(x)
|
|
86
|
+
y0 = self.sum(y)
|
|
87
|
+
y1 = self.celltypeScaleLayer(y0)
|
|
88
|
+
y2 = self.stretchLayer(y1)
|
|
89
|
+
y = y2.view(N_batch, 1, N_feature)
|
|
90
|
+
y = self.conv1(y).squeeze(0)
|
|
91
|
+
return y, y0, y1, y2
|
CNNreg/losses.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torchsort
|
|
3
|
+
|
|
4
|
+
def spearmanr(pred, target, indx=None, **kw):
|
|
5
|
+
"""Differentiable Spearman correlation using soft ranking."""
|
|
6
|
+
if len(pred.shape) == 1:
|
|
7
|
+
pred = pred.unsqueeze(0)
|
|
8
|
+
if len(target.shape) == 1:
|
|
9
|
+
target = target.unsqueeze(0)
|
|
10
|
+
if indx is not None:
|
|
11
|
+
pred = pred[:, indx]
|
|
12
|
+
target = target[:, indx]
|
|
13
|
+
pred_rank = torchsort.soft_rank(pred, **kw)
|
|
14
|
+
target_rank = torchsort.soft_rank(target, **kw)
|
|
15
|
+
pred_rank_center = pred_rank - pred_rank.mean(dim=-1, keepdim=True)
|
|
16
|
+
target_rank_center = target_rank - target_rank.mean(dim=-1, keepdim=True)
|
|
17
|
+
covariance = (pred_rank_center * target_rank_center).sum(dim=-1)
|
|
18
|
+
pred_std_rank = torch.sqrt((pred_rank_center ** 2).sum(dim=-1))
|
|
19
|
+
target_std_rank = torch.sqrt((target_rank_center ** 2).sum(dim=-1))
|
|
20
|
+
return covariance / (pred_std_rank * target_std_rank)
|
|
21
|
+
|
|
22
|
+
def loss_spearmanr(pred, target, indx=None, **kw):
|
|
23
|
+
"""Spearman correlation loss."""
|
|
24
|
+
return torch.mean(1 - spearmanr(pred, target, indx=indx))
|
|
25
|
+
|
|
26
|
+
def loss_prop(model):
|
|
27
|
+
loss = 0.0
|
|
28
|
+
for param in model.conv1.parameters():
|
|
29
|
+
loss = loss + torch.sum(torch.abs(param[param < 0]))
|
|
30
|
+
#prop_sum = torch.sum(torch.squeeze(param,1),1) ## sum of cell proportion for each sample
|
|
31
|
+
#loss = loss + torch.mean(torch.max(torch.abs(prop_sum-1.0)-0.8, torch.zeros_like(prop_sum)))
|
|
32
|
+
return loss
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def loss_ref(model, pHash):
|
|
36
|
+
loss = 0.0
|
|
37
|
+
for param in model.refLayer.parameters():
|
|
38
|
+
loss = torch.sum(torch.abs(param[param < 0]))
|
|
39
|
+
for x in range(pHash["n_celltype"]):
|
|
40
|
+
idx = [y for y in range(x*pHash["n_ref"], (x+1)*pHash["n_ref"])]
|
|
41
|
+
ref_sum = torch.sum(param[idx])
|
|
42
|
+
#loss = loss + torch.abs(ref_sum-1)
|
|
43
|
+
loss = loss +torch.max(torch.abs(ref_sum-1.0), torch.zeros_like(ref_sum))
|
|
44
|
+
return loss
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def loss_scale(model):
|
|
49
|
+
loss = 0.0
|
|
50
|
+
#for param in model.celltypeScaleLayer.parameters():
|
|
51
|
+
# idx_lo = param < 0.2
|
|
52
|
+
# if torch.sum(idx_lo) > 0:
|
|
53
|
+
# loss = loss + torch.mean((0.2-param[idx_lo])*(0.2-param[idx_lo]))
|
|
54
|
+
# idx_hi = param > 5.0
|
|
55
|
+
# if torch.sum(idx_hi) > 0:
|
|
56
|
+
# loss = loss + torch.mean((5.0-param[idx_hi])*(5.0-param[idx_hi]))
|
|
57
|
+
#return loss
|
|
58
|
+
for param in model.stretchLayer.parameters():
|
|
59
|
+
#loss = torch.sum(torch.abs(param[param < 0.0]))
|
|
60
|
+
#loss = loss + torch.sum(torch.max(torch.abs(param-1.0)-0.1, torch.zeros_like(param)))
|
|
61
|
+
loss = torch.sum(torch.abs(param[param < 0.5]))
|
|
62
|
+
loss = loss + torch.sum(torch.abs(param[param > 2]))
|
|
63
|
+
#for ii in range(pHash["n_celltype"]):
|
|
64
|
+
# loss = loss + loss_spearmanr(x[dHash["idx_feature_celltype"][ii]], dHash["ref_CSE"][ii])
|
|
65
|
+
return loss
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def loss_stretch(model,x, dHash, pHash):
|
|
70
|
+
loss = 0.0
|
|
71
|
+
for param in model.stretchLayer.parameters():
|
|
72
|
+
#loss = torch.sum(torch.abs(param[param < 0.0]))
|
|
73
|
+
#loss = loss + torch.sum(torch.max(torch.abs(param-1.0)-0.1, torch.zeros_like(param)))
|
|
74
|
+
loss = torch.mean(torch.abs(param[param < 0.5]))
|
|
75
|
+
loss = loss + torch.mean(torch.abs(param[param > 2]))
|
|
76
|
+
#for ii in range(pHash["n_celltype"]):
|
|
77
|
+
# loss = loss + loss_spearmanr(x[dHash["idx_feature_celltype"][ii]], dHash["ref_CSE"][ii])
|
|
78
|
+
return loss
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def loss_epsilon_insensitive(prediction, target, epsilon):
|
|
82
|
+
#return torch.mean(torch.max(torch.abs(prediction-target) - epsilon, torch.zeros_like(prediction)))
|
|
83
|
+
return torch.mean(torch.max(torch.abs(prediction-target) - epsilon*target, torch.zeros_like(prediction)))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def loss_epsilon_insensitive_2(prediction, target, epsilon):
|
|
87
|
+
#return torch.mean(torch.max(torch.abs(prediction-target) - epsilon, torch.zeros_like(prediction)))
|
|
88
|
+
return torch.mean(torch.max(torch.abs(prediction-target) - epsilon, torch.zeros_like(prediction)))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def loss_large_insensitive(prediction, target, threshold):
|
|
92
|
+
#return torch.mean(torch.min(torch.abs(prediction-target) , threshold*torch.ones_like(prediction)))
|
|
93
|
+
return torch.mean(torch.min(torch.abs(prediction-target), threshold*target))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def loss_large_insensitive_2(prediction, target, threshold):
|
|
98
|
+
#return torch.mean(torch.min(torch.abs(prediction-target) , threshold*torch.ones_like(prediction)))
|
|
99
|
+
return torch.mean(torch.min(torch.abs(prediction-target), threshold*torch.ones_like(prediction)))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def loss_small_large_insensitive(prediction, target, epsilon, threshold):
|
|
103
|
+
#return torch.mean(torch.min(torch.abs(prediction-target) , threshold*torch.ones_like(prediction)))
|
|
104
|
+
loss = torch.min( torch.max(torch.abs(prediction-target) - epsilon*target, torch.zeros_like(prediction)), threshold*torch.ones_like(prediction))
|
|
105
|
+
return torch.mean(loss)
|
|
106
|
+
|
|
107
|
+
|
CNNreg/train.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from .data import divide_by_row_sum, reformat_ref
|
|
5
|
+
from .layers import DeconvProp_S1
|
|
6
|
+
from .losses import (loss_prop, loss_ref, loss_scale, loss_stretch,
|
|
7
|
+
loss_epsilon_insensitive, loss_epsilon_insensitive_2,
|
|
8
|
+
spearmanr)
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torchmetrics.functional import pearson_corrcoef
|
|
11
|
+
|
|
12
|
+
# https://pytorch.org/docs/0.3.0/optim.html#per-parameter-options
|
|
13
|
+
# https://discuss.pytorch.org/t/parameters-with-requires-grad-false-are-updated-during-training/90096/9
|
|
14
|
+
# https://www.youtube.com/watch?v=DbeIqrwb_dE
|
|
15
|
+
|
|
16
|
+
#### initialize weigth with some randomness improves performance
|
|
17
|
+
#### use of optim.Adam worsen the performance a lot
|
|
18
|
+
#### block = 2 seems perform better than larger block size
|
|
19
|
+
#### tuning celltypeScale layer and cell prortion layer separately in S1 performs better than tuning them together.
|
|
20
|
+
#### using StretchLayer_2 in S2 worsen the performance. Should stick to StretchLayer
|
|
21
|
+
|
|
22
|
+
def trainProp(dHash, pHash):
|
|
23
|
+
"""
|
|
24
|
+
Train the Stage I CNN model for cell proportion estimation.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
dHash : dict
|
|
29
|
+
Data dictionary containing bulk RNA-seq and reference data.
|
|
30
|
+
pHash : dict
|
|
31
|
+
Parameter dictionary with model configuration.
|
|
32
|
+
"""
|
|
33
|
+
## Stage I: scale expression for each cell type and adjust cell proportion
|
|
34
|
+
|
|
35
|
+
torch.manual_seed(1334)
|
|
36
|
+
torch.cuda.manual_seed(1334)
|
|
37
|
+
net_Prop_S1 = DeconvProp_S1(dHash, pHash)
|
|
38
|
+
net_Prop_S1.train()
|
|
39
|
+
block = 4
|
|
40
|
+
loss = nn.L1Loss() # nn.MSELoss() nn.HuberLoss()
|
|
41
|
+
target = dHash["bulk"] # (N_sample, N_feature)
|
|
42
|
+
target_mean = torch.mean(target, 0) + 0.0001
|
|
43
|
+
target_mean_adj = 2.0*torch.tanh(target_mean+0.1)
|
|
44
|
+
#condition2 = torch.logical_and(target_mean > 0.01, target_mean <= 0.5)
|
|
45
|
+
#target_mean_adj = torch.tensor(torch.where(condition1, torch.tensor(0.25).double().to(pHash["device"]), torch.where(condition2, torch.tensor(0.5).double().to(pHash["device"]), torch.tensor(1.0).double().to(pHash["device"]))), dtype=torch.float64, device=pHash["device"])
|
|
46
|
+
LR = 0.02 #max(0.0001, min(0.001, 1-epoch/N)) ## cannot use bigger than 0.05
|
|
47
|
+
N = pHash["max_epoch_cellprop"]
|
|
48
|
+
ll_kernel = []
|
|
49
|
+
loss_model = []
|
|
50
|
+
for epoch in range(0, N):
|
|
51
|
+
print("epoch = " + str(epoch))
|
|
52
|
+
modd = epoch % block
|
|
53
|
+
x_predict, x_afterRef, x_afterScale, x_afterStretch = net_Prop_S1(dHash["CSE_reformat"], 1, pHash["n_celltype"]*pHash["n_gene"])
|
|
54
|
+
new_ref = x_afterStretch.reshape([pHash["n_gene"], pHash["n_celltype"]])
|
|
55
|
+
gene_var = torch.var(new_ref, dim=1)
|
|
56
|
+
quantile_marker = torch.quantile(new_ref, 0.5, dim=0)
|
|
57
|
+
ll = []
|
|
58
|
+
for ii in range( pHash["n_celltype"] ):
|
|
59
|
+
mask = torch.ones(new_ref.shape[1], dtype=torch.bool)
|
|
60
|
+
mask[ii] = False # Set the mask to False for the column you want to exclude
|
|
61
|
+
#ll = ll + torch.where((new_ref[:,ii] >= quantile_marker[ii]) & (new_ref[:,ii] > torch.mean(new_ref[:,mask],dim=1)[0]))[0].tolist()
|
|
62
|
+
ll = ll + torch.where((new_ref[:,ii] >= quantile_marker[ii]))[0].tolist()
|
|
63
|
+
indx_marker = list(set(ll))
|
|
64
|
+
gene_cv = gene_var/target_mean
|
|
65
|
+
th = torch.quantile(gene_cv, 0.75)
|
|
66
|
+
indx_1 = torch.where(((gene_cv >= 1.0) | (gene_cv >= th)))[0]
|
|
67
|
+
#indx_2 = list(set(indx_marker + torch.where((target_mean >= 0.5))[0].tolist()))
|
|
68
|
+
indx_2 = list(set(indx_marker).intersection(set(torch.where((target_mean >= 0.05))[0].tolist())))
|
|
69
|
+
indx_3 = torch.where((target_mean < 2.0))[0]
|
|
70
|
+
if modd == 1: # tune reference layer
|
|
71
|
+
train_loss = loss_ref(net_Prop_S1, pHash) + (1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
|
|
72
|
+
#train_loss = loss_ref(net_Prop_S1, pHash) + 1-torch.mean(spearmanr(x_predict[:,indx_3], target[:,indx_3]))
|
|
73
|
+
#0.1*(1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
|
|
74
|
+
train_loss.backward()
|
|
75
|
+
with torch.no_grad():
|
|
76
|
+
net_Prop_S1.refLayer.weight.sub_(net_Prop_S1.refLayer.weight.grad*LR/pHash["n_ref"])
|
|
77
|
+
|
|
78
|
+
elif modd == 2:
|
|
79
|
+
train_loss = loss_scale(net_Prop_S1) + 0.1*(1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
|
|
80
|
+
train_loss.backward()
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
net_Prop_S1.celltypeScaleLayer.weight.sub_(net_Prop_S1.celltypeScaleLayer.weight.grad*LR)
|
|
83
|
+
|
|
84
|
+
elif modd == 3:
|
|
85
|
+
train_loss = loss_stretch(net_Prop_S1, x_afterStretch, dHash, pHash) + 0.1*(1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
|
|
86
|
+
train_loss.backward()
|
|
87
|
+
with torch.no_grad():
|
|
88
|
+
net_Prop_S1.stretchLayer.weight.sub_(net_Prop_S1.stretchLayer.weight.grad*LR*pHash["n_gene"]/100)
|
|
89
|
+
|
|
90
|
+
else:
|
|
91
|
+
rho_1 = spearmanr(x_predict[:,indx_1].t(), target[:,indx_1].t())
|
|
92
|
+
rho_2 = spearmanr(x_predict[:,indx_2], target[:,indx_2])
|
|
93
|
+
train_loss = loss_prop(net_Prop_S1) + \
|
|
94
|
+
loss_epsilon_insensitive(x_predict[:,indx_3]/target_mean_adj[indx_3], target[:,indx_3]/target_mean_adj[indx_3], 0.05) + \
|
|
95
|
+
0.1*loss_epsilon_insensitive_2(rho_1, torch.ones_like(rho_1), 0.05) + \
|
|
96
|
+
0.1*loss_epsilon_insensitive_2(rho_2, torch.ones_like(rho_2), 0.05)
|
|
97
|
+
train_loss.backward()
|
|
98
|
+
with torch.no_grad():
|
|
99
|
+
net_Prop_S1.conv1.weight.sub_(net_Prop_S1.conv1.weight.grad*LR)
|
|
100
|
+
ll_kernel.append(net_Prop_S1.conv1.weight.tolist())
|
|
101
|
+
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
net_Prop_S1.refLayer.weight.grad.zero_()
|
|
104
|
+
net_Prop_S1.celltypeScaleLayer.weight.grad.zero_()
|
|
105
|
+
net_Prop_S1.stretchLayer.weight.grad.zero_()
|
|
106
|
+
net_Prop_S1.conv1.weight.grad.zero_()
|
|
107
|
+
## output estimation
|
|
108
|
+
if epoch % 1000 == 0:
|
|
109
|
+
print("celltypeScaleLayer weight")
|
|
110
|
+
print(net_Prop_S1.celltypeScaleLayer.weight.tolist())
|
|
111
|
+
cellprop = np.squeeze(np.array(net_Prop_S1.conv1.weight.tolist()), 1)
|
|
112
|
+
cellprop = divide_by_row_sum(cellprop)
|
|
113
|
+
df = pd.concat([pd.DataFrame(dHash["sample"]), pd.DataFrame(cellprop)], axis=1)
|
|
114
|
+
x = ["Sample"]
|
|
115
|
+
x.extend(dHash["celltype"])
|
|
116
|
+
df.columns = x
|
|
117
|
+
df.to_csv(pHash["data_out_dir"]+ "/" + "Prop_predicted_" + pHash["prefix"] + "_epoch_" + str(epoch) + ".csv", index=False)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
for name, param in net_Prop_S1.named_parameters():
|
|
121
|
+
print(name, param)
|
|
122
|
+
|
|
123
|
+
### output estimated proportion from Stage I
|
|
124
|
+
cellprop = np.squeeze(np.array(net_Prop_S1.conv1.weight.tolist()), 1)
|
|
125
|
+
df = pd.concat([pd.DataFrame(dHash["sample"]), pd.DataFrame(cellprop)], axis=1)
|
|
126
|
+
x = ["Sample"]
|
|
127
|
+
x.extend(dHash["celltype"])
|
|
128
|
+
df.columns = x
|
|
129
|
+
df.to_csv(pHash["data_out_dir"]+ "/" + "Prop_predicted_" + pHash["prefix"] + ".csv", index=False)
|
|
130
|
+
# save final model
|
|
131
|
+
torch.save(net_Prop_S1.state_dict(), pHash["model_file"])
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
####################################################################################################################################################################
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: CNNreg
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: CNN-based regression for cell type deconvolution of bulk RNA-seq using scRNA-seq reference
|
|
5
|
+
Author-email: Xue Wang <wang.xue@mayo.edu>
|
|
6
|
+
Maintainer-email: Yuanhang Liu <liu.yuanhang@mayo.edu>, Xue Wang <wang.xue@mayo.edu>
|
|
7
|
+
License: MIT
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: torch>=1.10.0
|
|
12
|
+
Requires-Dist: torchmetrics>=0.10.0
|
|
13
|
+
Requires-Dist: torchsort>=0.1.0
|
|
14
|
+
Requires-Dist: numpy
|
|
15
|
+
Requires-Dist: pandas
|
|
16
|
+
Dynamic: license-file
|
|
17
|
+
|
|
18
|
+
# CNNreg: CNN-based Cell Type Deconvolution
|
|
19
|
+
|
|
20
|
+
[](https://www.python.org/downloads/)
|
|
21
|
+
[](https://opensource.org/licenses/MIT)
|
|
22
|
+
|
|
23
|
+
A deep learning approach for cell type deconvolution of bulk RNA-seq data using single-cell RNA-seq reference data. CNNreg employs a custom CNN-based regression model to estimate cell type proportions in complex tissue samples.
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
### Option 1: Using pip (Recommended)
|
|
28
|
+
|
|
29
|
+
CNNreg requires PyTorch. Install PyTorch first according to your system configuration:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
# For CUDA 12.6 (check your CUDA version: nvidia-smi)
|
|
33
|
+
pip3 install torch --index-url https://download.pytorch.org/whl/cu126
|
|
34
|
+
|
|
35
|
+
# For CUDA 11.8
|
|
36
|
+
pip3 install torch --index-url https://download.pytorch.org/whl/cu118
|
|
37
|
+
|
|
38
|
+
# For CPU only
|
|
39
|
+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Then install CNNreg:
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
pip install CNNreg
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
### Option 2: From Source (Development)
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
git clone https://github.com/mwang159/CNNreg.git
|
|
52
|
+
cd CNNreg
|
|
53
|
+
pip install -e .
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
### Option 3: Using Conda Environment
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
# Create environment
|
|
60
|
+
conda create -n cnnreg_env python=3.10
|
|
61
|
+
conda activate cnnreg_env
|
|
62
|
+
|
|
63
|
+
# Install PyTorch (choose appropriate CUDA version)
|
|
64
|
+
pip3 install torch --index-url https://download.pytorch.org/whl/cu126
|
|
65
|
+
|
|
66
|
+
# Install CNNreg
|
|
67
|
+
pip install CNNreg
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### Verify Installation
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
# Check CLI command
|
|
74
|
+
cnnreg --help
|
|
75
|
+
|
|
76
|
+
# Test in Python
|
|
77
|
+
python -c "import CNNreg; print('CNNreg installed successfully!')"
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Quick Start
|
|
81
|
+
|
|
82
|
+
### Command Line Interface
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
cnnreg -M train \
|
|
86
|
+
-bulk data/bulk.csv \
|
|
87
|
+
-ref data/sc_ref.csv \
|
|
88
|
+
-o output/ \
|
|
89
|
+
-C 7 \
|
|
90
|
+
-EP 50000 \
|
|
91
|
+
-pre GBM_analysis
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
**Parameters:**
|
|
95
|
+
- `-M`: Mode - currently only `train` is implemented (evaluate/predict/explain coming in future versions)
|
|
96
|
+
- `-bulk`: Path to bulk RNA-seq CSV file
|
|
97
|
+
- `-ref`: Path to reference cell type specific expression (CSE) CSV
|
|
98
|
+
- `-o`: Output directory
|
|
99
|
+
- `-C`: Number of cell types (kernel size)
|
|
100
|
+
- `-EP`: Maximum training epochs
|
|
101
|
+
- `-pre`: Output file prefix
|
|
102
|
+
|
|
103
|
+
**Note**: Training mode automatically generates cell proportion predictions for the input bulk samples. Predictions are saved at checkpoints (every 1000 epochs) and at completion.
|
|
104
|
+
|
|
105
|
+
### Python API
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
import torch
|
|
109
|
+
import pandas as pd
|
|
110
|
+
from CNNreg.data import data_CSE, reformat_ref
|
|
111
|
+
from CNNreg.train import trainProp
|
|
112
|
+
from CNNreg.layers import DeconvProp_S1
|
|
113
|
+
|
|
114
|
+
# Setup
|
|
115
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
116
|
+
|
|
117
|
+
# Load data
|
|
118
|
+
df_bulk = pd.read_csv("bulk.csv")
|
|
119
|
+
bulk_data = torch.tensor(
|
|
120
|
+
df_bulk.iloc[:, 1:].values.transpose(),
|
|
121
|
+
dtype=torch.float32
|
|
122
|
+
).to(device)
|
|
123
|
+
|
|
124
|
+
# Configure parameters
|
|
125
|
+
pHash = {
|
|
126
|
+
"bulk": "bulk.csv",
|
|
127
|
+
"reference": "sc_ref.csv",
|
|
128
|
+
"data_out_dir": "output/",
|
|
129
|
+
"max_epoch_cellprop": 50000,
|
|
130
|
+
"model_file": "output/model.pt",
|
|
131
|
+
"prefix": "GBM",
|
|
132
|
+
"device": device,
|
|
133
|
+
"n_kernel": 7,
|
|
134
|
+
"n_gene": df_bulk.shape[0],
|
|
135
|
+
"n_sample": df_bulk.shape[1] - 1,
|
|
136
|
+
"n_celltype": 7
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# Initialize data dictionary
|
|
140
|
+
dHash = {
|
|
141
|
+
"bulk": bulk_data,
|
|
142
|
+
"sample": df_bulk.columns.values[1:],
|
|
143
|
+
"celltype": ["AClike", "MESlike", "NPClike", "OPClike", "OL", "Myeloid", "T"],
|
|
144
|
+
"CSE": data_CSE(pHash["reference"], device=device)
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
# Add required indices
|
|
148
|
+
k = pHash["n_gene"] * pHash["n_celltype"]
|
|
149
|
+
dHash["idx_feature_celltype"] = [
|
|
150
|
+
[int(y+x) for y in range(0, k, pHash["n_celltype"])]
|
|
151
|
+
for x in range(pHash["n_celltype"])
|
|
152
|
+
]
|
|
153
|
+
dHash["CSE_reformat"] = reformat_ref(dHash["CSE"].expr_cse)
|
|
154
|
+
|
|
155
|
+
# Train model
|
|
156
|
+
trainProp(dHash, pHash)
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
## Input Data Format
|
|
160
|
+
|
|
161
|
+
### Bulk RNA-seq Data
|
|
162
|
+
|
|
163
|
+
Rows are samples, columns are genes:
|
|
164
|
+
|
|
165
|
+
```csv
|
|
166
|
+
Sample,Gene1,Gene2,Gene3,...
|
|
167
|
+
Sample1,0.5,1.2,0.8,...
|
|
168
|
+
Sample2,1.1,0.9,1.5,...
|
|
169
|
+
Sample3,0.7,1.3,0.6,...
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
### Reference scRNA-seq Data
|
|
173
|
+
|
|
174
|
+
Cell Type Specific Expression profiles from scRNA-seq:
|
|
175
|
+
|
|
176
|
+
```csv
|
|
177
|
+
CellType,Gene1,Gene2,Gene3,...
|
|
178
|
+
AClike_ref1,0.3,0.8,0.5,...
|
|
179
|
+
AClike_ref2,0.4,0.7,0.6,...
|
|
180
|
+
MESlike_ref1,0.9,0.2,0.4,...
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
## Output Files
|
|
184
|
+
|
|
185
|
+
- `Prop_predicted_{prefix}_epoch_{N}.csv`: Cell proportions at checkpoint epochs (every 1000)
|
|
186
|
+
- `Prop_predicted_{prefix}.csv`: Final estimated cell proportions
|
|
187
|
+
- `cellprop_model.pt`: Trained PyTorch model (state dict)
|
|
188
|
+
|
|
189
|
+
**Output format:**
|
|
190
|
+
|
|
191
|
+
```csv
|
|
192
|
+
Sample,celltype_1,celltype_2,celltype_3,...
|
|
193
|
+
Sample1,0.23,0.15,0.31,...
|
|
194
|
+
Sample2,0.19,0.28,0.22,...
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
## Model Architecture
|
|
198
|
+
|
|
199
|
+
CNNreg uses a custom 5-layer CNN pipeline specifically designed for biological deconvolution:
|
|
200
|
+
|
|
201
|
+
1. **RefCombLayer**: Combines multiple reference samples per cell type
|
|
202
|
+
2. **SliceSumLayer**: Aggregates reference combinations
|
|
203
|
+
3. **CelltypeScaleLayer**: Scales expression for each cell type independently
|
|
204
|
+
4. **StretchLayer**: Applies gene-specific scaling factors
|
|
205
|
+
5. **Conv1D Layer**: Estimates cell proportions via 1D convolution
|
|
206
|
+
|
|
207
|
+
## Citation
|
|
208
|
+
|
|
209
|
+
If you use CNNreg in your research, please cite:
|
|
210
|
+
|
|
211
|
+
## License
|
|
212
|
+
|
|
213
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
214
|
+
|
|
215
|
+
## Contact
|
|
216
|
+
|
|
217
|
+
- **Issues**: [GitHub Issues](https://github.com/yourusername/CNNreg/issues)
|
|
218
|
+
- **Email**: wang.xue@mayo.edu, liu.yuanhang@mayo.edu
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
CNNreg/__init__.py,sha256=LZNMFSvn-lvAp4amVFYy8HnhxlXLQeXt9aT6i-BS2EQ,454
|
|
2
|
+
CNNreg/cli.py,sha256=YErLfKgt0RD4uRw6_rOdFMajRo9akQhj_ncL4YXPKkI,2870
|
|
3
|
+
CNNreg/data.py,sha256=gAvpMOJHj7T7bh2Jp4gKgM5K5A0DvNGIMV6dVi3p0GY,1335
|
|
4
|
+
CNNreg/layers.py,sha256=SMgAR0ojLvOEqUqt4_hyhj1OKmkvIGaXWtIWz44twkM,3569
|
|
5
|
+
CNNreg/losses.py,sha256=zmpjIFhpmgwosoG6Lgrvq0RObqtvMzhyfxrHKFdzgUc,4699
|
|
6
|
+
CNNreg/train.py,sha256=cJVQ3x4IqAgXfORkavna0B5btZtozOf4GVG7l2Zzmr4,7262
|
|
7
|
+
cnnreg-0.1.0.dist-info/licenses/LICENSE,sha256=6le1Ynl_12mYxMvuMCVOeBVHPY01YMyexsWvTBl0r6Y,1064
|
|
8
|
+
cnnreg-0.1.0.dist-info/METADATA,sha256=EnhVzcsm5HkiQpuVD4_bTr9y-WxnutpAE9e2QLeRutk,5699
|
|
9
|
+
cnnreg-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
+
cnnreg-0.1.0.dist-info/entry_points.txt,sha256=oNea6BqtYSR4t5s0q0t9IV4i4k4z6omc4D7c64lx5Uw,43
|
|
11
|
+
cnnreg-0.1.0.dist-info/top_level.txt,sha256=zSYWDE6KCfeTnHz1PQ5KXUVxpZSIVkjuOS9FrgHWJBA,7
|
|
12
|
+
cnnreg-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Xue Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
CNNreg
|