MEDfl 0.1.6__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,233 +1,311 @@
1
- #!/usr/bin/env python3
2
-
3
- import pkg_resources
4
- import torch
5
- import yaml
6
- from sklearn.metrics import *
7
- from yaml.loader import SafeLoader
8
-
9
- from scripts.base import *
10
- import json
11
-
12
-
13
- import pandas as pd
14
- import numpy as np
15
-
16
- yaml_path = pkg_resources.resource_filename(__name__, "params.yaml")
17
- with open(yaml_path) as g:
18
- params = yaml.load(g, Loader=SafeLoader)
19
-
20
- global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
21
- with open(global_yaml_path) as g:
22
- global_params = yaml.load(g, Loader=SafeLoader)
23
-
24
-
25
- def custom_classification_report(y_true, y_pred_prob):
26
- """
27
- Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
28
- F1-score, false positive rate, and true positive rate.
29
-
30
- Args:
31
- y_true (array-like): True labels.
32
- y_pred (array-like): Predicted labels.
33
-
34
- Returns:
35
- dict: A dictionary containing custom classification report metrics.
36
- """
37
- y_pred = (y_pred_prob).round() # Round absolute values of predicted probabilities to the nearest integer
38
-
39
- auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
40
-
41
- tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
42
-
43
- # Accuracy
44
- denominator_acc = tp + tn + fp + fn
45
- acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
46
-
47
- # Sensitivity/Recall
48
- denominator_sen = tp + fn
49
- sen = tp / denominator_sen if denominator_sen != 0 else 0.0
50
-
51
- # Specificity
52
- denominator_sp = tn + fp
53
- sp = tn / denominator_sp if denominator_sp != 0 else 0.0
54
-
55
- # PPV/Precision
56
- denominator_ppv = tp + fp
57
- ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
58
-
59
- # NPV
60
- denominator_npv = tn + fn
61
- npv = tn / denominator_npv if denominator_npv != 0 else 0.0
62
-
63
- # F1 Score
64
- denominator_f1 = sen + ppv
65
- f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
66
-
67
- # False Positive Rate
68
- denominator_fpr = fp + tn
69
- fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
70
-
71
- # True Positive Rate
72
- denominator_tpr = tp + fn
73
- tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
74
-
75
- return {
76
- "confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
77
- "Accuracy": round(acc, 3),
78
- "Sensitivity/Recall": round(sen, 3),
79
- "Specificity": round(sp, 3),
80
- "PPV/Precision": round(ppv, 3),
81
- "NPV": round(npv, 3),
82
- "F1-score": round(f1, 3),
83
- "False positive rate": round(fpr, 3),
84
- "True positive rate": round(tpr, 3),
85
- "auc": auc
86
- }
87
-
88
-
89
- def test(model, test_loader, device=torch.device("cpu")):
90
- """
91
- Evaluate a model using a test loader and return a custom classification report.
92
-
93
- Args:
94
- model (torch.nn.Module): PyTorch model to evaluate.
95
- test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
96
- device (torch.device, optional): Device for model evaluation. Default is "cpu".
97
-
98
- Returns:
99
- dict: A dictionary containing custom classification report metrics.
100
- """
101
-
102
- model.eval()
103
- with torch.no_grad():
104
- X_test, y_test = test_loader.dataset[:][0].to(device), test_loader.dataset[:][1].to(device)
105
- y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
106
-
107
- return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
108
-
109
-
110
- column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
111
-
112
-
113
- def empty_db():
114
- """
115
- Empty the database by deleting records from multiple tables and resetting auto-increment counters.
116
-
117
- Returns:
118
- None
119
- """
120
-
121
- # my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
122
- my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
123
- my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
124
- my_eng.execute(text(f"DELETE FROM {'Networks'}"))
125
- my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
126
-
127
- my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
128
- my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
129
- my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
130
- my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
131
- my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
132
- my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
133
- my_eng.execute(text(f"DELETE FROM {'testResults'}"))
134
- my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
135
- my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
136
-
137
- def get_pipeline_from_name(name):
138
- """
139
- Get the pipeline ID from its name in the database.
140
-
141
- Args:
142
- name (str): Name of the pipeline.
143
-
144
- Returns:
145
- int: ID of the pipeline.
146
- """
147
-
148
- NodeId = int(
149
- pd.read_sql(
150
- text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
151
- ).iloc[0, 0]
152
- )
153
- return NodeId
154
-
155
- def get_pipeline_confusion_matrix(pipeline_id):
156
- """
157
- Get the global confusion matrix for a pipeline based on test results.
158
-
159
- Args:
160
- pipeline_id (int): ID of the pipeline.
161
-
162
- Returns:
163
- dict: A dictionary representing the global confusion matrix.
164
- """
165
-
166
- data = pd.read_sql(
167
- text(f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
168
- )
169
-
170
- # Convert the column of strings into a list of dictionaries representing confusion matrices
171
- confusion_matrices = [
172
- json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
173
- ]
174
-
175
- # Initialize variables for global confusion matrix
176
- global_TP = global_FP = global_FN = global_TN = 0
177
-
178
- # Iterate through each dictionary and sum the corresponding values for each category
179
- for matrix in confusion_matrices:
180
- global_TP += matrix['TP']
181
- global_FP += matrix['FP']
182
- global_FN += matrix['FN']
183
- global_TN += matrix['TN']
184
-
185
- # Create a global confusion matrix as a dictionary
186
- global_confusion_matrix = {
187
- 'TP': global_TP,
188
- 'FP': global_FP,
189
- 'FN': global_FN,
190
- 'TN': global_TN
191
- }
192
- # Return the list of dictionaries representing confusion matrices
193
- return global_confusion_matrix
194
-
195
- def get_node_confusion_matrix(pipeline_id , node_name):
196
- """
197
- Get the confusion matrix for a specific node in a pipeline based on test results.
198
-
199
- Args:
200
- pipeline_id (int): ID of the pipeline.
201
- node_name (str): Name of the node.
202
-
203
- Returns:
204
- dict: A dictionary representing the confusion matrix for the specified node.
205
- """
206
-
207
- data = pd.read_sql(
208
- text(f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
209
- )
210
-
211
- # Convert the column of strings into a list of dictionaries representing confusion matrices
212
- confusion_matrices = [
213
- json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
214
- ]
215
-
216
-
217
- # Return the list of dictionaries representing confusion matrices
218
- return confusion_matrices[0]
219
-
220
- def get_pipeline_result(pipeline_id):
221
- """
222
- Get the test results for a pipeline.
223
-
224
- Args:
225
- pipeline_id (int): ID of the pipeline.
226
-
227
- Returns:
228
- pandas.DataFrame: DataFrame containing test results for the specified pipeline.
229
- """
230
- data = pd.read_sql(
231
- text(f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
232
- )
233
- return data
1
+ #!/usr/bin/env python3
2
+
3
+ import pkg_resources
4
+ import torch
5
+ import yaml
6
+ from sklearn.metrics import *
7
+ from yaml.loader import SafeLoader
8
+
9
+ # from scripts.base import *
10
+ import json
11
+
12
+
13
+ import pandas as pd
14
+ import numpy as np
15
+
16
+ import os
17
+ import configparser
18
+
19
+ import subprocess
20
+ import ast
21
+
22
+ from sqlalchemy import text
23
+
24
+ from Medfl.NetManager.database_connector import DatabaseManager
25
+
26
+
27
+ # Get the directory of the current script
28
+ current_directory = os.path.dirname(os.path.abspath(__file__))
29
+
30
+ # Load configuration from the config file
31
+ yaml_path = os.path.join(current_directory, 'params.yaml')
32
+
33
+ with open(yaml_path) as g:
34
+ params = yaml.load(g, Loader=SafeLoader)
35
+
36
+ # global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
37
+ # with open(global_yaml_path) as g:
38
+ # global_params = yaml.load(g, Loader=SafeLoader)
39
+
40
+
41
+ # Default path for the config file
42
+ DEFAULT_CONFIG_PATH = 'config.ini'
43
+
44
+
45
+ def load_config():
46
+ config = os.environ.get('MEDFL_DB_CONFIG')
47
+
48
+ if config:
49
+ return ast.literal_eval( config )
50
+ else:
51
+ raise ValueError(f"MEDfl db config not found")
52
+
53
+ # Function to allow users to set config path programmatically
54
+
55
+
56
+ def set_db_config(config_path):
57
+ config = configparser.ConfigParser()
58
+ config.read(config_path)
59
+ if (config['mysql']):
60
+ os.environ['MEDFL_DB_CONFIG'] = str(dict(config['mysql']))
61
+ else:
62
+ raise ValueError(f"mysql key not found in file '{config_path}'")
63
+
64
+ # set the master dataset path
65
+
66
+
67
+ def set_masterDataset_path(dataset_path):
68
+ os.environ['MEDFL_MASTERDATASET_PATH'] = dataset_path
69
+
70
+ # Create database
71
+
72
+
73
+ def create_MEDfl_db():
74
+ script_path = os.path.join(os.path.dirname(
75
+ __file__), 'scripts', 'create_db.sh')
76
+ subprocess.run(['sh', script_path], check=True)
77
+
78
+
79
+ def custom_classification_report(y_true, y_pred_prob):
80
+ """
81
+ Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
82
+ F1-score, false positive rate, and true positive rate.
83
+
84
+ Args:
85
+ y_true (array-like): True labels.
86
+ y_pred (array-like): Predicted labels.
87
+
88
+ Returns:
89
+ dict: A dictionary containing custom classification report metrics.
90
+ """
91
+ y_pred = (y_pred_prob).round(
92
+ ) # Round absolute values of predicted probabilities to the nearest integer
93
+
94
+ auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
95
+
96
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
97
+
98
+ # Accuracy
99
+ denominator_acc = tp + tn + fp + fn
100
+ acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
101
+
102
+ # Sensitivity/Recall
103
+ denominator_sen = tp + fn
104
+ sen = tp / denominator_sen if denominator_sen != 0 else 0.0
105
+
106
+ # Specificity
107
+ denominator_sp = tn + fp
108
+ sp = tn / denominator_sp if denominator_sp != 0 else 0.0
109
+
110
+ # PPV/Precision
111
+ denominator_ppv = tp + fp
112
+ ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
113
+
114
+ # NPV
115
+ denominator_npv = tn + fn
116
+ npv = tn / denominator_npv if denominator_npv != 0 else 0.0
117
+
118
+ # F1 Score
119
+ denominator_f1 = sen + ppv
120
+ f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
121
+
122
+ # False Positive Rate
123
+ denominator_fpr = fp + tn
124
+ fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
125
+
126
+ # True Positive Rate
127
+ denominator_tpr = tp + fn
128
+ tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
129
+
130
+ return {
131
+ "confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
132
+ "Accuracy": round(acc, 3),
133
+ "Sensitivity/Recall": round(sen, 3),
134
+ "Specificity": round(sp, 3),
135
+ "PPV/Precision": round(ppv, 3),
136
+ "NPV": round(npv, 3),
137
+ "F1-score": round(f1, 3),
138
+ "False positive rate": round(fpr, 3),
139
+ "True positive rate": round(tpr, 3),
140
+ "auc": auc
141
+ }
142
+
143
+
144
+ def test(model, test_loader, device=torch.device("cpu")):
145
+ """
146
+ Evaluate a model using a test loader and return a custom classification report.
147
+
148
+ Args:
149
+ model (torch.nn.Module): PyTorch model to evaluate.
150
+ test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
151
+ device (torch.device, optional): Device for model evaluation. Default is "cpu".
152
+
153
+ Returns:
154
+ dict: A dictionary containing custom classification report metrics.
155
+ """
156
+
157
+ model.eval()
158
+ with torch.no_grad():
159
+ X_test, y_test = test_loader.dataset[:][0].to(
160
+ device), test_loader.dataset[:][1].to(device)
161
+ y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
162
+
163
+ return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
164
+
165
+
166
+ column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
167
+
168
+
169
+ def empty_db():
170
+ """
171
+ Empty the database by deleting records from multiple tables and resetting auto-increment counters.
172
+
173
+ Returns:
174
+ None
175
+ """
176
+ db_manager = DatabaseManager()
177
+ db_manager.connect()
178
+ my_eng = db_manager.get_connection()
179
+
180
+ # my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
181
+ my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
182
+ my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
183
+ my_eng.execute(text(f"DELETE FROM {'Networks'}"))
184
+ my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
185
+
186
+ my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
187
+ my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
188
+ my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
189
+ my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
190
+ my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
191
+ my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
192
+ my_eng.execute(text(f"DELETE FROM {'testResults'}"))
193
+ my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
194
+ my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
195
+
196
+
197
+ def get_pipeline_from_name(name):
198
+ """
199
+ Get the pipeline ID from its name in the database.
200
+
201
+ Args:
202
+ name (str): Name of the pipeline.
203
+
204
+ Returns:
205
+ int: ID of the pipeline.
206
+ """
207
+ db_manager = DatabaseManager()
208
+ db_manager.connect()
209
+ my_eng = db_manager.get_connection()
210
+
211
+ NodeId = int(
212
+ pd.read_sql(
213
+ text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
214
+ ).iloc[0, 0]
215
+ )
216
+ return NodeId
217
+
218
+
219
+ def get_pipeline_confusion_matrix(pipeline_id):
220
+ """
221
+ Get the global confusion matrix for a pipeline based on test results.
222
+
223
+ Args:
224
+ pipeline_id (int): ID of the pipeline.
225
+
226
+ Returns:
227
+ dict: A dictionary representing the global confusion matrix.
228
+ """
229
+ db_manager = DatabaseManager()
230
+ db_manager.connect()
231
+ my_eng = db_manager.get_connection()
232
+
233
+ data = pd.read_sql(
234
+ text(
235
+ f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
236
+ )
237
+
238
+ # Convert the column of strings into a list of dictionaries representing confusion matrices
239
+ confusion_matrices = [
240
+ json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
241
+ ]
242
+
243
+ # Initialize variables for global confusion matrix
244
+ global_TP = global_FP = global_FN = global_TN = 0
245
+
246
+ # Iterate through each dictionary and sum the corresponding values for each category
247
+ for matrix in confusion_matrices:
248
+ global_TP += matrix['TP']
249
+ global_FP += matrix['FP']
250
+ global_FN += matrix['FN']
251
+ global_TN += matrix['TN']
252
+
253
+ # Create a global confusion matrix as a dictionary
254
+ global_confusion_matrix = {
255
+ 'TP': global_TP,
256
+ 'FP': global_FP,
257
+ 'FN': global_FN,
258
+ 'TN': global_TN
259
+ }
260
+ # Return the list of dictionaries representing confusion matrices
261
+ return global_confusion_matrix
262
+
263
+
264
+ def get_node_confusion_matrix(pipeline_id, node_name):
265
+ """
266
+ Get the confusion matrix for a specific node in a pipeline based on test results.
267
+
268
+ Args:
269
+ pipeline_id (int): ID of the pipeline.
270
+ node_name (str): Name of the node.
271
+
272
+ Returns:
273
+ dict: A dictionary representing the confusion matrix for the specified node.
274
+ """
275
+ db_manager = DatabaseManager()
276
+ db_manager.connect()
277
+ my_eng = db_manager.get_connection()
278
+
279
+ data = pd.read_sql(
280
+ text(
281
+ f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
282
+ )
283
+
284
+ # Convert the column of strings into a list of dictionaries representing confusion matrices
285
+ confusion_matrices = [
286
+ json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
287
+ ]
288
+
289
+ # Return the list of dictionaries representing confusion matrices
290
+ return confusion_matrices[0]
291
+
292
+
293
+ def get_pipeline_result(pipeline_id):
294
+ """
295
+ Get the test results for a pipeline.
296
+
297
+ Args:
298
+ pipeline_id (int): ID of the pipeline.
299
+
300
+ Returns:
301
+ pandas.DataFrame: DataFrame containing test results for the specified pipeline.
302
+ """
303
+ db_manager = DatabaseManager()
304
+ db_manager.connect()
305
+ my_eng = db_manager.get_connection()
306
+
307
+ data = pd.read_sql(
308
+ text(
309
+ f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
310
+ )
311
+ return data
@@ -1,9 +1,9 @@
1
- # # Medfl/NetworkManager/__init__.py
2
-
3
- # # Import modules from this package
4
- # from .dataset import *
5
- # from .flsetup import *
6
- # from .net_helper import *
7
- # from .net_manager_queries import *
8
- # from .network import *
9
- # from .node import *
1
+ # # Medfl/NetworkManager/__init__.py
2
+
3
+ # # Import modules from this package
4
+ # from .dataset import *
5
+ # from .flsetup import *
6
+ # from .net_helper import *
7
+ # from .net_manager_queries import *
8
+ # from .network import *
9
+ # from .node import *
@@ -0,0 +1,50 @@
1
+ import os
2
+ from sqlalchemy import create_engine
3
+ from configparser import ConfigParser
4
+
5
+ import subprocess
6
+
7
+
8
+ class DatabaseManager:
9
+ def __init__(self):
10
+
11
+ from Medfl.LearningManager.utils import load_config
12
+ db_config = load_config()
13
+ if db_config:
14
+ self.config = db_config
15
+ else:
16
+ self.config = None
17
+ self.engine = None
18
+
19
+ def connect(self):
20
+ if not self.config:
21
+ raise ValueError(
22
+ "Database configuration not loaded. Use load_config() or set_config_path() first.")
23
+ connection_string = (
24
+ f"mysql+mysqlconnector://{self.config['user']}:{self.config['password']}@"
25
+ f"{self.config['host']}:{self.config['port']}/{self.config['database']}"
26
+ )
27
+ self.engine = create_engine(connection_string, pool_pre_ping=True)
28
+
29
+ def get_connection(self):
30
+ if not self.engine:
31
+ self.connect()
32
+ return self.engine.connect()
33
+
34
+ def create_MEDfl_db(self, path_to_csv):
35
+ # Get the directory of the current script
36
+ current_directory = os.path.dirname(__file__)
37
+
38
+ # Define the path to the create_db.py script
39
+ create_db_script_path = os.path.join(
40
+ current_directory, '..', '..', 'scripts', 'create_db.py')
41
+
42
+ # Execute the create_db.py script
43
+ subprocess.run(
44
+ ['python', create_db_script_path, path_to_csv], check=True)
45
+
46
+ return
47
+
48
+ def close(self):
49
+ if self.engine:
50
+ self.engine.dispose()