MEDfl 2.0.4.dev0__py3-none-any.whl → 2.0.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/rw/client.py +98 -29
- MEDfl/rw/model.py +46 -74
- MEDfl/rw/server.py +71 -18
- MEDfl/rw/strategy.py +73 -78
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev2.dist-info}/METADATA +2 -14
- MEDfl-2.0.4.dev2.dist-info/RECORD +36 -0
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev2.dist-info}/WHEEL +1 -1
- MEDfl/rw/rwConfig.py +0 -21
- MEDfl/rw/verbose_server.py +0 -21
- Medfl/LearningManager/__init__.py +0 -13
- Medfl/LearningManager/client.py +0 -150
- Medfl/LearningManager/dynamicModal.py +0 -287
- Medfl/LearningManager/federated_dataset.py +0 -60
- Medfl/LearningManager/flpipeline.py +0 -192
- Medfl/LearningManager/model.py +0 -223
- Medfl/LearningManager/params.yaml +0 -14
- Medfl/LearningManager/params_optimiser.py +0 -442
- Medfl/LearningManager/plot.py +0 -229
- Medfl/LearningManager/server.py +0 -181
- Medfl/LearningManager/strategy.py +0 -82
- Medfl/LearningManager/utils.py +0 -331
- Medfl/NetManager/__init__.py +0 -10
- Medfl/NetManager/database_connector.py +0 -43
- Medfl/NetManager/dataset.py +0 -92
- Medfl/NetManager/flsetup.py +0 -320
- Medfl/NetManager/net_helper.py +0 -254
- Medfl/NetManager/net_manager_queries.py +0 -142
- Medfl/NetManager/network.py +0 -194
- Medfl/NetManager/node.py +0 -184
- Medfl/__init__.py +0 -3
- Medfl/scripts/__init__.py +0 -2
- Medfl/scripts/base.py +0 -30
- Medfl/scripts/create_db.py +0 -126
- medfl-2.0.4.dev0.dist-info/RECORD +0 -62
- {medfl-2.0.4.dev0.dist-info/licenses → MEDfl-2.0.4.dev2.dist-info}/LICENSE +0 -0
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev2.dist-info}/top_level.txt +0 -0
Medfl/LearningManager/server.py
DELETED
@@ -1,181 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
|
3
|
-
import copy
|
4
|
-
from typing import Dict, Optional, Tuple
|
5
|
-
|
6
|
-
import flwr as fl
|
7
|
-
import torch
|
8
|
-
|
9
|
-
from .client import FlowerClient
|
10
|
-
from .federated_dataset import FederatedDataset
|
11
|
-
from .model import Model
|
12
|
-
from .strategy import Strategy
|
13
|
-
|
14
|
-
|
15
|
-
class FlowerServer:
|
16
|
-
"""
|
17
|
-
A class representing the central server for Federated Learning using Flower.
|
18
|
-
|
19
|
-
Attributes:
|
20
|
-
global_model (Model): The global model that will be federated among clients.
|
21
|
-
strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
|
22
|
-
num_rounds (int): The number of federated learning rounds to perform.
|
23
|
-
num_clients (int): The number of clients participating in the federated learning process.
|
24
|
-
fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
|
25
|
-
diff_priv (bool): Whether differential privacy is used during the federated learning process.
|
26
|
-
accuracies (List[float]): A list to store the accuracy of the global model during each round.
|
27
|
-
losses (List[float]): A list to store the loss of the global model during each round.
|
28
|
-
flower_clients (List[FlowerClient]): A list to store the FlowerClient objects representing individual clients.
|
29
|
-
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(
|
33
|
-
self,
|
34
|
-
global_model: Model,
|
35
|
-
strategy: Strategy,
|
36
|
-
num_rounds: int,
|
37
|
-
num_clients: int,
|
38
|
-
fed_dataset: FederatedDataset,
|
39
|
-
diff_privacy: bool = False,
|
40
|
-
client_resources: Optional[Dict[str, float]] = {'num_cpus': 1, 'num_gpus': 0.0}
|
41
|
-
) -> None:
|
42
|
-
"""
|
43
|
-
Initialize a FlowerServer object with the specified parameters.
|
44
|
-
|
45
|
-
Args:
|
46
|
-
global_model (Model): The global model that will be federated among clients.
|
47
|
-
strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
|
48
|
-
num_rounds (int): The number of federated learning rounds to perform.
|
49
|
-
num_clients (int): The number of clients participating in the federated learning process.
|
50
|
-
fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
|
51
|
-
diff_privacy (bool, optional): Whether differential privacy is used during the federated learning process.
|
52
|
-
Default is False.
|
53
|
-
"""
|
54
|
-
self.device = torch.device(
|
55
|
-
f"cuda" if torch.cuda.is_available() else "cpu"
|
56
|
-
)
|
57
|
-
self.global_model = global_model
|
58
|
-
self.params = global_model.get_parameters()
|
59
|
-
self.global_model.model = global_model.model.to(self.device)
|
60
|
-
self.num_rounds = num_rounds
|
61
|
-
self.num_clients = num_clients
|
62
|
-
self.fed_dataset = fed_dataset
|
63
|
-
self.strategy = strategy
|
64
|
-
self.client_resources = client_resources
|
65
|
-
setattr(
|
66
|
-
self.strategy.strategy_object,
|
67
|
-
"min_available_clients",
|
68
|
-
self.num_clients,
|
69
|
-
)
|
70
|
-
setattr(
|
71
|
-
self.strategy.strategy_object,
|
72
|
-
"initial_parameters",
|
73
|
-
fl.common.ndarrays_to_parameters(self.params),
|
74
|
-
)
|
75
|
-
setattr(self.strategy.strategy_object, "evaluate_fn", self.evaluate)
|
76
|
-
self.fed_dataset = fed_dataset
|
77
|
-
self.diff_priv = diff_privacy
|
78
|
-
self.accuracies = []
|
79
|
-
self.losses = []
|
80
|
-
self.auc = []
|
81
|
-
self.flower_clients = []
|
82
|
-
self.validate()
|
83
|
-
|
84
|
-
def validate(self) -> None:
|
85
|
-
"""Validate global_model, strategy, num_clients, num_rounds, fed_dataset, diff_privacy"""
|
86
|
-
if not isinstance(self.global_model, Model):
|
87
|
-
raise TypeError("global_model argument must be a Model instance")
|
88
|
-
|
89
|
-
# if not isinstance(self.strategy, Strategy):
|
90
|
-
# print(self.strategy)
|
91
|
-
# print(isinstance(self.strategy, Strategy))
|
92
|
-
# raise TypeError("strategy argument must be a Strategy instance")
|
93
|
-
|
94
|
-
if not isinstance(self.num_clients, int):
|
95
|
-
raise TypeError("num_clients argument must be an int")
|
96
|
-
|
97
|
-
if not isinstance(self.num_rounds, int):
|
98
|
-
raise TypeError("num_rounds argument must be an int")
|
99
|
-
|
100
|
-
if not isinstance(self.diff_priv, bool):
|
101
|
-
raise TypeError("diff_priv argument must be a bool")
|
102
|
-
|
103
|
-
def client_fn(self, cid) -> FlowerClient:
|
104
|
-
"""
|
105
|
-
Return a FlowerClient object for a specific client ID.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
cid: The client ID.
|
109
|
-
|
110
|
-
Returns:
|
111
|
-
FlowerClient: A FlowerClient object representing the individual client.
|
112
|
-
"""
|
113
|
-
|
114
|
-
device = torch.device(
|
115
|
-
f"cuda:{int(cid) % 4}" if torch.cuda.is_available() else "cpu"
|
116
|
-
)
|
117
|
-
client_model = copy.deepcopy(self.global_model)
|
118
|
-
|
119
|
-
trainloader = self.fed_dataset.trainloaders[int(cid)]
|
120
|
-
valloader = self.fed_dataset.valloaders[int(cid)]
|
121
|
-
# this helps in making plots
|
122
|
-
|
123
|
-
client = FlowerClient(
|
124
|
-
cid, client_model, trainloader, valloader, self.diff_priv
|
125
|
-
)
|
126
|
-
self.flower_clients.append(client)
|
127
|
-
return client
|
128
|
-
|
129
|
-
def evaluate(
|
130
|
-
self,
|
131
|
-
server_round: int,
|
132
|
-
parameters: fl.common.NDArrays,
|
133
|
-
config: Dict[str, fl.common.Scalar],
|
134
|
-
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
|
135
|
-
"""
|
136
|
-
Evaluate the global model on the validation dataset and update the accuracies and losses.
|
137
|
-
|
138
|
-
Args:
|
139
|
-
server_round (int): The current round of the federated learning process.
|
140
|
-
parameters (fl.common.NDArrays): The global model parameters.
|
141
|
-
config (Dict[str, fl.common.Scalar]): Configuration dictionary.
|
142
|
-
|
143
|
-
Returns:
|
144
|
-
Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: The evaluation loss and accuracy.
|
145
|
-
"""
|
146
|
-
testloader = self.fed_dataset.valloaders[0]
|
147
|
-
|
148
|
-
self.global_model.set_parameters(
|
149
|
-
parameters
|
150
|
-
) # Update model with the latest parameters
|
151
|
-
loss, accuracy ,auc = self.global_model.evaluate(testloader, self.device)
|
152
|
-
self.auc.append(auc)
|
153
|
-
self.losses.append(loss)
|
154
|
-
self.accuracies.append(accuracy)
|
155
|
-
|
156
|
-
return loss, {"accuracy": accuracy}
|
157
|
-
|
158
|
-
def run(self) -> None:
|
159
|
-
"""
|
160
|
-
Run the federated learning process using Flower simulation.
|
161
|
-
|
162
|
-
Returns:
|
163
|
-
History: The history of the accuracies and losses during the training of each node
|
164
|
-
"""
|
165
|
-
# Increase the object store memory to the minimum allowed value or higher
|
166
|
-
ray_init_args = {"include_dashboard": False
|
167
|
-
, "object_store_memory": 78643200
|
168
|
-
}
|
169
|
-
self.fed_dataset.eng = None
|
170
|
-
|
171
|
-
history = fl.simulation.start_simulation(
|
172
|
-
client_fn=self.client_fn,
|
173
|
-
num_clients=self.num_clients,
|
174
|
-
config=fl.server.ServerConfig(self.num_rounds),
|
175
|
-
strategy=self.strategy.strategy_object,
|
176
|
-
ray_init_args=ray_init_args,
|
177
|
-
client_resources = self.client_resources
|
178
|
-
)
|
179
|
-
|
180
|
-
return history
|
181
|
-
|
@@ -1,82 +0,0 @@
|
|
1
|
-
|
2
|
-
from collections import OrderedDict
|
3
|
-
from typing import Dict, List, Optional, Tuple
|
4
|
-
|
5
|
-
import flwr as fl
|
6
|
-
import numpy as np
|
7
|
-
|
8
|
-
import optuna
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
class Strategy:
|
14
|
-
"""
|
15
|
-
A class representing a strategy for Federated Learning.
|
16
|
-
|
17
|
-
Attributes:
|
18
|
-
name (str): The name of the strategy. Default is "FedAvg".
|
19
|
-
fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
|
20
|
-
fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
|
21
|
-
min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
|
22
|
-
min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
|
23
|
-
min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
|
24
|
-
initial_parameters (Optional[]): The initial parameters of the server model
|
25
|
-
Methods:
|
26
|
-
|
27
|
-
"""
|
28
|
-
|
29
|
-
def __init__(
|
30
|
-
self,
|
31
|
-
name: str = "FedAvg",
|
32
|
-
fraction_fit: float = 1.0,
|
33
|
-
fraction_evaluate: float = 1.0,
|
34
|
-
min_fit_clients: int = 2,
|
35
|
-
min_evaluate_clients: int = 2,
|
36
|
-
min_available_clients: int = 2,
|
37
|
-
initial_parameters = [],
|
38
|
-
evaluation_methode = "centralized"
|
39
|
-
) -> None:
|
40
|
-
"""
|
41
|
-
Initialize a Strategy object with the specified parameters.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
name (str): The name of the strategy. Default is "FedAvg".
|
45
|
-
fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
|
46
|
-
fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
|
47
|
-
min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
|
48
|
-
min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
|
49
|
-
min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
|
50
|
-
initial_parameters (Optional[]): The initial parametres of the server model
|
51
|
-
evaluation_methode ( "centralized" | "distributed")
|
52
|
-
"""
|
53
|
-
self.fraction_fit = fraction_fit
|
54
|
-
self.fraction_evaluate = fraction_evaluate
|
55
|
-
self.min_fit_clients = min_fit_clients
|
56
|
-
self.min_evaluate_clients = min_evaluate_clients
|
57
|
-
self.min_available_clients = min_available_clients
|
58
|
-
self.initial_parameters = initial_parameters
|
59
|
-
self.evaluate_fn = None
|
60
|
-
self.name = name
|
61
|
-
|
62
|
-
def optuna_fed_optimization(self, direction:str , hpo_rate:int , params_config):
|
63
|
-
self.study = optuna.create_study(direction=direction)
|
64
|
-
self.hpo_rate = hpo_rate
|
65
|
-
self.params_config = params_config
|
66
|
-
|
67
|
-
|
68
|
-
def create_strategy(self):
|
69
|
-
self.strategy_object = self.get_strategy_by_name()(
|
70
|
-
fraction_fit=self.fraction_fit,
|
71
|
-
fraction_evaluate=self.fraction_evaluate,
|
72
|
-
min_fit_clients=self.min_fit_clients,
|
73
|
-
min_evaluate_clients=self.min_evaluate_clients,
|
74
|
-
min_available_clients=self.min_available_clients,
|
75
|
-
initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
|
76
|
-
evaluate_fn=self.evaluate_fn
|
77
|
-
)
|
78
|
-
def get_strategy_by_name(self):
|
79
|
-
return eval(f"fl.server.strategy.{self.name}")
|
80
|
-
|
81
|
-
|
82
|
-
|
Medfl/LearningManager/utils.py
DELETED
@@ -1,331 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
|
3
|
-
import pkg_resources
|
4
|
-
import torch
|
5
|
-
import yaml
|
6
|
-
from sklearn.metrics import *
|
7
|
-
from yaml.loader import SafeLoader
|
8
|
-
|
9
|
-
|
10
|
-
from MEDfl.NetManager.database_connector import DatabaseManager
|
11
|
-
|
12
|
-
# from scripts.base import *
|
13
|
-
import json
|
14
|
-
|
15
|
-
|
16
|
-
import pandas as pd
|
17
|
-
import numpy as np
|
18
|
-
|
19
|
-
import os
|
20
|
-
import configparser
|
21
|
-
|
22
|
-
import subprocess
|
23
|
-
import ast
|
24
|
-
|
25
|
-
from sqlalchemy import text
|
26
|
-
|
27
|
-
|
28
|
-
# Get the directory of the current script
|
29
|
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
30
|
-
|
31
|
-
# Load configuration from the config file
|
32
|
-
yaml_path = os.path.join(current_directory, 'params.yaml')
|
33
|
-
|
34
|
-
with open(yaml_path) as g:
|
35
|
-
params = yaml.load(g, Loader=SafeLoader)
|
36
|
-
|
37
|
-
# global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
|
38
|
-
# with open(global_yaml_path) as g:
|
39
|
-
# global_params = yaml.load(g, Loader=SafeLoader)
|
40
|
-
|
41
|
-
|
42
|
-
# Default path for the config file
|
43
|
-
DEFAULT_CONFIG_PATH = 'db_config.ini'
|
44
|
-
|
45
|
-
|
46
|
-
def load_db_config_dep():
|
47
|
-
config = os.environ.get('MEDfl_DB_CONFIG')
|
48
|
-
|
49
|
-
if config:
|
50
|
-
return ast.literal_eval(config)
|
51
|
-
else:
|
52
|
-
raise ValueError(f"MEDfl db config not found")
|
53
|
-
|
54
|
-
# Function to allow users to set config path programmatically
|
55
|
-
|
56
|
-
|
57
|
-
def set_db_config_dep(config_path):
|
58
|
-
config = configparser.ConfigParser()
|
59
|
-
config.read(config_path)
|
60
|
-
if (config['sqllite']):
|
61
|
-
os.environ['MEDfl_DB_CONFIG'] = str(dict(config['sqllite']))
|
62
|
-
else:
|
63
|
-
raise ValueError(f"mysql key not found in file '{config_path}'")
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
def load_db_config():
|
68
|
-
"""Read a dictionary from an environment variable."""
|
69
|
-
obj_str = os.getenv("MEDfl_DB_CONFIG")
|
70
|
-
if obj_str is not None:
|
71
|
-
return ast.literal_eval(obj_str)
|
72
|
-
else:
|
73
|
-
raise ValueError(f"Environment variable MEDfl_DB_CONFIG not found")
|
74
|
-
|
75
|
-
# Function to allow users to set config path programmatically
|
76
|
-
|
77
|
-
|
78
|
-
def set_db_config(config_path):
|
79
|
-
obj = {"database" : config_path}
|
80
|
-
|
81
|
-
"""Store a dictionary as a string in an environment variable."""
|
82
|
-
obj_str = str(obj)
|
83
|
-
os.environ['MEDfl_DB_CONFIG'] = obj_str
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
# Create databas
|
91
|
-
|
92
|
-
|
93
|
-
def create_MEDfl_db():
|
94
|
-
script_path = os.path.join(os.path.dirname(
|
95
|
-
__file__), 'scripts', 'create_db.sh')
|
96
|
-
subprocess.run(['sh', script_path], check=True)
|
97
|
-
|
98
|
-
|
99
|
-
def custom_classification_report(y_true, y_pred_prob):
|
100
|
-
"""
|
101
|
-
Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
|
102
|
-
F1-score, false positive rate, and true positive rate.
|
103
|
-
|
104
|
-
Args:
|
105
|
-
y_true (array-like): True labels.
|
106
|
-
y_pred (array-like): Predicted labels.
|
107
|
-
|
108
|
-
Returns:
|
109
|
-
dict: A dictionary containing custom classification report metrics.
|
110
|
-
"""
|
111
|
-
y_pred = (y_pred_prob).round(
|
112
|
-
) # Round absolute values of predicted probabilities to the nearest integer
|
113
|
-
|
114
|
-
auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
|
115
|
-
|
116
|
-
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
117
|
-
|
118
|
-
# Accuracy
|
119
|
-
denominator_acc = tp + tn + fp + fn
|
120
|
-
acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
|
121
|
-
|
122
|
-
# Sensitivity/Recall
|
123
|
-
denominator_sen = tp + fn
|
124
|
-
sen = tp / denominator_sen if denominator_sen != 0 else 0.0
|
125
|
-
|
126
|
-
# Specificity
|
127
|
-
denominator_sp = tn + fp
|
128
|
-
sp = tn / denominator_sp if denominator_sp != 0 else 0.0
|
129
|
-
|
130
|
-
# PPV/Precision
|
131
|
-
denominator_ppv = tp + fp
|
132
|
-
ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
|
133
|
-
|
134
|
-
# NPV
|
135
|
-
denominator_npv = tn + fn
|
136
|
-
npv = tn / denominator_npv if denominator_npv != 0 else 0.0
|
137
|
-
|
138
|
-
# F1 Score
|
139
|
-
denominator_f1 = sen + ppv
|
140
|
-
f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
|
141
|
-
|
142
|
-
# False Positive Rate
|
143
|
-
denominator_fpr = fp + tn
|
144
|
-
fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
|
145
|
-
|
146
|
-
# True Positive Rate
|
147
|
-
denominator_tpr = tp + fn
|
148
|
-
tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
|
149
|
-
|
150
|
-
return {
|
151
|
-
"confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
|
152
|
-
"Accuracy": round(acc, 3),
|
153
|
-
"Sensitivity/Recall": round(sen, 3),
|
154
|
-
"Specificity": round(sp, 3),
|
155
|
-
"PPV/Precision": round(ppv, 3),
|
156
|
-
"NPV": round(npv, 3),
|
157
|
-
"F1-score": round(f1, 3),
|
158
|
-
"False positive rate": round(fpr, 3),
|
159
|
-
"True positive rate": round(tpr, 3),
|
160
|
-
"auc": auc
|
161
|
-
}
|
162
|
-
|
163
|
-
|
164
|
-
def test(model, test_loader, device=torch.device("cpu")):
|
165
|
-
"""
|
166
|
-
Evaluate a model using a test loader and return a custom classification report.
|
167
|
-
|
168
|
-
Args:
|
169
|
-
model (torch.nn.Module): PyTorch model to evaluate.
|
170
|
-
test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
|
171
|
-
device (torch.device, optional): Device for model evaluation. Default is "cpu".
|
172
|
-
|
173
|
-
Returns:
|
174
|
-
dict: A dictionary containing custom classification report metrics.
|
175
|
-
"""
|
176
|
-
|
177
|
-
model.eval()
|
178
|
-
with torch.no_grad():
|
179
|
-
X_test, y_test = test_loader.dataset[:][0].to(
|
180
|
-
device), test_loader.dataset[:][1].to(device)
|
181
|
-
y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
|
182
|
-
|
183
|
-
return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
|
184
|
-
|
185
|
-
|
186
|
-
column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
|
187
|
-
|
188
|
-
|
189
|
-
def empty_db():
|
190
|
-
"""
|
191
|
-
Empty the database by deleting records from multiple tables and resetting auto-increment counters.
|
192
|
-
|
193
|
-
Returns:
|
194
|
-
None
|
195
|
-
"""
|
196
|
-
db_manager = DatabaseManager()
|
197
|
-
db_manager.connect()
|
198
|
-
my_eng = db_manager.get_connection()
|
199
|
-
|
200
|
-
# my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
|
201
|
-
my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
|
202
|
-
my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
|
203
|
-
my_eng.execute(text(f"DELETE FROM {'Networks'}"))
|
204
|
-
my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
|
205
|
-
|
206
|
-
my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
|
207
|
-
my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
|
208
|
-
my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
|
209
|
-
my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
|
210
|
-
my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
|
211
|
-
my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
|
212
|
-
my_eng.execute(text(f"DELETE FROM {'testResults'}"))
|
213
|
-
my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
|
214
|
-
my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
|
215
|
-
|
216
|
-
|
217
|
-
def get_pipeline_from_name(name):
|
218
|
-
"""
|
219
|
-
Get the pipeline ID from its name in the database.
|
220
|
-
|
221
|
-
Args:
|
222
|
-
name (str): Name of the pipeline.
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
int: ID of the pipeline.
|
226
|
-
"""
|
227
|
-
db_manager = DatabaseManager()
|
228
|
-
db_manager.connect()
|
229
|
-
my_eng = db_manager.get_connection()
|
230
|
-
|
231
|
-
NodeId = int(
|
232
|
-
pd.read_sql(
|
233
|
-
text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
|
234
|
-
).iloc[0, 0]
|
235
|
-
)
|
236
|
-
return NodeId
|
237
|
-
|
238
|
-
|
239
|
-
def get_pipeline_confusion_matrix(pipeline_id):
|
240
|
-
"""
|
241
|
-
Get the global confusion matrix for a pipeline based on test results.
|
242
|
-
|
243
|
-
Args:
|
244
|
-
pipeline_id (int): ID of the pipeline.
|
245
|
-
|
246
|
-
Returns:
|
247
|
-
dict: A dictionary representing the global confusion matrix.
|
248
|
-
"""
|
249
|
-
db_manager = DatabaseManager()
|
250
|
-
db_manager.connect()
|
251
|
-
my_eng = db_manager.get_connection()
|
252
|
-
|
253
|
-
data = pd.read_sql(
|
254
|
-
text(
|
255
|
-
f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
|
256
|
-
)
|
257
|
-
|
258
|
-
# Convert the column of strings into a list of dictionaries representing confusion matrices
|
259
|
-
confusion_matrices = [
|
260
|
-
json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
|
261
|
-
]
|
262
|
-
|
263
|
-
# Initialize variables for global confusion matrix
|
264
|
-
global_TP = global_FP = global_FN = global_TN = 0
|
265
|
-
|
266
|
-
# Iterate through each dictionary and sum the corresponding values for each category
|
267
|
-
for matrix in confusion_matrices:
|
268
|
-
global_TP += matrix['TP']
|
269
|
-
global_FP += matrix['FP']
|
270
|
-
global_FN += matrix['FN']
|
271
|
-
global_TN += matrix['TN']
|
272
|
-
|
273
|
-
# Create a global confusion matrix as a dictionary
|
274
|
-
global_confusion_matrix = {
|
275
|
-
'TP': global_TP,
|
276
|
-
'FP': global_FP,
|
277
|
-
'FN': global_FN,
|
278
|
-
'TN': global_TN
|
279
|
-
}
|
280
|
-
# Return the list of dictionaries representing confusion matrices
|
281
|
-
return global_confusion_matrix
|
282
|
-
|
283
|
-
|
284
|
-
def get_node_confusion_matrix(pipeline_id, node_name):
|
285
|
-
"""
|
286
|
-
Get the confusion matrix for a specific node in a pipeline based on test results.
|
287
|
-
|
288
|
-
Args:
|
289
|
-
pipeline_id (int): ID of the pipeline.
|
290
|
-
node_name (str): Name of the node.
|
291
|
-
|
292
|
-
Returns:
|
293
|
-
dict: A dictionary representing the confusion matrix for the specified node.
|
294
|
-
"""
|
295
|
-
db_manager = DatabaseManager()
|
296
|
-
db_manager.connect()
|
297
|
-
my_eng = db_manager.get_connection()
|
298
|
-
|
299
|
-
data = pd.read_sql(
|
300
|
-
text(
|
301
|
-
f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
|
302
|
-
)
|
303
|
-
|
304
|
-
# Convert the column of strings into a list of dictionaries representing confusion matrices
|
305
|
-
confusion_matrices = [
|
306
|
-
json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
|
307
|
-
]
|
308
|
-
|
309
|
-
# Return the list of dictionaries representing confusion matrices
|
310
|
-
return confusion_matrices[0]
|
311
|
-
|
312
|
-
|
313
|
-
def get_pipeline_result(pipeline_id):
|
314
|
-
"""
|
315
|
-
Get the test results for a pipeline.
|
316
|
-
|
317
|
-
Args:
|
318
|
-
pipeline_id (int): ID of the pipeline.
|
319
|
-
|
320
|
-
Returns:
|
321
|
-
pandas.DataFrame: DataFrame containing test results for the specified pipeline.
|
322
|
-
"""
|
323
|
-
db_manager = DatabaseManager()
|
324
|
-
db_manager.connect()
|
325
|
-
my_eng = db_manager.get_connection()
|
326
|
-
|
327
|
-
data = pd.read_sql(
|
328
|
-
text(
|
329
|
-
f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
|
330
|
-
)
|
331
|
-
return data
|
Medfl/NetManager/__init__.py
DELETED
@@ -1,10 +0,0 @@
|
|
1
|
-
# # MEDfl/NetworkManager/__init__.py
|
2
|
-
|
3
|
-
# # Import modules from this package
|
4
|
-
# from .dataset import *
|
5
|
-
# from .flsetup import *
|
6
|
-
# from .net_helper import *
|
7
|
-
# from .net_manager_queries import *
|
8
|
-
# from .network import *
|
9
|
-
# from .node import *
|
10
|
-
# from .database_connector import *
|
@@ -1,43 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import subprocess
|
3
|
-
from sqlalchemy import create_engine
|
4
|
-
from configparser import ConfigParser
|
5
|
-
|
6
|
-
class DatabaseManager:
|
7
|
-
def __init__(self):
|
8
|
-
from MEDfl.LearningManager.utils import load_db_config
|
9
|
-
db_config = load_db_config()
|
10
|
-
if db_config:
|
11
|
-
self.config = db_config
|
12
|
-
else:
|
13
|
-
self.config = None
|
14
|
-
self.engine = None
|
15
|
-
|
16
|
-
def connect(self):
|
17
|
-
if not self.config:
|
18
|
-
raise ValueError("Database configuration not loaded. Use load_db_config() or set_config_path() first.")
|
19
|
-
# Assuming the SQLite database file path is provided in the config with the key 'database'
|
20
|
-
database_path = self.config['database']
|
21
|
-
connection_string = f"sqlite:///{database_path}"
|
22
|
-
self.engine = create_engine(connection_string, pool_pre_ping=True)
|
23
|
-
|
24
|
-
def get_connection(self):
|
25
|
-
if not self.engine:
|
26
|
-
self.connect()
|
27
|
-
return self.engine.connect()
|
28
|
-
|
29
|
-
def create_MEDfl_db(self, path_to_csv):
|
30
|
-
# Get the directory of the current script
|
31
|
-
current_directory = os.path.dirname(__file__)
|
32
|
-
|
33
|
-
# Define the path to the create_db.py script
|
34
|
-
create_db_script_path = os.path.join(current_directory, '..', 'scripts', 'create_db.py')
|
35
|
-
|
36
|
-
# Execute the create_db.py script
|
37
|
-
subprocess.run(['python', create_db_script_path, path_to_csv], check=True)
|
38
|
-
|
39
|
-
return
|
40
|
-
|
41
|
-
def close(self):
|
42
|
-
if self.engine:
|
43
|
-
self.engine.dispose()
|