MEDfl 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {MEDfl-0.1.31.dist-info → MEDfl-0.1.33.dist-info}/METADATA +127 -128
- MEDfl-0.1.33.dist-info/RECORD +34 -0
- {MEDfl-0.1.31.dist-info → MEDfl-0.1.33.dist-info}/WHEEL +1 -1
- {MEDfl-0.1.31.dist-info → MEDfl-0.1.33.dist-info}/top_level.txt +0 -1
- Medfl/LearningManager/__init__.py +13 -13
- Medfl/LearningManager/client.py +150 -150
- Medfl/LearningManager/dynamicModal.py +287 -287
- Medfl/LearningManager/federated_dataset.py +60 -60
- Medfl/LearningManager/flpipeline.py +192 -192
- Medfl/LearningManager/model.py +223 -223
- Medfl/LearningManager/params.yaml +14 -14
- Medfl/LearningManager/params_optimiser.py +442 -442
- Medfl/LearningManager/plot.py +229 -229
- Medfl/LearningManager/server.py +181 -181
- Medfl/LearningManager/strategy.py +82 -82
- Medfl/LearningManager/utils.py +331 -308
- Medfl/NetManager/__init__.py +10 -9
- Medfl/NetManager/database_connector.py +43 -48
- Medfl/NetManager/dataset.py +92 -92
- Medfl/NetManager/flsetup.py +320 -320
- Medfl/NetManager/net_helper.py +254 -248
- Medfl/NetManager/net_manager_queries.py +142 -137
- Medfl/NetManager/network.py +194 -174
- Medfl/NetManager/node.py +184 -178
- Medfl/__init__.py +3 -2
- Medfl/scripts/__init__.py +2 -0
- Medfl/scripts/base.py +30 -0
- Medfl/scripts/create_db.py +126 -0
- alembic/env.py +61 -61
- scripts/base.py +29 -29
- scripts/config.ini +5 -5
- scripts/create_db.py +133 -133
- MEDfl/LearningManager/__init__.py +0 -13
- MEDfl/LearningManager/client.py +0 -150
- MEDfl/LearningManager/dynamicModal.py +0 -287
- MEDfl/LearningManager/federated_dataset.py +0 -60
- MEDfl/LearningManager/flpipeline.py +0 -192
- MEDfl/LearningManager/model.py +0 -223
- MEDfl/LearningManager/params.yaml +0 -14
- MEDfl/LearningManager/params_optimiser.py +0 -442
- MEDfl/LearningManager/plot.py +0 -229
- MEDfl/LearningManager/server.py +0 -181
- MEDfl/LearningManager/strategy.py +0 -82
- MEDfl/LearningManager/utils.py +0 -333
- MEDfl/NetManager/__init__.py +0 -9
- MEDfl/NetManager/database_connector.py +0 -48
- MEDfl/NetManager/dataset.py +0 -92
- MEDfl/NetManager/flsetup.py +0 -320
- MEDfl/NetManager/net_helper.py +0 -248
- MEDfl/NetManager/net_manager_queries.py +0 -137
- MEDfl/NetManager/network.py +0 -174
- MEDfl/NetManager/node.py +0 -178
- MEDfl/__init__.py +0 -2
- MEDfl-0.1.31.data/scripts/setup_mysql.sh +0 -22
- MEDfl-0.1.31.dist-info/RECORD +0 -54
- scripts/db_config.ini +0 -6
scripts/base.py
CHANGED
@@ -1,30 +1,30 @@
|
|
1
|
-
import mysql.connector
|
2
|
-
from sqlalchemy import create_engine, text
|
3
|
-
from configparser import ConfigParser
|
4
|
-
import yaml
|
5
|
-
import pkg_resources
|
6
|
-
import os
|
7
|
-
|
8
|
-
# Get the directory of the current script
|
9
|
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
10
|
-
|
11
|
-
# Load configuration from the config file
|
12
|
-
config_file_path = os.path.join(current_directory, 'db_config.ini')
|
13
|
-
|
14
|
-
config = ConfigParser()
|
15
|
-
config.read(config_file_path)
|
16
|
-
mysql_config = config['mysql']
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
connection_string = (
|
21
|
-
f"mysql+mysqlconnector://{mysql_config['user']}:{mysql_config['password']}@"
|
22
|
-
f"{mysql_config['host']}:{mysql_config['port']}/{mysql_config['database']}"
|
23
|
-
)
|
24
|
-
|
25
|
-
eng = create_engine(
|
26
|
-
connection_string,
|
27
|
-
execution_options={"autocommit": True},
|
28
|
-
)
|
29
|
-
|
1
|
+
import mysql.connector
|
2
|
+
from sqlalchemy import create_engine, text
|
3
|
+
from configparser import ConfigParser
|
4
|
+
import yaml
|
5
|
+
import pkg_resources
|
6
|
+
import os
|
7
|
+
|
8
|
+
# Get the directory of the current script
|
9
|
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
10
|
+
|
11
|
+
# Load configuration from the config file
|
12
|
+
config_file_path = os.path.join(current_directory, 'db_config.ini')
|
13
|
+
|
14
|
+
config = ConfigParser()
|
15
|
+
config.read(config_file_path)
|
16
|
+
mysql_config = config['mysql']
|
17
|
+
|
18
|
+
|
19
|
+
|
20
|
+
connection_string = (
|
21
|
+
f"mysql+mysqlconnector://{mysql_config['user']}:{mysql_config['password']}@"
|
22
|
+
f"{mysql_config['host']}:{mysql_config['port']}/{mysql_config['database']}"
|
23
|
+
)
|
24
|
+
|
25
|
+
eng = create_engine(
|
26
|
+
connection_string,
|
27
|
+
execution_options={"autocommit": True},
|
28
|
+
)
|
29
|
+
|
30
30
|
my_eng = eng.connect()
|
scripts/config.ini
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
[mysql]
|
2
|
-
host = localhost
|
3
|
-
port = 3306
|
4
|
-
user = root
|
5
|
-
password =
|
1
|
+
[mysql]
|
2
|
+
host = localhost
|
3
|
+
port = 3306
|
4
|
+
user = root
|
5
|
+
password =
|
6
6
|
database = MEDfl
|
scripts/create_db.py
CHANGED
@@ -1,133 +1,133 @@
|
|
1
|
-
import sys
|
2
|
-
import mysql.connector
|
3
|
-
import pandas as pd
|
4
|
-
from mysql.connector import Error
|
5
|
-
|
6
|
-
from configparser import ConfigParser
|
7
|
-
import os
|
8
|
-
|
9
|
-
def main(csv_file_path):
|
10
|
-
try:
|
11
|
-
# Get the directory of the current script
|
12
|
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
13
|
-
|
14
|
-
# Load configuration from the config file
|
15
|
-
config_file_path = os.path.join(current_directory, 'db_config.ini')
|
16
|
-
|
17
|
-
config = ConfigParser()
|
18
|
-
config.read(config_file_path)
|
19
|
-
mysql_config = config['mysql']
|
20
|
-
|
21
|
-
print('Im here !')
|
22
|
-
|
23
|
-
mydb = mysql.connector.connect(host=mysql_config['host'], user=mysql_config['user'], password=mysql_config['password'])
|
24
|
-
mycursor = mydb.cursor()
|
25
|
-
|
26
|
-
# Create the 'MEDfl' database if it doesn't exist
|
27
|
-
mycursor.execute("CREATE DATABASE IF NOT EXISTS MEDfl")
|
28
|
-
|
29
|
-
# Select the 'MEDfl' database
|
30
|
-
mycursor.execute("USE MEDfl")
|
31
|
-
|
32
|
-
# Get the list of all tables in the database
|
33
|
-
mycursor.execute("SHOW TABLES")
|
34
|
-
tables = mycursor.fetchall()
|
35
|
-
|
36
|
-
# Drop each table one by one
|
37
|
-
for table in tables:
|
38
|
-
table_name = table[0]
|
39
|
-
mycursor.execute(f"DROP TABLE IF EXISTS {table_name}")
|
40
|
-
|
41
|
-
# Create Networks table
|
42
|
-
mycursor.execute(
|
43
|
-
"CREATE TABLE Networks( \
|
44
|
-
NetId INT NOT NULL AUTO_INCREMENT, \
|
45
|
-
NetName VARCHAR(255), \
|
46
|
-
PRIMARY KEY (NetId) \
|
47
|
-
);"
|
48
|
-
)
|
49
|
-
|
50
|
-
# Create FLsetup table
|
51
|
-
mycursor.execute("CREATE TABLE FLsetup (\
|
52
|
-
FLsetupId int NOT NULL AUTO_INCREMENT,\
|
53
|
-
name varchar(255) NOT NULL, \
|
54
|
-
description varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL,\
|
55
|
-
creation_date datetime NOT NULL,\
|
56
|
-
NetId int NOT NULL,\
|
57
|
-
column_name varchar(255) DEFAULT NULL,\
|
58
|
-
PRIMARY KEY (`FLsetupId`) \
|
59
|
-
)")
|
60
|
-
|
61
|
-
# Create Nodes table
|
62
|
-
mycursor.execute("CREATE TABLE Nodes ( \
|
63
|
-
NodeId int NOT NULL AUTO_INCREMENT,\
|
64
|
-
NodeName varchar(255) DEFAULT NULL,\
|
65
|
-
train tinyint(1) DEFAULT '1',\
|
66
|
-
NetId int DEFAULT NULL,\
|
67
|
-
PRIMARY KEY (NodeId)\
|
68
|
-
)")
|
69
|
-
|
70
|
-
data_df = pd.read_csv(csv_file_path)
|
71
|
-
columns = data_df.columns.tolist()
|
72
|
-
column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
|
73
|
-
sub_query = "".join(f"{col} {column_map[str(data_df[col].dtype)]}," for col in columns)
|
74
|
-
|
75
|
-
# Create Datasets table by getting columns from the master csv file
|
76
|
-
mycursor.execute(
|
77
|
-
f"CREATE TABLE DataSets( \
|
78
|
-
DataSetId INT NOT NULL AUTO_INCREMENT, \
|
79
|
-
DataSetName VARCHAR(255), \
|
80
|
-
NodeId INT CHECK (NodeId = -1 OR NodeId IS NOT NULL),\
|
81
|
-
{sub_query}\
|
82
|
-
PRIMARY KEY (DataSetId)\
|
83
|
-
)"
|
84
|
-
)
|
85
|
-
|
86
|
-
# Create FLpipeline table
|
87
|
-
mycursor.execute("CREATE TABLE FLpipeline(\
|
88
|
-
id int NOT NULL AUTO_INCREMENT,\
|
89
|
-
name varchar(255) NOT NULL, \
|
90
|
-
description varchar(255) NOT NULL,\
|
91
|
-
creation_date datetime NOT NULL,\
|
92
|
-
results longtext NOT NULL,\
|
93
|
-
PRIMARY KEY (id)\
|
94
|
-
) ")
|
95
|
-
|
96
|
-
# Create test results table
|
97
|
-
mycursor.execute("CREATE TABLE testResults(\
|
98
|
-
pipelineId INT,\
|
99
|
-
nodename VARCHAR(100) NOT NULL, \
|
100
|
-
confusionmatrix VARCHAR(255),\
|
101
|
-
accuracy LONG,\
|
102
|
-
sensivity LONG,\
|
103
|
-
ppv LONG,\
|
104
|
-
npv LONG,\
|
105
|
-
f1score LONG,\
|
106
|
-
fpr LONG,\
|
107
|
-
tpr LONG, \
|
108
|
-
PRIMARY KEY (pipelineId , nodename)\
|
109
|
-
) ")
|
110
|
-
|
111
|
-
# Create FederatedDataset table
|
112
|
-
mycursor.execute("CREATE TABLE FedDatasets (\
|
113
|
-
FedId int NOT NULL AUTO_INCREMENT,\
|
114
|
-
FLsetupId int DEFAULT NULL,\
|
115
|
-
FLpipeId int DEFAULT NULL,\
|
116
|
-
name varchar(255) NOT NULL,\
|
117
|
-
PRIMARY KEY (FedId)\
|
118
|
-
)")
|
119
|
-
|
120
|
-
# Commit and close the cursor
|
121
|
-
mydb.commit()
|
122
|
-
mycursor.close()
|
123
|
-
mydb.close()
|
124
|
-
|
125
|
-
except Error as e:
|
126
|
-
print(f"Error: {e}")
|
127
|
-
|
128
|
-
if __name__ == "__main__":
|
129
|
-
if len(sys.argv) != 2:
|
130
|
-
print("Usage: python script.py <path_to_csv_file>")
|
131
|
-
sys.exit(1)
|
132
|
-
csv_file_path = sys.argv[1]
|
133
|
-
main(csv_file_path)
|
1
|
+
import sys
|
2
|
+
import mysql.connector
|
3
|
+
import pandas as pd
|
4
|
+
from mysql.connector import Error
|
5
|
+
|
6
|
+
from configparser import ConfigParser
|
7
|
+
import os
|
8
|
+
|
9
|
+
def main(csv_file_path):
|
10
|
+
try:
|
11
|
+
# Get the directory of the current script
|
12
|
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
13
|
+
|
14
|
+
# Load configuration from the config file
|
15
|
+
config_file_path = os.path.join(current_directory, 'db_config.ini')
|
16
|
+
|
17
|
+
config = ConfigParser()
|
18
|
+
config.read(config_file_path)
|
19
|
+
mysql_config = config['mysql']
|
20
|
+
|
21
|
+
print('Im here !')
|
22
|
+
|
23
|
+
mydb = mysql.connector.connect(host=mysql_config['host'], user=mysql_config['user'], password=mysql_config['password'])
|
24
|
+
mycursor = mydb.cursor()
|
25
|
+
|
26
|
+
# Create the 'MEDfl' database if it doesn't exist
|
27
|
+
mycursor.execute("CREATE DATABASE IF NOT EXISTS MEDfl")
|
28
|
+
|
29
|
+
# Select the 'MEDfl' database
|
30
|
+
mycursor.execute("USE MEDfl")
|
31
|
+
|
32
|
+
# Get the list of all tables in the database
|
33
|
+
mycursor.execute("SHOW TABLES")
|
34
|
+
tables = mycursor.fetchall()
|
35
|
+
|
36
|
+
# Drop each table one by one
|
37
|
+
for table in tables:
|
38
|
+
table_name = table[0]
|
39
|
+
mycursor.execute(f"DROP TABLE IF EXISTS {table_name}")
|
40
|
+
|
41
|
+
# Create Networks table
|
42
|
+
mycursor.execute(
|
43
|
+
"CREATE TABLE Networks( \
|
44
|
+
NetId INT NOT NULL AUTO_INCREMENT, \
|
45
|
+
NetName VARCHAR(255), \
|
46
|
+
PRIMARY KEY (NetId) \
|
47
|
+
);"
|
48
|
+
)
|
49
|
+
|
50
|
+
# Create FLsetup table
|
51
|
+
mycursor.execute("CREATE TABLE FLsetup (\
|
52
|
+
FLsetupId int NOT NULL AUTO_INCREMENT,\
|
53
|
+
name varchar(255) NOT NULL, \
|
54
|
+
description varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL,\
|
55
|
+
creation_date datetime NOT NULL,\
|
56
|
+
NetId int NOT NULL,\
|
57
|
+
column_name varchar(255) DEFAULT NULL,\
|
58
|
+
PRIMARY KEY (`FLsetupId`) \
|
59
|
+
)")
|
60
|
+
|
61
|
+
# Create Nodes table
|
62
|
+
mycursor.execute("CREATE TABLE Nodes ( \
|
63
|
+
NodeId int NOT NULL AUTO_INCREMENT,\
|
64
|
+
NodeName varchar(255) DEFAULT NULL,\
|
65
|
+
train tinyint(1) DEFAULT '1',\
|
66
|
+
NetId int DEFAULT NULL,\
|
67
|
+
PRIMARY KEY (NodeId)\
|
68
|
+
)")
|
69
|
+
|
70
|
+
data_df = pd.read_csv(csv_file_path)
|
71
|
+
columns = data_df.columns.tolist()
|
72
|
+
column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
|
73
|
+
sub_query = "".join(f"{col} {column_map[str(data_df[col].dtype)]}," for col in columns)
|
74
|
+
|
75
|
+
# Create Datasets table by getting columns from the master csv file
|
76
|
+
mycursor.execute(
|
77
|
+
f"CREATE TABLE DataSets( \
|
78
|
+
DataSetId INT NOT NULL AUTO_INCREMENT, \
|
79
|
+
DataSetName VARCHAR(255), \
|
80
|
+
NodeId INT CHECK (NodeId = -1 OR NodeId IS NOT NULL),\
|
81
|
+
{sub_query}\
|
82
|
+
PRIMARY KEY (DataSetId)\
|
83
|
+
)"
|
84
|
+
)
|
85
|
+
|
86
|
+
# Create FLpipeline table
|
87
|
+
mycursor.execute("CREATE TABLE FLpipeline(\
|
88
|
+
id int NOT NULL AUTO_INCREMENT,\
|
89
|
+
name varchar(255) NOT NULL, \
|
90
|
+
description varchar(255) NOT NULL,\
|
91
|
+
creation_date datetime NOT NULL,\
|
92
|
+
results longtext NOT NULL,\
|
93
|
+
PRIMARY KEY (id)\
|
94
|
+
) ")
|
95
|
+
|
96
|
+
# Create test results table
|
97
|
+
mycursor.execute("CREATE TABLE testResults(\
|
98
|
+
pipelineId INT,\
|
99
|
+
nodename VARCHAR(100) NOT NULL, \
|
100
|
+
confusionmatrix VARCHAR(255),\
|
101
|
+
accuracy LONG,\
|
102
|
+
sensivity LONG,\
|
103
|
+
ppv LONG,\
|
104
|
+
npv LONG,\
|
105
|
+
f1score LONG,\
|
106
|
+
fpr LONG,\
|
107
|
+
tpr LONG, \
|
108
|
+
PRIMARY KEY (pipelineId , nodename)\
|
109
|
+
) ")
|
110
|
+
|
111
|
+
# Create FederatedDataset table
|
112
|
+
mycursor.execute("CREATE TABLE FedDatasets (\
|
113
|
+
FedId int NOT NULL AUTO_INCREMENT,\
|
114
|
+
FLsetupId int DEFAULT NULL,\
|
115
|
+
FLpipeId int DEFAULT NULL,\
|
116
|
+
name varchar(255) NOT NULL,\
|
117
|
+
PRIMARY KEY (FedId)\
|
118
|
+
)")
|
119
|
+
|
120
|
+
# Commit and close the cursor
|
121
|
+
mydb.commit()
|
122
|
+
mycursor.close()
|
123
|
+
mydb.close()
|
124
|
+
|
125
|
+
except Error as e:
|
126
|
+
print(f"Error: {e}")
|
127
|
+
|
128
|
+
if __name__ == "__main__":
|
129
|
+
if len(sys.argv) != 2:
|
130
|
+
print("Usage: python script.py <path_to_csv_file>")
|
131
|
+
sys.exit(1)
|
132
|
+
csv_file_path = sys.argv[1]
|
133
|
+
main(csv_file_path)
|
@@ -1,13 +0,0 @@
|
|
1
|
-
# # MEDfl/LearningManager/__init__.py
|
2
|
-
|
3
|
-
# # Import modules from this package
|
4
|
-
# from .client import *
|
5
|
-
# from .dynamicModal import *
|
6
|
-
# from .flpipeline import *
|
7
|
-
# from .federated_dataset import *
|
8
|
-
# from .model import *
|
9
|
-
# from .params_optimiser import *
|
10
|
-
# from .plot import *
|
11
|
-
# from .server import *
|
12
|
-
# from .strategy import *
|
13
|
-
# from .utils import *
|
MEDfl/LearningManager/client.py
DELETED
@@ -1,150 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
import flwr as fl
|
3
|
-
from opacus import PrivacyEngine
|
4
|
-
from torch.utils.data import DataLoader
|
5
|
-
|
6
|
-
from .model import Model
|
7
|
-
from .utils import params
|
8
|
-
import torch
|
9
|
-
|
10
|
-
class FlowerClient(fl.client.NumPyClient):
|
11
|
-
"""
|
12
|
-
FlowerClient class for creating MEDfl clients.
|
13
|
-
|
14
|
-
Attributes:
|
15
|
-
cid (str): Client ID.
|
16
|
-
local_model (Model): Local model of the federated learning network.
|
17
|
-
trainloader (DataLoader): DataLoader for training data.
|
18
|
-
valloader (DataLoader): DataLoader for validation data.
|
19
|
-
diff_priv (bool): Flag indicating whether to use differential privacy.
|
20
|
-
"""
|
21
|
-
def __init__(self, cid: str, local_model: Model, trainloader: DataLoader, valloader: DataLoader, diff_priv: bool = params["diff_privacy"]):
|
22
|
-
"""
|
23
|
-
Initializes the FlowerClient instance.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
cid (str): Client ID.
|
27
|
-
local_model (Model): Local model of the federated learning network.
|
28
|
-
trainloader (DataLoader): DataLoader for training data.
|
29
|
-
valloader (DataLoader): DataLoader for validation data.
|
30
|
-
diff_priv (bool): Flag indicating whether to use differential privacy.
|
31
|
-
"""
|
32
|
-
self.cid = cid
|
33
|
-
self.local_model = local_model
|
34
|
-
self.trainloader = trainloader
|
35
|
-
self.valloader = valloader
|
36
|
-
if torch.cuda.is_available():
|
37
|
-
num_cuda_devices = torch.cuda.device_count()
|
38
|
-
if num_cuda_devices > 0:
|
39
|
-
device_idx = int(self.cid) % num_cuda_devices
|
40
|
-
self.device = torch.device(f"cuda:{device_idx}")
|
41
|
-
self.local_model.model.to(self.device)
|
42
|
-
else:
|
43
|
-
# Handle case where CUDA is available but no CUDA devices are found
|
44
|
-
raise RuntimeError("CUDA is available, but no CUDA devices are found.")
|
45
|
-
else:
|
46
|
-
# Handle case where CUDA is not available
|
47
|
-
self.device = torch.device("cpu")
|
48
|
-
self.local_model.model.to(self.device)
|
49
|
-
|
50
|
-
self.privacy_engine = PrivacyEngine(secure_mode=False)
|
51
|
-
self.diff_priv = diff_priv
|
52
|
-
self.epsilons = []
|
53
|
-
self.accuracies = []
|
54
|
-
self.losses = []
|
55
|
-
if self.diff_priv:
|
56
|
-
model, optimizer, self.trainloader = self.privacy_engine.make_private_with_epsilon(
|
57
|
-
module=self.local_model.model.train(),
|
58
|
-
optimizer=self.local_model.optimizer,
|
59
|
-
data_loader=self.trainloader,
|
60
|
-
epochs=params["train_epochs"],
|
61
|
-
target_epsilon=float(params["EPSILON"]),
|
62
|
-
target_delta= float(params["DELTA"]),
|
63
|
-
max_grad_norm=params["MAX_GRAD_NORM"],
|
64
|
-
)
|
65
|
-
setattr(self.local_model, "model", model)
|
66
|
-
setattr(self.local_model, "optimizer", optimizer)
|
67
|
-
self.validate()
|
68
|
-
|
69
|
-
def validate(self):
|
70
|
-
"""Validates cid, local_model, trainloader, valloader."""
|
71
|
-
if not isinstance(self.cid, str):
|
72
|
-
raise TypeError("cid argument must be a string")
|
73
|
-
|
74
|
-
if not isinstance(self.local_model, Model):
|
75
|
-
raise TypeError("local_model argument must be a MEDfl.LearningManager.model.Model")
|
76
|
-
|
77
|
-
if not isinstance(self.trainloader, DataLoader):
|
78
|
-
raise TypeError("trainloader argument must be a torch.utils.data.dataloader")
|
79
|
-
|
80
|
-
if not isinstance(self.valloader, DataLoader):
|
81
|
-
raise TypeError("valloader argument must be a torch.utils.data.dataloader")
|
82
|
-
|
83
|
-
if not isinstance(self.diff_priv, bool):
|
84
|
-
raise TypeError("diff_priv argument must be a bool")
|
85
|
-
|
86
|
-
def get_parameters(self, config):
|
87
|
-
"""
|
88
|
-
Returns the current parameters of the local model.
|
89
|
-
|
90
|
-
Args:
|
91
|
-
config: Configuration information.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
Numpy array: Parameters of the local model.
|
95
|
-
"""
|
96
|
-
print(f"[Client {self.cid}] get_parameters")
|
97
|
-
return self.local_model.get_parameters()
|
98
|
-
|
99
|
-
def fit(self, parameters, config):
|
100
|
-
"""
|
101
|
-
Fits the local model to the received parameters using federated learning.
|
102
|
-
|
103
|
-
Args:
|
104
|
-
parameters: Parameters received from the server.
|
105
|
-
config: Configuration information.
|
106
|
-
|
107
|
-
Returns:
|
108
|
-
Tuple: Parameters of the local model, number of training examples, and privacy information.
|
109
|
-
"""
|
110
|
-
print('\n -------------------------------- \n this is the config of the client')
|
111
|
-
print(f"[Client {self.cid}] fit, config: {config}")
|
112
|
-
# print(config['epochs'])
|
113
|
-
print('\n -------------------------------- \n ')
|
114
|
-
self.local_model.set_parameters(parameters)
|
115
|
-
for _ in range(params["train_epochs"]):
|
116
|
-
epsilon = self.local_model.train(
|
117
|
-
self.trainloader,
|
118
|
-
epoch=_,
|
119
|
-
device=self.device,
|
120
|
-
privacy_engine=self.privacy_engine,
|
121
|
-
diff_priv=self.diff_priv,
|
122
|
-
)
|
123
|
-
self.epsilons.append(epsilon)
|
124
|
-
print(f"epsilon of client {self.cid} : eps = {epsilon}")
|
125
|
-
return (
|
126
|
-
self.local_model.get_parameters(),
|
127
|
-
len(self.trainloader),
|
128
|
-
{"epsilon": epsilon},
|
129
|
-
)
|
130
|
-
|
131
|
-
def evaluate(self, parameters, config):
|
132
|
-
"""
|
133
|
-
Evaluates the local model on the validation data and returns the loss and accuracy.
|
134
|
-
|
135
|
-
Args:
|
136
|
-
parameters: Parameters received from the server.
|
137
|
-
config: Configuration information.
|
138
|
-
|
139
|
-
Returns:
|
140
|
-
Tuple: Loss, number of validation examples, and accuracy information.
|
141
|
-
"""
|
142
|
-
print(f"[Client {self.cid}] evaluate, config: {config}")
|
143
|
-
self.local_model.set_parameters(parameters)
|
144
|
-
loss, accuracy , auc = self.local_model.evaluate(
|
145
|
-
self.valloader, device=self.device
|
146
|
-
)
|
147
|
-
self.losses.append(loss)
|
148
|
-
self.accuracies.append(accuracy)
|
149
|
-
|
150
|
-
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
|