FedModelKit 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FedModelKit/README.md +25 -0
- FedModelKit/__init__.py +17 -0
- FedModelKit/aggregator.py +41 -0
- FedModelKit/cli.py +97 -0
- FedModelKit/default_create_functions.py +114 -0
- FedModelKit/interface.py +130 -0
- FedModelKit/local_learner.py +61 -0
- FedModelKit/py.typed +0 -0
- FedModelKit/src/utils.py +65 -0
- FedModelKit/templates/__init__template.py +0 -0
- FedModelKit/templates/client_app_template.py +118 -0
- FedModelKit/templates/ds_template.ipynb +332 -0
- FedModelKit/templates/extern_pyproject_template.toml +17 -0
- FedModelKit/templates/images/doSendModels.png +0 -0
- FedModelKit/templates/images/doWaitsForJobs.png +0 -0
- FedModelKit/templates/images/dsAggregateModels.png +0 -0
- FedModelKit/templates/images/dsDoneSubmittingJobs.png +0 -0
- FedModelKit/templates/images/dsSendsJobs.png +0 -0
- FedModelKit/templates/images/overview.png +0 -0
- FedModelKit/templates/main_template.py +27 -0
- FedModelKit/templates/pyproject_template.toml +53 -0
- FedModelKit/templates/readme_template.md +48 -0
- FedModelKit/templates/server_app_template.py +204 -0
- FedModelKit/templates/task_template.py +140 -0
- FedModelKit/templates/uv_template.lock +2812 -0
- FedModelKit/templates.py +76 -0
- fedmodelkit-0.5.0.dist-info/METADATA +283 -0
- fedmodelkit-0.5.0.dist-info/RECORD +31 -0
- fedmodelkit-0.5.0.dist-info/WHEEL +4 -0
- fedmodelkit-0.5.0.dist-info/entry_points.txt +2 -0
- fedmodelkit-0.5.0.dist-info/licenses/LICENSE +23 -0
FedModelKit/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
### **Interface Directory**
|
|
2
|
+
|
|
3
|
+
1. **`mlflow_interface.py`**
|
|
4
|
+
- This file defines the `FederatedModel` class, which encapsulates the local learner, the aggregation strategy, and their respective names. It facilitates the integration of these components into the MLflow tracking system. It also includes the `submit_fl_model` function for logging the model to MLflow.
|
|
5
|
+
|
|
6
|
+
2. **`aggregator.py`**
|
|
7
|
+
- This file defines the `AggFactoryProtocol` and `AggProtocol` protocols. The `AggFactoryProtocol` is a protocol for the function creating the aggregator instance, while the `AggProtocol` defines the methods and attributes that an aggregator must implement.
|
|
8
|
+
|
|
9
|
+
3. **`local_learner.py`**
|
|
10
|
+
- This file defines the `LLFactoryProtocol` and `LLProtocol` protocols. The `LLFactoryProtocol` is a protocol for the function creating the local learner instance, while the `LLProtocol` defines the methods and attributes that a local learner must implement. This file provides the structure for implementing custom local learners.
|
|
11
|
+
|
|
12
|
+
4. **`default_create_functions.py`**
|
|
13
|
+
- This file provides default implementations for creating local learners and aggregators. The `default_create_local_learner` and `default_create_aggregator` functions.
|
|
14
|
+
|
|
15
|
+
5. **`cli.py`**
|
|
16
|
+
- This file contains the `create_structure` function, which sets up the project structure for uploading models to MLflow. It creates necessary directories and files, including `utils.py`, `example.py`, and `README.md`.
|
|
17
|
+
|
|
18
|
+
6. **`templates.py`**
|
|
19
|
+
- This file contains template functions for generating the content of `utils.py`, `example.py`, and `README.md`. It includes the `get_utils_template`, `get_main_template`, and `get_readme_template` functions.
|
|
20
|
+
|
|
21
|
+
7. **`readme_template.txt`**
|
|
22
|
+
- This file contains the template for the `README.md` file that is generated by the `create_structure` function. It provides a detailed guide on how to use the package functionalities.
|
|
23
|
+
|
|
24
|
+
8. **`src/utils.py`**
|
|
25
|
+
- This file contains utility functions for managing and converting Flower objects. It includes methods for serializing and deserializing model parameters and other helper functions.
|
FedModelKit/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
'''from .interface import FederatedModel, submit_fl_model'''
|
|
2
|
+
from .interface import FederatedModel
|
|
3
|
+
from .local_learner import LLFactoryProtocol, LLProtocol
|
|
4
|
+
from .aggregator import AggFactoryProtocol, AggProtocol
|
|
5
|
+
from .default_create_functions import default_create_local_learner, default_create_aggregator
|
|
6
|
+
from .cli import create_structure
|
|
7
|
+
import flwr
|
|
8
|
+
|
|
9
|
+
__all__ = ["FederatedModel",
|
|
10
|
+
"LLFactoryProtocol",
|
|
11
|
+
"LLProtocol",
|
|
12
|
+
"AggFactoryProtocol",
|
|
13
|
+
"AggProtocol",
|
|
14
|
+
"default_create_local_learner",
|
|
15
|
+
"default_create_aggregator",
|
|
16
|
+
"create_structure",
|
|
17
|
+
"flwr"]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import flwr
|
|
2
|
+
from typing import Protocol, runtime_checkable
|
|
3
|
+
from typing import Protocol, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@runtime_checkable
|
|
7
|
+
class AggProtocol(Protocol):
|
|
8
|
+
"""
|
|
9
|
+
Defines the required structure for aggregation strategies in federated learning.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.ParametersRecord:
|
|
13
|
+
"""
|
|
14
|
+
Aggregates model parameters from a list of clients.
|
|
15
|
+
Args:
|
|
16
|
+
results: List of `flwr.common.ParametersRecord` from clients.
|
|
17
|
+
Returns:
|
|
18
|
+
Aggregated `flwr.common.ParametersRecord`.
|
|
19
|
+
"""
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
|
|
23
|
+
"""
|
|
24
|
+
Aggregates metrics from a list of clients.
|
|
25
|
+
Args:
|
|
26
|
+
results: List of `flwr.common.MetricsRecord` from clients.
|
|
27
|
+
Returns:
|
|
28
|
+
Aggregated `flwr.common.MetricsRecord`.
|
|
29
|
+
"""
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@runtime_checkable
|
|
33
|
+
class AggFactoryProtocol(Protocol):
|
|
34
|
+
'''
|
|
35
|
+
Protocol for the function creating the aggregator instance
|
|
36
|
+
'''
|
|
37
|
+
def __call__(self) -> AggProtocol:
|
|
38
|
+
"""
|
|
39
|
+
Creates and returns an instance of a class conforming to AggProtocol.
|
|
40
|
+
"""
|
|
41
|
+
...
|
FedModelKit/cli.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
from .templates import (
|
|
5
|
+
get_main_template,
|
|
6
|
+
get_task_template,
|
|
7
|
+
get_readme_template,
|
|
8
|
+
get_server_template,
|
|
9
|
+
get_client_template,
|
|
10
|
+
get_pyproject_template,
|
|
11
|
+
get_extern_pyproject_template,
|
|
12
|
+
get_uv_template,
|
|
13
|
+
get_init_template,
|
|
14
|
+
get_ds_template,
|
|
15
|
+
get_images,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def create_structure(exp_name: str = "new_experiment") -> None:
|
|
19
|
+
root_dir = os.getcwd()
|
|
20
|
+
|
|
21
|
+
structure = {
|
|
22
|
+
f"{exp_name}": {
|
|
23
|
+
"main.py": get_main_template(),
|
|
24
|
+
"pyproject.toml": get_pyproject_template(exp_name),
|
|
25
|
+
f"{exp_name}": {
|
|
26
|
+
"__init__.py": get_init_template(),
|
|
27
|
+
"task.py": get_task_template(),
|
|
28
|
+
"server_app.py": get_server_template(),
|
|
29
|
+
"client_app.py": get_client_template(),
|
|
30
|
+
},
|
|
31
|
+
},
|
|
32
|
+
"pyproject.toml": get_extern_pyproject_template(exp_name),
|
|
33
|
+
"uv.lock": get_uv_template(),
|
|
34
|
+
"README.md": get_readme_template(),
|
|
35
|
+
"ds.ipynb": get_ds_template(),
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
for name, content in structure.items():
|
|
40
|
+
path = os.path.join(root_dir, name)
|
|
41
|
+
if isinstance(content, dict): # If content is a folder
|
|
42
|
+
os.makedirs(path, exist_ok=True)
|
|
43
|
+
for name1, content1 in content.items():
|
|
44
|
+
nested_path = os.path.join(path, name1)
|
|
45
|
+
if isinstance(content1, dict): # If nested content is a folder
|
|
46
|
+
os.makedirs(nested_path, exist_ok=True)
|
|
47
|
+
for name2, content2 in content1.items():
|
|
48
|
+
nested_nested_path = os.path.join(nested_path, name2)
|
|
49
|
+
with open(nested_nested_path, "w") as f:
|
|
50
|
+
f.write(content2)
|
|
51
|
+
else:
|
|
52
|
+
with open(nested_path, "w") as f:
|
|
53
|
+
f.write(content1)
|
|
54
|
+
else: # Single file at the root
|
|
55
|
+
with open(path, "w") as f:
|
|
56
|
+
f.write(content)
|
|
57
|
+
# Create images directory and copy images
|
|
58
|
+
get_images(root_dir)
|
|
59
|
+
|
|
60
|
+
print(f"Project structure for {exp_name} created successfully in {root_dir}.")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
'''def main():
|
|
64
|
+
if len(sys.argv) < 2:
|
|
65
|
+
print("Usage: msi <command> [options]")
|
|
66
|
+
sys.exit(1)
|
|
67
|
+
|
|
68
|
+
command = sys.argv[1]
|
|
69
|
+
|
|
70
|
+
if command == "init":
|
|
71
|
+
structure_type = sys.argv[2] if len(sys.argv) > 2 else "default"
|
|
72
|
+
create_structure(structure_type)
|
|
73
|
+
else:
|
|
74
|
+
print(f"Unknown command: {command}")
|
|
75
|
+
sys.exit(1)
|
|
76
|
+
'''
|
|
77
|
+
|
|
78
|
+
def main():
|
|
79
|
+
if len(sys.argv) < 2:
|
|
80
|
+
print("Usage: fmk init -n/--name [project_name]")
|
|
81
|
+
sys.exit(1)
|
|
82
|
+
|
|
83
|
+
parser = argparse.ArgumentParser(prog='fmk', description='FedModelKit CLI for managing federated learning projects')
|
|
84
|
+
subparsers = parser.add_subparsers(dest='command')
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# init subcommand
|
|
88
|
+
init_parser = subparsers.add_parser('init', help='Initialize the project')
|
|
89
|
+
init_parser.add_argument('-n', '--name', required=True, help='Set the experiment title')
|
|
90
|
+
|
|
91
|
+
args = parser.parse_args()
|
|
92
|
+
|
|
93
|
+
if args.command == 'init':
|
|
94
|
+
print(f"Initializing project with title: {args.name}")
|
|
95
|
+
|
|
96
|
+
# Create the project structure
|
|
97
|
+
create_structure(args.name)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Create your custom function defining a PyTorch-based model and returning its instance
|
|
2
|
+
def default_create_local_learner():
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from torch import nn, optim
|
|
5
|
+
from torch.utils.data import DataLoader, Dataset
|
|
6
|
+
import torch
|
|
7
|
+
import flwr
|
|
8
|
+
|
|
9
|
+
from .src.utils import Utils # Dependency script stored inside the src directory
|
|
10
|
+
|
|
11
|
+
class DefaultLocalLearner(nn.Module):
|
|
12
|
+
def __init__(self, input_size: int) -> None:
|
|
13
|
+
super(DefaultLocalLearner, self).__init__()
|
|
14
|
+
self.linear = nn.Linear(input_size, 3)
|
|
15
|
+
self.softmax = nn.Softmax(dim=1)
|
|
16
|
+
self.loss_fn = nn.CrossEntropyLoss()
|
|
17
|
+
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
|
18
|
+
|
|
19
|
+
def _forward(self, x):
|
|
20
|
+
x = self.linear(x)
|
|
21
|
+
return self.softmax(x)
|
|
22
|
+
|
|
23
|
+
def get_parameters(self) -> flwr.common.ParametersRecord:
|
|
24
|
+
return Utils.pytorch_to_parameter_record(self.state_dict())
|
|
25
|
+
|
|
26
|
+
def set_parameters(self, parameters: flwr.common.ParametersRecord) -> None:
|
|
27
|
+
self.load_state_dict(Utils.parameters_to_pytorch_state_dict(parameters))
|
|
28
|
+
|
|
29
|
+
def prepare_data(self, data: pd.DataFrame) -> None:
|
|
30
|
+
class IrisDataset(Dataset):
|
|
31
|
+
def __init__(self, dataframe: pd.DataFrame) -> None:
|
|
32
|
+
self.dataframe = dataframe
|
|
33
|
+
self.dataframe.loc[:, "class"] = dataframe.loc[:, "class"].replace(
|
|
34
|
+
{"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, idx: int):
|
|
38
|
+
x = self.dataframe.iloc[idx, :-1].to_numpy("float32")
|
|
39
|
+
y = torch.tensor(self.dataframe.iloc[idx, -1], dtype=torch.long)
|
|
40
|
+
return x, y
|
|
41
|
+
|
|
42
|
+
def __len__(self) -> int:
|
|
43
|
+
return len(self.dataframe)
|
|
44
|
+
|
|
45
|
+
dataset = IrisDataset(data)
|
|
46
|
+
dataloader = DataLoader(dataset, batch_size=32)
|
|
47
|
+
self.dataloader = dataloader
|
|
48
|
+
|
|
49
|
+
def train_round(self) -> flwr.common.MetricsRecord: # type: ignore
|
|
50
|
+
for batch in self.dataloader:
|
|
51
|
+
x, y = batch
|
|
52
|
+
self.optimizer.zero_grad()
|
|
53
|
+
y_hat = self._forward(x)
|
|
54
|
+
loss = self.loss_fn(y_hat, y)
|
|
55
|
+
loss.backward()
|
|
56
|
+
self.optimizer.step()
|
|
57
|
+
|
|
58
|
+
return flwr.common.MetricsRecord({"loss": loss.item()}) # type: ignore
|
|
59
|
+
|
|
60
|
+
def evaluate(self) -> flwr.common.MetricsRecord:
|
|
61
|
+
correct = 0
|
|
62
|
+
total = 0
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
for batch in self.dataloader:
|
|
65
|
+
x, y = batch
|
|
66
|
+
y_hat = self._forward(x)
|
|
67
|
+
_, predicted = torch.max(y_hat, 1)
|
|
68
|
+
total += y.size(0)
|
|
69
|
+
correct += (predicted == y).sum().item()
|
|
70
|
+
return flwr.common.MetricsRecord({"accuracy": correct / total})
|
|
71
|
+
|
|
72
|
+
return DefaultLocalLearner(4)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Create your custom function defining an aggregation strategy and returning its instance
|
|
76
|
+
def default_create_aggregator():
|
|
77
|
+
from collections import OrderedDict
|
|
78
|
+
import numpy as np
|
|
79
|
+
import flwr
|
|
80
|
+
from typing import Optional
|
|
81
|
+
|
|
82
|
+
from .src.utils import Utils # Dependency script stored inside the src directory
|
|
83
|
+
|
|
84
|
+
class DefaultAggregator:
|
|
85
|
+
|
|
86
|
+
def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None
|
|
87
|
+
) -> flwr.common.ParametersRecord:
|
|
88
|
+
parameters = [Utils.parameters_to_dict(param) for param in results]
|
|
89
|
+
keys = parameters[0].keys()
|
|
90
|
+
result = OrderedDict()
|
|
91
|
+
for key in keys:
|
|
92
|
+
# Init array
|
|
93
|
+
this_array: np.ndarray = np.zeros_like(parameters[0][key])
|
|
94
|
+
for p in parameters:
|
|
95
|
+
this_array += p[key]
|
|
96
|
+
result[key] = this_array / len(results)
|
|
97
|
+
return Utils.dict_to_parameter_record(result)
|
|
98
|
+
|
|
99
|
+
def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
|
|
100
|
+
keys = results[0].keys()
|
|
101
|
+
result = OrderedDict()
|
|
102
|
+
for key in keys:
|
|
103
|
+
# Init array
|
|
104
|
+
cumsum = 0.0
|
|
105
|
+
for m in results:
|
|
106
|
+
if not isinstance(m[key], (int, float)):
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"flwr.common.MetricsRecord value type not supported: {type(m[key])}"
|
|
109
|
+
)
|
|
110
|
+
cumsum += m[key] # type: ignore
|
|
111
|
+
result[key] = cumsum / len(results)
|
|
112
|
+
return flwr.common.MetricsRecord(result) # type: ignore
|
|
113
|
+
|
|
114
|
+
return DefaultAggregator()
|
FedModelKit/interface.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
'''import mlflow
|
|
2
|
+
from mlflow import MlflowClient
|
|
3
|
+
from mlflow.pyfunc.model import PythonModel
|
|
4
|
+
from mlflow.pyfunc import log_model'''
|
|
5
|
+
|
|
6
|
+
from .local_learner import LLFactoryProtocol, LLProtocol
|
|
7
|
+
from .aggregator import AggFactoryProtocol, AggProtocol
|
|
8
|
+
from .default_create_functions import default_create_local_learner, default_create_aggregator
|
|
9
|
+
|
|
10
|
+
from typing import Literal
|
|
11
|
+
import os
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FederatedModel():
|
|
16
|
+
"""
|
|
17
|
+
A class for creating a federated learning model and its aggregator. If no aggregator is provided
|
|
18
|
+
the PlainAVGAggregator is passed as default
|
|
19
|
+
|
|
20
|
+
This class encapsulates the local learner, the aggregation strategy, and their respective names.
|
|
21
|
+
It is designed to facilitate the integration of these components into the MLflow tracking system.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
aggregator (AggFactoryProtocol): The aggregation strategy for federated learning.
|
|
25
|
+
aggregator_name (str): The name of the aggregation strategy (default is "PlainAVG").
|
|
26
|
+
local_learner (LLFactoryProtocol): The local learner factory for creating model instances.
|
|
27
|
+
model_name (str): The name of the model to be registered in MLflow.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, create_local_learner: LLFactoryProtocol = default_create_local_learner,
|
|
30
|
+
model_name: str = "Default_iris_model",
|
|
31
|
+
create_aggregator: AggFactoryProtocol = default_create_aggregator,
|
|
32
|
+
aggregator_name: str = "PlainAVG"
|
|
33
|
+
) -> None:
|
|
34
|
+
self.create_aggregator = create_aggregator
|
|
35
|
+
self.aggregator_name = aggregator_name
|
|
36
|
+
self.create_local_learner = create_local_learner
|
|
37
|
+
self.model_name = model_name
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
'''def submit_fl_model(model: FederatedModel,
|
|
42
|
+
platform_url: str,
|
|
43
|
+
username: str,
|
|
44
|
+
password: str,
|
|
45
|
+
experiment_name: str,
|
|
46
|
+
disease: Literal["AML", "SCD"],
|
|
47
|
+
trained: bool) -> dict:
|
|
48
|
+
"""
|
|
49
|
+
Submit the model and aggregation strategy to MLflow.
|
|
50
|
+
|
|
51
|
+
This function logs the model and aggregator as artifacts, registers the model in MLflow,
|
|
52
|
+
and ensures a new run is created in the specified experiment.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model (FederatedModel): The federated learning model to be submitted.
|
|
56
|
+
platform_url (str): The URL of the MLflow tracking server.
|
|
57
|
+
username (str): The username for authenticating with the MLflow tracking server.
|
|
58
|
+
password (str): The password for authenticating with the MLflow tracking server.
|
|
59
|
+
experiment_name (str): The name of the MLflow experiment to use or create.
|
|
60
|
+
disease (Literal["AML", "SCD"]): The use case for the model.
|
|
61
|
+
trained (bool): Indicates whether the model is already trained.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
RuntimeError: If the username or password is not provided.
|
|
65
|
+
"""
|
|
66
|
+
# Ignore warnings about input example
|
|
67
|
+
warnings.filterwarnings("ignore")
|
|
68
|
+
|
|
69
|
+
# Check if username and password are provided
|
|
70
|
+
if not username or not password:
|
|
71
|
+
raise RuntimeError("Username and password must be provided.")
|
|
72
|
+
# Check if the use case is valid
|
|
73
|
+
assert disease in ["AML", "SCD"], "Disease must be either 'AML' or 'SCD'"
|
|
74
|
+
|
|
75
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
76
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
77
|
+
os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true"
|
|
78
|
+
|
|
79
|
+
MLFLOW_URL = platform_url
|
|
80
|
+
mlflow.set_tracking_uri(MLFLOW_URL)
|
|
81
|
+
|
|
82
|
+
# Test model and aggregator protocol
|
|
83
|
+
model_instance = model.create_local_learner()
|
|
84
|
+
aggregator_instance = model.create_aggregator()
|
|
85
|
+
assert isinstance(model.create_aggregator, AggFactoryProtocol), "create_aggregator function does not conform to AggFactoryProtocol"
|
|
86
|
+
assert isinstance(model.create_local_learner, LLFactoryProtocol), "create_local_learner function does not conform to LLFactoryProtocol"
|
|
87
|
+
assert isinstance(model_instance, LLProtocol), "Local learner does not conform to LLProtocol"
|
|
88
|
+
assert isinstance(aggregator_instance, AggProtocol), "Aggregator instance does not conform to AggProtocol"
|
|
89
|
+
|
|
90
|
+
# Ensure the experiment exists or create it
|
|
91
|
+
experiment = mlflow.get_experiment_by_name(experiment_name)
|
|
92
|
+
if experiment is None:
|
|
93
|
+
experiment_id = mlflow.create_experiment(experiment_name)
|
|
94
|
+
else:
|
|
95
|
+
experiment_id = experiment.experiment_id
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Start a new run
|
|
99
|
+
with mlflow.start_run(experiment_id=experiment_id) as run:
|
|
100
|
+
mlflow.set_tag("use_case", disease)
|
|
101
|
+
mlflow.set_tag("trained", str(trained))
|
|
102
|
+
if os.path.isdir("./src"):
|
|
103
|
+
model_info = log_model(
|
|
104
|
+
artifact_path="model",
|
|
105
|
+
python_model=model,
|
|
106
|
+
registered_model_name=model.model_name,
|
|
107
|
+
code_paths=["src"],
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
model_info = log_model(
|
|
111
|
+
artifact_path="model",
|
|
112
|
+
python_model=model,
|
|
113
|
+
registered_model_name=model.model_name,
|
|
114
|
+
)
|
|
115
|
+
mlflow_client = MlflowClient(tracking_uri=MLFLOW_URL)
|
|
116
|
+
model_meta = mlflow_client.get_latest_versions(model.model_name, stages=["None"])
|
|
117
|
+
version = model_meta[0].version
|
|
118
|
+
# mlflow_client.update_model_version(model.model_name, version, description)
|
|
119
|
+
tags = {"use_case": disease, "trained": str(trained)}
|
|
120
|
+
for key, value in tags.items():
|
|
121
|
+
mlflow_client.set_model_version_tag(model.model_name, version, key, value)
|
|
122
|
+
# mlflow_client.set_experiment_tag(experiment_id, "use_case", disease)
|
|
123
|
+
# mlflow_client.set_experiment_tag(experiment_id, "Trained", str(trained))
|
|
124
|
+
return {
|
|
125
|
+
"detail": f"Model '{model.model_name}' registered.",
|
|
126
|
+
"model_uuid": model_info.model_uuid,
|
|
127
|
+
"run_id": model_info.run_id,
|
|
128
|
+
"model_uri": model_info.model_uri,
|
|
129
|
+
}'''
|
|
130
|
+
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Protocol, runtime_checkable
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import flwr
|
|
5
|
+
|
|
6
|
+
@runtime_checkable
|
|
7
|
+
class LLProtocol(Protocol):
|
|
8
|
+
"""
|
|
9
|
+
Protocol for defining a local learner in a federated learning setup.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def prepare_data(self, data: pd.DataFrame) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Prepares input data for training or evaluation.
|
|
15
|
+
Args:
|
|
16
|
+
data: Input data as a pandas DataFrame.
|
|
17
|
+
"""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def train_round(self) -> flwr.common.MetricsRecord:
|
|
21
|
+
"""
|
|
22
|
+
Trains the model and returns metrics.
|
|
23
|
+
Returns:
|
|
24
|
+
flwr.common.MetricsRecord: Performance metrics after training.
|
|
25
|
+
"""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
def get_parameters(self) -> flwr.common.ParametersRecord:
|
|
29
|
+
"""
|
|
30
|
+
Retrieves the current model parameters.
|
|
31
|
+
Returns:
|
|
32
|
+
flwr.common.ParametersRecord: Current model parameters.
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
def set_parameters(self, parameters: flwr.common.ParametersRecord) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Sets the model parameters.
|
|
39
|
+
Args:
|
|
40
|
+
parameters: A flwr.common.ParametersRecord containing the parameters to set.
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
def evaluate(self) -> flwr.common.MetricsRecord:
|
|
45
|
+
"""
|
|
46
|
+
Evaluates the model and returns performance metrics.
|
|
47
|
+
Returns:
|
|
48
|
+
flwr.common.MetricsRecord: Evaluation metrics.
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
@runtime_checkable
|
|
53
|
+
class LLFactoryProtocol(Protocol):
|
|
54
|
+
'''
|
|
55
|
+
Protocol for the function creating the local learner instance
|
|
56
|
+
'''
|
|
57
|
+
def __call__(self) -> LLProtocol:
|
|
58
|
+
"""
|
|
59
|
+
Creates and returns an instance of a class conforming to LLProtocol.
|
|
60
|
+
"""
|
|
61
|
+
...
|
FedModelKit/py.typed
ADDED
|
File without changes
|
FedModelKit/src/utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
import flwr
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Utils:
|
|
8
|
+
import flwr
|
|
9
|
+
@staticmethod
|
|
10
|
+
def _basic_array_deserialisation(array: flwr.common.Array) -> flwr.common.NDArray:
|
|
11
|
+
return np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def parameters_to_dict(params_record: flwr.common.ParametersRecord) -> OrderedDict:
|
|
15
|
+
state_dict = OrderedDict()
|
|
16
|
+
for k, v in params_record.items():
|
|
17
|
+
state_dict[k] = Utils._basic_array_deserialisation(v)
|
|
18
|
+
|
|
19
|
+
return state_dict
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def _ndarray_to_array(ndarray: flwr.common.NDArray) -> flwr.common.Array:
|
|
23
|
+
"""Represent NumPy ndarray as Array."""
|
|
24
|
+
return flwr.common.Array(
|
|
25
|
+
data=ndarray.tobytes(),
|
|
26
|
+
dtype=str(ndarray.dtype),
|
|
27
|
+
stype="numpy.ndarray.tobytes",
|
|
28
|
+
shape=list(ndarray.shape),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def dict_to_parameter_record(
|
|
33
|
+
parameters: OrderedDict["str", flwr.common.NDArray],
|
|
34
|
+
) -> flwr.common.ParametersRecord:
|
|
35
|
+
state_dict = OrderedDict()
|
|
36
|
+
for k, v in parameters.items():
|
|
37
|
+
state_dict[k] = Utils._ndarray_to_array(v)
|
|
38
|
+
|
|
39
|
+
return flwr.common.ParametersRecord(state_dict)
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def pytorch_to_parameter_record(
|
|
43
|
+
state_dict: dict,
|
|
44
|
+
) -> flwr.common.ParametersRecord:
|
|
45
|
+
"""Serialise your PyTorch model."""
|
|
46
|
+
transformed_state_dict = OrderedDict()
|
|
47
|
+
|
|
48
|
+
for k, v in state_dict.items():
|
|
49
|
+
transformed_state_dict[k] = Utils._ndarray_to_array(v.numpy())
|
|
50
|
+
|
|
51
|
+
return flwr.common.ParametersRecord(transformed_state_dict)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def parameters_to_pytorch_state_dict(
|
|
55
|
+
params_record: flwr.common.ParametersRecord,
|
|
56
|
+
) -> dict:
|
|
57
|
+
# Make sure to import locally torch as it is only available in the server
|
|
58
|
+
import torch
|
|
59
|
+
|
|
60
|
+
"""Reconstruct PyTorch state_dict from its serialised representation."""
|
|
61
|
+
state_dict = {}
|
|
62
|
+
for k, v in params_record.items():
|
|
63
|
+
state_dict[k] = torch.tensor(Utils._basic_array_deserialisation(v))
|
|
64
|
+
|
|
65
|
+
return state_dict
|
|
File without changes
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import pickle
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from flwr.client import ClientApp
|
|
7
|
+
from flwr.common import Message, Context
|
|
8
|
+
from flwr.common.record import RecordSet, MetricsRecord, ConfigsRecord
|
|
9
|
+
from sklearn.preprocessing import OneHotEncoder
|
|
10
|
+
import FedModelKit as msi
|
|
11
|
+
|
|
12
|
+
from model_example import create_local_learner #type: ignore[import]
|
|
13
|
+
from load_data import load_data #type: ignore[import]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Initialize the Flower ClientApp
|
|
17
|
+
app = ClientApp()
|
|
18
|
+
|
|
19
|
+
@app.query()
|
|
20
|
+
def query(msg: Message, ctx: Context) -> Message:
|
|
21
|
+
"""
|
|
22
|
+
Query function to be executed by the Flower client. This function handles the
|
|
23
|
+
initial configuration sent by the server.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# Retrieve the configuration sent by the server
|
|
27
|
+
fancy_config = msg.content.configs_records['fancy_config']
|
|
28
|
+
|
|
29
|
+
# Load the client split data using the load_data function
|
|
30
|
+
data = load_data(fancy_config['num_clients'], fancy_config['client_id'])
|
|
31
|
+
|
|
32
|
+
# Instantiate the federated model
|
|
33
|
+
federated_model = msi.FederatedModel(create_local_learner=create_local_learner, model_name='simple_lr')
|
|
34
|
+
|
|
35
|
+
# Store the local learner and the data split in the context
|
|
36
|
+
# To store in context other objects, you can use ctx.state.<object_name> = <object>
|
|
37
|
+
ctx.state.local_learner = federated_model.create_local_learner()
|
|
38
|
+
ctx.state.data = data
|
|
39
|
+
ctx.state.undestand = "che succede"
|
|
40
|
+
|
|
41
|
+
return msg.create_reply(RecordSet())
|
|
42
|
+
|
|
43
|
+
@app.train()
|
|
44
|
+
def train(msg: Message, ctx: Context):
|
|
45
|
+
"""
|
|
46
|
+
Train function to be executed by the Flower client.
|
|
47
|
+
This function handles the training of the local model using the data provided.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
# Retrieve the local learner and the client split from the context
|
|
51
|
+
local_learner = ctx.state.local_learner
|
|
52
|
+
data = ctx.state.data
|
|
53
|
+
|
|
54
|
+
# Retrieve configuration sent by the server - example
|
|
55
|
+
#fancy_config = msg.content.configs_records['fancy_config']
|
|
56
|
+
#local_epochs = fancy_config['local_epochs']
|
|
57
|
+
|
|
58
|
+
# Retrieve the model parameters sent by the server
|
|
59
|
+
fancy_parameters = msg.content.parameters_records['fancy_model']
|
|
60
|
+
local_learner.set_parameters(fancy_parameters)
|
|
61
|
+
|
|
62
|
+
# Prepare the data using the local learner
|
|
63
|
+
local_learner.prepare_data(data)
|
|
64
|
+
|
|
65
|
+
# Perform local training and obtain training metrics
|
|
66
|
+
train_metrics = local_learner.train_round()
|
|
67
|
+
|
|
68
|
+
# Retrieve the trained model parameters
|
|
69
|
+
new_parameters_records = local_learner.get_parameters()
|
|
70
|
+
assert ctx.state.undestand.startswith("che"), "The context state is not being stored correctly"
|
|
71
|
+
|
|
72
|
+
# Construct a reply message carrying updated model parameters and generated metrics
|
|
73
|
+
reply_content = RecordSet()
|
|
74
|
+
reply_content.parameters_records['fancy_model_returned'] = new_parameters_records
|
|
75
|
+
reply_content.metrics_records['train_metrics'] = train_metrics
|
|
76
|
+
|
|
77
|
+
# Store the metrics and the local learner in the context for future reference
|
|
78
|
+
ctx.state.metrics_records['prev'] = train_metrics
|
|
79
|
+
ctx.state.local_learner = local_learner
|
|
80
|
+
|
|
81
|
+
# Return the reply message to the server
|
|
82
|
+
return msg.create_reply(reply_content)
|
|
83
|
+
|
|
84
|
+
@app.evaluate()
|
|
85
|
+
def eval(msg: Message, ctx: Context):
|
|
86
|
+
"""
|
|
87
|
+
Evaluate function to be executed by the Flower client.
|
|
88
|
+
This function handles the evaluation of the local model using the data provided.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
# Retrieve the local learner and the client split from the context
|
|
92
|
+
local_learner = ctx.state.local_learner
|
|
93
|
+
data = ctx.state.data
|
|
94
|
+
|
|
95
|
+
# Retrieve configuration sent by the server - example
|
|
96
|
+
#fancy_config = msg.content.configs_records['fancy_config']
|
|
97
|
+
#local_epochs = fancy_config['local_epochs']
|
|
98
|
+
|
|
99
|
+
# Retrieve the model parameters sent by the server
|
|
100
|
+
fancy_parameters = msg.content.parameters_records['fancy_model']
|
|
101
|
+
local_learner.set_parameters(fancy_parameters)
|
|
102
|
+
|
|
103
|
+
# Prepare the data using the local learner
|
|
104
|
+
local_learner.prepare_data(data)
|
|
105
|
+
|
|
106
|
+
# Evaluate the model and obtain evaluation metrics
|
|
107
|
+
eval_metrics = local_learner.evaluate()
|
|
108
|
+
|
|
109
|
+
# Construct a reply message with evaluation metrics
|
|
110
|
+
reply_content = RecordSet()
|
|
111
|
+
reply_content.metrics_records['eval_metrics'] = eval_metrics
|
|
112
|
+
|
|
113
|
+
# Store the metrics and the local learner in the context for future reference
|
|
114
|
+
ctx.state.metrics_records['prev'] = eval_metrics
|
|
115
|
+
ctx.state.local_learner = local_learner
|
|
116
|
+
|
|
117
|
+
# Return the reply message to the server
|
|
118
|
+
return msg.create_reply(reply_content)
|