flwr-nightly 1.13.0.dev20241105__py3-none-any.whl → 1.13.0.dev20241107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/run/run.py +16 -5
- flwr/client/app.py +10 -6
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/run_info_store.py +1 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +1 -1
- flwr/common/context.py +9 -4
- flwr/common/object_ref.py +40 -33
- flwr/common/serde.py +2 -0
- flwr/proto/exec_pb2.py +14 -17
- flwr/proto/exec_pb2.pyi +6 -20
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/server/app.py +140 -107
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +2 -6
- flwr/server/driver/inmemory_driver.py +1 -3
- flwr/server/run_serverapp.py +5 -2
- flwr/server/serverapp/app.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +2 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +15 -16
- flwr/server/superlink/linkstate/linkstate.py +18 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +41 -21
- flwr/server/superlink/linkstate/utils.py +14 -30
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
- flwr/simulation/__init__.py +2 -0
- flwr/simulation/run_simulation.py +4 -1
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/deployment.py +8 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +4 -3
- flwr/superexec/simulation.py +8 -8
- {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/RECORD +42 -34
- {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -47,6 +47,7 @@ from flwr.common.constant import (
|
|
|
47
47
|
ISOLATION_MODE_SUBPROCESS,
|
|
48
48
|
MISSING_EXTRA_REST,
|
|
49
49
|
SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
50
|
+
SIMULATIONIO_API_DEFAULT_ADDRESS,
|
|
50
51
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
51
52
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
52
53
|
TRANSPORT_TYPE_REST,
|
|
@@ -63,6 +64,7 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
|
63
64
|
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
64
65
|
from flwr.superexec.app import load_executor
|
|
65
66
|
from flwr.superexec.exec_grpc import run_exec_api_grpc
|
|
67
|
+
from flwr.superexec.simulation import SimulationEngine
|
|
66
68
|
|
|
67
69
|
from .client_manager import ClientManager
|
|
68
70
|
from .history import History
|
|
@@ -79,6 +81,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
79
81
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
80
82
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
81
83
|
from .superlink.linkstate import LinkStateFactory
|
|
84
|
+
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
|
82
85
|
|
|
83
86
|
DATABASE = ":flwr-in-memory-state:"
|
|
84
87
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
@@ -215,6 +218,7 @@ def run_superlink() -> None:
|
|
|
215
218
|
# Parse IP addresses
|
|
216
219
|
serverappio_address, _, _ = _format_address(args.serverappio_api_address)
|
|
217
220
|
exec_address, _, _ = _format_address(args.exec_api_address)
|
|
221
|
+
simulationio_address, _, _ = _format_address(args.simulationio_api_address)
|
|
218
222
|
|
|
219
223
|
# Obtain certificates
|
|
220
224
|
certificates = _try_obtain_certificates(args)
|
|
@@ -225,128 +229,148 @@ def run_superlink() -> None:
|
|
|
225
229
|
# Initialize FfsFactory
|
|
226
230
|
ffs_factory = FfsFactory(args.storage_dir)
|
|
227
231
|
|
|
228
|
-
# Start
|
|
229
|
-
|
|
230
|
-
|
|
232
|
+
# Start Exec API
|
|
233
|
+
executor = load_executor(args)
|
|
234
|
+
exec_server: grpc.Server = run_exec_api_grpc(
|
|
235
|
+
address=exec_address,
|
|
231
236
|
state_factory=state_factory,
|
|
232
237
|
ffs_factory=ffs_factory,
|
|
238
|
+
executor=executor,
|
|
233
239
|
certificates=certificates,
|
|
240
|
+
config=parse_config_args(
|
|
241
|
+
[args.executor_config] if args.executor_config else args.executor_config
|
|
242
|
+
),
|
|
234
243
|
)
|
|
235
|
-
grpc_servers = [
|
|
244
|
+
grpc_servers = [exec_server]
|
|
236
245
|
|
|
237
|
-
#
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
if args.fleet_api_type in [
|
|
241
|
-
TRANSPORT_TYPE_GRPC_RERE,
|
|
242
|
-
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
243
|
-
]:
|
|
244
|
-
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
|
|
245
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
246
|
-
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS
|
|
247
|
-
|
|
248
|
-
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
249
|
-
|
|
250
|
-
num_workers = args.fleet_api_num_workers
|
|
251
|
-
if num_workers != 1:
|
|
252
|
-
log(
|
|
253
|
-
WARN,
|
|
254
|
-
"The Fleet API currently supports only 1 worker. "
|
|
255
|
-
"You have specified %d workers. "
|
|
256
|
-
"Support for multiple workers will be added in future releases. "
|
|
257
|
-
"Proceeding with a single worker.",
|
|
258
|
-
args.fleet_api_num_workers,
|
|
259
|
-
)
|
|
260
|
-
num_workers = 1
|
|
246
|
+
# Determine Exec plugin
|
|
247
|
+
# If simulation is used, don't start ServerAppIo and Fleet APIs
|
|
248
|
+
sim_exec = isinstance(executor, SimulationEngine)
|
|
261
249
|
|
|
262
|
-
|
|
263
|
-
if (
|
|
264
|
-
importlib.util.find_spec("requests")
|
|
265
|
-
and importlib.util.find_spec("starlette")
|
|
266
|
-
and importlib.util.find_spec("uvicorn")
|
|
267
|
-
) is None:
|
|
268
|
-
sys.exit(MISSING_EXTRA_REST)
|
|
269
|
-
|
|
270
|
-
_, ssl_certfile, ssl_keyfile = (
|
|
271
|
-
certificates if certificates is not None else (None, None, None)
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
fleet_thread = threading.Thread(
|
|
275
|
-
target=_run_fleet_api_rest,
|
|
276
|
-
args=(
|
|
277
|
-
host,
|
|
278
|
-
port,
|
|
279
|
-
ssl_keyfile,
|
|
280
|
-
ssl_certfile,
|
|
281
|
-
state_factory,
|
|
282
|
-
ffs_factory,
|
|
283
|
-
num_workers,
|
|
284
|
-
),
|
|
285
|
-
)
|
|
286
|
-
fleet_thread.start()
|
|
287
|
-
bckg_threads.append(fleet_thread)
|
|
288
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
289
|
-
maybe_keys = _try_setup_node_authentication(args, certificates)
|
|
290
|
-
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
291
|
-
if maybe_keys is not None:
|
|
292
|
-
(
|
|
293
|
-
node_public_keys,
|
|
294
|
-
server_private_key,
|
|
295
|
-
server_public_key,
|
|
296
|
-
) = maybe_keys
|
|
297
|
-
state = state_factory.state()
|
|
298
|
-
state.store_node_public_keys(node_public_keys)
|
|
299
|
-
state.store_server_private_public_key(
|
|
300
|
-
private_key_to_bytes(server_private_key),
|
|
301
|
-
public_key_to_bytes(server_public_key),
|
|
302
|
-
)
|
|
303
|
-
log(
|
|
304
|
-
INFO,
|
|
305
|
-
"Node authentication enabled with %d known public keys",
|
|
306
|
-
len(node_public_keys),
|
|
307
|
-
)
|
|
308
|
-
interceptors = [AuthenticateServerInterceptor(state)]
|
|
250
|
+
bckg_threads = []
|
|
309
251
|
|
|
310
|
-
|
|
311
|
-
|
|
252
|
+
if sim_exec:
|
|
253
|
+
simulationio_server: grpc.Server = run_simulationio_api_grpc(
|
|
254
|
+
address=simulationio_address,
|
|
312
255
|
state_factory=state_factory,
|
|
313
256
|
ffs_factory=ffs_factory,
|
|
314
257
|
certificates=certificates,
|
|
315
|
-
interceptors=interceptors,
|
|
316
258
|
)
|
|
317
|
-
grpc_servers.append(
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
259
|
+
grpc_servers.append(simulationio_server)
|
|
260
|
+
|
|
261
|
+
else:
|
|
262
|
+
# Start ServerAppIo API
|
|
263
|
+
serverappio_server: grpc.Server = run_serverappio_api_grpc(
|
|
264
|
+
address=serverappio_address,
|
|
321
265
|
state_factory=state_factory,
|
|
322
266
|
ffs_factory=ffs_factory,
|
|
323
267
|
certificates=certificates,
|
|
324
268
|
)
|
|
325
|
-
grpc_servers.append(
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
269
|
+
grpc_servers.append(serverappio_server)
|
|
270
|
+
|
|
271
|
+
# Start Fleet API
|
|
272
|
+
if not args.fleet_api_address:
|
|
273
|
+
if args.fleet_api_type in [
|
|
274
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
275
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
276
|
+
]:
|
|
277
|
+
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
|
|
278
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
279
|
+
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS
|
|
280
|
+
|
|
281
|
+
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
282
|
+
|
|
283
|
+
num_workers = args.fleet_api_num_workers
|
|
284
|
+
if num_workers != 1:
|
|
285
|
+
log(
|
|
286
|
+
WARN,
|
|
287
|
+
"The Fleet API currently supports only 1 worker. "
|
|
288
|
+
"You have specified %d workers. "
|
|
289
|
+
"Support for multiple workers will be added in future releases. "
|
|
290
|
+
"Proceeding with a single worker.",
|
|
291
|
+
args.fleet_api_num_workers,
|
|
292
|
+
)
|
|
293
|
+
num_workers = 1
|
|
294
|
+
|
|
295
|
+
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
296
|
+
if (
|
|
297
|
+
importlib.util.find_spec("requests")
|
|
298
|
+
and importlib.util.find_spec("starlette")
|
|
299
|
+
and importlib.util.find_spec("uvicorn")
|
|
300
|
+
) is None:
|
|
301
|
+
sys.exit(MISSING_EXTRA_REST)
|
|
302
|
+
|
|
303
|
+
_, ssl_certfile, ssl_keyfile = (
|
|
304
|
+
certificates if certificates is not None else (None, None, None)
|
|
305
|
+
)
|
|
341
306
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
307
|
+
fleet_thread = threading.Thread(
|
|
308
|
+
target=_run_fleet_api_rest,
|
|
309
|
+
args=(
|
|
310
|
+
host,
|
|
311
|
+
port,
|
|
312
|
+
ssl_keyfile,
|
|
313
|
+
ssl_certfile,
|
|
314
|
+
state_factory,
|
|
315
|
+
ffs_factory,
|
|
316
|
+
num_workers,
|
|
317
|
+
),
|
|
318
|
+
)
|
|
319
|
+
fleet_thread.start()
|
|
320
|
+
bckg_threads.append(fleet_thread)
|
|
321
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
322
|
+
maybe_keys = _try_setup_node_authentication(args, certificates)
|
|
323
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
324
|
+
if maybe_keys is not None:
|
|
325
|
+
(
|
|
326
|
+
node_public_keys,
|
|
327
|
+
server_private_key,
|
|
328
|
+
server_public_key,
|
|
329
|
+
) = maybe_keys
|
|
330
|
+
state = state_factory.state()
|
|
331
|
+
state.store_node_public_keys(node_public_keys)
|
|
332
|
+
state.store_server_private_public_key(
|
|
333
|
+
private_key_to_bytes(server_private_key),
|
|
334
|
+
public_key_to_bytes(server_public_key),
|
|
335
|
+
)
|
|
336
|
+
log(
|
|
337
|
+
INFO,
|
|
338
|
+
"Node authentication enabled with %d known public keys",
|
|
339
|
+
len(node_public_keys),
|
|
340
|
+
)
|
|
341
|
+
interceptors = [AuthenticateServerInterceptor(state)]
|
|
342
|
+
|
|
343
|
+
fleet_server = _run_fleet_api_grpc_rere(
|
|
344
|
+
address=fleet_address,
|
|
345
|
+
state_factory=state_factory,
|
|
346
|
+
ffs_factory=ffs_factory,
|
|
347
|
+
certificates=certificates,
|
|
348
|
+
interceptors=interceptors,
|
|
349
|
+
)
|
|
350
|
+
grpc_servers.append(fleet_server)
|
|
351
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
352
|
+
fleet_server = _run_fleet_api_grpc_adapter(
|
|
353
|
+
address=fleet_address,
|
|
354
|
+
state_factory=state_factory,
|
|
355
|
+
ffs_factory=ffs_factory,
|
|
356
|
+
certificates=certificates,
|
|
357
|
+
)
|
|
358
|
+
grpc_servers.append(fleet_server)
|
|
359
|
+
else:
|
|
360
|
+
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
361
|
+
|
|
362
|
+
if args.isolation == ISOLATION_MODE_SUBPROCESS:
|
|
363
|
+
# Scheduler thread
|
|
364
|
+
scheduler_th = threading.Thread(
|
|
365
|
+
target=_flwr_serverapp_scheduler,
|
|
366
|
+
args=(
|
|
367
|
+
state_factory,
|
|
368
|
+
args.serverappio_api_address,
|
|
369
|
+
args.ssl_ca_certfile,
|
|
370
|
+
),
|
|
371
|
+
)
|
|
372
|
+
scheduler_th.start()
|
|
373
|
+
bckg_threads.append(scheduler_th)
|
|
350
374
|
|
|
351
375
|
# Graceful shutdown
|
|
352
376
|
register_exit_handlers(
|
|
@@ -361,7 +385,7 @@ def run_superlink() -> None:
|
|
|
361
385
|
for thread in bckg_threads:
|
|
362
386
|
if not thread.is_alive():
|
|
363
387
|
sys.exit(1)
|
|
364
|
-
|
|
388
|
+
exec_server.wait_for_termination(timeout=1)
|
|
365
389
|
|
|
366
390
|
|
|
367
391
|
def _flwr_serverapp_scheduler(
|
|
@@ -657,6 +681,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
|
657
681
|
_add_args_serverappio_api(parser=parser)
|
|
658
682
|
_add_args_fleet_api(parser=parser)
|
|
659
683
|
_add_args_exec_api(parser=parser)
|
|
684
|
+
_add_args_simulationio_api(parser=parser)
|
|
660
685
|
|
|
661
686
|
return parser
|
|
662
687
|
|
|
@@ -790,3 +815,11 @@ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None:
|
|
|
790
815
|
"For example:\n\n`--executor-config 'verbose=true "
|
|
791
816
|
'root-certificates="certificates/superlink-ca.crt"\'`',
|
|
792
817
|
)
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def _add_args_simulationio_api(parser: argparse.ArgumentParser) -> None:
|
|
821
|
+
parser.add_argument(
|
|
822
|
+
"--simulationio-api-address",
|
|
823
|
+
help="SimulationIo API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
824
|
+
default=SIMULATIONIO_API_DEFAULT_ADDRESS,
|
|
825
|
+
)
|
flwr/server/driver/driver.py
CHANGED
|
@@ -27,7 +27,7 @@ class Driver(ABC):
|
|
|
27
27
|
"""Abstract base Driver class for the ServerAppIo API."""
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
|
-
def
|
|
30
|
+
def set_run(self, run_id: int) -> None:
|
|
31
31
|
"""Request a run to the SuperLink with a given `run_id`.
|
|
32
32
|
|
|
33
33
|
If a Run with the specified `run_id` exists, a local Run
|
|
@@ -112,12 +112,8 @@ class GrpcDriver(Driver):
|
|
|
112
112
|
channel.close()
|
|
113
113
|
log(DEBUG, "[Driver] Disconnected")
|
|
114
114
|
|
|
115
|
-
def
|
|
116
|
-
"""
|
|
117
|
-
# Check if is initialized
|
|
118
|
-
if self._run is not None:
|
|
119
|
-
return
|
|
120
|
-
|
|
115
|
+
def set_run(self, run_id: int) -> None:
|
|
116
|
+
"""Set the run."""
|
|
121
117
|
# Get the run info
|
|
122
118
|
req = GetRunRequest(run_id=run_id)
|
|
123
119
|
res: GetRunResponse = self._stub.GetRun(req)
|
|
@@ -62,10 +62,8 @@ class InMemoryDriver(Driver):
|
|
|
62
62
|
):
|
|
63
63
|
raise ValueError(f"Invalid message: {message}")
|
|
64
64
|
|
|
65
|
-
def
|
|
65
|
+
def set_run(self, run_id: int) -> None:
|
|
66
66
|
"""Initialize the run."""
|
|
67
|
-
if self._run is not None:
|
|
68
|
-
return
|
|
69
67
|
run = self.state.get_run(run_id)
|
|
70
68
|
if run is None:
|
|
71
69
|
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -174,7 +174,7 @@ def run_server_app() -> None:
|
|
|
174
174
|
root_certificates=root_certificates,
|
|
175
175
|
)
|
|
176
176
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
177
|
-
driver.
|
|
177
|
+
driver.set_run(args.run_id)
|
|
178
178
|
run_ = driver.run
|
|
179
179
|
if not run_.fab_hash:
|
|
180
180
|
raise ValueError("FAB hash not provided.")
|
|
@@ -188,6 +188,7 @@ def run_server_app() -> None:
|
|
|
188
188
|
|
|
189
189
|
app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
|
|
190
190
|
config = get_project_config(app_path)
|
|
191
|
+
run_id = run_.run_id
|
|
191
192
|
else:
|
|
192
193
|
# User provided `app_dir`, but not `--run-id`
|
|
193
194
|
# Create run if run_id is not provided
|
|
@@ -203,7 +204,8 @@ def run_server_app() -> None:
|
|
|
203
204
|
req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
|
|
204
205
|
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
205
206
|
# Fetch full `Run` using `run_id`
|
|
206
|
-
driver.
|
|
207
|
+
driver.set_run(res.run_id) # pylint: disable=W0212
|
|
208
|
+
run_id = res.run_id
|
|
207
209
|
|
|
208
210
|
# Obtain server app reference and the run config
|
|
209
211
|
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
@@ -221,6 +223,7 @@ def run_server_app() -> None:
|
|
|
221
223
|
|
|
222
224
|
# Initialize Context
|
|
223
225
|
context = Context(
|
|
226
|
+
run_id=run_id,
|
|
224
227
|
node_id=0,
|
|
225
228
|
node_config={},
|
|
226
229
|
state=RecordSet(),
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -189,7 +189,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
189
189
|
run = run_from_proto(res.run)
|
|
190
190
|
fab = fab_from_proto(res.fab)
|
|
191
191
|
|
|
192
|
-
driver.
|
|
192
|
+
driver.set_run(run.run_id)
|
|
193
193
|
|
|
194
194
|
# Start log uploader for this run
|
|
195
195
|
log_uploader = start_log_uploader(
|
|
@@ -23,6 +23,7 @@ from uuid import UUID
|
|
|
23
23
|
|
|
24
24
|
import grpc
|
|
25
25
|
|
|
26
|
+
from flwr.common import ConfigsRecord
|
|
26
27
|
from flwr.common.constant import Status
|
|
27
28
|
from flwr.common.logger import log
|
|
28
29
|
from flwr.common.serde import (
|
|
@@ -112,6 +113,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
112
113
|
request.fab_version,
|
|
113
114
|
fab_hash,
|
|
114
115
|
user_config_from_proto(request.override_config),
|
|
116
|
+
ConfigsRecord(),
|
|
115
117
|
)
|
|
116
118
|
return CreateRunResponse(run_id=run_id)
|
|
117
119
|
|
|
@@ -30,6 +30,7 @@ from flwr.common.constant import (
|
|
|
30
30
|
RUN_ID_NUM_BYTES,
|
|
31
31
|
Status,
|
|
32
32
|
)
|
|
33
|
+
from flwr.common.record import ConfigsRecord
|
|
33
34
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
34
35
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
35
36
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
@@ -39,7 +40,6 @@ from .utils import (
|
|
|
39
40
|
generate_rand_int_from_bytes,
|
|
40
41
|
has_valid_sub_status,
|
|
41
42
|
is_valid_transition,
|
|
42
|
-
make_node_unavailable_taskres,
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
|
|
@@ -69,6 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
69
69
|
# Map run_id to RunRecord
|
|
70
70
|
self.run_ids: dict[int, RunRecord] = {}
|
|
71
71
|
self.contexts: dict[int, Context] = {}
|
|
72
|
+
self.federation_options: dict[int, ConfigsRecord] = {}
|
|
72
73
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
73
74
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
74
75
|
|
|
@@ -255,21 +256,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
255
256
|
task_res_list.append(task_res)
|
|
256
257
|
replied_task_ids.add(reply_to)
|
|
257
258
|
|
|
258
|
-
# Check if the node is offline
|
|
259
|
-
for task_id in task_ids - replied_task_ids:
|
|
260
|
-
task_ins = self.task_ins_store.get(task_id)
|
|
261
|
-
if task_ins is None:
|
|
262
|
-
continue
|
|
263
|
-
node_id = task_ins.task.consumer.node_id
|
|
264
|
-
online_until, _ = self.node_ids[node_id]
|
|
265
|
-
# Generate a TaskRes containing an error reply if the node is offline.
|
|
266
|
-
if online_until < time.time():
|
|
267
|
-
err_taskres = make_node_unavailable_taskres(
|
|
268
|
-
ref_taskins=task_ins,
|
|
269
|
-
)
|
|
270
|
-
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
|
|
271
|
-
task_res_list.append(err_taskres)
|
|
272
|
-
|
|
273
259
|
# Mark all of them as delivered
|
|
274
260
|
delivered_at = now().isoformat()
|
|
275
261
|
for task_res in task_res_list:
|
|
@@ -378,12 +364,14 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
378
364
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
379
365
|
return self.public_key_to_node_id.get(node_public_key)
|
|
380
366
|
|
|
367
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
381
368
|
def create_run(
|
|
382
369
|
self,
|
|
383
370
|
fab_id: Optional[str],
|
|
384
371
|
fab_version: Optional[str],
|
|
385
372
|
fab_hash: Optional[str],
|
|
386
373
|
override_config: UserConfig,
|
|
374
|
+
federation_options: ConfigsRecord,
|
|
387
375
|
) -> int:
|
|
388
376
|
"""Create a new run for the specified `fab_hash`."""
|
|
389
377
|
# Sample a random int64 as run_id
|
|
@@ -407,6 +395,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
407
395
|
pending_at=now().isoformat(),
|
|
408
396
|
)
|
|
409
397
|
self.run_ids[run_id] = run_record
|
|
398
|
+
|
|
399
|
+
# Record federation options. Leave empty if not passed
|
|
400
|
+
self.federation_options[run_id] = federation_options
|
|
410
401
|
return run_id
|
|
411
402
|
log(ERROR, "Unexpected run creation failure.")
|
|
412
403
|
return 0
|
|
@@ -514,6 +505,14 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
514
505
|
|
|
515
506
|
return pending_run_id
|
|
516
507
|
|
|
508
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
509
|
+
"""Retrieve the federation options for the specified `run_id`."""
|
|
510
|
+
with self.lock:
|
|
511
|
+
if run_id not in self.run_ids:
|
|
512
|
+
log(ERROR, "`run_id` is invalid")
|
|
513
|
+
return None
|
|
514
|
+
return self.federation_options[run_id]
|
|
515
|
+
|
|
517
516
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
518
517
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
519
518
|
with self.lock:
|
|
@@ -20,6 +20,7 @@ from typing import Optional
|
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.common import Context
|
|
23
|
+
from flwr.common.record import ConfigsRecord
|
|
23
24
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
24
25
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
25
26
|
|
|
@@ -152,12 +153,13 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
152
153
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
153
154
|
|
|
154
155
|
@abc.abstractmethod
|
|
155
|
-
def create_run(
|
|
156
|
+
def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
156
157
|
self,
|
|
157
158
|
fab_id: Optional[str],
|
|
158
159
|
fab_version: Optional[str],
|
|
159
160
|
fab_hash: Optional[str],
|
|
160
161
|
override_config: UserConfig,
|
|
162
|
+
federation_options: ConfigsRecord,
|
|
161
163
|
) -> int:
|
|
162
164
|
"""Create a new run for the specified `fab_hash`."""
|
|
163
165
|
|
|
@@ -227,6 +229,21 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
227
229
|
there is no Run pending.
|
|
228
230
|
"""
|
|
229
231
|
|
|
232
|
+
@abc.abstractmethod
|
|
233
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
234
|
+
"""Retrieve the federation options for the specified `run_id`.
|
|
235
|
+
|
|
236
|
+
Parameters
|
|
237
|
+
----------
|
|
238
|
+
run_id : int
|
|
239
|
+
The identifier of the run.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
Optional[ConfigsRecord]
|
|
244
|
+
The federation options for the run if it exists; None otherwise.
|
|
245
|
+
"""
|
|
246
|
+
|
|
230
247
|
@abc.abstractmethod
|
|
231
248
|
def store_server_private_public_key(
|
|
232
249
|
self, private_key: bytes, public_key: bytes
|
|
@@ -33,6 +33,7 @@ from flwr.common.constant import (
|
|
|
33
33
|
RUN_ID_NUM_BYTES,
|
|
34
34
|
Status,
|
|
35
35
|
)
|
|
36
|
+
from flwr.common.record import ConfigsRecord
|
|
36
37
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
37
38
|
|
|
38
39
|
# pylint: disable=E0611
|
|
@@ -45,6 +46,8 @@ from flwr.server.utils.validator import validate_task_ins_or_res
|
|
|
45
46
|
|
|
46
47
|
from .linkstate import LinkState
|
|
47
48
|
from .utils import (
|
|
49
|
+
configsrecord_from_bytes,
|
|
50
|
+
configsrecord_to_bytes,
|
|
48
51
|
context_from_bytes,
|
|
49
52
|
context_to_bytes,
|
|
50
53
|
convert_sint64_to_uint64,
|
|
@@ -54,7 +57,6 @@ from .utils import (
|
|
|
54
57
|
generate_rand_int_from_bytes,
|
|
55
58
|
has_valid_sub_status,
|
|
56
59
|
is_valid_transition,
|
|
57
|
-
make_node_unavailable_taskres,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -95,7 +97,8 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
95
97
|
running_at TEXT,
|
|
96
98
|
finished_at TEXT,
|
|
97
99
|
sub_status TEXT,
|
|
98
|
-
details TEXT
|
|
100
|
+
details TEXT,
|
|
101
|
+
federation_options BLOB
|
|
99
102
|
);
|
|
100
103
|
"""
|
|
101
104
|
|
|
@@ -636,20 +639,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
636
639
|
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
637
640
|
task_ins_rows = self.query(query, data)
|
|
638
641
|
|
|
639
|
-
# Make TaskRes containing node unavailabe error
|
|
640
|
-
for row in task_ins_rows:
|
|
641
|
-
for row in rows:
|
|
642
|
-
# Convert values from sint64 to uint64
|
|
643
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
644
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
645
|
-
)
|
|
646
|
-
|
|
647
|
-
task_ins = dict_to_task_ins(row)
|
|
648
|
-
err_taskres = make_node_unavailable_taskres(
|
|
649
|
-
ref_taskins=task_ins,
|
|
650
|
-
)
|
|
651
|
-
result.append(err_taskres)
|
|
652
|
-
|
|
653
642
|
return result
|
|
654
643
|
|
|
655
644
|
def num_task_ins(self) -> int:
|
|
@@ -810,12 +799,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
810
799
|
return uint64_node_id
|
|
811
800
|
return None
|
|
812
801
|
|
|
802
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
813
803
|
def create_run(
|
|
814
804
|
self,
|
|
815
805
|
fab_id: Optional[str],
|
|
816
806
|
fab_version: Optional[str],
|
|
817
807
|
fab_hash: Optional[str],
|
|
818
808
|
override_config: UserConfig,
|
|
809
|
+
federation_options: ConfigsRecord,
|
|
819
810
|
) -> int:
|
|
820
811
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
821
812
|
# Sample a random int64 as run_id
|
|
@@ -830,15 +821,29 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
830
821
|
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
831
822
|
query = (
|
|
832
823
|
"INSERT INTO run "
|
|
833
|
-
"(run_id, fab_id, fab_version, fab_hash, override_config,
|
|
834
|
-
"starting_at, running_at, finished_at,
|
|
835
|
-
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
824
|
+
"(run_id, fab_id, fab_version, fab_hash, override_config, "
|
|
825
|
+
"federation_options, pending_at, starting_at, running_at, finished_at, "
|
|
826
|
+
"sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
836
827
|
)
|
|
837
828
|
if fab_hash:
|
|
838
829
|
fab_id, fab_version = "", ""
|
|
839
830
|
override_config_json = json.dumps(override_config)
|
|
840
|
-
data = [
|
|
841
|
-
|
|
831
|
+
data = [
|
|
832
|
+
sint64_run_id,
|
|
833
|
+
fab_id,
|
|
834
|
+
fab_version,
|
|
835
|
+
fab_hash,
|
|
836
|
+
override_config_json,
|
|
837
|
+
configsrecord_to_bytes(federation_options),
|
|
838
|
+
]
|
|
839
|
+
data += [
|
|
840
|
+
now().isoformat(),
|
|
841
|
+
"",
|
|
842
|
+
"",
|
|
843
|
+
"",
|
|
844
|
+
"",
|
|
845
|
+
"",
|
|
846
|
+
]
|
|
842
847
|
self.query(query, tuple(data))
|
|
843
848
|
return uint64_run_id
|
|
844
849
|
log(ERROR, "Unexpected run creation failure.")
|
|
@@ -1003,6 +1008,21 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1003
1008
|
|
|
1004
1009
|
return pending_run_id
|
|
1005
1010
|
|
|
1011
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
1012
|
+
"""Retrieve the federation options for the specified `run_id`."""
|
|
1013
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1014
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1015
|
+
query = "SELECT federation_options FROM run WHERE run_id = ?;"
|
|
1016
|
+
rows = self.query(query, (sint64_run_id,))
|
|
1017
|
+
|
|
1018
|
+
# Check if the run_id exists
|
|
1019
|
+
if not rows:
|
|
1020
|
+
log(ERROR, "`run_id` is invalid")
|
|
1021
|
+
return None
|
|
1022
|
+
|
|
1023
|
+
row = rows[0]
|
|
1024
|
+
return configsrecord_from_bytes(row["federation_options"])
|
|
1025
|
+
|
|
1006
1026
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
1007
1027
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
1008
1028
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|