flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (66) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +19 -14
  4. flwr/cli/new/new.py +51 -22
  5. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  7. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
  10. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  11. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  13. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
  14. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  15. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  16. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  17. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  19. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  20. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  21. flwr/cli/run/run.py +1 -1
  22. flwr/cli/utils.py +18 -17
  23. flwr/client/__init__.py +3 -1
  24. flwr/client/app.py +20 -142
  25. flwr/client/grpc_client/connection.py +8 -2
  26. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  27. flwr/client/grpc_rere_client/connection.py +33 -4
  28. flwr/client/mod/centraldp_mods.py +4 -2
  29. flwr/client/mod/localdp_mod.py +9 -3
  30. flwr/client/rest_client/connection.py +92 -169
  31. flwr/client/supernode/__init__.py +24 -0
  32. flwr/client/supernode/app.py +281 -0
  33. flwr/common/grpc.py +5 -1
  34. flwr/common/logger.py +37 -4
  35. flwr/common/message.py +105 -86
  36. flwr/common/record/parametersrecord.py +0 -1
  37. flwr/common/record/recordset.py +78 -27
  38. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  39. flwr/common/telemetry.py +4 -0
  40. flwr/server/app.py +116 -6
  41. flwr/server/compat/app.py +2 -2
  42. flwr/server/compat/app_utils.py +1 -1
  43. flwr/server/compat/driver_client_proxy.py +27 -70
  44. flwr/server/driver/__init__.py +2 -1
  45. flwr/server/driver/driver.py +12 -139
  46. flwr/server/driver/grpc_driver.py +199 -13
  47. flwr/server/run_serverapp.py +18 -4
  48. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  49. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  50. flwr/server/superlink/driver/driver_servicer.py +1 -1
  51. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  52. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  53. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
  54. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  55. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  56. flwr/server/superlink/state/in_memory_state.py +89 -12
  57. flwr/server/superlink/state/sqlite_state.py +133 -16
  58. flwr/server/superlink/state/state.py +56 -6
  59. flwr/simulation/__init__.py +2 -2
  60. flwr/simulation/app.py +16 -1
  61. flwr/simulation/run_simulation.py +10 -7
  62. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  63. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
  64. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
  65. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  66. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,17 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
16
-
15
+ """Flower gRPC Driver."""
17
16
 
17
+ import time
18
+ import warnings
18
19
  from logging import DEBUG, ERROR, WARNING
19
- from typing import Optional
20
+ from typing import Iterable, List, Optional, Tuple
20
21
 
21
22
  import grpc
22
23
 
23
- from flwr.common import EventType, event
24
+ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
24
25
  from flwr.common.grpc import create_channel
25
26
  from flwr.common.logger import log
27
+ from flwr.common.serde import message_from_taskres, message_to_taskins
26
28
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
27
29
  CreateRunRequest,
28
30
  CreateRunResponse,
@@ -34,19 +36,23 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
34
36
  PushTaskInsResponse,
35
37
  )
36
38
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
+
42
+ from .driver import Driver
37
43
 
38
44
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
39
45
 
40
46
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
41
47
  [Driver] Error: Not connected.
42
48
 
43
- Call `connect()` on the `GrpcDriver` instance before calling any of the other
44
- `GrpcDriver` methods.
49
+ Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
50
+ `GrpcDriverHelper` methods.
45
51
  """
46
52
 
47
53
 
48
- class GrpcDriver:
49
- """`GrpcDriver` provides access to the gRPC Driver API/service."""
54
+ class GrpcDriverHelper:
55
+ """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
50
56
 
51
57
  def __init__(
52
58
  self,
@@ -89,7 +95,7 @@ class GrpcDriver:
89
95
  # Check if channel is open
90
96
  if self.stub is None:
91
97
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
92
- raise ConnectionError("`GrpcDriver` instance not connected")
98
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
93
99
 
94
100
  # Call Driver API
95
101
  res: CreateRunResponse = self.stub.CreateRun(request=req)
@@ -100,7 +106,7 @@ class GrpcDriver:
100
106
  # Check if channel is open
101
107
  if self.stub is None:
102
108
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
103
- raise ConnectionError("`GrpcDriver` instance not connected")
109
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
104
110
 
105
111
  # Call gRPC Driver API
106
112
  res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -111,7 +117,7 @@ class GrpcDriver:
111
117
  # Check if channel is open
112
118
  if self.stub is None:
113
119
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
114
- raise ConnectionError("`GrpcDriver` instance not connected")
120
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
115
121
 
116
122
  # Call gRPC Driver API
117
123
  res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -122,8 +128,188 @@ class GrpcDriver:
122
128
  # Check if channel is open
123
129
  if self.stub is None:
124
130
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
125
- raise ConnectionError("`GrpcDriver` instance not connected")
131
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
126
132
 
127
133
  # Call Driver API
128
134
  res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
129
135
  return res
136
+
137
+
138
+ class GrpcDriver(Driver):
139
+ """`Driver` class provides an interface to the Driver API.
140
+
141
+ Parameters
142
+ ----------
143
+ driver_service_address : Optional[str]
144
+ The IPv4 or IPv6 address of the Driver API server.
145
+ Defaults to `"[::]:9091"`.
146
+ certificates : bytes (default: None)
147
+ Tuple containing root certificate, server certificate, and private key
148
+ to start a secure SSL-enabled server. The tuple is expected to have
149
+ three bytes elements in the following order:
150
+
151
+ * CA certificate.
152
+ * server certificate.
153
+ * server private key.
154
+ fab_id : str (default: None)
155
+ The identifier of the FAB used in the run.
156
+ fab_version : str (default: None)
157
+ The version of the FAB used in the run.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
163
+ root_certificates: Optional[bytes] = None,
164
+ fab_id: Optional[str] = None,
165
+ fab_version: Optional[str] = None,
166
+ ) -> None:
167
+ self.addr = driver_service_address
168
+ self.root_certificates = root_certificates
169
+ self.driver_helper: Optional[GrpcDriverHelper] = None
170
+ self.run_id: Optional[int] = None
171
+ self.fab_id = fab_id if fab_id is not None else ""
172
+ self.fab_version = fab_version if fab_version is not None else ""
173
+ self.node = Node(node_id=0, anonymous=True)
174
+
175
+ def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
176
+ # Check if the GrpcDriverHelper is initialized
177
+ if self.driver_helper is None or self.run_id is None:
178
+ # Connect and create run
179
+ self.driver_helper = GrpcDriverHelper(
180
+ driver_service_address=self.addr,
181
+ root_certificates=self.root_certificates,
182
+ )
183
+ self.driver_helper.connect()
184
+ req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
185
+ res = self.driver_helper.create_run(req)
186
+ self.run_id = res.run_id
187
+ return self.driver_helper, self.run_id
188
+
189
+ def _check_message(self, message: Message) -> None:
190
+ # Check if the message is valid
191
+ if not (
192
+ message.metadata.run_id == self.run_id
193
+ and message.metadata.src_node_id == self.node.node_id
194
+ and message.metadata.message_id == ""
195
+ and message.metadata.reply_to_message == ""
196
+ and message.metadata.ttl > 0
197
+ ):
198
+ raise ValueError(f"Invalid message: {message}")
199
+
200
+ def create_message( # pylint: disable=too-many-arguments
201
+ self,
202
+ content: RecordSet,
203
+ message_type: str,
204
+ dst_node_id: int,
205
+ group_id: str,
206
+ ttl: Optional[float] = None,
207
+ ) -> Message:
208
+ """Create a new message with specified parameters.
209
+
210
+ This method constructs a new `Message` with given content and metadata.
211
+ The `run_id` and `src_node_id` will be set automatically.
212
+ """
213
+ _, run_id = self._get_grpc_driver_helper_and_run_id()
214
+ if ttl:
215
+ warnings.warn(
216
+ "A custom TTL was set, but note that the SuperLink does not enforce "
217
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
218
+ "version of Flower.",
219
+ stacklevel=2,
220
+ )
221
+
222
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
223
+ metadata = Metadata(
224
+ run_id=run_id,
225
+ message_id="", # Will be set by the server
226
+ src_node_id=self.node.node_id,
227
+ dst_node_id=dst_node_id,
228
+ reply_to_message="",
229
+ group_id=group_id,
230
+ ttl=ttl_,
231
+ message_type=message_type,
232
+ )
233
+ return Message(metadata=metadata, content=content)
234
+
235
+ def get_node_ids(self) -> List[int]:
236
+ """Get node IDs."""
237
+ grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
238
+ # Call GrpcDriverHelper method
239
+ res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
240
+ return [node.node_id for node in res.nodes]
241
+
242
+ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
243
+ """Push messages to specified node IDs.
244
+
245
+ This method takes an iterable of messages and sends each message
246
+ to the node specified in `dst_node_id`.
247
+ """
248
+ grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
249
+ # Construct TaskIns
250
+ task_ins_list: List[TaskIns] = []
251
+ for msg in messages:
252
+ # Check message
253
+ self._check_message(msg)
254
+ # Convert Message to TaskIns
255
+ taskins = message_to_taskins(msg)
256
+ # Add to list
257
+ task_ins_list.append(taskins)
258
+ # Call GrpcDriverHelper method
259
+ res = grpc_driver_helper.push_task_ins(
260
+ PushTaskInsRequest(task_ins_list=task_ins_list)
261
+ )
262
+ return list(res.task_ids)
263
+
264
+ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
265
+ """Pull messages based on message IDs.
266
+
267
+ This method is used to collect messages from the SuperLink that correspond to a
268
+ set of given message IDs.
269
+ """
270
+ grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
271
+ # Pull TaskRes
272
+ res = grpc_driver.pull_task_res(
273
+ PullTaskResRequest(node=self.node, task_ids=message_ids)
274
+ )
275
+ # Convert TaskRes to Message
276
+ msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
277
+ return msgs
278
+
279
+ def send_and_receive(
280
+ self,
281
+ messages: Iterable[Message],
282
+ *,
283
+ timeout: Optional[float] = None,
284
+ ) -> Iterable[Message]:
285
+ """Push messages to specified node IDs and pull the reply messages.
286
+
287
+ This method sends a list of messages to their destination node IDs and then
288
+ waits for the replies. It continues to pull replies until either all replies are
289
+ received or the specified timeout duration is exceeded.
290
+ """
291
+ # Push messages
292
+ msg_ids = set(self.push_messages(messages))
293
+
294
+ # Pull messages
295
+ end_time = time.time() + (timeout if timeout is not None else 0.0)
296
+ ret: List[Message] = []
297
+ while timeout is None or time.time() < end_time:
298
+ res_msgs = self.pull_messages(msg_ids)
299
+ ret.extend(res_msgs)
300
+ msg_ids.difference_update(
301
+ {msg.metadata.reply_to_message for msg in res_msgs}
302
+ )
303
+ if len(msg_ids) == 0:
304
+ break
305
+ # Sleep
306
+ time.sleep(3)
307
+ return ret
308
+
309
+ def close(self) -> None:
310
+ """Disconnect from the SuperLink if connected."""
311
+ # Check if GrpcDriverHelper is initialized
312
+ if self.driver_helper is None:
313
+ return
314
+ # Disconnect
315
+ self.driver_helper.disconnect()
@@ -25,7 +25,7 @@ from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.logger import log, update_console_handler
26
26
  from flwr.common.object_ref import load_app
27
27
 
28
- from .driver.driver import Driver
28
+ from .driver import Driver, GrpcDriver
29
29
  from .server_app import LoadServerAppError, ServerApp
30
30
 
31
31
 
@@ -128,13 +128,15 @@ def run_server_app() -> None:
128
128
  server_app_dir = args.dir
129
129
  server_app_attr = getattr(args, "server-app")
130
130
 
131
- # Initialize Driver
132
- driver = Driver(
131
+ # Initialize GrpcDriver
132
+ driver = GrpcDriver(
133
133
  driver_service_address=args.server,
134
134
  root_certificates=root_certificates,
135
+ fab_id=args.fab_id,
136
+ fab_version=args.fab_version,
135
137
  )
136
138
 
137
- # Run the Server App with the Driver
139
+ # Run the ServerApp with the Driver
138
140
  run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
139
141
 
140
142
  # Clean up
@@ -183,5 +185,17 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
183
185
  "app from there."
184
186
  " Default: current working directory.",
185
187
  )
188
+ parser.add_argument(
189
+ "--fab-id",
190
+ default=None,
191
+ type=str,
192
+ help="The identifier of the FAB used in the run.",
193
+ )
194
+ parser.add_argument(
195
+ "--fab-version",
196
+ default=None,
197
+ type=str,
198
+ help="The version of the FAB used in the run.",
199
+ )
186
200
 
187
201
  return parser
@@ -200,7 +200,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
200
200
 
201
201
  log(
202
202
  INFO,
203
- "aggregate_fit: parameters are clipped by value: %s.",
203
+ "aggregate_fit: parameters are clipped by value: %.4f.",
204
204
  self.clipping_norm,
205
205
  )
206
206
 
@@ -234,7 +234,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
234
234
  )
235
235
  log(
236
236
  INFO,
237
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
237
+ "aggregate_fit: central DP noise with "
238
+ "standard deviation: %.4f added to parameters.",
238
239
  compute_stdv(
239
240
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
240
241
  ),
@@ -424,7 +425,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
424
425
  )
425
426
  log(
426
427
  INFO,
427
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
428
+ "aggregate_fit: central DP noise with "
429
+ "standard deviation: %.4f added to parameters.",
428
430
  compute_stdv(
429
431
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
430
432
  ),
@@ -158,7 +158,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
158
158
  )
159
159
  log(
160
160
  INFO,
161
- "aggregate_fit: parameters are clipped by value: %s.",
161
+ "aggregate_fit: parameters are clipped by value: %.4f.",
162
162
  self.clipping_norm,
163
163
  )
164
164
  # Convert back to parameters
@@ -180,7 +180,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
180
180
 
181
181
  log(
182
182
  INFO,
183
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
183
+ "aggregate_fit: central DP noise with "
184
+ "standard deviation: %.4f added to parameters.",
184
185
  compute_stdv(
185
186
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
186
187
  ),
@@ -337,11 +338,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
337
338
  )
338
339
  log(
339
340
  INFO,
340
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
341
+ "aggregate_fit: central DP noise with "
342
+ "standard deviation: %.4f added to parameters.",
341
343
  compute_stdv(
342
344
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
343
345
  ),
344
346
  )
347
+
345
348
  return aggregated_params, metrics
346
349
 
347
350
  def aggregate_evaluate(
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
64
64
  """Create run ID."""
65
65
  log(INFO, "DriverServicer.CreateRun")
66
66
  state: State = self.state_factory.state()
67
- run_id = state.create_run()
67
+ run_id = state.create_run(request.fab_id, request.fab_version)
68
68
  return CreateRunResponse(run_id=run_id)
69
69
 
70
70
  def PushTaskIns(
@@ -18,7 +18,7 @@
18
18
  import concurrent.futures
19
19
  import sys
20
20
  from logging import ERROR
21
- from typing import Any, Callable, Optional, Tuple, Union
21
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
22
22
 
23
23
  import grpc
24
24
 
@@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
162
162
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
163
163
  keepalive_time_ms: int = 210000,
164
164
  certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
165
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
165
166
  ) -> grpc.Server:
166
167
  """Create a gRPC server with a single servicer.
167
168
 
@@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
249
250
  # returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
250
251
  maximum_concurrent_rpcs=max_concurrent_workers,
251
252
  options=options,
253
+ interceptors=interceptors,
252
254
  )
253
255
  add_servicer_to_server_fn(servicer, server)
254
256
 
@@ -0,0 +1,215 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower server interceptor."""
16
+
17
+
18
+ import base64
19
+ from logging import WARNING
20
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
+
22
+ import grpc
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+
25
+ from flwr.common.logger import log
26
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
27
+ bytes_to_private_key,
28
+ bytes_to_public_key,
29
+ generate_shared_key,
30
+ verify_hmac,
31
+ )
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
+ CreateNodeRequest,
34
+ CreateNodeResponse,
35
+ DeleteNodeRequest,
36
+ DeleteNodeResponse,
37
+ GetRunRequest,
38
+ GetRunResponse,
39
+ PingRequest,
40
+ PingResponse,
41
+ PullTaskInsRequest,
42
+ PullTaskInsResponse,
43
+ PushTaskResRequest,
44
+ PushTaskResResponse,
45
+ )
46
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
+ from flwr.server.superlink.state import State
48
+
49
+ _PUBLIC_KEY_HEADER = "public-key"
50
+ _AUTH_TOKEN_HEADER = "auth-token"
51
+
52
+ Request = Union[
53
+ CreateNodeRequest,
54
+ DeleteNodeRequest,
55
+ PullTaskInsRequest,
56
+ PushTaskResRequest,
57
+ GetRunRequest,
58
+ PingRequest,
59
+ ]
60
+
61
+ Response = Union[
62
+ CreateNodeResponse,
63
+ DeleteNodeResponse,
64
+ PullTaskInsResponse,
65
+ PushTaskResResponse,
66
+ GetRunResponse,
67
+ PingResponse,
68
+ ]
69
+
70
+
71
+ def _get_value_from_tuples(
72
+ key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
73
+ ) -> bytes:
74
+ value = next((value for key, value in tuples if key == key_string), "")
75
+ if isinstance(value, str):
76
+ return value.encode()
77
+
78
+ return value
79
+
80
+
81
+ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
82
+ """Server interceptor for client authentication."""
83
+
84
+ def __init__(self, state: State):
85
+ self.state = state
86
+
87
+ self.client_public_keys = state.get_client_public_keys()
88
+ if len(self.client_public_keys) == 0:
89
+ log(WARNING, "Authentication enabled, but no known public keys configured")
90
+
91
+ private_key = self.state.get_server_private_key()
92
+ public_key = self.state.get_server_public_key()
93
+
94
+ if private_key is None or public_key is None:
95
+ raise ValueError("Error loading authentication keys")
96
+
97
+ self.server_private_key = bytes_to_private_key(private_key)
98
+ self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
99
+
100
+ def intercept_service(
101
+ self,
102
+ continuation: Callable[[Any], Any],
103
+ handler_call_details: grpc.HandlerCallDetails,
104
+ ) -> grpc.RpcMethodHandler:
105
+ """Flower server interceptor authentication logic.
106
+
107
+ Intercept all unary calls from clients and authenticate clients by validating
108
+ auth metadata sent by the client. Continue RPC call if client is authenticated,
109
+ else, terminate RPC call by setting context to abort.
110
+ """
111
+ # One of the method handlers in
112
+ # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
113
+ method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
114
+ return self._generic_auth_unary_method_handler(method_handler)
115
+
116
+ def _generic_auth_unary_method_handler(
117
+ self, method_handler: grpc.RpcMethodHandler
118
+ ) -> grpc.RpcMethodHandler:
119
+ def _generic_method_handler(
120
+ request: Request,
121
+ context: grpc.ServicerContext,
122
+ ) -> Response:
123
+ client_public_key_bytes = base64.urlsafe_b64decode(
124
+ _get_value_from_tuples(
125
+ _PUBLIC_KEY_HEADER, context.invocation_metadata()
126
+ )
127
+ )
128
+ if client_public_key_bytes not in self.client_public_keys:
129
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
130
+
131
+ if isinstance(request, CreateNodeRequest):
132
+ return self._create_authenticated_node(
133
+ client_public_key_bytes, request, context
134
+ )
135
+
136
+ # Verify hmac value
137
+ hmac_value = base64.urlsafe_b64decode(
138
+ _get_value_from_tuples(
139
+ _AUTH_TOKEN_HEADER, context.invocation_metadata()
140
+ )
141
+ )
142
+ public_key = bytes_to_public_key(client_public_key_bytes)
143
+
144
+ if not self._verify_hmac(public_key, request, hmac_value):
145
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
146
+
147
+ # Verify node_id
148
+ node_id = self.state.get_node_id(client_public_key_bytes)
149
+
150
+ if not self._verify_node_id(node_id, request):
151
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
152
+
153
+ return method_handler.unary_unary(request, context) # type: ignore
154
+
155
+ return grpc.unary_unary_rpc_method_handler(
156
+ _generic_method_handler,
157
+ request_deserializer=method_handler.request_deserializer,
158
+ response_serializer=method_handler.response_serializer,
159
+ )
160
+
161
+ def _verify_node_id(
162
+ self,
163
+ node_id: Optional[int],
164
+ request: Union[
165
+ DeleteNodeRequest,
166
+ PullTaskInsRequest,
167
+ PushTaskResRequest,
168
+ GetRunRequest,
169
+ PingRequest,
170
+ ],
171
+ ) -> bool:
172
+ if node_id is None:
173
+ return False
174
+ if isinstance(request, PushTaskResRequest):
175
+ if len(request.task_res_list) == 0:
176
+ return False
177
+ return request.task_res_list[0].task.producer.node_id == node_id
178
+ if isinstance(request, GetRunRequest):
179
+ return node_id in self.state.get_nodes(request.run_id)
180
+ return request.node.node_id == node_id
181
+
182
+ def _verify_hmac(
183
+ self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
184
+ ) -> bool:
185
+ shared_secret = generate_shared_key(self.server_private_key, public_key)
186
+ return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
187
+
188
+ def _create_authenticated_node(
189
+ self,
190
+ public_key_bytes: bytes,
191
+ request: CreateNodeRequest,
192
+ context: grpc.ServicerContext,
193
+ ) -> CreateNodeResponse:
194
+ context.send_initial_metadata(
195
+ (
196
+ (
197
+ _PUBLIC_KEY_HEADER,
198
+ self.encoded_server_public_key,
199
+ ),
200
+ )
201
+ )
202
+
203
+ node_id = self.state.get_node_id(public_key_bytes)
204
+
205
+ # Handle `CreateNode` here instead of calling the default method handler
206
+ # Return previously assigned `node_id` for the provided `public_key`
207
+ if node_id is not None:
208
+ self.state.acknowledge_ping(node_id, request.ping_interval)
209
+ return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
210
+
211
+ # No `node_id` exists for the provided `public_key`
212
+ # Handle `CreateNode` here instead of calling the default method handler
213
+ # Note: the innermost `CreateNode` method will never be called
214
+ node_id = self.state.create_node(request.ping_interval, public_key_bytes)
215
+ return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -33,6 +33,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  PushTaskResRequest,
34
34
  PushTaskResResponse,
35
35
  Reconnect,
36
+ Run,
36
37
  )
37
38
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
38
39
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
@@ -109,4 +110,6 @@ def get_run(
109
110
  request: GetRunRequest, state: State # pylint: disable=W0613
110
111
  ) -> GetRunResponse:
111
112
  """Get run information."""
112
- return GetRunResponse()
113
+ run_id, fab_id, fab_version = state.get_run(request.run_id)
114
+ run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version)
115
+ return GetRunResponse(run=run)
@@ -15,7 +15,7 @@
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
17
  import pathlib
18
- from logging import ERROR, INFO
18
+ from logging import DEBUG, ERROR, INFO
19
19
  from typing import Callable, Dict, List, Tuple, Union
20
20
 
21
21
  import ray
@@ -46,7 +46,7 @@ class RayBackend(Backend):
46
46
  ) -> None:
47
47
  """Prepare RayBackend by initialising Ray and creating the ActorPool."""
48
48
  log(INFO, "Initialising: %s", self.__class__.__name__)
49
- log(INFO, "Backend config: %s", backend_config)
49
+ log(DEBUG, "Backend config: %s", backend_config)
50
50
 
51
51
  if not pathlib.Path(work_dir).exists():
52
52
  raise ValueError(f"Specified work_dir {work_dir} does not exist.")
@@ -109,7 +109,7 @@ class RayBackend(Backend):
109
109
  else:
110
110
  client_resources = {"num_cpus": 2, "num_gpus": 0.0}
111
111
  log(
112
- INFO,
112
+ DEBUG,
113
113
  "`%s` not specified in backend config. Applying default setting: %s",
114
114
  self.client_resources_key,
115
115
  client_resources,
@@ -129,7 +129,7 @@ class RayBackend(Backend):
129
129
  async def build(self) -> None:
130
130
  """Build pool of Ray actors that this backend will submit jobs to."""
131
131
  await self.pool.add_actors_to_pool(self.pool.actors_capacity)
132
- log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors)
132
+ log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
133
 
134
134
  async def process_message(
135
135
  self,
@@ -173,4 +173,4 @@ class RayBackend(Backend):
173
173
  """Terminate all actors in actor pool."""
174
174
  await self.pool.terminate_all_actors()
175
175
  ray.shutdown()
176
- log(INFO, "Terminated %s", self.__class__.__name__)
176
+ log(DEBUG, "Terminated %s", self.__class__.__name__)