crisgi 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crisgi/__init__.py +4 -0
- crisgi/_version.py +1 -0
- crisgi/cnn/CNNModel.py +64 -0
- crisgi/cnn/__init__.py +0 -0
- crisgi/cnn/autoEncoder.py +36 -0
- crisgi/cnn/evalution_metrics.py +46 -0
- crisgi/cnn/mlp.py +16 -0
- crisgi/cnn/train.py +42 -0
- crisgi/crisgi.py +986 -0
- crisgi/logistic/LogisticModel.py +37 -0
- crisgi/logistic/__init__.py +0 -0
- crisgi/logistic/evalution_metrics.py +53 -0
- crisgi/logistic/train.py +32 -0
- crisgi/plotting.py +219 -0
- crisgi/plotting_crisgi_time.py +278 -0
- crisgi/simplecnn/SimpleCNNModel.py +65 -0
- crisgi/simplecnn/__init__.py +0 -0
- crisgi/simplecnn/autoEncoder.py +22 -0
- crisgi/simplecnn/evalution_metrics.py +53 -0
- crisgi/simplecnn/mlp.py +14 -0
- crisgi/simplecnn/train.py +45 -0
- crisgi/startpoint_detection.py +61 -0
- crisgi/stringdb_human_v12_gb_net.pk +0 -0
- crisgi/stringdb_human_v12_genes.pk +0 -0
- crisgi/stringdb_mouse_v12_gb_net.pk +0 -0
- crisgi/stringdb_mouse_v12_genes.pk +0 -0
- crisgi/util.py +85 -0
- crisgi-0.1.0.dist-info/METADATA +123 -0
- crisgi-0.1.0.dist-info/RECORD +32 -0
- crisgi-0.1.0.dist-info/WHEEL +5 -0
- crisgi-0.1.0.dist-info/licenses/LICENSE +183 -0
- crisgi-0.1.0.dist-info/top_level.txt +1 -0
crisgi/__init__.py
ADDED
crisgi/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
crisgi/cnn/CNNModel.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import torch
|
|
3
|
+
from crisgi.cnn.autoEncoder import AE
|
|
4
|
+
import crisgi.cnn.autoEncoder
|
|
5
|
+
from crisgi.cnn.mlp import MLP
|
|
6
|
+
import crisgi.cnn.mlp
|
|
7
|
+
from crisgi.cnn.train import train as train_loop
|
|
8
|
+
|
|
9
|
+
sys.modules['autoEncoder'] = crisgi.cnn.autoEncoder
|
|
10
|
+
sys.modules['mlp'] = crisgi.cnn.mlp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CNNModel:
|
|
14
|
+
def __init__(self, device, ae_path=None, mlp_path=None):
|
|
15
|
+
self.device = device
|
|
16
|
+
self.ae = AE().to(device)
|
|
17
|
+
self.mlp = MLP().to(device)
|
|
18
|
+
self.ae_loss_fn = torch.nn.MSELoss().to(device)
|
|
19
|
+
self.ce_loss_fn = torch.nn.CrossEntropyLoss().to(device)
|
|
20
|
+
self.optimizer = torch.optim.Adam(
|
|
21
|
+
list(self.ae.parameters()) + list(self.mlp.parameters()),
|
|
22
|
+
lr=0.001, weight_decay=1e-4
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if ae_path and mlp_path:
|
|
26
|
+
print(f"Loading pretrained models:\n - AE: {ae_path}\n - MLP: {mlp_path}")
|
|
27
|
+
self.ae = torch.load(ae_path, map_location=device, weights_only=False)
|
|
28
|
+
self.mlp = torch.load(mlp_path, map_location=device, weights_only=False)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def train(self, train_loader,epochs=10):
|
|
32
|
+
final_metrics = None
|
|
33
|
+
for epoch in range(epochs):
|
|
34
|
+
print(f"\nEpoch {epoch + 1}/{epochs}")
|
|
35
|
+
final_metrics = train_loop(
|
|
36
|
+
self.ae,
|
|
37
|
+
self.mlp,
|
|
38
|
+
train_loader,
|
|
39
|
+
self.ae_loss_fn,
|
|
40
|
+
self.ce_loss_fn,
|
|
41
|
+
self.optimizer,
|
|
42
|
+
self.device
|
|
43
|
+
)
|
|
44
|
+
return final_metrics
|
|
45
|
+
|
|
46
|
+
def predict(self, data_loader):
|
|
47
|
+
self.ae.eval()
|
|
48
|
+
self.mlp.eval()
|
|
49
|
+
all_predictions = []
|
|
50
|
+
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
for x, _ in data_loader:
|
|
53
|
+
x = x.to(self.device)
|
|
54
|
+
en, _ = self.ae(x)
|
|
55
|
+
out = self.mlp(en)
|
|
56
|
+
predicted = out.argmax(dim=1)
|
|
57
|
+
all_predictions.extend(predicted.cpu().numpy())
|
|
58
|
+
|
|
59
|
+
return all_predictions
|
|
60
|
+
|
|
61
|
+
def save(self, ae_path, mlp_path):
|
|
62
|
+
torch.save(self.ae.state_dict(), ae_path)
|
|
63
|
+
torch.save(self.mlp.state_dict(), mlp_path)
|
|
64
|
+
print(f"Models saved to:\n - {ae_path}\n - {mlp_path}")
|
crisgi/cnn/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AE(nn.Module):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.encoder = nn.Sequential(
|
|
8
|
+
nn.Conv2d(4, 16, 3, padding=1),
|
|
9
|
+
nn.BatchNorm2d(16),
|
|
10
|
+
nn.ReLU(),
|
|
11
|
+
nn.Dropout(0.2),
|
|
12
|
+
nn.MaxPool2d(2, 2),
|
|
13
|
+
nn.Conv2d(16, 32, 3, padding=1),
|
|
14
|
+
nn.BatchNorm2d(32),
|
|
15
|
+
nn.ReLU(),
|
|
16
|
+
nn.Dropout(0.2),
|
|
17
|
+
nn.MaxPool2d(2, 2),
|
|
18
|
+
nn.Conv2d(32, 64, 3, padding=1),
|
|
19
|
+
nn.BatchNorm2d(64),
|
|
20
|
+
nn.ReLU(),
|
|
21
|
+
nn.Dropout(0.2),
|
|
22
|
+
nn.MaxPool2d(2, 2)
|
|
23
|
+
)
|
|
24
|
+
self.decoder = nn.Sequential(
|
|
25
|
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
|
|
26
|
+
nn.ReLU(),
|
|
27
|
+
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
|
|
28
|
+
nn.ReLU(),
|
|
29
|
+
nn.ConvTranspose2d(16, 4, 3, stride=2, padding=1, output_padding=1),
|
|
30
|
+
nn.Sigmoid()
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def forward(self, x):
|
|
34
|
+
en = self.encoder(x)
|
|
35
|
+
de = self.decoder(en)
|
|
36
|
+
return en, de
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from sklearn.metrics import auc, brier_score_loss, cohen_kappa_score, confusion_matrix, f1_score, roc_curve, roc_auc_score, precision_recall_curve, average_precision_score
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def calculate_pred_metric(label, pred):
|
|
5
|
+
# label: ground truth
|
|
6
|
+
# pred: prediction
|
|
7
|
+
|
|
8
|
+
# calculate ROC AUC
|
|
9
|
+
roc_auc = roc_auc_score(label, pred)
|
|
10
|
+
|
|
11
|
+
# confusion matrix
|
|
12
|
+
tn, fp, fn, tp = confusion_matrix(label, pred).ravel()
|
|
13
|
+
|
|
14
|
+
# other metric
|
|
15
|
+
accuracy = (tp + tn) / (tp + fp + fn + tn)
|
|
16
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
17
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
18
|
+
sensitivity = recall
|
|
19
|
+
specificity = tn / (tn + fp)
|
|
20
|
+
ppv = precision
|
|
21
|
+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0
|
|
22
|
+
f1 = f1_score(label, pred)
|
|
23
|
+
kappa = cohen_kappa_score(label, pred)
|
|
24
|
+
brier = brier_score_loss(label, pred)
|
|
25
|
+
|
|
26
|
+
# calculate PR curve and its AUC
|
|
27
|
+
precision_vals, recall_vals, _ = precision_recall_curve(label, pred)
|
|
28
|
+
avg_pr_auc = auc(recall_vals, precision_vals)
|
|
29
|
+
|
|
30
|
+
# store the result into dictionary
|
|
31
|
+
res = {
|
|
32
|
+
'AUROC': roc_auc,
|
|
33
|
+
'AUPRC': avg_pr_auc,
|
|
34
|
+
'Accuracy': accuracy,
|
|
35
|
+
'Precision': precision,
|
|
36
|
+
'Recall': recall,
|
|
37
|
+
'Sensitivity': sensitivity,
|
|
38
|
+
'Specificity': specificity,
|
|
39
|
+
'PPV': ppv,
|
|
40
|
+
'NPV': npv,
|
|
41
|
+
'F1_Score': f1,
|
|
42
|
+
'Kappa': kappa,
|
|
43
|
+
'Brier_Score': brier
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
return res
|
crisgi/cnn/mlp.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MLP(nn.Module):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.mlp = nn.Sequential(
|
|
8
|
+
nn.Linear(64 * 28 * 28, 64),
|
|
9
|
+
nn.ReLU(),
|
|
10
|
+
nn.Linear(64, 2),
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
def forward(self, x):
|
|
14
|
+
flatt = x.view(x.size(0), -1)
|
|
15
|
+
res = self.mlp(flatt)
|
|
16
|
+
return res
|
crisgi/cnn/train.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from crisgi.cnn.evalution_metrics import calculate_pred_metric
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def train(ae, mlp, train_loader, AE_loss_function, CE_loss_function, optimizer, device):
|
|
5
|
+
ae.train()
|
|
6
|
+
mlp.train()
|
|
7
|
+
total_loss = 0
|
|
8
|
+
correct = 0
|
|
9
|
+
size = len(train_loader.dataset)
|
|
10
|
+
all_predictions = []
|
|
11
|
+
all_labels = []
|
|
12
|
+
|
|
13
|
+
for x, y in train_loader:
|
|
14
|
+
x = x.to(device)
|
|
15
|
+
y = y.to(device)
|
|
16
|
+
optimizer.zero_grad()
|
|
17
|
+
en, de = ae(x)
|
|
18
|
+
classification_res = mlp(en)
|
|
19
|
+
|
|
20
|
+
AE_loss = AE_loss_function(de, x)
|
|
21
|
+
CE_loss = CE_loss_function(classification_res, y)
|
|
22
|
+
loss = AE_loss + CE_loss
|
|
23
|
+
|
|
24
|
+
loss.backward()
|
|
25
|
+
optimizer.step()
|
|
26
|
+
|
|
27
|
+
total_loss += loss.item()
|
|
28
|
+
predicted = classification_res.argmax(1)
|
|
29
|
+
correct += (predicted == y).sum().item()
|
|
30
|
+
|
|
31
|
+
all_labels.extend(y.cpu().numpy())
|
|
32
|
+
all_predictions.extend(predicted.cpu().numpy())
|
|
33
|
+
|
|
34
|
+
avg_loss = total_loss / size
|
|
35
|
+
accuracy = 100 * correct / size
|
|
36
|
+
|
|
37
|
+
print(f"Total Train Loss: {avg_loss}")
|
|
38
|
+
print(f"Train Accuracy: {accuracy}%")
|
|
39
|
+
|
|
40
|
+
metrics = calculate_pred_metric(all_labels, all_predictions)
|
|
41
|
+
|
|
42
|
+
return metrics
|