flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240708__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +9 -6
- flwr/client/grpc_adapter_client/connection.py +2 -1
- flwr/client/grpc_client/connection.py +2 -1
- flwr/client/grpc_rere_client/connection.py +9 -3
- flwr/client/rest_client/connection.py +10 -4
- flwr/common/config.py +75 -2
- flwr/common/context.py +8 -2
- flwr/common/typing.py +1 -0
- flwr/proto/common_pb2.py +24 -0
- flwr/proto/common_pb2.pyi +7 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +23 -19
- flwr/proto/driver_pb2.pyi +18 -1
- flwr/proto/exec_pb2.py +15 -11
- flwr/proto/exec_pb2.pyi +19 -1
- flwr/proto/run_pb2.py +11 -7
- flwr/proto/run_pb2.pyi +19 -1
- flwr/server/driver/grpc_driver.py +77 -139
- flwr/server/run_serverapp.py +20 -12
- flwr/server/superlink/driver/driver_servicer.py +5 -1
- flwr/server/superlink/state/in_memory_state.py +10 -2
- flwr/server/superlink/state/sqlite_state.py +22 -7
- flwr/server/superlink/state/state.py +7 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +16 -5
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +2 -3
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/METADATA +1 -1
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/RECORD +34 -30
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/entry_points.txt +0 -0
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
|
-
from logging import DEBUG,
|
|
20
|
-
from typing import Iterable, List, Optional,
|
|
19
|
+
from logging import DEBUG, WARNING
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
@@ -27,8 +27,6 @@ from flwr.common.logger import log
|
|
|
27
27
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
28
|
from flwr.common.typing import Run
|
|
29
29
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
30
|
-
CreateRunRequest,
|
|
31
|
-
CreateRunResponse,
|
|
32
30
|
GetNodesRequest,
|
|
33
31
|
GetNodesResponse,
|
|
34
32
|
PullTaskResRequest,
|
|
@@ -53,167 +51,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
|
|
|
53
51
|
"""
|
|
54
52
|
|
|
55
53
|
|
|
56
|
-
class
|
|
57
|
-
"""`
|
|
54
|
+
class GrpcDriver(Driver):
|
|
55
|
+
"""`GrpcDriver` provides an interface to the Driver API.
|
|
58
56
|
|
|
59
57
|
Parameters
|
|
60
58
|
----------
|
|
61
|
-
|
|
62
|
-
The
|
|
63
|
-
|
|
59
|
+
run_id : int
|
|
60
|
+
The identifier of the run.
|
|
61
|
+
driver_service_address : str (default: "[::]:9091")
|
|
62
|
+
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
64
63
|
root_certificates : Optional[bytes] (default: None)
|
|
65
64
|
The PEM-encoded root certificates as a byte string.
|
|
66
65
|
If provided, a secure connection using the certificates will be
|
|
67
66
|
established to an SSL-enabled Flower server.
|
|
68
67
|
"""
|
|
69
68
|
|
|
70
|
-
def __init__(
|
|
69
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
71
70
|
self,
|
|
71
|
+
run_id: int,
|
|
72
72
|
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
73
73
|
root_certificates: Optional[bytes] = None,
|
|
74
74
|
) -> None:
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
self.
|
|
75
|
+
self._run_id = run_id
|
|
76
|
+
self._addr = driver_service_address
|
|
77
|
+
self._cert = root_certificates
|
|
78
|
+
self._run: Optional[Run] = None
|
|
79
|
+
self._grpc_stub: Optional[DriverStub] = None
|
|
80
|
+
self._channel: Optional[grpc.Channel] = None
|
|
81
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def _is_connected(self) -> bool:
|
|
85
|
+
"""Check if connected to the Driver API server."""
|
|
86
|
+
return self._channel is not None
|
|
79
87
|
|
|
80
|
-
def
|
|
81
|
-
"""
|
|
82
|
-
return self.channel is not None
|
|
88
|
+
def _connect(self) -> None:
|
|
89
|
+
"""Connect to the Driver API.
|
|
83
90
|
|
|
84
|
-
|
|
85
|
-
"""
|
|
91
|
+
This will not call GetRun.
|
|
92
|
+
"""
|
|
86
93
|
event(EventType.DRIVER_CONNECT)
|
|
87
|
-
if self.
|
|
94
|
+
if self._is_connected:
|
|
88
95
|
log(WARNING, "Already connected")
|
|
89
96
|
return
|
|
90
|
-
self.
|
|
91
|
-
server_address=self.
|
|
92
|
-
insecure=(self.
|
|
93
|
-
root_certificates=self.
|
|
97
|
+
self._channel = create_channel(
|
|
98
|
+
server_address=self._addr,
|
|
99
|
+
insecure=(self._cert is None),
|
|
100
|
+
root_certificates=self._cert,
|
|
94
101
|
)
|
|
95
|
-
self.
|
|
96
|
-
log(DEBUG, "[Driver] Connected to %s", self.
|
|
102
|
+
self._grpc_stub = DriverStub(self._channel)
|
|
103
|
+
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
97
104
|
|
|
98
|
-
def
|
|
105
|
+
def _disconnect(self) -> None:
|
|
99
106
|
"""Disconnect from the Driver API."""
|
|
100
107
|
event(EventType.DRIVER_DISCONNECT)
|
|
101
|
-
if
|
|
108
|
+
if not self._is_connected:
|
|
102
109
|
log(DEBUG, "Already disconnected")
|
|
103
110
|
return
|
|
104
|
-
channel = self.
|
|
105
|
-
self.
|
|
106
|
-
self.
|
|
111
|
+
channel: grpc.Channel = self._channel
|
|
112
|
+
self._channel = None
|
|
113
|
+
self._grpc_stub = None
|
|
107
114
|
channel.close()
|
|
108
115
|
log(DEBUG, "[Driver] Disconnected")
|
|
109
116
|
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
-
|
|
128
|
-
# Call gRPC Driver API
|
|
129
|
-
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
-
return res
|
|
131
|
-
|
|
132
|
-
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
133
|
-
"""Get client IDs."""
|
|
134
|
-
# Check if channel is open
|
|
135
|
-
if self.stub is None:
|
|
136
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
137
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
138
|
-
|
|
139
|
-
# Call gRPC Driver API
|
|
140
|
-
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
141
|
-
return res
|
|
142
|
-
|
|
143
|
-
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
144
|
-
"""Schedule tasks."""
|
|
145
|
-
# Check if channel is open
|
|
146
|
-
if self.stub is None:
|
|
147
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
148
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
149
|
-
|
|
150
|
-
# Call gRPC Driver API
|
|
151
|
-
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
152
|
-
return res
|
|
153
|
-
|
|
154
|
-
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
155
|
-
"""Get task results."""
|
|
156
|
-
# Check if channel is open
|
|
157
|
-
if self.stub is None:
|
|
158
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
159
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
160
|
-
|
|
161
|
-
# Call Driver API
|
|
162
|
-
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
163
|
-
return res
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
class GrpcDriver(Driver):
|
|
167
|
-
"""`Driver` class provides an interface to the Driver API.
|
|
168
|
-
|
|
169
|
-
Parameters
|
|
170
|
-
----------
|
|
171
|
-
run_id : int
|
|
172
|
-
The identifier of the run.
|
|
173
|
-
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
-
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
-
If None, an instance connected to "[::]:9091" will be created.
|
|
176
|
-
"""
|
|
177
|
-
|
|
178
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
179
|
-
self,
|
|
180
|
-
run_id: int,
|
|
181
|
-
stub: Optional[GrpcDriverStub] = None,
|
|
182
|
-
) -> None:
|
|
183
|
-
self._run_id = run_id
|
|
184
|
-
self._run: Optional[Run] = None
|
|
185
|
-
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
186
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
117
|
+
def _init_run(self) -> None:
|
|
118
|
+
# Check if is initialized
|
|
119
|
+
if self._run is not None:
|
|
120
|
+
return
|
|
121
|
+
# Get the run info
|
|
122
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
123
|
+
res: GetRunResponse = self._stub.GetRun(req)
|
|
124
|
+
if not res.HasField("run"):
|
|
125
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
126
|
+
self._run = Run(
|
|
127
|
+
run_id=res.run.run_id,
|
|
128
|
+
fab_id=res.run.fab_id,
|
|
129
|
+
fab_version=res.run.fab_version,
|
|
130
|
+
override_config=dict(res.run.override_config.items()),
|
|
131
|
+
)
|
|
187
132
|
|
|
188
133
|
@property
|
|
189
134
|
def run(self) -> Run:
|
|
190
135
|
"""Run information."""
|
|
191
|
-
self.
|
|
192
|
-
return Run(**vars(
|
|
136
|
+
self._init_run()
|
|
137
|
+
return Run(**vars(self._run))
|
|
193
138
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
# Get the run info
|
|
201
|
-
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
-
res = self.stub.get_run(req)
|
|
203
|
-
if not res.HasField("run"):
|
|
204
|
-
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
-
self._run = Run(
|
|
206
|
-
run_id=res.run.run_id,
|
|
207
|
-
fab_id=res.run.fab_id,
|
|
208
|
-
fab_version=res.run.fab_version,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
return self.stub, self._run.run_id
|
|
139
|
+
@property
|
|
140
|
+
def _stub(self) -> DriverStub:
|
|
141
|
+
"""Driver stub."""
|
|
142
|
+
if not self._is_connected:
|
|
143
|
+
self._connect()
|
|
144
|
+
return cast(DriverStub, self._grpc_stub)
|
|
212
145
|
|
|
213
146
|
def _check_message(self, message: Message) -> None:
|
|
214
147
|
# Check if the message is valid
|
|
215
148
|
if not (
|
|
216
|
-
|
|
149
|
+
# Assume self._run being initialized
|
|
150
|
+
message.metadata.run_id == self._run_id
|
|
217
151
|
and message.metadata.src_node_id == self.node.node_id
|
|
218
152
|
and message.metadata.message_id == ""
|
|
219
153
|
and message.metadata.reply_to_message == ""
|
|
@@ -234,7 +168,7 @@ class GrpcDriver(Driver):
|
|
|
234
168
|
This method constructs a new `Message` with given content and metadata.
|
|
235
169
|
The `run_id` and `src_node_id` will be set automatically.
|
|
236
170
|
"""
|
|
237
|
-
|
|
171
|
+
self._init_run()
|
|
238
172
|
if ttl:
|
|
239
173
|
warnings.warn(
|
|
240
174
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -245,7 +179,7 @@ class GrpcDriver(Driver):
|
|
|
245
179
|
|
|
246
180
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
247
181
|
metadata = Metadata(
|
|
248
|
-
run_id=
|
|
182
|
+
run_id=self._run_id,
|
|
249
183
|
message_id="", # Will be set by the server
|
|
250
184
|
src_node_id=self.node.node_id,
|
|
251
185
|
dst_node_id=dst_node_id,
|
|
@@ -258,9 +192,11 @@ class GrpcDriver(Driver):
|
|
|
258
192
|
|
|
259
193
|
def get_node_ids(self) -> List[int]:
|
|
260
194
|
"""Get node IDs."""
|
|
261
|
-
|
|
195
|
+
self._init_run()
|
|
262
196
|
# Call GrpcDriverStub method
|
|
263
|
-
res =
|
|
197
|
+
res: GetNodesResponse = self._stub.GetNodes(
|
|
198
|
+
GetNodesRequest(run_id=self._run_id)
|
|
199
|
+
)
|
|
264
200
|
return [node.node_id for node in res.nodes]
|
|
265
201
|
|
|
266
202
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -269,7 +205,7 @@ class GrpcDriver(Driver):
|
|
|
269
205
|
This method takes an iterable of messages and sends each message
|
|
270
206
|
to the node specified in `dst_node_id`.
|
|
271
207
|
"""
|
|
272
|
-
|
|
208
|
+
self._init_run()
|
|
273
209
|
# Construct TaskIns
|
|
274
210
|
task_ins_list: List[TaskIns] = []
|
|
275
211
|
for msg in messages:
|
|
@@ -280,7 +216,9 @@ class GrpcDriver(Driver):
|
|
|
280
216
|
# Add to list
|
|
281
217
|
task_ins_list.append(taskins)
|
|
282
218
|
# Call GrpcDriverStub method
|
|
283
|
-
res =
|
|
219
|
+
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
220
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
221
|
+
)
|
|
284
222
|
return list(res.task_ids)
|
|
285
223
|
|
|
286
224
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -289,9 +227,9 @@ class GrpcDriver(Driver):
|
|
|
289
227
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
290
228
|
set of given message IDs.
|
|
291
229
|
"""
|
|
292
|
-
|
|
230
|
+
self._init_run()
|
|
293
231
|
# Pull TaskRes
|
|
294
|
-
res =
|
|
232
|
+
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
295
233
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
296
234
|
)
|
|
297
235
|
# Convert TaskRes to Message
|
|
@@ -331,7 +269,7 @@ class GrpcDriver(Driver):
|
|
|
331
269
|
def close(self) -> None:
|
|
332
270
|
"""Disconnect from the SuperLink if connected."""
|
|
333
271
|
# Check if `connect` was called before
|
|
334
|
-
if not self.
|
|
272
|
+
if not self._is_connected:
|
|
335
273
|
return
|
|
336
274
|
# Disconnect
|
|
337
|
-
self.
|
|
275
|
+
self._disconnect()
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -25,10 +25,13 @@ from flwr.common import Context, EventType, RecordSet, event
|
|
|
25
25
|
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
|
|
26
26
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
27
27
|
from flwr.common.object_ref import load_app
|
|
28
|
-
from flwr.proto.driver_pb2 import
|
|
28
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
|
+
CreateRunRequest,
|
|
30
|
+
CreateRunResponse,
|
|
31
|
+
)
|
|
29
32
|
|
|
30
33
|
from .driver import Driver
|
|
31
|
-
from .driver.grpc_driver import GrpcDriver
|
|
34
|
+
from .driver.grpc_driver import GrpcDriver
|
|
32
35
|
from .server_app import LoadServerAppError, ServerApp
|
|
33
36
|
|
|
34
37
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -144,22 +147,27 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
144
147
|
"For more details, use: ``flower-server-app -h``"
|
|
145
148
|
)
|
|
146
149
|
|
|
147
|
-
|
|
148
|
-
driver_service_address=args.superlink, root_certificates=root_certificates
|
|
149
|
-
)
|
|
150
|
+
# Initialize GrpcDriver
|
|
150
151
|
if args.run_id is not None:
|
|
151
152
|
# User provided `--run-id`, but not `server-app`
|
|
152
|
-
|
|
153
|
+
driver = GrpcDriver(
|
|
154
|
+
run_id=args.run_id,
|
|
155
|
+
driver_service_address=args.superlink,
|
|
156
|
+
root_certificates=root_certificates,
|
|
157
|
+
)
|
|
153
158
|
else:
|
|
154
159
|
# User provided `server-app`, but not `--run-id`
|
|
155
160
|
# Create run if run_id is not provided
|
|
156
|
-
|
|
161
|
+
driver = GrpcDriver(
|
|
162
|
+
run_id=0, # Will be overwritten
|
|
163
|
+
driver_service_address=args.superlink,
|
|
164
|
+
root_certificates=root_certificates,
|
|
165
|
+
)
|
|
166
|
+
# Create run
|
|
157
167
|
req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
|
|
158
|
-
res =
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
# Initialize GrpcDriver
|
|
162
|
-
driver = GrpcDriver(run_id=run_id, stub=stub)
|
|
168
|
+
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
169
|
+
# Overwrite driver._run_id
|
|
170
|
+
driver._run_id = res.run_id # pylint: disable=W0212
|
|
163
171
|
|
|
164
172
|
# Dynamically obtain ServerApp path based on run_id
|
|
165
173
|
if args.run_id is not None:
|
|
@@ -69,7 +69,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
69
69
|
"""Create run ID."""
|
|
70
70
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
71
71
|
state: State = self.state_factory.state()
|
|
72
|
-
run_id = state.create_run(
|
|
72
|
+
run_id = state.create_run(
|
|
73
|
+
request.fab_id,
|
|
74
|
+
request.fab_version,
|
|
75
|
+
dict(request.override_config.items()),
|
|
76
|
+
)
|
|
73
77
|
return CreateRunResponse(run_id=run_id)
|
|
74
78
|
|
|
75
79
|
def PushTaskIns(
|
|
@@ -275,7 +275,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
275
275
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
276
276
|
return self.public_key_to_node_id.get(client_public_key)
|
|
277
277
|
|
|
278
|
-
def create_run(
|
|
278
|
+
def create_run(
|
|
279
|
+
self,
|
|
280
|
+
fab_id: str,
|
|
281
|
+
fab_version: str,
|
|
282
|
+
override_config: Dict[str, str],
|
|
283
|
+
) -> int:
|
|
279
284
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
280
285
|
# Sample a random int64 as run_id
|
|
281
286
|
with self.lock:
|
|
@@ -283,7 +288,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
283
288
|
|
|
284
289
|
if run_id not in self.run_ids:
|
|
285
290
|
self.run_ids[run_id] = Run(
|
|
286
|
-
run_id=run_id,
|
|
291
|
+
run_id=run_id,
|
|
292
|
+
fab_id=fab_id,
|
|
293
|
+
fab_version=fab_version,
|
|
294
|
+
override_config=override_config,
|
|
287
295
|
)
|
|
288
296
|
return run_id
|
|
289
297
|
log(ERROR, "Unexpected run creation failure.")
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import json
|
|
18
19
|
import re
|
|
19
20
|
import sqlite3
|
|
20
21
|
import time
|
|
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
61
62
|
|
|
62
63
|
SQL_CREATE_TABLE_RUN = """
|
|
63
64
|
CREATE TABLE IF NOT EXISTS run(
|
|
64
|
-
run_id
|
|
65
|
-
fab_id
|
|
66
|
-
fab_version
|
|
65
|
+
run_id INTEGER UNIQUE,
|
|
66
|
+
fab_id TEXT,
|
|
67
|
+
fab_version TEXT,
|
|
68
|
+
override_config TEXT
|
|
67
69
|
);
|
|
68
70
|
"""
|
|
69
71
|
|
|
@@ -613,7 +615,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
613
615
|
return node_id
|
|
614
616
|
return None
|
|
615
617
|
|
|
616
|
-
def create_run(
|
|
618
|
+
def create_run(
|
|
619
|
+
self,
|
|
620
|
+
fab_id: str,
|
|
621
|
+
fab_version: str,
|
|
622
|
+
override_config: Dict[str, str],
|
|
623
|
+
) -> int:
|
|
617
624
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
618
625
|
# Sample a random int64 as run_id
|
|
619
626
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -622,8 +629,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
622
629
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
623
630
|
# If run_id does not exist
|
|
624
631
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
625
|
-
query =
|
|
626
|
-
|
|
632
|
+
query = (
|
|
633
|
+
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
|
|
634
|
+
"VALUES (?, ?, ?, ?);"
|
|
635
|
+
)
|
|
636
|
+
self.query(
|
|
637
|
+
query, (run_id, fab_id, fab_version, json.dumps(override_config))
|
|
638
|
+
)
|
|
627
639
|
return run_id
|
|
628
640
|
log(ERROR, "Unexpected run creation failure.")
|
|
629
641
|
return 0
|
|
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
687
699
|
try:
|
|
688
700
|
row = self.query(query, (run_id,))[0]
|
|
689
701
|
return Run(
|
|
690
|
-
run_id=run_id,
|
|
702
|
+
run_id=run_id,
|
|
703
|
+
fab_id=row["fab_id"],
|
|
704
|
+
fab_version=row["fab_version"],
|
|
705
|
+
override_config=json.loads(row["override_config"]),
|
|
691
706
|
)
|
|
692
707
|
except sqlite3.IntegrityError:
|
|
693
708
|
log(ERROR, "`run_id` does not exist.")
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import List, Optional, Set
|
|
19
|
+
from typing import Dict, List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.common.typing import Run
|
|
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
157
157
|
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
158
158
|
|
|
159
159
|
@abc.abstractmethod
|
|
160
|
-
def create_run(
|
|
160
|
+
def create_run(
|
|
161
|
+
self,
|
|
162
|
+
fab_id: str,
|
|
163
|
+
fab_version: str,
|
|
164
|
+
override_config: Dict[str, str],
|
|
165
|
+
) -> int:
|
|
161
166
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
162
167
|
|
|
163
168
|
@abc.abstractmethod
|
|
@@ -209,7 +209,7 @@ def _main_loop(
|
|
|
209
209
|
serverapp_th = None
|
|
210
210
|
try:
|
|
211
211
|
# Create run (with empty fab_id and fab_version)
|
|
212
|
-
run_id_ = state_factory.state().create_run("", "")
|
|
212
|
+
run_id_ = state_factory.state().create_run("", "", {})
|
|
213
213
|
|
|
214
214
|
if run_id:
|
|
215
215
|
_override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id)
|
flwr/superexec/app.py
CHANGED
|
@@ -77,6 +77,7 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser:
|
|
|
77
77
|
parser.add_argument(
|
|
78
78
|
"executor",
|
|
79
79
|
help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.",
|
|
80
|
+
default="flwr.superexec.deployment:executor",
|
|
80
81
|
)
|
|
81
82
|
parser.add_argument(
|
|
82
83
|
"--address",
|
flwr/superexec/deployment.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import subprocess
|
|
18
18
|
import sys
|
|
19
19
|
from logging import ERROR, INFO
|
|
20
|
-
from typing import Optional
|
|
20
|
+
from typing import Dict, Optional
|
|
21
21
|
|
|
22
22
|
from typing_extensions import override
|
|
23
23
|
|
|
@@ -53,18 +53,29 @@ class DeploymentEngine(Executor):
|
|
|
53
53
|
)
|
|
54
54
|
self.stub = DriverStub(channel)
|
|
55
55
|
|
|
56
|
-
def _create_run(
|
|
56
|
+
def _create_run(
|
|
57
|
+
self,
|
|
58
|
+
fab_id: str,
|
|
59
|
+
fab_version: str,
|
|
60
|
+
override_config: Dict[str, str],
|
|
61
|
+
) -> int:
|
|
57
62
|
if self.stub is None:
|
|
58
63
|
self._connect()
|
|
59
64
|
|
|
60
65
|
assert self.stub is not None
|
|
61
66
|
|
|
62
|
-
req = CreateRunRequest(
|
|
67
|
+
req = CreateRunRequest(
|
|
68
|
+
fab_id=fab_id,
|
|
69
|
+
fab_version=fab_version,
|
|
70
|
+
override_config=override_config,
|
|
71
|
+
)
|
|
63
72
|
res = self.stub.CreateRun(request=req)
|
|
64
73
|
return int(res.run_id)
|
|
65
74
|
|
|
66
75
|
@override
|
|
67
|
-
def start_run(
|
|
76
|
+
def start_run(
|
|
77
|
+
self, fab_file: bytes, override_config: Dict[str, str]
|
|
78
|
+
) -> Optional[RunTracker]:
|
|
68
79
|
"""Start run using the Flower Deployment Engine."""
|
|
69
80
|
try:
|
|
70
81
|
# Install FAB to flwr dir
|
|
@@ -79,7 +90,7 @@ class DeploymentEngine(Executor):
|
|
|
79
90
|
)
|
|
80
91
|
|
|
81
92
|
# Call SuperLink to create run
|
|
82
|
-
run_id: int = self._create_run(fab_id, fab_version)
|
|
93
|
+
run_id: int = self._create_run(fab_id, fab_version, override_config)
|
|
83
94
|
log(INFO, "Created run %s", str(run_id))
|
|
84
95
|
|
|
85
96
|
# Start ServerApp
|
flwr/superexec/exec_servicer.py
CHANGED
|
@@ -45,7 +45,10 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
45
45
|
"""Create run ID."""
|
|
46
46
|
log(INFO, "ExecServicer.StartRun")
|
|
47
47
|
|
|
48
|
-
run = self.executor.start_run(
|
|
48
|
+
run = self.executor.start_run(
|
|
49
|
+
request.fab_file,
|
|
50
|
+
dict(request.override_config.items()),
|
|
51
|
+
)
|
|
49
52
|
|
|
50
53
|
if run is None:
|
|
51
54
|
log(ERROR, "Executor failed to start run")
|
flwr/superexec/executor.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from subprocess import Popen
|
|
20
|
-
from typing import Optional
|
|
20
|
+
from typing import Dict, Optional
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
@dataclass
|
|
@@ -33,8 +33,7 @@ class Executor(ABC):
|
|
|
33
33
|
|
|
34
34
|
@abstractmethod
|
|
35
35
|
def start_run(
|
|
36
|
-
self,
|
|
37
|
-
fab_file: bytes,
|
|
36
|
+
self, fab_file: bytes, override_config: Dict[str, str]
|
|
38
37
|
) -> Optional[RunTracker]:
|
|
39
38
|
"""Start a run using the given Flower FAB ID and version.
|
|
40
39
|
|