FedModelKit 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FedModelKit/README.md +25 -0
- FedModelKit/__init__.py +17 -0
- FedModelKit/aggregator.py +41 -0
- FedModelKit/cli.py +97 -0
- FedModelKit/default_create_functions.py +114 -0
- FedModelKit/interface.py +130 -0
- FedModelKit/local_learner.py +61 -0
- FedModelKit/py.typed +0 -0
- FedModelKit/src/utils.py +65 -0
- FedModelKit/templates/__init__template.py +0 -0
- FedModelKit/templates/client_app_template.py +118 -0
- FedModelKit/templates/ds_template.ipynb +332 -0
- FedModelKit/templates/extern_pyproject_template.toml +17 -0
- FedModelKit/templates/images/doSendModels.png +0 -0
- FedModelKit/templates/images/doWaitsForJobs.png +0 -0
- FedModelKit/templates/images/dsAggregateModels.png +0 -0
- FedModelKit/templates/images/dsDoneSubmittingJobs.png +0 -0
- FedModelKit/templates/images/dsSendsJobs.png +0 -0
- FedModelKit/templates/images/overview.png +0 -0
- FedModelKit/templates/main_template.py +27 -0
- FedModelKit/templates/pyproject_template.toml +53 -0
- FedModelKit/templates/readme_template.md +48 -0
- FedModelKit/templates/server_app_template.py +204 -0
- FedModelKit/templates/task_template.py +140 -0
- FedModelKit/templates/uv_template.lock +2812 -0
- FedModelKit/templates.py +76 -0
- fedmodelkit-0.5.0.dist-info/METADATA +283 -0
- fedmodelkit-0.5.0.dist-info/RECORD +31 -0
- fedmodelkit-0.5.0.dist-info/WHEEL +4 -0
- fedmodelkit-0.5.0.dist-info/entry_points.txt +2 -0
- fedmodelkit-0.5.0.dist-info/licenses/LICENSE +23 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "markdown",
|
|
5
|
+
"metadata": {},
|
|
6
|
+
"source": [
|
|
7
|
+
"# Data Scientist (DS)"
|
|
8
|
+
]
|
|
9
|
+
},
|
|
10
|
+
{
|
|
11
|
+
"cell_type": "markdown",
|
|
12
|
+
"metadata": {},
|
|
13
|
+
"source": [
|
|
14
|
+
"Change `LOCAL_TEST` to True if you want to run the clients locally to test. \n",
|
|
15
|
+
"With `LOCAL_TEST = False`, please have your syftbox client running. Installation instructions here https://www.syftbox.net/."
|
|
16
|
+
]
|
|
17
|
+
},
|
|
18
|
+
{
|
|
19
|
+
"cell_type": "code",
|
|
20
|
+
"execution_count": null,
|
|
21
|
+
"metadata": {},
|
|
22
|
+
"outputs": [],
|
|
23
|
+
"source": [
|
|
24
|
+
"LOCAL_TEST = True"
|
|
25
|
+
]
|
|
26
|
+
},
|
|
27
|
+
{
|
|
28
|
+
"cell_type": "markdown",
|
|
29
|
+
"metadata": {},
|
|
30
|
+
"source": [
|
|
31
|
+
"## Some paths and constants "
|
|
32
|
+
]
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
"cell_type": "code",
|
|
36
|
+
"execution_count": null,
|
|
37
|
+
"metadata": {},
|
|
38
|
+
"outputs": [],
|
|
39
|
+
"source": [
|
|
40
|
+
"from pathlib import Path\n",
|
|
41
|
+
"\n",
|
|
42
|
+
"SYFTBOX_DATASET_NAME = \"pima-indians-diabetes-database\""
|
|
43
|
+
]
|
|
44
|
+
},
|
|
45
|
+
{
|
|
46
|
+
"cell_type": "markdown",
|
|
47
|
+
"metadata": {},
|
|
48
|
+
"source": [
|
|
49
|
+
"## Log into the data owners' datasites"
|
|
50
|
+
]
|
|
51
|
+
},
|
|
52
|
+
{
|
|
53
|
+
"cell_type": "code",
|
|
54
|
+
"execution_count": null,
|
|
55
|
+
"metadata": {},
|
|
56
|
+
"outputs": [],
|
|
57
|
+
"source": [
|
|
58
|
+
"if LOCAL_TEST:\n",
|
|
59
|
+
" from syft_rds.orchestra import setup_rds_server\n",
|
|
60
|
+
"\n",
|
|
61
|
+
" print(\"Running locally!\")\n",
|
|
62
|
+
"\n",
|
|
63
|
+
" DS = \"ds@openmined.org\"\n",
|
|
64
|
+
" print(\"DS email: \", DS)\n",
|
|
65
|
+
"\n",
|
|
66
|
+
" DO1 = \"do1@openmined.org\"\n",
|
|
67
|
+
" DO2 = \"do2@openmined.org\"\n",
|
|
68
|
+
"\n",
|
|
69
|
+
" ds_stack = setup_rds_server(email=DS, key=\"flwr\", root_dir=Path(\".\"))\n",
|
|
70
|
+
" do_client_1 = ds_stack.init_session(host=DO1)\n",
|
|
71
|
+
" do_client_2 = ds_stack.init_session(host=DO2)\n",
|
|
72
|
+
"else:\n",
|
|
73
|
+
" import syft_rds as sy\n",
|
|
74
|
+
" from syft_core import Client\n",
|
|
75
|
+
"\n",
|
|
76
|
+
" DS = Client.load().email\n",
|
|
77
|
+
" print(\"DS email: \", DS)\n",
|
|
78
|
+
"\n",
|
|
79
|
+
" DO1 = \"flower-test-group-1@openmined.org\"\n",
|
|
80
|
+
" DO2 = \"flower-test-group-2@openmined.org\"\n",
|
|
81
|
+
"\n",
|
|
82
|
+
" do_client_1 = sy.init_session(host=DO1)\n",
|
|
83
|
+
" print(\"Logged into: \", do_client_1.host)\n",
|
|
84
|
+
"\n",
|
|
85
|
+
" do_client_2 = sy.init_session(host=DO2)\n",
|
|
86
|
+
" print(\"Logged into: \", do_client_2.host)\n",
|
|
87
|
+
"\n",
|
|
88
|
+
"do_clients = [do_client_1, do_client_2]\n",
|
|
89
|
+
"do_emails = [DO1, DO2]"
|
|
90
|
+
]
|
|
91
|
+
},
|
|
92
|
+
{
|
|
93
|
+
"cell_type": "markdown",
|
|
94
|
+
"metadata": {},
|
|
95
|
+
"source": [
|
|
96
|
+
"## Explore the datasets"
|
|
97
|
+
]
|
|
98
|
+
},
|
|
99
|
+
{
|
|
100
|
+
"cell_type": "code",
|
|
101
|
+
"execution_count": null,
|
|
102
|
+
"metadata": {},
|
|
103
|
+
"outputs": [],
|
|
104
|
+
"source": [
|
|
105
|
+
"SYFTBOX_DATASET_NAME"
|
|
106
|
+
]
|
|
107
|
+
},
|
|
108
|
+
{
|
|
109
|
+
"cell_type": "code",
|
|
110
|
+
"execution_count": null,
|
|
111
|
+
"metadata": {},
|
|
112
|
+
"outputs": [],
|
|
113
|
+
"source": [
|
|
114
|
+
"mock_paths = []\n",
|
|
115
|
+
"for client in do_clients:\n",
|
|
116
|
+
" dataset = client.dataset.get(name=SYFTBOX_DATASET_NAME)\n",
|
|
117
|
+
" mock_paths.append(dataset.get_mock_path())\n",
|
|
118
|
+
" print(f\"Client {client.host}'s dataset: \\n{dataset}\\n\")"
|
|
119
|
+
]
|
|
120
|
+
},
|
|
121
|
+
{
|
|
122
|
+
"cell_type": "markdown",
|
|
123
|
+
"metadata": {},
|
|
124
|
+
"source": [
|
|
125
|
+
"## Bootstrapping and run `syft_flwr` simulation"
|
|
126
|
+
]
|
|
127
|
+
},
|
|
128
|
+
{
|
|
129
|
+
"cell_type": "code",
|
|
130
|
+
"execution_count": null,
|
|
131
|
+
"metadata": {},
|
|
132
|
+
"outputs": [],
|
|
133
|
+
"source": [
|
|
134
|
+
"SYFT_FLWR_PROJECT_PATH = Path(\"./fl-diabetes-prediction\")\n",
|
|
135
|
+
"assert SYFT_FLWR_PROJECT_PATH.exists()"
|
|
136
|
+
]
|
|
137
|
+
},
|
|
138
|
+
{
|
|
139
|
+
"cell_type": "code",
|
|
140
|
+
"execution_count": null,
|
|
141
|
+
"metadata": {},
|
|
142
|
+
"outputs": [],
|
|
143
|
+
"source": [
|
|
144
|
+
"import syft_flwr\n",
|
|
145
|
+
"\n",
|
|
146
|
+
"try:\n",
|
|
147
|
+
" !rm -rf {SYFT_FLWR_PROJECT_PATH / \"main.py\"}\n",
|
|
148
|
+
" syft_flwr.bootstrap(SYFT_FLWR_PROJECT_PATH, aggregator=DS, datasites=do_emails)\n",
|
|
149
|
+
" print(\"Bootstrapped project successfully ✅\")\n",
|
|
150
|
+
"except Exception as e:\n",
|
|
151
|
+
" print(e)"
|
|
152
|
+
]
|
|
153
|
+
},
|
|
154
|
+
{
|
|
155
|
+
"cell_type": "markdown",
|
|
156
|
+
"metadata": {},
|
|
157
|
+
"source": [
|
|
158
|
+
"## Run `flwr` and `syft_flwr` simulations (optional)"
|
|
159
|
+
]
|
|
160
|
+
},
|
|
161
|
+
{
|
|
162
|
+
"cell_type": "code",
|
|
163
|
+
"execution_count": null,
|
|
164
|
+
"metadata": {},
|
|
165
|
+
"outputs": [],
|
|
166
|
+
"source": [
|
|
167
|
+
"RUN_SIMULATION = True"
|
|
168
|
+
]
|
|
169
|
+
},
|
|
170
|
+
{
|
|
171
|
+
"cell_type": "code",
|
|
172
|
+
"execution_count": null,
|
|
173
|
+
"metadata": {},
|
|
174
|
+
"outputs": [],
|
|
175
|
+
"source": [
|
|
176
|
+
"if RUN_SIMULATION:\n",
|
|
177
|
+
" !flwr run {SYFT_FLWR_PROJECT_PATH}"
|
|
178
|
+
]
|
|
179
|
+
},
|
|
180
|
+
{
|
|
181
|
+
"cell_type": "code",
|
|
182
|
+
"execution_count": null,
|
|
183
|
+
"metadata": {},
|
|
184
|
+
"outputs": [],
|
|
185
|
+
"source": [
|
|
186
|
+
"# clean up\n",
|
|
187
|
+
"!rm -rf {SYFT_FLWR_PROJECT_PATH / \"fl_diabetes_prediction\" / \"__pycache__\"}\n",
|
|
188
|
+
"!rm -rf weights/"
|
|
189
|
+
]
|
|
190
|
+
},
|
|
191
|
+
{
|
|
192
|
+
"cell_type": "code",
|
|
193
|
+
"execution_count": null,
|
|
194
|
+
"metadata": {},
|
|
195
|
+
"outputs": [],
|
|
196
|
+
"source": [
|
|
197
|
+
"mock_paths"
|
|
198
|
+
]
|
|
199
|
+
},
|
|
200
|
+
{
|
|
201
|
+
"cell_type": "code",
|
|
202
|
+
"execution_count": null,
|
|
203
|
+
"metadata": {},
|
|
204
|
+
"outputs": [],
|
|
205
|
+
"source": [
|
|
206
|
+
"if RUN_SIMULATION:\n",
|
|
207
|
+
" print(f\"running syft_flwr simulation with mock paths: {mock_paths}\")\n",
|
|
208
|
+
" syft_flwr.run(SYFT_FLWR_PROJECT_PATH, mock_paths)"
|
|
209
|
+
]
|
|
210
|
+
},
|
|
211
|
+
{
|
|
212
|
+
"cell_type": "markdown",
|
|
213
|
+
"metadata": {},
|
|
214
|
+
"source": [
|
|
215
|
+
"## Submit jobs"
|
|
216
|
+
]
|
|
217
|
+
},
|
|
218
|
+
{
|
|
219
|
+
"cell_type": "markdown",
|
|
220
|
+
"metadata": {},
|
|
221
|
+
"source": [
|
|
222
|
+
"<img src=\"./images/dsSendsJobs.png\" width=\"80%\" alt=\"DS Submits Jobs\">"
|
|
223
|
+
]
|
|
224
|
+
},
|
|
225
|
+
{
|
|
226
|
+
"cell_type": "code",
|
|
227
|
+
"execution_count": null,
|
|
228
|
+
"metadata": {},
|
|
229
|
+
"outputs": [],
|
|
230
|
+
"source": [
|
|
231
|
+
"# clean up before submitting jobs\n",
|
|
232
|
+
"!rm -rf {SYFT_FLWR_PROJECT_PATH / \"fl_diabetes_prediction\" / \"__pycache__\"}\n",
|
|
233
|
+
"!rm -rf {SYFT_FLWR_PROJECT_PATH / \"simulation_logs\"}\n",
|
|
234
|
+
"!rm -rf weights/"
|
|
235
|
+
]
|
|
236
|
+
},
|
|
237
|
+
{
|
|
238
|
+
"cell_type": "code",
|
|
239
|
+
"execution_count": null,
|
|
240
|
+
"metadata": {},
|
|
241
|
+
"outputs": [],
|
|
242
|
+
"source": [
|
|
243
|
+
"for client in do_clients:\n",
|
|
244
|
+
" print(f\"sending job to {client.host}\")\n",
|
|
245
|
+
" job = client.jobs.submit(\n",
|
|
246
|
+
" name=\"Syft Flower Experiment\",\n",
|
|
247
|
+
" description=\"Syft Flower Federated Learning Experiment\",\n",
|
|
248
|
+
" user_code_path=SYFT_FLWR_PROJECT_PATH,\n",
|
|
249
|
+
" dataset_name=SYFTBOX_DATASET_NAME,\n",
|
|
250
|
+
" tags=[\"federated learning\", \"fl\", \"syft_flwr\", \"flwr\"],\n",
|
|
251
|
+
" entrypoint=\"main.py\",\n",
|
|
252
|
+
" )\n",
|
|
253
|
+
" print(job)"
|
|
254
|
+
]
|
|
255
|
+
},
|
|
256
|
+
{
|
|
257
|
+
"cell_type": "markdown",
|
|
258
|
+
"metadata": {},
|
|
259
|
+
"source": [
|
|
260
|
+
"<img src=\"./images/dsDoneSubmittingJobs.png\" width=\"40%\" alt=\"DS waits for jobs to be approved\">"
|
|
261
|
+
]
|
|
262
|
+
},
|
|
263
|
+
{
|
|
264
|
+
"cell_type": "markdown",
|
|
265
|
+
"metadata": {},
|
|
266
|
+
"source": [
|
|
267
|
+
"## DS starts the FL server code"
|
|
268
|
+
]
|
|
269
|
+
},
|
|
270
|
+
{
|
|
271
|
+
"cell_type": "code",
|
|
272
|
+
"execution_count": null,
|
|
273
|
+
"metadata": {},
|
|
274
|
+
"outputs": [],
|
|
275
|
+
"source": [
|
|
276
|
+
"import os\n",
|
|
277
|
+
"\n",
|
|
278
|
+
"if LOCAL_TEST:\n",
|
|
279
|
+
" os.environ[\"SYFTBOX_CLIENT_CONFIG_PATH\"] = str(ds_stack.client.config_path)\n",
|
|
280
|
+
"\n",
|
|
281
|
+
"os.environ[\"LOGURU_LEVEL\"] = \"DEBUG\"\n",
|
|
282
|
+
"os.environ[\"SYFT_FLWR_MSG_TIMEOUT\"] = \"60\"\n",
|
|
283
|
+
"\n",
|
|
284
|
+
"!uv run {str(SYFT_FLWR_PROJECT_PATH / \"main.py\")} --active"
|
|
285
|
+
]
|
|
286
|
+
},
|
|
287
|
+
{
|
|
288
|
+
"cell_type": "markdown",
|
|
289
|
+
"metadata": {},
|
|
290
|
+
"source": [
|
|
291
|
+
"By running the FL server code, the DS aggregates the models trained on DOs' private local data into an improved global model\n",
|
|
292
|
+
"\n",
|
|
293
|
+
"<img src=\"./images/dsAggregateModels.png\" width=\"30%\" alt=\"DS Aggregates Models\">"
|
|
294
|
+
]
|
|
295
|
+
},
|
|
296
|
+
{
|
|
297
|
+
"cell_type": "markdown",
|
|
298
|
+
"metadata": {},
|
|
299
|
+
"source": [
|
|
300
|
+
"## DS Observes the Results"
|
|
301
|
+
]
|
|
302
|
+
},
|
|
303
|
+
{
|
|
304
|
+
"cell_type": "markdown",
|
|
305
|
+
"metadata": {},
|
|
306
|
+
"source": [
|
|
307
|
+
"Now the DS can monitor the aggregated models trained no DO's private datasets in the `weights` folder"
|
|
308
|
+
]
|
|
309
|
+
}
|
|
310
|
+
],
|
|
311
|
+
"metadata": {
|
|
312
|
+
"kernelspec": {
|
|
313
|
+
"display_name": ".venv",
|
|
314
|
+
"language": "python",
|
|
315
|
+
"name": "python3"
|
|
316
|
+
},
|
|
317
|
+
"language_info": {
|
|
318
|
+
"codemirror_mode": {
|
|
319
|
+
"name": "ipython",
|
|
320
|
+
"version": 3
|
|
321
|
+
},
|
|
322
|
+
"file_extension": ".py",
|
|
323
|
+
"mimetype": "text/x-python",
|
|
324
|
+
"name": "python",
|
|
325
|
+
"nbconvert_exporter": "python",
|
|
326
|
+
"pygments_lexer": "ipython3",
|
|
327
|
+
"version": "3.12.9"
|
|
328
|
+
}
|
|
329
|
+
},
|
|
330
|
+
"nbformat": 4,
|
|
331
|
+
"nbformat_minor": 2
|
|
332
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "EXPERIMENT_NAME"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Description of the experiment"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.13"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"flwr-datasets>=0.5.0",
|
|
9
|
+
"flwr[simulation]==1.17.0",
|
|
10
|
+
"imblearn>=0.0",
|
|
11
|
+
"jupyterlab>=4.4.4",
|
|
12
|
+
"loguru>=0.7.3",
|
|
13
|
+
"pandas>=2.3.0",
|
|
14
|
+
"scikit-learn==1.6.1",
|
|
15
|
+
"syft-flwr>=0.1.5",
|
|
16
|
+
"torch==2.7.0",
|
|
17
|
+
]
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from syft_core import Client
|
|
5
|
+
|
|
6
|
+
from syft_flwr.config import load_flwr_pyproject
|
|
7
|
+
from syft_flwr.run import syftbox_run_flwr_client, syftbox_run_flwr_server
|
|
8
|
+
|
|
9
|
+
DATA_DIR = os.getenv("DATA_DIR")
|
|
10
|
+
OUTPUT_DIR = os.getenv("OUTPUT_DIR")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
flower_project_dir = Path(__file__).parent.absolute()
|
|
14
|
+
client = Client.load()
|
|
15
|
+
config = load_flwr_pyproject(flower_project_dir)
|
|
16
|
+
|
|
17
|
+
is_client = client.email in config["tool"]["syft_flwr"]["datasites"]
|
|
18
|
+
is_server = client.email in config["tool"]["syft_flwr"]["aggregator"]
|
|
19
|
+
|
|
20
|
+
if is_client:
|
|
21
|
+
# run by each DO
|
|
22
|
+
syftbox_run_flwr_client(flower_project_dir)
|
|
23
|
+
elif is_server:
|
|
24
|
+
# run by the DS
|
|
25
|
+
syftbox_run_flwr_server(flower_project_dir)
|
|
26
|
+
else:
|
|
27
|
+
raise ValueError(f"{client.email} is not in config.datasites or config.aggregator")
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [
|
|
3
|
+
"hatchling",
|
|
4
|
+
]
|
|
5
|
+
build-backend = "hatchling.build"
|
|
6
|
+
|
|
7
|
+
[project]
|
|
8
|
+
name = "EXPERIMENT_NAME"
|
|
9
|
+
version = "1.0.0"
|
|
10
|
+
description = "Description of the experiment"
|
|
11
|
+
license = "Apache-2.0"
|
|
12
|
+
dependencies = [
|
|
13
|
+
"flwr[simulation]==1.17.0",
|
|
14
|
+
"flwr-datasets>=0.5.0",
|
|
15
|
+
"torch==2.7.0",
|
|
16
|
+
"imblearn",
|
|
17
|
+
"pandas",
|
|
18
|
+
"scikit-learn==1.6.1",
|
|
19
|
+
"loguru",
|
|
20
|
+
"jupyter",
|
|
21
|
+
"syft_flwr==0.1.5",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[tool.hatch.build.targets.wheel]
|
|
25
|
+
packages = [
|
|
26
|
+
".",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[tool.flwr.app]
|
|
30
|
+
publisher = "OpenMined"
|
|
31
|
+
|
|
32
|
+
[tool.flwr.app.components]
|
|
33
|
+
serverapp = "EXPERIMENT_NAME.server_app:app"
|
|
34
|
+
clientapp = "EXPERIMENT_NAME.client_app:app"
|
|
35
|
+
|
|
36
|
+
[tool.flwr.app.config]
|
|
37
|
+
num-server-rounds = 3
|
|
38
|
+
partition-id = 0
|
|
39
|
+
num-partitions = 1
|
|
40
|
+
|
|
41
|
+
[tool.flwr.federations]
|
|
42
|
+
default = "local-simulation"
|
|
43
|
+
|
|
44
|
+
[tool.flwr.federations.local-simulation.options]
|
|
45
|
+
num-supernodes = 2
|
|
46
|
+
|
|
47
|
+
[tool.syft_flwr]
|
|
48
|
+
app_name = "ds@openmined.org_fl-diabetes-prediction_1751887833"
|
|
49
|
+
datasites = [
|
|
50
|
+
"do1@openmined.org",
|
|
51
|
+
"do2@openmined.org",
|
|
52
|
+
]
|
|
53
|
+
aggregator = "ds@openmined.org"
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Diabetes Prediction with `syft_flwr`
|
|
2
|
+
|
|
3
|
+
## Introduction
|
|
4
|
+
|
|
5
|
+
In this tutorial, we'll walk through a practical federated learning implementation for diabetes prediction using [syft_flwr](https://github.com/OpenMined/syft-flwr) — a framework that combines the flexibility of [Flower](https://github.com/adap/flower/) (a popular federated learning framework) with the privacy-preserving networking capabilities of [syftbox](https://www.syftbox.net/).
|
|
6
|
+
|
|
7
|
+

|
|
8
|
+
|
|
9
|
+
Dataset: https://www.kaggle.com/datasets/uciml/pima-indians-diabetes-database/
|
|
10
|
+
|
|
11
|
+
## Set up
|
|
12
|
+
|
|
13
|
+
### Setup python virtual environment
|
|
14
|
+
Assume that you have python and the [uv](https://docs.astral.sh/uv/) package manager installed. Now let's create a virtual python environment with all dependencies installed:
|
|
15
|
+
```bash
|
|
16
|
+
uv sync
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
### Install and run `syftbox` client
|
|
20
|
+
Make sure you have syftbox client running in a terminal:
|
|
21
|
+
1. Install `syftbox`: `curl -fsSL https://syftbox.net/install.sh | sh`
|
|
22
|
+
2. Follow the instructions to start your `syftbox` client
|
|
23
|
+
|
|
24
|
+
When you have `syftbox` installed and run in the background, you can proceeed and run the notebooks with the installed Python environment in your favorite IDE.
|
|
25
|
+
|
|
26
|
+
### Local Setup
|
|
27
|
+
At the start of the notebooks, you can set the flag `LOCAL_TEST` to `True` if you want to run all the clients (2 data owners and 1 data scientist) locally to test the whole workflow, where all clients' local datasites will be saved locally under the `flwr` folder. For this, you need to follow the running the notebooks `do1.ipynb`, `do2.ipynb`, then `ds.ipynb` in order.
|
|
28
|
+
|
|
29
|
+
If running locally, you don't need your `syftbox client` running.
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
### Distributed setup
|
|
33
|
+
Set `LOCAL_TEST` to `False` and make sure you have your `syftbox` client running if you want to run the clients over the `syftbox` network.
|
|
34
|
+
|
|
35
|
+
1. For the data scientist's workflow (prepare code, observe mock datasets on the data owner's datasites, submit jobs), please look into the `ds.ipynb` notebook. Following this notebook will help you submit jobs to two datasites named `flower-test-group-1@openmined.org` and `flower-test-group-2@openmined.org` that host 2 partitions of the `pima-indians-diabetes-database`, and they will approve your job automatically.
|
|
36
|
+
|
|
37
|
+
Optionally, you can look at the `local_training.ipynb` to see the DS's process of processing data and training the neural network locally.
|
|
38
|
+
|
|
39
|
+
2. For the data owner's workflow (uploading dataset, monitor and approve jobs), please take a look at `do.ipynb` notebook. Following this notebook, you will learn how to upload your partition of the `pima-indians-diabetes-database` so others can submit jobs to you.
|
|
40
|
+
|
|
41
|
+
## References
|
|
42
|
+
- https://syftbox.net
|
|
43
|
+
- https://www.kaggle.com/datasets/uciml/pima-indians-diabetes-database/
|
|
44
|
+
- https://github.com/OpenMined/syftbox
|
|
45
|
+
- https://github.com/OpenMined/syft-flwr
|
|
46
|
+
- https://github.com/adap/flower/
|
|
47
|
+
- https://github.com/OpenMined/rds
|
|
48
|
+
- https://github.com/elarsiad/diabetes-prediction-keras
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import List
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import flwr as fl
|
|
6
|
+
from flwr.common import (
|
|
7
|
+
Context,
|
|
8
|
+
NDArrays,
|
|
9
|
+
Message,
|
|
10
|
+
MessageType,
|
|
11
|
+
Metrics,
|
|
12
|
+
RecordSet,
|
|
13
|
+
ConfigsRecord,
|
|
14
|
+
DEFAULT_TTL,
|
|
15
|
+
)
|
|
16
|
+
from flwr.server import Driver
|
|
17
|
+
|
|
18
|
+
import FedModelKit as msi
|
|
19
|
+
|
|
20
|
+
from model_example import create_local_learner #type: ignore[import]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Run via `flower-server-app server:app`
|
|
24
|
+
app = fl.server.ServerApp()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@app.main()
|
|
30
|
+
def main(driver: Driver, context: Context) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Main function to run the federated learning server.
|
|
33
|
+
|
|
34
|
+
Structure:
|
|
35
|
+
- Send a query message to clients for creating the local learner and loading the data
|
|
36
|
+
- Start global epochs loop for training and evaluation
|
|
37
|
+
- Send training messages to clients
|
|
38
|
+
- Aggregate parameters received from clients
|
|
39
|
+
- Send evaluation messages to clients
|
|
40
|
+
- Aggregate evaluation metrics
|
|
41
|
+
"""
|
|
42
|
+
print("Starting test run")
|
|
43
|
+
|
|
44
|
+
# Get node IDs of connected clients
|
|
45
|
+
node_ids = driver.get_node_ids()
|
|
46
|
+
|
|
47
|
+
# Initialize the federated model
|
|
48
|
+
federated_model = msi.FederatedModel(create_local_learner=create_local_learner,
|
|
49
|
+
model_name='simple_lr')
|
|
50
|
+
global_model = federated_model.create_local_learner()
|
|
51
|
+
aggregation_strategy = federated_model.create_aggregator()
|
|
52
|
+
|
|
53
|
+
# Send a query message to clients for creating the local learner and loading the data
|
|
54
|
+
messages = []
|
|
55
|
+
for idx, node_id in enumerate(node_ids):
|
|
56
|
+
# Create messages to send to clients
|
|
57
|
+
recordset = RecordSet()
|
|
58
|
+
|
|
59
|
+
# Add a config with information to send the client for the query
|
|
60
|
+
recordset.configs_records["fancy_config"] = ConfigsRecord({"num_clients": len(node_ids), "client_id": idx})
|
|
61
|
+
|
|
62
|
+
# Create a query message for each client
|
|
63
|
+
message = driver.create_message(
|
|
64
|
+
content=recordset,
|
|
65
|
+
message_type=MessageType.QUERY,
|
|
66
|
+
dst_node_id=node_id,
|
|
67
|
+
group_id=str(1),
|
|
68
|
+
ttl=DEFAULT_TTL,
|
|
69
|
+
)
|
|
70
|
+
messages.append(message)
|
|
71
|
+
|
|
72
|
+
# Send training messages to clients
|
|
73
|
+
message_ids = driver.push_messages(messages)
|
|
74
|
+
print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
|
|
75
|
+
|
|
76
|
+
# Wait for results from clients
|
|
77
|
+
message_ids = [message_id for message_id in message_ids if message_id != ""]
|
|
78
|
+
all_replies: List[Message] = []
|
|
79
|
+
while True:
|
|
80
|
+
replies = driver.pull_messages(message_ids=message_ids)
|
|
81
|
+
print(f"Got {len(list(replies))} results")
|
|
82
|
+
all_replies += replies
|
|
83
|
+
if len(all_replies) == len(message_ids):
|
|
84
|
+
break
|
|
85
|
+
time.sleep(12)
|
|
86
|
+
|
|
87
|
+
# Filter out messages with errors
|
|
88
|
+
all_replies = [
|
|
89
|
+
msg
|
|
90
|
+
for msg in all_replies
|
|
91
|
+
if msg.has_content()
|
|
92
|
+
]
|
|
93
|
+
print(f"Received {len(all_replies)} answers")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Run federated training and evaluation for a fixed number of rounds
|
|
97
|
+
for server_round in range(3):
|
|
98
|
+
print(f"Commencing server train and evaluation round {server_round + 1}")
|
|
99
|
+
|
|
100
|
+
messages = []
|
|
101
|
+
for idx, node_id in enumerate(node_ids):
|
|
102
|
+
# Create messages to send to clients
|
|
103
|
+
recordset = RecordSet()
|
|
104
|
+
|
|
105
|
+
# Add model parameters to record
|
|
106
|
+
recordset.parameters_records["fancy_model"] = global_model.get_parameters()
|
|
107
|
+
# Add a config with information to send the client for training
|
|
108
|
+
recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
|
|
109
|
+
|
|
110
|
+
# Create a training message for each client
|
|
111
|
+
message = driver.create_message(
|
|
112
|
+
content=recordset,
|
|
113
|
+
message_type=MessageType.TRAIN,
|
|
114
|
+
dst_node_id=node_id,
|
|
115
|
+
group_id=str(server_round),
|
|
116
|
+
ttl=DEFAULT_TTL,
|
|
117
|
+
)
|
|
118
|
+
messages.append(message)
|
|
119
|
+
|
|
120
|
+
# Send training messages to clients
|
|
121
|
+
message_ids = driver.push_messages(messages)
|
|
122
|
+
print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
|
|
123
|
+
|
|
124
|
+
# Wait for results from clients
|
|
125
|
+
message_ids = [message_id for message_id in message_ids if message_id != ""]
|
|
126
|
+
all_replies: List[Message] = []
|
|
127
|
+
while True:
|
|
128
|
+
replies = driver.pull_messages(message_ids=message_ids)
|
|
129
|
+
print(f"Got {len(list(replies))} results")
|
|
130
|
+
all_replies += replies
|
|
131
|
+
if len(all_replies) == len(message_ids):
|
|
132
|
+
break
|
|
133
|
+
time.sleep(12)
|
|
134
|
+
|
|
135
|
+
# Filter out messages with errors
|
|
136
|
+
all_replies = [
|
|
137
|
+
msg
|
|
138
|
+
for msg in all_replies
|
|
139
|
+
if msg.has_content()
|
|
140
|
+
]
|
|
141
|
+
print(f"Received {len(all_replies)} results")
|
|
142
|
+
|
|
143
|
+
# Print metrics received from clients
|
|
144
|
+
for reply in all_replies:
|
|
145
|
+
print(reply.content.metrics_records)
|
|
146
|
+
|
|
147
|
+
# Aggregate parameters received from clients
|
|
148
|
+
parameter_records_list = [reply.content.parameters_records["fancy_model_returned"] for reply in all_replies]
|
|
149
|
+
new_parameter_record = aggregation_strategy.aggregate_parameters(parameter_records_list)
|
|
150
|
+
global_model.set_parameters(new_parameter_record)
|
|
151
|
+
|
|
152
|
+
# Evaluate the updated global model
|
|
153
|
+
messages = []
|
|
154
|
+
for idx, node_id in enumerate(node_ids):
|
|
155
|
+
# Create evaluation messages for clients
|
|
156
|
+
recordset = RecordSet()
|
|
157
|
+
|
|
158
|
+
# Add updated model parameters to record
|
|
159
|
+
recordset.parameters_records["fancy_model"] = new_parameter_record
|
|
160
|
+
# Add a config with information to send the client for evaluation
|
|
161
|
+
recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
|
|
162
|
+
|
|
163
|
+
# Create an evaluation message for each client
|
|
164
|
+
message = driver.create_message(
|
|
165
|
+
content=recordset,
|
|
166
|
+
message_type=MessageType.EVALUATE,
|
|
167
|
+
dst_node_id=node_id,
|
|
168
|
+
group_id=str(server_round),
|
|
169
|
+
ttl=DEFAULT_TTL,
|
|
170
|
+
)
|
|
171
|
+
messages.append(message)
|
|
172
|
+
|
|
173
|
+
# Send evaluation messages to clients
|
|
174
|
+
message_ids = driver.push_messages(messages)
|
|
175
|
+
print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
|
|
176
|
+
|
|
177
|
+
# Wait for evaluation results from clients
|
|
178
|
+
message_ids = [message_id for message_id in message_ids if message_id != ""]
|
|
179
|
+
all_replies: List[Message] = []
|
|
180
|
+
while True:
|
|
181
|
+
replies = driver.pull_messages(message_ids=message_ids)
|
|
182
|
+
print(f"Got {len(list(replies))} results")
|
|
183
|
+
all_replies += replies
|
|
184
|
+
if len(all_replies) == len(message_ids):
|
|
185
|
+
break
|
|
186
|
+
time.sleep(3)
|
|
187
|
+
|
|
188
|
+
# Filter out messages with errors
|
|
189
|
+
all_replies = [
|
|
190
|
+
msg
|
|
191
|
+
for msg in all_replies
|
|
192
|
+
if msg.has_content()
|
|
193
|
+
]
|
|
194
|
+
print(f"Received {len(all_replies)} results")
|
|
195
|
+
|
|
196
|
+
# Print evaluation metrics received from clients
|
|
197
|
+
metrics_records_list = [reply.content.metrics_records['eval_metrics'] for reply in all_replies]
|
|
198
|
+
for i, reply in enumerate(all_replies):
|
|
199
|
+
print(f"Client {i+1} metrics: ", reply.content.metrics_records['eval_metrics'])
|
|
200
|
+
|
|
201
|
+
# Aggregate evaluation metrics
|
|
202
|
+
print("Aggregated metrics result: ", aggregation_strategy.aggregate_metrics(metrics_records_list))
|
|
203
|
+
|
|
204
|
+
print("🎉🎉🎉 Successfully completed federated learning run! 🎉🎉🎉")
|