CNNreg 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
CNNreg/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ """
2
+ cnn_deconv: A CNN-based regression package for cell type deconvolution of bulk RNA-seq using scRNA-seq reference.
3
+ """
4
+ from .data import data_CSE, flatten_list, divide_by_row_sum, reformat_ref
5
+ from .layers import RefCombLayer, SliceSumLayer, CelltypeScaleLayer, StretchLayer, DeconvProp_S1
6
+ from .losses import (loss_spearmanr, loss_prop, loss_ref, loss_scale,
7
+ loss_stretch, loss_epsilon_insensitive)
8
+ from .train import trainProp
CNNreg/cli.py ADDED
@@ -0,0 +1,67 @@
1
+ import argparse
2
+ import torch
3
+ import pandas as pd
4
+ from .train import trainProp
5
+ from .data import data_CSE, reformat_ref, divide_by_row_sum
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(
9
+ description="CNN-based bulk RNA-seq deconvolution",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+ parser.add_argument("-M", dest="runMode", type=str, required=True, choices=['train', 'evaluate', 'predict', 'explain'])
13
+ parser.add_argument("-bulk", dest="bulk", type=str, required=False)
14
+ parser.add_argument("-ref", dest="reference", type=str, required=False)
15
+ parser.add_argument("-pre", dest="prefix", type=str, default="Project")
16
+ parser.add_argument("-o", dest="outDIR", type=str, required=True)
17
+ parser.add_argument("-mP", dest="modelFile", type=str)
18
+ parser.add_argument("-C", dest="kernelSize", type=int, default=10)
19
+ parser.add_argument("-EP", dest="maxEpochCellProp", type=int, default=1000)
20
+
21
+ args = parser.parse_args()
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Set default model file path if not provided
26
+ model_file = args.modelFile if args.modelFile else f"{args.outDIR}/cellprop_model.pt"
27
+
28
+ paraHash = {
29
+ "runMode": args.runMode,
30
+ "bulk": args.bulk,
31
+ "reference": args.reference,
32
+ "data_out_dir": args.outDIR,
33
+ "max_epoch_cellprop": args.maxEpochCellProp,
34
+ "model_file": model_file,
35
+ "prefix": args.prefix,
36
+ "device": device,
37
+ "n_kernel": args.kernelSize
38
+ }
39
+
40
+ if paraHash["runMode"] == "train":
41
+ df_bulk = pd.read_csv(paraHash["bulk"])
42
+ dataHash = {
43
+ "bulk": torch.tensor(df_bulk.iloc[:,1:].values.transpose(), dtype=torch.float32).to(device),
44
+ "sample": df_bulk.columns.values[range(1, df_bulk.shape[1])],
45
+ "celltype": [f"celltype_{i+1}" for i in range(paraHash["n_kernel"])],
46
+ "CSE": data_CSE(paraHash["reference"], device=device)
47
+ }
48
+ paraHash["n_gene"] = df_bulk.shape[0]
49
+ paraHash["n_sample"] = df_bulk.shape[1]-1
50
+ paraHash["n_celltype"] = paraHash["n_kernel"]
51
+ paraHash["n_ref"] = dataHash["CSE"].expr_cse.shape[0]
52
+
53
+ # Initialize index arrays for celltype and reference features
54
+ k = paraHash["n_gene"] * paraHash["n_celltype"]
55
+ dataHash["idx_feature_celltype"] = []
56
+ for x in range(paraHash["n_celltype"]):
57
+ dataHash["idx_feature_celltype"].append([int(y+x) for y in range(0, k, paraHash["n_celltype"])])
58
+
59
+ dataHash["idx_feature_ref"] = []
60
+ for x in range(paraHash["n_ref"]):
61
+ dataHash["idx_feature_ref"].append([y for y in range(x*k, (x+1)*k)])
62
+
63
+ dataHash["CSE_reformat"] = reformat_ref(dataHash["CSE"].expr_cse)
64
+ trainProp(dataHash, paraHash)
65
+
66
+ if __name__ == "__main__":
67
+ main()
CNNreg/data.py ADDED
@@ -0,0 +1,43 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ def flatten_list(nested_list):
6
+ """Flatten a nested list recursively."""
7
+ flattened_list = []
8
+ for item in nested_list:
9
+ if isinstance(item, list):
10
+ flattened_list.extend(flatten_list(item))
11
+ else:
12
+ flattened_list.append(item)
13
+ return flattened_list
14
+
15
+ def divide_by_row_sum(array):
16
+ """Normalize rows to sum to 1."""
17
+ row_sums = np.sum(array, axis=1, keepdims=True)
18
+ return array / row_sums
19
+
20
+ def reformat_ref(CSE_data):
21
+ """Flatten the reference CSE tensor."""
22
+ return torch.flatten(torch.t(CSE_data), start_dim=0)
23
+
24
+ class data_CSE(torch.utils.data.Dataset):
25
+ """Dataset wrapper for Cell Type Specific Expression (CSE)."""
26
+ def __init__(self, f_CSE, device):
27
+ super().__init__()
28
+ df_cse = pd.read_csv(f_CSE)
29
+ sample = df_cse.iloc[:,0]
30
+ expr_cse = df_cse.iloc[:,1:]
31
+ self.expr_cse = torch.tensor(expr_cse.values, dtype=torch.float32).to(device)
32
+ self.sample = sample
33
+
34
+ def __len__(self):
35
+ return self.expr_cse.shape[0]
36
+
37
+ def __getitem__(self, idx):
38
+ if torch.is_tensor(idx):
39
+ idx = idx.tolist()
40
+ return {'expr_cse': self.expr_cse[idx, :], 'sample': self.sample[idx]}
41
+
42
+ def update_data(self, new_CSE):
43
+ self.expr_cse = new_CSE
CNNreg/layers.py ADDED
@@ -0,0 +1,91 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torchsort
5
+
6
+ class RefCombLayer(nn.Module):
7
+ """Reference combination layer for CNN deconvolution."""
8
+ def __init__(self, pHash, dHash, device=None):
9
+ super().__init__()
10
+ num = 1.0 / pHash["n_ref"]
11
+ w = torch.empty(pHash["n_ref"] * pHash["n_celltype"])
12
+ self.weight = nn.Parameter(torch.nn.init.uniform_(w, a=num-num/5, b=num+num/5).to(device))
13
+ self.w_expand = torch.flatten(self.weight.expand([pHash["n_gene"], len(w)]))
14
+
15
+ def forward(self, x):
16
+ return torch.mul(x, self.w_expand)
17
+
18
+
19
+ class SliceSumLayer(nn.Module):
20
+ """Sum slices of the reference layer output."""
21
+ def __init__(self, pHash, dHash, device=None):
22
+ super().__init__()
23
+ self.slice_column = pHash["n_ref"]
24
+ self.slice_row = pHash["n_celltype"] * pHash["n_gene"]
25
+
26
+ def forward(self, x):
27
+ x = x.reshape(self.slice_row, self.slice_column)
28
+ return torch.sum(x, dim=1)
29
+
30
+
31
+ class CelltypeScaleLayer(nn.Module):
32
+ """Scale expression for each cell type."""
33
+ def __init__(self, pHash, dHash, device=None):
34
+ super().__init__()
35
+ w = torch.empty(pHash["n_celltype"])
36
+ self.weight = nn.Parameter(torch.nn.init.normal_(w, 1.0, 0.1).to(device))
37
+ self.idx = dHash["idx_feature_celltype"]
38
+
39
+ def forward(self, x):
40
+ z = x[self.idx[0]] * self.weight[0]
41
+ for ii in range(1, len(self.idx)):
42
+ y = x[self.idx[ii]] * self.weight[ii]
43
+ z = torch.vstack((z,y))
44
+ return torch.flatten(torch.t(z))
45
+
46
+
47
+ class StretchLayer(nn.Module):
48
+ """Stretch layer to scale features per gene."""
49
+ def __init__(self, pHash, dHash, device=None):
50
+ super().__init__()
51
+ w = torch.empty(pHash["n_gene"])
52
+ self.weight = nn.Parameter(torch.nn.init.normal_(w, 1.0, 0.1).to(device))
53
+ self.w_expand = torch.flatten(torch.t(self.weight.expand((pHash["n_celltype"], -1))))
54
+
55
+
56
+ def forward(self, x):
57
+ return torch.mul(x, self.w_expand)
58
+
59
+
60
+ class DeconvProp_S1(nn.Module):
61
+ """Stage I CNN model for cell proportion estimation."""
62
+
63
+ @staticmethod
64
+ def ini_kernel_weight(pHash):
65
+ """Initialize kernel weights for conv1 layer."""
66
+ ## (out_channels, in_channels, kernel_size)
67
+ arr = np.ones(pHash["n_celltype"])/pHash["n_celltype"]
68
+ arr = np.tile(arr, (pHash["n_sample"], 1, 1))
69
+ w = torch.empty((arr.shape[0], arr.shape[1], arr.shape[2]))
70
+ weights = torch.nn.init.normal_(w, 1.0/pHash["n_celltype"], 0.25*1.0/pHash["n_celltype"]).to(pHash["device"])
71
+ return nn.Parameter(weights)
72
+
73
+ def __init__(self, dHash, pHash):
74
+ super().__init__()
75
+ k = pHash["n_gene"]*pHash["n_celltype"]
76
+ self.refLayer = RefCombLayer(pHash, dHash, device=pHash["device"])
77
+ self.sum = SliceSumLayer(pHash, dHash, device=pHash["device"])
78
+ self.celltypeScaleLayer = CelltypeScaleLayer(pHash, dHash, device=pHash["device"])
79
+ self.stretchLayer = StretchLayer(pHash, dHash, device=pHash["device"])
80
+ self.conv1 = nn.Conv1d(1, pHash["n_sample"], kernel_size=pHash["n_kernel"],
81
+ stride=pHash["n_kernel"], device=pHash["device"], bias=False)
82
+ self.conv1.weight = self.ini_kernel_weight(pHash)
83
+
84
+ def forward(self, x, N_batch, N_feature):
85
+ y = self.refLayer(x)
86
+ y0 = self.sum(y)
87
+ y1 = self.celltypeScaleLayer(y0)
88
+ y2 = self.stretchLayer(y1)
89
+ y = y2.view(N_batch, 1, N_feature)
90
+ y = self.conv1(y).squeeze(0)
91
+ return y, y0, y1, y2
CNNreg/losses.py ADDED
@@ -0,0 +1,107 @@
1
+ import torch
2
+ import torchsort
3
+
4
+ def spearmanr(pred, target, indx=None, **kw):
5
+ """Differentiable Spearman correlation using soft ranking."""
6
+ if len(pred.shape) == 1:
7
+ pred = pred.unsqueeze(0)
8
+ if len(target.shape) == 1:
9
+ target = target.unsqueeze(0)
10
+ if indx is not None:
11
+ pred = pred[:, indx]
12
+ target = target[:, indx]
13
+ pred_rank = torchsort.soft_rank(pred, **kw)
14
+ target_rank = torchsort.soft_rank(target, **kw)
15
+ pred_rank_center = pred_rank - pred_rank.mean(dim=-1, keepdim=True)
16
+ target_rank_center = target_rank - target_rank.mean(dim=-1, keepdim=True)
17
+ covariance = (pred_rank_center * target_rank_center).sum(dim=-1)
18
+ pred_std_rank = torch.sqrt((pred_rank_center ** 2).sum(dim=-1))
19
+ target_std_rank = torch.sqrt((target_rank_center ** 2).sum(dim=-1))
20
+ return covariance / (pred_std_rank * target_std_rank)
21
+
22
+ def loss_spearmanr(pred, target, indx=None, **kw):
23
+ """Spearman correlation loss."""
24
+ return torch.mean(1 - spearmanr(pred, target, indx=indx))
25
+
26
+ def loss_prop(model):
27
+ loss = 0.0
28
+ for param in model.conv1.parameters():
29
+ loss = loss + torch.sum(torch.abs(param[param < 0]))
30
+ #prop_sum = torch.sum(torch.squeeze(param,1),1) ## sum of cell proportion for each sample
31
+ #loss = loss + torch.mean(torch.max(torch.abs(prop_sum-1.0)-0.8, torch.zeros_like(prop_sum)))
32
+ return loss
33
+
34
+
35
+ def loss_ref(model, pHash):
36
+ loss = 0.0
37
+ for param in model.refLayer.parameters():
38
+ loss = torch.sum(torch.abs(param[param < 0]))
39
+ for x in range(pHash["n_celltype"]):
40
+ idx = [y for y in range(x*pHash["n_ref"], (x+1)*pHash["n_ref"])]
41
+ ref_sum = torch.sum(param[idx])
42
+ #loss = loss + torch.abs(ref_sum-1)
43
+ loss = loss +torch.max(torch.abs(ref_sum-1.0), torch.zeros_like(ref_sum))
44
+ return loss
45
+
46
+
47
+
48
+ def loss_scale(model):
49
+ loss = 0.0
50
+ #for param in model.celltypeScaleLayer.parameters():
51
+ # idx_lo = param < 0.2
52
+ # if torch.sum(idx_lo) > 0:
53
+ # loss = loss + torch.mean((0.2-param[idx_lo])*(0.2-param[idx_lo]))
54
+ # idx_hi = param > 5.0
55
+ # if torch.sum(idx_hi) > 0:
56
+ # loss = loss + torch.mean((5.0-param[idx_hi])*(5.0-param[idx_hi]))
57
+ #return loss
58
+ for param in model.stretchLayer.parameters():
59
+ #loss = torch.sum(torch.abs(param[param < 0.0]))
60
+ #loss = loss + torch.sum(torch.max(torch.abs(param-1.0)-0.1, torch.zeros_like(param)))
61
+ loss = torch.sum(torch.abs(param[param < 0.5]))
62
+ loss = loss + torch.sum(torch.abs(param[param > 2]))
63
+ #for ii in range(pHash["n_celltype"]):
64
+ # loss = loss + loss_spearmanr(x[dHash["idx_feature_celltype"][ii]], dHash["ref_CSE"][ii])
65
+ return loss
66
+
67
+
68
+
69
+ def loss_stretch(model,x, dHash, pHash):
70
+ loss = 0.0
71
+ for param in model.stretchLayer.parameters():
72
+ #loss = torch.sum(torch.abs(param[param < 0.0]))
73
+ #loss = loss + torch.sum(torch.max(torch.abs(param-1.0)-0.1, torch.zeros_like(param)))
74
+ loss = torch.mean(torch.abs(param[param < 0.5]))
75
+ loss = loss + torch.mean(torch.abs(param[param > 2]))
76
+ #for ii in range(pHash["n_celltype"]):
77
+ # loss = loss + loss_spearmanr(x[dHash["idx_feature_celltype"][ii]], dHash["ref_CSE"][ii])
78
+ return loss
79
+
80
+
81
+ def loss_epsilon_insensitive(prediction, target, epsilon):
82
+ #return torch.mean(torch.max(torch.abs(prediction-target) - epsilon, torch.zeros_like(prediction)))
83
+ return torch.mean(torch.max(torch.abs(prediction-target) - epsilon*target, torch.zeros_like(prediction)))
84
+
85
+
86
+ def loss_epsilon_insensitive_2(prediction, target, epsilon):
87
+ #return torch.mean(torch.max(torch.abs(prediction-target) - epsilon, torch.zeros_like(prediction)))
88
+ return torch.mean(torch.max(torch.abs(prediction-target) - epsilon, torch.zeros_like(prediction)))
89
+
90
+
91
+ def loss_large_insensitive(prediction, target, threshold):
92
+ #return torch.mean(torch.min(torch.abs(prediction-target) , threshold*torch.ones_like(prediction)))
93
+ return torch.mean(torch.min(torch.abs(prediction-target), threshold*target))
94
+
95
+
96
+
97
+ def loss_large_insensitive_2(prediction, target, threshold):
98
+ #return torch.mean(torch.min(torch.abs(prediction-target) , threshold*torch.ones_like(prediction)))
99
+ return torch.mean(torch.min(torch.abs(prediction-target), threshold*torch.ones_like(prediction)))
100
+
101
+
102
+ def loss_small_large_insensitive(prediction, target, epsilon, threshold):
103
+ #return torch.mean(torch.min(torch.abs(prediction-target) , threshold*torch.ones_like(prediction)))
104
+ loss = torch.min( torch.max(torch.abs(prediction-target) - epsilon*target, torch.zeros_like(prediction)), threshold*torch.ones_like(prediction))
105
+ return torch.mean(loss)
106
+
107
+
CNNreg/train.py ADDED
@@ -0,0 +1,134 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ from .data import divide_by_row_sum, reformat_ref
5
+ from .layers import DeconvProp_S1
6
+ from .losses import (loss_prop, loss_ref, loss_scale, loss_stretch,
7
+ loss_epsilon_insensitive, loss_epsilon_insensitive_2,
8
+ spearmanr)
9
+ import torch.nn as nn
10
+ from torchmetrics.functional import pearson_corrcoef
11
+
12
+ # https://pytorch.org/docs/0.3.0/optim.html#per-parameter-options
13
+ # https://discuss.pytorch.org/t/parameters-with-requires-grad-false-are-updated-during-training/90096/9
14
+ # https://www.youtube.com/watch?v=DbeIqrwb_dE
15
+
16
+ #### initialize weigth with some randomness improves performance
17
+ #### use of optim.Adam worsen the performance a lot
18
+ #### block = 2 seems perform better than larger block size
19
+ #### tuning celltypeScale layer and cell prortion layer separately in S1 performs better than tuning them together.
20
+ #### using StretchLayer_2 in S2 worsen the performance. Should stick to StretchLayer
21
+
22
+ def trainProp(dHash, pHash):
23
+ """
24
+ Train the Stage I CNN model for cell proportion estimation.
25
+
26
+ Parameters
27
+ ----------
28
+ dHash : dict
29
+ Data dictionary containing bulk RNA-seq and reference data.
30
+ pHash : dict
31
+ Parameter dictionary with model configuration.
32
+ """
33
+ ## Stage I: scale expression for each cell type and adjust cell proportion
34
+
35
+ torch.manual_seed(1334)
36
+ torch.cuda.manual_seed(1334)
37
+ net_Prop_S1 = DeconvProp_S1(dHash, pHash)
38
+ net_Prop_S1.train()
39
+ block = 4
40
+ loss = nn.L1Loss() # nn.MSELoss() nn.HuberLoss()
41
+ target = dHash["bulk"] # (N_sample, N_feature)
42
+ target_mean = torch.mean(target, 0) + 0.0001
43
+ target_mean_adj = 2.0*torch.tanh(target_mean+0.1)
44
+ #condition2 = torch.logical_and(target_mean > 0.01, target_mean <= 0.5)
45
+ #target_mean_adj = torch.tensor(torch.where(condition1, torch.tensor(0.25).double().to(pHash["device"]), torch.where(condition2, torch.tensor(0.5).double().to(pHash["device"]), torch.tensor(1.0).double().to(pHash["device"]))), dtype=torch.float64, device=pHash["device"])
46
+ LR = 0.02 #max(0.0001, min(0.001, 1-epoch/N)) ## cannot use bigger than 0.05
47
+ N = pHash["max_epoch_cellprop"]
48
+ ll_kernel = []
49
+ loss_model = []
50
+ for epoch in range(0, N):
51
+ print("epoch = " + str(epoch))
52
+ modd = epoch % block
53
+ x_predict, x_afterRef, x_afterScale, x_afterStretch = net_Prop_S1(dHash["CSE_reformat"], 1, pHash["n_celltype"]*pHash["n_gene"])
54
+ new_ref = x_afterStretch.reshape([pHash["n_gene"], pHash["n_celltype"]])
55
+ gene_var = torch.var(new_ref, dim=1)
56
+ quantile_marker = torch.quantile(new_ref, 0.5, dim=0)
57
+ ll = []
58
+ for ii in range( pHash["n_celltype"] ):
59
+ mask = torch.ones(new_ref.shape[1], dtype=torch.bool)
60
+ mask[ii] = False # Set the mask to False for the column you want to exclude
61
+ #ll = ll + torch.where((new_ref[:,ii] >= quantile_marker[ii]) & (new_ref[:,ii] > torch.mean(new_ref[:,mask],dim=1)[0]))[0].tolist()
62
+ ll = ll + torch.where((new_ref[:,ii] >= quantile_marker[ii]))[0].tolist()
63
+ indx_marker = list(set(ll))
64
+ gene_cv = gene_var/target_mean
65
+ th = torch.quantile(gene_cv, 0.75)
66
+ indx_1 = torch.where(((gene_cv >= 1.0) | (gene_cv >= th)))[0]
67
+ #indx_2 = list(set(indx_marker + torch.where((target_mean >= 0.5))[0].tolist()))
68
+ indx_2 = list(set(indx_marker).intersection(set(torch.where((target_mean >= 0.05))[0].tolist())))
69
+ indx_3 = torch.where((target_mean < 2.0))[0]
70
+ if modd == 1: # tune reference layer
71
+ train_loss = loss_ref(net_Prop_S1, pHash) + (1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
72
+ #train_loss = loss_ref(net_Prop_S1, pHash) + 1-torch.mean(spearmanr(x_predict[:,indx_3], target[:,indx_3]))
73
+ #0.1*(1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
74
+ train_loss.backward()
75
+ with torch.no_grad():
76
+ net_Prop_S1.refLayer.weight.sub_(net_Prop_S1.refLayer.weight.grad*LR/pHash["n_ref"])
77
+
78
+ elif modd == 2:
79
+ train_loss = loss_scale(net_Prop_S1) + 0.1*(1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
80
+ train_loss.backward()
81
+ with torch.no_grad():
82
+ net_Prop_S1.celltypeScaleLayer.weight.sub_(net_Prop_S1.celltypeScaleLayer.weight.grad*LR)
83
+
84
+ elif modd == 3:
85
+ train_loss = loss_stretch(net_Prop_S1, x_afterStretch, dHash, pHash) + 0.1*(1-torch.mean(pearson_corrcoef(x_predict[:,indx_3].t(), target[:,indx_3].t())))
86
+ train_loss.backward()
87
+ with torch.no_grad():
88
+ net_Prop_S1.stretchLayer.weight.sub_(net_Prop_S1.stretchLayer.weight.grad*LR*pHash["n_gene"]/100)
89
+
90
+ else:
91
+ rho_1 = spearmanr(x_predict[:,indx_1].t(), target[:,indx_1].t())
92
+ rho_2 = spearmanr(x_predict[:,indx_2], target[:,indx_2])
93
+ train_loss = loss_prop(net_Prop_S1) + \
94
+ loss_epsilon_insensitive(x_predict[:,indx_3]/target_mean_adj[indx_3], target[:,indx_3]/target_mean_adj[indx_3], 0.05) + \
95
+ 0.1*loss_epsilon_insensitive_2(rho_1, torch.ones_like(rho_1), 0.05) + \
96
+ 0.1*loss_epsilon_insensitive_2(rho_2, torch.ones_like(rho_2), 0.05)
97
+ train_loss.backward()
98
+ with torch.no_grad():
99
+ net_Prop_S1.conv1.weight.sub_(net_Prop_S1.conv1.weight.grad*LR)
100
+ ll_kernel.append(net_Prop_S1.conv1.weight.tolist())
101
+
102
+ with torch.no_grad():
103
+ net_Prop_S1.refLayer.weight.grad.zero_()
104
+ net_Prop_S1.celltypeScaleLayer.weight.grad.zero_()
105
+ net_Prop_S1.stretchLayer.weight.grad.zero_()
106
+ net_Prop_S1.conv1.weight.grad.zero_()
107
+ ## output estimation
108
+ if epoch % 1000 == 0:
109
+ print("celltypeScaleLayer weight")
110
+ print(net_Prop_S1.celltypeScaleLayer.weight.tolist())
111
+ cellprop = np.squeeze(np.array(net_Prop_S1.conv1.weight.tolist()), 1)
112
+ cellprop = divide_by_row_sum(cellprop)
113
+ df = pd.concat([pd.DataFrame(dHash["sample"]), pd.DataFrame(cellprop)], axis=1)
114
+ x = ["Sample"]
115
+ x.extend(dHash["celltype"])
116
+ df.columns = x
117
+ df.to_csv(pHash["data_out_dir"]+ "/" + "Prop_predicted_" + pHash["prefix"] + "_epoch_" + str(epoch) + ".csv", index=False)
118
+
119
+
120
+ for name, param in net_Prop_S1.named_parameters():
121
+ print(name, param)
122
+
123
+ ### output estimated proportion from Stage I
124
+ cellprop = np.squeeze(np.array(net_Prop_S1.conv1.weight.tolist()), 1)
125
+ df = pd.concat([pd.DataFrame(dHash["sample"]), pd.DataFrame(cellprop)], axis=1)
126
+ x = ["Sample"]
127
+ x.extend(dHash["celltype"])
128
+ df.columns = x
129
+ df.to_csv(pHash["data_out_dir"]+ "/" + "Prop_predicted_" + pHash["prefix"] + ".csv", index=False)
130
+ # save final model
131
+ torch.save(net_Prop_S1.state_dict(), pHash["model_file"])
132
+
133
+
134
+ ####################################################################################################################################################################
@@ -0,0 +1,218 @@
1
+ Metadata-Version: 2.4
2
+ Name: CNNreg
3
+ Version: 0.1.0
4
+ Summary: CNN-based regression for cell type deconvolution of bulk RNA-seq using scRNA-seq reference
5
+ Author-email: Xue Wang <wang.xue@mayo.edu>
6
+ Maintainer-email: Yuanhang Liu <liu.yuanhang@mayo.edu>, Xue Wang <wang.xue@mayo.edu>
7
+ License: MIT
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch>=1.10.0
12
+ Requires-Dist: torchmetrics>=0.10.0
13
+ Requires-Dist: torchsort>=0.1.0
14
+ Requires-Dist: numpy
15
+ Requires-Dist: pandas
16
+ Dynamic: license-file
17
+
18
+ # CNNreg: CNN-based Cell Type Deconvolution
19
+
20
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
21
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
22
+
23
+ A deep learning approach for cell type deconvolution of bulk RNA-seq data using single-cell RNA-seq reference data. CNNreg employs a custom CNN-based regression model to estimate cell type proportions in complex tissue samples.
24
+
25
+ ## Installation
26
+
27
+ ### Option 1: Using pip (Recommended)
28
+
29
+ CNNreg requires PyTorch. Install PyTorch first according to your system configuration:
30
+
31
+ ```bash
32
+ # For CUDA 12.6 (check your CUDA version: nvidia-smi)
33
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu126
34
+
35
+ # For CUDA 11.8
36
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu118
37
+
38
+ # For CPU only
39
+ pip3 install torch --index-url https://download.pytorch.org/whl/cpu
40
+ ```
41
+
42
+ Then install CNNreg:
43
+
44
+ ```bash
45
+ pip install CNNreg
46
+ ```
47
+
48
+ ### Option 2: From Source (Development)
49
+
50
+ ```bash
51
+ git clone https://github.com/mwang159/CNNreg.git
52
+ cd CNNreg
53
+ pip install -e .
54
+ ```
55
+
56
+ ### Option 3: Using Conda Environment
57
+
58
+ ```bash
59
+ # Create environment
60
+ conda create -n cnnreg_env python=3.10
61
+ conda activate cnnreg_env
62
+
63
+ # Install PyTorch (choose appropriate CUDA version)
64
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu126
65
+
66
+ # Install CNNreg
67
+ pip install CNNreg
68
+ ```
69
+
70
+ ### Verify Installation
71
+
72
+ ```bash
73
+ # Check CLI command
74
+ cnnreg --help
75
+
76
+ # Test in Python
77
+ python -c "import CNNreg; print('CNNreg installed successfully!')"
78
+ ```
79
+
80
+ ## Quick Start
81
+
82
+ ### Command Line Interface
83
+
84
+ ```bash
85
+ cnnreg -M train \
86
+ -bulk data/bulk.csv \
87
+ -ref data/sc_ref.csv \
88
+ -o output/ \
89
+ -C 7 \
90
+ -EP 50000 \
91
+ -pre GBM_analysis
92
+ ```
93
+
94
+ **Parameters:**
95
+ - `-M`: Mode - currently only `train` is implemented (evaluate/predict/explain coming in future versions)
96
+ - `-bulk`: Path to bulk RNA-seq CSV file
97
+ - `-ref`: Path to reference cell type specific expression (CSE) CSV
98
+ - `-o`: Output directory
99
+ - `-C`: Number of cell types (kernel size)
100
+ - `-EP`: Maximum training epochs
101
+ - `-pre`: Output file prefix
102
+
103
+ **Note**: Training mode automatically generates cell proportion predictions for the input bulk samples. Predictions are saved at checkpoints (every 1000 epochs) and at completion.
104
+
105
+ ### Python API
106
+
107
+ ```python
108
+ import torch
109
+ import pandas as pd
110
+ from CNNreg.data import data_CSE, reformat_ref
111
+ from CNNreg.train import trainProp
112
+ from CNNreg.layers import DeconvProp_S1
113
+
114
+ # Setup
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+
117
+ # Load data
118
+ df_bulk = pd.read_csv("bulk.csv")
119
+ bulk_data = torch.tensor(
120
+ df_bulk.iloc[:, 1:].values.transpose(),
121
+ dtype=torch.float32
122
+ ).to(device)
123
+
124
+ # Configure parameters
125
+ pHash = {
126
+ "bulk": "bulk.csv",
127
+ "reference": "sc_ref.csv",
128
+ "data_out_dir": "output/",
129
+ "max_epoch_cellprop": 50000,
130
+ "model_file": "output/model.pt",
131
+ "prefix": "GBM",
132
+ "device": device,
133
+ "n_kernel": 7,
134
+ "n_gene": df_bulk.shape[0],
135
+ "n_sample": df_bulk.shape[1] - 1,
136
+ "n_celltype": 7
137
+ }
138
+
139
+ # Initialize data dictionary
140
+ dHash = {
141
+ "bulk": bulk_data,
142
+ "sample": df_bulk.columns.values[1:],
143
+ "celltype": ["AClike", "MESlike", "NPClike", "OPClike", "OL", "Myeloid", "T"],
144
+ "CSE": data_CSE(pHash["reference"], device=device)
145
+ }
146
+
147
+ # Add required indices
148
+ k = pHash["n_gene"] * pHash["n_celltype"]
149
+ dHash["idx_feature_celltype"] = [
150
+ [int(y+x) for y in range(0, k, pHash["n_celltype"])]
151
+ for x in range(pHash["n_celltype"])
152
+ ]
153
+ dHash["CSE_reformat"] = reformat_ref(dHash["CSE"].expr_cse)
154
+
155
+ # Train model
156
+ trainProp(dHash, pHash)
157
+ ```
158
+
159
+ ## Input Data Format
160
+
161
+ ### Bulk RNA-seq Data
162
+
163
+ Rows are samples, columns are genes:
164
+
165
+ ```csv
166
+ Sample,Gene1,Gene2,Gene3,...
167
+ Sample1,0.5,1.2,0.8,...
168
+ Sample2,1.1,0.9,1.5,...
169
+ Sample3,0.7,1.3,0.6,...
170
+ ```
171
+
172
+ ### Reference scRNA-seq Data
173
+
174
+ Cell Type Specific Expression profiles from scRNA-seq:
175
+
176
+ ```csv
177
+ CellType,Gene1,Gene2,Gene3,...
178
+ AClike_ref1,0.3,0.8,0.5,...
179
+ AClike_ref2,0.4,0.7,0.6,...
180
+ MESlike_ref1,0.9,0.2,0.4,...
181
+ ```
182
+
183
+ ## Output Files
184
+
185
+ - `Prop_predicted_{prefix}_epoch_{N}.csv`: Cell proportions at checkpoint epochs (every 1000)
186
+ - `Prop_predicted_{prefix}.csv`: Final estimated cell proportions
187
+ - `cellprop_model.pt`: Trained PyTorch model (state dict)
188
+
189
+ **Output format:**
190
+
191
+ ```csv
192
+ Sample,celltype_1,celltype_2,celltype_3,...
193
+ Sample1,0.23,0.15,0.31,...
194
+ Sample2,0.19,0.28,0.22,...
195
+ ```
196
+
197
+ ## Model Architecture
198
+
199
+ CNNreg uses a custom 5-layer CNN pipeline specifically designed for biological deconvolution:
200
+
201
+ 1. **RefCombLayer**: Combines multiple reference samples per cell type
202
+ 2. **SliceSumLayer**: Aggregates reference combinations
203
+ 3. **CelltypeScaleLayer**: Scales expression for each cell type independently
204
+ 4. **StretchLayer**: Applies gene-specific scaling factors
205
+ 5. **Conv1D Layer**: Estimates cell proportions via 1D convolution
206
+
207
+ ## Citation
208
+
209
+ If you use CNNreg in your research, please cite:
210
+
211
+ ## License
212
+
213
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
214
+
215
+ ## Contact
216
+
217
+ - **Issues**: [GitHub Issues](https://github.com/yourusername/CNNreg/issues)
218
+ - **Email**: wang.xue@mayo.edu, liu.yuanhang@mayo.edu
@@ -0,0 +1,12 @@
1
+ CNNreg/__init__.py,sha256=LZNMFSvn-lvAp4amVFYy8HnhxlXLQeXt9aT6i-BS2EQ,454
2
+ CNNreg/cli.py,sha256=YErLfKgt0RD4uRw6_rOdFMajRo9akQhj_ncL4YXPKkI,2870
3
+ CNNreg/data.py,sha256=gAvpMOJHj7T7bh2Jp4gKgM5K5A0DvNGIMV6dVi3p0GY,1335
4
+ CNNreg/layers.py,sha256=SMgAR0ojLvOEqUqt4_hyhj1OKmkvIGaXWtIWz44twkM,3569
5
+ CNNreg/losses.py,sha256=zmpjIFhpmgwosoG6Lgrvq0RObqtvMzhyfxrHKFdzgUc,4699
6
+ CNNreg/train.py,sha256=cJVQ3x4IqAgXfORkavna0B5btZtozOf4GVG7l2Zzmr4,7262
7
+ cnnreg-0.1.0.dist-info/licenses/LICENSE,sha256=6le1Ynl_12mYxMvuMCVOeBVHPY01YMyexsWvTBl0r6Y,1064
8
+ cnnreg-0.1.0.dist-info/METADATA,sha256=EnhVzcsm5HkiQpuVD4_bTr9y-WxnutpAE9e2QLeRutk,5699
9
+ cnnreg-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ cnnreg-0.1.0.dist-info/entry_points.txt,sha256=oNea6BqtYSR4t5s0q0t9IV4i4k4z6omc4D7c64lx5Uw,43
11
+ cnnreg-0.1.0.dist-info/top_level.txt,sha256=zSYWDE6KCfeTnHz1PQ5KXUVxpZSIVkjuOS9FrgHWJBA,7
12
+ cnnreg-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ cnnreg = CNNreg.cli:main
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Xue Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ CNNreg