MEDfl 0.2.1__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/LearningManager/__init__.py +13 -13
- MEDfl/LearningManager/client.py +150 -181
- MEDfl/LearningManager/dynamicModal.py +287 -287
- MEDfl/LearningManager/federated_dataset.py +60 -60
- MEDfl/LearningManager/flpipeline.py +192 -192
- MEDfl/LearningManager/model.py +223 -223
- MEDfl/LearningManager/params.yaml +14 -14
- MEDfl/LearningManager/params_optimiser.py +442 -442
- MEDfl/LearningManager/plot.py +229 -229
- MEDfl/LearningManager/server.py +181 -189
- MEDfl/LearningManager/strategy.py +82 -138
- MEDfl/LearningManager/utils.py +331 -331
- MEDfl/NetManager/__init__.py +10 -10
- MEDfl/NetManager/database_connector.py +43 -43
- MEDfl/NetManager/dataset.py +92 -92
- MEDfl/NetManager/flsetup.py +320 -320
- MEDfl/NetManager/net_helper.py +254 -254
- MEDfl/NetManager/net_manager_queries.py +142 -142
- MEDfl/NetManager/network.py +194 -194
- MEDfl/NetManager/node.py +184 -184
- MEDfl/__init__.py +4 -3
- MEDfl/scripts/__init__.py +1 -1
- MEDfl/scripts/base.py +29 -29
- MEDfl/scripts/create_db.py +126 -126
- Medfl/LearningManager/__init__.py +13 -0
- Medfl/LearningManager/client.py +150 -0
- Medfl/LearningManager/dynamicModal.py +287 -0
- Medfl/LearningManager/federated_dataset.py +60 -0
- Medfl/LearningManager/flpipeline.py +192 -0
- Medfl/LearningManager/model.py +223 -0
- Medfl/LearningManager/params.yaml +14 -0
- Medfl/LearningManager/params_optimiser.py +442 -0
- Medfl/LearningManager/plot.py +229 -0
- Medfl/LearningManager/server.py +181 -0
- Medfl/LearningManager/strategy.py +82 -0
- Medfl/LearningManager/utils.py +331 -0
- Medfl/NetManager/__init__.py +10 -0
- Medfl/NetManager/database_connector.py +43 -0
- Medfl/NetManager/dataset.py +92 -0
- Medfl/NetManager/flsetup.py +320 -0
- Medfl/NetManager/net_helper.py +254 -0
- Medfl/NetManager/net_manager_queries.py +142 -0
- Medfl/NetManager/network.py +194 -0
- Medfl/NetManager/node.py +184 -0
- Medfl/__init__.py +3 -0
- Medfl/scripts/__init__.py +2 -0
- Medfl/scripts/base.py +30 -0
- Medfl/scripts/create_db.py +126 -0
- alembic/env.py +61 -61
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/METADATA +120 -108
- medfl-2.0.1.dist-info/RECORD +55 -0
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/WHEEL +1 -1
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info/licenses}/LICENSE +674 -674
- MEDfl-0.2.1.dist-info/RECORD +0 -31
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/top_level.txt +0 -0
MEDfl/NetManager/flsetup.py
CHANGED
@@ -1,320 +1,320 @@
|
|
1
|
-
from datetime import datetime
|
2
|
-
|
3
|
-
|
4
|
-
from torch.utils.data import random_split, DataLoader, Dataset
|
5
|
-
|
6
|
-
from MEDfl.LearningManager.federated_dataset import FederatedDataset
|
7
|
-
from .net_helper import *
|
8
|
-
from .net_manager_queries import * # Import the sql_queries module
|
9
|
-
from .network import Network
|
10
|
-
|
11
|
-
from .node import Node
|
12
|
-
|
13
|
-
from MEDfl.NetManager.database_connector import DatabaseManager
|
14
|
-
|
15
|
-
|
16
|
-
class FLsetup:
|
17
|
-
def __init__(self, name: str, description: str, network: Network):
|
18
|
-
"""Initialize a Federated Learning (FL) setup.
|
19
|
-
|
20
|
-
Args:
|
21
|
-
name (str): The name of the FL setup.
|
22
|
-
description (str): A description of the FL setup.
|
23
|
-
network (Network): An instance of the Network class representing the network architecture.
|
24
|
-
"""
|
25
|
-
self.name = name
|
26
|
-
self.description = description
|
27
|
-
self.network = network
|
28
|
-
self.column_name = None
|
29
|
-
self.auto = 1 if self.column_name is not None else 0
|
30
|
-
self.validate()
|
31
|
-
self.fed_dataset = None
|
32
|
-
|
33
|
-
db_manager = DatabaseManager()
|
34
|
-
db_manager.connect()
|
35
|
-
self.eng = db_manager.get_connection()
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
def validate(self):
|
40
|
-
"""Validate name, description, and network."""
|
41
|
-
if not isinstance(self.name, str):
|
42
|
-
raise TypeError("name argument must be a string")
|
43
|
-
|
44
|
-
if not isinstance(self.description, str):
|
45
|
-
raise TypeError("description argument must be a string")
|
46
|
-
|
47
|
-
if not isinstance(self.network, Network):
|
48
|
-
raise TypeError(
|
49
|
-
"network argument must be a MEDfl.NetManager.Network "
|
50
|
-
)
|
51
|
-
|
52
|
-
def create(self):
|
53
|
-
"""Create an FL setup."""
|
54
|
-
creation_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
55
|
-
netid = get_netid_from_name(self.network.name)
|
56
|
-
self.eng.execute(
|
57
|
-
text(CREATE_FLSETUP_QUERY),
|
58
|
-
{
|
59
|
-
"name": self.name,
|
60
|
-
"description": self.description,
|
61
|
-
"creation_date": creation_date,
|
62
|
-
"net_id": netid,
|
63
|
-
"column_name": self.column_name,
|
64
|
-
},
|
65
|
-
)
|
66
|
-
self.id = get_flsetupid_from_name(self.name)
|
67
|
-
|
68
|
-
def delete(self):
|
69
|
-
"""Delete the FL setup."""
|
70
|
-
if self.fed_dataset is not None:
|
71
|
-
self.fed_dataset.delete_Flsetup(FLsetupId=self.id)
|
72
|
-
self.eng.execute(text(DELETE_FLSETUP_QUERY), {"name": self.name})
|
73
|
-
|
74
|
-
@classmethod
|
75
|
-
def read_setup(cls, FLsetupId: int):
|
76
|
-
"""Read the FL setup by FLsetupId.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
FLsetupId (int): The id of the FL setup to read.
|
80
|
-
|
81
|
-
Returns:
|
82
|
-
FLsetup: An instance of the FLsetup class with the specified FLsetupId.
|
83
|
-
"""
|
84
|
-
db_manager = DatabaseManager()
|
85
|
-
db_manager.connect()
|
86
|
-
my_eng = db_manager.get_connection()
|
87
|
-
|
88
|
-
res = pd.read_sql(
|
89
|
-
text(READ_SETUP_QUERY), my_eng, params={"flsetup_id": FLsetupId}
|
90
|
-
).iloc[0]
|
91
|
-
|
92
|
-
network_res = pd.read_sql(
|
93
|
-
text(READ_NETWORK_BY_ID_QUERY),
|
94
|
-
my_eng,
|
95
|
-
params={"net_id": int(res["NetId"])},
|
96
|
-
).iloc[0]
|
97
|
-
network = Network(network_res["NetName"])
|
98
|
-
setattr(network, "id", res["NetId"])
|
99
|
-
fl_setup = cls(res["name"], res["description"], network)
|
100
|
-
if res["column_name"] == str(None):
|
101
|
-
res["column_name"] = None
|
102
|
-
setattr(fl_setup, "column_name", res["column_name"])
|
103
|
-
setattr(fl_setup, "id", res["FLsetupId"])
|
104
|
-
|
105
|
-
return fl_setup
|
106
|
-
|
107
|
-
@staticmethod
|
108
|
-
def list_allsetups():
|
109
|
-
"""List all the FL setups.
|
110
|
-
|
111
|
-
Returns:
|
112
|
-
DataFrame: A DataFrame containing information about all the FL setups.
|
113
|
-
"""
|
114
|
-
db_manager = DatabaseManager()
|
115
|
-
db_manager.connect()
|
116
|
-
my_eng = db_manager.get_connection()
|
117
|
-
|
118
|
-
Flsetups = pd.read_sql(text(READ_ALL_SETUPS_QUERY), my_eng)
|
119
|
-
return Flsetups
|
120
|
-
|
121
|
-
def create_nodes_from_master_dataset(self, params_dict: dict):
|
122
|
-
"""Create nodes from the master dataset.
|
123
|
-
|
124
|
-
Args:
|
125
|
-
params_dict (dict): A dictionary containing parameters for node creation.
|
126
|
-
- column_name (str): The name of the column in the MasterDataset used to create nodes.
|
127
|
-
- train_nodes (list): A list of node names that will be used for training.
|
128
|
-
- test_nodes (list): A list of node names that will be used for testing.
|
129
|
-
|
130
|
-
Returns:
|
131
|
-
list: A list of Node instances created from the master dataset.
|
132
|
-
"""
|
133
|
-
assert "column_name" in params_dict.keys()
|
134
|
-
column_name, train_nodes, test_nodes = (
|
135
|
-
params_dict["column_name"],
|
136
|
-
params_dict["train_nodes"],
|
137
|
-
params_dict["test_nodes"],
|
138
|
-
)
|
139
|
-
self.column_name = column_name
|
140
|
-
self.auto = 1
|
141
|
-
|
142
|
-
# Update the Column name of the auto flSetup
|
143
|
-
query = f"UPDATE FLsetup SET column_name = '{column_name}' WHERE name = '{self.name}'"
|
144
|
-
self.eng.execute(text(query))
|
145
|
-
|
146
|
-
|
147
|
-
# Add Network to DB
|
148
|
-
# self.network.create_network()
|
149
|
-
|
150
|
-
netid = get_netid_from_name(self.network.name)
|
151
|
-
|
152
|
-
assert self.network.mtable_exists == 1
|
153
|
-
node_names = pd.read_sql(
|
154
|
-
text(READ_DISTINCT_NODES_QUERY.format(column_name)), self.eng
|
155
|
-
)
|
156
|
-
|
157
|
-
nodes = [Node(val[0], 1) for val in node_names.values.tolist()]
|
158
|
-
|
159
|
-
used_nodes = []
|
160
|
-
|
161
|
-
for node in nodes:
|
162
|
-
if node.name in train_nodes:
|
163
|
-
node.train = 1
|
164
|
-
node.create_node(netid)
|
165
|
-
used_nodes.append(node)
|
166
|
-
if node.name in test_nodes:
|
167
|
-
node.train =0
|
168
|
-
node.create_node(netid)
|
169
|
-
used_nodes.append(node)
|
170
|
-
return used_nodes
|
171
|
-
|
172
|
-
def create_dataloader_from_node(
|
173
|
-
self,
|
174
|
-
node: Node,
|
175
|
-
output,
|
176
|
-
fill_strategy="mean", fit_encode=[], to_drop=[],
|
177
|
-
train_batch_size: int = 32,
|
178
|
-
test_batch_size: int = 1,
|
179
|
-
split_frac: float = 0.2,
|
180
|
-
dataset: Dataset = None,
|
181
|
-
|
182
|
-
):
|
183
|
-
"""Create DataLoader from a Node.
|
184
|
-
|
185
|
-
Args:
|
186
|
-
node (Node): The node from which to create DataLoader.
|
187
|
-
train_batch_size (int): The batch size for training data.
|
188
|
-
test_batch_size (int): The batch size for test data.
|
189
|
-
split_frac (float): The fraction of data to be used for training.
|
190
|
-
dataset (Dataset): The dataset to use. If None, the method will read the dataset from the node.
|
191
|
-
|
192
|
-
Returns:
|
193
|
-
DataLoader: The DataLoader instances for training and testing.
|
194
|
-
"""
|
195
|
-
if dataset is None:
|
196
|
-
if self.column_name is not None:
|
197
|
-
dataset = process_data_after_reading(
|
198
|
-
node.get_dataset(self.column_name), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop
|
199
|
-
)
|
200
|
-
else:
|
201
|
-
dataset = process_data_after_reading(
|
202
|
-
node.get_dataset(), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
|
203
|
-
|
204
|
-
dataset_size = len(dataset)
|
205
|
-
traindata_size = int(dataset_size * (1 - split_frac))
|
206
|
-
traindata, testdata = random_split(
|
207
|
-
dataset, [traindata_size, dataset_size - traindata_size]
|
208
|
-
)
|
209
|
-
trainloader, testloader = DataLoader(
|
210
|
-
traindata, batch_size=train_batch_size
|
211
|
-
), DataLoader(testdata, batch_size=test_batch_size)
|
212
|
-
return trainloader, testloader
|
213
|
-
|
214
|
-
def create_federated_dataset(
|
215
|
-
self, output, fill_strategy="mean", fit_encode=[], to_drop=[], val_frac=0.1, test_frac=0.2
|
216
|
-
) -> FederatedDataset:
|
217
|
-
"""Create a federated dataset.
|
218
|
-
|
219
|
-
Args:
|
220
|
-
output (string): the output feature of the dataset
|
221
|
-
val_frac (float): The fraction of data to be used for validation.
|
222
|
-
test_frac (float): The fraction of data to be used for testing.
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
FederatedDataset: The FederatedDataset instance containing train, validation, and test data.
|
226
|
-
"""
|
227
|
-
|
228
|
-
if not self.column_name:
|
229
|
-
to_drop.extend(["DataSetName" , "NodeId" , "DataSetId"])
|
230
|
-
else :
|
231
|
-
to_drop.extend(["PatientId"])
|
232
|
-
|
233
|
-
netid = self.network.id
|
234
|
-
train_nodes = pd.read_sql(
|
235
|
-
text(
|
236
|
-
f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 1 "
|
237
|
-
),
|
238
|
-
self.eng,
|
239
|
-
)
|
240
|
-
test_nodes = pd.read_sql(
|
241
|
-
text(
|
242
|
-
f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 0 "
|
243
|
-
),
|
244
|
-
self.eng,
|
245
|
-
)
|
246
|
-
|
247
|
-
train_nodes = [
|
248
|
-
Node(val[0], 1, test_frac) for val in train_nodes.values.tolist()
|
249
|
-
]
|
250
|
-
test_nodes = [Node(val[0], 0) for val in test_nodes.values.tolist()]
|
251
|
-
|
252
|
-
trainloaders, valloaders, testloaders = [], [], []
|
253
|
-
# if len(test_nodes) == 0:
|
254
|
-
# raise "test node empty"
|
255
|
-
if test_nodes is None:
|
256
|
-
_, testloader = self.create_dataloader_from_node(
|
257
|
-
train_nodes[0], output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
|
258
|
-
testloaders.append(testloader)
|
259
|
-
else:
|
260
|
-
for train_node in train_nodes:
|
261
|
-
train_valloader, testloader = self.create_dataloader_from_node(
|
262
|
-
train_node, output, fill_strategy=fill_strategy,
|
263
|
-
fit_encode=fit_encode, to_drop=to_drop,)
|
264
|
-
trainloader, valloader = self.create_dataloader_from_node(
|
265
|
-
train_node,
|
266
|
-
output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop,
|
267
|
-
test_batch_size=32,
|
268
|
-
split_frac=val_frac,
|
269
|
-
dataset=train_valloader.dataset,
|
270
|
-
)
|
271
|
-
trainloaders.append(trainloader)
|
272
|
-
valloaders.append(valloader)
|
273
|
-
testloaders.append(testloader)
|
274
|
-
|
275
|
-
for test_node in test_nodes:
|
276
|
-
_, testloader = self.create_dataloader_from_node(
|
277
|
-
test_node, output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop, split_frac=1.0
|
278
|
-
)
|
279
|
-
testloaders.append(testloader)
|
280
|
-
train_nodes_names = [node.name for node in train_nodes]
|
281
|
-
test_nodes_names = train_nodes_names + [
|
282
|
-
node.name for node in test_nodes
|
283
|
-
]
|
284
|
-
|
285
|
-
# test_nodes_names = [
|
286
|
-
# node.name for node in test_nodes
|
287
|
-
# ]
|
288
|
-
|
289
|
-
# Add FlSetup on to the DataBase
|
290
|
-
# self.create()
|
291
|
-
|
292
|
-
# self.network.update_network(FLsetupId=self.id)
|
293
|
-
fed_dataset = FederatedDataset(
|
294
|
-
self.name + "_Feddataset",
|
295
|
-
train_nodes_names,
|
296
|
-
test_nodes_names,
|
297
|
-
trainloaders,
|
298
|
-
valloaders,
|
299
|
-
testloaders,
|
300
|
-
)
|
301
|
-
self.fed_dataset = fed_dataset
|
302
|
-
self.fed_dataset.create(self.id)
|
303
|
-
return self.fed_dataset
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
def get_flDataSet(self):
|
309
|
-
"""
|
310
|
-
Retrieve the federated dataset associated with the FL setup using the FL setup's name.
|
311
|
-
|
312
|
-
Returns:
|
313
|
-
pandas.DataFrame: DataFrame containing the federated dataset information.
|
314
|
-
"""
|
315
|
-
return pd.read_sql(
|
316
|
-
text(
|
317
|
-
f"SELECT * FROM FedDatasets WHERE FLsetupId = {get_flsetupid_from_name(self.name)}"
|
318
|
-
),
|
319
|
-
self.eng,
|
320
|
-
)
|
1
|
+
from datetime import datetime
|
2
|
+
|
3
|
+
|
4
|
+
from torch.utils.data import random_split, DataLoader, Dataset
|
5
|
+
|
6
|
+
from MEDfl.LearningManager.federated_dataset import FederatedDataset
|
7
|
+
from .net_helper import *
|
8
|
+
from .net_manager_queries import * # Import the sql_queries module
|
9
|
+
from .network import Network
|
10
|
+
|
11
|
+
from .node import Node
|
12
|
+
|
13
|
+
from MEDfl.NetManager.database_connector import DatabaseManager
|
14
|
+
|
15
|
+
|
16
|
+
class FLsetup:
|
17
|
+
def __init__(self, name: str, description: str, network: Network):
|
18
|
+
"""Initialize a Federated Learning (FL) setup.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
name (str): The name of the FL setup.
|
22
|
+
description (str): A description of the FL setup.
|
23
|
+
network (Network): An instance of the Network class representing the network architecture.
|
24
|
+
"""
|
25
|
+
self.name = name
|
26
|
+
self.description = description
|
27
|
+
self.network = network
|
28
|
+
self.column_name = None
|
29
|
+
self.auto = 1 if self.column_name is not None else 0
|
30
|
+
self.validate()
|
31
|
+
self.fed_dataset = None
|
32
|
+
|
33
|
+
db_manager = DatabaseManager()
|
34
|
+
db_manager.connect()
|
35
|
+
self.eng = db_manager.get_connection()
|
36
|
+
|
37
|
+
|
38
|
+
|
39
|
+
def validate(self):
|
40
|
+
"""Validate name, description, and network."""
|
41
|
+
if not isinstance(self.name, str):
|
42
|
+
raise TypeError("name argument must be a string")
|
43
|
+
|
44
|
+
if not isinstance(self.description, str):
|
45
|
+
raise TypeError("description argument must be a string")
|
46
|
+
|
47
|
+
if not isinstance(self.network, Network):
|
48
|
+
raise TypeError(
|
49
|
+
"network argument must be a MEDfl.NetManager.Network "
|
50
|
+
)
|
51
|
+
|
52
|
+
def create(self):
|
53
|
+
"""Create an FL setup."""
|
54
|
+
creation_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
55
|
+
netid = get_netid_from_name(self.network.name)
|
56
|
+
self.eng.execute(
|
57
|
+
text(CREATE_FLSETUP_QUERY),
|
58
|
+
{
|
59
|
+
"name": self.name,
|
60
|
+
"description": self.description,
|
61
|
+
"creation_date": creation_date,
|
62
|
+
"net_id": netid,
|
63
|
+
"column_name": self.column_name,
|
64
|
+
},
|
65
|
+
)
|
66
|
+
self.id = get_flsetupid_from_name(self.name)
|
67
|
+
|
68
|
+
def delete(self):
|
69
|
+
"""Delete the FL setup."""
|
70
|
+
if self.fed_dataset is not None:
|
71
|
+
self.fed_dataset.delete_Flsetup(FLsetupId=self.id)
|
72
|
+
self.eng.execute(text(DELETE_FLSETUP_QUERY), {"name": self.name})
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def read_setup(cls, FLsetupId: int):
|
76
|
+
"""Read the FL setup by FLsetupId.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
FLsetupId (int): The id of the FL setup to read.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
FLsetup: An instance of the FLsetup class with the specified FLsetupId.
|
83
|
+
"""
|
84
|
+
db_manager = DatabaseManager()
|
85
|
+
db_manager.connect()
|
86
|
+
my_eng = db_manager.get_connection()
|
87
|
+
|
88
|
+
res = pd.read_sql(
|
89
|
+
text(READ_SETUP_QUERY), my_eng, params={"flsetup_id": FLsetupId}
|
90
|
+
).iloc[0]
|
91
|
+
|
92
|
+
network_res = pd.read_sql(
|
93
|
+
text(READ_NETWORK_BY_ID_QUERY),
|
94
|
+
my_eng,
|
95
|
+
params={"net_id": int(res["NetId"])},
|
96
|
+
).iloc[0]
|
97
|
+
network = Network(network_res["NetName"])
|
98
|
+
setattr(network, "id", res["NetId"])
|
99
|
+
fl_setup = cls(res["name"], res["description"], network)
|
100
|
+
if res["column_name"] == str(None):
|
101
|
+
res["column_name"] = None
|
102
|
+
setattr(fl_setup, "column_name", res["column_name"])
|
103
|
+
setattr(fl_setup, "id", res["FLsetupId"])
|
104
|
+
|
105
|
+
return fl_setup
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def list_allsetups():
|
109
|
+
"""List all the FL setups.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
DataFrame: A DataFrame containing information about all the FL setups.
|
113
|
+
"""
|
114
|
+
db_manager = DatabaseManager()
|
115
|
+
db_manager.connect()
|
116
|
+
my_eng = db_manager.get_connection()
|
117
|
+
|
118
|
+
Flsetups = pd.read_sql(text(READ_ALL_SETUPS_QUERY), my_eng)
|
119
|
+
return Flsetups
|
120
|
+
|
121
|
+
def create_nodes_from_master_dataset(self, params_dict: dict):
|
122
|
+
"""Create nodes from the master dataset.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
params_dict (dict): A dictionary containing parameters for node creation.
|
126
|
+
- column_name (str): The name of the column in the MasterDataset used to create nodes.
|
127
|
+
- train_nodes (list): A list of node names that will be used for training.
|
128
|
+
- test_nodes (list): A list of node names that will be used for testing.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
list: A list of Node instances created from the master dataset.
|
132
|
+
"""
|
133
|
+
assert "column_name" in params_dict.keys()
|
134
|
+
column_name, train_nodes, test_nodes = (
|
135
|
+
params_dict["column_name"],
|
136
|
+
params_dict["train_nodes"],
|
137
|
+
params_dict["test_nodes"],
|
138
|
+
)
|
139
|
+
self.column_name = column_name
|
140
|
+
self.auto = 1
|
141
|
+
|
142
|
+
# Update the Column name of the auto flSetup
|
143
|
+
query = f"UPDATE FLsetup SET column_name = '{column_name}' WHERE name = '{self.name}'"
|
144
|
+
self.eng.execute(text(query))
|
145
|
+
|
146
|
+
|
147
|
+
# Add Network to DB
|
148
|
+
# self.network.create_network()
|
149
|
+
|
150
|
+
netid = get_netid_from_name(self.network.name)
|
151
|
+
|
152
|
+
assert self.network.mtable_exists == 1
|
153
|
+
node_names = pd.read_sql(
|
154
|
+
text(READ_DISTINCT_NODES_QUERY.format(column_name)), self.eng
|
155
|
+
)
|
156
|
+
|
157
|
+
nodes = [Node(val[0], 1) for val in node_names.values.tolist()]
|
158
|
+
|
159
|
+
used_nodes = []
|
160
|
+
|
161
|
+
for node in nodes:
|
162
|
+
if node.name in train_nodes:
|
163
|
+
node.train = 1
|
164
|
+
node.create_node(netid)
|
165
|
+
used_nodes.append(node)
|
166
|
+
if node.name in test_nodes:
|
167
|
+
node.train =0
|
168
|
+
node.create_node(netid)
|
169
|
+
used_nodes.append(node)
|
170
|
+
return used_nodes
|
171
|
+
|
172
|
+
def create_dataloader_from_node(
|
173
|
+
self,
|
174
|
+
node: Node,
|
175
|
+
output,
|
176
|
+
fill_strategy="mean", fit_encode=[], to_drop=[],
|
177
|
+
train_batch_size: int = 32,
|
178
|
+
test_batch_size: int = 1,
|
179
|
+
split_frac: float = 0.2,
|
180
|
+
dataset: Dataset = None,
|
181
|
+
|
182
|
+
):
|
183
|
+
"""Create DataLoader from a Node.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
node (Node): The node from which to create DataLoader.
|
187
|
+
train_batch_size (int): The batch size for training data.
|
188
|
+
test_batch_size (int): The batch size for test data.
|
189
|
+
split_frac (float): The fraction of data to be used for training.
|
190
|
+
dataset (Dataset): The dataset to use. If None, the method will read the dataset from the node.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
DataLoader: The DataLoader instances for training and testing.
|
194
|
+
"""
|
195
|
+
if dataset is None:
|
196
|
+
if self.column_name is not None:
|
197
|
+
dataset = process_data_after_reading(
|
198
|
+
node.get_dataset(self.column_name), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
dataset = process_data_after_reading(
|
202
|
+
node.get_dataset(), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
|
203
|
+
|
204
|
+
dataset_size = len(dataset)
|
205
|
+
traindata_size = int(dataset_size * (1 - split_frac))
|
206
|
+
traindata, testdata = random_split(
|
207
|
+
dataset, [traindata_size, dataset_size - traindata_size]
|
208
|
+
)
|
209
|
+
trainloader, testloader = DataLoader(
|
210
|
+
traindata, batch_size=train_batch_size
|
211
|
+
), DataLoader(testdata, batch_size=test_batch_size)
|
212
|
+
return trainloader, testloader
|
213
|
+
|
214
|
+
def create_federated_dataset(
|
215
|
+
self, output, fill_strategy="mean", fit_encode=[], to_drop=[], val_frac=0.1, test_frac=0.2
|
216
|
+
) -> FederatedDataset:
|
217
|
+
"""Create a federated dataset.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
output (string): the output feature of the dataset
|
221
|
+
val_frac (float): The fraction of data to be used for validation.
|
222
|
+
test_frac (float): The fraction of data to be used for testing.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
FederatedDataset: The FederatedDataset instance containing train, validation, and test data.
|
226
|
+
"""
|
227
|
+
|
228
|
+
if not self.column_name:
|
229
|
+
to_drop.extend(["DataSetName" , "NodeId" , "DataSetId"])
|
230
|
+
else :
|
231
|
+
to_drop.extend(["PatientId"])
|
232
|
+
|
233
|
+
netid = self.network.id
|
234
|
+
train_nodes = pd.read_sql(
|
235
|
+
text(
|
236
|
+
f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 1 "
|
237
|
+
),
|
238
|
+
self.eng,
|
239
|
+
)
|
240
|
+
test_nodes = pd.read_sql(
|
241
|
+
text(
|
242
|
+
f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 0 "
|
243
|
+
),
|
244
|
+
self.eng,
|
245
|
+
)
|
246
|
+
|
247
|
+
train_nodes = [
|
248
|
+
Node(val[0], 1, test_frac) for val in train_nodes.values.tolist()
|
249
|
+
]
|
250
|
+
test_nodes = [Node(val[0], 0) for val in test_nodes.values.tolist()]
|
251
|
+
|
252
|
+
trainloaders, valloaders, testloaders = [], [], []
|
253
|
+
# if len(test_nodes) == 0:
|
254
|
+
# raise "test node empty"
|
255
|
+
if test_nodes is None:
|
256
|
+
_, testloader = self.create_dataloader_from_node(
|
257
|
+
train_nodes[0], output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
|
258
|
+
testloaders.append(testloader)
|
259
|
+
else:
|
260
|
+
for train_node in train_nodes:
|
261
|
+
train_valloader, testloader = self.create_dataloader_from_node(
|
262
|
+
train_node, output, fill_strategy=fill_strategy,
|
263
|
+
fit_encode=fit_encode, to_drop=to_drop,)
|
264
|
+
trainloader, valloader = self.create_dataloader_from_node(
|
265
|
+
train_node,
|
266
|
+
output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop,
|
267
|
+
test_batch_size=32,
|
268
|
+
split_frac=val_frac,
|
269
|
+
dataset=train_valloader.dataset,
|
270
|
+
)
|
271
|
+
trainloaders.append(trainloader)
|
272
|
+
valloaders.append(valloader)
|
273
|
+
testloaders.append(testloader)
|
274
|
+
|
275
|
+
for test_node in test_nodes:
|
276
|
+
_, testloader = self.create_dataloader_from_node(
|
277
|
+
test_node, output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop, split_frac=1.0
|
278
|
+
)
|
279
|
+
testloaders.append(testloader)
|
280
|
+
train_nodes_names = [node.name for node in train_nodes]
|
281
|
+
test_nodes_names = train_nodes_names + [
|
282
|
+
node.name for node in test_nodes
|
283
|
+
]
|
284
|
+
|
285
|
+
# test_nodes_names = [
|
286
|
+
# node.name for node in test_nodes
|
287
|
+
# ]
|
288
|
+
|
289
|
+
# Add FlSetup on to the DataBase
|
290
|
+
# self.create()
|
291
|
+
|
292
|
+
# self.network.update_network(FLsetupId=self.id)
|
293
|
+
fed_dataset = FederatedDataset(
|
294
|
+
self.name + "_Feddataset",
|
295
|
+
train_nodes_names,
|
296
|
+
test_nodes_names,
|
297
|
+
trainloaders,
|
298
|
+
valloaders,
|
299
|
+
testloaders,
|
300
|
+
)
|
301
|
+
self.fed_dataset = fed_dataset
|
302
|
+
self.fed_dataset.create(self.id)
|
303
|
+
return self.fed_dataset
|
304
|
+
|
305
|
+
|
306
|
+
|
307
|
+
|
308
|
+
def get_flDataSet(self):
|
309
|
+
"""
|
310
|
+
Retrieve the federated dataset associated with the FL setup using the FL setup's name.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
pandas.DataFrame: DataFrame containing the federated dataset information.
|
314
|
+
"""
|
315
|
+
return pd.read_sql(
|
316
|
+
text(
|
317
|
+
f"SELECT * FROM FedDatasets WHERE FLsetupId = {get_flsetupid_from_name(self.name)}"
|
318
|
+
),
|
319
|
+
self.eng,
|
320
|
+
)
|