FedModelKit 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
FedModelKit/README.md ADDED
@@ -0,0 +1,25 @@
1
+ ### **Interface Directory**
2
+
3
+ 1. **`mlflow_interface.py`**
4
+ - This file defines the `FederatedModel` class, which encapsulates the local learner, the aggregation strategy, and their respective names. It facilitates the integration of these components into the MLflow tracking system. It also includes the `submit_fl_model` function for logging the model to MLflow.
5
+
6
+ 2. **`aggregator.py`**
7
+ - This file defines the `AggFactoryProtocol` and `AggProtocol` protocols. The `AggFactoryProtocol` is a protocol for the function creating the aggregator instance, while the `AggProtocol` defines the methods and attributes that an aggregator must implement.
8
+
9
+ 3. **`local_learner.py`**
10
+ - This file defines the `LLFactoryProtocol` and `LLProtocol` protocols. The `LLFactoryProtocol` is a protocol for the function creating the local learner instance, while the `LLProtocol` defines the methods and attributes that a local learner must implement. This file provides the structure for implementing custom local learners.
11
+
12
+ 4. **`default_create_functions.py`**
13
+ - This file provides default implementations for creating local learners and aggregators. The `default_create_local_learner` and `default_create_aggregator` functions.
14
+
15
+ 5. **`cli.py`**
16
+ - This file contains the `create_structure` function, which sets up the project structure for uploading models to MLflow. It creates necessary directories and files, including `utils.py`, `example.py`, and `README.md`.
17
+
18
+ 6. **`templates.py`**
19
+ - This file contains template functions for generating the content of `utils.py`, `example.py`, and `README.md`. It includes the `get_utils_template`, `get_main_template`, and `get_readme_template` functions.
20
+
21
+ 7. **`readme_template.txt`**
22
+ - This file contains the template for the `README.md` file that is generated by the `create_structure` function. It provides a detailed guide on how to use the package functionalities.
23
+
24
+ 8. **`src/utils.py`**
25
+ - This file contains utility functions for managing and converting Flower objects. It includes methods for serializing and deserializing model parameters and other helper functions.
@@ -0,0 +1,17 @@
1
+ '''from .interface import FederatedModel, submit_fl_model'''
2
+ from .interface import FederatedModel
3
+ from .local_learner import LLFactoryProtocol, LLProtocol
4
+ from .aggregator import AggFactoryProtocol, AggProtocol
5
+ from .default_create_functions import default_create_local_learner, default_create_aggregator
6
+ from .cli import create_structure
7
+ import flwr
8
+
9
+ __all__ = ["FederatedModel",
10
+ "LLFactoryProtocol",
11
+ "LLProtocol",
12
+ "AggFactoryProtocol",
13
+ "AggProtocol",
14
+ "default_create_local_learner",
15
+ "default_create_aggregator",
16
+ "create_structure",
17
+ "flwr"]
@@ -0,0 +1,41 @@
1
+ import flwr
2
+ from typing import Protocol, runtime_checkable
3
+ from typing import Protocol, Optional
4
+
5
+
6
+ @runtime_checkable
7
+ class AggProtocol(Protocol):
8
+ """
9
+ Defines the required structure for aggregation strategies in federated learning.
10
+ """
11
+
12
+ def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.ParametersRecord:
13
+ """
14
+ Aggregates model parameters from a list of clients.
15
+ Args:
16
+ results: List of `flwr.common.ParametersRecord` from clients.
17
+ Returns:
18
+ Aggregated `flwr.common.ParametersRecord`.
19
+ """
20
+ ...
21
+
22
+ def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
23
+ """
24
+ Aggregates metrics from a list of clients.
25
+ Args:
26
+ results: List of `flwr.common.MetricsRecord` from clients.
27
+ Returns:
28
+ Aggregated `flwr.common.MetricsRecord`.
29
+ """
30
+ ...
31
+
32
+ @runtime_checkable
33
+ class AggFactoryProtocol(Protocol):
34
+ '''
35
+ Protocol for the function creating the aggregator instance
36
+ '''
37
+ def __call__(self) -> AggProtocol:
38
+ """
39
+ Creates and returns an instance of a class conforming to AggProtocol.
40
+ """
41
+ ...
FedModelKit/cli.py ADDED
@@ -0,0 +1,97 @@
1
+ import os
2
+ import sys
3
+ import argparse
4
+ from .templates import (
5
+ get_main_template,
6
+ get_task_template,
7
+ get_readme_template,
8
+ get_server_template,
9
+ get_client_template,
10
+ get_pyproject_template,
11
+ get_extern_pyproject_template,
12
+ get_uv_template,
13
+ get_init_template,
14
+ get_ds_template,
15
+ get_images,
16
+ )
17
+
18
+ def create_structure(exp_name: str = "new_experiment") -> None:
19
+ root_dir = os.getcwd()
20
+
21
+ structure = {
22
+ f"{exp_name}": {
23
+ "main.py": get_main_template(),
24
+ "pyproject.toml": get_pyproject_template(exp_name),
25
+ f"{exp_name}": {
26
+ "__init__.py": get_init_template(),
27
+ "task.py": get_task_template(),
28
+ "server_app.py": get_server_template(),
29
+ "client_app.py": get_client_template(),
30
+ },
31
+ },
32
+ "pyproject.toml": get_extern_pyproject_template(exp_name),
33
+ "uv.lock": get_uv_template(),
34
+ "README.md": get_readme_template(),
35
+ "ds.ipynb": get_ds_template(),
36
+ }
37
+
38
+
39
+ for name, content in structure.items():
40
+ path = os.path.join(root_dir, name)
41
+ if isinstance(content, dict): # If content is a folder
42
+ os.makedirs(path, exist_ok=True)
43
+ for name1, content1 in content.items():
44
+ nested_path = os.path.join(path, name1)
45
+ if isinstance(content1, dict): # If nested content is a folder
46
+ os.makedirs(nested_path, exist_ok=True)
47
+ for name2, content2 in content1.items():
48
+ nested_nested_path = os.path.join(nested_path, name2)
49
+ with open(nested_nested_path, "w") as f:
50
+ f.write(content2)
51
+ else:
52
+ with open(nested_path, "w") as f:
53
+ f.write(content1)
54
+ else: # Single file at the root
55
+ with open(path, "w") as f:
56
+ f.write(content)
57
+ # Create images directory and copy images
58
+ get_images(root_dir)
59
+
60
+ print(f"Project structure for {exp_name} created successfully in {root_dir}.")
61
+
62
+
63
+ '''def main():
64
+ if len(sys.argv) < 2:
65
+ print("Usage: msi <command> [options]")
66
+ sys.exit(1)
67
+
68
+ command = sys.argv[1]
69
+
70
+ if command == "init":
71
+ structure_type = sys.argv[2] if len(sys.argv) > 2 else "default"
72
+ create_structure(structure_type)
73
+ else:
74
+ print(f"Unknown command: {command}")
75
+ sys.exit(1)
76
+ '''
77
+
78
+ def main():
79
+ if len(sys.argv) < 2:
80
+ print("Usage: fmk init -n/--name [project_name]")
81
+ sys.exit(1)
82
+
83
+ parser = argparse.ArgumentParser(prog='fmk', description='FedModelKit CLI for managing federated learning projects')
84
+ subparsers = parser.add_subparsers(dest='command')
85
+
86
+
87
+ # init subcommand
88
+ init_parser = subparsers.add_parser('init', help='Initialize the project')
89
+ init_parser.add_argument('-n', '--name', required=True, help='Set the experiment title')
90
+
91
+ args = parser.parse_args()
92
+
93
+ if args.command == 'init':
94
+ print(f"Initializing project with title: {args.name}")
95
+
96
+ # Create the project structure
97
+ create_structure(args.name)
@@ -0,0 +1,114 @@
1
+ # Create your custom function defining a PyTorch-based model and returning its instance
2
+ def default_create_local_learner():
3
+ import pandas as pd
4
+ from torch import nn, optim
5
+ from torch.utils.data import DataLoader, Dataset
6
+ import torch
7
+ import flwr
8
+
9
+ from .src.utils import Utils # Dependency script stored inside the src directory
10
+
11
+ class DefaultLocalLearner(nn.Module):
12
+ def __init__(self, input_size: int) -> None:
13
+ super(DefaultLocalLearner, self).__init__()
14
+ self.linear = nn.Linear(input_size, 3)
15
+ self.softmax = nn.Softmax(dim=1)
16
+ self.loss_fn = nn.CrossEntropyLoss()
17
+ self.optimizer = optim.Adam(self.parameters(), lr=0.001)
18
+
19
+ def _forward(self, x):
20
+ x = self.linear(x)
21
+ return self.softmax(x)
22
+
23
+ def get_parameters(self) -> flwr.common.ParametersRecord:
24
+ return Utils.pytorch_to_parameter_record(self.state_dict())
25
+
26
+ def set_parameters(self, parameters: flwr.common.ParametersRecord) -> None:
27
+ self.load_state_dict(Utils.parameters_to_pytorch_state_dict(parameters))
28
+
29
+ def prepare_data(self, data: pd.DataFrame) -> None:
30
+ class IrisDataset(Dataset):
31
+ def __init__(self, dataframe: pd.DataFrame) -> None:
32
+ self.dataframe = dataframe
33
+ self.dataframe.loc[:, "class"] = dataframe.loc[:, "class"].replace(
34
+ {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
35
+ )
36
+
37
+ def __getitem__(self, idx: int):
38
+ x = self.dataframe.iloc[idx, :-1].to_numpy("float32")
39
+ y = torch.tensor(self.dataframe.iloc[idx, -1], dtype=torch.long)
40
+ return x, y
41
+
42
+ def __len__(self) -> int:
43
+ return len(self.dataframe)
44
+
45
+ dataset = IrisDataset(data)
46
+ dataloader = DataLoader(dataset, batch_size=32)
47
+ self.dataloader = dataloader
48
+
49
+ def train_round(self) -> flwr.common.MetricsRecord: # type: ignore
50
+ for batch in self.dataloader:
51
+ x, y = batch
52
+ self.optimizer.zero_grad()
53
+ y_hat = self._forward(x)
54
+ loss = self.loss_fn(y_hat, y)
55
+ loss.backward()
56
+ self.optimizer.step()
57
+
58
+ return flwr.common.MetricsRecord({"loss": loss.item()}) # type: ignore
59
+
60
+ def evaluate(self) -> flwr.common.MetricsRecord:
61
+ correct = 0
62
+ total = 0
63
+ with torch.no_grad():
64
+ for batch in self.dataloader:
65
+ x, y = batch
66
+ y_hat = self._forward(x)
67
+ _, predicted = torch.max(y_hat, 1)
68
+ total += y.size(0)
69
+ correct += (predicted == y).sum().item()
70
+ return flwr.common.MetricsRecord({"accuracy": correct / total})
71
+
72
+ return DefaultLocalLearner(4)
73
+
74
+
75
+ # Create your custom function defining an aggregation strategy and returning its instance
76
+ def default_create_aggregator():
77
+ from collections import OrderedDict
78
+ import numpy as np
79
+ import flwr
80
+ from typing import Optional
81
+
82
+ from .src.utils import Utils # Dependency script stored inside the src directory
83
+
84
+ class DefaultAggregator:
85
+
86
+ def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None
87
+ ) -> flwr.common.ParametersRecord:
88
+ parameters = [Utils.parameters_to_dict(param) for param in results]
89
+ keys = parameters[0].keys()
90
+ result = OrderedDict()
91
+ for key in keys:
92
+ # Init array
93
+ this_array: np.ndarray = np.zeros_like(parameters[0][key])
94
+ for p in parameters:
95
+ this_array += p[key]
96
+ result[key] = this_array / len(results)
97
+ return Utils.dict_to_parameter_record(result)
98
+
99
+ def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
100
+ keys = results[0].keys()
101
+ result = OrderedDict()
102
+ for key in keys:
103
+ # Init array
104
+ cumsum = 0.0
105
+ for m in results:
106
+ if not isinstance(m[key], (int, float)):
107
+ raise ValueError(
108
+ f"flwr.common.MetricsRecord value type not supported: {type(m[key])}"
109
+ )
110
+ cumsum += m[key] # type: ignore
111
+ result[key] = cumsum / len(results)
112
+ return flwr.common.MetricsRecord(result) # type: ignore
113
+
114
+ return DefaultAggregator()
@@ -0,0 +1,130 @@
1
+ '''import mlflow
2
+ from mlflow import MlflowClient
3
+ from mlflow.pyfunc.model import PythonModel
4
+ from mlflow.pyfunc import log_model'''
5
+
6
+ from .local_learner import LLFactoryProtocol, LLProtocol
7
+ from .aggregator import AggFactoryProtocol, AggProtocol
8
+ from .default_create_functions import default_create_local_learner, default_create_aggregator
9
+
10
+ from typing import Literal
11
+ import os
12
+ import warnings
13
+
14
+
15
+ class FederatedModel():
16
+ """
17
+ A class for creating a federated learning model and its aggregator. If no aggregator is provided
18
+ the PlainAVGAggregator is passed as default
19
+
20
+ This class encapsulates the local learner, the aggregation strategy, and their respective names.
21
+ It is designed to facilitate the integration of these components into the MLflow tracking system.
22
+
23
+ Attributes:
24
+ aggregator (AggFactoryProtocol): The aggregation strategy for federated learning.
25
+ aggregator_name (str): The name of the aggregation strategy (default is "PlainAVG").
26
+ local_learner (LLFactoryProtocol): The local learner factory for creating model instances.
27
+ model_name (str): The name of the model to be registered in MLflow.
28
+ """
29
+ def __init__(self, create_local_learner: LLFactoryProtocol = default_create_local_learner,
30
+ model_name: str = "Default_iris_model",
31
+ create_aggregator: AggFactoryProtocol = default_create_aggregator,
32
+ aggregator_name: str = "PlainAVG"
33
+ ) -> None:
34
+ self.create_aggregator = create_aggregator
35
+ self.aggregator_name = aggregator_name
36
+ self.create_local_learner = create_local_learner
37
+ self.model_name = model_name
38
+
39
+
40
+
41
+ '''def submit_fl_model(model: FederatedModel,
42
+ platform_url: str,
43
+ username: str,
44
+ password: str,
45
+ experiment_name: str,
46
+ disease: Literal["AML", "SCD"],
47
+ trained: bool) -> dict:
48
+ """
49
+ Submit the model and aggregation strategy to MLflow.
50
+
51
+ This function logs the model and aggregator as artifacts, registers the model in MLflow,
52
+ and ensures a new run is created in the specified experiment.
53
+
54
+ Args:
55
+ model (FederatedModel): The federated learning model to be submitted.
56
+ platform_url (str): The URL of the MLflow tracking server.
57
+ username (str): The username for authenticating with the MLflow tracking server.
58
+ password (str): The password for authenticating with the MLflow tracking server.
59
+ experiment_name (str): The name of the MLflow experiment to use or create.
60
+ disease (Literal["AML", "SCD"]): The use case for the model.
61
+ trained (bool): Indicates whether the model is already trained.
62
+
63
+ Raises:
64
+ RuntimeError: If the username or password is not provided.
65
+ """
66
+ # Ignore warnings about input example
67
+ warnings.filterwarnings("ignore")
68
+
69
+ # Check if username and password are provided
70
+ if not username or not password:
71
+ raise RuntimeError("Username and password must be provided.")
72
+ # Check if the use case is valid
73
+ assert disease in ["AML", "SCD"], "Disease must be either 'AML' or 'SCD'"
74
+
75
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
76
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
77
+ os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true"
78
+
79
+ MLFLOW_URL = platform_url
80
+ mlflow.set_tracking_uri(MLFLOW_URL)
81
+
82
+ # Test model and aggregator protocol
83
+ model_instance = model.create_local_learner()
84
+ aggregator_instance = model.create_aggregator()
85
+ assert isinstance(model.create_aggregator, AggFactoryProtocol), "create_aggregator function does not conform to AggFactoryProtocol"
86
+ assert isinstance(model.create_local_learner, LLFactoryProtocol), "create_local_learner function does not conform to LLFactoryProtocol"
87
+ assert isinstance(model_instance, LLProtocol), "Local learner does not conform to LLProtocol"
88
+ assert isinstance(aggregator_instance, AggProtocol), "Aggregator instance does not conform to AggProtocol"
89
+
90
+ # Ensure the experiment exists or create it
91
+ experiment = mlflow.get_experiment_by_name(experiment_name)
92
+ if experiment is None:
93
+ experiment_id = mlflow.create_experiment(experiment_name)
94
+ else:
95
+ experiment_id = experiment.experiment_id
96
+
97
+
98
+ # Start a new run
99
+ with mlflow.start_run(experiment_id=experiment_id) as run:
100
+ mlflow.set_tag("use_case", disease)
101
+ mlflow.set_tag("trained", str(trained))
102
+ if os.path.isdir("./src"):
103
+ model_info = log_model(
104
+ artifact_path="model",
105
+ python_model=model,
106
+ registered_model_name=model.model_name,
107
+ code_paths=["src"],
108
+ )
109
+ else:
110
+ model_info = log_model(
111
+ artifact_path="model",
112
+ python_model=model,
113
+ registered_model_name=model.model_name,
114
+ )
115
+ mlflow_client = MlflowClient(tracking_uri=MLFLOW_URL)
116
+ model_meta = mlflow_client.get_latest_versions(model.model_name, stages=["None"])
117
+ version = model_meta[0].version
118
+ # mlflow_client.update_model_version(model.model_name, version, description)
119
+ tags = {"use_case": disease, "trained": str(trained)}
120
+ for key, value in tags.items():
121
+ mlflow_client.set_model_version_tag(model.model_name, version, key, value)
122
+ # mlflow_client.set_experiment_tag(experiment_id, "use_case", disease)
123
+ # mlflow_client.set_experiment_tag(experiment_id, "Trained", str(trained))
124
+ return {
125
+ "detail": f"Model '{model.model_name}' registered.",
126
+ "model_uuid": model_info.model_uuid,
127
+ "run_id": model_info.run_id,
128
+ "model_uri": model_info.model_uri,
129
+ }'''
130
+
@@ -0,0 +1,61 @@
1
+ from typing import Protocol, runtime_checkable
2
+
3
+ import pandas as pd
4
+ import flwr
5
+
6
+ @runtime_checkable
7
+ class LLProtocol(Protocol):
8
+ """
9
+ Protocol for defining a local learner in a federated learning setup.
10
+ """
11
+
12
+ def prepare_data(self, data: pd.DataFrame) -> None:
13
+ """
14
+ Prepares input data for training or evaluation.
15
+ Args:
16
+ data: Input data as a pandas DataFrame.
17
+ """
18
+ ...
19
+
20
+ def train_round(self) -> flwr.common.MetricsRecord:
21
+ """
22
+ Trains the model and returns metrics.
23
+ Returns:
24
+ flwr.common.MetricsRecord: Performance metrics after training.
25
+ """
26
+ ...
27
+
28
+ def get_parameters(self) -> flwr.common.ParametersRecord:
29
+ """
30
+ Retrieves the current model parameters.
31
+ Returns:
32
+ flwr.common.ParametersRecord: Current model parameters.
33
+ """
34
+ ...
35
+
36
+ def set_parameters(self, parameters: flwr.common.ParametersRecord) -> None:
37
+ """
38
+ Sets the model parameters.
39
+ Args:
40
+ parameters: A flwr.common.ParametersRecord containing the parameters to set.
41
+ """
42
+ ...
43
+
44
+ def evaluate(self) -> flwr.common.MetricsRecord:
45
+ """
46
+ Evaluates the model and returns performance metrics.
47
+ Returns:
48
+ flwr.common.MetricsRecord: Evaluation metrics.
49
+ """
50
+ ...
51
+
52
+ @runtime_checkable
53
+ class LLFactoryProtocol(Protocol):
54
+ '''
55
+ Protocol for the function creating the local learner instance
56
+ '''
57
+ def __call__(self) -> LLProtocol:
58
+ """
59
+ Creates and returns an instance of a class conforming to LLProtocol.
60
+ """
61
+ ...
FedModelKit/py.typed ADDED
File without changes
@@ -0,0 +1,65 @@
1
+ from collections import OrderedDict
2
+ import flwr
3
+
4
+ import numpy as np
5
+
6
+
7
+ class Utils:
8
+ import flwr
9
+ @staticmethod
10
+ def _basic_array_deserialisation(array: flwr.common.Array) -> flwr.common.NDArray:
11
+ return np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)
12
+
13
+ @staticmethod
14
+ def parameters_to_dict(params_record: flwr.common.ParametersRecord) -> OrderedDict:
15
+ state_dict = OrderedDict()
16
+ for k, v in params_record.items():
17
+ state_dict[k] = Utils._basic_array_deserialisation(v)
18
+
19
+ return state_dict
20
+
21
+ @staticmethod
22
+ def _ndarray_to_array(ndarray: flwr.common.NDArray) -> flwr.common.Array:
23
+ """Represent NumPy ndarray as Array."""
24
+ return flwr.common.Array(
25
+ data=ndarray.tobytes(),
26
+ dtype=str(ndarray.dtype),
27
+ stype="numpy.ndarray.tobytes",
28
+ shape=list(ndarray.shape),
29
+ )
30
+
31
+ @staticmethod
32
+ def dict_to_parameter_record(
33
+ parameters: OrderedDict["str", flwr.common.NDArray],
34
+ ) -> flwr.common.ParametersRecord:
35
+ state_dict = OrderedDict()
36
+ for k, v in parameters.items():
37
+ state_dict[k] = Utils._ndarray_to_array(v)
38
+
39
+ return flwr.common.ParametersRecord(state_dict)
40
+
41
+ @staticmethod
42
+ def pytorch_to_parameter_record(
43
+ state_dict: dict,
44
+ ) -> flwr.common.ParametersRecord:
45
+ """Serialise your PyTorch model."""
46
+ transformed_state_dict = OrderedDict()
47
+
48
+ for k, v in state_dict.items():
49
+ transformed_state_dict[k] = Utils._ndarray_to_array(v.numpy())
50
+
51
+ return flwr.common.ParametersRecord(transformed_state_dict)
52
+
53
+ @staticmethod
54
+ def parameters_to_pytorch_state_dict(
55
+ params_record: flwr.common.ParametersRecord,
56
+ ) -> dict:
57
+ # Make sure to import locally torch as it is only available in the server
58
+ import torch
59
+
60
+ """Reconstruct PyTorch state_dict from its serialised representation."""
61
+ state_dict = {}
62
+ for k, v in params_record.items():
63
+ state_dict[k] = torch.tensor(Utils._basic_array_deserialisation(v))
64
+
65
+ return state_dict
File without changes
@@ -0,0 +1,118 @@
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pickle
5
+ from pathlib import Path
6
+ from flwr.client import ClientApp
7
+ from flwr.common import Message, Context
8
+ from flwr.common.record import RecordSet, MetricsRecord, ConfigsRecord
9
+ from sklearn.preprocessing import OneHotEncoder
10
+ import FedModelKit as msi
11
+
12
+ from model_example import create_local_learner #type: ignore[import]
13
+ from load_data import load_data #type: ignore[import]
14
+
15
+
16
+ # Initialize the Flower ClientApp
17
+ app = ClientApp()
18
+
19
+ @app.query()
20
+ def query(msg: Message, ctx: Context) -> Message:
21
+ """
22
+ Query function to be executed by the Flower client. This function handles the
23
+ initial configuration sent by the server.
24
+ """
25
+
26
+ # Retrieve the configuration sent by the server
27
+ fancy_config = msg.content.configs_records['fancy_config']
28
+
29
+ # Load the client split data using the load_data function
30
+ data = load_data(fancy_config['num_clients'], fancy_config['client_id'])
31
+
32
+ # Instantiate the federated model
33
+ federated_model = msi.FederatedModel(create_local_learner=create_local_learner, model_name='simple_lr')
34
+
35
+ # Store the local learner and the data split in the context
36
+ # To store in context other objects, you can use ctx.state.<object_name> = <object>
37
+ ctx.state.local_learner = federated_model.create_local_learner()
38
+ ctx.state.data = data
39
+ ctx.state.undestand = "che succede"
40
+
41
+ return msg.create_reply(RecordSet())
42
+
43
+ @app.train()
44
+ def train(msg: Message, ctx: Context):
45
+ """
46
+ Train function to be executed by the Flower client.
47
+ This function handles the training of the local model using the data provided.
48
+ """
49
+
50
+ # Retrieve the local learner and the client split from the context
51
+ local_learner = ctx.state.local_learner
52
+ data = ctx.state.data
53
+
54
+ # Retrieve configuration sent by the server - example
55
+ #fancy_config = msg.content.configs_records['fancy_config']
56
+ #local_epochs = fancy_config['local_epochs']
57
+
58
+ # Retrieve the model parameters sent by the server
59
+ fancy_parameters = msg.content.parameters_records['fancy_model']
60
+ local_learner.set_parameters(fancy_parameters)
61
+
62
+ # Prepare the data using the local learner
63
+ local_learner.prepare_data(data)
64
+
65
+ # Perform local training and obtain training metrics
66
+ train_metrics = local_learner.train_round()
67
+
68
+ # Retrieve the trained model parameters
69
+ new_parameters_records = local_learner.get_parameters()
70
+ assert ctx.state.undestand.startswith("che"), "The context state is not being stored correctly"
71
+
72
+ # Construct a reply message carrying updated model parameters and generated metrics
73
+ reply_content = RecordSet()
74
+ reply_content.parameters_records['fancy_model_returned'] = new_parameters_records
75
+ reply_content.metrics_records['train_metrics'] = train_metrics
76
+
77
+ # Store the metrics and the local learner in the context for future reference
78
+ ctx.state.metrics_records['prev'] = train_metrics
79
+ ctx.state.local_learner = local_learner
80
+
81
+ # Return the reply message to the server
82
+ return msg.create_reply(reply_content)
83
+
84
+ @app.evaluate()
85
+ def eval(msg: Message, ctx: Context):
86
+ """
87
+ Evaluate function to be executed by the Flower client.
88
+ This function handles the evaluation of the local model using the data provided.
89
+ """
90
+
91
+ # Retrieve the local learner and the client split from the context
92
+ local_learner = ctx.state.local_learner
93
+ data = ctx.state.data
94
+
95
+ # Retrieve configuration sent by the server - example
96
+ #fancy_config = msg.content.configs_records['fancy_config']
97
+ #local_epochs = fancy_config['local_epochs']
98
+
99
+ # Retrieve the model parameters sent by the server
100
+ fancy_parameters = msg.content.parameters_records['fancy_model']
101
+ local_learner.set_parameters(fancy_parameters)
102
+
103
+ # Prepare the data using the local learner
104
+ local_learner.prepare_data(data)
105
+
106
+ # Evaluate the model and obtain evaluation metrics
107
+ eval_metrics = local_learner.evaluate()
108
+
109
+ # Construct a reply message with evaluation metrics
110
+ reply_content = RecordSet()
111
+ reply_content.metrics_records['eval_metrics'] = eval_metrics
112
+
113
+ # Store the metrics and the local learner in the context for future reference
114
+ ctx.state.metrics_records['prev'] = eval_metrics
115
+ ctx.state.local_learner = local_learner
116
+
117
+ # Return the reply message to the server
118
+ return msg.create_reply(reply_content)