crisgi 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
crisgi/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .crisgi import CRISGITime
2
+ from ._version import __version__
3
+
4
+ __all__ = ['CRISGITime', '__version__']
crisgi/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
crisgi/cnn/CNNModel.py ADDED
@@ -0,0 +1,64 @@
1
+ import sys
2
+ import torch
3
+ from crisgi.cnn.autoEncoder import AE
4
+ import crisgi.cnn.autoEncoder
5
+ from crisgi.cnn.mlp import MLP
6
+ import crisgi.cnn.mlp
7
+ from crisgi.cnn.train import train as train_loop
8
+
9
+ sys.modules['autoEncoder'] = crisgi.cnn.autoEncoder
10
+ sys.modules['mlp'] = crisgi.cnn.mlp
11
+
12
+
13
+ class CNNModel:
14
+ def __init__(self, device, ae_path=None, mlp_path=None):
15
+ self.device = device
16
+ self.ae = AE().to(device)
17
+ self.mlp = MLP().to(device)
18
+ self.ae_loss_fn = torch.nn.MSELoss().to(device)
19
+ self.ce_loss_fn = torch.nn.CrossEntropyLoss().to(device)
20
+ self.optimizer = torch.optim.Adam(
21
+ list(self.ae.parameters()) + list(self.mlp.parameters()),
22
+ lr=0.001, weight_decay=1e-4
23
+ )
24
+
25
+ if ae_path and mlp_path:
26
+ print(f"Loading pretrained models:\n - AE: {ae_path}\n - MLP: {mlp_path}")
27
+ self.ae = torch.load(ae_path, map_location=device, weights_only=False)
28
+ self.mlp = torch.load(mlp_path, map_location=device, weights_only=False)
29
+
30
+
31
+ def train(self, train_loader,epochs=10):
32
+ final_metrics = None
33
+ for epoch in range(epochs):
34
+ print(f"\nEpoch {epoch + 1}/{epochs}")
35
+ final_metrics = train_loop(
36
+ self.ae,
37
+ self.mlp,
38
+ train_loader,
39
+ self.ae_loss_fn,
40
+ self.ce_loss_fn,
41
+ self.optimizer,
42
+ self.device
43
+ )
44
+ return final_metrics
45
+
46
+ def predict(self, data_loader):
47
+ self.ae.eval()
48
+ self.mlp.eval()
49
+ all_predictions = []
50
+
51
+ with torch.no_grad():
52
+ for x, _ in data_loader:
53
+ x = x.to(self.device)
54
+ en, _ = self.ae(x)
55
+ out = self.mlp(en)
56
+ predicted = out.argmax(dim=1)
57
+ all_predictions.extend(predicted.cpu().numpy())
58
+
59
+ return all_predictions
60
+
61
+ def save(self, ae_path, mlp_path):
62
+ torch.save(self.ae.state_dict(), ae_path)
63
+ torch.save(self.mlp.state_dict(), mlp_path)
64
+ print(f"Models saved to:\n - {ae_path}\n - {mlp_path}")
crisgi/cnn/__init__.py ADDED
File without changes
@@ -0,0 +1,36 @@
1
+ from torch import nn
2
+
3
+
4
+ class AE(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.encoder = nn.Sequential(
8
+ nn.Conv2d(4, 16, 3, padding=1),
9
+ nn.BatchNorm2d(16),
10
+ nn.ReLU(),
11
+ nn.Dropout(0.2),
12
+ nn.MaxPool2d(2, 2),
13
+ nn.Conv2d(16, 32, 3, padding=1),
14
+ nn.BatchNorm2d(32),
15
+ nn.ReLU(),
16
+ nn.Dropout(0.2),
17
+ nn.MaxPool2d(2, 2),
18
+ nn.Conv2d(32, 64, 3, padding=1),
19
+ nn.BatchNorm2d(64),
20
+ nn.ReLU(),
21
+ nn.Dropout(0.2),
22
+ nn.MaxPool2d(2, 2)
23
+ )
24
+ self.decoder = nn.Sequential(
25
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
26
+ nn.ReLU(),
27
+ nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
28
+ nn.ReLU(),
29
+ nn.ConvTranspose2d(16, 4, 3, stride=2, padding=1, output_padding=1),
30
+ nn.Sigmoid()
31
+ )
32
+
33
+ def forward(self, x):
34
+ en = self.encoder(x)
35
+ de = self.decoder(en)
36
+ return en, de
@@ -0,0 +1,46 @@
1
+ from sklearn.metrics import auc, brier_score_loss, cohen_kappa_score, confusion_matrix, f1_score, roc_curve, roc_auc_score, precision_recall_curve, average_precision_score
2
+
3
+
4
+ def calculate_pred_metric(label, pred):
5
+ # label: ground truth
6
+ # pred: prediction
7
+
8
+ # calculate ROC AUC
9
+ roc_auc = roc_auc_score(label, pred)
10
+
11
+ # confusion matrix
12
+ tn, fp, fn, tp = confusion_matrix(label, pred).ravel()
13
+
14
+ # other metric
15
+ accuracy = (tp + tn) / (tp + fp + fn + tn)
16
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
17
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
18
+ sensitivity = recall
19
+ specificity = tn / (tn + fp)
20
+ ppv = precision
21
+ npv = tn / (tn + fn) if (tn + fn) > 0 else 0
22
+ f1 = f1_score(label, pred)
23
+ kappa = cohen_kappa_score(label, pred)
24
+ brier = brier_score_loss(label, pred)
25
+
26
+ # calculate PR curve and its AUC
27
+ precision_vals, recall_vals, _ = precision_recall_curve(label, pred)
28
+ avg_pr_auc = auc(recall_vals, precision_vals)
29
+
30
+ # store the result into dictionary
31
+ res = {
32
+ 'AUROC': roc_auc,
33
+ 'AUPRC': avg_pr_auc,
34
+ 'Accuracy': accuracy,
35
+ 'Precision': precision,
36
+ 'Recall': recall,
37
+ 'Sensitivity': sensitivity,
38
+ 'Specificity': specificity,
39
+ 'PPV': ppv,
40
+ 'NPV': npv,
41
+ 'F1_Score': f1,
42
+ 'Kappa': kappa,
43
+ 'Brier_Score': brier
44
+ }
45
+
46
+ return res
crisgi/cnn/mlp.py ADDED
@@ -0,0 +1,16 @@
1
+ from torch import nn
2
+
3
+
4
+ class MLP(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.mlp = nn.Sequential(
8
+ nn.Linear(64 * 28 * 28, 64),
9
+ nn.ReLU(),
10
+ nn.Linear(64, 2),
11
+ )
12
+
13
+ def forward(self, x):
14
+ flatt = x.view(x.size(0), -1)
15
+ res = self.mlp(flatt)
16
+ return res
crisgi/cnn/train.py ADDED
@@ -0,0 +1,42 @@
1
+ from crisgi.cnn.evalution_metrics import calculate_pred_metric
2
+
3
+
4
+ def train(ae, mlp, train_loader, AE_loss_function, CE_loss_function, optimizer, device):
5
+ ae.train()
6
+ mlp.train()
7
+ total_loss = 0
8
+ correct = 0
9
+ size = len(train_loader.dataset)
10
+ all_predictions = []
11
+ all_labels = []
12
+
13
+ for x, y in train_loader:
14
+ x = x.to(device)
15
+ y = y.to(device)
16
+ optimizer.zero_grad()
17
+ en, de = ae(x)
18
+ classification_res = mlp(en)
19
+
20
+ AE_loss = AE_loss_function(de, x)
21
+ CE_loss = CE_loss_function(classification_res, y)
22
+ loss = AE_loss + CE_loss
23
+
24
+ loss.backward()
25
+ optimizer.step()
26
+
27
+ total_loss += loss.item()
28
+ predicted = classification_res.argmax(1)
29
+ correct += (predicted == y).sum().item()
30
+
31
+ all_labels.extend(y.cpu().numpy())
32
+ all_predictions.extend(predicted.cpu().numpy())
33
+
34
+ avg_loss = total_loss / size
35
+ accuracy = 100 * correct / size
36
+
37
+ print(f"Total Train Loss: {avg_loss}")
38
+ print(f"Train Accuracy: {accuracy}%")
39
+
40
+ metrics = calculate_pred_metric(all_labels, all_predictions)
41
+
42
+ return metrics