flwr-nightly 1.9.0.dev20240423__py3-none-any.whl → 1.9.0.dev20240425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. flwr/cli/config_utils.py +18 -46
  2. flwr/cli/new/new.py +37 -17
  3. flwr/cli/new/templates/app/README.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  5. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +6 -3
  7. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +6 -3
  8. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +6 -3
  9. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +6 -3
  10. flwr/cli/run/run.py +1 -1
  11. flwr/cli/utils.py +18 -17
  12. flwr/client/grpc_client/connection.py +6 -1
  13. flwr/client/grpc_rere_client/client_interceptor.py +150 -0
  14. flwr/client/grpc_rere_client/connection.py +17 -2
  15. flwr/client/rest_client/connection.py +5 -1
  16. flwr/common/grpc.py +5 -1
  17. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +20 -1
  18. flwr/server/compat/app_utils.py +1 -1
  19. flwr/server/compat/driver_client_proxy.py +27 -72
  20. flwr/server/driver/grpc_driver.py +17 -8
  21. flwr/server/run_serverapp.py +14 -0
  22. flwr/server/superlink/driver/driver_servicer.py +1 -1
  23. flwr/server/superlink/state/in_memory_state.py +37 -1
  24. flwr/server/superlink/state/sqlite_state.py +71 -4
  25. flwr/server/superlink/state/state.py +26 -0
  26. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/METADATA +1 -1
  27. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/RECORD +30 -29
  28. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.9.0.dev20240423.dist-info → flwr_nightly-1.9.0.dev20240425.dist-info}/entry_points.txt +0 -0
flwr/common/grpc.py CHANGED
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from logging import DEBUG
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
20
 
21
21
  import grpc
22
22
 
@@ -30,6 +30,7 @@ def create_channel(
30
30
  insecure: bool,
31
31
  root_certificates: Optional[bytes] = None,
32
32
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
33
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
33
34
  ) -> grpc.Channel:
34
35
  """Create a gRPC channel, either secure or insecure."""
35
36
  # Check for conflicting parameters
@@ -57,4 +58,7 @@ def create_channel(
57
58
  )
58
59
  log(DEBUG, "Opened secure gRPC connection using certificates")
59
60
 
61
+ if interceptors is not None:
62
+ channel = grpc.intercept_channel(channel, interceptors)
63
+
60
64
  return channel
@@ -18,8 +18,9 @@
18
18
  import base64
19
19
  from typing import Tuple, cast
20
20
 
21
+ from cryptography.exceptions import InvalidSignature
21
22
  from cryptography.fernet import Fernet
22
- from cryptography.hazmat.primitives import hashes, serialization
23
+ from cryptography.hazmat.primitives import hashes, hmac, serialization
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
25
26
 
@@ -98,3 +99,21 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
98
99
  # The input key must be url safe
99
100
  fernet = Fernet(key)
100
101
  return fernet.decrypt(ciphertext)
102
+
103
+
104
+ def compute_hmac(key: bytes, message: bytes) -> bytes:
105
+ """Compute hmac of a message using key as hash."""
106
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
107
+ computed_hmac.update(message)
108
+ return computed_hmac.finalize()
109
+
110
+
111
+ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
112
+ """Verify hmac of a message using key as hash."""
113
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
114
+ computed_hmac.update(message)
115
+ try:
116
+ computed_hmac.verify(hmac_value)
117
+ return True
118
+ except InvalidSignature:
119
+ return False
@@ -89,7 +89,7 @@ def _update_client_manager(
89
89
  for node_id in new_nodes:
90
90
  client_proxy = DriverClientProxy(
91
91
  node_id=node_id,
92
- driver=driver.grpc_driver_helper, # type: ignore
92
+ driver=driver,
93
93
  anonymous=False,
94
94
  run_id=driver.run_id, # type: ignore
95
95
  )
@@ -16,16 +16,14 @@
16
16
 
17
17
 
18
18
  import time
19
- from typing import List, Optional
19
+ from typing import Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
- from flwr.common import serde
25
- from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
24
  from flwr.server.client_proxy import ClientProxy
27
25
 
28
- from ..driver.grpc_driver import GrpcDriverHelper
26
+ from ..driver.driver import Driver
29
27
 
30
28
  SLEEP_TIME = 1
31
29
 
@@ -33,9 +31,7 @@ SLEEP_TIME = 1
33
31
  class DriverClientProxy(ClientProxy):
34
32
  """Flower client proxy which delegates work using the Driver API."""
35
33
 
36
- def __init__(
37
- self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
38
- ):
34
+ def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
39
35
  super().__init__(str(node_id))
40
36
  self.node_id = node_id
41
37
  self.driver = driver
@@ -116,80 +112,39 @@ class DriverClientProxy(ClientProxy):
116
112
  timeout: Optional[float],
117
113
  group_id: Optional[int],
118
114
  ) -> RecordSet:
119
- task_ins = task_pb2.TaskIns( # pylint: disable=E1101
120
- task_id="",
121
- group_id=str(group_id) if group_id is not None else "",
122
- run_id=self.run_id,
123
- task=task_pb2.Task( # pylint: disable=E1101
124
- producer=node_pb2.Node( # pylint: disable=E1101
125
- node_id=0,
126
- anonymous=True,
127
- ),
128
- consumer=node_pb2.Node( # pylint: disable=E1101
129
- node_id=self.node_id,
130
- anonymous=self.anonymous,
131
- ),
132
- task_type=task_type,
133
- recordset=serde.recordset_to_proto(recordset),
134
- ttl=DEFAULT_TTL,
135
- ),
136
- )
137
-
138
- # This would normally be recorded upon common.Message creation
139
- # but this compatibility stack doesn't create Messages,
140
- # so we need to inject `created_at` manually (needed for
141
- # taskins validation by server.utils.validator)
142
- task_ins.task.created_at = time.time()
143
115
 
144
- push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
145
- task_ins_list=[task_ins]
116
+ # Create message
117
+ message = self.driver.create_message(
118
+ content=recordset,
119
+ message_type=task_type,
120
+ dst_node_id=self.node_id,
121
+ group_id=str(group_id) if group_id else "",
122
+ ttl=timeout,
146
123
  )
147
124
 
148
- # Send TaskIns to Driver API
149
- push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
150
-
151
- if len(push_task_ins_res.task_ids) != 1:
152
- raise ValueError("Unexpected number of task_ids")
125
+ # Push message
126
+ message_ids = list(self.driver.push_messages(messages=[message]))
127
+ if len(message_ids) != 1:
128
+ raise ValueError("Unexpected number of message_ids")
153
129
 
154
- task_id = push_task_ins_res.task_ids[0]
155
- if task_id == "":
156
- raise ValueError(f"Failed to schedule task for node {self.node_id}")
130
+ message_id = message_ids[0]
131
+ if message_id == "":
132
+ raise ValueError(f"Failed to send message to node {self.node_id}")
157
133
 
158
134
  if timeout:
159
135
  start_time = time.time()
160
136
 
161
137
  while True:
162
- pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
163
- node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
164
- task_ids=[task_id],
165
- )
166
-
167
- # Ask Driver API for TaskRes
168
- pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
169
-
170
- task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
171
- pull_task_res_res.task_res_list
172
- )
173
- if len(task_res_list) == 1:
174
- task_res = task_res_list[0]
175
-
176
- # This will raise an Exception if task_res carries an `error`
177
- validate_task_res(task_res=task_res)
178
-
179
- return serde.recordset_from_proto(task_res.task.recordset)
138
+ messages = list(self.driver.pull_messages(message_ids))
139
+ if len(messages) == 1:
140
+ msg: Message = messages[0]
141
+ if msg.has_error():
142
+ raise ValueError(
143
+ f"Message contains an Error (reason: {msg.error.reason}). "
144
+ "It originated during client-side execution of a message."
145
+ )
146
+ return msg.content
180
147
 
181
148
  if timeout is not None and time.time() > start_time + timeout:
182
149
  raise RuntimeError("Timeout reached")
183
150
  time.sleep(SLEEP_TIME)
184
-
185
-
186
- def validate_task_res(
187
- task_res: task_pb2.TaskRes, # pylint: disable=E1101
188
- ) -> None:
189
- """Validate if a TaskRes is empty or not."""
190
- if not task_res.HasField("task"):
191
- raise ValueError("Invalid TaskRes, field `task` missing")
192
- if task_res.task.HasField("error"):
193
- raise ValueError("Exception during client-side task execution")
194
- if not task_res.task.HasField("recordset"):
195
- raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
@@ -151,31 +151,40 @@ class GrpcDriver(Driver):
151
151
  * CA certificate.
152
152
  * server certificate.
153
153
  * server private key.
154
+ fab_id : str (default: None)
155
+ The identifier of the FAB used in the run.
156
+ fab_version : str (default: None)
157
+ The version of the FAB used in the run.
154
158
  """
155
159
 
156
160
  def __init__(
157
161
  self,
158
162
  driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
159
163
  root_certificates: Optional[bytes] = None,
164
+ fab_id: Optional[str] = None,
165
+ fab_version: Optional[str] = None,
160
166
  ) -> None:
161
167
  self.addr = driver_service_address
162
168
  self.root_certificates = root_certificates
163
- self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
169
+ self.driver_helper: Optional[GrpcDriverHelper] = None
164
170
  self.run_id: Optional[int] = None
171
+ self.fab_id = fab_id if fab_id is not None else ""
172
+ self.fab_version = fab_version if fab_version is not None else ""
165
173
  self.node = Node(node_id=0, anonymous=True)
166
174
 
167
175
  def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
168
176
  # Check if the GrpcDriverHelper is initialized
169
- if self.grpc_driver_helper is None or self.run_id is None:
177
+ if self.driver_helper is None or self.run_id is None:
170
178
  # Connect and create run
171
- self.grpc_driver_helper = GrpcDriverHelper(
179
+ self.driver_helper = GrpcDriverHelper(
172
180
  driver_service_address=self.addr,
173
181
  root_certificates=self.root_certificates,
174
182
  )
175
- self.grpc_driver_helper.connect()
176
- res = self.grpc_driver_helper.create_run(CreateRunRequest())
183
+ self.driver_helper.connect()
184
+ req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
185
+ res = self.driver_helper.create_run(req)
177
186
  self.run_id = res.run_id
178
- return self.grpc_driver_helper, self.run_id
187
+ return self.driver_helper, self.run_id
179
188
 
180
189
  def _check_message(self, message: Message) -> None:
181
190
  # Check if the message is valid
@@ -300,7 +309,7 @@ class GrpcDriver(Driver):
300
309
  def close(self) -> None:
301
310
  """Disconnect from the SuperLink if connected."""
302
311
  # Check if GrpcDriverHelper is initialized
303
- if self.grpc_driver_helper is None:
312
+ if self.driver_helper is None:
304
313
  return
305
314
  # Disconnect
306
- self.grpc_driver_helper.disconnect()
315
+ self.driver_helper.disconnect()
@@ -132,6 +132,8 @@ def run_server_app() -> None:
132
132
  driver = GrpcDriver(
133
133
  driver_service_address=args.server,
134
134
  root_certificates=root_certificates,
135
+ fab_id=args.fab_id,
136
+ fab_version=args.fab_version,
135
137
  )
136
138
 
137
139
  # Run the ServerApp with the Driver
@@ -183,5 +185,17 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
183
185
  "app from there."
184
186
  " Default: current working directory.",
185
187
  )
188
+ parser.add_argument(
189
+ "--fab-id",
190
+ default=None,
191
+ type=str,
192
+ help="The identifier of the FAB used in the run.",
193
+ )
194
+ parser.add_argument(
195
+ "--fab-version",
196
+ default=None,
197
+ type=str,
198
+ help="The version of the FAB used in the run.",
199
+ )
186
200
 
187
201
  return parser
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
64
64
  """Create run ID."""
65
65
  log(INFO, "DriverServicer.CreateRun")
66
66
  state: State = self.state_factory.state()
67
- run_id = state.create_run("None/None", "None")
67
+ run_id = state.create_run(request.fab_id, request.fab_version)
68
68
  return CreateRunResponse(run_id=run_id)
69
69
 
70
70
  def PushTaskIns(
@@ -30,7 +30,7 @@ from flwr.server.utils import validate_task_ins_or_res
30
30
  from .utils import make_node_unavailable_taskres
31
31
 
32
32
 
33
- class InMemoryState(State):
33
+ class InMemoryState(State): # pylint: disable=R0902
34
34
  """In-memory State implementation."""
35
35
 
36
36
  def __init__(self) -> None:
@@ -40,6 +40,9 @@ class InMemoryState(State):
40
40
  self.run_ids: Dict[int, Tuple[str, str]] = {}
41
41
  self.task_ins_store: Dict[UUID, TaskIns] = {}
42
42
  self.task_res_store: Dict[UUID, TaskRes] = {}
43
+ self.client_public_keys: Set[bytes] = set()
44
+ self.server_public_key: Optional[bytes] = None
45
+ self.server_private_key: Optional[bytes] = None
43
46
  self.lock = threading.Lock()
44
47
 
45
48
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
@@ -251,6 +254,39 @@ class InMemoryState(State):
251
254
  log(ERROR, "Unexpected run creation failure.")
252
255
  return 0
253
256
 
257
+ def store_server_public_private_key(
258
+ self, public_key: bytes, private_key: bytes
259
+ ) -> None:
260
+ """Store `server_public_key` and `server_private_key` in state."""
261
+ with self.lock:
262
+ if self.server_private_key is None and self.server_public_key is None:
263
+ self.server_private_key = private_key
264
+ self.server_public_key = public_key
265
+ else:
266
+ raise RuntimeError("Server public and private key already set")
267
+
268
+ def get_server_private_key(self) -> Optional[bytes]:
269
+ """Retrieve `server_private_key` in urlsafe bytes."""
270
+ return self.server_private_key
271
+
272
+ def get_server_public_key(self) -> Optional[bytes]:
273
+ """Retrieve `server_public_key` in urlsafe bytes."""
274
+ return self.server_public_key
275
+
276
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
277
+ """Store a set of `client_public_keys` in state."""
278
+ with self.lock:
279
+ self.client_public_keys = public_keys
280
+
281
+ def store_client_public_key(self, public_key: bytes) -> None:
282
+ """Store a `client_public_key` in state."""
283
+ with self.lock:
284
+ self.client_public_keys.add(public_key)
285
+
286
+ def get_client_public_keys(self) -> Set[bytes]:
287
+ """Retrieve all currently stored `client_public_keys` as a set."""
288
+ return self.client_public_keys
289
+
254
290
  def get_run(self, run_id: int) -> Tuple[int, str, str]:
255
291
  """Retrieve information about the run with the specified `run_id`."""
256
292
  with self.lock:
@@ -20,7 +20,7 @@ import re
20
20
  import sqlite3
21
21
  import time
22
22
  from logging import DEBUG, ERROR
23
- from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
23
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
@@ -40,6 +40,19 @@ CREATE TABLE IF NOT EXISTS node(
40
40
  );
41
41
  """
42
42
 
43
+ SQL_CREATE_TABLE_CREDENTIAL = """
44
+ CREATE TABLE IF NOT EXISTS credential(
45
+ public_key BLOB PRIMARY KEY,
46
+ private_key BLOB
47
+ );
48
+ """
49
+
50
+ SQL_CREATE_TABLE_PUBLIC_KEY = """
51
+ CREATE TABLE IF NOT EXISTS public_key(
52
+ public_key BLOB UNIQUE
53
+ );
54
+ """
55
+
43
56
  SQL_CREATE_INDEX_ONLINE_UNTIL = """
44
57
  CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
45
58
  """
@@ -72,7 +85,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
72
85
  );
73
86
  """
74
87
 
75
-
76
88
  SQL_CREATE_TABLE_TASK_RES = """
77
89
  CREATE TABLE IF NOT EXISTS task_res(
78
90
  task_id TEXT UNIQUE,
@@ -96,7 +108,7 @@ CREATE TABLE IF NOT EXISTS task_res(
96
108
  DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
97
109
 
98
110
 
99
- class SqliteState(State):
111
+ class SqliteState(State): # pylint: disable=R0904
100
112
  """SQLite-based state implementation."""
101
113
 
102
114
  def __init__(
@@ -134,6 +146,8 @@ class SqliteState(State):
134
146
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
135
147
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
136
148
  cur.execute(SQL_CREATE_TABLE_NODE)
149
+ cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
150
+ cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
137
151
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
138
152
  res = cur.execute("SELECT name FROM sqlite_schema;")
139
153
 
@@ -142,7 +156,7 @@ class SqliteState(State):
142
156
  def query(
143
157
  self,
144
158
  query: str,
145
- data: Optional[Union[List[DictOrTuple], DictOrTuple]] = None,
159
+ data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
146
160
  ) -> List[Dict[str, Any]]:
147
161
  """Execute a SQL query."""
148
162
  if self.conn is None:
@@ -575,6 +589,59 @@ class SqliteState(State):
575
589
  log(ERROR, "Unexpected run creation failure.")
576
590
  return 0
577
591
 
592
+ def store_server_public_private_key(
593
+ self, public_key: bytes, private_key: bytes
594
+ ) -> None:
595
+ """Store `server_public_key` and `server_private_key` in state."""
596
+ query = "SELECT COUNT(*) FROM credential"
597
+ count = self.query(query)[0]["COUNT(*)"]
598
+ if count < 1:
599
+ query = (
600
+ "INSERT OR REPLACE INTO credential (public_key, private_key) "
601
+ "VALUES (:public_key, :private_key)"
602
+ )
603
+ self.query(query, {"public_key": public_key, "private_key": private_key})
604
+ else:
605
+ raise RuntimeError("Server public and private key already set")
606
+
607
+ def get_server_private_key(self) -> Optional[bytes]:
608
+ """Retrieve `server_private_key` in urlsafe bytes."""
609
+ query = "SELECT private_key FROM credential"
610
+ rows = self.query(query)
611
+ try:
612
+ private_key: Optional[bytes] = rows[0]["private_key"]
613
+ except IndexError:
614
+ private_key = None
615
+ return private_key
616
+
617
+ def get_server_public_key(self) -> Optional[bytes]:
618
+ """Retrieve `server_public_key` in urlsafe bytes."""
619
+ query = "SELECT public_key FROM credential"
620
+ rows = self.query(query)
621
+ try:
622
+ public_key: Optional[bytes] = rows[0]["public_key"]
623
+ except IndexError:
624
+ public_key = None
625
+ return public_key
626
+
627
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
628
+ """Store a set of `client_public_keys` in state."""
629
+ query = "INSERT INTO public_key (public_key) VALUES (?)"
630
+ data = [(key,) for key in public_keys]
631
+ self.query(query, data)
632
+
633
+ def store_client_public_key(self, public_key: bytes) -> None:
634
+ """Store a `client_public_key` in state."""
635
+ query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
636
+ self.query(query, {"public_key": public_key})
637
+
638
+ def get_client_public_keys(self) -> Set[bytes]:
639
+ """Retrieve all currently stored `client_public_keys` as a set."""
640
+ query = "SELECT public_key FROM public_key"
641
+ rows = self.query(query)
642
+ result: Set[bytes] = {row["public_key"] for row in rows}
643
+ return result
644
+
578
645
  def get_run(self, run_id: int) -> Tuple[int, str, str]:
579
646
  """Retrieve information about the run with the specified `run_id`."""
580
647
  query = "SELECT * FROM run WHERE run_id = ?;"
@@ -171,6 +171,32 @@ class State(abc.ABC):
171
171
  - `fab_version`: The version of the FAB used in the specified run.
172
172
  """
173
173
 
174
+ @abc.abstractmethod
175
+ def store_server_public_private_key(
176
+ self, public_key: bytes, private_key: bytes
177
+ ) -> None:
178
+ """Store `server_public_key` and `server_private_key` in state."""
179
+
180
+ @abc.abstractmethod
181
+ def get_server_private_key(self) -> Optional[bytes]:
182
+ """Retrieve `server_private_key` in urlsafe bytes."""
183
+
184
+ @abc.abstractmethod
185
+ def get_server_public_key(self) -> Optional[bytes]:
186
+ """Retrieve `server_public_key` in urlsafe bytes."""
187
+
188
+ @abc.abstractmethod
189
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
190
+ """Store a set of `client_public_keys` in state."""
191
+
192
+ @abc.abstractmethod
193
+ def store_client_public_key(self, public_key: bytes) -> None:
194
+ """Store a `client_public_key` in state."""
195
+
196
+ @abc.abstractmethod
197
+ def get_client_public_keys(self) -> Set[bytes]:
198
+ """Retrieve all currently stored `client_public_keys` as a set."""
199
+
174
200
  @abc.abstractmethod
175
201
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
176
202
  """Acknowledge a ping received from a node, serving as a heartbeat.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240423
3
+ Version: 1.9.0.dev20240425
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0