FedModelKit 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,332 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Data Scientist (DS)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "Change `LOCAL_TEST` to True if you want to run the clients locally to test. \n",
15
+ "With `LOCAL_TEST = False`, please have your syftbox client running. Installation instructions here https://www.syftbox.net/."
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "LOCAL_TEST = True"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "## Some paths and constants "
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "from pathlib import Path\n",
41
+ "\n",
42
+ "SYFTBOX_DATASET_NAME = \"pima-indians-diabetes-database\""
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "metadata": {},
48
+ "source": [
49
+ "## Log into the data owners' datasites"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "if LOCAL_TEST:\n",
59
+ " from syft_rds.orchestra import setup_rds_server\n",
60
+ "\n",
61
+ " print(\"Running locally!\")\n",
62
+ "\n",
63
+ " DS = \"ds@openmined.org\"\n",
64
+ " print(\"DS email: \", DS)\n",
65
+ "\n",
66
+ " DO1 = \"do1@openmined.org\"\n",
67
+ " DO2 = \"do2@openmined.org\"\n",
68
+ "\n",
69
+ " ds_stack = setup_rds_server(email=DS, key=\"flwr\", root_dir=Path(\".\"))\n",
70
+ " do_client_1 = ds_stack.init_session(host=DO1)\n",
71
+ " do_client_2 = ds_stack.init_session(host=DO2)\n",
72
+ "else:\n",
73
+ " import syft_rds as sy\n",
74
+ " from syft_core import Client\n",
75
+ "\n",
76
+ " DS = Client.load().email\n",
77
+ " print(\"DS email: \", DS)\n",
78
+ "\n",
79
+ " DO1 = \"flower-test-group-1@openmined.org\"\n",
80
+ " DO2 = \"flower-test-group-2@openmined.org\"\n",
81
+ "\n",
82
+ " do_client_1 = sy.init_session(host=DO1)\n",
83
+ " print(\"Logged into: \", do_client_1.host)\n",
84
+ "\n",
85
+ " do_client_2 = sy.init_session(host=DO2)\n",
86
+ " print(\"Logged into: \", do_client_2.host)\n",
87
+ "\n",
88
+ "do_clients = [do_client_1, do_client_2]\n",
89
+ "do_emails = [DO1, DO2]"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Explore the datasets"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "SYFTBOX_DATASET_NAME"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "mock_paths = []\n",
115
+ "for client in do_clients:\n",
116
+ " dataset = client.dataset.get(name=SYFTBOX_DATASET_NAME)\n",
117
+ " mock_paths.append(dataset.get_mock_path())\n",
118
+ " print(f\"Client {client.host}'s dataset: \\n{dataset}\\n\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "## Bootstrapping and run `syft_flwr` simulation"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "SYFT_FLWR_PROJECT_PATH = Path(\"./fl-diabetes-prediction\")\n",
135
+ "assert SYFT_FLWR_PROJECT_PATH.exists()"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "import syft_flwr\n",
145
+ "\n",
146
+ "try:\n",
147
+ " !rm -rf {SYFT_FLWR_PROJECT_PATH / \"main.py\"}\n",
148
+ " syft_flwr.bootstrap(SYFT_FLWR_PROJECT_PATH, aggregator=DS, datasites=do_emails)\n",
149
+ " print(\"Bootstrapped project successfully ✅\")\n",
150
+ "except Exception as e:\n",
151
+ " print(e)"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {},
157
+ "source": [
158
+ "## Run `flwr` and `syft_flwr` simulations (optional)"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "RUN_SIMULATION = True"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "if RUN_SIMULATION:\n",
177
+ " !flwr run {SYFT_FLWR_PROJECT_PATH}"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "# clean up\n",
187
+ "!rm -rf {SYFT_FLWR_PROJECT_PATH / \"fl_diabetes_prediction\" / \"__pycache__\"}\n",
188
+ "!rm -rf weights/"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "mock_paths"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "if RUN_SIMULATION:\n",
207
+ " print(f\"running syft_flwr simulation with mock paths: {mock_paths}\")\n",
208
+ " syft_flwr.run(SYFT_FLWR_PROJECT_PATH, mock_paths)"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {},
214
+ "source": [
215
+ "## Submit jobs"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {},
221
+ "source": [
222
+ "<img src=\"./images/dsSendsJobs.png\" width=\"80%\" alt=\"DS Submits Jobs\">"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "# clean up before submitting jobs\n",
232
+ "!rm -rf {SYFT_FLWR_PROJECT_PATH / \"fl_diabetes_prediction\" / \"__pycache__\"}\n",
233
+ "!rm -rf {SYFT_FLWR_PROJECT_PATH / \"simulation_logs\"}\n",
234
+ "!rm -rf weights/"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "for client in do_clients:\n",
244
+ " print(f\"sending job to {client.host}\")\n",
245
+ " job = client.jobs.submit(\n",
246
+ " name=\"Syft Flower Experiment\",\n",
247
+ " description=\"Syft Flower Federated Learning Experiment\",\n",
248
+ " user_code_path=SYFT_FLWR_PROJECT_PATH,\n",
249
+ " dataset_name=SYFTBOX_DATASET_NAME,\n",
250
+ " tags=[\"federated learning\", \"fl\", \"syft_flwr\", \"flwr\"],\n",
251
+ " entrypoint=\"main.py\",\n",
252
+ " )\n",
253
+ " print(job)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "metadata": {},
259
+ "source": [
260
+ "<img src=\"./images/dsDoneSubmittingJobs.png\" width=\"40%\" alt=\"DS waits for jobs to be approved\">"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "metadata": {},
266
+ "source": [
267
+ "## DS starts the FL server code"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "import os\n",
277
+ "\n",
278
+ "if LOCAL_TEST:\n",
279
+ " os.environ[\"SYFTBOX_CLIENT_CONFIG_PATH\"] = str(ds_stack.client.config_path)\n",
280
+ "\n",
281
+ "os.environ[\"LOGURU_LEVEL\"] = \"DEBUG\"\n",
282
+ "os.environ[\"SYFT_FLWR_MSG_TIMEOUT\"] = \"60\"\n",
283
+ "\n",
284
+ "!uv run {str(SYFT_FLWR_PROJECT_PATH / \"main.py\")} --active"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "metadata": {},
290
+ "source": [
291
+ "By running the FL server code, the DS aggregates the models trained on DOs' private local data into an improved global model\n",
292
+ "\n",
293
+ "<img src=\"./images/dsAggregateModels.png\" width=\"30%\" alt=\"DS Aggregates Models\">"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {},
299
+ "source": [
300
+ "## DS Observes the Results"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "markdown",
305
+ "metadata": {},
306
+ "source": [
307
+ "Now the DS can monitor the aggregated models trained no DO's private datasets in the `weights` folder"
308
+ ]
309
+ }
310
+ ],
311
+ "metadata": {
312
+ "kernelspec": {
313
+ "display_name": ".venv",
314
+ "language": "python",
315
+ "name": "python3"
316
+ },
317
+ "language_info": {
318
+ "codemirror_mode": {
319
+ "name": "ipython",
320
+ "version": 3
321
+ },
322
+ "file_extension": ".py",
323
+ "mimetype": "text/x-python",
324
+ "name": "python",
325
+ "nbconvert_exporter": "python",
326
+ "pygments_lexer": "ipython3",
327
+ "version": "3.12.9"
328
+ }
329
+ },
330
+ "nbformat": 4,
331
+ "nbformat_minor": 2
332
+ }
@@ -0,0 +1,17 @@
1
+ [project]
2
+ name = "EXPERIMENT_NAME"
3
+ version = "0.1.0"
4
+ description = "Description of the experiment"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "flwr-datasets>=0.5.0",
9
+ "flwr[simulation]==1.17.0",
10
+ "imblearn>=0.0",
11
+ "jupyterlab>=4.4.4",
12
+ "loguru>=0.7.3",
13
+ "pandas>=2.3.0",
14
+ "scikit-learn==1.6.1",
15
+ "syft-flwr>=0.1.5",
16
+ "torch==2.7.0",
17
+ ]
@@ -0,0 +1,27 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from syft_core import Client
5
+
6
+ from syft_flwr.config import load_flwr_pyproject
7
+ from syft_flwr.run import syftbox_run_flwr_client, syftbox_run_flwr_server
8
+
9
+ DATA_DIR = os.getenv("DATA_DIR")
10
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR")
11
+
12
+
13
+ flower_project_dir = Path(__file__).parent.absolute()
14
+ client = Client.load()
15
+ config = load_flwr_pyproject(flower_project_dir)
16
+
17
+ is_client = client.email in config["tool"]["syft_flwr"]["datasites"]
18
+ is_server = client.email in config["tool"]["syft_flwr"]["aggregator"]
19
+
20
+ if is_client:
21
+ # run by each DO
22
+ syftbox_run_flwr_client(flower_project_dir)
23
+ elif is_server:
24
+ # run by the DS
25
+ syftbox_run_flwr_server(flower_project_dir)
26
+ else:
27
+ raise ValueError(f"{client.email} is not in config.datasites or config.aggregator")
@@ -0,0 +1,53 @@
1
+ [build-system]
2
+ requires = [
3
+ "hatchling",
4
+ ]
5
+ build-backend = "hatchling.build"
6
+
7
+ [project]
8
+ name = "EXPERIMENT_NAME"
9
+ version = "1.0.0"
10
+ description = "Description of the experiment"
11
+ license = "Apache-2.0"
12
+ dependencies = [
13
+ "flwr[simulation]==1.17.0",
14
+ "flwr-datasets>=0.5.0",
15
+ "torch==2.7.0",
16
+ "imblearn",
17
+ "pandas",
18
+ "scikit-learn==1.6.1",
19
+ "loguru",
20
+ "jupyter",
21
+ "syft_flwr==0.1.5",
22
+ ]
23
+
24
+ [tool.hatch.build.targets.wheel]
25
+ packages = [
26
+ ".",
27
+ ]
28
+
29
+ [tool.flwr.app]
30
+ publisher = "OpenMined"
31
+
32
+ [tool.flwr.app.components]
33
+ serverapp = "EXPERIMENT_NAME.server_app:app"
34
+ clientapp = "EXPERIMENT_NAME.client_app:app"
35
+
36
+ [tool.flwr.app.config]
37
+ num-server-rounds = 3
38
+ partition-id = 0
39
+ num-partitions = 1
40
+
41
+ [tool.flwr.federations]
42
+ default = "local-simulation"
43
+
44
+ [tool.flwr.federations.local-simulation.options]
45
+ num-supernodes = 2
46
+
47
+ [tool.syft_flwr]
48
+ app_name = "ds@openmined.org_fl-diabetes-prediction_1751887833"
49
+ datasites = [
50
+ "do1@openmined.org",
51
+ "do2@openmined.org",
52
+ ]
53
+ aggregator = "ds@openmined.org"
@@ -0,0 +1,48 @@
1
+ # Diabetes Prediction with `syft_flwr`
2
+
3
+ ## Introduction
4
+
5
+ In this tutorial, we'll walk through a practical federated learning implementation for diabetes prediction using [syft_flwr](https://github.com/OpenMined/syft-flwr) — a framework that combines the flexibility of [Flower](https://github.com/adap/flower/) (a popular federated learning framework) with the privacy-preserving networking capabilities of [syftbox](https://www.syftbox.net/).
6
+
7
+ ![overview](./images/overview.png)
8
+
9
+ Dataset: https://www.kaggle.com/datasets/uciml/pima-indians-diabetes-database/
10
+
11
+ ## Set up
12
+
13
+ ### Setup python virtual environment
14
+ Assume that you have python and the [uv](https://docs.astral.sh/uv/) package manager installed. Now let's create a virtual python environment with all dependencies installed:
15
+ ```bash
16
+ uv sync
17
+ ```
18
+
19
+ ### Install and run `syftbox` client
20
+ Make sure you have syftbox client running in a terminal:
21
+ 1. Install `syftbox`: `curl -fsSL https://syftbox.net/install.sh | sh`
22
+ 2. Follow the instructions to start your `syftbox` client
23
+
24
+ When you have `syftbox` installed and run in the background, you can proceeed and run the notebooks with the installed Python environment in your favorite IDE.
25
+
26
+ ### Local Setup
27
+ At the start of the notebooks, you can set the flag `LOCAL_TEST` to `True` if you want to run all the clients (2 data owners and 1 data scientist) locally to test the whole workflow, where all clients' local datasites will be saved locally under the `flwr` folder. For this, you need to follow the running the notebooks `do1.ipynb`, `do2.ipynb`, then `ds.ipynb` in order.
28
+
29
+ If running locally, you don't need your `syftbox client` running.
30
+
31
+
32
+ ### Distributed setup
33
+ Set `LOCAL_TEST` to `False` and make sure you have your `syftbox` client running if you want to run the clients over the `syftbox` network.
34
+
35
+ 1. For the data scientist's workflow (prepare code, observe mock datasets on the data owner's datasites, submit jobs), please look into the `ds.ipynb` notebook. Following this notebook will help you submit jobs to two datasites named `flower-test-group-1@openmined.org` and `flower-test-group-2@openmined.org` that host 2 partitions of the `pima-indians-diabetes-database`, and they will approve your job automatically.
36
+
37
+ Optionally, you can look at the `local_training.ipynb` to see the DS's process of processing data and training the neural network locally.
38
+
39
+ 2. For the data owner's workflow (uploading dataset, monitor and approve jobs), please take a look at `do.ipynb` notebook. Following this notebook, you will learn how to upload your partition of the `pima-indians-diabetes-database` so others can submit jobs to you.
40
+
41
+ ## References
42
+ - https://syftbox.net
43
+ - https://www.kaggle.com/datasets/uciml/pima-indians-diabetes-database/
44
+ - https://github.com/OpenMined/syftbox
45
+ - https://github.com/OpenMined/syft-flwr
46
+ - https://github.com/adap/flower/
47
+ - https://github.com/OpenMined/rds
48
+ - https://github.com/elarsiad/diabetes-prediction-keras
@@ -0,0 +1,204 @@
1
+
2
+ from typing import List
3
+ import time
4
+
5
+ import flwr as fl
6
+ from flwr.common import (
7
+ Context,
8
+ NDArrays,
9
+ Message,
10
+ MessageType,
11
+ Metrics,
12
+ RecordSet,
13
+ ConfigsRecord,
14
+ DEFAULT_TTL,
15
+ )
16
+ from flwr.server import Driver
17
+
18
+ import FedModelKit as msi
19
+
20
+ from model_example import create_local_learner #type: ignore[import]
21
+
22
+
23
+ # Run via `flower-server-app server:app`
24
+ app = fl.server.ServerApp()
25
+
26
+
27
+
28
+
29
+ @app.main()
30
+ def main(driver: Driver, context: Context) -> None:
31
+ """
32
+ Main function to run the federated learning server.
33
+
34
+ Structure:
35
+ - Send a query message to clients for creating the local learner and loading the data
36
+ - Start global epochs loop for training and evaluation
37
+ - Send training messages to clients
38
+ - Aggregate parameters received from clients
39
+ - Send evaluation messages to clients
40
+ - Aggregate evaluation metrics
41
+ """
42
+ print("Starting test run")
43
+
44
+ # Get node IDs of connected clients
45
+ node_ids = driver.get_node_ids()
46
+
47
+ # Initialize the federated model
48
+ federated_model = msi.FederatedModel(create_local_learner=create_local_learner,
49
+ model_name='simple_lr')
50
+ global_model = federated_model.create_local_learner()
51
+ aggregation_strategy = federated_model.create_aggregator()
52
+
53
+ # Send a query message to clients for creating the local learner and loading the data
54
+ messages = []
55
+ for idx, node_id in enumerate(node_ids):
56
+ # Create messages to send to clients
57
+ recordset = RecordSet()
58
+
59
+ # Add a config with information to send the client for the query
60
+ recordset.configs_records["fancy_config"] = ConfigsRecord({"num_clients": len(node_ids), "client_id": idx})
61
+
62
+ # Create a query message for each client
63
+ message = driver.create_message(
64
+ content=recordset,
65
+ message_type=MessageType.QUERY,
66
+ dst_node_id=node_id,
67
+ group_id=str(1),
68
+ ttl=DEFAULT_TTL,
69
+ )
70
+ messages.append(message)
71
+
72
+ # Send training messages to clients
73
+ message_ids = driver.push_messages(messages)
74
+ print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
75
+
76
+ # Wait for results from clients
77
+ message_ids = [message_id for message_id in message_ids if message_id != ""]
78
+ all_replies: List[Message] = []
79
+ while True:
80
+ replies = driver.pull_messages(message_ids=message_ids)
81
+ print(f"Got {len(list(replies))} results")
82
+ all_replies += replies
83
+ if len(all_replies) == len(message_ids):
84
+ break
85
+ time.sleep(12)
86
+
87
+ # Filter out messages with errors
88
+ all_replies = [
89
+ msg
90
+ for msg in all_replies
91
+ if msg.has_content()
92
+ ]
93
+ print(f"Received {len(all_replies)} answers")
94
+
95
+
96
+ # Run federated training and evaluation for a fixed number of rounds
97
+ for server_round in range(3):
98
+ print(f"Commencing server train and evaluation round {server_round + 1}")
99
+
100
+ messages = []
101
+ for idx, node_id in enumerate(node_ids):
102
+ # Create messages to send to clients
103
+ recordset = RecordSet()
104
+
105
+ # Add model parameters to record
106
+ recordset.parameters_records["fancy_model"] = global_model.get_parameters()
107
+ # Add a config with information to send the client for training
108
+ recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
109
+
110
+ # Create a training message for each client
111
+ message = driver.create_message(
112
+ content=recordset,
113
+ message_type=MessageType.TRAIN,
114
+ dst_node_id=node_id,
115
+ group_id=str(server_round),
116
+ ttl=DEFAULT_TTL,
117
+ )
118
+ messages.append(message)
119
+
120
+ # Send training messages to clients
121
+ message_ids = driver.push_messages(messages)
122
+ print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
123
+
124
+ # Wait for results from clients
125
+ message_ids = [message_id for message_id in message_ids if message_id != ""]
126
+ all_replies: List[Message] = []
127
+ while True:
128
+ replies = driver.pull_messages(message_ids=message_ids)
129
+ print(f"Got {len(list(replies))} results")
130
+ all_replies += replies
131
+ if len(all_replies) == len(message_ids):
132
+ break
133
+ time.sleep(12)
134
+
135
+ # Filter out messages with errors
136
+ all_replies = [
137
+ msg
138
+ for msg in all_replies
139
+ if msg.has_content()
140
+ ]
141
+ print(f"Received {len(all_replies)} results")
142
+
143
+ # Print metrics received from clients
144
+ for reply in all_replies:
145
+ print(reply.content.metrics_records)
146
+
147
+ # Aggregate parameters received from clients
148
+ parameter_records_list = [reply.content.parameters_records["fancy_model_returned"] for reply in all_replies]
149
+ new_parameter_record = aggregation_strategy.aggregate_parameters(parameter_records_list)
150
+ global_model.set_parameters(new_parameter_record)
151
+
152
+ # Evaluate the updated global model
153
+ messages = []
154
+ for idx, node_id in enumerate(node_ids):
155
+ # Create evaluation messages for clients
156
+ recordset = RecordSet()
157
+
158
+ # Add updated model parameters to record
159
+ recordset.parameters_records["fancy_model"] = new_parameter_record
160
+ # Add a config with information to send the client for evaluation
161
+ recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
162
+
163
+ # Create an evaluation message for each client
164
+ message = driver.create_message(
165
+ content=recordset,
166
+ message_type=MessageType.EVALUATE,
167
+ dst_node_id=node_id,
168
+ group_id=str(server_round),
169
+ ttl=DEFAULT_TTL,
170
+ )
171
+ messages.append(message)
172
+
173
+ # Send evaluation messages to clients
174
+ message_ids = driver.push_messages(messages)
175
+ print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
176
+
177
+ # Wait for evaluation results from clients
178
+ message_ids = [message_id for message_id in message_ids if message_id != ""]
179
+ all_replies: List[Message] = []
180
+ while True:
181
+ replies = driver.pull_messages(message_ids=message_ids)
182
+ print(f"Got {len(list(replies))} results")
183
+ all_replies += replies
184
+ if len(all_replies) == len(message_ids):
185
+ break
186
+ time.sleep(3)
187
+
188
+ # Filter out messages with errors
189
+ all_replies = [
190
+ msg
191
+ for msg in all_replies
192
+ if msg.has_content()
193
+ ]
194
+ print(f"Received {len(all_replies)} results")
195
+
196
+ # Print evaluation metrics received from clients
197
+ metrics_records_list = [reply.content.metrics_records['eval_metrics'] for reply in all_replies]
198
+ for i, reply in enumerate(all_replies):
199
+ print(f"Client {i+1} metrics: ", reply.content.metrics_records['eval_metrics'])
200
+
201
+ # Aggregate evaluation metrics
202
+ print("Aggregated metrics result: ", aggregation_strategy.aggregate_metrics(metrics_records_list))
203
+
204
+ print("🎉🎉🎉 Successfully completed federated learning run! 🎉🎉🎉")