flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240509__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (71) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +44 -18
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  7. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  10. flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
  11. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  12. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  14. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  15. flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
  16. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  17. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  18. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
  19. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  20. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  21. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  22. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  23. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  24. flwr/cli/run/run.py +1 -1
  25. flwr/cli/utils.py +18 -17
  26. flwr/client/__init__.py +1 -1
  27. flwr/client/app.py +17 -93
  28. flwr/client/grpc_client/connection.py +6 -1
  29. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  30. flwr/client/grpc_rere_client/connection.py +17 -2
  31. flwr/client/mod/centraldp_mods.py +4 -2
  32. flwr/client/mod/localdp_mod.py +9 -3
  33. flwr/client/rest_client/connection.py +5 -1
  34. flwr/client/supernode/__init__.py +2 -0
  35. flwr/client/supernode/app.py +181 -7
  36. flwr/common/grpc.py +5 -1
  37. flwr/common/logger.py +37 -4
  38. flwr/common/message.py +105 -86
  39. flwr/common/record/parametersrecord.py +0 -1
  40. flwr/common/record/recordset.py +17 -5
  41. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  42. flwr/server/__init__.py +0 -2
  43. flwr/server/app.py +118 -2
  44. flwr/server/compat/app.py +5 -56
  45. flwr/server/compat/app_utils.py +1 -1
  46. flwr/server/compat/driver_client_proxy.py +27 -72
  47. flwr/server/driver/__init__.py +3 -0
  48. flwr/server/driver/driver.py +12 -242
  49. flwr/server/driver/grpc_driver.py +315 -0
  50. flwr/server/history.py +20 -20
  51. flwr/server/run_serverapp.py +18 -4
  52. flwr/server/server.py +2 -5
  53. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  54. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  55. flwr/server/superlink/driver/driver_servicer.py +1 -1
  56. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  57. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  58. flwr/server/superlink/fleet/vce/backend/raybackend.py +9 -6
  59. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  60. flwr/server/superlink/state/in_memory_state.py +76 -8
  61. flwr/server/superlink/state/sqlite_state.py +116 -11
  62. flwr/server/superlink/state/state.py +35 -3
  63. flwr/simulation/__init__.py +2 -2
  64. flwr/simulation/app.py +16 -1
  65. flwr/simulation/run_simulation.py +14 -9
  66. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/METADATA +3 -2
  67. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/RECORD +70 -55
  68. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/entry_points.txt +1 -1
  69. flwr/server/driver/abc_driver.py +0 -140
  70. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/LICENSE +0 -0
  71. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/WHEEL +0 -0
@@ -25,7 +25,7 @@ from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.logger import log, update_console_handler
26
26
  from flwr.common.object_ref import load_app
27
27
 
28
- from .driver.driver import Driver
28
+ from .driver import Driver, GrpcDriver
29
29
  from .server_app import LoadServerAppError, ServerApp
30
30
 
31
31
 
@@ -128,13 +128,15 @@ def run_server_app() -> None:
128
128
  server_app_dir = args.dir
129
129
  server_app_attr = getattr(args, "server-app")
130
130
 
131
- # Initialize Driver
132
- driver = Driver(
131
+ # Initialize GrpcDriver
132
+ driver = GrpcDriver(
133
133
  driver_service_address=args.server,
134
134
  root_certificates=root_certificates,
135
+ fab_id=args.fab_id,
136
+ fab_version=args.fab_version,
135
137
  )
136
138
 
137
- # Run the Server App with the Driver
139
+ # Run the ServerApp with the Driver
138
140
  run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
139
141
 
140
142
  # Clean up
@@ -183,5 +185,17 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
183
185
  "app from there."
184
186
  " Default: current working directory.",
185
187
  )
188
+ parser.add_argument(
189
+ "--fab-id",
190
+ default=None,
191
+ type=str,
192
+ help="The identifier of the FAB used in the run.",
193
+ )
194
+ parser.add_argument(
195
+ "--fab-version",
196
+ default=None,
197
+ type=str,
198
+ help="The version of the FAB used in the run.",
199
+ )
186
200
 
187
201
  return parser
flwr/server/server.py CHANGED
@@ -487,11 +487,8 @@ def run_fl(
487
487
  log(INFO, "")
488
488
  log(INFO, "[SUMMARY]")
489
489
  log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
490
- for idx, line in enumerate(io.StringIO(str(hist))):
491
- if idx == 0:
492
- log(INFO, "%s", line.strip("\n"))
493
- else:
494
- log(INFO, "\t%s", line.strip("\n"))
490
+ for line in io.StringIO(str(hist)):
491
+ log(INFO, "\t%s", line.strip("\n"))
495
492
  log(INFO, "")
496
493
 
497
494
  # Graceful shutdown
@@ -200,7 +200,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
200
200
 
201
201
  log(
202
202
  INFO,
203
- "aggregate_fit: parameters are clipped by value: %s.",
203
+ "aggregate_fit: parameters are clipped by value: %.4f.",
204
204
  self.clipping_norm,
205
205
  )
206
206
 
@@ -234,7 +234,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
234
234
  )
235
235
  log(
236
236
  INFO,
237
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
237
+ "aggregate_fit: central DP noise with "
238
+ "standard deviation: %.4f added to parameters.",
238
239
  compute_stdv(
239
240
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
240
241
  ),
@@ -424,7 +425,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
424
425
  )
425
426
  log(
426
427
  INFO,
427
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
428
+ "aggregate_fit: central DP noise with "
429
+ "standard deviation: %.4f added to parameters.",
428
430
  compute_stdv(
429
431
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
430
432
  ),
@@ -158,7 +158,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
158
158
  )
159
159
  log(
160
160
  INFO,
161
- "aggregate_fit: parameters are clipped by value: %s.",
161
+ "aggregate_fit: parameters are clipped by value: %.4f.",
162
162
  self.clipping_norm,
163
163
  )
164
164
  # Convert back to parameters
@@ -180,7 +180,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
180
180
 
181
181
  log(
182
182
  INFO,
183
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
183
+ "aggregate_fit: central DP noise with "
184
+ "standard deviation: %.4f added to parameters.",
184
185
  compute_stdv(
185
186
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
186
187
  ),
@@ -337,11 +338,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
337
338
  )
338
339
  log(
339
340
  INFO,
340
- "aggregate_fit: central DP noise with standard deviation: %s added to parameters.",
341
+ "aggregate_fit: central DP noise with "
342
+ "standard deviation: %.4f added to parameters.",
341
343
  compute_stdv(
342
344
  self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
343
345
  ),
344
346
  )
347
+
345
348
  return aggregated_params, metrics
346
349
 
347
350
  def aggregate_evaluate(
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
64
64
  """Create run ID."""
65
65
  log(INFO, "DriverServicer.CreateRun")
66
66
  state: State = self.state_factory.state()
67
- run_id = state.create_run("None/None", "None")
67
+ run_id = state.create_run(request.fab_id, request.fab_version)
68
68
  return CreateRunResponse(run_id=run_id)
69
69
 
70
70
  def PushTaskIns(
@@ -18,7 +18,7 @@
18
18
  import concurrent.futures
19
19
  import sys
20
20
  from logging import ERROR
21
- from typing import Any, Callable, Optional, Tuple, Union
21
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
22
22
 
23
23
  import grpc
24
24
 
@@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
162
162
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
163
163
  keepalive_time_ms: int = 210000,
164
164
  certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
165
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
165
166
  ) -> grpc.Server:
166
167
  """Create a gRPC server with a single servicer.
167
168
 
@@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
249
250
  # returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
250
251
  maximum_concurrent_rpcs=max_concurrent_workers,
251
252
  options=options,
253
+ interceptors=interceptors,
252
254
  )
253
255
  add_servicer_to_server_fn(servicer, server)
254
256
 
@@ -0,0 +1,215 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower server interceptor."""
16
+
17
+
18
+ import base64
19
+ from logging import WARNING
20
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
+
22
+ import grpc
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+
25
+ from flwr.common.logger import log
26
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
27
+ bytes_to_private_key,
28
+ bytes_to_public_key,
29
+ generate_shared_key,
30
+ verify_hmac,
31
+ )
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
+ CreateNodeRequest,
34
+ CreateNodeResponse,
35
+ DeleteNodeRequest,
36
+ DeleteNodeResponse,
37
+ GetRunRequest,
38
+ GetRunResponse,
39
+ PingRequest,
40
+ PingResponse,
41
+ PullTaskInsRequest,
42
+ PullTaskInsResponse,
43
+ PushTaskResRequest,
44
+ PushTaskResResponse,
45
+ )
46
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
+ from flwr.server.superlink.state import State
48
+
49
+ _PUBLIC_KEY_HEADER = "public-key"
50
+ _AUTH_TOKEN_HEADER = "auth-token"
51
+
52
+ Request = Union[
53
+ CreateNodeRequest,
54
+ DeleteNodeRequest,
55
+ PullTaskInsRequest,
56
+ PushTaskResRequest,
57
+ GetRunRequest,
58
+ PingRequest,
59
+ ]
60
+
61
+ Response = Union[
62
+ CreateNodeResponse,
63
+ DeleteNodeResponse,
64
+ PullTaskInsResponse,
65
+ PushTaskResResponse,
66
+ GetRunResponse,
67
+ PingResponse,
68
+ ]
69
+
70
+
71
+ def _get_value_from_tuples(
72
+ key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
73
+ ) -> bytes:
74
+ value = next((value for key, value in tuples if key == key_string), "")
75
+ if isinstance(value, str):
76
+ return value.encode()
77
+
78
+ return value
79
+
80
+
81
+ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
82
+ """Server interceptor for client authentication."""
83
+
84
+ def __init__(self, state: State):
85
+ self.state = state
86
+
87
+ self.client_public_keys = state.get_client_public_keys()
88
+ if len(self.client_public_keys) == 0:
89
+ log(WARNING, "Authentication enabled, but no known public keys configured")
90
+
91
+ private_key = self.state.get_server_private_key()
92
+ public_key = self.state.get_server_public_key()
93
+
94
+ if private_key is None or public_key is None:
95
+ raise ValueError("Error loading authentication keys")
96
+
97
+ self.server_private_key = bytes_to_private_key(private_key)
98
+ self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
99
+
100
+ def intercept_service(
101
+ self,
102
+ continuation: Callable[[Any], Any],
103
+ handler_call_details: grpc.HandlerCallDetails,
104
+ ) -> grpc.RpcMethodHandler:
105
+ """Flower server interceptor authentication logic.
106
+
107
+ Intercept all unary calls from clients and authenticate clients by validating
108
+ auth metadata sent by the client. Continue RPC call if client is authenticated,
109
+ else, terminate RPC call by setting context to abort.
110
+ """
111
+ # One of the method handlers in
112
+ # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
113
+ method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
114
+ return self._generic_auth_unary_method_handler(method_handler)
115
+
116
+ def _generic_auth_unary_method_handler(
117
+ self, method_handler: grpc.RpcMethodHandler
118
+ ) -> grpc.RpcMethodHandler:
119
+ def _generic_method_handler(
120
+ request: Request,
121
+ context: grpc.ServicerContext,
122
+ ) -> Response:
123
+ client_public_key_bytes = base64.urlsafe_b64decode(
124
+ _get_value_from_tuples(
125
+ _PUBLIC_KEY_HEADER, context.invocation_metadata()
126
+ )
127
+ )
128
+ if client_public_key_bytes not in self.client_public_keys:
129
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
130
+
131
+ if isinstance(request, CreateNodeRequest):
132
+ return self._create_authenticated_node(
133
+ client_public_key_bytes, request, context
134
+ )
135
+
136
+ # Verify hmac value
137
+ hmac_value = base64.urlsafe_b64decode(
138
+ _get_value_from_tuples(
139
+ _AUTH_TOKEN_HEADER, context.invocation_metadata()
140
+ )
141
+ )
142
+ public_key = bytes_to_public_key(client_public_key_bytes)
143
+
144
+ if not self._verify_hmac(public_key, request, hmac_value):
145
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
146
+
147
+ # Verify node_id
148
+ node_id = self.state.get_node_id(client_public_key_bytes)
149
+
150
+ if not self._verify_node_id(node_id, request):
151
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
152
+
153
+ return method_handler.unary_unary(request, context) # type: ignore
154
+
155
+ return grpc.unary_unary_rpc_method_handler(
156
+ _generic_method_handler,
157
+ request_deserializer=method_handler.request_deserializer,
158
+ response_serializer=method_handler.response_serializer,
159
+ )
160
+
161
+ def _verify_node_id(
162
+ self,
163
+ node_id: Optional[int],
164
+ request: Union[
165
+ DeleteNodeRequest,
166
+ PullTaskInsRequest,
167
+ PushTaskResRequest,
168
+ GetRunRequest,
169
+ PingRequest,
170
+ ],
171
+ ) -> bool:
172
+ if node_id is None:
173
+ return False
174
+ if isinstance(request, PushTaskResRequest):
175
+ if len(request.task_res_list) == 0:
176
+ return False
177
+ return request.task_res_list[0].task.producer.node_id == node_id
178
+ if isinstance(request, GetRunRequest):
179
+ return node_id in self.state.get_nodes(request.run_id)
180
+ return request.node.node_id == node_id
181
+
182
+ def _verify_hmac(
183
+ self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
184
+ ) -> bool:
185
+ shared_secret = generate_shared_key(self.server_private_key, public_key)
186
+ return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
187
+
188
+ def _create_authenticated_node(
189
+ self,
190
+ public_key_bytes: bytes,
191
+ request: CreateNodeRequest,
192
+ context: grpc.ServicerContext,
193
+ ) -> CreateNodeResponse:
194
+ context.send_initial_metadata(
195
+ (
196
+ (
197
+ _PUBLIC_KEY_HEADER,
198
+ self.encoded_server_public_key,
199
+ ),
200
+ )
201
+ )
202
+
203
+ node_id = self.state.get_node_id(public_key_bytes)
204
+
205
+ # Handle `CreateNode` here instead of calling the default method handler
206
+ # Return previously assigned `node_id` for the provided `public_key`
207
+ if node_id is not None:
208
+ self.state.acknowledge_ping(node_id, request.ping_interval)
209
+ return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
210
+
211
+ # No `node_id` exists for the provided `public_key`
212
+ # Handle `CreateNode` here instead of calling the default method handler
213
+ # Note: the innermost `CreateNode` method will never be called
214
+ node_id = self.state.create_node(request.ping_interval, public_key_bytes)
215
+ return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -15,7 +15,7 @@
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
17
  import pathlib
18
- from logging import ERROR, INFO
18
+ from logging import DEBUG, ERROR, INFO, WARNING
19
19
  from typing import Callable, Dict, List, Tuple, Union
20
20
 
21
21
  import ray
@@ -46,7 +46,7 @@ class RayBackend(Backend):
46
46
  ) -> None:
47
47
  """Prepare RayBackend by initialising Ray and creating the ActorPool."""
48
48
  log(INFO, "Initialising: %s", self.__class__.__name__)
49
- log(INFO, "Backend config: %s", backend_config)
49
+ log(DEBUG, "Backend config: %s", backend_config)
50
50
 
51
51
  if not pathlib.Path(work_dir).exists():
52
52
  raise ValueError(f"Specified work_dir {work_dir} does not exist.")
@@ -55,7 +55,10 @@ class RayBackend(Backend):
55
55
  runtime_env = (
56
56
  self._configure_runtime_env(work_dir=work_dir) if work_dir else None
57
57
  )
58
- init_ray(runtime_env=runtime_env)
58
+ if backend_config.get("silent", False):
59
+ init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
60
+ else:
61
+ init_ray(runtime_env=runtime_env)
59
62
 
60
63
  # Validate client resources
61
64
  self.client_resources_key = "client_resources"
@@ -109,7 +112,7 @@ class RayBackend(Backend):
109
112
  else:
110
113
  client_resources = {"num_cpus": 2, "num_gpus": 0.0}
111
114
  log(
112
- INFO,
115
+ DEBUG,
113
116
  "`%s` not specified in backend config. Applying default setting: %s",
114
117
  self.client_resources_key,
115
118
  client_resources,
@@ -129,7 +132,7 @@ class RayBackend(Backend):
129
132
  async def build(self) -> None:
130
133
  """Build pool of Ray actors that this backend will submit jobs to."""
131
134
  await self.pool.add_actors_to_pool(self.pool.actors_capacity)
132
- log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors)
135
+ log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
136
 
134
137
  async def process_message(
135
138
  self,
@@ -173,4 +176,4 @@ class RayBackend(Backend):
173
176
  """Terminate all actors in actor pool."""
174
177
  await self.pool.terminate_all_actors()
175
178
  ray.shutdown()
176
- log(INFO, "Terminated %s", self.__class__.__name__)
179
+ log(DEBUG, "Terminated %s", self.__class__.__name__)
@@ -293,7 +293,7 @@ def start_vce(
293
293
  node_states[node_id] = NodeState()
294
294
 
295
295
  # Load backend config
296
- log(INFO, "Supported backends: %s", list(supported_backends.keys()))
296
+ log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
297
297
  backend_config = json.loads(backend_config_json_stream)
298
298
 
299
299
  try:
@@ -30,16 +30,24 @@ from flwr.server.utils import validate_task_ins_or_res
30
30
  from .utils import make_node_unavailable_taskres
31
31
 
32
32
 
33
- class InMemoryState(State):
33
+ class InMemoryState(State): # pylint: disable=R0902,R0904
34
34
  """In-memory State implementation."""
35
35
 
36
36
  def __init__(self) -> None:
37
+
37
38
  # Map node_id to (online_until, ping_interval)
38
39
  self.node_ids: Dict[int, Tuple[float, float]] = {}
40
+ self.public_key_to_node_id: Dict[bytes, int] = {}
41
+
39
42
  # Map run_id to (fab_id, fab_version)
40
43
  self.run_ids: Dict[int, Tuple[str, str]] = {}
41
44
  self.task_ins_store: Dict[UUID, TaskIns] = {}
42
45
  self.task_res_store: Dict[UUID, TaskRes] = {}
46
+
47
+ self.client_public_keys: Set[bytes] = set()
48
+ self.server_public_key: Optional[bytes] = None
49
+ self.server_private_key: Optional[bytes] = None
50
+
43
51
  self.lock = threading.Lock()
44
52
 
45
53
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
@@ -202,23 +210,46 @@ class InMemoryState(State):
202
210
  """
203
211
  return len(self.task_res_store)
204
212
 
205
- def create_node(self, ping_interval: float) -> int:
213
+ def create_node(
214
+ self, ping_interval: float, public_key: Optional[bytes] = None
215
+ ) -> int:
206
216
  """Create, store in state, and return `node_id`."""
207
217
  # Sample a random int64 as node_id
208
218
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
209
219
 
210
220
  with self.lock:
211
- if node_id not in self.node_ids:
212
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
213
- return node_id
214
- log(ERROR, "Unexpected node registration failure.")
215
- return 0
221
+ if node_id in self.node_ids:
222
+ log(ERROR, "Unexpected node registration failure.")
223
+ return 0
224
+
225
+ if public_key is not None:
226
+ if (
227
+ public_key in self.public_key_to_node_id
228
+ or node_id in self.public_key_to_node_id.values()
229
+ ):
230
+ log(ERROR, "Unexpected node registration failure.")
231
+ return 0
232
+
233
+ self.public_key_to_node_id[public_key] = node_id
234
+
235
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
236
+ return node_id
216
237
 
217
- def delete_node(self, node_id: int) -> None:
238
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
218
239
  """Delete a client node."""
219
240
  with self.lock:
220
241
  if node_id not in self.node_ids:
221
242
  raise ValueError(f"Node {node_id} not found")
243
+
244
+ if public_key is not None:
245
+ if (
246
+ public_key not in self.public_key_to_node_id
247
+ or node_id not in self.public_key_to_node_id.values()
248
+ ):
249
+ raise ValueError("Public key or node_id not found")
250
+
251
+ del self.public_key_to_node_id[public_key]
252
+
222
253
  del self.node_ids[node_id]
223
254
 
224
255
  def get_nodes(self, run_id: int) -> Set[int]:
@@ -239,6 +270,10 @@ class InMemoryState(State):
239
270
  if online_until > current_time
240
271
  }
241
272
 
273
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
274
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
275
+ return self.public_key_to_node_id.get(client_public_key)
276
+
242
277
  def create_run(self, fab_id: str, fab_version: str) -> int:
243
278
  """Create a new run for the specified `fab_id` and `fab_version`."""
244
279
  # Sample a random int64 as run_id
@@ -251,6 +286,39 @@ class InMemoryState(State):
251
286
  log(ERROR, "Unexpected run creation failure.")
252
287
  return 0
253
288
 
289
+ def store_server_private_public_key(
290
+ self, private_key: bytes, public_key: bytes
291
+ ) -> None:
292
+ """Store `server_private_key` and `server_public_key` in state."""
293
+ with self.lock:
294
+ if self.server_private_key is None and self.server_public_key is None:
295
+ self.server_private_key = private_key
296
+ self.server_public_key = public_key
297
+ else:
298
+ raise RuntimeError("Server private and public key already set")
299
+
300
+ def get_server_private_key(self) -> Optional[bytes]:
301
+ """Retrieve `server_private_key` in urlsafe bytes."""
302
+ return self.server_private_key
303
+
304
+ def get_server_public_key(self) -> Optional[bytes]:
305
+ """Retrieve `server_public_key` in urlsafe bytes."""
306
+ return self.server_public_key
307
+
308
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
309
+ """Store a set of `client_public_keys` in state."""
310
+ with self.lock:
311
+ self.client_public_keys = public_keys
312
+
313
+ def store_client_public_key(self, public_key: bytes) -> None:
314
+ """Store a `client_public_key` in state."""
315
+ with self.lock:
316
+ self.client_public_keys.add(public_key)
317
+
318
+ def get_client_public_keys(self) -> Set[bytes]:
319
+ """Retrieve all currently stored `client_public_keys` as a set."""
320
+ return self.client_public_keys
321
+
254
322
  def get_run(self, run_id: int) -> Tuple[int, str, str]:
255
323
  """Retrieve information about the run with the specified `run_id`."""
256
324
  with self.lock: