MEDfl 0.2.1__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. MEDfl/LearningManager/__init__.py +13 -13
  2. MEDfl/LearningManager/client.py +150 -181
  3. MEDfl/LearningManager/dynamicModal.py +287 -287
  4. MEDfl/LearningManager/federated_dataset.py +60 -60
  5. MEDfl/LearningManager/flpipeline.py +192 -192
  6. MEDfl/LearningManager/model.py +223 -223
  7. MEDfl/LearningManager/params.yaml +14 -14
  8. MEDfl/LearningManager/params_optimiser.py +442 -442
  9. MEDfl/LearningManager/plot.py +229 -229
  10. MEDfl/LearningManager/server.py +181 -189
  11. MEDfl/LearningManager/strategy.py +82 -138
  12. MEDfl/LearningManager/utils.py +331 -331
  13. MEDfl/NetManager/__init__.py +10 -10
  14. MEDfl/NetManager/database_connector.py +43 -43
  15. MEDfl/NetManager/dataset.py +92 -92
  16. MEDfl/NetManager/flsetup.py +320 -320
  17. MEDfl/NetManager/net_helper.py +254 -254
  18. MEDfl/NetManager/net_manager_queries.py +142 -142
  19. MEDfl/NetManager/network.py +194 -194
  20. MEDfl/NetManager/node.py +184 -184
  21. MEDfl/__init__.py +4 -3
  22. MEDfl/scripts/__init__.py +1 -1
  23. MEDfl/scripts/base.py +29 -29
  24. MEDfl/scripts/create_db.py +126 -126
  25. Medfl/LearningManager/__init__.py +13 -0
  26. Medfl/LearningManager/client.py +150 -0
  27. Medfl/LearningManager/dynamicModal.py +287 -0
  28. Medfl/LearningManager/federated_dataset.py +60 -0
  29. Medfl/LearningManager/flpipeline.py +192 -0
  30. Medfl/LearningManager/model.py +223 -0
  31. Medfl/LearningManager/params.yaml +14 -0
  32. Medfl/LearningManager/params_optimiser.py +442 -0
  33. Medfl/LearningManager/plot.py +229 -0
  34. Medfl/LearningManager/server.py +181 -0
  35. Medfl/LearningManager/strategy.py +82 -0
  36. Medfl/LearningManager/utils.py +331 -0
  37. Medfl/NetManager/__init__.py +10 -0
  38. Medfl/NetManager/database_connector.py +43 -0
  39. Medfl/NetManager/dataset.py +92 -0
  40. Medfl/NetManager/flsetup.py +320 -0
  41. Medfl/NetManager/net_helper.py +254 -0
  42. Medfl/NetManager/net_manager_queries.py +142 -0
  43. Medfl/NetManager/network.py +194 -0
  44. Medfl/NetManager/node.py +184 -0
  45. Medfl/__init__.py +3 -0
  46. Medfl/scripts/__init__.py +2 -0
  47. Medfl/scripts/base.py +30 -0
  48. Medfl/scripts/create_db.py +126 -0
  49. alembic/env.py +61 -61
  50. {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/METADATA +120 -108
  51. medfl-2.0.1.dist-info/RECORD +55 -0
  52. {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/WHEEL +1 -1
  53. {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info/licenses}/LICENSE +674 -674
  54. MEDfl-0.2.1.dist-info/RECORD +0 -31
  55. {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,320 +1,320 @@
1
- from datetime import datetime
2
-
3
-
4
- from torch.utils.data import random_split, DataLoader, Dataset
5
-
6
- from MEDfl.LearningManager.federated_dataset import FederatedDataset
7
- from .net_helper import *
8
- from .net_manager_queries import * # Import the sql_queries module
9
- from .network import Network
10
-
11
- from .node import Node
12
-
13
- from MEDfl.NetManager.database_connector import DatabaseManager
14
-
15
-
16
- class FLsetup:
17
- def __init__(self, name: str, description: str, network: Network):
18
- """Initialize a Federated Learning (FL) setup.
19
-
20
- Args:
21
- name (str): The name of the FL setup.
22
- description (str): A description of the FL setup.
23
- network (Network): An instance of the Network class representing the network architecture.
24
- """
25
- self.name = name
26
- self.description = description
27
- self.network = network
28
- self.column_name = None
29
- self.auto = 1 if self.column_name is not None else 0
30
- self.validate()
31
- self.fed_dataset = None
32
-
33
- db_manager = DatabaseManager()
34
- db_manager.connect()
35
- self.eng = db_manager.get_connection()
36
-
37
-
38
-
39
- def validate(self):
40
- """Validate name, description, and network."""
41
- if not isinstance(self.name, str):
42
- raise TypeError("name argument must be a string")
43
-
44
- if not isinstance(self.description, str):
45
- raise TypeError("description argument must be a string")
46
-
47
- if not isinstance(self.network, Network):
48
- raise TypeError(
49
- "network argument must be a MEDfl.NetManager.Network "
50
- )
51
-
52
- def create(self):
53
- """Create an FL setup."""
54
- creation_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
55
- netid = get_netid_from_name(self.network.name)
56
- self.eng.execute(
57
- text(CREATE_FLSETUP_QUERY),
58
- {
59
- "name": self.name,
60
- "description": self.description,
61
- "creation_date": creation_date,
62
- "net_id": netid,
63
- "column_name": self.column_name,
64
- },
65
- )
66
- self.id = get_flsetupid_from_name(self.name)
67
-
68
- def delete(self):
69
- """Delete the FL setup."""
70
- if self.fed_dataset is not None:
71
- self.fed_dataset.delete_Flsetup(FLsetupId=self.id)
72
- self.eng.execute(text(DELETE_FLSETUP_QUERY), {"name": self.name})
73
-
74
- @classmethod
75
- def read_setup(cls, FLsetupId: int):
76
- """Read the FL setup by FLsetupId.
77
-
78
- Args:
79
- FLsetupId (int): The id of the FL setup to read.
80
-
81
- Returns:
82
- FLsetup: An instance of the FLsetup class with the specified FLsetupId.
83
- """
84
- db_manager = DatabaseManager()
85
- db_manager.connect()
86
- my_eng = db_manager.get_connection()
87
-
88
- res = pd.read_sql(
89
- text(READ_SETUP_QUERY), my_eng, params={"flsetup_id": FLsetupId}
90
- ).iloc[0]
91
-
92
- network_res = pd.read_sql(
93
- text(READ_NETWORK_BY_ID_QUERY),
94
- my_eng,
95
- params={"net_id": int(res["NetId"])},
96
- ).iloc[0]
97
- network = Network(network_res["NetName"])
98
- setattr(network, "id", res["NetId"])
99
- fl_setup = cls(res["name"], res["description"], network)
100
- if res["column_name"] == str(None):
101
- res["column_name"] = None
102
- setattr(fl_setup, "column_name", res["column_name"])
103
- setattr(fl_setup, "id", res["FLsetupId"])
104
-
105
- return fl_setup
106
-
107
- @staticmethod
108
- def list_allsetups():
109
- """List all the FL setups.
110
-
111
- Returns:
112
- DataFrame: A DataFrame containing information about all the FL setups.
113
- """
114
- db_manager = DatabaseManager()
115
- db_manager.connect()
116
- my_eng = db_manager.get_connection()
117
-
118
- Flsetups = pd.read_sql(text(READ_ALL_SETUPS_QUERY), my_eng)
119
- return Flsetups
120
-
121
- def create_nodes_from_master_dataset(self, params_dict: dict):
122
- """Create nodes from the master dataset.
123
-
124
- Args:
125
- params_dict (dict): A dictionary containing parameters for node creation.
126
- - column_name (str): The name of the column in the MasterDataset used to create nodes.
127
- - train_nodes (list): A list of node names that will be used for training.
128
- - test_nodes (list): A list of node names that will be used for testing.
129
-
130
- Returns:
131
- list: A list of Node instances created from the master dataset.
132
- """
133
- assert "column_name" in params_dict.keys()
134
- column_name, train_nodes, test_nodes = (
135
- params_dict["column_name"],
136
- params_dict["train_nodes"],
137
- params_dict["test_nodes"],
138
- )
139
- self.column_name = column_name
140
- self.auto = 1
141
-
142
- # Update the Column name of the auto flSetup
143
- query = f"UPDATE FLsetup SET column_name = '{column_name}' WHERE name = '{self.name}'"
144
- self.eng.execute(text(query))
145
-
146
-
147
- # Add Network to DB
148
- # self.network.create_network()
149
-
150
- netid = get_netid_from_name(self.network.name)
151
-
152
- assert self.network.mtable_exists == 1
153
- node_names = pd.read_sql(
154
- text(READ_DISTINCT_NODES_QUERY.format(column_name)), self.eng
155
- )
156
-
157
- nodes = [Node(val[0], 1) for val in node_names.values.tolist()]
158
-
159
- used_nodes = []
160
-
161
- for node in nodes:
162
- if node.name in train_nodes:
163
- node.train = 1
164
- node.create_node(netid)
165
- used_nodes.append(node)
166
- if node.name in test_nodes:
167
- node.train =0
168
- node.create_node(netid)
169
- used_nodes.append(node)
170
- return used_nodes
171
-
172
- def create_dataloader_from_node(
173
- self,
174
- node: Node,
175
- output,
176
- fill_strategy="mean", fit_encode=[], to_drop=[],
177
- train_batch_size: int = 32,
178
- test_batch_size: int = 1,
179
- split_frac: float = 0.2,
180
- dataset: Dataset = None,
181
-
182
- ):
183
- """Create DataLoader from a Node.
184
-
185
- Args:
186
- node (Node): The node from which to create DataLoader.
187
- train_batch_size (int): The batch size for training data.
188
- test_batch_size (int): The batch size for test data.
189
- split_frac (float): The fraction of data to be used for training.
190
- dataset (Dataset): The dataset to use. If None, the method will read the dataset from the node.
191
-
192
- Returns:
193
- DataLoader: The DataLoader instances for training and testing.
194
- """
195
- if dataset is None:
196
- if self.column_name is not None:
197
- dataset = process_data_after_reading(
198
- node.get_dataset(self.column_name), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop
199
- )
200
- else:
201
- dataset = process_data_after_reading(
202
- node.get_dataset(), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
203
-
204
- dataset_size = len(dataset)
205
- traindata_size = int(dataset_size * (1 - split_frac))
206
- traindata, testdata = random_split(
207
- dataset, [traindata_size, dataset_size - traindata_size]
208
- )
209
- trainloader, testloader = DataLoader(
210
- traindata, batch_size=train_batch_size
211
- ), DataLoader(testdata, batch_size=test_batch_size)
212
- return trainloader, testloader
213
-
214
- def create_federated_dataset(
215
- self, output, fill_strategy="mean", fit_encode=[], to_drop=[], val_frac=0.1, test_frac=0.2
216
- ) -> FederatedDataset:
217
- """Create a federated dataset.
218
-
219
- Args:
220
- output (string): the output feature of the dataset
221
- val_frac (float): The fraction of data to be used for validation.
222
- test_frac (float): The fraction of data to be used for testing.
223
-
224
- Returns:
225
- FederatedDataset: The FederatedDataset instance containing train, validation, and test data.
226
- """
227
-
228
- if not self.column_name:
229
- to_drop.extend(["DataSetName" , "NodeId" , "DataSetId"])
230
- else :
231
- to_drop.extend(["PatientId"])
232
-
233
- netid = self.network.id
234
- train_nodes = pd.read_sql(
235
- text(
236
- f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 1 "
237
- ),
238
- self.eng,
239
- )
240
- test_nodes = pd.read_sql(
241
- text(
242
- f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 0 "
243
- ),
244
- self.eng,
245
- )
246
-
247
- train_nodes = [
248
- Node(val[0], 1, test_frac) for val in train_nodes.values.tolist()
249
- ]
250
- test_nodes = [Node(val[0], 0) for val in test_nodes.values.tolist()]
251
-
252
- trainloaders, valloaders, testloaders = [], [], []
253
- # if len(test_nodes) == 0:
254
- # raise "test node empty"
255
- if test_nodes is None:
256
- _, testloader = self.create_dataloader_from_node(
257
- train_nodes[0], output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
258
- testloaders.append(testloader)
259
- else:
260
- for train_node in train_nodes:
261
- train_valloader, testloader = self.create_dataloader_from_node(
262
- train_node, output, fill_strategy=fill_strategy,
263
- fit_encode=fit_encode, to_drop=to_drop,)
264
- trainloader, valloader = self.create_dataloader_from_node(
265
- train_node,
266
- output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop,
267
- test_batch_size=32,
268
- split_frac=val_frac,
269
- dataset=train_valloader.dataset,
270
- )
271
- trainloaders.append(trainloader)
272
- valloaders.append(valloader)
273
- testloaders.append(testloader)
274
-
275
- for test_node in test_nodes:
276
- _, testloader = self.create_dataloader_from_node(
277
- test_node, output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop, split_frac=1.0
278
- )
279
- testloaders.append(testloader)
280
- train_nodes_names = [node.name for node in train_nodes]
281
- test_nodes_names = train_nodes_names + [
282
- node.name for node in test_nodes
283
- ]
284
-
285
- # test_nodes_names = [
286
- # node.name for node in test_nodes
287
- # ]
288
-
289
- # Add FlSetup on to the DataBase
290
- # self.create()
291
-
292
- # self.network.update_network(FLsetupId=self.id)
293
- fed_dataset = FederatedDataset(
294
- self.name + "_Feddataset",
295
- train_nodes_names,
296
- test_nodes_names,
297
- trainloaders,
298
- valloaders,
299
- testloaders,
300
- )
301
- self.fed_dataset = fed_dataset
302
- self.fed_dataset.create(self.id)
303
- return self.fed_dataset
304
-
305
-
306
-
307
-
308
- def get_flDataSet(self):
309
- """
310
- Retrieve the federated dataset associated with the FL setup using the FL setup's name.
311
-
312
- Returns:
313
- pandas.DataFrame: DataFrame containing the federated dataset information.
314
- """
315
- return pd.read_sql(
316
- text(
317
- f"SELECT * FROM FedDatasets WHERE FLsetupId = {get_flsetupid_from_name(self.name)}"
318
- ),
319
- self.eng,
320
- )
1
+ from datetime import datetime
2
+
3
+
4
+ from torch.utils.data import random_split, DataLoader, Dataset
5
+
6
+ from MEDfl.LearningManager.federated_dataset import FederatedDataset
7
+ from .net_helper import *
8
+ from .net_manager_queries import * # Import the sql_queries module
9
+ from .network import Network
10
+
11
+ from .node import Node
12
+
13
+ from MEDfl.NetManager.database_connector import DatabaseManager
14
+
15
+
16
+ class FLsetup:
17
+ def __init__(self, name: str, description: str, network: Network):
18
+ """Initialize a Federated Learning (FL) setup.
19
+
20
+ Args:
21
+ name (str): The name of the FL setup.
22
+ description (str): A description of the FL setup.
23
+ network (Network): An instance of the Network class representing the network architecture.
24
+ """
25
+ self.name = name
26
+ self.description = description
27
+ self.network = network
28
+ self.column_name = None
29
+ self.auto = 1 if self.column_name is not None else 0
30
+ self.validate()
31
+ self.fed_dataset = None
32
+
33
+ db_manager = DatabaseManager()
34
+ db_manager.connect()
35
+ self.eng = db_manager.get_connection()
36
+
37
+
38
+
39
+ def validate(self):
40
+ """Validate name, description, and network."""
41
+ if not isinstance(self.name, str):
42
+ raise TypeError("name argument must be a string")
43
+
44
+ if not isinstance(self.description, str):
45
+ raise TypeError("description argument must be a string")
46
+
47
+ if not isinstance(self.network, Network):
48
+ raise TypeError(
49
+ "network argument must be a MEDfl.NetManager.Network "
50
+ )
51
+
52
+ def create(self):
53
+ """Create an FL setup."""
54
+ creation_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
55
+ netid = get_netid_from_name(self.network.name)
56
+ self.eng.execute(
57
+ text(CREATE_FLSETUP_QUERY),
58
+ {
59
+ "name": self.name,
60
+ "description": self.description,
61
+ "creation_date": creation_date,
62
+ "net_id": netid,
63
+ "column_name": self.column_name,
64
+ },
65
+ )
66
+ self.id = get_flsetupid_from_name(self.name)
67
+
68
+ def delete(self):
69
+ """Delete the FL setup."""
70
+ if self.fed_dataset is not None:
71
+ self.fed_dataset.delete_Flsetup(FLsetupId=self.id)
72
+ self.eng.execute(text(DELETE_FLSETUP_QUERY), {"name": self.name})
73
+
74
+ @classmethod
75
+ def read_setup(cls, FLsetupId: int):
76
+ """Read the FL setup by FLsetupId.
77
+
78
+ Args:
79
+ FLsetupId (int): The id of the FL setup to read.
80
+
81
+ Returns:
82
+ FLsetup: An instance of the FLsetup class with the specified FLsetupId.
83
+ """
84
+ db_manager = DatabaseManager()
85
+ db_manager.connect()
86
+ my_eng = db_manager.get_connection()
87
+
88
+ res = pd.read_sql(
89
+ text(READ_SETUP_QUERY), my_eng, params={"flsetup_id": FLsetupId}
90
+ ).iloc[0]
91
+
92
+ network_res = pd.read_sql(
93
+ text(READ_NETWORK_BY_ID_QUERY),
94
+ my_eng,
95
+ params={"net_id": int(res["NetId"])},
96
+ ).iloc[0]
97
+ network = Network(network_res["NetName"])
98
+ setattr(network, "id", res["NetId"])
99
+ fl_setup = cls(res["name"], res["description"], network)
100
+ if res["column_name"] == str(None):
101
+ res["column_name"] = None
102
+ setattr(fl_setup, "column_name", res["column_name"])
103
+ setattr(fl_setup, "id", res["FLsetupId"])
104
+
105
+ return fl_setup
106
+
107
+ @staticmethod
108
+ def list_allsetups():
109
+ """List all the FL setups.
110
+
111
+ Returns:
112
+ DataFrame: A DataFrame containing information about all the FL setups.
113
+ """
114
+ db_manager = DatabaseManager()
115
+ db_manager.connect()
116
+ my_eng = db_manager.get_connection()
117
+
118
+ Flsetups = pd.read_sql(text(READ_ALL_SETUPS_QUERY), my_eng)
119
+ return Flsetups
120
+
121
+ def create_nodes_from_master_dataset(self, params_dict: dict):
122
+ """Create nodes from the master dataset.
123
+
124
+ Args:
125
+ params_dict (dict): A dictionary containing parameters for node creation.
126
+ - column_name (str): The name of the column in the MasterDataset used to create nodes.
127
+ - train_nodes (list): A list of node names that will be used for training.
128
+ - test_nodes (list): A list of node names that will be used for testing.
129
+
130
+ Returns:
131
+ list: A list of Node instances created from the master dataset.
132
+ """
133
+ assert "column_name" in params_dict.keys()
134
+ column_name, train_nodes, test_nodes = (
135
+ params_dict["column_name"],
136
+ params_dict["train_nodes"],
137
+ params_dict["test_nodes"],
138
+ )
139
+ self.column_name = column_name
140
+ self.auto = 1
141
+
142
+ # Update the Column name of the auto flSetup
143
+ query = f"UPDATE FLsetup SET column_name = '{column_name}' WHERE name = '{self.name}'"
144
+ self.eng.execute(text(query))
145
+
146
+
147
+ # Add Network to DB
148
+ # self.network.create_network()
149
+
150
+ netid = get_netid_from_name(self.network.name)
151
+
152
+ assert self.network.mtable_exists == 1
153
+ node_names = pd.read_sql(
154
+ text(READ_DISTINCT_NODES_QUERY.format(column_name)), self.eng
155
+ )
156
+
157
+ nodes = [Node(val[0], 1) for val in node_names.values.tolist()]
158
+
159
+ used_nodes = []
160
+
161
+ for node in nodes:
162
+ if node.name in train_nodes:
163
+ node.train = 1
164
+ node.create_node(netid)
165
+ used_nodes.append(node)
166
+ if node.name in test_nodes:
167
+ node.train =0
168
+ node.create_node(netid)
169
+ used_nodes.append(node)
170
+ return used_nodes
171
+
172
+ def create_dataloader_from_node(
173
+ self,
174
+ node: Node,
175
+ output,
176
+ fill_strategy="mean", fit_encode=[], to_drop=[],
177
+ train_batch_size: int = 32,
178
+ test_batch_size: int = 1,
179
+ split_frac: float = 0.2,
180
+ dataset: Dataset = None,
181
+
182
+ ):
183
+ """Create DataLoader from a Node.
184
+
185
+ Args:
186
+ node (Node): The node from which to create DataLoader.
187
+ train_batch_size (int): The batch size for training data.
188
+ test_batch_size (int): The batch size for test data.
189
+ split_frac (float): The fraction of data to be used for training.
190
+ dataset (Dataset): The dataset to use. If None, the method will read the dataset from the node.
191
+
192
+ Returns:
193
+ DataLoader: The DataLoader instances for training and testing.
194
+ """
195
+ if dataset is None:
196
+ if self.column_name is not None:
197
+ dataset = process_data_after_reading(
198
+ node.get_dataset(self.column_name), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop
199
+ )
200
+ else:
201
+ dataset = process_data_after_reading(
202
+ node.get_dataset(), output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
203
+
204
+ dataset_size = len(dataset)
205
+ traindata_size = int(dataset_size * (1 - split_frac))
206
+ traindata, testdata = random_split(
207
+ dataset, [traindata_size, dataset_size - traindata_size]
208
+ )
209
+ trainloader, testloader = DataLoader(
210
+ traindata, batch_size=train_batch_size
211
+ ), DataLoader(testdata, batch_size=test_batch_size)
212
+ return trainloader, testloader
213
+
214
+ def create_federated_dataset(
215
+ self, output, fill_strategy="mean", fit_encode=[], to_drop=[], val_frac=0.1, test_frac=0.2
216
+ ) -> FederatedDataset:
217
+ """Create a federated dataset.
218
+
219
+ Args:
220
+ output (string): the output feature of the dataset
221
+ val_frac (float): The fraction of data to be used for validation.
222
+ test_frac (float): The fraction of data to be used for testing.
223
+
224
+ Returns:
225
+ FederatedDataset: The FederatedDataset instance containing train, validation, and test data.
226
+ """
227
+
228
+ if not self.column_name:
229
+ to_drop.extend(["DataSetName" , "NodeId" , "DataSetId"])
230
+ else :
231
+ to_drop.extend(["PatientId"])
232
+
233
+ netid = self.network.id
234
+ train_nodes = pd.read_sql(
235
+ text(
236
+ f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 1 "
237
+ ),
238
+ self.eng,
239
+ )
240
+ test_nodes = pd.read_sql(
241
+ text(
242
+ f"SELECT Nodes.NodeName FROM Nodes WHERE Nodes.NetId = {netid} AND Nodes.train = 0 "
243
+ ),
244
+ self.eng,
245
+ )
246
+
247
+ train_nodes = [
248
+ Node(val[0], 1, test_frac) for val in train_nodes.values.tolist()
249
+ ]
250
+ test_nodes = [Node(val[0], 0) for val in test_nodes.values.tolist()]
251
+
252
+ trainloaders, valloaders, testloaders = [], [], []
253
+ # if len(test_nodes) == 0:
254
+ # raise "test node empty"
255
+ if test_nodes is None:
256
+ _, testloader = self.create_dataloader_from_node(
257
+ train_nodes[0], output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop)
258
+ testloaders.append(testloader)
259
+ else:
260
+ for train_node in train_nodes:
261
+ train_valloader, testloader = self.create_dataloader_from_node(
262
+ train_node, output, fill_strategy=fill_strategy,
263
+ fit_encode=fit_encode, to_drop=to_drop,)
264
+ trainloader, valloader = self.create_dataloader_from_node(
265
+ train_node,
266
+ output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop,
267
+ test_batch_size=32,
268
+ split_frac=val_frac,
269
+ dataset=train_valloader.dataset,
270
+ )
271
+ trainloaders.append(trainloader)
272
+ valloaders.append(valloader)
273
+ testloaders.append(testloader)
274
+
275
+ for test_node in test_nodes:
276
+ _, testloader = self.create_dataloader_from_node(
277
+ test_node, output, fill_strategy=fill_strategy, fit_encode=fit_encode, to_drop=to_drop, split_frac=1.0
278
+ )
279
+ testloaders.append(testloader)
280
+ train_nodes_names = [node.name for node in train_nodes]
281
+ test_nodes_names = train_nodes_names + [
282
+ node.name for node in test_nodes
283
+ ]
284
+
285
+ # test_nodes_names = [
286
+ # node.name for node in test_nodes
287
+ # ]
288
+
289
+ # Add FlSetup on to the DataBase
290
+ # self.create()
291
+
292
+ # self.network.update_network(FLsetupId=self.id)
293
+ fed_dataset = FederatedDataset(
294
+ self.name + "_Feddataset",
295
+ train_nodes_names,
296
+ test_nodes_names,
297
+ trainloaders,
298
+ valloaders,
299
+ testloaders,
300
+ )
301
+ self.fed_dataset = fed_dataset
302
+ self.fed_dataset.create(self.id)
303
+ return self.fed_dataset
304
+
305
+
306
+
307
+
308
+ def get_flDataSet(self):
309
+ """
310
+ Retrieve the federated dataset associated with the FL setup using the FL setup's name.
311
+
312
+ Returns:
313
+ pandas.DataFrame: DataFrame containing the federated dataset information.
314
+ """
315
+ return pd.read_sql(
316
+ text(
317
+ f"SELECT * FROM FedDatasets WHERE FLsetupId = {get_flsetupid_from_name(self.name)}"
318
+ ),
319
+ self.eng,
320
+ )