flwr-nightly 1.11.0.dev20240822__py3-none-any.whl → 1.11.0.dev20240824__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -1
- flwr/cli/new/templates/app/README.md.tpl +7 -30
- flwr/cli/run/run.py +10 -0
- flwr/client/client.py +22 -1
- flwr/client/numpy_client.py +22 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/common/address.py +43 -0
- flwr/server/app.py +19 -13
- flwr/server/run_serverapp.py +15 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +15 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +11 -11
- flwr/server/superlink/fleet/rest_rere/rest_api.py +71 -122
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +13 -4
- flwr/server/superlink/fleet/vce/vce_api.py +2 -6
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -0
- flwr/simulation/run_simulation.py +48 -17
- flwr/superexec/simulation.py +20 -1
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/METADATA +2 -2
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/RECORD +29 -29
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.0.dev20240824.dist-info}/entry_points.txt +0 -0
|
@@ -15,17 +15,29 @@
|
|
|
15
15
|
"""Experimental REST API server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
import sys
|
|
21
|
+
from typing import Awaitable, Callable, TypeVar
|
|
22
|
+
|
|
23
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
19
24
|
|
|
20
25
|
from flwr.common.constant import MISSING_EXTRA_REST
|
|
26
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
21
27
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
28
|
CreateNodeRequest,
|
|
29
|
+
CreateNodeResponse,
|
|
23
30
|
DeleteNodeRequest,
|
|
31
|
+
DeleteNodeResponse,
|
|
24
32
|
PingRequest,
|
|
33
|
+
PingResponse,
|
|
25
34
|
PullTaskInsRequest,
|
|
35
|
+
PullTaskInsResponse,
|
|
26
36
|
PushTaskResRequest,
|
|
37
|
+
PushTaskResResponse,
|
|
27
38
|
)
|
|
28
|
-
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
39
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
|
29
41
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
30
42
|
from flwr.server.superlink.state import State
|
|
31
43
|
|
|
@@ -40,172 +52,108 @@ except ModuleNotFoundError:
|
|
|
40
52
|
sys.exit(MISSING_EXTRA_REST)
|
|
41
53
|
|
|
42
54
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
_check_headers(request.headers)
|
|
55
|
+
GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
|
|
56
|
+
GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
|
|
46
57
|
|
|
47
|
-
|
|
48
|
-
|
|
58
|
+
GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
|
|
59
|
+
RestEndPoint = Callable[[Request], Awaitable[Response]]
|
|
49
60
|
|
|
50
|
-
# Deserialize ProtoBuf
|
|
51
|
-
create_node_request_proto = CreateNodeRequest()
|
|
52
|
-
create_node_request_proto.ParseFromString(create_node_request_bytes)
|
|
53
61
|
|
|
54
|
-
|
|
55
|
-
|
|
62
|
+
def rest_request_response(
|
|
63
|
+
grpc_request_type: type[GrpcRequest],
|
|
64
|
+
) -> Callable[[GrpcAsyncFunction[GrpcRequest, GrpcResponse]], RestEndPoint]:
|
|
65
|
+
"""Convert an async gRPC-based function into a RESTful HTTP endpoint."""
|
|
56
66
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
)
|
|
67
|
+
def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
|
|
68
|
+
async def wrapper(request: Request) -> Response:
|
|
69
|
+
_check_headers(request.headers)
|
|
61
70
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
return Response(
|
|
65
|
-
status_code=200,
|
|
66
|
-
content=create_node_response_bytes,
|
|
67
|
-
headers={"Content-Type": "application/protobuf"},
|
|
68
|
-
)
|
|
71
|
+
# Get the request body as raw bytes
|
|
72
|
+
grpc_req_bytes: bytes = await request.body()
|
|
69
73
|
|
|
74
|
+
# Deserialize ProtoBuf
|
|
75
|
+
grpc_req = grpc_request_type.FromString(grpc_req_bytes)
|
|
76
|
+
grpc_res = await func(grpc_req)
|
|
77
|
+
return Response(
|
|
78
|
+
status_code=200,
|
|
79
|
+
content=grpc_res.SerializeToString(),
|
|
80
|
+
headers={"Content-Type": "application/protobuf"},
|
|
81
|
+
)
|
|
70
82
|
|
|
71
|
-
|
|
72
|
-
"""Delete Node Id."""
|
|
73
|
-
_check_headers(request.headers)
|
|
83
|
+
return wrapper
|
|
74
84
|
|
|
75
|
-
|
|
76
|
-
delete_node_request_bytes: bytes = await request.body()
|
|
85
|
+
return decorator
|
|
77
86
|
|
|
78
|
-
# Deserialize ProtoBuf
|
|
79
|
-
delete_node_request_proto = DeleteNodeRequest()
|
|
80
|
-
delete_node_request_proto.ParseFromString(delete_node_request_bytes)
|
|
81
87
|
|
|
88
|
+
@rest_request_response(CreateNodeRequest)
|
|
89
|
+
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
90
|
+
"""Create Node."""
|
|
82
91
|
# Get state from app
|
|
83
92
|
state: State = app.state.STATE_FACTORY.state()
|
|
84
93
|
|
|
85
94
|
# Handle message
|
|
86
|
-
|
|
87
|
-
request=delete_node_request_proto, state=state
|
|
88
|
-
)
|
|
95
|
+
return message_handler.create_node(request=request, state=state)
|
|
89
96
|
|
|
90
|
-
# Return serialized ProtoBuf
|
|
91
|
-
delete_node_response_bytes = delete_node_response_proto.SerializeToString()
|
|
92
|
-
return Response(
|
|
93
|
-
status_code=200,
|
|
94
|
-
content=delete_node_response_bytes,
|
|
95
|
-
headers={"Content-Type": "application/protobuf"},
|
|
96
|
-
)
|
|
97
97
|
|
|
98
|
+
@rest_request_response(DeleteNodeRequest)
|
|
99
|
+
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
100
|
+
"""Delete Node Id."""
|
|
101
|
+
# Get state from app
|
|
102
|
+
state: State = app.state.STATE_FACTORY.state()
|
|
98
103
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
_check_headers(request.headers)
|
|
102
|
-
|
|
103
|
-
# Get the request body as raw bytes
|
|
104
|
-
pull_task_ins_request_bytes: bytes = await request.body()
|
|
104
|
+
# Handle message
|
|
105
|
+
return message_handler.delete_node(request=request, state=state)
|
|
105
106
|
|
|
106
|
-
# Deserialize ProtoBuf
|
|
107
|
-
pull_task_ins_request_proto = PullTaskInsRequest()
|
|
108
|
-
pull_task_ins_request_proto.ParseFromString(pull_task_ins_request_bytes)
|
|
109
107
|
|
|
108
|
+
@rest_request_response(PullTaskInsRequest)
|
|
109
|
+
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
110
|
+
"""Pull TaskIns."""
|
|
110
111
|
# Get state from app
|
|
111
112
|
state: State = app.state.STATE_FACTORY.state()
|
|
112
113
|
|
|
113
114
|
# Handle message
|
|
114
|
-
|
|
115
|
-
request=pull_task_ins_request_proto,
|
|
116
|
-
state=state,
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
# Return serialized ProtoBuf
|
|
120
|
-
pull_task_ins_response_bytes = pull_task_ins_response_proto.SerializeToString()
|
|
121
|
-
return Response(
|
|
122
|
-
status_code=200,
|
|
123
|
-
content=pull_task_ins_response_bytes,
|
|
124
|
-
headers={"Content-Type": "application/protobuf"},
|
|
125
|
-
)
|
|
115
|
+
return message_handler.pull_task_ins(request=request, state=state)
|
|
126
116
|
|
|
127
117
|
|
|
128
|
-
|
|
118
|
+
# Check if token is needed here
|
|
119
|
+
@rest_request_response(PushTaskResRequest)
|
|
120
|
+
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
129
121
|
"""Push TaskRes."""
|
|
130
|
-
_check_headers(request.headers)
|
|
131
|
-
|
|
132
|
-
# Get the request body as raw bytes
|
|
133
|
-
push_task_res_request_bytes: bytes = await request.body()
|
|
134
|
-
|
|
135
|
-
# Deserialize ProtoBuf
|
|
136
|
-
push_task_res_request_proto = PushTaskResRequest()
|
|
137
|
-
push_task_res_request_proto.ParseFromString(push_task_res_request_bytes)
|
|
138
|
-
|
|
139
122
|
# Get state from app
|
|
140
123
|
state: State = app.state.STATE_FACTORY.state()
|
|
141
124
|
|
|
142
125
|
# Handle message
|
|
143
|
-
|
|
144
|
-
request=push_task_res_request_proto,
|
|
145
|
-
state=state,
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
# Return serialized ProtoBuf
|
|
149
|
-
push_task_res_response_bytes = push_task_res_response_proto.SerializeToString()
|
|
150
|
-
return Response(
|
|
151
|
-
status_code=200,
|
|
152
|
-
content=push_task_res_response_bytes,
|
|
153
|
-
headers={"Content-Type": "application/protobuf"},
|
|
154
|
-
)
|
|
126
|
+
return message_handler.push_task_res(request=request, state=state)
|
|
155
127
|
|
|
156
128
|
|
|
157
|
-
|
|
129
|
+
@rest_request_response(PingRequest)
|
|
130
|
+
async def ping(request: PingRequest) -> PingResponse:
|
|
158
131
|
"""Ping."""
|
|
159
|
-
_check_headers(request.headers)
|
|
160
|
-
|
|
161
|
-
# Get the request body as raw bytes
|
|
162
|
-
ping_request_bytes: bytes = await request.body()
|
|
163
|
-
|
|
164
|
-
# Deserialize ProtoBuf
|
|
165
|
-
ping_request_proto = PingRequest()
|
|
166
|
-
ping_request_proto.ParseFromString(ping_request_bytes)
|
|
167
|
-
|
|
168
132
|
# Get state from app
|
|
169
133
|
state: State = app.state.STATE_FACTORY.state()
|
|
170
134
|
|
|
171
135
|
# Handle message
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
# Return serialized ProtoBuf
|
|
175
|
-
ping_response_bytes = ping_response_proto.SerializeToString()
|
|
176
|
-
return Response(
|
|
177
|
-
status_code=200,
|
|
178
|
-
content=ping_response_bytes,
|
|
179
|
-
headers={"Content-Type": "application/protobuf"},
|
|
180
|
-
)
|
|
136
|
+
return message_handler.ping(request=request, state=state)
|
|
181
137
|
|
|
182
138
|
|
|
183
|
-
|
|
139
|
+
@rest_request_response(GetRunRequest)
|
|
140
|
+
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
184
141
|
"""GetRun."""
|
|
185
|
-
_check_headers(request.headers)
|
|
186
|
-
|
|
187
|
-
# Get the request body as raw bytes
|
|
188
|
-
get_run_request_bytes: bytes = await request.body()
|
|
189
|
-
|
|
190
|
-
# Deserialize ProtoBuf
|
|
191
|
-
get_run_request_proto = GetRunRequest()
|
|
192
|
-
get_run_request_proto.ParseFromString(get_run_request_bytes)
|
|
193
|
-
|
|
194
142
|
# Get state from app
|
|
195
143
|
state: State = app.state.STATE_FACTORY.state()
|
|
196
144
|
|
|
197
145
|
# Handle message
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
)
|
|
146
|
+
return message_handler.get_run(request=request, state=state)
|
|
147
|
+
|
|
201
148
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
149
|
+
@rest_request_response(GetFabRequest)
|
|
150
|
+
async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
151
|
+
"""GetRun."""
|
|
152
|
+
# Get ffs from app
|
|
153
|
+
ffs: Ffs = app.state.FFS_FACTORY.state()
|
|
154
|
+
|
|
155
|
+
# Handle message
|
|
156
|
+
return message_handler.get_fab(request=request, ffs=ffs)
|
|
209
157
|
|
|
210
158
|
|
|
211
159
|
routes = [
|
|
@@ -215,6 +163,7 @@ routes = [
|
|
|
215
163
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
|
216
164
|
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
217
165
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
|
166
|
+
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
|
218
167
|
]
|
|
219
168
|
|
|
220
169
|
app: Starlette = Starlette(
|
|
@@ -33,7 +33,7 @@ class Backend(ABC):
|
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
def build(self) -> None:
|
|
36
|
+
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
|
37
37
|
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
@@ -60,7 +60,6 @@ class Backend(ABC):
|
|
|
60
60
|
@abstractmethod
|
|
61
61
|
def process_message(
|
|
62
62
|
self,
|
|
63
|
-
app: Callable[[], ClientApp],
|
|
64
63
|
message: Message,
|
|
65
64
|
context: Context,
|
|
66
65
|
) -> Tuple[Message, Context]:
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import sys
|
|
18
18
|
from logging import DEBUG, ERROR
|
|
19
|
-
from typing import Callable, Dict, Tuple, Union
|
|
19
|
+
from typing import Callable, Dict, Optional, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
@@ -63,6 +63,8 @@ class RayBackend(Backend):
|
|
|
63
63
|
actor_kwargs=actor_kwargs,
|
|
64
64
|
)
|
|
65
65
|
|
|
66
|
+
self.app_fn: Optional[Callable[[], ClientApp]] = None
|
|
67
|
+
|
|
66
68
|
def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
|
|
67
69
|
client_resources_config = config.get(self.client_resources_key)
|
|
68
70
|
client_resources: ClientResourcesDict = {}
|
|
@@ -126,14 +128,15 @@ class RayBackend(Backend):
|
|
|
126
128
|
"""Report whether the pool has idle actors."""
|
|
127
129
|
return self.pool.is_actor_available()
|
|
128
130
|
|
|
129
|
-
def build(self) -> None:
|
|
131
|
+
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
|
130
132
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
131
133
|
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
134
|
+
# Set ClientApp callable that ray actors will use
|
|
135
|
+
self.app_fn = app_fn
|
|
132
136
|
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
133
137
|
|
|
134
138
|
def process_message(
|
|
135
139
|
self,
|
|
136
|
-
app: Callable[[], ClientApp],
|
|
137
140
|
message: Message,
|
|
138
141
|
context: Context,
|
|
139
142
|
) -> Tuple[Message, Context]:
|
|
@@ -143,11 +146,17 @@ class RayBackend(Backend):
|
|
|
143
146
|
"""
|
|
144
147
|
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
145
148
|
|
|
149
|
+
if self.app_fn is None:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"Unspecified function to load a `ClientApp`. "
|
|
152
|
+
"Call the backend's `build()` method before processing messages."
|
|
153
|
+
)
|
|
154
|
+
|
|
146
155
|
try:
|
|
147
156
|
# Submit a task to the pool
|
|
148
157
|
future = self.pool.submit(
|
|
149
158
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
150
|
-
(
|
|
159
|
+
(self.app_fn, message, str(partition_id), context),
|
|
151
160
|
)
|
|
152
161
|
|
|
153
162
|
# Fetch result
|
|
@@ -87,7 +87,6 @@ def _register_node_states(
|
|
|
87
87
|
|
|
88
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
89
|
def worker(
|
|
90
|
-
app_fn: Callable[[], ClientApp],
|
|
91
90
|
taskins_queue: "Queue[TaskIns]",
|
|
92
91
|
taskres_queue: "Queue[TaskRes]",
|
|
93
92
|
node_states: Dict[int, NodeState],
|
|
@@ -110,9 +109,7 @@ def worker(
|
|
|
110
109
|
message = message_from_taskins(task_ins)
|
|
111
110
|
|
|
112
111
|
# Let backend process message
|
|
113
|
-
out_mssg, updated_context = backend.process_message(
|
|
114
|
-
app_fn, message, context
|
|
115
|
-
)
|
|
112
|
+
out_mssg, updated_context = backend.process_message(message, context)
|
|
116
113
|
|
|
117
114
|
# Update Context
|
|
118
115
|
node_states[node_id].update_context(
|
|
@@ -193,7 +190,7 @@ def run_api(
|
|
|
193
190
|
backend = backend_fn()
|
|
194
191
|
|
|
195
192
|
# Build backend
|
|
196
|
-
backend.build()
|
|
193
|
+
backend.build(app_fn)
|
|
197
194
|
|
|
198
195
|
# Add workers (they submit Messages to Backend)
|
|
199
196
|
state = state_factory.state()
|
|
@@ -223,7 +220,6 @@ def run_api(
|
|
|
223
220
|
_ = [
|
|
224
221
|
executor.submit(
|
|
225
222
|
worker,
|
|
226
|
-
app_fn,
|
|
227
223
|
taskins_queue,
|
|
228
224
|
taskres_queue,
|
|
229
225
|
node_states,
|
|
@@ -45,7 +45,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
45
45
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
46
46
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
47
47
|
|
|
48
|
-
self.
|
|
48
|
+
self.node_public_keys: Set[bytes] = set()
|
|
49
49
|
self.server_public_key: Optional[bytes] = None
|
|
50
50
|
self.server_private_key: Optional[bytes] = None
|
|
51
51
|
|
|
@@ -237,7 +237,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
237
237
|
return node_id
|
|
238
238
|
|
|
239
239
|
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
240
|
-
"""Delete a
|
|
240
|
+
"""Delete a node."""
|
|
241
241
|
with self.lock:
|
|
242
242
|
if node_id not in self.node_ids:
|
|
243
243
|
raise ValueError(f"Node {node_id} not found")
|
|
@@ -254,7 +254,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
254
254
|
del self.node_ids[node_id]
|
|
255
255
|
|
|
256
256
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
257
|
-
"""Return all available
|
|
257
|
+
"""Return all available nodes.
|
|
258
258
|
|
|
259
259
|
Constraints
|
|
260
260
|
-----------
|
|
@@ -271,9 +271,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
271
271
|
if online_until > current_time
|
|
272
272
|
}
|
|
273
273
|
|
|
274
|
-
def get_node_id(self,
|
|
275
|
-
"""Retrieve stored `node_id` filtered by `
|
|
276
|
-
return self.public_key_to_node_id.get(
|
|
274
|
+
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
275
|
+
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
276
|
+
return self.public_key_to_node_id.get(node_public_key)
|
|
277
277
|
|
|
278
278
|
def create_run(
|
|
279
279
|
self,
|
|
@@ -318,19 +318,19 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
318
318
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
319
319
|
return self.server_public_key
|
|
320
320
|
|
|
321
|
-
def
|
|
322
|
-
"""Store a set of `
|
|
321
|
+
def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
322
|
+
"""Store a set of `node_public_keys` in state."""
|
|
323
323
|
with self.lock:
|
|
324
|
-
self.
|
|
324
|
+
self.node_public_keys = public_keys
|
|
325
325
|
|
|
326
|
-
def
|
|
327
|
-
"""Store a `
|
|
326
|
+
def store_node_public_key(self, public_key: bytes) -> None:
|
|
327
|
+
"""Store a `node_public_key` in state."""
|
|
328
328
|
with self.lock:
|
|
329
|
-
self.
|
|
329
|
+
self.node_public_keys.add(public_key)
|
|
330
330
|
|
|
331
|
-
def
|
|
332
|
-
"""Retrieve all currently stored `
|
|
333
|
-
return self.
|
|
331
|
+
def get_node_public_keys(self) -> Set[bytes]:
|
|
332
|
+
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
333
|
+
return self.node_public_keys
|
|
334
334
|
|
|
335
335
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
336
336
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
@@ -569,7 +569,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
569
569
|
return node_id
|
|
570
570
|
|
|
571
571
|
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
572
|
-
"""Delete a
|
|
572
|
+
"""Delete a node."""
|
|
573
573
|
query = "DELETE FROM node WHERE node_id = ?"
|
|
574
574
|
params = (node_id,)
|
|
575
575
|
|
|
@@ -607,10 +607,10 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
607
607
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
608
608
|
return result
|
|
609
609
|
|
|
610
|
-
def get_node_id(self,
|
|
611
|
-
"""Retrieve stored `node_id` filtered by `
|
|
610
|
+
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
611
|
+
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
612
612
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
613
|
-
row = self.query(query, {"public_key":
|
|
613
|
+
row = self.query(query, {"public_key": node_public_key})
|
|
614
614
|
if len(row) > 0:
|
|
615
615
|
node_id: int = row[0]["node_id"]
|
|
616
616
|
return node_id
|
|
@@ -684,19 +684,19 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
684
684
|
public_key = None
|
|
685
685
|
return public_key
|
|
686
686
|
|
|
687
|
-
def
|
|
688
|
-
"""Store a set of `
|
|
687
|
+
def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
688
|
+
"""Store a set of `node_public_keys` in state."""
|
|
689
689
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
690
690
|
data = [(key,) for key in public_keys]
|
|
691
691
|
self.query(query, data)
|
|
692
692
|
|
|
693
|
-
def
|
|
694
|
-
"""Store a `
|
|
693
|
+
def store_node_public_key(self, public_key: bytes) -> None:
|
|
694
|
+
"""Store a `node_public_key` in state."""
|
|
695
695
|
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
696
696
|
self.query(query, {"public_key": public_key})
|
|
697
697
|
|
|
698
|
-
def
|
|
699
|
-
"""Retrieve all currently stored `
|
|
698
|
+
def get_node_public_keys(self) -> Set[bytes]:
|
|
699
|
+
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
700
700
|
query = "SELECT public_key FROM public_key"
|
|
701
701
|
rows = self.query(query)
|
|
702
702
|
result: Set[bytes] = {row["public_key"] for row in rows}
|
|
@@ -153,8 +153,8 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
153
153
|
"""
|
|
154
154
|
|
|
155
155
|
@abc.abstractmethod
|
|
156
|
-
def get_node_id(self,
|
|
157
|
-
"""Retrieve stored `node_id` filtered by `
|
|
156
|
+
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
157
|
+
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
158
158
|
|
|
159
159
|
@abc.abstractmethod
|
|
160
160
|
def create_run(
|
|
@@ -199,16 +199,16 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
199
199
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
200
200
|
|
|
201
201
|
@abc.abstractmethod
|
|
202
|
-
def
|
|
203
|
-
"""Store a set of `
|
|
202
|
+
def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
203
|
+
"""Store a set of `node_public_keys` in state."""
|
|
204
204
|
|
|
205
205
|
@abc.abstractmethod
|
|
206
|
-
def
|
|
207
|
-
"""Store a `
|
|
206
|
+
def store_node_public_key(self, public_key: bytes) -> None:
|
|
207
|
+
"""Store a `node_public_key` in state."""
|
|
208
208
|
|
|
209
209
|
@abc.abstractmethod
|
|
210
|
-
def
|
|
211
|
-
"""Retrieve all currently stored `
|
|
210
|
+
def get_node_public_keys(self) -> Set[bytes]:
|
|
211
|
+
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
212
212
|
|
|
213
213
|
@abc.abstractmethod
|
|
214
214
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
@@ -35,6 +35,7 @@ class SecAggWorkflow(SecAggPlusWorkflow):
|
|
|
35
35
|
contributions to compute the weighted average of model parameters.
|
|
36
36
|
|
|
37
37
|
The protocol involves four main stages:
|
|
38
|
+
|
|
38
39
|
- 'setup': Send SecAgg configuration to clients and collect their public keys.
|
|
39
40
|
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
|
40
41
|
key shares.
|
|
@@ -99,6 +99,7 @@ class SecAggPlusWorkflow:
|
|
|
99
99
|
contributions to compute the weighted average of model parameters.
|
|
100
100
|
|
|
101
101
|
The protocol involves four main stages:
|
|
102
|
+
|
|
102
103
|
- 'setup': Send SecAgg+ configuration to clients and collect their public keys.
|
|
103
104
|
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
|
104
105
|
key shares.
|