flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +44 -18
- flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/__init__.py +0 -2
- flwr/server/app.py +118 -2
- flwr/server/compat/app.py +5 -56
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +18 -4
- flwr/server/server.py +2 -5
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +9 -6
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +14 -9
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/RECORD +70 -55
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/WHEEL +0 -0
flwr/common/record/recordset.py
CHANGED
|
@@ -24,6 +24,7 @@ from .parametersrecord import ParametersRecord
|
|
|
24
24
|
from .typeddict import TypedDict
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
@dataclass
|
|
27
28
|
class RecordSetData:
|
|
28
29
|
"""Inner data container for the RecordSet class."""
|
|
29
30
|
|
|
@@ -82,7 +83,6 @@ class RecordSetData:
|
|
|
82
83
|
)
|
|
83
84
|
|
|
84
85
|
|
|
85
|
-
@dataclass
|
|
86
86
|
class RecordSet:
|
|
87
87
|
"""RecordSet stores groups of parameters, metrics and configs."""
|
|
88
88
|
|
|
@@ -97,22 +97,34 @@ class RecordSet:
|
|
|
97
97
|
metrics_records=metrics_records,
|
|
98
98
|
configs_records=configs_records,
|
|
99
99
|
)
|
|
100
|
-
|
|
100
|
+
self.__dict__["_data"] = data
|
|
101
101
|
|
|
102
102
|
@property
|
|
103
103
|
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
|
|
104
104
|
"""Dictionary holding ParametersRecord instances."""
|
|
105
|
-
data = cast(RecordSetData,
|
|
105
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
106
106
|
return data.parameters_records
|
|
107
107
|
|
|
108
108
|
@property
|
|
109
109
|
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
|
|
110
110
|
"""Dictionary holding MetricsRecord instances."""
|
|
111
|
-
data = cast(RecordSetData,
|
|
111
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
112
112
|
return data.metrics_records
|
|
113
113
|
|
|
114
114
|
@property
|
|
115
115
|
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
|
|
116
116
|
"""Dictionary holding ConfigsRecord instances."""
|
|
117
|
-
data = cast(RecordSetData,
|
|
117
|
+
data = cast(RecordSetData, self.__dict__["_data"])
|
|
118
118
|
return data.configs_records
|
|
119
|
+
|
|
120
|
+
def __repr__(self) -> str:
|
|
121
|
+
"""Return a string representation of this instance."""
|
|
122
|
+
flds = ("parameters_records", "metrics_records", "configs_records")
|
|
123
|
+
view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
|
|
124
|
+
return f"{self.__class__.__qualname__}({view})"
|
|
125
|
+
|
|
126
|
+
def __eq__(self, other: object) -> bool:
|
|
127
|
+
"""Compare two instances of the class."""
|
|
128
|
+
if not isinstance(other, self.__class__):
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
return self.__dict__ == other.__dict__
|
|
@@ -18,8 +18,9 @@
|
|
|
18
18
|
import base64
|
|
19
19
|
from typing import Tuple, cast
|
|
20
20
|
|
|
21
|
+
from cryptography.exceptions import InvalidSignature
|
|
21
22
|
from cryptography.fernet import Fernet
|
|
22
|
-
from cryptography.hazmat.primitives import hashes, serialization
|
|
23
|
+
from cryptography.hazmat.primitives import hashes, hmac, serialization
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
25
26
|
|
|
@@ -98,3 +99,36 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
|
|
|
98
99
|
# The input key must be url safe
|
|
99
100
|
fernet = Fernet(key)
|
|
100
101
|
return fernet.decrypt(ciphertext)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_hmac(key: bytes, message: bytes) -> bytes:
|
|
105
|
+
"""Compute hmac of a message using key as hash."""
|
|
106
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
107
|
+
computed_hmac.update(message)
|
|
108
|
+
return computed_hmac.finalize()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
112
|
+
"""Verify hmac of a message using key as hash."""
|
|
113
|
+
computed_hmac = hmac.HMAC(key, hashes.SHA256())
|
|
114
|
+
computed_hmac.update(message)
|
|
115
|
+
try:
|
|
116
|
+
computed_hmac.verify(hmac_value)
|
|
117
|
+
return True
|
|
118
|
+
except InvalidSignature:
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def ssh_types_to_elliptic_curve(
|
|
123
|
+
private_key: serialization.SSHPrivateKeyTypes,
|
|
124
|
+
public_key: serialization.SSHPublicKeyTypes,
|
|
125
|
+
) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
|
|
126
|
+
"""Cast SSH key types to elliptic curve."""
|
|
127
|
+
if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
|
|
128
|
+
public_key, ec.EllipticCurvePublicKey
|
|
129
|
+
):
|
|
130
|
+
return (private_key, public_key)
|
|
131
|
+
|
|
132
|
+
raise TypeError(
|
|
133
|
+
"The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
|
|
134
|
+
)
|
flwr/server/__init__.py
CHANGED
|
@@ -24,7 +24,6 @@ from .app import start_server as start_server
|
|
|
24
24
|
from .client_manager import ClientManager as ClientManager
|
|
25
25
|
from .client_manager import SimpleClientManager as SimpleClientManager
|
|
26
26
|
from .compat import LegacyContext as LegacyContext
|
|
27
|
-
from .compat import start_driver as start_driver
|
|
28
27
|
from .driver import Driver as Driver
|
|
29
28
|
from .history import History as History
|
|
30
29
|
from .run_serverapp import run_server_app as run_server_app
|
|
@@ -45,7 +44,6 @@ __all__ = [
|
|
|
45
44
|
"ServerApp",
|
|
46
45
|
"ServerConfig",
|
|
47
46
|
"SimpleClientManager",
|
|
48
|
-
"start_driver",
|
|
49
47
|
"start_server",
|
|
50
48
|
"strategy",
|
|
51
49
|
"workflow",
|
flwr/server/app.py
CHANGED
|
@@ -16,15 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import asyncio
|
|
19
|
+
import csv
|
|
19
20
|
import importlib.util
|
|
20
21
|
import sys
|
|
21
22
|
import threading
|
|
22
23
|
from logging import ERROR, INFO, WARN
|
|
23
24
|
from os.path import isfile
|
|
24
25
|
from pathlib import Path
|
|
25
|
-
from typing import List, Optional, Tuple
|
|
26
|
+
from typing import List, Optional, Sequence, Set, Tuple
|
|
26
27
|
|
|
27
28
|
import grpc
|
|
29
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
30
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
31
|
+
load_ssh_private_key,
|
|
32
|
+
load_ssh_public_key,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
30
36
|
from flwr.common.address import parse_address
|
|
@@ -35,7 +41,12 @@ from flwr.common.constant import (
|
|
|
35
41
|
TRANSPORT_TYPE_VCE,
|
|
36
42
|
)
|
|
37
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
38
|
-
from flwr.common.logger import log
|
|
44
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
|
45
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
46
|
+
private_key_to_bytes,
|
|
47
|
+
public_key_to_bytes,
|
|
48
|
+
ssh_types_to_elliptic_curve,
|
|
49
|
+
)
|
|
39
50
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
40
51
|
add_FleetServicer_to_server,
|
|
41
52
|
)
|
|
@@ -51,6 +62,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
51
62
|
start_grpc_server,
|
|
52
63
|
)
|
|
53
64
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
65
|
+
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
54
66
|
from .superlink.fleet.vce import start_vce
|
|
55
67
|
from .superlink.state import StateFactory
|
|
56
68
|
|
|
@@ -184,6 +196,9 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
184
196
|
def run_driver_api() -> None:
|
|
185
197
|
"""Run Flower server (Driver API)."""
|
|
186
198
|
log(INFO, "Starting Flower server (Driver API)")
|
|
199
|
+
# Running `flower-driver-api` is deprecated
|
|
200
|
+
warn_deprecated_feature("flower-driver-api")
|
|
201
|
+
log(WARN, "Use `flower-superlink` instead")
|
|
187
202
|
event(EventType.RUN_DRIVER_API_ENTER)
|
|
188
203
|
args = _parse_args_run_driver_api().parse_args()
|
|
189
204
|
|
|
@@ -221,6 +236,9 @@ def run_driver_api() -> None:
|
|
|
221
236
|
def run_fleet_api() -> None:
|
|
222
237
|
"""Run Flower server (Fleet API)."""
|
|
223
238
|
log(INFO, "Starting Flower server (Fleet API)")
|
|
239
|
+
# Running `flower-fleet-api` is deprecated
|
|
240
|
+
warn_deprecated_feature("flower-fleet-api")
|
|
241
|
+
log(WARN, "Use `flower-superlink` instead")
|
|
224
242
|
event(EventType.RUN_FLEET_API_ENTER)
|
|
225
243
|
args = _parse_args_run_fleet_api().parse_args()
|
|
226
244
|
|
|
@@ -354,10 +372,33 @@ def run_superlink() -> None:
|
|
|
354
372
|
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
355
373
|
host, port, is_v6 = parsed_address
|
|
356
374
|
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
375
|
+
|
|
376
|
+
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
377
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
378
|
+
if maybe_keys is not None:
|
|
379
|
+
(
|
|
380
|
+
client_public_keys,
|
|
381
|
+
server_private_key,
|
|
382
|
+
server_public_key,
|
|
383
|
+
) = maybe_keys
|
|
384
|
+
state = state_factory.state()
|
|
385
|
+
state.store_client_public_keys(client_public_keys)
|
|
386
|
+
state.store_server_private_public_key(
|
|
387
|
+
private_key_to_bytes(server_private_key),
|
|
388
|
+
public_key_to_bytes(server_public_key),
|
|
389
|
+
)
|
|
390
|
+
log(
|
|
391
|
+
INFO,
|
|
392
|
+
"Client authentication enabled with %d known public keys",
|
|
393
|
+
len(client_public_keys),
|
|
394
|
+
)
|
|
395
|
+
interceptors = [AuthenticateServerInterceptor(state)]
|
|
396
|
+
|
|
357
397
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
358
398
|
address=address,
|
|
359
399
|
state_factory=state_factory,
|
|
360
400
|
certificates=certificates,
|
|
401
|
+
interceptors=interceptors,
|
|
361
402
|
)
|
|
362
403
|
grpc_servers.append(fleet_server)
|
|
363
404
|
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
|
@@ -390,6 +431,70 @@ def run_superlink() -> None:
|
|
|
390
431
|
driver_server.wait_for_termination(timeout=1)
|
|
391
432
|
|
|
392
433
|
|
|
434
|
+
def _try_setup_client_authentication(
|
|
435
|
+
args: argparse.Namespace,
|
|
436
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
437
|
+
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
438
|
+
if not args.require_client_authentication:
|
|
439
|
+
return None
|
|
440
|
+
|
|
441
|
+
if certificates is None:
|
|
442
|
+
sys.exit(
|
|
443
|
+
"Client authentication only works over secure connections. "
|
|
444
|
+
"Please provide certificate paths using '--certificates' when "
|
|
445
|
+
"enabling '--require-client-authentication'."
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
client_keys_file_path = Path(args.require_client_authentication[0])
|
|
449
|
+
if not client_keys_file_path.exists():
|
|
450
|
+
sys.exit(
|
|
451
|
+
"The provided path to the client public keys CSV file does not exist: "
|
|
452
|
+
f"{client_keys_file_path}. "
|
|
453
|
+
"Please provide the CSV file path containing known client public keys "
|
|
454
|
+
"to '--require-client-authentication'."
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
client_public_keys: Set[bytes] = set()
|
|
458
|
+
ssh_private_key = load_ssh_private_key(
|
|
459
|
+
Path(args.require_client_authentication[1]).read_bytes(),
|
|
460
|
+
None,
|
|
461
|
+
)
|
|
462
|
+
ssh_public_key = load_ssh_public_key(
|
|
463
|
+
Path(args.require_client_authentication[2]).read_bytes()
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
try:
|
|
467
|
+
server_private_key, server_public_key = ssh_types_to_elliptic_curve(
|
|
468
|
+
ssh_private_key, ssh_public_key
|
|
469
|
+
)
|
|
470
|
+
except TypeError:
|
|
471
|
+
sys.exit(
|
|
472
|
+
"The file paths provided could not be read as a private and public "
|
|
473
|
+
"key pair. Client authentication requires an elliptic curve public and "
|
|
474
|
+
"private key pair. Please provide the file paths containing elliptic "
|
|
475
|
+
"curve private and public keys to '--require-client-authentication'."
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
479
|
+
reader = csv.reader(csvfile)
|
|
480
|
+
for row in reader:
|
|
481
|
+
for element in row:
|
|
482
|
+
public_key = load_ssh_public_key(element.encode())
|
|
483
|
+
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
484
|
+
client_public_keys.add(public_key_to_bytes(public_key))
|
|
485
|
+
else:
|
|
486
|
+
sys.exit(
|
|
487
|
+
"Error: Unable to parse the public keys in the .csv "
|
|
488
|
+
"file. Please ensure that the .csv file contains valid "
|
|
489
|
+
"SSH public keys and try again."
|
|
490
|
+
)
|
|
491
|
+
return (
|
|
492
|
+
client_public_keys,
|
|
493
|
+
server_private_key,
|
|
494
|
+
server_public_key,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
|
|
393
498
|
def _try_obtain_certificates(
|
|
394
499
|
args: argparse.Namespace,
|
|
395
500
|
) -> Optional[Tuple[bytes, bytes, bytes]]:
|
|
@@ -417,6 +522,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
417
522
|
address: str,
|
|
418
523
|
state_factory: StateFactory,
|
|
419
524
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
525
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
420
526
|
) -> grpc.Server:
|
|
421
527
|
"""Run Fleet API (gRPC, request-response)."""
|
|
422
528
|
# Create Fleet API gRPC server
|
|
@@ -429,6 +535,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
429
535
|
server_address=address,
|
|
430
536
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
431
537
|
certificates=certificates,
|
|
538
|
+
interceptors=interceptors,
|
|
432
539
|
)
|
|
433
540
|
|
|
434
541
|
log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
|
|
@@ -606,6 +713,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
606
713
|
"Flower will just create a state in memory.",
|
|
607
714
|
default=DATABASE,
|
|
608
715
|
)
|
|
716
|
+
parser.add_argument(
|
|
717
|
+
"--require-client-authentication",
|
|
718
|
+
nargs=3,
|
|
719
|
+
metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
|
|
720
|
+
type=str,
|
|
721
|
+
help="Provide three file paths: (1) a .csv file containing a list of "
|
|
722
|
+
"known client public keys for authentication, (2) the server's private "
|
|
723
|
+
"key file, and (3) the server's public key file.",
|
|
724
|
+
)
|
|
609
725
|
|
|
610
726
|
|
|
611
727
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
flwr/server/compat/app.py
CHANGED
|
@@ -15,14 +15,11 @@
|
|
|
15
15
|
"""Flower driver app."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import sys
|
|
19
18
|
from logging import INFO
|
|
20
|
-
from
|
|
21
|
-
from typing import Optional, Union
|
|
19
|
+
from typing import Optional
|
|
22
20
|
|
|
23
21
|
from flwr.common import EventType, event
|
|
24
|
-
from flwr.common.
|
|
25
|
-
from flwr.common.logger import log, warn_deprecated_feature
|
|
22
|
+
from flwr.common.logger import log
|
|
26
23
|
from flwr.server.client_manager import ClientManager
|
|
27
24
|
from flwr.server.history import History
|
|
28
25
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
@@ -32,33 +29,21 @@ from flwr.server.strategy import Strategy
|
|
|
32
29
|
from ..driver import Driver
|
|
33
30
|
from .app_utils import start_update_client_manager_thread
|
|
34
31
|
|
|
35
|
-
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
36
|
-
|
|
37
|
-
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
38
|
-
[Driver] Error: Not connected.
|
|
39
|
-
|
|
40
|
-
Call `connect()` on the `Driver` instance before calling any of the other `Driver`
|
|
41
|
-
methods.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
32
|
|
|
45
33
|
def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
46
34
|
*,
|
|
47
|
-
|
|
35
|
+
driver: Driver,
|
|
48
36
|
server: Optional[Server] = None,
|
|
49
37
|
config: Optional[ServerConfig] = None,
|
|
50
38
|
strategy: Optional[Strategy] = None,
|
|
51
39
|
client_manager: Optional[ClientManager] = None,
|
|
52
|
-
root_certificates: Optional[Union[bytes, str]] = None,
|
|
53
|
-
driver: Optional[Driver] = None,
|
|
54
40
|
) -> History:
|
|
55
41
|
"""Start a Flower Driver API server.
|
|
56
42
|
|
|
57
43
|
Parameters
|
|
58
44
|
----------
|
|
59
|
-
|
|
60
|
-
The
|
|
61
|
-
Defaults to `"[::]:8080"`.
|
|
45
|
+
driver : Driver
|
|
46
|
+
The Driver object to use.
|
|
62
47
|
server : Optional[flwr.server.Server] (default: None)
|
|
63
48
|
A server implementation, either `flwr.server.Server` or a subclass
|
|
64
49
|
thereof. If no instance is provided, then `start_driver` will create
|
|
@@ -74,50 +59,14 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
74
59
|
An implementation of the class `flwr.server.ClientManager`. If no
|
|
75
60
|
implementation is provided, then `start_driver` will use
|
|
76
61
|
`flwr.server.SimpleClientManager`.
|
|
77
|
-
root_certificates : Optional[Union[bytes, str]] (default: None)
|
|
78
|
-
The PEM-encoded root certificates as a byte string or a path string.
|
|
79
|
-
If provided, a secure connection using the certificates will be
|
|
80
|
-
established to an SSL-enabled Flower server.
|
|
81
|
-
driver : Optional[Driver] (default: None)
|
|
82
|
-
The Driver object to use.
|
|
83
62
|
|
|
84
63
|
Returns
|
|
85
64
|
-------
|
|
86
65
|
hist : flwr.server.history.History
|
|
87
66
|
Object containing training and evaluation metrics.
|
|
88
|
-
|
|
89
|
-
Examples
|
|
90
|
-
--------
|
|
91
|
-
Starting a driver that connects to an insecure server:
|
|
92
|
-
|
|
93
|
-
>>> start_driver()
|
|
94
|
-
|
|
95
|
-
Starting a driver that connects to an SSL-enabled server:
|
|
96
|
-
|
|
97
|
-
>>> start_driver(
|
|
98
|
-
>>> root_certificates=Path("/crts/root.pem").read_bytes()
|
|
99
|
-
>>> )
|
|
100
67
|
"""
|
|
101
68
|
event(EventType.START_DRIVER_ENTER)
|
|
102
69
|
|
|
103
|
-
if driver is None:
|
|
104
|
-
# Not passing a `Driver` object is deprecated
|
|
105
|
-
warn_deprecated_feature("start_driver")
|
|
106
|
-
|
|
107
|
-
# Parse IP address
|
|
108
|
-
parsed_address = parse_address(server_address)
|
|
109
|
-
if not parsed_address:
|
|
110
|
-
sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
|
|
111
|
-
host, port, is_v6 = parsed_address
|
|
112
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
113
|
-
|
|
114
|
-
# Create the Driver
|
|
115
|
-
if isinstance(root_certificates, str):
|
|
116
|
-
root_certificates = Path(root_certificates).read_bytes()
|
|
117
|
-
driver = Driver(
|
|
118
|
-
driver_service_address=address, root_certificates=root_certificates
|
|
119
|
-
)
|
|
120
|
-
|
|
121
70
|
# Initialize the Driver API server and config
|
|
122
71
|
initialized_server, initialized_config = init_defaults(
|
|
123
72
|
server=server,
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -16,16 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import
|
|
22
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
|
-
from flwr.common import serde
|
|
25
|
-
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
24
|
from flwr.server.client_proxy import ClientProxy
|
|
27
25
|
|
|
28
|
-
from ..driver.driver import
|
|
26
|
+
from ..driver.driver import Driver
|
|
29
27
|
|
|
30
28
|
SLEEP_TIME = 1
|
|
31
29
|
|
|
@@ -33,9 +31,7 @@ SLEEP_TIME = 1
|
|
|
33
31
|
class DriverClientProxy(ClientProxy):
|
|
34
32
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
35
33
|
|
|
36
|
-
def __init__(
|
|
37
|
-
self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
|
|
38
|
-
):
|
|
34
|
+
def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
|
|
39
35
|
super().__init__(str(node_id))
|
|
40
36
|
self.node_id = node_id
|
|
41
37
|
self.driver = driver
|
|
@@ -116,80 +112,39 @@ class DriverClientProxy(ClientProxy):
|
|
|
116
112
|
timeout: Optional[float],
|
|
117
113
|
group_id: Optional[int],
|
|
118
114
|
) -> RecordSet:
|
|
119
|
-
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
|
120
|
-
task_id="",
|
|
121
|
-
group_id=str(group_id) if group_id is not None else "",
|
|
122
|
-
run_id=self.run_id,
|
|
123
|
-
task=task_pb2.Task( # pylint: disable=E1101
|
|
124
|
-
producer=node_pb2.Node( # pylint: disable=E1101
|
|
125
|
-
node_id=0,
|
|
126
|
-
anonymous=True,
|
|
127
|
-
),
|
|
128
|
-
consumer=node_pb2.Node( # pylint: disable=E1101
|
|
129
|
-
node_id=self.node_id,
|
|
130
|
-
anonymous=self.anonymous,
|
|
131
|
-
),
|
|
132
|
-
task_type=task_type,
|
|
133
|
-
recordset=serde.recordset_to_proto(recordset),
|
|
134
|
-
ttl=DEFAULT_TTL,
|
|
135
|
-
),
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
# This would normally be recorded upon common.Message creation
|
|
139
|
-
# but this compatibility stack doesn't create Messages,
|
|
140
|
-
# so we need to inject `created_at` manually (needed for
|
|
141
|
-
# taskins validation by server.utils.validator)
|
|
142
|
-
task_ins.task.created_at = time.time()
|
|
143
115
|
|
|
144
|
-
|
|
145
|
-
|
|
116
|
+
# Create message
|
|
117
|
+
message = self.driver.create_message(
|
|
118
|
+
content=recordset,
|
|
119
|
+
message_type=task_type,
|
|
120
|
+
dst_node_id=self.node_id,
|
|
121
|
+
group_id=str(group_id) if group_id else "",
|
|
122
|
+
ttl=timeout,
|
|
146
123
|
)
|
|
147
124
|
|
|
148
|
-
#
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
raise ValueError("Unexpected number of task_ids")
|
|
125
|
+
# Push message
|
|
126
|
+
message_ids = list(self.driver.push_messages(messages=[message]))
|
|
127
|
+
if len(message_ids) != 1:
|
|
128
|
+
raise ValueError("Unexpected number of message_ids")
|
|
153
129
|
|
|
154
|
-
|
|
155
|
-
if
|
|
156
|
-
raise ValueError(f"Failed to
|
|
130
|
+
message_id = message_ids[0]
|
|
131
|
+
if message_id == "":
|
|
132
|
+
raise ValueError(f"Failed to send message to node {self.node_id}")
|
|
157
133
|
|
|
158
134
|
if timeout:
|
|
159
135
|
start_time = time.time()
|
|
160
136
|
|
|
161
137
|
while True:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
pull_task_res_res.task_res_list
|
|
172
|
-
)
|
|
173
|
-
if len(task_res_list) == 1:
|
|
174
|
-
task_res = task_res_list[0]
|
|
175
|
-
|
|
176
|
-
# This will raise an Exception if task_res carries an `error`
|
|
177
|
-
validate_task_res(task_res=task_res)
|
|
178
|
-
|
|
179
|
-
return serde.recordset_from_proto(task_res.task.recordset)
|
|
138
|
+
messages = list(self.driver.pull_messages(message_ids))
|
|
139
|
+
if len(messages) == 1:
|
|
140
|
+
msg: Message = messages[0]
|
|
141
|
+
if msg.has_error():
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
144
|
+
"It originated during client-side execution of a message."
|
|
145
|
+
)
|
|
146
|
+
return msg.content
|
|
180
147
|
|
|
181
148
|
if timeout is not None and time.time() > start_time + timeout:
|
|
182
149
|
raise RuntimeError("Timeout reached")
|
|
183
150
|
time.sleep(SLEEP_TIME)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def validate_task_res(
|
|
187
|
-
task_res: task_pb2.TaskRes, # pylint: disable=E1101
|
|
188
|
-
) -> None:
|
|
189
|
-
"""Validate if a TaskRes is empty or not."""
|
|
190
|
-
if not task_res.HasField("task"):
|
|
191
|
-
raise ValueError("Invalid TaskRes, field `task` missing")
|
|
192
|
-
if task_res.task.HasField("error"):
|
|
193
|
-
raise ValueError("Exception during client-side task execution")
|
|
194
|
-
if not task_res.task.HasField("recordset"):
|
|
195
|
-
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
|