flwr-nightly 1.9.0.dev20240531__py3-none-any.whl → 1.10.0.dev20240612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +2 -15
- flwr/cli/config_utils.py +11 -4
- flwr/cli/install.py +196 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/utils.py +14 -0
- flwr/client/__init__.py +1 -0
- flwr/client/app.py +135 -97
- flwr/client/client_app.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -4
- flwr/client/mod/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -2
- flwr/client/supernode/app.py +29 -5
- flwr/common/object_ref.py +13 -9
- flwr/proto/driver_pb2.py +20 -19
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/fleet_pb2.py +28 -33
- flwr/proto/fleet_pb2.pyi +0 -42
- flwr/proto/fleet_pb2_grpc.py +7 -6
- flwr/proto/fleet_pb2_grpc.pyi +5 -4
- flwr/proto/run_pb2.py +30 -0
- flwr/proto/run_pb2.pyi +52 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +57 -214
- flwr/server/run_serverapp.py +29 -5
- flwr/server/server_app.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/RECORD +46 -41
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/entry_points.txt +0 -2
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/WHEEL +0 -0
flwr/server/app.py
CHANGED
|
@@ -40,7 +40,7 @@ from flwr.common.constant import (
|
|
|
40
40
|
TRANSPORT_TYPE_REST,
|
|
41
41
|
)
|
|
42
42
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
43
|
-
from flwr.common.logger import log
|
|
43
|
+
from flwr.common.logger import log
|
|
44
44
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
45
45
|
private_key_to_bytes,
|
|
46
46
|
public_key_to_bytes,
|
|
@@ -190,120 +190,6 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
190
190
|
return hist
|
|
191
191
|
|
|
192
192
|
|
|
193
|
-
def run_driver_api() -> None:
|
|
194
|
-
"""Run Flower server (Driver API)."""
|
|
195
|
-
log(INFO, "Starting Flower server (Driver API)")
|
|
196
|
-
# Running `flower-driver-api` is deprecated
|
|
197
|
-
warn_deprecated_feature("flower-driver-api")
|
|
198
|
-
log(WARN, "Use `flower-superlink` instead")
|
|
199
|
-
event(EventType.RUN_DRIVER_API_ENTER)
|
|
200
|
-
args = _parse_args_run_driver_api().parse_args()
|
|
201
|
-
|
|
202
|
-
# Parse IP address
|
|
203
|
-
parsed_address = parse_address(args.driver_api_address)
|
|
204
|
-
if not parsed_address:
|
|
205
|
-
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
206
|
-
host, port, is_v6 = parsed_address
|
|
207
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
208
|
-
|
|
209
|
-
# Obtain certificates
|
|
210
|
-
certificates = _try_obtain_certificates(args)
|
|
211
|
-
|
|
212
|
-
# Initialize StateFactory
|
|
213
|
-
state_factory = StateFactory(args.database)
|
|
214
|
-
|
|
215
|
-
# Start server
|
|
216
|
-
grpc_server: grpc.Server = run_driver_api_grpc(
|
|
217
|
-
address=address,
|
|
218
|
-
state_factory=state_factory,
|
|
219
|
-
certificates=certificates,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
# Graceful shutdown
|
|
223
|
-
register_exit_handlers(
|
|
224
|
-
event_type=EventType.RUN_DRIVER_API_LEAVE,
|
|
225
|
-
grpc_servers=[grpc_server],
|
|
226
|
-
bckg_threads=[],
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
# Block
|
|
230
|
-
grpc_server.wait_for_termination()
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def run_fleet_api() -> None:
|
|
234
|
-
"""Run Flower server (Fleet API)."""
|
|
235
|
-
log(INFO, "Starting Flower server (Fleet API)")
|
|
236
|
-
# Running `flower-fleet-api` is deprecated
|
|
237
|
-
warn_deprecated_feature("flower-fleet-api")
|
|
238
|
-
log(WARN, "Use `flower-superlink` instead")
|
|
239
|
-
event(EventType.RUN_FLEET_API_ENTER)
|
|
240
|
-
args = _parse_args_run_fleet_api().parse_args()
|
|
241
|
-
|
|
242
|
-
# Obtain certificates
|
|
243
|
-
certificates = _try_obtain_certificates(args)
|
|
244
|
-
|
|
245
|
-
# Initialize StateFactory
|
|
246
|
-
state_factory = StateFactory(args.database)
|
|
247
|
-
|
|
248
|
-
grpc_servers = []
|
|
249
|
-
bckg_threads = []
|
|
250
|
-
|
|
251
|
-
# Start Fleet API
|
|
252
|
-
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
253
|
-
if (
|
|
254
|
-
importlib.util.find_spec("requests")
|
|
255
|
-
and importlib.util.find_spec("starlette")
|
|
256
|
-
and importlib.util.find_spec("uvicorn")
|
|
257
|
-
) is None:
|
|
258
|
-
sys.exit(MISSING_EXTRA_REST)
|
|
259
|
-
address_arg = args.rest_fleet_api_address
|
|
260
|
-
parsed_address = parse_address(address_arg)
|
|
261
|
-
if not parsed_address:
|
|
262
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
263
|
-
host, port, _ = parsed_address
|
|
264
|
-
fleet_thread = threading.Thread(
|
|
265
|
-
target=_run_fleet_api_rest,
|
|
266
|
-
args=(
|
|
267
|
-
host,
|
|
268
|
-
port,
|
|
269
|
-
args.ssl_keyfile,
|
|
270
|
-
args.ssl_certfile,
|
|
271
|
-
state_factory,
|
|
272
|
-
args.rest_fleet_api_workers,
|
|
273
|
-
),
|
|
274
|
-
)
|
|
275
|
-
fleet_thread.start()
|
|
276
|
-
bckg_threads.append(fleet_thread)
|
|
277
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
278
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
279
|
-
parsed_address = parse_address(address_arg)
|
|
280
|
-
if not parsed_address:
|
|
281
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
282
|
-
host, port, is_v6 = parsed_address
|
|
283
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
284
|
-
fleet_server = _run_fleet_api_grpc_rere(
|
|
285
|
-
address=address,
|
|
286
|
-
state_factory=state_factory,
|
|
287
|
-
certificates=certificates,
|
|
288
|
-
)
|
|
289
|
-
grpc_servers.append(fleet_server)
|
|
290
|
-
else:
|
|
291
|
-
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
292
|
-
|
|
293
|
-
# Graceful shutdown
|
|
294
|
-
register_exit_handlers(
|
|
295
|
-
event_type=EventType.RUN_FLEET_API_LEAVE,
|
|
296
|
-
grpc_servers=grpc_servers,
|
|
297
|
-
bckg_threads=bckg_threads,
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
# Block
|
|
301
|
-
if len(grpc_servers) > 0:
|
|
302
|
-
grpc_servers[0].wait_for_termination()
|
|
303
|
-
elif len(bckg_threads) > 0:
|
|
304
|
-
bckg_threads[0].join()
|
|
305
|
-
|
|
306
|
-
|
|
307
193
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
308
194
|
def run_superlink() -> None:
|
|
309
195
|
"""Run Flower SuperLink (Driver API and Fleet API)."""
|
|
@@ -314,11 +200,15 @@ def run_superlink() -> None:
|
|
|
314
200
|
args = _parse_args_run_superlink().parse_args()
|
|
315
201
|
|
|
316
202
|
# Parse IP address
|
|
317
|
-
|
|
318
|
-
if not
|
|
203
|
+
parsed_driver_address = parse_address(args.driver_api_address)
|
|
204
|
+
if not parsed_driver_address:
|
|
319
205
|
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
320
|
-
|
|
321
|
-
|
|
206
|
+
driver_host, driver_port, driver_is_v6 = parsed_driver_address
|
|
207
|
+
driver_address = (
|
|
208
|
+
f"[{driver_host}]:{driver_port}"
|
|
209
|
+
if driver_is_v6
|
|
210
|
+
else f"{driver_host}:{driver_port}"
|
|
211
|
+
)
|
|
322
212
|
|
|
323
213
|
# Obtain certificates
|
|
324
214
|
certificates = _try_obtain_certificates(args)
|
|
@@ -328,13 +218,38 @@ def run_superlink() -> None:
|
|
|
328
218
|
|
|
329
219
|
# Start Driver API
|
|
330
220
|
driver_server: grpc.Server = run_driver_api_grpc(
|
|
331
|
-
address=
|
|
221
|
+
address=driver_address,
|
|
332
222
|
state_factory=state_factory,
|
|
333
223
|
certificates=certificates,
|
|
334
224
|
)
|
|
335
225
|
|
|
336
226
|
grpc_servers = [driver_server]
|
|
337
227
|
bckg_threads = []
|
|
228
|
+
if not args.fleet_api_address:
|
|
229
|
+
args.fleet_api_address = (
|
|
230
|
+
ADDRESS_FLEET_API_GRPC_RERE
|
|
231
|
+
if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE
|
|
232
|
+
else ADDRESS_FLEET_API_REST
|
|
233
|
+
)
|
|
234
|
+
parsed_fleet_address = parse_address(args.fleet_api_address)
|
|
235
|
+
if not parsed_fleet_address:
|
|
236
|
+
sys.exit(f"Fleet IP address ({args.fleet_api_address}) cannot be parsed.")
|
|
237
|
+
fleet_host, fleet_port, fleet_is_v6 = parsed_fleet_address
|
|
238
|
+
fleet_address = (
|
|
239
|
+
f"[{fleet_host}]:{fleet_port}" if fleet_is_v6 else f"{fleet_host}:{fleet_port}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
num_workers = args.fleet_api_num_workers
|
|
243
|
+
if num_workers != 1:
|
|
244
|
+
log(
|
|
245
|
+
WARN,
|
|
246
|
+
"The Fleet API currently supports only 1 worker. "
|
|
247
|
+
"You have specified %d workers. "
|
|
248
|
+
"Support for multiple workers will be added in future releases. "
|
|
249
|
+
"Proceeding with a single worker.",
|
|
250
|
+
args.fleet_api_num_workers,
|
|
251
|
+
)
|
|
252
|
+
num_workers = 1
|
|
338
253
|
|
|
339
254
|
# Start Fleet API
|
|
340
255
|
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
@@ -344,35 +259,25 @@ def run_superlink() -> None:
|
|
|
344
259
|
and importlib.util.find_spec("uvicorn")
|
|
345
260
|
) is None:
|
|
346
261
|
sys.exit(MISSING_EXTRA_REST)
|
|
347
|
-
|
|
348
|
-
parsed_address = parse_address(address_arg)
|
|
262
|
+
|
|
349
263
|
_, ssl_certfile, ssl_keyfile = (
|
|
350
264
|
certificates if certificates is not None else (None, None, None)
|
|
351
265
|
)
|
|
352
|
-
|
|
353
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
354
|
-
host, port, _ = parsed_address
|
|
266
|
+
|
|
355
267
|
fleet_thread = threading.Thread(
|
|
356
268
|
target=_run_fleet_api_rest,
|
|
357
269
|
args=(
|
|
358
|
-
|
|
359
|
-
|
|
270
|
+
fleet_host,
|
|
271
|
+
fleet_port,
|
|
360
272
|
ssl_keyfile,
|
|
361
273
|
ssl_certfile,
|
|
362
274
|
state_factory,
|
|
363
|
-
|
|
275
|
+
num_workers,
|
|
364
276
|
),
|
|
365
277
|
)
|
|
366
278
|
fleet_thread.start()
|
|
367
279
|
bckg_threads.append(fleet_thread)
|
|
368
280
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
369
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
370
|
-
parsed_address = parse_address(address_arg)
|
|
371
|
-
if not parsed_address:
|
|
372
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
373
|
-
host, port, is_v6 = parsed_address
|
|
374
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
375
|
-
|
|
376
281
|
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
377
282
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
378
283
|
if maybe_keys is not None:
|
|
@@ -395,7 +300,7 @@ def run_superlink() -> None:
|
|
|
395
300
|
interceptors = [AuthenticateServerInterceptor(state)]
|
|
396
301
|
|
|
397
302
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
398
|
-
address=
|
|
303
|
+
address=fleet_address,
|
|
399
304
|
state_factory=state_factory,
|
|
400
305
|
certificates=certificates,
|
|
401
306
|
interceptors=interceptors,
|
|
@@ -596,7 +501,7 @@ def _run_fleet_api_rest(
|
|
|
596
501
|
ssl_keyfile: Optional[str],
|
|
597
502
|
ssl_certfile: Optional[str],
|
|
598
503
|
state_factory: StateFactory,
|
|
599
|
-
|
|
504
|
+
num_workers: int,
|
|
600
505
|
) -> None:
|
|
601
506
|
"""Run Driver API (REST-based)."""
|
|
602
507
|
try:
|
|
@@ -605,12 +510,7 @@ def _run_fleet_api_rest(
|
|
|
605
510
|
from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
|
|
606
511
|
except ModuleNotFoundError:
|
|
607
512
|
sys.exit(MISSING_EXTRA_REST)
|
|
608
|
-
|
|
609
|
-
raise ValueError(
|
|
610
|
-
f"The supported number of workers for the Fleet API (REST server) is "
|
|
611
|
-
f"1. Instead given {workers}. The functionality of >1 workers will be "
|
|
612
|
-
f"added in the future releases."
|
|
613
|
-
)
|
|
513
|
+
|
|
614
514
|
log(INFO, "Starting Flower REST server")
|
|
615
515
|
|
|
616
516
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
|
@@ -624,44 +524,10 @@ def _run_fleet_api_rest(
|
|
|
624
524
|
access_log=True,
|
|
625
525
|
ssl_keyfile=ssl_keyfile,
|
|
626
526
|
ssl_certfile=ssl_certfile,
|
|
627
|
-
workers=
|
|
527
|
+
workers=num_workers,
|
|
628
528
|
)
|
|
629
529
|
|
|
630
530
|
|
|
631
|
-
def _parse_args_run_driver_api() -> argparse.ArgumentParser:
|
|
632
|
-
"""Parse command line arguments for Driver API."""
|
|
633
|
-
parser = argparse.ArgumentParser(
|
|
634
|
-
description="Start a Flower Driver API server. "
|
|
635
|
-
"This server will be responsible for "
|
|
636
|
-
"receiving TaskIns from the Driver script and "
|
|
637
|
-
"sending them to the Fleet API. Once the client nodes "
|
|
638
|
-
"are done, they will send the TaskRes back to this Driver API server (through"
|
|
639
|
-
" the Fleet API) which will then send them back to the Driver script.",
|
|
640
|
-
)
|
|
641
|
-
|
|
642
|
-
_add_args_common(parser=parser)
|
|
643
|
-
_add_args_driver_api(parser=parser)
|
|
644
|
-
|
|
645
|
-
return parser
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
|
|
649
|
-
"""Parse command line arguments for Fleet API."""
|
|
650
|
-
parser = argparse.ArgumentParser(
|
|
651
|
-
description="Start a Flower Fleet API server."
|
|
652
|
-
"This server will be responsible for "
|
|
653
|
-
"sending TaskIns (received from the Driver API) to the client nodes "
|
|
654
|
-
"and of receiving TaskRes sent back from those same client nodes once "
|
|
655
|
-
"they are done. Then, this Fleet API server can send those "
|
|
656
|
-
"TaskRes back to the Driver API.",
|
|
657
|
-
)
|
|
658
|
-
|
|
659
|
-
_add_args_common(parser=parser)
|
|
660
|
-
_add_args_fleet_api(parser=parser)
|
|
661
|
-
|
|
662
|
-
return parser
|
|
663
|
-
|
|
664
|
-
|
|
665
531
|
def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
666
532
|
"""Parse command line arguments for both Driver API and Fleet API."""
|
|
667
533
|
parser = argparse.ArgumentParser(
|
|
@@ -732,50 +598,27 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
732
598
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
733
599
|
parser.add_argument(
|
|
734
600
|
"--driver-api-address",
|
|
735
|
-
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name)",
|
|
601
|
+
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
736
602
|
default=ADDRESS_DRIVER_API,
|
|
737
603
|
)
|
|
738
604
|
|
|
739
605
|
|
|
740
606
|
def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
741
607
|
# Fleet API transport layer type
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
"--grpc-rere",
|
|
745
|
-
action="store_const",
|
|
746
|
-
dest="fleet_api_type",
|
|
747
|
-
const=TRANSPORT_TYPE_GRPC_RERE,
|
|
608
|
+
parser.add_argument(
|
|
609
|
+
"--fleet-api-type",
|
|
748
610
|
default=TRANSPORT_TYPE_GRPC_RERE,
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
"--rest",
|
|
753
|
-
action="store_const",
|
|
754
|
-
dest="fleet_api_type",
|
|
755
|
-
const=TRANSPORT_TYPE_REST,
|
|
756
|
-
help="Start a Fleet API server (REST, experimental)",
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
# Fleet API gRPC-rere options
|
|
760
|
-
grpc_rere_group = parser.add_argument_group(
|
|
761
|
-
"Fleet API (gRPC-rere) server options", ""
|
|
762
|
-
)
|
|
763
|
-
grpc_rere_group.add_argument(
|
|
764
|
-
"--grpc-rere-fleet-api-address",
|
|
765
|
-
help="Fleet API (gRPC-rere) server address (IPv4, IPv6, or a domain name)",
|
|
766
|
-
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
611
|
+
type=str,
|
|
612
|
+
choices=[TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST],
|
|
613
|
+
help="Start a gRPC-rere or REST (experimental) Fleet API server.",
|
|
767
614
|
)
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
rest_group.add_argument(
|
|
772
|
-
"--rest-fleet-api-address",
|
|
773
|
-
help="Fleet API (REST) server address (IPv4, IPv6, or a domain name)",
|
|
774
|
-
default=ADDRESS_FLEET_API_REST,
|
|
615
|
+
parser.add_argument(
|
|
616
|
+
"--fleet-api-address",
|
|
617
|
+
help="Fleet API server address (IPv4, IPv6, or a domain name).",
|
|
775
618
|
)
|
|
776
|
-
|
|
777
|
-
"--
|
|
778
|
-
help="Set the number of concurrent workers for the Fleet API REST server.",
|
|
779
|
-
type=int,
|
|
619
|
+
parser.add_argument(
|
|
620
|
+
"--fleet-api-num-workers",
|
|
780
621
|
default=1,
|
|
622
|
+
type=int,
|
|
623
|
+
help="Set the number of concurrent workers for the Fleet API server.",
|
|
781
624
|
)
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,12 +22,14 @@ from pathlib import Path
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
|
-
from flwr.common.logger import log, update_console_handler
|
|
25
|
+
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
26
26
|
from flwr.common.object_ref import load_app
|
|
27
27
|
|
|
28
28
|
from .driver import Driver, GrpcDriver
|
|
29
29
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
30
|
|
|
31
|
+
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
32
|
+
|
|
31
33
|
|
|
32
34
|
def run(
|
|
33
35
|
driver: Driver,
|
|
@@ -76,6 +78,23 @@ def run_server_app() -> None:
|
|
|
76
78
|
|
|
77
79
|
args = _parse_args_run_server_app().parse_args()
|
|
78
80
|
|
|
81
|
+
if args.server != ADDRESS_DRIVER_API:
|
|
82
|
+
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
83
|
+
warn_deprecated_feature(warn)
|
|
84
|
+
|
|
85
|
+
if args.superlink != ADDRESS_DRIVER_API:
|
|
86
|
+
# if `--superlink` also passed, then
|
|
87
|
+
# warn user that this argument overrides what was passed with `--server`
|
|
88
|
+
log(
|
|
89
|
+
WARN,
|
|
90
|
+
"Both `--server` and `--superlink` were passed. "
|
|
91
|
+
"`--server` will be ignored. Connecting to the Superlink Driver API "
|
|
92
|
+
"at %s.",
|
|
93
|
+
args.superlink,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
args.superlink = args.server
|
|
97
|
+
|
|
79
98
|
update_console_handler(
|
|
80
99
|
level=DEBUG if args.verbose else INFO,
|
|
81
100
|
timestamps=args.verbose,
|
|
@@ -95,7 +114,7 @@ def run_server_app() -> None:
|
|
|
95
114
|
WARN,
|
|
96
115
|
"Option `--insecure` was set. "
|
|
97
116
|
"Starting insecure HTTP client connected to %s.",
|
|
98
|
-
args.
|
|
117
|
+
args.superlink,
|
|
99
118
|
)
|
|
100
119
|
root_certificates = None
|
|
101
120
|
else:
|
|
@@ -109,7 +128,7 @@ def run_server_app() -> None:
|
|
|
109
128
|
DEBUG,
|
|
110
129
|
"Starting secure HTTPS client connected to %s "
|
|
111
130
|
"with the following certificates: %s.",
|
|
112
|
-
args.
|
|
131
|
+
args.superlink,
|
|
113
132
|
cert_path,
|
|
114
133
|
)
|
|
115
134
|
|
|
@@ -130,7 +149,7 @@ def run_server_app() -> None:
|
|
|
130
149
|
|
|
131
150
|
# Initialize GrpcDriver
|
|
132
151
|
driver = GrpcDriver(
|
|
133
|
-
driver_service_address=args.
|
|
152
|
+
driver_service_address=args.superlink,
|
|
134
153
|
root_certificates=root_certificates,
|
|
135
154
|
fab_id=args.fab_id,
|
|
136
155
|
fab_version=args.fab_version,
|
|
@@ -175,9 +194,14 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
175
194
|
)
|
|
176
195
|
parser.add_argument(
|
|
177
196
|
"--server",
|
|
178
|
-
default=
|
|
197
|
+
default=ADDRESS_DRIVER_API,
|
|
179
198
|
help="Server address",
|
|
180
199
|
)
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--superlink",
|
|
202
|
+
default=ADDRESS_DRIVER_API,
|
|
203
|
+
help="SuperLink Driver API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
204
|
+
)
|
|
181
205
|
parser.add_argument(
|
|
182
206
|
"--dir",
|
|
183
207
|
default="",
|
flwr/server/server_app.py
CHANGED
|
@@ -39,7 +39,7 @@ class ServerApp:
|
|
|
39
39
|
>>> server_config = ServerConfig(num_rounds=3)
|
|
40
40
|
>>> strategy = FedAvg()
|
|
41
41
|
>>>
|
|
42
|
-
>>> app = ServerApp(
|
|
42
|
+
>>> app = ServerApp(
|
|
43
43
|
>>> server_config=server_config,
|
|
44
44
|
>>> strategy=strategy,
|
|
45
45
|
>>> )
|
|
@@ -106,7 +106,7 @@ class ServerApp:
|
|
|
106
106
|
>>> server_config = ServerConfig(num_rounds=3)
|
|
107
107
|
>>> strategy = FedAvg()
|
|
108
108
|
>>>
|
|
109
|
-
>>> app = ServerApp(
|
|
109
|
+
>>> app = ServerApp(
|
|
110
110
|
>>> server_config=server_config,
|
|
111
111
|
>>> strategy=strategy,
|
|
112
112
|
>>> )
|
|
@@ -35,6 +35,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
35
35
|
PushTaskInsResponse,
|
|
36
36
|
)
|
|
37
37
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
38
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
39
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
39
40
|
from flwr.server.superlink.state import State, StateFactory
|
|
40
41
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
@@ -129,6 +130,12 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
129
130
|
context.set_code(grpc.StatusCode.OK)
|
|
130
131
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
131
132
|
|
|
133
|
+
def GetRun(
|
|
134
|
+
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
135
|
+
) -> GetRunResponse:
|
|
136
|
+
"""Get run information."""
|
|
137
|
+
raise NotImplementedError
|
|
138
|
+
|
|
132
139
|
|
|
133
140
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
134
141
|
if validation_error:
|
|
@@ -26,8 +26,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
26
26
|
CreateNodeResponse,
|
|
27
27
|
DeleteNodeRequest,
|
|
28
28
|
DeleteNodeResponse,
|
|
29
|
-
GetRunRequest,
|
|
30
|
-
GetRunResponse,
|
|
31
29
|
PingRequest,
|
|
32
30
|
PingResponse,
|
|
33
31
|
PullTaskInsRequest,
|
|
@@ -35,6 +33,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
35
33
|
PushTaskResRequest,
|
|
36
34
|
PushTaskResResponse,
|
|
37
35
|
)
|
|
36
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
37
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
39
38
|
from flwr.server.superlink.state import StateFactory
|
|
40
39
|
|
|
@@ -34,8 +34,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
34
34
|
CreateNodeResponse,
|
|
35
35
|
DeleteNodeRequest,
|
|
36
36
|
DeleteNodeResponse,
|
|
37
|
-
GetRunRequest,
|
|
38
|
-
GetRunResponse,
|
|
39
37
|
PingRequest,
|
|
40
38
|
PingResponse,
|
|
41
39
|
PullTaskInsRequest,
|
|
@@ -44,6 +42,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
44
42
|
PushTaskResResponse,
|
|
45
43
|
)
|
|
46
44
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
45
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
47
46
|
from flwr.server.superlink.state import State
|
|
48
47
|
|
|
49
48
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
@@ -24,8 +24,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
24
24
|
CreateNodeResponse,
|
|
25
25
|
DeleteNodeRequest,
|
|
26
26
|
DeleteNodeResponse,
|
|
27
|
-
GetRunRequest,
|
|
28
|
-
GetRunResponse,
|
|
29
27
|
PingRequest,
|
|
30
28
|
PingResponse,
|
|
31
29
|
PullTaskInsRequest,
|
|
@@ -33,9 +31,13 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
33
31
|
PushTaskResRequest,
|
|
34
32
|
PushTaskResResponse,
|
|
35
33
|
Reconnect,
|
|
36
|
-
Run,
|
|
37
34
|
)
|
|
38
35
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
36
|
+
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
37
|
+
GetRunRequest,
|
|
38
|
+
GetRunResponse,
|
|
39
|
+
Run,
|
|
40
|
+
)
|
|
39
41
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
40
42
|
from flwr.server.superlink.state import State
|
|
41
43
|
|
|
@@ -21,11 +21,11 @@ from flwr.common.constant import MISSING_EXTRA_REST
|
|
|
21
21
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
22
|
CreateNodeRequest,
|
|
23
23
|
DeleteNodeRequest,
|
|
24
|
-
GetRunRequest,
|
|
25
24
|
PingRequest,
|
|
26
25
|
PullTaskInsRequest,
|
|
27
26
|
PushTaskResRequest,
|
|
28
27
|
)
|
|
28
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
29
29
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
30
30
|
from flwr.server.superlink.state import State
|
|
31
31
|
|