flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.12.0.dev20240906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +0 -2
- flwr/cli/new/new.py +24 -10
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/run/run.py +2 -2
- flwr/client/__init__.py +0 -4
- flwr/client/grpc_rere_client/client_interceptor.py +13 -4
- flwr/client/supernode/app.py +3 -1
- flwr/common/config.py +14 -11
- flwr/common/telemetry.py +36 -30
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +13 -13
- flwr/server/compat/app.py +0 -5
- flwr/server/driver/grpc_driver.py +1 -3
- flwr/server/run_serverapp.py +15 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +11 -11
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/simulation/run_simulation.py +23 -6
- flwr/superexec/__init__.py +0 -6
- flwr/superexec/app.py +3 -1
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/METADATA +3 -3
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/RECORD +43 -35
- flwr_nightly-1.12.0.dev20240906.dist-info/entry_points.txt +10 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/WHEEL +0 -0
|
@@ -8,15 +8,15 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
13
|
-
"hydra-core==1.3.2",
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
14
13
|
"trl==0.8.1",
|
|
15
14
|
"bitsandbytes==0.43.0",
|
|
16
15
|
"scipy==1.13.0",
|
|
17
16
|
"peft==0.6.2",
|
|
18
17
|
"transformers==4.39.3",
|
|
19
18
|
"sentencepiece==0.2.0",
|
|
19
|
+
"omegaconf==2.3.0",
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
[tool.hatch.build.targets.wheel]
|
|
@@ -26,14 +26,41 @@ packages = ["."]
|
|
|
26
26
|
publisher = "$username"
|
|
27
27
|
|
|
28
28
|
[tool.flwr.app.components]
|
|
29
|
-
serverapp = "$import_name.app
|
|
30
|
-
clientapp = "$import_name.app
|
|
29
|
+
serverapp = "$import_name.server_app:app"
|
|
30
|
+
clientapp = "$import_name.client_app:app"
|
|
31
31
|
|
|
32
32
|
[tool.flwr.app.config]
|
|
33
|
-
|
|
33
|
+
model.name = "mistralai/Mistral-7B-v0.3"
|
|
34
|
+
model.quantization = 4
|
|
35
|
+
model.gradient-checkpointing = true
|
|
36
|
+
model.lora.peft-lora-r = 32
|
|
37
|
+
model.lora.peft-lora-alpha = 64
|
|
38
|
+
train.save-every-round = 5
|
|
39
|
+
train.learning-rate-max = 5e-5
|
|
40
|
+
train.learning-rate-min = 1e-6
|
|
41
|
+
train.seq-length = 512
|
|
42
|
+
train.training-arguments.output-dir = ""
|
|
43
|
+
train.training-arguments.learning-rate = ""
|
|
44
|
+
train.training-arguments.per-device-train-batch-size = 16
|
|
45
|
+
train.training-arguments.gradient-accumulation-steps = 1
|
|
46
|
+
train.training-arguments.logging-steps = 10
|
|
47
|
+
train.training-arguments.num-train-epochs = 3
|
|
48
|
+
train.training-arguments.max-steps = 10
|
|
49
|
+
train.training-arguments.save-steps = 1000
|
|
50
|
+
train.training-arguments.save-total-limit = 10
|
|
51
|
+
train.training-arguments.gradient-checkpointing = true
|
|
52
|
+
train.training-arguments.lr-scheduler-type = "constant"
|
|
53
|
+
strategy.fraction-fit = $fraction_fit
|
|
54
|
+
strategy.fraction-evaluate = 0.0
|
|
55
|
+
num-server-rounds = 200
|
|
56
|
+
|
|
57
|
+
[tool.flwr.app.config.static]
|
|
58
|
+
dataset.name = "$dataset_name"
|
|
34
59
|
|
|
35
60
|
[tool.flwr.federations]
|
|
36
61
|
default = "local-simulation"
|
|
37
62
|
|
|
38
63
|
[tool.flwr.federations.local-simulation]
|
|
39
|
-
options.num-supernodes =
|
|
64
|
+
options.num-supernodes = $num_clients
|
|
65
|
+
options.backend.client-resources.num-cpus = 6
|
|
66
|
+
options.backend.client-resources.num-gpus = 1.0
|
flwr/cli/run/run.py
CHANGED
|
@@ -124,14 +124,14 @@ def run(
|
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
def _run_with_superexec(
|
|
127
|
-
app:
|
|
127
|
+
app: Path,
|
|
128
128
|
federation_config: Dict[str, Any],
|
|
129
129
|
config_overrides: Optional[List[str]],
|
|
130
130
|
) -> None:
|
|
131
131
|
|
|
132
132
|
insecure_str = federation_config.get("insecure")
|
|
133
133
|
if root_certificates := federation_config.get("root-certificates"):
|
|
134
|
-
root_certificates_bytes =
|
|
134
|
+
root_certificates_bytes = (app / root_certificates).read_bytes()
|
|
135
135
|
if insecure := bool(insecure_str):
|
|
136
136
|
typer.secho(
|
|
137
137
|
"❌ `root_certificates` were provided but the `insecure` parameter"
|
flwr/client/__init__.py
CHANGED
|
@@ -20,8 +20,6 @@ from .app import start_numpy_client as start_numpy_client
|
|
|
20
20
|
from .client import Client as Client
|
|
21
21
|
from .client_app import ClientApp as ClientApp
|
|
22
22
|
from .numpy_client import NumPyClient as NumPyClient
|
|
23
|
-
from .supernode import run_client_app as run_client_app
|
|
24
|
-
from .supernode import run_supernode as run_supernode
|
|
25
23
|
from .typing import ClientFn as ClientFn
|
|
26
24
|
from .typing import ClientFnExt as ClientFnExt
|
|
27
25
|
|
|
@@ -32,8 +30,6 @@ __all__ = [
|
|
|
32
30
|
"ClientFnExt",
|
|
33
31
|
"NumPyClient",
|
|
34
32
|
"mod",
|
|
35
|
-
"run_client_app",
|
|
36
|
-
"run_supernode",
|
|
37
33
|
"start_client",
|
|
38
34
|
"start_numpy_client",
|
|
39
35
|
]
|
|
@@ -17,11 +17,13 @@
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
19
|
import collections
|
|
20
|
+
from logging import WARNING
|
|
20
21
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
22
|
|
|
22
23
|
import grpc
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
|
|
26
|
+
from flwr.common.logger import log
|
|
25
27
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
28
|
bytes_to_public_key,
|
|
27
29
|
compute_hmac,
|
|
@@ -151,8 +153,15 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
151
153
|
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
154
|
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
155
|
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.
|
|
157
|
-
|
|
156
|
+
|
|
157
|
+
if server_public_key_bytes != b"":
|
|
158
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
159
|
+
else:
|
|
160
|
+
log(WARNING, "Can't get server public key, SuperLink may be offline")
|
|
161
|
+
|
|
162
|
+
if self.server_public_key is not None:
|
|
163
|
+
self.shared_secret = generate_shared_key(
|
|
164
|
+
self.private_key, self.server_public_key
|
|
165
|
+
)
|
|
166
|
+
|
|
158
167
|
return response
|
flwr/client/supernode/app.py
CHANGED
|
@@ -77,7 +77,9 @@ def run_supernode() -> None:
|
|
|
77
77
|
authentication_keys=authentication_keys,
|
|
78
78
|
max_retries=args.max_retries,
|
|
79
79
|
max_wait_time=args.max_wait_time,
|
|
80
|
-
node_config=parse_config_args(
|
|
80
|
+
node_config=parse_config_args(
|
|
81
|
+
[args.node_config] if args.node_config else args.node_config
|
|
82
|
+
),
|
|
81
83
|
isolation=args.isolation,
|
|
82
84
|
supernode_address=args.supernode_address,
|
|
83
85
|
)
|
flwr/common/config.py
CHANGED
|
@@ -185,23 +185,26 @@ def parse_config_args(
|
|
|
185
185
|
if config is None:
|
|
186
186
|
return overrides
|
|
187
187
|
|
|
188
|
+
# Handle if .toml file is passed
|
|
189
|
+
if len(config) == 1 and config[0].endswith(".toml"):
|
|
190
|
+
with Path(config[0]).open("rb") as config_file:
|
|
191
|
+
overrides = flatten_dict(tomli.load(config_file))
|
|
192
|
+
return overrides
|
|
193
|
+
|
|
188
194
|
# Regular expression to capture key-value pairs with possible quoted values
|
|
189
195
|
pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")
|
|
190
196
|
|
|
191
197
|
for config_line in config:
|
|
192
198
|
if config_line:
|
|
193
|
-
|
|
199
|
+
# .toml files aren't allowed alongside other configs
|
|
200
|
+
if config_line.endswith(".toml"):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"TOML files cannot be passed alongside key-value pairs."
|
|
203
|
+
)
|
|
194
204
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
and matches[0][0].endswith(".toml")
|
|
199
|
-
):
|
|
200
|
-
with Path(matches[0][0]).open("rb") as config_file:
|
|
201
|
-
overrides = flatten_dict(tomli.load(config_file))
|
|
202
|
-
else:
|
|
203
|
-
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
|
204
|
-
overrides.update(tomli.loads(toml_str))
|
|
205
|
+
matches = pattern.findall(config_line)
|
|
206
|
+
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
|
207
|
+
overrides.update(tomli.loads(toml_str))
|
|
205
208
|
|
|
206
209
|
return overrides
|
|
207
210
|
|
flwr/common/telemetry.py
CHANGED
|
@@ -132,53 +132,59 @@ class EventType(str, Enum):
|
|
|
132
132
|
# Ping
|
|
133
133
|
PING = auto()
|
|
134
134
|
|
|
135
|
-
#
|
|
135
|
+
# --- LEGACY FUNCTIONS -------------------------------------------------------------
|
|
136
|
+
|
|
137
|
+
# Legacy: `start_client` function
|
|
136
138
|
START_CLIENT_ENTER = auto()
|
|
137
139
|
START_CLIENT_LEAVE = auto()
|
|
138
140
|
|
|
139
|
-
#
|
|
141
|
+
# Legacy: `start_server` function
|
|
140
142
|
START_SERVER_ENTER = auto()
|
|
141
143
|
START_SERVER_LEAVE = auto()
|
|
142
144
|
|
|
143
|
-
#
|
|
144
|
-
|
|
145
|
-
|
|
145
|
+
# Legacy: `start_simulation` function
|
|
146
|
+
START_SIMULATION_ENTER = auto()
|
|
147
|
+
START_SIMULATION_LEAVE = auto()
|
|
146
148
|
|
|
147
|
-
#
|
|
148
|
-
RUN_FLEET_API_ENTER = auto()
|
|
149
|
-
RUN_FLEET_API_LEAVE = auto()
|
|
149
|
+
# --- `flwr` CLI -------------------------------------------------------------------
|
|
150
150
|
|
|
151
|
-
#
|
|
152
|
-
RUN_SUPERLINK_ENTER = auto()
|
|
153
|
-
RUN_SUPERLINK_LEAVE = auto()
|
|
151
|
+
# Not yet implemented
|
|
154
152
|
|
|
155
|
-
#
|
|
156
|
-
START_SIMULATION_ENTER = auto()
|
|
157
|
-
START_SIMULATION_LEAVE = auto()
|
|
153
|
+
# --- SuperExec --------------------------------------------------------------------
|
|
158
154
|
|
|
159
|
-
#
|
|
160
|
-
|
|
161
|
-
|
|
155
|
+
# SuperExec
|
|
156
|
+
RUN_SUPEREXEC_ENTER = auto()
|
|
157
|
+
RUN_SUPEREXEC_LEAVE = auto()
|
|
162
158
|
|
|
163
|
-
#
|
|
164
|
-
START_DRIVER_ENTER = auto()
|
|
165
|
-
START_DRIVER_LEAVE = auto()
|
|
159
|
+
# --- Simulation Engine ------------------------------------------------------------
|
|
166
160
|
|
|
167
|
-
# flower-
|
|
168
|
-
|
|
169
|
-
|
|
161
|
+
# CLI: flower-simulation
|
|
162
|
+
CLI_FLOWER_SIMULATION_ENTER = auto()
|
|
163
|
+
CLI_FLOWER_SIMULATION_LEAVE = auto()
|
|
170
164
|
|
|
171
|
-
#
|
|
172
|
-
|
|
173
|
-
|
|
165
|
+
# Python API: `run_simulation`
|
|
166
|
+
PYTHON_API_RUN_SIMULATION_ENTER = auto()
|
|
167
|
+
PYTHON_API_RUN_SIMULATION_LEAVE = auto()
|
|
174
168
|
|
|
175
|
-
#
|
|
169
|
+
# --- Deployment Engine ------------------------------------------------------------
|
|
170
|
+
|
|
171
|
+
# CLI: `flower-superlink`
|
|
172
|
+
RUN_SUPERLINK_ENTER = auto()
|
|
173
|
+
RUN_SUPERLINK_LEAVE = auto()
|
|
174
|
+
|
|
175
|
+
# CLI: `flower-supernode`
|
|
176
176
|
RUN_SUPERNODE_ENTER = auto()
|
|
177
177
|
RUN_SUPERNODE_LEAVE = auto()
|
|
178
178
|
|
|
179
|
-
#
|
|
180
|
-
|
|
181
|
-
|
|
179
|
+
# CLI: `flower-server-app`
|
|
180
|
+
RUN_SERVER_APP_ENTER = auto()
|
|
181
|
+
RUN_SERVER_APP_LEAVE = auto()
|
|
182
|
+
|
|
183
|
+
# --- DEPRECATED -------------------------------------------------------------------
|
|
184
|
+
|
|
185
|
+
# [DEPRECATED] CLI: `flower-client-app`
|
|
186
|
+
RUN_CLIENT_APP_ENTER = auto()
|
|
187
|
+
RUN_CLIENT_APP_LEAVE = auto()
|
|
182
188
|
|
|
183
189
|
|
|
184
190
|
# Use the ThreadPoolExecutor with max_workers=1 to have a queue
|
flwr/server/__init__.py
CHANGED
|
@@ -17,14 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from . import strategy
|
|
19
19
|
from . import workflow as workflow
|
|
20
|
-
from .app import run_superlink as run_superlink
|
|
21
20
|
from .app import start_server as start_server
|
|
22
21
|
from .client_manager import ClientManager as ClientManager
|
|
23
22
|
from .client_manager import SimpleClientManager as SimpleClientManager
|
|
24
23
|
from .compat import LegacyContext as LegacyContext
|
|
25
24
|
from .driver import Driver as Driver
|
|
26
25
|
from .history import History as History
|
|
27
|
-
from .run_serverapp import run_server_app as run_server_app
|
|
28
26
|
from .server import Server as Server
|
|
29
27
|
from .server_app import ServerApp as ServerApp
|
|
30
28
|
from .server_config import ServerConfig as ServerConfig
|
|
@@ -40,8 +38,6 @@ __all__ = [
|
|
|
40
38
|
"ServerAppComponents",
|
|
41
39
|
"ServerConfig",
|
|
42
40
|
"SimpleClientManager",
|
|
43
|
-
"run_server_app",
|
|
44
|
-
"run_superlink",
|
|
45
41
|
"start_server",
|
|
46
42
|
"strategy",
|
|
47
43
|
"workflow",
|
flwr/server/app.py
CHANGED
|
@@ -278,24 +278,24 @@ def run_superlink() -> None:
|
|
|
278
278
|
fleet_thread.start()
|
|
279
279
|
bckg_threads.append(fleet_thread)
|
|
280
280
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
281
|
-
maybe_keys =
|
|
281
|
+
maybe_keys = _try_setup_node_authentication(args, certificates)
|
|
282
282
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
283
283
|
if maybe_keys is not None:
|
|
284
284
|
(
|
|
285
|
-
|
|
285
|
+
node_public_keys,
|
|
286
286
|
server_private_key,
|
|
287
287
|
server_public_key,
|
|
288
288
|
) = maybe_keys
|
|
289
289
|
state = state_factory.state()
|
|
290
|
-
state.
|
|
290
|
+
state.store_node_public_keys(node_public_keys)
|
|
291
291
|
state.store_server_private_public_key(
|
|
292
292
|
private_key_to_bytes(server_private_key),
|
|
293
293
|
public_key_to_bytes(server_public_key),
|
|
294
294
|
)
|
|
295
295
|
log(
|
|
296
296
|
INFO,
|
|
297
|
-
"
|
|
298
|
-
len(
|
|
297
|
+
"Node authentication enabled with %d known public keys",
|
|
298
|
+
len(node_public_keys),
|
|
299
299
|
)
|
|
300
300
|
interceptors = [AuthenticateServerInterceptor(state)]
|
|
301
301
|
|
|
@@ -344,7 +344,7 @@ def _format_address(address: str) -> Tuple[str, str, int]:
|
|
|
344
344
|
return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
|
|
345
345
|
|
|
346
346
|
|
|
347
|
-
def
|
|
347
|
+
def _try_setup_node_authentication(
|
|
348
348
|
args: argparse.Namespace,
|
|
349
349
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
350
350
|
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
@@ -373,16 +373,16 @@ def _try_setup_client_authentication(
|
|
|
373
373
|
"`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
|
|
374
374
|
)
|
|
375
375
|
|
|
376
|
-
|
|
377
|
-
if not
|
|
376
|
+
node_keys_file_path = Path(args.auth_list_public_keys)
|
|
377
|
+
if not node_keys_file_path.exists():
|
|
378
378
|
sys.exit(
|
|
379
379
|
"The provided path to the known public keys CSV file does not exist: "
|
|
380
|
-
f"{
|
|
380
|
+
f"{node_keys_file_path}. "
|
|
381
381
|
"Please provide the CSV file path containing known public keys "
|
|
382
382
|
"to '--auth-list-public-keys'."
|
|
383
383
|
)
|
|
384
384
|
|
|
385
|
-
|
|
385
|
+
node_public_keys: Set[bytes] = set()
|
|
386
386
|
|
|
387
387
|
try:
|
|
388
388
|
ssh_private_key = load_ssh_private_key(
|
|
@@ -413,13 +413,13 @@ def _try_setup_client_authentication(
|
|
|
413
413
|
"path points to a valid public key file and try again."
|
|
414
414
|
)
|
|
415
415
|
|
|
416
|
-
with open(
|
|
416
|
+
with open(node_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
417
417
|
reader = csv.reader(csvfile)
|
|
418
418
|
for row in reader:
|
|
419
419
|
for element in row:
|
|
420
420
|
public_key = load_ssh_public_key(element.encode())
|
|
421
421
|
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
422
|
-
|
|
422
|
+
node_public_keys.add(public_key_to_bytes(public_key))
|
|
423
423
|
else:
|
|
424
424
|
sys.exit(
|
|
425
425
|
"Error: Unable to parse the public keys in the CSV "
|
|
@@ -427,7 +427,7 @@ def _try_setup_client_authentication(
|
|
|
427
427
|
"known SSH public keys files and try again."
|
|
428
428
|
)
|
|
429
429
|
return (
|
|
430
|
-
|
|
430
|
+
node_public_keys,
|
|
431
431
|
ssh_private_key,
|
|
432
432
|
ssh_public_key,
|
|
433
433
|
)
|
flwr/server/compat/app.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
from logging import INFO
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import EventType, event
|
|
22
21
|
from flwr.common.logger import log
|
|
23
22
|
from flwr.server.client_manager import ClientManager
|
|
24
23
|
from flwr.server.history import History
|
|
@@ -65,8 +64,6 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
65
64
|
hist : flwr.server.history.History
|
|
66
65
|
Object containing training and evaluation metrics.
|
|
67
66
|
"""
|
|
68
|
-
event(EventType.START_DRIVER_ENTER)
|
|
69
|
-
|
|
70
67
|
# Initialize the Driver API server and config
|
|
71
68
|
initialized_server, initialized_config = init_defaults(
|
|
72
69
|
server=server,
|
|
@@ -96,6 +93,4 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
96
93
|
f_stop.set()
|
|
97
94
|
thread.join()
|
|
98
95
|
|
|
99
|
-
event(EventType.START_SERVER_LEAVE)
|
|
100
|
-
|
|
101
96
|
return hist
|
|
@@ -21,7 +21,7 @@ from typing import Iterable, List, Optional, cast
|
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
24
|
-
from flwr.common import DEFAULT_TTL,
|
|
24
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
27
|
from flwr.common.serde import (
|
|
@@ -94,7 +94,6 @@ class GrpcDriver(Driver):
|
|
|
94
94
|
|
|
95
95
|
This will not call GetRun.
|
|
96
96
|
"""
|
|
97
|
-
event(EventType.DRIVER_CONNECT)
|
|
98
97
|
if self._is_connected:
|
|
99
98
|
log(WARNING, "Already connected")
|
|
100
99
|
return
|
|
@@ -108,7 +107,6 @@ class GrpcDriver(Driver):
|
|
|
108
107
|
|
|
109
108
|
def _disconnect(self) -> None:
|
|
110
109
|
"""Disconnect from the Driver API."""
|
|
111
|
-
event(EventType.DRIVER_DISCONNECT)
|
|
112
110
|
if not self._is_connected:
|
|
113
111
|
log(DEBUG, "Already disconnected")
|
|
114
112
|
return
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -97,6 +97,21 @@ def run_server_app() -> None:
|
|
|
97
97
|
|
|
98
98
|
args = _parse_args_run_server_app().parse_args()
|
|
99
99
|
|
|
100
|
+
# Check if the server app reference is passed.
|
|
101
|
+
# Since Flower 1.11, passing a reference is not allowed.
|
|
102
|
+
app_path: Optional[str] = args.app
|
|
103
|
+
# If the provided app_path doesn't exist, and contains a ":",
|
|
104
|
+
# it is likely to be a server app reference instead of a path.
|
|
105
|
+
if app_path is not None and not Path(app_path).exists() and ":" in app_path:
|
|
106
|
+
sys.exit(
|
|
107
|
+
"It appears you've passed a reference like `server:app`.\n\n"
|
|
108
|
+
"Note that since version `1.11.0`, `flower-server-app` no longer supports "
|
|
109
|
+
"passing a reference to a `ServerApp` attribute. Instead, you need to pass "
|
|
110
|
+
"the path to Flower app via the argument `--app`. This is the path to a "
|
|
111
|
+
"directory containing a `pyproject.toml`. You can create a valid Flower "
|
|
112
|
+
"app by executing `flwr new` and following the prompt."
|
|
113
|
+
)
|
|
114
|
+
|
|
100
115
|
if args.server != ADDRESS_DRIVER_API:
|
|
101
116
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
102
117
|
warn_deprecated_feature(warn)
|
|
@@ -151,7 +166,6 @@ def run_server_app() -> None:
|
|
|
151
166
|
cert_path,
|
|
152
167
|
)
|
|
153
168
|
|
|
154
|
-
app_path: Optional[str] = args.app
|
|
155
169
|
if not (app_path is None) ^ (args.run_id is None):
|
|
156
170
|
raise sys.exit(
|
|
157
171
|
"Please provide either a Flower App path or a Run ID, but not both. "
|
|
@@ -51,19 +51,22 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
51
51
|
self, request: CreateNodeRequest, context: grpc.ServicerContext
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
|
-
log(INFO, "
|
|
54
|
+
log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
|
|
55
|
+
log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
|
|
55
56
|
response = message_handler.create_node(
|
|
56
57
|
request=request,
|
|
57
58
|
state=self.state_factory.state(),
|
|
58
59
|
)
|
|
59
|
-
log(INFO, "
|
|
60
|
+
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
|
61
|
+
log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
|
|
60
62
|
return response
|
|
61
63
|
|
|
62
64
|
def DeleteNode(
|
|
63
65
|
self, request: DeleteNodeRequest, context: grpc.ServicerContext
|
|
64
66
|
) -> DeleteNodeResponse:
|
|
65
67
|
"""."""
|
|
66
|
-
log(INFO, "
|
|
68
|
+
log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
|
|
69
|
+
log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
|
|
67
70
|
return message_handler.delete_node(
|
|
68
71
|
request=request,
|
|
69
72
|
state=self.state_factory.state(),
|
|
@@ -71,7 +74,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
71
74
|
|
|
72
75
|
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
|
73
76
|
"""."""
|
|
74
|
-
log(DEBUG, "
|
|
77
|
+
log(DEBUG, "[Fleet.Ping] Request: %s", request)
|
|
75
78
|
return message_handler.ping(
|
|
76
79
|
request=request,
|
|
77
80
|
state=self.state_factory.state(),
|
|
@@ -81,7 +84,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
81
84
|
self, request: PullTaskInsRequest, context: grpc.ServicerContext
|
|
82
85
|
) -> PullTaskInsResponse:
|
|
83
86
|
"""Pull TaskIns."""
|
|
84
|
-
log(INFO, "
|
|
87
|
+
log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
|
|
88
|
+
log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
|
|
85
89
|
return message_handler.pull_task_ins(
|
|
86
90
|
request=request,
|
|
87
91
|
state=self.state_factory.state(),
|
|
@@ -91,7 +95,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
91
95
|
self, request: PushTaskResRequest, context: grpc.ServicerContext
|
|
92
96
|
) -> PushTaskResResponse:
|
|
93
97
|
"""Push TaskRes."""
|
|
94
|
-
|
|
98
|
+
if request.task_res_list:
|
|
99
|
+
log(
|
|
100
|
+
INFO,
|
|
101
|
+
"[Fleet.PushTaskRes] Push results from node_id=%s",
|
|
102
|
+
request.task_res_list[0].task.producer.node_id,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
95
106
|
return message_handler.push_task_res(
|
|
96
107
|
request=request,
|
|
97
108
|
state=self.state_factory.state(),
|
|
@@ -101,7 +112,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
101
112
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
102
113
|
) -> GetRunResponse:
|
|
103
114
|
"""Get run information."""
|
|
104
|
-
log(INFO, "
|
|
115
|
+
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
105
116
|
return message_handler.get_run(
|
|
106
117
|
request=request,
|
|
107
118
|
state=self.state_factory.state(),
|
|
@@ -111,7 +122,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
111
122
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
112
123
|
) -> GetFabResponse:
|
|
113
124
|
"""Get FAB."""
|
|
114
|
-
log(
|
|
125
|
+
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
115
126
|
return message_handler.get_fab(
|
|
116
127
|
request=request,
|
|
117
128
|
ffs=self.ffs_factory.ffs(),
|
|
@@ -78,13 +78,13 @@ def _get_value_from_tuples(
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
81
|
-
"""Server interceptor for
|
|
81
|
+
"""Server interceptor for node authentication."""
|
|
82
82
|
|
|
83
83
|
def __init__(self, state: State):
|
|
84
84
|
self.state = state
|
|
85
85
|
|
|
86
|
-
self.
|
|
87
|
-
if len(self.
|
|
86
|
+
self.node_public_keys = state.get_node_public_keys()
|
|
87
|
+
if len(self.node_public_keys) == 0:
|
|
88
88
|
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
89
89
|
|
|
90
90
|
private_key = self.state.get_server_private_key()
|
|
@@ -103,9 +103,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
103
103
|
) -> grpc.RpcMethodHandler:
|
|
104
104
|
"""Flower server interceptor authentication logic.
|
|
105
105
|
|
|
106
|
-
Intercept all unary calls from
|
|
107
|
-
|
|
108
|
-
|
|
106
|
+
Intercept all unary calls from nodes and authenticate nodes by validating auth
|
|
107
|
+
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
|
108
|
+
terminate RPC call by setting context to abort.
|
|
109
109
|
"""
|
|
110
110
|
# One of the method handlers in
|
|
111
111
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
@@ -119,17 +119,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
119
119
|
request: Request,
|
|
120
120
|
context: grpc.ServicerContext,
|
|
121
121
|
) -> Response:
|
|
122
|
-
|
|
122
|
+
node_public_key_bytes = base64.urlsafe_b64decode(
|
|
123
123
|
_get_value_from_tuples(
|
|
124
124
|
_PUBLIC_KEY_HEADER, context.invocation_metadata()
|
|
125
125
|
)
|
|
126
126
|
)
|
|
127
|
-
if
|
|
127
|
+
if node_public_key_bytes not in self.node_public_keys:
|
|
128
128
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
129
129
|
|
|
130
130
|
if isinstance(request, CreateNodeRequest):
|
|
131
131
|
response = self._create_authenticated_node(
|
|
132
|
-
|
|
132
|
+
node_public_key_bytes, request, context
|
|
133
133
|
)
|
|
134
134
|
log(
|
|
135
135
|
INFO,
|
|
@@ -144,13 +144,13 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
144
144
|
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
145
145
|
)
|
|
146
146
|
)
|
|
147
|
-
public_key = bytes_to_public_key(
|
|
147
|
+
public_key = bytes_to_public_key(node_public_key_bytes)
|
|
148
148
|
|
|
149
149
|
if not self._verify_hmac(public_key, request, hmac_value):
|
|
150
150
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
151
151
|
|
|
152
152
|
# Verify node_id
|
|
153
|
-
node_id = self.state.get_node_id(
|
|
153
|
+
node_id = self.state.get_node_id(node_public_key_bytes)
|
|
154
154
|
|
|
155
155
|
if not self._verify_node_id(node_id, request):
|
|
156
156
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|