MEDfl 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {MEDfl-0.1.31.dist-info → MEDfl-0.1.33.dist-info}/METADATA +127 -128
  2. MEDfl-0.1.33.dist-info/RECORD +34 -0
  3. {MEDfl-0.1.31.dist-info → MEDfl-0.1.33.dist-info}/WHEEL +1 -1
  4. {MEDfl-0.1.31.dist-info → MEDfl-0.1.33.dist-info}/top_level.txt +0 -1
  5. Medfl/LearningManager/__init__.py +13 -13
  6. Medfl/LearningManager/client.py +150 -150
  7. Medfl/LearningManager/dynamicModal.py +287 -287
  8. Medfl/LearningManager/federated_dataset.py +60 -60
  9. Medfl/LearningManager/flpipeline.py +192 -192
  10. Medfl/LearningManager/model.py +223 -223
  11. Medfl/LearningManager/params.yaml +14 -14
  12. Medfl/LearningManager/params_optimiser.py +442 -442
  13. Medfl/LearningManager/plot.py +229 -229
  14. Medfl/LearningManager/server.py +181 -181
  15. Medfl/LearningManager/strategy.py +82 -82
  16. Medfl/LearningManager/utils.py +331 -308
  17. Medfl/NetManager/__init__.py +10 -9
  18. Medfl/NetManager/database_connector.py +43 -48
  19. Medfl/NetManager/dataset.py +92 -92
  20. Medfl/NetManager/flsetup.py +320 -320
  21. Medfl/NetManager/net_helper.py +254 -248
  22. Medfl/NetManager/net_manager_queries.py +142 -137
  23. Medfl/NetManager/network.py +194 -174
  24. Medfl/NetManager/node.py +184 -178
  25. Medfl/__init__.py +3 -2
  26. Medfl/scripts/__init__.py +2 -0
  27. Medfl/scripts/base.py +30 -0
  28. Medfl/scripts/create_db.py +126 -0
  29. alembic/env.py +61 -61
  30. scripts/base.py +29 -29
  31. scripts/config.ini +5 -5
  32. scripts/create_db.py +133 -133
  33. MEDfl/LearningManager/__init__.py +0 -13
  34. MEDfl/LearningManager/client.py +0 -150
  35. MEDfl/LearningManager/dynamicModal.py +0 -287
  36. MEDfl/LearningManager/federated_dataset.py +0 -60
  37. MEDfl/LearningManager/flpipeline.py +0 -192
  38. MEDfl/LearningManager/model.py +0 -223
  39. MEDfl/LearningManager/params.yaml +0 -14
  40. MEDfl/LearningManager/params_optimiser.py +0 -442
  41. MEDfl/LearningManager/plot.py +0 -229
  42. MEDfl/LearningManager/server.py +0 -181
  43. MEDfl/LearningManager/strategy.py +0 -82
  44. MEDfl/LearningManager/utils.py +0 -333
  45. MEDfl/NetManager/__init__.py +0 -9
  46. MEDfl/NetManager/database_connector.py +0 -48
  47. MEDfl/NetManager/dataset.py +0 -92
  48. MEDfl/NetManager/flsetup.py +0 -320
  49. MEDfl/NetManager/net_helper.py +0 -248
  50. MEDfl/NetManager/net_manager_queries.py +0 -137
  51. MEDfl/NetManager/network.py +0 -174
  52. MEDfl/NetManager/node.py +0 -178
  53. MEDfl/__init__.py +0 -2
  54. MEDfl-0.1.31.data/scripts/setup_mysql.sh +0 -22
  55. MEDfl-0.1.31.dist-info/RECORD +0 -54
  56. scripts/db_config.ini +0 -6
@@ -1,333 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import pkg_resources
4
- import torch
5
- import yaml
6
- from sklearn.metrics import *
7
- from yaml.loader import SafeLoader
8
-
9
-
10
- from MEDfl.NetManager.database_connector import DatabaseManager
11
-
12
- # from scripts.base import *
13
- import json
14
-
15
-
16
- import pandas as pd
17
- import numpy as np
18
-
19
- import os
20
- import configparser
21
-
22
- import subprocess
23
- import ast
24
-
25
- from sqlalchemy import text
26
-
27
- from dotenv import load_dotenv, set_key
28
-
29
-
30
- # Get the directory of the current script
31
- current_directory = os.path.dirname(os.path.abspath(__file__))
32
-
33
- # Load configuration from the config file
34
- yaml_path = os.path.join(current_directory, 'params.yaml')
35
-
36
- with open(yaml_path) as g:
37
- params = yaml.load(g, Loader=SafeLoader)
38
-
39
- # global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
40
- # with open(global_yaml_path) as g:
41
- # global_params = yaml.load(g, Loader=SafeLoader)
42
-
43
-
44
- # Default path for the config file
45
- DEFAULT_CONFIG_PATH = 'db_config.ini'
46
-
47
-
48
- # def load_db_config():
49
- # config = os.environ.get('MEDfl_DB_CONFIG')
50
-
51
- # if config:
52
- # return ast.literal_eval(config)
53
- # else:
54
- # raise ValueError(f"MEDfl db config not found")
55
-
56
- def load_db_config():
57
- # Load environment variables from .env file
58
- load_dotenv()
59
-
60
- config = os.getenv('MEDfl_DB_CONFIG')
61
-
62
- if config:
63
- return ast.literal_eval(config)
64
- else:
65
- raise ValueError(f"MEDfl db config not found")
66
-
67
-
68
- # Function to allow users to set config path programmatically
69
-
70
-
71
- # def set_db_config(config_path):
72
- # config = configparser.ConfigParser()
73
- # config.read(config_path)
74
- # if (config['mysql']):
75
- # os.environ['MEDfl_DB_CONFIG'] = str(dict(config['mysql']))
76
- # else:
77
- # raise ValueError(f"mysql key not found in file '{config_path}'")
78
-
79
- def set_db_config(config_path):
80
- config = configparser.ConfigParser()
81
- config.read(config_path)
82
- if 'mysql' in config:
83
- env_path = '.env'
84
- load_dotenv(dotenv_path=env_path)
85
- for key, value in config['mysql'].items():
86
- set_key(env_path, key, value)
87
- os.environ[key] = value
88
- else:
89
- raise ValueError(f"mysql key not found in file '{config_path}'")
90
-
91
-
92
- # Create databas
93
-
94
-
95
- def create_MEDfl_db():
96
- script_path = os.path.join(os.path.dirname(
97
- __file__), 'scripts', 'create_db.sh')
98
- subprocess.run(['sh', script_path], check=True)
99
-
100
-
101
- def custom_classification_report(y_true, y_pred_prob):
102
- """
103
- Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
104
- F1-score, false positive rate, and true positive rate.
105
-
106
- Args:
107
- y_true (array-like): True labels.
108
- y_pred (array-like): Predicted labels.
109
-
110
- Returns:
111
- dict: A dictionary containing custom classification report metrics.
112
- """
113
- y_pred = (y_pred_prob).round(
114
- ) # Round absolute values of predicted probabilities to the nearest integer
115
-
116
- auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
117
-
118
- tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
119
-
120
- # Accuracy
121
- denominator_acc = tp + tn + fp + fn
122
- acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
123
-
124
- # Sensitivity/Recall
125
- denominator_sen = tp + fn
126
- sen = tp / denominator_sen if denominator_sen != 0 else 0.0
127
-
128
- # Specificity
129
- denominator_sp = tn + fp
130
- sp = tn / denominator_sp if denominator_sp != 0 else 0.0
131
-
132
- # PPV/Precision
133
- denominator_ppv = tp + fp
134
- ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
135
-
136
- # NPV
137
- denominator_npv = tn + fn
138
- npv = tn / denominator_npv if denominator_npv != 0 else 0.0
139
-
140
- # F1 Score
141
- denominator_f1 = sen + ppv
142
- f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
143
-
144
- # False Positive Rate
145
- denominator_fpr = fp + tn
146
- fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
147
-
148
- # True Positive Rate
149
- denominator_tpr = tp + fn
150
- tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
151
-
152
- return {
153
- "confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
154
- "Accuracy": round(acc, 3),
155
- "Sensitivity/Recall": round(sen, 3),
156
- "Specificity": round(sp, 3),
157
- "PPV/Precision": round(ppv, 3),
158
- "NPV": round(npv, 3),
159
- "F1-score": round(f1, 3),
160
- "False positive rate": round(fpr, 3),
161
- "True positive rate": round(tpr, 3),
162
- "auc": auc
163
- }
164
-
165
-
166
- def test(model, test_loader, device=torch.device("cpu")):
167
- """
168
- Evaluate a model using a test loader and return a custom classification report.
169
-
170
- Args:
171
- model (torch.nn.Module): PyTorch model to evaluate.
172
- test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
173
- device (torch.device, optional): Device for model evaluation. Default is "cpu".
174
-
175
- Returns:
176
- dict: A dictionary containing custom classification report metrics.
177
- """
178
-
179
- model.eval()
180
- with torch.no_grad():
181
- X_test, y_test = test_loader.dataset[:][0].to(
182
- device), test_loader.dataset[:][1].to(device)
183
- y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
184
-
185
- return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
186
-
187
-
188
- column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
189
-
190
-
191
- def empty_db():
192
- """
193
- Empty the database by deleting records from multiple tables and resetting auto-increment counters.
194
-
195
- Returns:
196
- None
197
- """
198
- db_manager = DatabaseManager()
199
- db_manager.connect()
200
- my_eng = db_manager.get_connection()
201
-
202
- # my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
203
- my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
204
- my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
205
- my_eng.execute(text(f"DELETE FROM {'Networks'}"))
206
- my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
207
-
208
- my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
209
- my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
210
- my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
211
- my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
212
- my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
213
- my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
214
- my_eng.execute(text(f"DELETE FROM {'testResults'}"))
215
- my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
216
- my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
217
-
218
-
219
- def get_pipeline_from_name(name):
220
- """
221
- Get the pipeline ID from its name in the database.
222
-
223
- Args:
224
- name (str): Name of the pipeline.
225
-
226
- Returns:
227
- int: ID of the pipeline.
228
- """
229
- db_manager = DatabaseManager()
230
- db_manager.connect()
231
- my_eng = db_manager.get_connection()
232
-
233
- NodeId = int(
234
- pd.read_sql(
235
- text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
236
- ).iloc[0, 0]
237
- )
238
- return NodeId
239
-
240
-
241
- def get_pipeline_confusion_matrix(pipeline_id):
242
- """
243
- Get the global confusion matrix for a pipeline based on test results.
244
-
245
- Args:
246
- pipeline_id (int): ID of the pipeline.
247
-
248
- Returns:
249
- dict: A dictionary representing the global confusion matrix.
250
- """
251
- db_manager = DatabaseManager()
252
- db_manager.connect()
253
- my_eng = db_manager.get_connection()
254
-
255
- data = pd.read_sql(
256
- text(
257
- f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
258
- )
259
-
260
- # Convert the column of strings into a list of dictionaries representing confusion matrices
261
- confusion_matrices = [
262
- json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
263
- ]
264
-
265
- # Initialize variables for global confusion matrix
266
- global_TP = global_FP = global_FN = global_TN = 0
267
-
268
- # Iterate through each dictionary and sum the corresponding values for each category
269
- for matrix in confusion_matrices:
270
- global_TP += matrix['TP']
271
- global_FP += matrix['FP']
272
- global_FN += matrix['FN']
273
- global_TN += matrix['TN']
274
-
275
- # Create a global confusion matrix as a dictionary
276
- global_confusion_matrix = {
277
- 'TP': global_TP,
278
- 'FP': global_FP,
279
- 'FN': global_FN,
280
- 'TN': global_TN
281
- }
282
- # Return the list of dictionaries representing confusion matrices
283
- return global_confusion_matrix
284
-
285
-
286
- def get_node_confusion_matrix(pipeline_id, node_name):
287
- """
288
- Get the confusion matrix for a specific node in a pipeline based on test results.
289
-
290
- Args:
291
- pipeline_id (int): ID of the pipeline.
292
- node_name (str): Name of the node.
293
-
294
- Returns:
295
- dict: A dictionary representing the confusion matrix for the specified node.
296
- """
297
- db_manager = DatabaseManager()
298
- db_manager.connect()
299
- my_eng = db_manager.get_connection()
300
-
301
- data = pd.read_sql(
302
- text(
303
- f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
304
- )
305
-
306
- # Convert the column of strings into a list of dictionaries representing confusion matrices
307
- confusion_matrices = [
308
- json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
309
- ]
310
-
311
- # Return the list of dictionaries representing confusion matrices
312
- return confusion_matrices[0]
313
-
314
-
315
- def get_pipeline_result(pipeline_id):
316
- """
317
- Get the test results for a pipeline.
318
-
319
- Args:
320
- pipeline_id (int): ID of the pipeline.
321
-
322
- Returns:
323
- pandas.DataFrame: DataFrame containing test results for the specified pipeline.
324
- """
325
- db_manager = DatabaseManager()
326
- db_manager.connect()
327
- my_eng = db_manager.get_connection()
328
-
329
- data = pd.read_sql(
330
- text(
331
- f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
332
- )
333
- return data
@@ -1,9 +0,0 @@
1
- # # MEDfl/NetworkManager/__init__.py
2
-
3
- # # Import modules from this package
4
- # from .dataset import *
5
- # from .flsetup import *
6
- # from .net_helper import *
7
- # from .net_manager_queries import *
8
- # from .network import *
9
- # from .node import *
@@ -1,48 +0,0 @@
1
- import os
2
- from sqlalchemy import create_engine
3
- from configparser import ConfigParser
4
-
5
- import subprocess
6
-
7
- class DatabaseManager:
8
- def __init__(self):
9
-
10
- from MEDfl.LearningManager.utils import load_db_config
11
- db_config = load_db_config()
12
- if db_config:
13
- self.config = db_config
14
- else:
15
- self.config = None
16
- self.engine = None
17
-
18
- def connect(self):
19
- if not self.config:
20
- raise ValueError("Database configuration not loaded. Use load_db_config() or set_config_path() first.")
21
- connection_string = (
22
- f"mysql+mysqlconnector://{self.config['user']}:{self.config['password']}@"
23
- f"{self.config['host']}:{self.config['port']}/{self.config['database']}"
24
- )
25
- self.engine = create_engine(connection_string, pool_pre_ping=True)
26
-
27
- def get_connection(self):
28
- if not self.engine:
29
- self.connect()
30
- return self.engine.connect()
31
-
32
- def create_MEDfl_db(self , path_to_csv):
33
- # Get the directory of the current script
34
- current_directory = os.path.dirname(__file__)
35
-
36
- # Define the path to the create_db.py script
37
- create_db_script_path = os.path.join(current_directory, '..','..', 'scripts', 'create_db.py')
38
-
39
- # Execute the create_db.py script
40
- subprocess.run(['python', create_db_script_path, path_to_csv], check=True)
41
-
42
- return
43
-
44
- def close(self):
45
- if self.engine:
46
- self.engine.dispose()
47
-
48
-
@@ -1,92 +0,0 @@
1
- import pandas as pd
2
- from sqlalchemy import text
3
-
4
- from .net_helper import *
5
- from .net_manager_queries import (DELETE_DATASET, INSERT_DATASET,
6
- SELECT_ALL_DATASET_NAMES)
7
- from MEDfl.NetManager.database_connector import DatabaseManager
8
-
9
- class DataSet:
10
- def __init__(self, name: str, path: str, engine=None):
11
- """
12
- Initialize a DataSet object.
13
-
14
- :param name: The name of the dataset.
15
- :type name: str
16
- :param path: The file path of the dataset CSV file.
17
- :type path: str
18
- """
19
- self.name = name
20
- self.path = path
21
- db_manager = DatabaseManager()
22
- db_manager.connect()
23
- self.engine = db_manager.get_connection()
24
-
25
- def validate(self):
26
- """
27
- Validate name and path attributes.
28
-
29
- :raises TypeError: If name or path is not a string.
30
- """
31
- if not isinstance(self.name, str):
32
- raise TypeError("name argument must be a string")
33
-
34
- if not isinstance(self.path, str):
35
- raise TypeError("path argument must be a string")
36
-
37
- def upload_dataset(self, NodeId=-1):
38
- """
39
- Upload the dataset to the database.
40
-
41
- :param NodeId: The NodeId associated with the dataset.
42
- :type NodeId: int
43
-
44
- Notes:
45
- - Assumes the file at self.path is a valid CSV file.
46
- - The dataset is uploaded to the 'DataSets' table in the database.
47
- """
48
-
49
- data_df = pd.read_csv(self.path)
50
- nodeId = NodeId
51
- columns = data_df.columns.tolist()
52
-
53
-
54
- data_df = process_eicu(data_df)
55
- for index, row in data_df.iterrows():
56
- query_1 = "INSERT INTO DataSets(DataSetName,nodeId," + "".join(
57
- f"{x}," for x in columns
58
- )
59
- query_2 = f" VALUES ('{self.name}',{nodeId}, " + "".join(
60
- f"{is_str(data_df, row, x)}," for x in columns
61
- )
62
- query = query_1[:-1] + ")" + query_2[:-1] + ")"
63
-
64
- self.engine.execute(text(query))
65
-
66
- def delete_dataset(self):
67
- """
68
- Delete the dataset from the database.
69
-
70
- Notes:
71
- - Assumes the dataset name is unique in the 'DataSets' table.
72
- """
73
- self.engine.execute(text(DELETE_DATASET), {"name": self.name})
74
-
75
- def update_data(self):
76
- """
77
- Update the data in the dataset.
78
-
79
- Not implemented yet.
80
- """
81
- pass
82
-
83
- @staticmethod
84
- def list_alldatasets(engine):
85
- """
86
- List all dataset names from the 'DataSets' table.
87
-
88
- :returns: A DataFrame containing the names of all datasets in the 'DataSets' table.
89
- :rtype: pd.DataFrame
90
- """
91
- res = pd.read_sql(text(SELECT_ALL_DATASET_NAMES), engine)
92
- return res