MEDfl 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Medfl/LearningManager/__init__.py +13 -0
- Medfl/LearningManager/client.py +150 -0
- Medfl/LearningManager/dynamicModal.py +287 -0
- Medfl/LearningManager/federated_dataset.py +57 -0
- Medfl/LearningManager/flpipeline.py +189 -0
- Medfl/LearningManager/model.py +223 -0
- Medfl/LearningManager/params_optimiser.py +442 -0
- Medfl/LearningManager/plot.py +229 -0
- Medfl/LearningManager/server.py +179 -0
- Medfl/LearningManager/strategy.py +82 -0
- Medfl/LearningManager/utils.py +233 -0
- Medfl/NetManager/__init__.py +9 -0
- Medfl/NetManager/dataset.py +91 -0
- Medfl/NetManager/flsetup.py +304 -0
- Medfl/NetManager/net_helper.py +243 -0
- Medfl/NetManager/net_manager_queries.py +137 -0
- Medfl/NetManager/network.py +160 -0
- Medfl/NetManager/node.py +181 -0
- Medfl/__init__.py +2 -0
- {Medfl-0.1.0.dist-info → Medfl-0.1.4.dist-info}/METADATA +19 -18
- Medfl-0.1.4.dist-info/RECORD +29 -0
- {Medfl-0.1.0.dist-info → Medfl-0.1.4.dist-info}/WHEEL +1 -1
- {Medfl-0.1.0.dist-info → Medfl-0.1.4.dist-info}/top_level.txt +1 -0
- Medfl-0.1.0.dist-info/RECORD +0 -10
- {Medfl-0.1.0.data → Medfl-0.1.4.data}/scripts/setup_mysql.sh +0 -0
@@ -0,0 +1,233 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
import pkg_resources
|
4
|
+
import torch
|
5
|
+
import yaml
|
6
|
+
from sklearn.metrics import *
|
7
|
+
from yaml.loader import SafeLoader
|
8
|
+
|
9
|
+
from scripts.base import *
|
10
|
+
import json
|
11
|
+
|
12
|
+
|
13
|
+
import pandas as pd
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
yaml_path = pkg_resources.resource_filename(__name__, "params.yaml")
|
17
|
+
with open(yaml_path) as g:
|
18
|
+
params = yaml.load(g, Loader=SafeLoader)
|
19
|
+
|
20
|
+
global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
|
21
|
+
with open(global_yaml_path) as g:
|
22
|
+
global_params = yaml.load(g, Loader=SafeLoader)
|
23
|
+
|
24
|
+
|
25
|
+
def custom_classification_report(y_true, y_pred_prob):
|
26
|
+
"""
|
27
|
+
Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
|
28
|
+
F1-score, false positive rate, and true positive rate.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
y_true (array-like): True labels.
|
32
|
+
y_pred (array-like): Predicted labels.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
dict: A dictionary containing custom classification report metrics.
|
36
|
+
"""
|
37
|
+
y_pred = (y_pred_prob).round() # Round absolute values of predicted probabilities to the nearest integer
|
38
|
+
|
39
|
+
auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
|
40
|
+
|
41
|
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
42
|
+
|
43
|
+
# Accuracy
|
44
|
+
denominator_acc = tp + tn + fp + fn
|
45
|
+
acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
|
46
|
+
|
47
|
+
# Sensitivity/Recall
|
48
|
+
denominator_sen = tp + fn
|
49
|
+
sen = tp / denominator_sen if denominator_sen != 0 else 0.0
|
50
|
+
|
51
|
+
# Specificity
|
52
|
+
denominator_sp = tn + fp
|
53
|
+
sp = tn / denominator_sp if denominator_sp != 0 else 0.0
|
54
|
+
|
55
|
+
# PPV/Precision
|
56
|
+
denominator_ppv = tp + fp
|
57
|
+
ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
|
58
|
+
|
59
|
+
# NPV
|
60
|
+
denominator_npv = tn + fn
|
61
|
+
npv = tn / denominator_npv if denominator_npv != 0 else 0.0
|
62
|
+
|
63
|
+
# F1 Score
|
64
|
+
denominator_f1 = sen + ppv
|
65
|
+
f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
|
66
|
+
|
67
|
+
# False Positive Rate
|
68
|
+
denominator_fpr = fp + tn
|
69
|
+
fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
|
70
|
+
|
71
|
+
# True Positive Rate
|
72
|
+
denominator_tpr = tp + fn
|
73
|
+
tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
|
74
|
+
|
75
|
+
return {
|
76
|
+
"confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
|
77
|
+
"Accuracy": round(acc, 3),
|
78
|
+
"Sensitivity/Recall": round(sen, 3),
|
79
|
+
"Specificity": round(sp, 3),
|
80
|
+
"PPV/Precision": round(ppv, 3),
|
81
|
+
"NPV": round(npv, 3),
|
82
|
+
"F1-score": round(f1, 3),
|
83
|
+
"False positive rate": round(fpr, 3),
|
84
|
+
"True positive rate": round(tpr, 3),
|
85
|
+
"auc": auc
|
86
|
+
}
|
87
|
+
|
88
|
+
|
89
|
+
def test(model, test_loader, device=torch.device("cpu")):
|
90
|
+
"""
|
91
|
+
Evaluate a model using a test loader and return a custom classification report.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
model (torch.nn.Module): PyTorch model to evaluate.
|
95
|
+
test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
|
96
|
+
device (torch.device, optional): Device for model evaluation. Default is "cpu".
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
dict: A dictionary containing custom classification report metrics.
|
100
|
+
"""
|
101
|
+
|
102
|
+
model.eval()
|
103
|
+
with torch.no_grad():
|
104
|
+
X_test, y_test = test_loader.dataset[:][0].to(device), test_loader.dataset[:][1].to(device)
|
105
|
+
y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
|
106
|
+
|
107
|
+
return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
|
108
|
+
|
109
|
+
|
110
|
+
column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
|
111
|
+
|
112
|
+
|
113
|
+
def empty_db():
|
114
|
+
"""
|
115
|
+
Empty the database by deleting records from multiple tables and resetting auto-increment counters.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
None
|
119
|
+
"""
|
120
|
+
|
121
|
+
# my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
|
122
|
+
my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
|
123
|
+
my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
|
124
|
+
my_eng.execute(text(f"DELETE FROM {'Networks'}"))
|
125
|
+
my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
|
126
|
+
|
127
|
+
my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
|
128
|
+
my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
|
129
|
+
my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
|
130
|
+
my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
|
131
|
+
my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
|
132
|
+
my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
|
133
|
+
my_eng.execute(text(f"DELETE FROM {'testResults'}"))
|
134
|
+
my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
|
135
|
+
my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
|
136
|
+
|
137
|
+
def get_pipeline_from_name(name):
|
138
|
+
"""
|
139
|
+
Get the pipeline ID from its name in the database.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
name (str): Name of the pipeline.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
int: ID of the pipeline.
|
146
|
+
"""
|
147
|
+
|
148
|
+
NodeId = int(
|
149
|
+
pd.read_sql(
|
150
|
+
text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
|
151
|
+
).iloc[0, 0]
|
152
|
+
)
|
153
|
+
return NodeId
|
154
|
+
|
155
|
+
def get_pipeline_confusion_matrix(pipeline_id):
|
156
|
+
"""
|
157
|
+
Get the global confusion matrix for a pipeline based on test results.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
pipeline_id (int): ID of the pipeline.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
dict: A dictionary representing the global confusion matrix.
|
164
|
+
"""
|
165
|
+
|
166
|
+
data = pd.read_sql(
|
167
|
+
text(f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
|
168
|
+
)
|
169
|
+
|
170
|
+
# Convert the column of strings into a list of dictionaries representing confusion matrices
|
171
|
+
confusion_matrices = [
|
172
|
+
json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
|
173
|
+
]
|
174
|
+
|
175
|
+
# Initialize variables for global confusion matrix
|
176
|
+
global_TP = global_FP = global_FN = global_TN = 0
|
177
|
+
|
178
|
+
# Iterate through each dictionary and sum the corresponding values for each category
|
179
|
+
for matrix in confusion_matrices:
|
180
|
+
global_TP += matrix['TP']
|
181
|
+
global_FP += matrix['FP']
|
182
|
+
global_FN += matrix['FN']
|
183
|
+
global_TN += matrix['TN']
|
184
|
+
|
185
|
+
# Create a global confusion matrix as a dictionary
|
186
|
+
global_confusion_matrix = {
|
187
|
+
'TP': global_TP,
|
188
|
+
'FP': global_FP,
|
189
|
+
'FN': global_FN,
|
190
|
+
'TN': global_TN
|
191
|
+
}
|
192
|
+
# Return the list of dictionaries representing confusion matrices
|
193
|
+
return global_confusion_matrix
|
194
|
+
|
195
|
+
def get_node_confusion_matrix(pipeline_id , node_name):
|
196
|
+
"""
|
197
|
+
Get the confusion matrix for a specific node in a pipeline based on test results.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
pipeline_id (int): ID of the pipeline.
|
201
|
+
node_name (str): Name of the node.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
dict: A dictionary representing the confusion matrix for the specified node.
|
205
|
+
"""
|
206
|
+
|
207
|
+
data = pd.read_sql(
|
208
|
+
text(f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
|
209
|
+
)
|
210
|
+
|
211
|
+
# Convert the column of strings into a list of dictionaries representing confusion matrices
|
212
|
+
confusion_matrices = [
|
213
|
+
json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
|
214
|
+
]
|
215
|
+
|
216
|
+
|
217
|
+
# Return the list of dictionaries representing confusion matrices
|
218
|
+
return confusion_matrices[0]
|
219
|
+
|
220
|
+
def get_pipeline_result(pipeline_id):
|
221
|
+
"""
|
222
|
+
Get the test results for a pipeline.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
pipeline_id (int): ID of the pipeline.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
pandas.DataFrame: DataFrame containing test results for the specified pipeline.
|
229
|
+
"""
|
230
|
+
data = pd.read_sql(
|
231
|
+
text(f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
|
232
|
+
)
|
233
|
+
return data
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
from sqlalchemy import text
|
3
|
+
|
4
|
+
from scripts.base import my_eng
|
5
|
+
from .net_helper import *
|
6
|
+
from .net_manager_queries import (DELETE_DATASET, INSERT_DATASET,
|
7
|
+
SELECT_ALL_DATASET_NAMES)
|
8
|
+
|
9
|
+
|
10
|
+
class DataSet:
|
11
|
+
def __init__(self, name: str, path: str, engine=None):
|
12
|
+
"""
|
13
|
+
Initialize a DataSet object.
|
14
|
+
|
15
|
+
:param name: The name of the dataset.
|
16
|
+
:type name: str
|
17
|
+
:param path: The file path of the dataset CSV file.
|
18
|
+
:type path: str
|
19
|
+
"""
|
20
|
+
self.name = name
|
21
|
+
self.path = path
|
22
|
+
self.engine = engine if engine is not None else my_eng
|
23
|
+
|
24
|
+
def validate(self):
|
25
|
+
"""
|
26
|
+
Validate name and path attributes.
|
27
|
+
|
28
|
+
:raises TypeError: If name or path is not a string.
|
29
|
+
"""
|
30
|
+
if not isinstance(self.name, str):
|
31
|
+
raise TypeError("name argument must be a string")
|
32
|
+
|
33
|
+
if not isinstance(self.path, str):
|
34
|
+
raise TypeError("path argument must be a string")
|
35
|
+
|
36
|
+
def upload_dataset(self, NodeId=-1):
|
37
|
+
"""
|
38
|
+
Upload the dataset to the database.
|
39
|
+
|
40
|
+
:param NodeId: The NodeId associated with the dataset.
|
41
|
+
:type NodeId: int
|
42
|
+
|
43
|
+
Notes:
|
44
|
+
- Assumes the file at self.path is a valid CSV file.
|
45
|
+
- The dataset is uploaded to the 'DataSets' table in the database.
|
46
|
+
"""
|
47
|
+
|
48
|
+
data_df = pd.read_csv(self.path)
|
49
|
+
nodeId = NodeId
|
50
|
+
columns = data_df.columns.tolist()
|
51
|
+
|
52
|
+
|
53
|
+
data_df = process_eicu(data_df)
|
54
|
+
for index, row in data_df.iterrows():
|
55
|
+
query_1 = "INSERT INTO DataSets(DataSetName,nodeId," + "".join(
|
56
|
+
f"{x}," for x in columns
|
57
|
+
)
|
58
|
+
query_2 = f" VALUES ('{self.name}',{nodeId}, " + "".join(
|
59
|
+
f"{is_str(data_df, row, x)}," for x in columns
|
60
|
+
)
|
61
|
+
query = query_1[:-1] + ")" + query_2[:-1] + ")"
|
62
|
+
|
63
|
+
self.engine.execute(text(query))
|
64
|
+
|
65
|
+
def delete_dataset(self):
|
66
|
+
"""
|
67
|
+
Delete the dataset from the database.
|
68
|
+
|
69
|
+
Notes:
|
70
|
+
- Assumes the dataset name is unique in the 'DataSets' table.
|
71
|
+
"""
|
72
|
+
self.engine.execute(text(DELETE_DATASET), {"name": self.name})
|
73
|
+
|
74
|
+
def update_data(self):
|
75
|
+
"""
|
76
|
+
Update the data in the dataset.
|
77
|
+
|
78
|
+
Not implemented yet.
|
79
|
+
"""
|
80
|
+
pass
|
81
|
+
|
82
|
+
@staticmethod
|
83
|
+
def list_alldatasets(engine):
|
84
|
+
"""
|
85
|
+
List all dataset names from the 'DataSets' table.
|
86
|
+
|
87
|
+
:returns: A DataFrame containing the names of all datasets in the 'DataSets' table.
|
88
|
+
:rtype: pd.DataFrame
|
89
|
+
"""
|
90
|
+
res = pd.read_sql(text(SELECT_ALL_DATASET_NAMES), engine)
|
91
|
+
return res
|
@@ -0,0 +1,304 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
|
3
|
+
|
4
|
+
from torch.utils.data import random_split, DataLoader, Dataset
|
5
|
+
|
6
|
+
from Medfl.LearningManager.federated_dataset import FederatedDataset
|
7
|
+
from .net_helper import *
|
8
|
+
from .net_manager_queries import * # Import the sql_queries module
|
9
|
+
from .network import Network
|
10
|
+
|
11
|
+
from .node import Node
|
12
|
+
|
13
|
+
|
14
|
+
class FLsetup:
|
15
|
+
def __init__(self, name: str, description: str, network: Network):
|
16
|
+
"""Initialize a Federated Learning (FL) setup.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
name (str): The name of the FL setup.
|
20
|
+
description (str): A description of the FL setup.
|
21
|
+
network (Network): An instance of the Network class representing the network architecture.
|
22
|
+
"""
|
23
|
+
self.name = name
|
24
|
+
self.description = description
|
25
|
+
self.network = network
|
26
|
+
self.column_name = None
|
27
|
+
self.auto = 1 if self.column_name is not None else 0
|
28
|
+
self.validate()
|
29
|
+
self.fed_dataset = None
|
30
|
+
|
31
|
+
def validate(self):
|
32
|
+
"""Validate name, description, and network."""
|
33
|
+
if not isinstance(self.name, str):
|
34
|
+
raise TypeError("name argument must be a string")
|
35
|
+
|
36
|
+
if not isinstance(self.description, str):
|
37
|
+
raise TypeError("description argument must be a string")
|
38
|
+
|
39
|
+
if not isinstance(self.network, Network):
|
40
|
+
raise TypeError(
|
41
|
+
"network argument must be a Medfl.NetManager.Network "
|
42
|
+
)
|
43
|
+
|
44
|
+
def create(self):
|
45
|
+
"""Create an FL setup."""
|
46
|
+
creation_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
47
|
+
netid = get_netid_from_name(self.network.name)
|
48
|
+
my_eng.execute(
|
49
|
+
text(CREATE_FLSETUP_QUERY),
|
50
|
+
{
|
51
|
+
"name": self.name,
|
52
|
+
"description": self.description,
|
53
|
+
"creation_date": creation_date,
|
54
|
+
"net_id": netid,
|
55
|
+
"column_name": self.column_name,
|
56
|
+
},
|
57
|
+
)
|
58
|
+
self.id = get_flsetupid_from_name(self.name)
|
59
|
+
|
60
|
+
def delete(self):
|
61
|
+
"""Delete the FL setup."""
|
62
|
+
if self.fed_dataset is not None:
|
63
|
+
self.fed_dataset.delete_Flsetup(FLsetupId=self.id)
|
64
|
+
my_eng.execute(text(DELETE_FLSETUP_QUERY), {"name": self.name})
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def read_setup(cls, FLsetupId: int):
|
68
|
+
"""Read the FL setup by FLsetupId.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
FLsetupId (int): The id of the FL setup to read.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
FLsetup: An instance of the FLsetup class with the specified FLsetupId.
|
75
|
+
"""
|
76
|
+
res = pd.read_sql(
|
77
|
+
text(READ_SETUP_QUERY), my_eng, params={"flsetup_id": FLsetupId}
|
78
|
+
).iloc[0]
|
79
|
+
|
80
|
+
network_res = pd.read_sql(
|
81
|
+
text(READ_NETWORK_BY_ID_QUERY),
|
82
|
+
my_eng,
|
83
|
+
params={"net_id": int(res["NetId"])},
|
84
|
+
).iloc[0]
|
85
|
+
network = Network(network_res["NetName"])
|
86
|
+
setattr(network, "id", res["NetId"])
|
87
|
+
fl_setup = cls(res["name"], res["description"], network)
|
88
|
+
if res["column_name"] == str(None):
|
89
|
+
res["column_name"] = None
|
90
|
+
setattr(fl_setup, "column_name", res["column_name"])
|
91
|
+
setattr(fl_setup, "id", res["FLsetupId"])
|
92
|
+
|
93
|
+
return fl_setup
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def list_allsetups():
|
97
|
+
"""List all the FL setups.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
DataFrame: A DataFrame containing information about all the FL setups.
|
101
|
+
"""
|
102
|
+
Flsetups = pd.read_sql(text(READ_ALL_SETUPS_QUERY), my_eng)
|
103
|
+
return Flsetups
|
104
|
+
|
105
|
+
def create_nodes_from_master_dataset(self, params_dict: dict):
|
106
|
+
"""Create nodes from the master dataset.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
params_dict (dict): A dictionary containing parameters for node creation.
|
110
|
+
- column_name (str): The name of the column in the MasterDataset used to create nodes.
|
111
|
+
- train_nodes (list): A list of node names that will be used for training.
|
112
|
+
- test_nodes (list): A list of node names that will be used for testing.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
list: A list of Node instances created from the master dataset.
|
116
|
+
"""
|
117
|
+
assert "column_name" in params_dict.keys()
|
118
|
+
column_name, train_nodes, test_nodes = (
|
119
|
+
params_dict["column_name"],
|
120
|
+
params_dict["train_nodes"],
|
121
|
+
params_dict["test_nodes"],
|
122
|
+
)
|
123
|
+
self.column_name = column_name
|
124
|
+
self.auto = 1
|
125
|
+
|
126
|
+
# Update the Column name of the auto flSetup
|
127
|
+
query = f"UPDATE FLsetup SET column_name = '{column_name}' WHERE name = '{self.name}'"
|
128
|
+
my_eng.execute(text(query))
|
129
|
+
|
130
|
+
|
131
|
+
# Add Network to DB
|
132
|
+
# self.network.create_network()
|
133
|
+
|
134
|
+
netid = get_netid_from_name(self.network.name)
|
135
|
+
|
136
|
+
assert self.network.mtable_exists == 1
|
137
|
+
node_names = pd.read_sql(
|
138
|
+
text(READ_DISTINCT_NODES_QUERY.format(column_name)), my_eng
|
139
|
+
)
|
140
|
+
|
141
|
+
nodes = [Node(val[0], 1) for val in node_names.values.tolist()]
|
142
|
+
|
143
|
+
used_nodes = []
|
144
|
+
|
145
|
+
for node in nodes:
|
146
|
+
if node.name in train_nodes:
|
147
|
+
node.train = 1
|
148
|
+
node.create_node(netid)
|
149
|
+
used_nodes.append(node)
|
150
|
+
if node.name in test_nodes:
|
151
|
+
node.train =0
|
152
|
+
node.create_node(netid)
|
153
|
+
used_nodes.append(node)
|
154
|
+
return used_nodes
|
155
|
+
|
156
|
+
def create_dataloader_from_node(
|
157
|
+
self,
|
158
|
+
node: Node,
|
159
|
+
output,
|
160
|
+
fill_strategy="mean", fit_encode=[], to_drop=[],
|
161
|
+
train_batch_size: int = 32,
|
162
|
+
test_batch_size: int = 1,
|
163
|
+
split_frac: float = 0.2,
|
164
|
+
dataset: Dataset = None,
|
165
|
+
|
166
|
+
):
|
167
|
+
"""Create DataLoader from a Node.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
node (Node): The node from which to create DataLoader.
|
171
|
+
train_batch_size (int): The batch size for training data.
|
172
|
+
test_batch_size (int): The batch size for test data.
|
173
|
+
split_frac (float): The fraction of data to be used for training.
|
174
|
+
dataset (Dataset): The dataset to use. If None, the method will read the dataset from the node.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
DataLoader: The DataLoader instances for training and testing.
|
178
|
+
"""
|
179
|
+
if dataset is None:
|
180
|
+
if self.column_name is not None:
|
181
|
+
dataset = process_data_after_reading(
|
182
|
+
node.get_dataset(self.column_name), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop
|
183
|
+
)
|
184
|
+
else:
|
185
|
+
dataset = process_data_after_reading(
|
186
|
+
node.get_dataset(), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
|
187
|
+
|
188
|
+
dataset_size = len(dataset)
|
189
|
+
traindata_size = int(dataset_size * (1 - split_frac))
|
190
|
+
traindata, testdata = random_split(
|
191
|
+
dataset, [traindata_size, dataset_size - traindata_size]
|
192
|
+
)
|
193
|
+
trainloader, testloader = DataLoader(
|
194
|
+
traindata, batch_size=train_batch_size
|
195
|
+
), DataLoader(testdata, batch_size=test_batch_size)
|
196
|
+
return trainloader, testloader
|
197
|
+
|
198
|
+
def create_federated_dataset(
|
199
|
+
self, output, fill_strategy="mean", fit_encode=[], to_drop=[], val_frac=0.1, test_frac=0.2
|
200
|
+
) -> FederatedDataset:
|
201
|
+
"""Create a federated dataset.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
output (string): the output feature of the dataset
|
205
|
+
val_frac (float): The fraction of data to be used for validation.
|
206
|
+
test_frac (float): The fraction of data to be used for testing.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
FederatedDataset: The FederatedDataset instance containing train, validation, and test data.
|
210
|
+
"""
|
211
|
+
|
212
|
+
if not self.column_name:
|
213
|
+
to_drop.extend(["DataSetName" , "NodeId" , "DataSetId"])
|
214
|
+
else :
|
215
|
+
to_drop.extend(["PatientId"])
|
216
|
+
|
217
|
+
netid = self.network.id
|
218
|
+
train_nodes = pd.read_sql(
|
219
|
+
text(
|
220
|
+
f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 1 "
|
221
|
+
),
|
222
|
+
my_eng,
|
223
|
+
)
|
224
|
+
test_nodes = pd.read_sql(
|
225
|
+
text(
|
226
|
+
f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 0 "
|
227
|
+
),
|
228
|
+
my_eng,
|
229
|
+
)
|
230
|
+
|
231
|
+
train_nodes = [
|
232
|
+
Node(val[0], 1, test_frac) for val in train_nodes.values.tolist()
|
233
|
+
]
|
234
|
+
test_nodes = [Node(val[0], 0) for val in test_nodes.values.tolist()]
|
235
|
+
|
236
|
+
trainloaders, valloaders, testloaders = [], [], []
|
237
|
+
# if len(test_nodes) == 0:
|
238
|
+
# raise "test node empty"
|
239
|
+
if test_nodes is None:
|
240
|
+
_, testloader = self.create_dataloader_from_node(
|
241
|
+
train_nodes[0], output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
|
242
|
+
testloaders.append(testloader)
|
243
|
+
else:
|
244
|
+
for train_node in train_nodes:
|
245
|
+
train_valloader, testloader = self.create_dataloader_from_node(
|
246
|
+
train_node, output, fill_strategy=fill_strategy,
|
247
|
+
fit_encode=fit_encode, to_drop=to_drop,)
|
248
|
+
trainloader, valloader = self.create_dataloader_from_node(
|
249
|
+
train_node,
|
250
|
+
output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop,
|
251
|
+
test_batch_size=32,
|
252
|
+
split_frac=val_frac,
|
253
|
+
dataset=train_valloader.dataset,
|
254
|
+
)
|
255
|
+
trainloaders.append(trainloader)
|
256
|
+
valloaders.append(valloader)
|
257
|
+
testloaders.append(testloader)
|
258
|
+
|
259
|
+
for test_node in test_nodes:
|
260
|
+
_, testloader = self.create_dataloader_from_node(
|
261
|
+
test_node, output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop, split_frac=1.0
|
262
|
+
)
|
263
|
+
testloaders.append(testloader)
|
264
|
+
train_nodes_names = [node.name for node in train_nodes]
|
265
|
+
test_nodes_names = train_nodes_names + [
|
266
|
+
node.name for node in test_nodes
|
267
|
+
]
|
268
|
+
|
269
|
+
# test_nodes_names = [
|
270
|
+
# node.name for node in test_nodes
|
271
|
+
# ]
|
272
|
+
|
273
|
+
# Add FlSetup on to the DataBase
|
274
|
+
# self.create()
|
275
|
+
|
276
|
+
# self.network.update_network(FLsetupId=self.id)
|
277
|
+
fed_dataset = FederatedDataset(
|
278
|
+
self.name + "_Feddataset",
|
279
|
+
train_nodes_names,
|
280
|
+
test_nodes_names,
|
281
|
+
trainloaders,
|
282
|
+
valloaders,
|
283
|
+
testloaders,
|
284
|
+
)
|
285
|
+
self.fed_dataset = fed_dataset
|
286
|
+
self.fed_dataset.create(self.id)
|
287
|
+
return self.fed_dataset
|
288
|
+
|
289
|
+
|
290
|
+
|
291
|
+
|
292
|
+
def get_flDataSet(self):
|
293
|
+
"""
|
294
|
+
Retrieve the federated dataset associated with the FL setup using the FL setup's name.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
pandas.DataFrame: DataFrame containing the federated dataset information.
|
298
|
+
"""
|
299
|
+
return pd.read_sql(
|
300
|
+
text(
|
301
|
+
f"SELECT * FROM FedDatasets WHERE FLsetupId = {get_flsetupid_from_name(self.name)}"
|
302
|
+
),
|
303
|
+
my_eng,
|
304
|
+
)
|