flwr-nightly 1.11.0.dev20240822__py3-none-any.whl → 1.11.0.dev20240824__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (29) hide show
  1. flwr/cli/build.py +1 -1
  2. flwr/cli/new/templates/app/README.md.tpl +7 -30
  3. flwr/cli/run/run.py +10 -0
  4. flwr/client/client.py +22 -1
  5. flwr/client/numpy_client.py +22 -1
  6. flwr/client/rest_client/connection.py +1 -1
  7. flwr/common/address.py +43 -0
  8. flwr/server/app.py +19 -13
  9. flwr/server/run_serverapp.py +15 -1
  10. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +15 -2
  11. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -0
  12. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
  13. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +11 -11
  14. flwr/server/superlink/fleet/rest_rere/rest_api.py +71 -122
  15. flwr/server/superlink/fleet/vce/backend/backend.py +1 -2
  16. flwr/server/superlink/fleet/vce/backend/raybackend.py +13 -4
  17. flwr/server/superlink/fleet/vce/vce_api.py +2 -6
  18. flwr/server/superlink/state/in_memory_state.py +15 -15
  19. flwr/server/superlink/state/sqlite_state.py +10 -10
  20. flwr/server/superlink/state/state.py +8 -8
  21. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
  22. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -0
  23. flwr/simulation/run_simulation.py +48 -17
  24. flwr/superexec/simulation.py +20 -1
  25. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/METADATA +2 -2
  26. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/RECORD +29 -29
  27. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/LICENSE +0 -0
  28. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/WHEEL +0 -0
  29. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/entry_points.txt +0 -0
@@ -15,17 +15,29 @@
15
15
  """Experimental REST API server."""
16
16
 
17
17
 
18
+ from __future__ import annotations
19
+
18
20
  import sys
21
+ from typing import Awaitable, Callable, TypeVar
22
+
23
+ from google.protobuf.message import Message as GrpcMessage
19
24
 
20
25
  from flwr.common.constant import MISSING_EXTRA_REST
26
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
21
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
28
  CreateNodeRequest,
29
+ CreateNodeResponse,
23
30
  DeleteNodeRequest,
31
+ DeleteNodeResponse,
24
32
  PingRequest,
33
+ PingResponse,
25
34
  PullTaskInsRequest,
35
+ PullTaskInsResponse,
26
36
  PushTaskResRequest,
37
+ PushTaskResResponse,
27
38
  )
28
- from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
39
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
+ from flwr.server.superlink.ffs.ffs import Ffs
29
41
  from flwr.server.superlink.fleet.message_handler import message_handler
30
42
  from flwr.server.superlink.state import State
31
43
 
@@ -40,172 +52,108 @@ except ModuleNotFoundError:
40
52
  sys.exit(MISSING_EXTRA_REST)
41
53
 
42
54
 
43
- async def create_node(request: Request) -> Response:
44
- """Create Node."""
45
- _check_headers(request.headers)
55
+ GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
56
+ GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
46
57
 
47
- # Get the request body as raw bytes
48
- create_node_request_bytes: bytes = await request.body()
58
+ GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
59
+ RestEndPoint = Callable[[Request], Awaitable[Response]]
49
60
 
50
- # Deserialize ProtoBuf
51
- create_node_request_proto = CreateNodeRequest()
52
- create_node_request_proto.ParseFromString(create_node_request_bytes)
53
61
 
54
- # Get state from app
55
- state: State = app.state.STATE_FACTORY.state()
62
+ def rest_request_response(
63
+ grpc_request_type: type[GrpcRequest],
64
+ ) -> Callable[[GrpcAsyncFunction[GrpcRequest, GrpcResponse]], RestEndPoint]:
65
+ """Convert an async gRPC-based function into a RESTful HTTP endpoint."""
56
66
 
57
- # Handle message
58
- create_node_response_proto = message_handler.create_node(
59
- request=create_node_request_proto, state=state
60
- )
67
+ def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
68
+ async def wrapper(request: Request) -> Response:
69
+ _check_headers(request.headers)
61
70
 
62
- # Return serialized ProtoBuf
63
- create_node_response_bytes = create_node_response_proto.SerializeToString()
64
- return Response(
65
- status_code=200,
66
- content=create_node_response_bytes,
67
- headers={"Content-Type": "application/protobuf"},
68
- )
71
+ # Get the request body as raw bytes
72
+ grpc_req_bytes: bytes = await request.body()
69
73
 
74
+ # Deserialize ProtoBuf
75
+ grpc_req = grpc_request_type.FromString(grpc_req_bytes)
76
+ grpc_res = await func(grpc_req)
77
+ return Response(
78
+ status_code=200,
79
+ content=grpc_res.SerializeToString(),
80
+ headers={"Content-Type": "application/protobuf"},
81
+ )
70
82
 
71
- async def delete_node(request: Request) -> Response:
72
- """Delete Node Id."""
73
- _check_headers(request.headers)
83
+ return wrapper
74
84
 
75
- # Get the request body as raw bytes
76
- delete_node_request_bytes: bytes = await request.body()
85
+ return decorator
77
86
 
78
- # Deserialize ProtoBuf
79
- delete_node_request_proto = DeleteNodeRequest()
80
- delete_node_request_proto.ParseFromString(delete_node_request_bytes)
81
87
 
88
+ @rest_request_response(CreateNodeRequest)
89
+ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
90
+ """Create Node."""
82
91
  # Get state from app
83
92
  state: State = app.state.STATE_FACTORY.state()
84
93
 
85
94
  # Handle message
86
- delete_node_response_proto = message_handler.delete_node(
87
- request=delete_node_request_proto, state=state
88
- )
95
+ return message_handler.create_node(request=request, state=state)
89
96
 
90
- # Return serialized ProtoBuf
91
- delete_node_response_bytes = delete_node_response_proto.SerializeToString()
92
- return Response(
93
- status_code=200,
94
- content=delete_node_response_bytes,
95
- headers={"Content-Type": "application/protobuf"},
96
- )
97
97
 
98
+ @rest_request_response(DeleteNodeRequest)
99
+ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
100
+ """Delete Node Id."""
101
+ # Get state from app
102
+ state: State = app.state.STATE_FACTORY.state()
98
103
 
99
- async def pull_task_ins(request: Request) -> Response:
100
- """Pull TaskIns."""
101
- _check_headers(request.headers)
102
-
103
- # Get the request body as raw bytes
104
- pull_task_ins_request_bytes: bytes = await request.body()
104
+ # Handle message
105
+ return message_handler.delete_node(request=request, state=state)
105
106
 
106
- # Deserialize ProtoBuf
107
- pull_task_ins_request_proto = PullTaskInsRequest()
108
- pull_task_ins_request_proto.ParseFromString(pull_task_ins_request_bytes)
109
107
 
108
+ @rest_request_response(PullTaskInsRequest)
109
+ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
110
+ """Pull TaskIns."""
110
111
  # Get state from app
111
112
  state: State = app.state.STATE_FACTORY.state()
112
113
 
113
114
  # Handle message
114
- pull_task_ins_response_proto = message_handler.pull_task_ins(
115
- request=pull_task_ins_request_proto,
116
- state=state,
117
- )
118
-
119
- # Return serialized ProtoBuf
120
- pull_task_ins_response_bytes = pull_task_ins_response_proto.SerializeToString()
121
- return Response(
122
- status_code=200,
123
- content=pull_task_ins_response_bytes,
124
- headers={"Content-Type": "application/protobuf"},
125
- )
115
+ return message_handler.pull_task_ins(request=request, state=state)
126
116
 
127
117
 
128
- async def push_task_res(request: Request) -> Response: # Check if token is needed here
118
+ # Check if token is needed here
119
+ @rest_request_response(PushTaskResRequest)
120
+ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
129
121
  """Push TaskRes."""
130
- _check_headers(request.headers)
131
-
132
- # Get the request body as raw bytes
133
- push_task_res_request_bytes: bytes = await request.body()
134
-
135
- # Deserialize ProtoBuf
136
- push_task_res_request_proto = PushTaskResRequest()
137
- push_task_res_request_proto.ParseFromString(push_task_res_request_bytes)
138
-
139
122
  # Get state from app
140
123
  state: State = app.state.STATE_FACTORY.state()
141
124
 
142
125
  # Handle message
143
- push_task_res_response_proto = message_handler.push_task_res(
144
- request=push_task_res_request_proto,
145
- state=state,
146
- )
147
-
148
- # Return serialized ProtoBuf
149
- push_task_res_response_bytes = push_task_res_response_proto.SerializeToString()
150
- return Response(
151
- status_code=200,
152
- content=push_task_res_response_bytes,
153
- headers={"Content-Type": "application/protobuf"},
154
- )
126
+ return message_handler.push_task_res(request=request, state=state)
155
127
 
156
128
 
157
- async def ping(request: Request) -> Response:
129
+ @rest_request_response(PingRequest)
130
+ async def ping(request: PingRequest) -> PingResponse:
158
131
  """Ping."""
159
- _check_headers(request.headers)
160
-
161
- # Get the request body as raw bytes
162
- ping_request_bytes: bytes = await request.body()
163
-
164
- # Deserialize ProtoBuf
165
- ping_request_proto = PingRequest()
166
- ping_request_proto.ParseFromString(ping_request_bytes)
167
-
168
132
  # Get state from app
169
133
  state: State = app.state.STATE_FACTORY.state()
170
134
 
171
135
  # Handle message
172
- ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
173
-
174
- # Return serialized ProtoBuf
175
- ping_response_bytes = ping_response_proto.SerializeToString()
176
- return Response(
177
- status_code=200,
178
- content=ping_response_bytes,
179
- headers={"Content-Type": "application/protobuf"},
180
- )
136
+ return message_handler.ping(request=request, state=state)
181
137
 
182
138
 
183
- async def get_run(request: Request) -> Response:
139
+ @rest_request_response(GetRunRequest)
140
+ async def get_run(request: GetRunRequest) -> GetRunResponse:
184
141
  """GetRun."""
185
- _check_headers(request.headers)
186
-
187
- # Get the request body as raw bytes
188
- get_run_request_bytes: bytes = await request.body()
189
-
190
- # Deserialize ProtoBuf
191
- get_run_request_proto = GetRunRequest()
192
- get_run_request_proto.ParseFromString(get_run_request_bytes)
193
-
194
142
  # Get state from app
195
143
  state: State = app.state.STATE_FACTORY.state()
196
144
 
197
145
  # Handle message
198
- get_run_response_proto = message_handler.get_run(
199
- request=get_run_request_proto, state=state
200
- )
146
+ return message_handler.get_run(request=request, state=state)
147
+
201
148
 
202
- # Return serialized ProtoBuf
203
- get_run_response_bytes = get_run_response_proto.SerializeToString()
204
- return Response(
205
- status_code=200,
206
- content=get_run_response_bytes,
207
- headers={"Content-Type": "application/protobuf"},
208
- )
149
+ @rest_request_response(GetFabRequest)
150
+ async def get_fab(request: GetFabRequest) -> GetFabResponse:
151
+ """GetRun."""
152
+ # Get ffs from app
153
+ ffs: Ffs = app.state.FFS_FACTORY.state()
154
+
155
+ # Handle message
156
+ return message_handler.get_fab(request=request, ffs=ffs)
209
157
 
210
158
 
211
159
  routes = [
@@ -215,6 +163,7 @@ routes = [
215
163
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
216
164
  Route("/api/v0/fleet/ping", ping, methods=["POST"]),
217
165
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
166
+ Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
218
167
  ]
219
168
 
220
169
  app: Starlette = Starlette(
@@ -33,7 +33,7 @@ class Backend(ABC):
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- def build(self) -> None:
36
+ def build(self, app_fn: Callable[[], ClientApp]) -> None:
37
37
  """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
@@ -60,7 +60,6 @@ class Backend(ABC):
60
60
  @abstractmethod
61
61
  def process_message(
62
62
  self,
63
- app: Callable[[], ClientApp],
64
63
  message: Message,
65
64
  context: Context,
66
65
  ) -> Tuple[Message, Context]:
@@ -16,7 +16,7 @@
16
16
 
17
17
  import sys
18
18
  from logging import DEBUG, ERROR
19
- from typing import Callable, Dict, Tuple, Union
19
+ from typing import Callable, Dict, Optional, Tuple, Union
20
20
 
21
21
  import ray
22
22
 
@@ -63,6 +63,8 @@ class RayBackend(Backend):
63
63
  actor_kwargs=actor_kwargs,
64
64
  )
65
65
 
66
+ self.app_fn: Optional[Callable[[], ClientApp]] = None
67
+
66
68
  def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
67
69
  client_resources_config = config.get(self.client_resources_key)
68
70
  client_resources: ClientResourcesDict = {}
@@ -126,14 +128,15 @@ class RayBackend(Backend):
126
128
  """Report whether the pool has idle actors."""
127
129
  return self.pool.is_actor_available()
128
130
 
129
- def build(self) -> None:
131
+ def build(self, app_fn: Callable[[], ClientApp]) -> None:
130
132
  """Build pool of Ray actors that this backend will submit jobs to."""
131
133
  self.pool.add_actors_to_pool(self.pool.actors_capacity)
134
+ # Set ClientApp callable that ray actors will use
135
+ self.app_fn = app_fn
132
136
  log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
137
 
134
138
  def process_message(
135
139
  self,
136
- app: Callable[[], ClientApp],
137
140
  message: Message,
138
141
  context: Context,
139
142
  ) -> Tuple[Message, Context]:
@@ -143,11 +146,17 @@ class RayBackend(Backend):
143
146
  """
144
147
  partition_id = context.node_config[PARTITION_ID_KEY]
145
148
 
149
+ if self.app_fn is None:
150
+ raise ValueError(
151
+ "Unspecified function to load a `ClientApp`. "
152
+ "Call the backend's `build()` method before processing messages."
153
+ )
154
+
146
155
  try:
147
156
  # Submit a task to the pool
148
157
  future = self.pool.submit(
149
158
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
150
- (app, message, str(partition_id), context),
159
+ (self.app_fn, message, str(partition_id), context),
151
160
  )
152
161
 
153
162
  # Fetch result
@@ -87,7 +87,6 @@ def _register_node_states(
87
87
 
88
88
  # pylint: disable=too-many-arguments,too-many-locals
89
89
  def worker(
90
- app_fn: Callable[[], ClientApp],
91
90
  taskins_queue: "Queue[TaskIns]",
92
91
  taskres_queue: "Queue[TaskRes]",
93
92
  node_states: Dict[int, NodeState],
@@ -110,9 +109,7 @@ def worker(
110
109
  message = message_from_taskins(task_ins)
111
110
 
112
111
  # Let backend process message
113
- out_mssg, updated_context = backend.process_message(
114
- app_fn, message, context
115
- )
112
+ out_mssg, updated_context = backend.process_message(message, context)
116
113
 
117
114
  # Update Context
118
115
  node_states[node_id].update_context(
@@ -193,7 +190,7 @@ def run_api(
193
190
  backend = backend_fn()
194
191
 
195
192
  # Build backend
196
- backend.build()
193
+ backend.build(app_fn)
197
194
 
198
195
  # Add workers (they submit Messages to Backend)
199
196
  state = state_factory.state()
@@ -223,7 +220,6 @@ def run_api(
223
220
  _ = [
224
221
  executor.submit(
225
222
  worker,
226
- app_fn,
227
223
  taskins_queue,
228
224
  taskres_queue,
229
225
  node_states,
@@ -45,7 +45,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
45
45
  self.task_ins_store: Dict[UUID, TaskIns] = {}
46
46
  self.task_res_store: Dict[UUID, TaskRes] = {}
47
47
 
48
- self.client_public_keys: Set[bytes] = set()
48
+ self.node_public_keys: Set[bytes] = set()
49
49
  self.server_public_key: Optional[bytes] = None
50
50
  self.server_private_key: Optional[bytes] = None
51
51
 
@@ -237,7 +237,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
237
237
  return node_id
238
238
 
239
239
  def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
240
- """Delete a client node."""
240
+ """Delete a node."""
241
241
  with self.lock:
242
242
  if node_id not in self.node_ids:
243
243
  raise ValueError(f"Node {node_id} not found")
@@ -254,7 +254,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
254
254
  del self.node_ids[node_id]
255
255
 
256
256
  def get_nodes(self, run_id: int) -> Set[int]:
257
- """Return all available client nodes.
257
+ """Return all available nodes.
258
258
 
259
259
  Constraints
260
260
  -----------
@@ -271,9 +271,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
271
271
  if online_until > current_time
272
272
  }
273
273
 
274
- def get_node_id(self, client_public_key: bytes) -> Optional[int]:
275
- """Retrieve stored `node_id` filtered by `client_public_keys`."""
276
- return self.public_key_to_node_id.get(client_public_key)
274
+ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
275
+ """Retrieve stored `node_id` filtered by `node_public_keys`."""
276
+ return self.public_key_to_node_id.get(node_public_key)
277
277
 
278
278
  def create_run(
279
279
  self,
@@ -318,19 +318,19 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
318
318
  """Retrieve `server_public_key` in urlsafe bytes."""
319
319
  return self.server_public_key
320
320
 
321
- def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
322
- """Store a set of `client_public_keys` in state."""
321
+ def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
322
+ """Store a set of `node_public_keys` in state."""
323
323
  with self.lock:
324
- self.client_public_keys = public_keys
324
+ self.node_public_keys = public_keys
325
325
 
326
- def store_client_public_key(self, public_key: bytes) -> None:
327
- """Store a `client_public_key` in state."""
326
+ def store_node_public_key(self, public_key: bytes) -> None:
327
+ """Store a `node_public_key` in state."""
328
328
  with self.lock:
329
- self.client_public_keys.add(public_key)
329
+ self.node_public_keys.add(public_key)
330
330
 
331
- def get_client_public_keys(self) -> Set[bytes]:
332
- """Retrieve all currently stored `client_public_keys` as a set."""
333
- return self.client_public_keys
331
+ def get_node_public_keys(self) -> Set[bytes]:
332
+ """Retrieve all currently stored `node_public_keys` as a set."""
333
+ return self.node_public_keys
334
334
 
335
335
  def get_run(self, run_id: int) -> Optional[Run]:
336
336
  """Retrieve information about the run with the specified `run_id`."""
@@ -569,7 +569,7 @@ class SqliteState(State): # pylint: disable=R0904
569
569
  return node_id
570
570
 
571
571
  def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
572
- """Delete a client node."""
572
+ """Delete a node."""
573
573
  query = "DELETE FROM node WHERE node_id = ?"
574
574
  params = (node_id,)
575
575
 
@@ -607,10 +607,10 @@ class SqliteState(State): # pylint: disable=R0904
607
607
  result: Set[int] = {row["node_id"] for row in rows}
608
608
  return result
609
609
 
610
- def get_node_id(self, client_public_key: bytes) -> Optional[int]:
611
- """Retrieve stored `node_id` filtered by `client_public_keys`."""
610
+ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
611
+ """Retrieve stored `node_id` filtered by `node_public_keys`."""
612
612
  query = "SELECT node_id FROM node WHERE public_key = :public_key;"
613
- row = self.query(query, {"public_key": client_public_key})
613
+ row = self.query(query, {"public_key": node_public_key})
614
614
  if len(row) > 0:
615
615
  node_id: int = row[0]["node_id"]
616
616
  return node_id
@@ -684,19 +684,19 @@ class SqliteState(State): # pylint: disable=R0904
684
684
  public_key = None
685
685
  return public_key
686
686
 
687
- def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
688
- """Store a set of `client_public_keys` in state."""
687
+ def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
688
+ """Store a set of `node_public_keys` in state."""
689
689
  query = "INSERT INTO public_key (public_key) VALUES (?)"
690
690
  data = [(key,) for key in public_keys]
691
691
  self.query(query, data)
692
692
 
693
- def store_client_public_key(self, public_key: bytes) -> None:
694
- """Store a `client_public_key` in state."""
693
+ def store_node_public_key(self, public_key: bytes) -> None:
694
+ """Store a `node_public_key` in state."""
695
695
  query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
696
696
  self.query(query, {"public_key": public_key})
697
697
 
698
- def get_client_public_keys(self) -> Set[bytes]:
699
- """Retrieve all currently stored `client_public_keys` as a set."""
698
+ def get_node_public_keys(self) -> Set[bytes]:
699
+ """Retrieve all currently stored `node_public_keys` as a set."""
700
700
  query = "SELECT public_key FROM public_key"
701
701
  rows = self.query(query)
702
702
  result: Set[bytes] = {row["public_key"] for row in rows}
@@ -153,8 +153,8 @@ class State(abc.ABC): # pylint: disable=R0904
153
153
  """
154
154
 
155
155
  @abc.abstractmethod
156
- def get_node_id(self, client_public_key: bytes) -> Optional[int]:
157
- """Retrieve stored `node_id` filtered by `client_public_keys`."""
156
+ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
157
+ """Retrieve stored `node_id` filtered by `node_public_keys`."""
158
158
 
159
159
  @abc.abstractmethod
160
160
  def create_run(
@@ -199,16 +199,16 @@ class State(abc.ABC): # pylint: disable=R0904
199
199
  """Retrieve `server_public_key` in urlsafe bytes."""
200
200
 
201
201
  @abc.abstractmethod
202
- def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
203
- """Store a set of `client_public_keys` in state."""
202
+ def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
203
+ """Store a set of `node_public_keys` in state."""
204
204
 
205
205
  @abc.abstractmethod
206
- def store_client_public_key(self, public_key: bytes) -> None:
207
- """Store a `client_public_key` in state."""
206
+ def store_node_public_key(self, public_key: bytes) -> None:
207
+ """Store a `node_public_key` in state."""
208
208
 
209
209
  @abc.abstractmethod
210
- def get_client_public_keys(self) -> Set[bytes]:
211
- """Retrieve all currently stored `client_public_keys` as a set."""
210
+ def get_node_public_keys(self) -> Set[bytes]:
211
+ """Retrieve all currently stored `node_public_keys` as a set."""
212
212
 
213
213
  @abc.abstractmethod
214
214
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
@@ -35,6 +35,7 @@ class SecAggWorkflow(SecAggPlusWorkflow):
35
35
  contributions to compute the weighted average of model parameters.
36
36
 
37
37
  The protocol involves four main stages:
38
+
38
39
  - 'setup': Send SecAgg configuration to clients and collect their public keys.
39
40
  - 'share keys': Broadcast public keys among clients and collect encrypted secret
40
41
  key shares.
@@ -99,6 +99,7 @@ class SecAggPlusWorkflow:
99
99
  contributions to compute the weighted average of model parameters.
100
100
 
101
101
  The protocol involves four main stages:
102
+
102
103
  - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
103
104
  - 'share keys': Broadcast public keys among clients and collect encrypted secret
104
105
  key shares.