flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -1,80 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
- from flwr.clientapp import ClientApp
6
-
7
- from $import_name.task import Net, load_data
8
- from $import_name.task import test as test_fn
9
- from $import_name.task import train as train_fn
10
-
11
- # Flower ClientApp
12
- app = ClientApp()
13
-
14
-
15
- @app.train()
16
- def train(msg: Message, context: Context):
17
- """Train the model on local data."""
18
-
19
- # Load the model and initialize it with the received weights
20
- model = Net()
21
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
- model.to(device)
24
-
25
- # Load the data
26
- partition_id = context.node_config["partition-id"]
27
- num_partitions = context.node_config["num-partitions"]
28
- trainloader, _ = load_data(partition_id, num_partitions)
29
-
30
- # Call the training function
31
- train_loss = train_fn(
32
- model,
33
- trainloader,
34
- context.run_config["local-epochs"],
35
- msg.content["config"]["lr"],
36
- device,
37
- )
38
-
39
- # Construct and return reply Message
40
- model_record = ArrayRecord(model.state_dict())
41
- metrics = {
42
- "train_loss": train_loss,
43
- "num-examples": len(trainloader.dataset),
44
- }
45
- metric_record = MetricRecord(metrics)
46
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
- return Message(content=content, reply_to=msg)
48
-
49
-
50
- @app.evaluate()
51
- def evaluate(msg: Message, context: Context):
52
- """Evaluate the model on local data."""
53
-
54
- # Load the model and initialize it with the received weights
55
- model = Net()
56
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
- model.to(device)
59
-
60
- # Load the data
61
- partition_id = context.node_config["partition-id"]
62
- num_partitions = context.node_config["num-partitions"]
63
- _, valloader = load_data(partition_id, num_partitions)
64
-
65
- # Call the evaluation function
66
- eval_loss, eval_acc = test_fn(
67
- model,
68
- valloader,
69
- device,
70
- )
71
-
72
- # Construct and return reply Message
73
- metrics = {
74
- "eval_loss": eval_loss,
75
- "eval_acc": eval_acc,
76
- "num-examples": len(valloader.dataset),
77
- }
78
- metric_record = MetricRecord(metrics)
79
- content = RecordDict({"metrics": metric_record})
80
- return Message(content=content, reply_to=msg)
@@ -1,41 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, ConfigRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import Net
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- fraction_train: float = context.run_config["fraction-train"]
20
- num_rounds: int = context.run_config["num-server-rounds"]
21
- lr: float = context.run_config["lr"]
22
-
23
- # Load global model
24
- global_model = Net()
25
- arrays = ArrayRecord(global_model.state_dict())
26
-
27
- # Initialize FedAvg strategy
28
- strategy = FedAvg(fraction_train=fraction_train)
29
-
30
- # Start strategy, run FedAvg for `num_rounds`
31
- result = strategy.start(
32
- grid=grid,
33
- initial_arrays=arrays,
34
- train_config=ConfigRecord({"lr": lr}),
35
- num_rounds=num_rounds,
36
- )
37
-
38
- # Save final model to disk
39
- print("\nSaving final model to disk...")
40
- state_dict = result.arrays.to_torch_state_dict()
41
- torch.save(state_dict, "final_model.pt")
@@ -1,149 +0,0 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Abstract classes for Flower User Auth Plugin."""
16
-
17
-
18
- from abc import ABC, abstractmethod
19
- from collections.abc import Sequence
20
- from pathlib import Path
21
- from typing import Optional, Union
22
-
23
- from flwr.common.typing import AccountInfo
24
- from flwr.proto.control_pb2_grpc import ControlStub
25
-
26
- from ..typing import UserAuthCredentials, UserAuthLoginDetails
27
-
28
-
29
- class ControlAuthPlugin(ABC):
30
- """Abstract Flower Auth Plugin class for ControlServicer.
31
-
32
- Parameters
33
- ----------
34
- user_auth_config_path : Path
35
- Path to the YAML file containing the authentication configuration.
36
- verify_tls_cert : bool
37
- Boolean indicating whether to verify the TLS certificate
38
- when making requests to the server.
39
- """
40
-
41
- @abstractmethod
42
- def __init__(
43
- self,
44
- user_auth_config_path: Path,
45
- verify_tls_cert: bool,
46
- ):
47
- """Abstract constructor."""
48
-
49
- @abstractmethod
50
- def get_login_details(self) -> Optional[UserAuthLoginDetails]:
51
- """Get the login details."""
52
-
53
- @abstractmethod
54
- def validate_tokens_in_metadata(
55
- self, metadata: Sequence[tuple[str, Union[str, bytes]]]
56
- ) -> tuple[bool, Optional[AccountInfo]]:
57
- """Validate authentication tokens in the provided metadata."""
58
-
59
- @abstractmethod
60
- def get_auth_tokens(self, device_code: str) -> Optional[UserAuthCredentials]:
61
- """Get authentication tokens."""
62
-
63
- @abstractmethod
64
- def refresh_tokens(
65
- self, metadata: Sequence[tuple[str, Union[str, bytes]]]
66
- ) -> tuple[
67
- Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
68
- ]:
69
- """Refresh authentication tokens in the provided metadata."""
70
-
71
-
72
- class ControlAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
73
- """Abstract Flower Authorization Plugin class for ControlServicer.
74
-
75
- Parameters
76
- ----------
77
- user_auth_config_path : Path
78
- Path to the YAML file containing the authorization configuration.
79
- verify_tls_cert : bool
80
- Boolean indicating whether to verify the TLS certificate
81
- when making requests to the server.
82
- """
83
-
84
- @abstractmethod
85
- def __init__(self, user_auth_config_path: Path, verify_tls_cert: bool):
86
- """Abstract constructor."""
87
-
88
- @abstractmethod
89
- def verify_user_authorization(self, account_info: AccountInfo) -> bool:
90
- """Verify user authorization request."""
91
-
92
-
93
- class CliAuthPlugin(ABC):
94
- """Abstract Flower Auth Plugin class for CLI.
95
-
96
- Parameters
97
- ----------
98
- credentials_path : Path
99
- Path to the user's authentication credentials file.
100
- """
101
-
102
- @staticmethod
103
- @abstractmethod
104
- def login(
105
- login_details: UserAuthLoginDetails,
106
- control_stub: ControlStub,
107
- ) -> UserAuthCredentials:
108
- """Authenticate the user and retrieve authentication credentials.
109
-
110
- Parameters
111
- ----------
112
- login_details : UserAuthLoginDetails
113
- An object containing the user's login details.
114
- control_stub : ControlStub
115
- A stub for executing RPC calls to the server.
116
-
117
- Returns
118
- -------
119
- UserAuthCredentials
120
- The authentication credentials obtained after login.
121
- """
122
-
123
- @abstractmethod
124
- def __init__(self, credentials_path: Path):
125
- """Abstract constructor."""
126
-
127
- @abstractmethod
128
- def store_tokens(self, credentials: UserAuthCredentials) -> None:
129
- """Store authentication tokens to the `credentials_path`.
130
-
131
- The credentials, including tokens, will be saved as a JSON file
132
- at `credentials_path`.
133
- """
134
-
135
- @abstractmethod
136
- def load_tokens(self) -> None:
137
- """Load authentication tokens from the `credentials_path`."""
138
-
139
- @abstractmethod
140
- def write_tokens_to_metadata(
141
- self, metadata: Sequence[tuple[str, Union[str, bytes]]]
142
- ) -> Sequence[tuple[str, Union[str, bytes]]]:
143
- """Write authentication tokens to the provided metadata."""
144
-
145
- @abstractmethod
146
- def read_tokens_from_metadata(
147
- self, metadata: Sequence[tuple[str, Union[str, bytes]]]
148
- ) -> Optional[UserAuthCredentials]:
149
- """Read authentication tokens from the provided metadata."""
@@ -1,352 +0,0 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Message-based Central differential privacy with fixed clipping.
16
-
17
- Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
18
- """
19
-
20
- from abc import ABC
21
- from collections import OrderedDict
22
- from collections.abc import Iterable
23
- from logging import INFO, WARNING
24
- from typing import Optional
25
-
26
- from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
27
- from flwr.common.differential_privacy import (
28
- add_gaussian_noise_inplace,
29
- compute_clip_model_update,
30
- compute_stdv,
31
- )
32
- from flwr.common.differential_privacy_constants import (
33
- CLIENTS_DISCREPANCY_WARNING,
34
- KEY_CLIPPING_NORM,
35
- )
36
- from flwr.server import Grid
37
-
38
- from .strategy import Strategy
39
-
40
-
41
- class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
42
- """Base class for DP strategies with fixed clipping.
43
-
44
- This class contains common functionality shared between server-side and
45
- client-side fixed clipping implementations.
46
-
47
- Parameters
48
- ----------
49
- strategy : Strategy
50
- The strategy to which DP functionalities will be added by this wrapper.
51
- noise_multiplier : float
52
- The noise multiplier for the Gaussian mechanism for model updates.
53
- A value of 1.0 or higher is recommended for strong privacy.
54
- clipping_norm : float
55
- The value of the clipping norm.
56
- num_sampled_clients : int
57
- The number of clients that are sampled on each round.
58
- """
59
-
60
- # pylint: disable=too-many-arguments,too-many-instance-attributes
61
- def __init__(
62
- self,
63
- strategy: Strategy,
64
- noise_multiplier: float,
65
- clipping_norm: float,
66
- num_sampled_clients: int,
67
- ) -> None:
68
- super().__init__()
69
-
70
- self.strategy = strategy
71
-
72
- if noise_multiplier < 0:
73
- raise ValueError("The noise multiplier should be a non-negative value.")
74
-
75
- if clipping_norm <= 0:
76
- raise ValueError("The clipping norm should be a positive value.")
77
-
78
- if num_sampled_clients <= 0:
79
- raise ValueError(
80
- "The number of sampled clients should be a positive value."
81
- )
82
-
83
- self.noise_multiplier = noise_multiplier
84
- self.clipping_norm = clipping_norm
85
- self.num_sampled_clients = num_sampled_clients
86
-
87
- def _validate_replies(self, replies: Iterable[Message]) -> bool:
88
- """Validate replies and log errors/warnings.
89
-
90
- Returns
91
- -------
92
- bool
93
- True if replies are valid for aggregation, False otherwise.
94
- """
95
- num_errors = 0
96
- num_replies_with_content = 0
97
- for msg in replies:
98
- if msg.has_error():
99
- log(
100
- INFO,
101
- "Received error in reply from node %d: %s",
102
- msg.metadata.src_node_id,
103
- msg.error,
104
- )
105
- num_errors += 1
106
- else:
107
- num_replies_with_content += 1
108
-
109
- # Errors are not allowed
110
- if num_errors:
111
- log(
112
- INFO,
113
- "aggregate_train: Some clients reported errors. Skipping aggregation.",
114
- )
115
- return False
116
-
117
- log(
118
- INFO,
119
- "aggregate_train: Received %s results and %s failures",
120
- num_replies_with_content,
121
- num_errors,
122
- )
123
-
124
- if num_replies_with_content != self.num_sampled_clients:
125
- log(
126
- WARNING,
127
- CLIENTS_DISCREPANCY_WARNING,
128
- num_replies_with_content,
129
- self.num_sampled_clients,
130
- )
131
-
132
- return True
133
-
134
- def _add_noise_to_aggregated_arrays(
135
- self, aggregated_arrays: ArrayRecord
136
- ) -> ArrayRecord:
137
- """Add Gaussian noise to aggregated arrays.
138
-
139
- Parameters
140
- ----------
141
- aggregated_arrays : ArrayRecord
142
- The aggregated arrays to add noise to.
143
-
144
- Returns
145
- -------
146
- ArrayRecord
147
- The aggregated arrays with noise added.
148
- """
149
- aggregated_ndarrays = aggregated_arrays.to_numpy_ndarrays()
150
- stdv = compute_stdv(
151
- self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
152
- )
153
- add_gaussian_noise_inplace(aggregated_ndarrays, stdv)
154
-
155
- log(
156
- INFO,
157
- "aggregate_fit: central DP noise with %.4f stdev added",
158
- stdv,
159
- )
160
-
161
- return ArrayRecord(
162
- OrderedDict(
163
- {
164
- k: Array(v)
165
- for k, v in zip(aggregated_arrays.keys(), aggregated_ndarrays)
166
- }
167
- )
168
- )
169
-
170
- def configure_evaluate(
171
- self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
172
- ) -> Iterable[Message]:
173
- """Configure the next round of federated evaluation."""
174
- return self.strategy.configure_evaluate(server_round, arrays, config, grid)
175
-
176
- def aggregate_evaluate(
177
- self,
178
- server_round: int,
179
- replies: Iterable[Message],
180
- ) -> Optional[MetricRecord]:
181
- """Aggregate MetricRecords in the received Messages."""
182
- return self.strategy.aggregate_evaluate(server_round, replies)
183
-
184
- def summary(self) -> None:
185
- """Log summary configuration of the strategy."""
186
- self.strategy.summary()
187
-
188
-
189
- class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippingBase):
190
- """Strategy wrapper for central DP with server-side fixed clipping.
191
-
192
- Parameters
193
- ----------
194
- strategy : Strategy
195
- The strategy to which DP functionalities will be added by this wrapper.
196
- noise_multiplier : float
197
- The noise multiplier for the Gaussian mechanism for model updates.
198
- A value of 1.0 or higher is recommended for strong privacy.
199
- clipping_norm : float
200
- The value of the clipping norm.
201
- num_sampled_clients : int
202
- The number of clients that are sampled on each round.
203
-
204
- Examples
205
- --------
206
- Create a strategy::
207
-
208
- strategy = fl.serverapp.FedAvg( ... )
209
-
210
- Wrap the strategy with the `DifferentialPrivacyServerSideFixedClipping` wrapper::
211
-
212
- dp_strategy = DifferentialPrivacyServerSideFixedClipping(
213
- strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
214
- )
215
- """
216
-
217
- def __init__(
218
- self,
219
- strategy: Strategy,
220
- noise_multiplier: float,
221
- clipping_norm: float,
222
- num_sampled_clients: int,
223
- ) -> None:
224
- super().__init__(strategy, noise_multiplier, clipping_norm, num_sampled_clients)
225
- self.current_arrays: ArrayRecord = ArrayRecord()
226
-
227
- def __repr__(self) -> str:
228
- """Compute a string representation of the strategy."""
229
- return "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
230
-
231
- def configure_train(
232
- self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
233
- ) -> Iterable[Message]:
234
- """Configure the next round of training."""
235
- self.current_arrays = arrays
236
- return self.strategy.configure_train(server_round, arrays, config, grid)
237
-
238
- def aggregate_train(
239
- self,
240
- server_round: int,
241
- replies: Iterable[Message],
242
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
243
- """Aggregate ArrayRecords and MetricRecords in the received Messages."""
244
- if not self._validate_replies(replies):
245
- return None, None
246
-
247
- # Clip arrays in replies
248
- current_ndarrays = self.current_arrays.to_numpy_ndarrays()
249
- for reply in replies:
250
- for arr_name, record in reply.content.array_records.items():
251
- # Clip
252
- reply_ndarrays = record.to_numpy_ndarrays()
253
- compute_clip_model_update(
254
- param1=reply_ndarrays,
255
- param2=current_ndarrays,
256
- clipping_norm=self.clipping_norm,
257
- )
258
- # Replace content while preserving keys
259
- reply.content[arr_name] = ArrayRecord(
260
- OrderedDict(
261
- {k: Array(v) for k, v in zip(record.keys(), reply_ndarrays)}
262
- )
263
- )
264
- log(
265
- INFO,
266
- "aggregate_fit: parameters are clipped by value: %.4f.",
267
- self.clipping_norm,
268
- )
269
-
270
- # Pass the new parameters for aggregation
271
- aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train(
272
- server_round, replies
273
- )
274
-
275
- # Add Gaussian noise to the aggregated arrays
276
- if aggregated_arrays:
277
- aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
278
-
279
- return aggregated_arrays, aggregated_metrics
280
-
281
-
282
- class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippingBase):
283
- """Strategy wrapper for central DP with client-side fixed clipping.
284
-
285
- Use `fixedclipping_mod` modifier at the client side.
286
-
287
- In comparison to `DifferentialPrivacyServerSideFixedClipping`,
288
- which performs clipping on the server-side,
289
- `DifferentialPrivacyClientSideFixedClipping` expects clipping to happen
290
- on the client-side, usually by using the built-in `fixedclipping_mod`.
291
-
292
- Parameters
293
- ----------
294
- strategy : Strategy
295
- The strategy to which DP functionalities will be added by this wrapper.
296
- noise_multiplier : float
297
- The noise multiplier for the Gaussian mechanism for model updates.
298
- A value of 1.0 or higher is recommended for strong privacy.
299
- clipping_norm : float
300
- The value of the clipping norm.
301
- num_sampled_clients : int
302
- The number of clients that are sampled on each round.
303
-
304
- Examples
305
- --------
306
- Create a strategy::
307
-
308
- strategy = fl.serverapp.FedAvg(...)
309
-
310
- Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper::
311
-
312
- dp_strategy = DifferentialPrivacyClientSideFixedClipping(
313
- strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
314
- )
315
-
316
- On the client, add the `fixedclipping_mod` to the client-side mods::
317
-
318
- app = fl.client.ClientApp(mods=[fixedclipping_mod])
319
- """
320
-
321
- def __repr__(self) -> str:
322
- """Compute a string representation of the strategy."""
323
- return "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
324
-
325
- def configure_train(
326
- self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
327
- ) -> Iterable[Message]:
328
- """Configure the next round of training."""
329
- # Inject clipping norm in config
330
- config[KEY_CLIPPING_NORM] = self.clipping_norm
331
- # Call parent method
332
- return self.strategy.configure_train(server_round, arrays, config, grid)
333
-
334
- def aggregate_train(
335
- self,
336
- server_round: int,
337
- replies: Iterable[Message],
338
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
339
- """Aggregate ArrayRecords and MetricRecords in the received Messages."""
340
- if not self._validate_replies(replies):
341
- return None, None
342
-
343
- # Aggregate
344
- aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train(
345
- server_round, replies
346
- )
347
-
348
- # Add Gaussian noise to the aggregated arrays
349
- if aggregated_arrays:
350
- aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
351
-
352
- return aggregated_arrays, aggregated_metrics