flwr-nightly 1.22.0.dev20250913__py3-none-any.whl → 1.22.0.dev20250916__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. flwr/cli/new/new.py +5 -5
  2. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  3. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  4. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  5. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  6. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  7. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
  9. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +2 -2
  10. flwr/serverapp/strategy/__init__.py +8 -0
  11. flwr/serverapp/strategy/fedavg.py +23 -2
  12. flwr/serverapp/strategy/fedavgm.py +198 -0
  13. flwr/serverapp/strategy/fedmedian.py +71 -0
  14. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  15. flwr/serverapp/strategy/fedxgb_bagging.py +82 -0
  16. flwr/serverapp/strategy/strategy_utils.py +48 -0
  17. flwr/serverapp/strategy/strategy_utils_tests.py +20 -1
  18. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/METADATA +6 -16
  19. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/RECORD +22 -18
  20. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  21. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  22. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  23. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/WHEEL +0 -0
  24. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ flwr/cli/login/__init__.py,sha256=B1SXKU3HCQhWfFDMJhlC7FOl8UsvH4mxysxeBnrfyUE,80
18
18
  flwr/cli/login/login.py,sha256=RM1Jiv_VFm3oz4rTHSr3D87X90lW3WzErjBBU7WviWY,4309
19
19
  flwr/cli/ls.py,sha256=3YK7cpoImJ7PbjlP_JgYRQWz1GymX2q7Reu-mKJEpao,10957
20
20
  flwr/cli/new/__init__.py,sha256=QA1E2QtzPvFCjLTUHnFnJbufuFiGyT_0Y53Wpbvg1F0,790
21
- flwr/cli/new/new.py,sha256=46QuAi7Act3_TbD0IkejUhognXPXlo2r3LRPvN8pEkA,10503
21
+ flwr/cli/new/new.py,sha256=KyTs9Fbm4eoJ5DohhuTkYNJJX5rDC0p-YTPtNatYXrI,10529
22
22
  flwr/cli/new/templates/__init__.py,sha256=FpjWCfIySU2DB4kh0HOXLAjlZNNFDTVU4w3HoE2TzcI,725
23
23
  flwr/cli/new/templates/app/.gitignore.tpl,sha256=HZJcGQoxp7aUzaPg8Uqch3kNrIESwr9yjimDxJYgXVY,3104
24
24
  flwr/cli/new/templates/app/LICENSE.tpl,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
@@ -29,14 +29,14 @@ flwr/cli/new/templates/app/__init__.py,sha256=LbR0ksGiF566JcHM_H5m1Tc4-oYUEilWFl
29
29
  flwr/cli/new/templates/app/code/__init__.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
30
30
  flwr/cli/new/templates/app/code/__init__.py,sha256=zXa2YU1swzHxOKDQbwlDMEwVPOUswVeosjkiXNMTgFo,736
31
31
  flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=J0Gn74E7khpLyKJVNqOPu7ev93vkcu1PZugsbxtABMw,52
32
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl,sha256=mKIS8MK_X8T9NlmcX1-_c9Bbexc-ueqDIBI7uN6c4dE,45
32
+ flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl,sha256=mKIS8MK_X8T9NlmcX1-_c9Bbexc-ueqDIBI7uN6c4dE,45
33
33
  flwr/cli/new/templates/app/code/client.baseline.py.tpl,sha256=IYlCZqnaxT2ucP1ReffRNohOkYwNrhtrnDoQBBcrThY,1901
34
34
  flwr/cli/new/templates/app/code/client.huggingface.py.tpl,sha256=SIZZ3s-6u8IU8cFfsqu6ZU8zjhfI1m1SWauOSUcW8TA,3015
35
35
  flwr/cli/new/templates/app/code/client.jax.py.tpl,sha256=uFCIPwAHYiRAgh2W3nRni_Oig02ZzRF-ofUG5O19zcE,2125
36
36
  flwr/cli/new/templates/app/code/client.mlx.py.tpl,sha256=CHU2IBIzI2YENZZuvTsAlSdL94DK19wMYMIhr-JgwZ8,3422
37
37
  flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=1_WEoOPe9jJeK-7FZgYuDUqY8mC0vxgqA83d-h201Gk,1381
38
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=fuxVmZpjHIueNy_aHWF81531vmi8DGu4CYjYDqmUwWo,1705
39
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl,sha256=fYoh-dTu07LkqNYvwcxQnbgVvH4Yo4eiGEcyHECbsnU,2473
38
+ flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=fYoh-dTu07LkqNYvwcxQnbgVvH4Yo4eiGEcyHECbsnU,2473
39
+ flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl,sha256=fuxVmZpjHIueNy_aHWF81531vmi8DGu4CYjYDqmUwWo,1705
40
40
  flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=0qqEe-RRjkHGOH8gsD9e83ae-kyyYixhyBgzVHjYpzk,3500
41
41
  flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=8o55KXpsbF_rv6o98ZNYJDCazjwMp_RPTaSzDfT7Qlw,2682
42
42
  flwr/cli/new/templates/app/code/dataset.baseline.py.tpl,sha256=jbd_exHAk2-Blu_kVutjPO6a_dkJQWb232zxSeXIZ1k,1453
@@ -52,8 +52,8 @@ flwr/cli/new/templates/app/code/server.huggingface.py.tpl,sha256=_2Mv-SqGSMf7sMd
52
52
  flwr/cli/new/templates/app/code/server.jax.py.tpl,sha256=RW-rh7ogcJ3_BD66bJxTw-ZoP7c-4SK8hVHc-e0SSVY,1029
53
53
  flwr/cli/new/templates/app/code/server.mlx.py.tpl,sha256=J8rIe6RL2ndODVJD79xShRKBH70HljFSCi4s_RJ-xLQ,1200
54
54
  flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=T3hcKbPw3uL5lXEP-MuVJXIBXjzva5sWJXfpQqarUwA,955
55
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=gvBsGA_Jg9kAH8xTxjzTjMcvBtciuccOwQFbO7ey8tU,916
56
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl,sha256=epARqfcQ-EQsdZwaaaUp5y4OSTBT6CiFGlNRocw-23A,1158
55
+ flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=epARqfcQ-EQsdZwaaaUp5y4OSTBT6CiFGlNRocw-23A,1158
56
+ flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl,sha256=gvBsGA_Jg9kAH8xTxjzTjMcvBtciuccOwQFbO7ey8tU,916
57
57
  flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=ehQ5VRgBn92WeFl6kupwJnuxSNkKvE-EvKde6A9mNQo,1377
58
58
  flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=2-WTOPd-ewdLd9QmSlflIH7ix7zxAzPEOZoyiPBOy8c,1010
59
59
  flwr/cli/new/templates/app/code/strategy.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
@@ -61,8 +61,8 @@ flwr/cli/new/templates/app/code/task.huggingface.py.tpl,sha256=piBbY3Dg60bQnCg15
61
61
  flwr/cli/new/templates/app/code/task.jax.py.tpl,sha256=Fb0XgdTAQplM-ZCusI081XA9asO3gHptH772S-Xcyy8,1525
62
62
  flwr/cli/new/templates/app/code/task.mlx.py.tpl,sha256=YxH5z4s5kOh5_9DIY9pvzqURckLDfgdanTA68_iM_Wo,2946
63
63
  flwr/cli/new/templates/app/code/task.numpy.py.tpl,sha256=CwUJPnN3z6GjP8-KVGWzx7RYRJsl0wLFZ72xscvl3RM,126
64
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=XlJqA4Ix_PloO_zJLhjiN5vDj16w3I4CPVGdmbe8asE,3800
65
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl,sha256=RKA5lV6O6OnVKZ2r75pbzwy9arg5o2lzXqG2kNrLIUU,3446
64
+ flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=RKA5lV6O6OnVKZ2r75pbzwy9arg5o2lzXqG2kNrLIUU,3446
65
+ flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl,sha256=XlJqA4Ix_PloO_zJLhjiN5vDj16w3I4CPVGdmbe8asE,3800
66
66
  flwr/cli/new/templates/app/code/task.sklearn.py.tpl,sha256=vHdhtMp0FHxbYafXyhDT9aKmmmA0Jvpx5Oum1Yu9lWY,1850
67
67
  flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=impgWN7MfztmcWF4xh1llcZGsgTvrb1HD5ZE0t-8U08,1731
68
68
  flwr/cli/new/templates/app/code/utils.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
@@ -72,8 +72,8 @@ flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=xHGF38i7oFpvnFv
72
72
  flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=fdDhwmPoMirJ095cU_vFCBf0ILQlAoa1fdnHb2LM1yk,1471
73
73
  flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=PAjPT2v06sBZxacNiyMJloDwocCK5tFcGQmMXOoBqc8,1542
74
74
  flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=Kb_O2iQfzwc6FTy3fWqtQYc3FwY6x9SUgQPGqZR_ILg,1409
75
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
76
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl,sha256=SE4H23OFkQbqNU64nYf38igqrT4cJGA7XxEtSnNxJqg,1490
75
+ flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=SE4H23OFkQbqNU64nYf38igqrT4cJGA7XxEtSnNxJqg,1490
76
+ flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
77
77
  flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=apauU_PUmLEbt2rjckKniEbzdRs1EnMri_qgtHtBJZ8,1484
78
78
  flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=LQpDKJTEnRKj5Ygn5FkT44SxlnLVprkPlbrGaFf5Q50,1508
79
79
  flwr/cli/run/__init__.py,sha256=RPyB7KbYTFl6YRiilCch6oezxrLQrl1kijV7BMGkLbA,790
@@ -332,17 +332,21 @@ flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAy
332
332
  flwr/serverapp/__init__.py,sha256=ZujKNXULwhWYQhFnxOOT5Wi9MRq2JCWFhAAj7ouiQ78,884
333
333
  flwr/serverapp/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
334
334
  flwr/serverapp/exception.py,sha256=5cuH-2AafvihzosWDdDjuMmHdDqZ1XxHvCqZXNBVklw,1334
335
- flwr/serverapp/strategy/__init__.py,sha256=yAYBZUkp4aNmcTLsvormEc9HyO34oEoFN45LiHgujE0,1229
335
+ flwr/serverapp/strategy/__init__.py,sha256=MHrU_tz_myWqzG3h4gZdIpt2DDN-JdNK-HHIcrz1-Ns,1448
336
336
  flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
337
337
  flwr/serverapp/strategy/fedadagrad.py,sha256=fD65P6OEERa_pxq847e1UZpA083AcWR44XavYB0naGM,6343
338
338
  flwr/serverapp/strategy/fedadam.py,sha256=s3xPIqhopy6yPTeFxevSPnc7a6BcKnKsvo2AaO6Z_xs,7138
339
- flwr/serverapp/strategy/fedavg.py,sha256=53L06lZLkbGV0TRZrUWvPaocvFTT1PAhTvu9UkKq1zE,11294
339
+ flwr/serverapp/strategy/fedavg.py,sha256=Bq_nlmngzJbjqX1fF1mevXGVN6-pwglHv-6yNrs6lkA,12035
340
+ flwr/serverapp/strategy/fedavgm.py,sha256=VlByltWzUYCoiVIWPFRqsqLKNWjlOlO2INK8SUxEjzk,8327
341
+ flwr/serverapp/strategy/fedmedian.py,sha256=b31Dk0LQBbQxi_f-jeSbWHI7iOBugcuBSN2Az-_a75E,2596
340
342
  flwr/serverapp/strategy/fedopt.py,sha256=kqT0uV2IUE93O72XEVa1JJo61dcwbZEoT9KmYTjR2tE,8477
343
+ flwr/serverapp/strategy/fedtrimmedavg.py,sha256=4-QxgAQGo_7vB_L7qDYy28d95OBt9MeDa92yaTRMHqk,7166
344
+ flwr/serverapp/strategy/fedxgb_bagging.py,sha256=ktDjzov4y0BRecioq788umCEtcuwElou9olBizQKOnM,3282
341
345
  flwr/serverapp/strategy/fedyogi.py,sha256=1Ripr4Hi2cdeTOLiFOXtMKvOxR3BsUQwc7bbTrXN4LM,6653
342
346
  flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
343
347
  flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
344
- flwr/serverapp/strategy/strategy_utils.py,sha256=9ga93Se21I_k7zYiw343EMC2qCTQ8rUG5ZEm8HVEuFs,9246
345
- flwr/serverapp/strategy/strategy_utils_tests.py,sha256=o32XHujd9PLCB-YZMI2AttWLlvUXHe9yuxgiCrCkpgU,10209
348
+ flwr/serverapp/strategy/strategy_utils.py,sha256=hiwS7k-Hx6_c4NZXoKpHucS5CBKb7f8GppXRBSMt3Us,10851
349
+ flwr/serverapp/strategy/strategy_utils_tests.py,sha256=_adS23Lrv1QA6V_3oZ7P_csMd8RqDObFeIhOkFnNtTg,10690
346
350
  flwr/simulation/__init__.py,sha256=Gg6OsP1Z-ixc3-xxzvl7j7rz2Fijy9rzyEPpxgAQCeM,1556
347
351
  flwr/simulation/app.py,sha256=LbGLMvN9Ap119yBqsUcNNmVLRnCySnr4VechqcQ1hpA,10401
348
352
  flwr/simulation/legacy_app.py,sha256=nMISQqW0otJL1-2Kfd94O6BLlGS2IEmEPKTM2WGKrIs,15861
@@ -403,7 +407,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
403
407
  flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
404
408
  flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
405
409
  flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
406
- flwr_nightly-1.22.0.dev20250913.dist-info/METADATA,sha256=taZ5hyFAPFrevCeD1fE30C3M-BaOJVn2vpR-z-f_eA8,15967
407
- flwr_nightly-1.22.0.dev20250913.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
408
- flwr_nightly-1.22.0.dev20250913.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
409
- flwr_nightly-1.22.0.dev20250913.dist-info/RECORD,,
410
+ flwr_nightly-1.22.0.dev20250916.dist-info/METADATA,sha256=5fd7FMKBNE9N1UWd12_xAPWowIbjr948mx-erdTIBBM,14559
411
+ flwr_nightly-1.22.0.dev20250916.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
412
+ flwr_nightly-1.22.0.dev20250916.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
413
+ flwr_nightly-1.22.0.dev20250916.dist-info/RECORD,,
@@ -1,80 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
- from flwr.clientapp import ClientApp
6
-
7
- from $import_name.task import Net, load_data
8
- from $import_name.task import test as test_fn
9
- from $import_name.task import train as train_fn
10
-
11
- # Flower ClientApp
12
- app = ClientApp()
13
-
14
-
15
- @app.train()
16
- def train(msg: Message, context: Context):
17
- """Train the model on local data."""
18
-
19
- # Load the model and initialize it with the received weights
20
- model = Net()
21
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
- model.to(device)
24
-
25
- # Load the data
26
- partition_id = context.node_config["partition-id"]
27
- num_partitions = context.node_config["num-partitions"]
28
- trainloader, _ = load_data(partition_id, num_partitions)
29
-
30
- # Call the training function
31
- train_loss = train_fn(
32
- model,
33
- trainloader,
34
- context.run_config["local-epochs"],
35
- msg.content["config"]["lr"],
36
- device,
37
- )
38
-
39
- # Construct and return reply Message
40
- model_record = ArrayRecord(model.state_dict())
41
- metrics = {
42
- "train_loss": train_loss,
43
- "num-examples": len(trainloader.dataset),
44
- }
45
- metric_record = MetricRecord(metrics)
46
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
- return Message(content=content, reply_to=msg)
48
-
49
-
50
- @app.evaluate()
51
- def evaluate(msg: Message, context: Context):
52
- """Evaluate the model on local data."""
53
-
54
- # Load the model and initialize it with the received weights
55
- model = Net()
56
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
- model.to(device)
59
-
60
- # Load the data
61
- partition_id = context.node_config["partition-id"]
62
- num_partitions = context.node_config["num-partitions"]
63
- _, valloader = load_data(partition_id, num_partitions)
64
-
65
- # Call the evaluation function
66
- eval_loss, eval_acc = test_fn(
67
- model,
68
- valloader,
69
- device,
70
- )
71
-
72
- # Construct and return reply Message
73
- metrics = {
74
- "eval_loss": eval_loss,
75
- "eval_acc": eval_acc,
76
- "num-examples": len(valloader.dataset),
77
- }
78
- metric_record = MetricRecord(metrics)
79
- content = RecordDict({"metrics": metric_record})
80
- return Message(content=content, reply_to=msg)
@@ -1,41 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, ConfigRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import Net
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- fraction_train: float = context.run_config["fraction-train"]
20
- num_rounds: int = context.run_config["num-server-rounds"]
21
- lr: float = context.run_config["lr"]
22
-
23
- # Load global model
24
- global_model = Net()
25
- arrays = ArrayRecord(global_model.state_dict())
26
-
27
- # Initialize FedAvg strategy
28
- strategy = FedAvg(fraction_train=fraction_train)
29
-
30
- # Start strategy, run FedAvg for `num_rounds`
31
- result = strategy.start(
32
- grid=grid,
33
- initial_arrays=arrays,
34
- train_config=ConfigRecord({"lr": lr}),
35
- num_rounds=num_rounds,
36
- )
37
-
38
- # Save final model to disk
39
- print("\nSaving final model to disk...")
40
- state_dict = result.arrays.to_torch_state_dict()
41
- torch.save(state_dict, "final_model.pt")