flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +19 -14
- flwr/cli/new/new.py +51 -22
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +3 -1
- flwr/client/app.py +20 -142
- flwr/client/grpc_client/connection.py +8 -2
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +33 -4
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +92 -169
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +281 -0
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +116 -6
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -70
- flwr/server/driver/__init__.py +2 -1
- flwr/server/driver/driver.py +12 -139
- flwr/server/driver/grpc_driver.py +199 -13
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +89 -12
- flwr/server/superlink/state/sqlite_state.py +133 -16
- flwr/server/superlink/state/state.py +56 -6
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
flwr/server/app.py
CHANGED
|
@@ -16,15 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import asyncio
|
|
19
|
+
import csv
|
|
19
20
|
import importlib.util
|
|
20
21
|
import sys
|
|
21
22
|
import threading
|
|
22
23
|
from logging import ERROR, INFO, WARN
|
|
23
24
|
from os.path import isfile
|
|
24
25
|
from pathlib import Path
|
|
25
|
-
from typing import List, Optional, Tuple
|
|
26
|
+
from typing import List, Optional, Sequence, Set, Tuple
|
|
26
27
|
|
|
27
28
|
import grpc
|
|
29
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
30
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
31
|
+
load_ssh_private_key,
|
|
32
|
+
load_ssh_public_key,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
30
36
|
from flwr.common.address import parse_address
|
|
@@ -36,6 +42,11 @@ from flwr.common.constant import (
|
|
|
36
42
|
)
|
|
37
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
38
44
|
from flwr.common.logger import log
|
|
45
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
46
|
+
private_key_to_bytes,
|
|
47
|
+
public_key_to_bytes,
|
|
48
|
+
ssh_types_to_elliptic_curve,
|
|
49
|
+
)
|
|
39
50
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
40
51
|
add_FleetServicer_to_server,
|
|
41
52
|
)
|
|
@@ -51,6 +62,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
51
62
|
start_grpc_server,
|
|
52
63
|
)
|
|
53
64
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
65
|
+
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
54
66
|
from .superlink.fleet.vce import start_vce
|
|
55
67
|
from .superlink.state import StateFactory
|
|
56
68
|
|
|
@@ -291,9 +303,11 @@ def run_fleet_api() -> None:
|
|
|
291
303
|
|
|
292
304
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
293
305
|
def run_superlink() -> None:
|
|
294
|
-
"""Run Flower
|
|
295
|
-
log(INFO, "Starting Flower
|
|
306
|
+
"""Run Flower SuperLink (Driver API and Fleet API)."""
|
|
307
|
+
log(INFO, "Starting Flower SuperLink")
|
|
308
|
+
|
|
296
309
|
event(EventType.RUN_SUPERLINK_ENTER)
|
|
310
|
+
|
|
297
311
|
args = _parse_args_run_superlink().parse_args()
|
|
298
312
|
|
|
299
313
|
# Parse IP address
|
|
@@ -352,10 +366,33 @@ def run_superlink() -> None:
|
|
|
352
366
|
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
353
367
|
host, port, is_v6 = parsed_address
|
|
354
368
|
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
369
|
+
|
|
370
|
+
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
371
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
372
|
+
if maybe_keys is not None:
|
|
373
|
+
(
|
|
374
|
+
client_public_keys,
|
|
375
|
+
server_private_key,
|
|
376
|
+
server_public_key,
|
|
377
|
+
) = maybe_keys
|
|
378
|
+
state = state_factory.state()
|
|
379
|
+
state.store_client_public_keys(client_public_keys)
|
|
380
|
+
state.store_server_private_public_key(
|
|
381
|
+
private_key_to_bytes(server_private_key),
|
|
382
|
+
public_key_to_bytes(server_public_key),
|
|
383
|
+
)
|
|
384
|
+
log(
|
|
385
|
+
INFO,
|
|
386
|
+
"Client authentication enabled with %d known public keys",
|
|
387
|
+
len(client_public_keys),
|
|
388
|
+
)
|
|
389
|
+
interceptors = [AuthenticateServerInterceptor(state)]
|
|
390
|
+
|
|
355
391
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
356
392
|
address=address,
|
|
357
393
|
state_factory=state_factory,
|
|
358
394
|
certificates=certificates,
|
|
395
|
+
interceptors=interceptors,
|
|
359
396
|
)
|
|
360
397
|
grpc_servers.append(fleet_server)
|
|
361
398
|
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
|
@@ -388,6 +425,70 @@ def run_superlink() -> None:
|
|
|
388
425
|
driver_server.wait_for_termination(timeout=1)
|
|
389
426
|
|
|
390
427
|
|
|
428
|
+
def _try_setup_client_authentication(
|
|
429
|
+
args: argparse.Namespace,
|
|
430
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
431
|
+
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
432
|
+
if not args.require_client_authentication:
|
|
433
|
+
return None
|
|
434
|
+
|
|
435
|
+
if certificates is None:
|
|
436
|
+
sys.exit(
|
|
437
|
+
"Client authentication only works over secure connections. "
|
|
438
|
+
"Please provide certificate paths using '--certificates' when "
|
|
439
|
+
"enabling '--require-client-authentication'."
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
client_keys_file_path = Path(args.require_client_authentication[0])
|
|
443
|
+
if not client_keys_file_path.exists():
|
|
444
|
+
sys.exit(
|
|
445
|
+
"The provided path to the client public keys CSV file does not exist: "
|
|
446
|
+
f"{client_keys_file_path}. "
|
|
447
|
+
"Please provide the CSV file path containing known client public keys "
|
|
448
|
+
"to '--require-client-authentication'."
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
client_public_keys: Set[bytes] = set()
|
|
452
|
+
ssh_private_key = load_ssh_private_key(
|
|
453
|
+
Path(args.require_client_authentication[1]).read_bytes(),
|
|
454
|
+
None,
|
|
455
|
+
)
|
|
456
|
+
ssh_public_key = load_ssh_public_key(
|
|
457
|
+
Path(args.require_client_authentication[2]).read_bytes()
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
try:
|
|
461
|
+
server_private_key, server_public_key = ssh_types_to_elliptic_curve(
|
|
462
|
+
ssh_private_key, ssh_public_key
|
|
463
|
+
)
|
|
464
|
+
except TypeError:
|
|
465
|
+
sys.exit(
|
|
466
|
+
"The file paths provided could not be read as a private and public "
|
|
467
|
+
"key pair. Client authentication requires an elliptic curve public and "
|
|
468
|
+
"private key pair. Please provide the file paths containing elliptic "
|
|
469
|
+
"curve private and public keys to '--require-client-authentication'."
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
473
|
+
reader = csv.reader(csvfile)
|
|
474
|
+
for row in reader:
|
|
475
|
+
for element in row:
|
|
476
|
+
public_key = load_ssh_public_key(element.encode())
|
|
477
|
+
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
478
|
+
client_public_keys.add(public_key_to_bytes(public_key))
|
|
479
|
+
else:
|
|
480
|
+
sys.exit(
|
|
481
|
+
"Error: Unable to parse the public keys in the .csv "
|
|
482
|
+
"file. Please ensure that the .csv file contains valid "
|
|
483
|
+
"SSH public keys and try again."
|
|
484
|
+
)
|
|
485
|
+
return (
|
|
486
|
+
client_public_keys,
|
|
487
|
+
server_private_key,
|
|
488
|
+
server_public_key,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
|
|
391
492
|
def _try_obtain_certificates(
|
|
392
493
|
args: argparse.Namespace,
|
|
393
494
|
) -> Optional[Tuple[bytes, bytes, bytes]]:
|
|
@@ -415,6 +516,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
415
516
|
address: str,
|
|
416
517
|
state_factory: StateFactory,
|
|
417
518
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
519
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
418
520
|
) -> grpc.Server:
|
|
419
521
|
"""Run Fleet API (gRPC, request-response)."""
|
|
420
522
|
# Create Fleet API gRPC server
|
|
@@ -427,6 +529,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
427
529
|
server_address=address,
|
|
428
530
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
429
531
|
certificates=certificates,
|
|
532
|
+
interceptors=interceptors,
|
|
430
533
|
)
|
|
431
534
|
|
|
432
535
|
log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
|
|
@@ -568,9 +671,7 @@ def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
|
|
|
568
671
|
def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
569
672
|
"""Parse command line arguments for both Driver API and Fleet API."""
|
|
570
673
|
parser = argparse.ArgumentParser(
|
|
571
|
-
description="
|
|
572
|
-
"(meaning, a Driver API and a Fleet API), "
|
|
573
|
-
"that clients will be able to connect to.",
|
|
674
|
+
description="Start a Flower SuperLink",
|
|
574
675
|
)
|
|
575
676
|
|
|
576
677
|
_add_args_common(parser=parser)
|
|
@@ -606,6 +707,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
606
707
|
"Flower will just create a state in memory.",
|
|
607
708
|
default=DATABASE,
|
|
608
709
|
)
|
|
710
|
+
parser.add_argument(
|
|
711
|
+
"--require-client-authentication",
|
|
712
|
+
nargs=3,
|
|
713
|
+
metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
|
|
714
|
+
type=str,
|
|
715
|
+
help="Provide three file paths: (1) a .csv file containing a list of "
|
|
716
|
+
"known client public keys for authentication, (2) the server's private "
|
|
717
|
+
"key file, and (3) the server's public key file.",
|
|
718
|
+
)
|
|
609
719
|
|
|
610
720
|
|
|
611
721
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
flwr/server/compat/app.py
CHANGED
|
@@ -29,7 +29,7 @@ from flwr.server.server import Server, init_defaults, run_fl
|
|
|
29
29
|
from flwr.server.server_config import ServerConfig
|
|
30
30
|
from flwr.server.strategy import Strategy
|
|
31
31
|
|
|
32
|
-
from ..driver import Driver
|
|
32
|
+
from ..driver import Driver, GrpcDriver
|
|
33
33
|
from .app_utils import start_update_client_manager_thread
|
|
34
34
|
|
|
35
35
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
@@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
114
114
|
# Create the Driver
|
|
115
115
|
if isinstance(root_certificates, str):
|
|
116
116
|
root_certificates = Path(root_certificates).read_bytes()
|
|
117
|
-
driver =
|
|
117
|
+
driver = GrpcDriver(
|
|
118
118
|
driver_service_address=address, root_certificates=root_certificates
|
|
119
119
|
)
|
|
120
120
|
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -16,16 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import
|
|
22
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
|
-
from flwr.common import serde
|
|
25
|
-
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
24
|
from flwr.server.client_proxy import ClientProxy
|
|
27
25
|
|
|
28
|
-
from ..driver.
|
|
26
|
+
from ..driver.driver import Driver
|
|
29
27
|
|
|
30
28
|
SLEEP_TIME = 1
|
|
31
29
|
|
|
@@ -33,7 +31,7 @@ SLEEP_TIME = 1
|
|
|
33
31
|
class DriverClientProxy(ClientProxy):
|
|
34
32
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
35
33
|
|
|
36
|
-
def __init__(self, node_id: int, driver:
|
|
34
|
+
def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
|
|
37
35
|
super().__init__(str(node_id))
|
|
38
36
|
self.node_id = node_id
|
|
39
37
|
self.driver = driver
|
|
@@ -114,80 +112,39 @@ class DriverClientProxy(ClientProxy):
|
|
|
114
112
|
timeout: Optional[float],
|
|
115
113
|
group_id: Optional[int],
|
|
116
114
|
) -> RecordSet:
|
|
117
|
-
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
|
118
|
-
task_id="",
|
|
119
|
-
group_id=str(group_id) if group_id is not None else "",
|
|
120
|
-
run_id=self.run_id,
|
|
121
|
-
task=task_pb2.Task( # pylint: disable=E1101
|
|
122
|
-
producer=node_pb2.Node( # pylint: disable=E1101
|
|
123
|
-
node_id=0,
|
|
124
|
-
anonymous=True,
|
|
125
|
-
),
|
|
126
|
-
consumer=node_pb2.Node( # pylint: disable=E1101
|
|
127
|
-
node_id=self.node_id,
|
|
128
|
-
anonymous=self.anonymous,
|
|
129
|
-
),
|
|
130
|
-
task_type=task_type,
|
|
131
|
-
recordset=serde.recordset_to_proto(recordset),
|
|
132
|
-
ttl=DEFAULT_TTL,
|
|
133
|
-
),
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# This would normally be recorded upon common.Message creation
|
|
137
|
-
# but this compatibility stack doesn't create Messages,
|
|
138
|
-
# so we need to inject `created_at` manually (needed for
|
|
139
|
-
# taskins validation by server.utils.validator)
|
|
140
|
-
task_ins.task.created_at = time.time()
|
|
141
115
|
|
|
142
|
-
|
|
143
|
-
|
|
116
|
+
# Create message
|
|
117
|
+
message = self.driver.create_message(
|
|
118
|
+
content=recordset,
|
|
119
|
+
message_type=task_type,
|
|
120
|
+
dst_node_id=self.node_id,
|
|
121
|
+
group_id=str(group_id) if group_id else "",
|
|
122
|
+
ttl=timeout,
|
|
144
123
|
)
|
|
145
124
|
|
|
146
|
-
#
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
raise ValueError("Unexpected number of task_ids")
|
|
125
|
+
# Push message
|
|
126
|
+
message_ids = list(self.driver.push_messages(messages=[message]))
|
|
127
|
+
if len(message_ids) != 1:
|
|
128
|
+
raise ValueError("Unexpected number of message_ids")
|
|
151
129
|
|
|
152
|
-
|
|
153
|
-
if
|
|
154
|
-
raise ValueError(f"Failed to
|
|
130
|
+
message_id = message_ids[0]
|
|
131
|
+
if message_id == "":
|
|
132
|
+
raise ValueError(f"Failed to send message to node {self.node_id}")
|
|
155
133
|
|
|
156
134
|
if timeout:
|
|
157
135
|
start_time = time.time()
|
|
158
136
|
|
|
159
137
|
while True:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
pull_task_res_res.task_res_list
|
|
170
|
-
)
|
|
171
|
-
if len(task_res_list) == 1:
|
|
172
|
-
task_res = task_res_list[0]
|
|
173
|
-
|
|
174
|
-
# This will raise an Exception if task_res carries an `error`
|
|
175
|
-
validate_task_res(task_res=task_res)
|
|
176
|
-
|
|
177
|
-
return serde.recordset_from_proto(task_res.task.recordset)
|
|
138
|
+
messages = list(self.driver.pull_messages(message_ids))
|
|
139
|
+
if len(messages) == 1:
|
|
140
|
+
msg: Message = messages[0]
|
|
141
|
+
if msg.has_error():
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
144
|
+
"It originated during client-side execution of a message."
|
|
145
|
+
)
|
|
146
|
+
return msg.content
|
|
178
147
|
|
|
179
148
|
if timeout is not None and time.time() > start_time + timeout:
|
|
180
149
|
raise RuntimeError("Timeout reached")
|
|
181
150
|
time.sleep(SLEEP_TIME)
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def validate_task_res(
|
|
185
|
-
task_res: task_pb2.TaskRes, # pylint: disable=E1101
|
|
186
|
-
) -> None:
|
|
187
|
-
"""Validate if a TaskRes is empty or not."""
|
|
188
|
-
if not task_res.HasField("task"):
|
|
189
|
-
raise ValueError("Invalid TaskRes, field `task` missing")
|
|
190
|
-
if task_res.task.HasField("error"):
|
|
191
|
-
raise ValueError("Exception during client-side task execution")
|
|
192
|
-
if not task_res.task.HasField("recordset"):
|
|
193
|
-
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
|
flwr/server/driver/__init__.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,79 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Driver (abstract base class)."""
|
|
16
16
|
|
|
17
|
-
import time
|
|
18
|
-
import warnings
|
|
19
|
-
from typing import Iterable, List, Optional, Tuple
|
|
20
17
|
|
|
21
|
-
from
|
|
22
|
-
from
|
|
23
|
-
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
24
|
-
CreateRunRequest,
|
|
25
|
-
GetNodesRequest,
|
|
26
|
-
PullTaskResRequest,
|
|
27
|
-
PushTaskInsRequest,
|
|
28
|
-
)
|
|
29
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
30
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Iterable, List, Optional
|
|
31
20
|
|
|
32
|
-
from .
|
|
21
|
+
from flwr.common import Message, RecordSet
|
|
33
22
|
|
|
34
23
|
|
|
35
|
-
class Driver:
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
Parameters
|
|
39
|
-
----------
|
|
40
|
-
driver_service_address : Optional[str]
|
|
41
|
-
The IPv4 or IPv6 address of the Driver API server.
|
|
42
|
-
Defaults to `"[::]:9091"`.
|
|
43
|
-
certificates : bytes (default: None)
|
|
44
|
-
Tuple containing root certificate, server certificate, and private key
|
|
45
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
46
|
-
three bytes elements in the following order:
|
|
47
|
-
|
|
48
|
-
* CA certificate.
|
|
49
|
-
* server certificate.
|
|
50
|
-
* server private key.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
def __init__(
|
|
54
|
-
self,
|
|
55
|
-
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
56
|
-
root_certificates: Optional[bytes] = None,
|
|
57
|
-
) -> None:
|
|
58
|
-
self.addr = driver_service_address
|
|
59
|
-
self.root_certificates = root_certificates
|
|
60
|
-
self.grpc_driver: Optional[GrpcDriver] = None
|
|
61
|
-
self.run_id: Optional[int] = None
|
|
62
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
63
|
-
|
|
64
|
-
def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]:
|
|
65
|
-
# Check if the GrpcDriver is initialized
|
|
66
|
-
if self.grpc_driver is None or self.run_id is None:
|
|
67
|
-
# Connect and create run
|
|
68
|
-
self.grpc_driver = GrpcDriver(
|
|
69
|
-
driver_service_address=self.addr,
|
|
70
|
-
root_certificates=self.root_certificates,
|
|
71
|
-
)
|
|
72
|
-
self.grpc_driver.connect()
|
|
73
|
-
res = self.grpc_driver.create_run(CreateRunRequest())
|
|
74
|
-
self.run_id = res.run_id
|
|
75
|
-
return self.grpc_driver, self.run_id
|
|
76
|
-
|
|
77
|
-
def _check_message(self, message: Message) -> None:
|
|
78
|
-
# Check if the message is valid
|
|
79
|
-
if not (
|
|
80
|
-
message.metadata.run_id == self.run_id
|
|
81
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
82
|
-
and message.metadata.message_id == ""
|
|
83
|
-
and message.metadata.reply_to_message == ""
|
|
84
|
-
and message.metadata.ttl > 0
|
|
85
|
-
):
|
|
86
|
-
raise ValueError(f"Invalid message: {message}")
|
|
24
|
+
class Driver(ABC):
|
|
25
|
+
"""Abstract base Driver class for the Driver API."""
|
|
87
26
|
|
|
27
|
+
@abstractmethod
|
|
88
28
|
def create_message( # pylint: disable=too-many-arguments
|
|
89
29
|
self,
|
|
90
30
|
content: RecordSet,
|
|
@@ -122,35 +62,12 @@ class Driver:
|
|
|
122
62
|
message : Message
|
|
123
63
|
A new `Message` instance with the specified content and metadata.
|
|
124
64
|
"""
|
|
125
|
-
_, run_id = self._get_grpc_driver_and_run_id()
|
|
126
|
-
if ttl:
|
|
127
|
-
warnings.warn(
|
|
128
|
-
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
129
|
-
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
130
|
-
"version of Flower.",
|
|
131
|
-
stacklevel=2,
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
135
|
-
metadata = Metadata(
|
|
136
|
-
run_id=run_id,
|
|
137
|
-
message_id="", # Will be set by the server
|
|
138
|
-
src_node_id=self.node.node_id,
|
|
139
|
-
dst_node_id=dst_node_id,
|
|
140
|
-
reply_to_message="",
|
|
141
|
-
group_id=group_id,
|
|
142
|
-
ttl=ttl_,
|
|
143
|
-
message_type=message_type,
|
|
144
|
-
)
|
|
145
|
-
return Message(metadata=metadata, content=content)
|
|
146
65
|
|
|
66
|
+
@abstractmethod
|
|
147
67
|
def get_node_ids(self) -> List[int]:
|
|
148
68
|
"""Get node IDs."""
|
|
149
|
-
grpc_driver, run_id = self._get_grpc_driver_and_run_id()
|
|
150
|
-
# Call GrpcDriver method
|
|
151
|
-
res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id))
|
|
152
|
-
return [node.node_id for node in res.nodes]
|
|
153
69
|
|
|
70
|
+
@abstractmethod
|
|
154
71
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
155
72
|
"""Push messages to specified node IDs.
|
|
156
73
|
|
|
@@ -168,20 +85,8 @@ class Driver:
|
|
|
168
85
|
An iterable of IDs for the messages that were sent, which can be used
|
|
169
86
|
to pull replies.
|
|
170
87
|
"""
|
|
171
|
-
grpc_driver, _ = self._get_grpc_driver_and_run_id()
|
|
172
|
-
# Construct TaskIns
|
|
173
|
-
task_ins_list: List[TaskIns] = []
|
|
174
|
-
for msg in messages:
|
|
175
|
-
# Check message
|
|
176
|
-
self._check_message(msg)
|
|
177
|
-
# Convert Message to TaskIns
|
|
178
|
-
taskins = message_to_taskins(msg)
|
|
179
|
-
# Add to list
|
|
180
|
-
task_ins_list.append(taskins)
|
|
181
|
-
# Call GrpcDriver method
|
|
182
|
-
res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
|
|
183
|
-
return list(res.task_ids)
|
|
184
88
|
|
|
89
|
+
@abstractmethod
|
|
185
90
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
186
91
|
"""Pull messages based on message IDs.
|
|
187
92
|
|
|
@@ -198,15 +103,8 @@ class Driver:
|
|
|
198
103
|
messages : Iterable[Message]
|
|
199
104
|
An iterable of messages received.
|
|
200
105
|
"""
|
|
201
|
-
grpc_driver, _ = self._get_grpc_driver_and_run_id()
|
|
202
|
-
# Pull TaskRes
|
|
203
|
-
res = grpc_driver.pull_task_res(
|
|
204
|
-
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
205
|
-
)
|
|
206
|
-
# Convert TaskRes to Message
|
|
207
|
-
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
208
|
-
return msgs
|
|
209
106
|
|
|
107
|
+
@abstractmethod
|
|
210
108
|
def send_and_receive(
|
|
211
109
|
self,
|
|
212
110
|
messages: Iterable[Message],
|
|
@@ -240,28 +138,3 @@ class Driver:
|
|
|
240
138
|
replies for all sent messages. A message remains valid until its TTL,
|
|
241
139
|
which is not affected by `timeout`.
|
|
242
140
|
"""
|
|
243
|
-
# Push messages
|
|
244
|
-
msg_ids = set(self.push_messages(messages))
|
|
245
|
-
|
|
246
|
-
# Pull messages
|
|
247
|
-
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
248
|
-
ret: List[Message] = []
|
|
249
|
-
while timeout is None or time.time() < end_time:
|
|
250
|
-
res_msgs = self.pull_messages(msg_ids)
|
|
251
|
-
ret.extend(res_msgs)
|
|
252
|
-
msg_ids.difference_update(
|
|
253
|
-
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
254
|
-
)
|
|
255
|
-
if len(msg_ids) == 0:
|
|
256
|
-
break
|
|
257
|
-
# Sleep
|
|
258
|
-
time.sleep(3)
|
|
259
|
-
return ret
|
|
260
|
-
|
|
261
|
-
def close(self) -> None:
|
|
262
|
-
"""Disconnect from the SuperLink if connected."""
|
|
263
|
-
# Check if GrpcDriver is initialized
|
|
264
|
-
if self.grpc_driver is None:
|
|
265
|
-
return
|
|
266
|
-
# Disconnect
|
|
267
|
-
self.grpc_driver.disconnect()
|