flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +128 -53
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +55 -24
- flwr/client/typing.py +2 -2
- flwr/common/config.py +87 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +110 -33
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/client/node_state.py
CHANGED
|
@@ -15,30 +15,72 @@
|
|
|
15
15
|
"""Node state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Dict, Optional
|
|
19
21
|
|
|
20
22
|
from flwr.common import Context, RecordSet
|
|
23
|
+
from flwr.common.config import get_fused_config, get_fused_config_from_dir
|
|
24
|
+
from flwr.common.typing import Run, UserConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass()
|
|
28
|
+
class RunInfo:
|
|
29
|
+
"""Contains the Context and initial run_config of a Run."""
|
|
30
|
+
|
|
31
|
+
context: Context
|
|
32
|
+
initial_run_config: UserConfig
|
|
21
33
|
|
|
22
34
|
|
|
23
35
|
class NodeState:
|
|
24
36
|
"""State of a node where client nodes execute runs."""
|
|
25
37
|
|
|
26
|
-
def __init__(
|
|
27
|
-
self
|
|
28
|
-
|
|
29
|
-
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
node_id: int,
|
|
41
|
+
node_config: UserConfig,
|
|
42
|
+
) -> None:
|
|
43
|
+
self.node_id = node_id
|
|
44
|
+
self.node_config = node_config
|
|
45
|
+
self.run_infos: Dict[int, RunInfo] = {}
|
|
30
46
|
|
|
31
|
-
def register_context(
|
|
47
|
+
def register_context(
|
|
48
|
+
self,
|
|
49
|
+
run_id: int,
|
|
50
|
+
run: Optional[Run] = None,
|
|
51
|
+
flwr_path: Optional[Path] = None,
|
|
52
|
+
app_dir: Optional[str] = None,
|
|
53
|
+
) -> None:
|
|
32
54
|
"""Register new run context for this node."""
|
|
33
|
-
if run_id not in self.
|
|
34
|
-
|
|
35
|
-
|
|
55
|
+
if run_id not in self.run_infos:
|
|
56
|
+
initial_run_config = {}
|
|
57
|
+
if app_dir:
|
|
58
|
+
# Load from app directory
|
|
59
|
+
app_path = Path(app_dir)
|
|
60
|
+
if app_path.is_dir():
|
|
61
|
+
override_config = run.override_config if run else {}
|
|
62
|
+
initial_run_config = get_fused_config_from_dir(
|
|
63
|
+
app_path, override_config
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("The specified `app_dir` must be a directory.")
|
|
67
|
+
else:
|
|
68
|
+
# Load from .fab
|
|
69
|
+
initial_run_config = get_fused_config(run, flwr_path) if run else {}
|
|
70
|
+
self.run_infos[run_id] = RunInfo(
|
|
71
|
+
initial_run_config=initial_run_config,
|
|
72
|
+
context=Context(
|
|
73
|
+
node_id=self.node_id,
|
|
74
|
+
node_config=self.node_config,
|
|
75
|
+
state=RecordSet(),
|
|
76
|
+
run_config=initial_run_config.copy(),
|
|
77
|
+
),
|
|
36
78
|
)
|
|
37
79
|
|
|
38
80
|
def retrieve_context(self, run_id: int) -> Context:
|
|
39
81
|
"""Get run context given a run_id."""
|
|
40
|
-
if run_id in self.
|
|
41
|
-
return self.
|
|
82
|
+
if run_id in self.run_infos:
|
|
83
|
+
return self.run_infos[run_id].context
|
|
42
84
|
|
|
43
85
|
raise RuntimeError(
|
|
44
86
|
f"Context for run_id={run_id} doesn't exist."
|
|
@@ -48,4 +90,9 @@ class NodeState:
|
|
|
48
90
|
|
|
49
91
|
def update_context(self, run_id: int, context: Context) -> None:
|
|
50
92
|
"""Update run context."""
|
|
51
|
-
self.
|
|
93
|
+
if context.run_config != self.run_infos[run_id].initial_run_config:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"The `run_config` field of the `Context` object cannot be "
|
|
96
|
+
f"modified (run_id: {run_id})."
|
|
97
|
+
)
|
|
98
|
+
self.run_infos[run_id].context = context
|
flwr/client/node_state_tests.py
CHANGED
|
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
|
|
|
41
41
|
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
|
|
42
42
|
|
|
43
43
|
# NodeState
|
|
44
|
-
node_state = NodeState(
|
|
44
|
+
node_state = NodeState(node_id=0, node_config={})
|
|
45
45
|
|
|
46
46
|
for task in tasks:
|
|
47
47
|
run_id = task.run_id
|
|
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
|
|
|
59
59
|
node_state.update_context(run_id=run_id, context=updated_state)
|
|
60
60
|
|
|
61
61
|
# Verify values
|
|
62
|
-
for run_id,
|
|
62
|
+
for run_id, run_info in node_state.run_infos.items():
|
|
63
63
|
assert (
|
|
64
|
-
context.state.configs_records["counter"]["count"]
|
|
64
|
+
run_info.context.state.configs_records["counter"]["count"]
|
|
65
|
+
== expected_values[run_id]
|
|
65
66
|
)
|
|
@@ -40,7 +40,12 @@ from flwr.common.constant import (
|
|
|
40
40
|
from flwr.common.logger import log
|
|
41
41
|
from flwr.common.message import Message, Metadata
|
|
42
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import (
|
|
44
|
+
message_from_taskins,
|
|
45
|
+
message_to_taskres,
|
|
46
|
+
user_config_from_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import Run
|
|
44
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
50
|
CreateNodeRequest,
|
|
46
51
|
CreateNodeResponse,
|
|
@@ -89,9 +94,9 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
89
94
|
Tuple[
|
|
90
95
|
Callable[[], Optional[Message]],
|
|
91
96
|
Callable[[Message], None],
|
|
97
|
+
Optional[Callable[[], Optional[int]]],
|
|
92
98
|
Optional[Callable[[], None]],
|
|
93
|
-
Optional[Callable[[],
|
|
94
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
99
|
+
Optional[Callable[[int], Run]],
|
|
95
100
|
]
|
|
96
101
|
]:
|
|
97
102
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -236,19 +241,20 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
236
241
|
if not ping_stop_event.is_set():
|
|
237
242
|
ping_stop_event.wait(next_interval)
|
|
238
243
|
|
|
239
|
-
def create_node() ->
|
|
244
|
+
def create_node() -> Optional[int]:
|
|
240
245
|
"""Set create_node."""
|
|
241
246
|
req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
242
247
|
|
|
243
248
|
# Send the request
|
|
244
249
|
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
|
245
250
|
if res is None:
|
|
246
|
-
return
|
|
251
|
+
return None
|
|
247
252
|
|
|
248
253
|
# Remember the node and the ping-loop thread
|
|
249
254
|
nonlocal node, ping_thread
|
|
250
255
|
node = res.node
|
|
251
256
|
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
257
|
+
return node.node_id
|
|
252
258
|
|
|
253
259
|
def delete_node() -> None:
|
|
254
260
|
"""Set delete_node."""
|
|
@@ -344,16 +350,21 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
344
350
|
res.results, # pylint: disable=no-member
|
|
345
351
|
)
|
|
346
352
|
|
|
347
|
-
def get_run(run_id: int) ->
|
|
353
|
+
def get_run(run_id: int) -> Run:
|
|
348
354
|
# Construct the request
|
|
349
355
|
req = GetRunRequest(run_id=run_id)
|
|
350
356
|
|
|
351
357
|
# Send the request
|
|
352
358
|
res = _request(req, GetRunResponse, PATH_GET_RUN)
|
|
353
359
|
if res is None:
|
|
354
|
-
return "", ""
|
|
360
|
+
return Run(run_id, "", "", {})
|
|
355
361
|
|
|
356
|
-
return
|
|
362
|
+
return Run(
|
|
363
|
+
run_id,
|
|
364
|
+
res.run.fab_id,
|
|
365
|
+
res.run.fab_version,
|
|
366
|
+
user_config_from_proto(res.run.override_config),
|
|
367
|
+
)
|
|
357
368
|
|
|
358
369
|
try:
|
|
359
370
|
# Yield methods
|
flwr/client/supernode/app.py
CHANGED
|
@@ -29,7 +29,12 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
29
29
|
|
|
30
30
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
31
31
|
from flwr.common import EventType, event
|
|
32
|
-
from flwr.common.config import
|
|
32
|
+
from flwr.common.config import (
|
|
33
|
+
get_flwr_dir,
|
|
34
|
+
get_project_config,
|
|
35
|
+
get_project_dir,
|
|
36
|
+
parse_config_args,
|
|
37
|
+
)
|
|
33
38
|
from flwr.common.constant import (
|
|
34
39
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
35
40
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
@@ -55,7 +60,12 @@ def run_supernode() -> None:
|
|
|
55
60
|
_warn_deprecated_server_arg(args)
|
|
56
61
|
|
|
57
62
|
root_certificates = _get_certificates(args)
|
|
58
|
-
load_fn = _get_load_client_app_fn(
|
|
63
|
+
load_fn = _get_load_client_app_fn(
|
|
64
|
+
default_app_ref=getattr(args, "client-app"),
|
|
65
|
+
dir_arg=args.dir,
|
|
66
|
+
flwr_dir_arg=args.flwr_dir,
|
|
67
|
+
multi_app=True,
|
|
68
|
+
)
|
|
59
69
|
authentication_keys = _try_setup_client_authentication(args)
|
|
60
70
|
|
|
61
71
|
_start_client_internal(
|
|
@@ -67,7 +77,8 @@ def run_supernode() -> None:
|
|
|
67
77
|
authentication_keys=authentication_keys,
|
|
68
78
|
max_retries=args.max_retries,
|
|
69
79
|
max_wait_time=args.max_wait_time,
|
|
70
|
-
|
|
80
|
+
node_config=parse_config_args([args.node_config]),
|
|
81
|
+
flwr_path=get_flwr_dir(args.flwr_dir),
|
|
71
82
|
)
|
|
72
83
|
|
|
73
84
|
# Graceful shutdown
|
|
@@ -87,11 +98,16 @@ def run_client_app() -> None:
|
|
|
87
98
|
_warn_deprecated_server_arg(args)
|
|
88
99
|
|
|
89
100
|
root_certificates = _get_certificates(args)
|
|
90
|
-
load_fn = _get_load_client_app_fn(
|
|
101
|
+
load_fn = _get_load_client_app_fn(
|
|
102
|
+
default_app_ref=getattr(args, "client-app"),
|
|
103
|
+
dir_arg=args.dir,
|
|
104
|
+
multi_app=False,
|
|
105
|
+
)
|
|
91
106
|
authentication_keys = _try_setup_client_authentication(args)
|
|
92
107
|
|
|
93
108
|
_start_client_internal(
|
|
94
109
|
server_address=args.superlink,
|
|
110
|
+
node_config=parse_config_args([args.node_config]),
|
|
95
111
|
load_client_app_fn=load_fn,
|
|
96
112
|
transport=args.transport,
|
|
97
113
|
root_certificates=root_certificates,
|
|
@@ -159,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
159
175
|
|
|
160
176
|
|
|
161
177
|
def _get_load_client_app_fn(
|
|
162
|
-
|
|
178
|
+
default_app_ref: str,
|
|
179
|
+
dir_arg: str,
|
|
180
|
+
multi_app: bool,
|
|
181
|
+
flwr_dir_arg: Optional[str] = None,
|
|
163
182
|
) -> Callable[[str, str], ClientApp]:
|
|
164
183
|
"""Get the load_client_app_fn function.
|
|
165
184
|
|
|
@@ -171,23 +190,27 @@ def _get_load_client_app_fn(
|
|
|
171
190
|
loads a default ClientApp.
|
|
172
191
|
"""
|
|
173
192
|
# Find the Flower directory containing Flower Apps (only for multi-app)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
193
|
+
if not multi_app:
|
|
194
|
+
flwr_dir = Path("")
|
|
195
|
+
else:
|
|
196
|
+
if flwr_dir_arg is None:
|
|
177
197
|
flwr_dir = get_flwr_dir()
|
|
178
198
|
else:
|
|
179
|
-
flwr_dir = Path(
|
|
199
|
+
flwr_dir = Path(flwr_dir_arg).absolute()
|
|
180
200
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
default_app_ref: str = getattr(args, "client-app")
|
|
201
|
+
inserted_path = None
|
|
184
202
|
|
|
185
203
|
if not multi_app:
|
|
186
204
|
log(
|
|
187
205
|
DEBUG,
|
|
188
206
|
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
189
|
-
|
|
207
|
+
default_app_ref,
|
|
190
208
|
)
|
|
209
|
+
# Insert sys.path
|
|
210
|
+
dir_path = Path(dir_arg).absolute()
|
|
211
|
+
sys.path.insert(0, str(dir_path))
|
|
212
|
+
inserted_path = str(dir_path)
|
|
213
|
+
|
|
191
214
|
valid, error_msg = validate(default_app_ref)
|
|
192
215
|
if not valid and error_msg:
|
|
193
216
|
raise LoadClientAppError(error_msg) from None
|
|
@@ -196,7 +219,7 @@ def _get_load_client_app_fn(
|
|
|
196
219
|
# If multi-app feature is disabled
|
|
197
220
|
if not multi_app:
|
|
198
221
|
# Get sys path to be inserted
|
|
199
|
-
|
|
222
|
+
dir_path = Path(dir_arg).absolute()
|
|
200
223
|
|
|
201
224
|
# Set app reference
|
|
202
225
|
client_app_ref = default_app_ref
|
|
@@ -209,7 +232,7 @@ def _get_load_client_app_fn(
|
|
|
209
232
|
|
|
210
233
|
log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
211
234
|
# Get sys path to be inserted
|
|
212
|
-
|
|
235
|
+
dir_path = Path(dir_arg).absolute()
|
|
213
236
|
|
|
214
237
|
# Set app reference
|
|
215
238
|
client_app_ref = default_app_ref
|
|
@@ -222,13 +245,21 @@ def _get_load_client_app_fn(
|
|
|
222
245
|
raise LoadClientAppError("Failed to load ClientApp") from e
|
|
223
246
|
|
|
224
247
|
# Get sys path to be inserted
|
|
225
|
-
|
|
248
|
+
dir_path = Path(project_dir).absolute()
|
|
226
249
|
|
|
227
250
|
# Set app reference
|
|
228
|
-
client_app_ref = config["
|
|
251
|
+
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
229
252
|
|
|
230
253
|
# Set sys.path
|
|
231
|
-
|
|
254
|
+
nonlocal inserted_path
|
|
255
|
+
if inserted_path != str(dir_path):
|
|
256
|
+
# Remove the previously inserted path
|
|
257
|
+
if inserted_path is not None:
|
|
258
|
+
sys.path.remove(inserted_path)
|
|
259
|
+
# Insert the new path
|
|
260
|
+
sys.path.insert(0, str(dir_path))
|
|
261
|
+
|
|
262
|
+
inserted_path = str(dir_path)
|
|
232
263
|
|
|
233
264
|
# Load ClientApp
|
|
234
265
|
log(
|
|
@@ -236,7 +267,7 @@ def _get_load_client_app_fn(
|
|
|
236
267
|
"Loading ClientApp `%s`",
|
|
237
268
|
client_app_ref,
|
|
238
269
|
)
|
|
239
|
-
client_app = load_app(client_app_ref, LoadClientAppError,
|
|
270
|
+
client_app = load_app(client_app_ref, LoadClientAppError, dir_path)
|
|
240
271
|
|
|
241
272
|
if not isinstance(client_app, ClientApp):
|
|
242
273
|
raise LoadClientAppError(
|
|
@@ -375,11 +406,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
375
406
|
help="The SuperNode's public key (as a path str) to enable authentication.",
|
|
376
407
|
)
|
|
377
408
|
parser.add_argument(
|
|
378
|
-
"--
|
|
379
|
-
type=
|
|
380
|
-
help="
|
|
381
|
-
"
|
|
382
|
-
"
|
|
409
|
+
"--node-config",
|
|
410
|
+
type=str,
|
|
411
|
+
help="A comma separated list of key/value pairs (separated by `=`) to "
|
|
412
|
+
"configure the SuperNode. "
|
|
413
|
+
"E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
|
|
383
414
|
)
|
|
384
415
|
|
|
385
416
|
|
flwr/client/typing.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Custom types for Flower clients."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Callable
|
|
18
|
+
from typing import Callable
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context, Message
|
|
21
21
|
|
|
@@ -23,7 +23,7 @@ from .client import Client as Client
|
|
|
23
23
|
|
|
24
24
|
# Compatibility
|
|
25
25
|
ClientFn = Callable[[str], Client]
|
|
26
|
-
ClientFnExt = Callable[[
|
|
26
|
+
ClientFnExt = Callable[[Context], Client]
|
|
27
27
|
|
|
28
28
|
ClientAppCallable = Callable[[Message, Context], Message]
|
|
29
29
|
Mod = Callable[[Message, Context, ClientAppCallable], Message]
|
flwr/common/config.py
CHANGED
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
18
|
from pathlib import Path
|
|
19
|
-
from typing import Any, Dict, Optional, Union
|
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
|
|
20
20
|
|
|
21
21
|
import tomli
|
|
22
22
|
|
|
23
23
|
from flwr.cli.config_utils import validate_fields
|
|
24
24
|
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
|
|
25
|
+
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
@@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
|
30
31
|
return Path(
|
|
31
32
|
os.getenv(
|
|
32
33
|
FLWR_HOME,
|
|
33
|
-
f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}
|
|
34
|
+
Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
|
|
34
35
|
)
|
|
35
36
|
)
|
|
36
37
|
return Path(provided_path).absolute()
|
|
@@ -71,3 +72,87 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
|
|
|
71
72
|
)
|
|
72
73
|
|
|
73
74
|
return config
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _fuse_dicts(
|
|
78
|
+
main_dict: UserConfig,
|
|
79
|
+
override_dict: UserConfig,
|
|
80
|
+
) -> UserConfig:
|
|
81
|
+
fused_dict = main_dict.copy()
|
|
82
|
+
|
|
83
|
+
for key, value in override_dict.items():
|
|
84
|
+
if key in main_dict:
|
|
85
|
+
fused_dict[key] = value
|
|
86
|
+
|
|
87
|
+
return fused_dict
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_fused_config_from_dir(
|
|
91
|
+
project_dir: Path, override_config: UserConfig
|
|
92
|
+
) -> UserConfig:
|
|
93
|
+
"""Merge the overrides from a given dict with the config from a Flower App."""
|
|
94
|
+
default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
|
|
95
|
+
"config", {}
|
|
96
|
+
)
|
|
97
|
+
flat_default_config = flatten_dict(default_config)
|
|
98
|
+
|
|
99
|
+
return _fuse_dicts(flat_default_config, override_config)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
|
103
|
+
"""Merge the overrides from a `Run` with the config from a FAB.
|
|
104
|
+
|
|
105
|
+
Get the config using the fab_id and the fab_version, remove the nesting by adding
|
|
106
|
+
the nested keys as prefixes separated by dots, and fuse it with the override dict.
|
|
107
|
+
"""
|
|
108
|
+
if not run.fab_id or not run.fab_version:
|
|
109
|
+
return {}
|
|
110
|
+
|
|
111
|
+
project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
|
|
112
|
+
|
|
113
|
+
return get_fused_config_from_dir(project_dir, run.override_config)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig:
|
|
117
|
+
"""Flatten dict by joining nested keys with a given separator."""
|
|
118
|
+
items: List[Tuple[str, UserConfigValue]] = []
|
|
119
|
+
separator: str = "."
|
|
120
|
+
for k, v in raw_dict.items():
|
|
121
|
+
new_key = f"{parent_key}{separator}{k}" if parent_key else k
|
|
122
|
+
if isinstance(v, dict):
|
|
123
|
+
items.extend(flatten_dict(v, parent_key=new_key).items())
|
|
124
|
+
elif isinstance(v, get_args(UserConfigValue)):
|
|
125
|
+
items.append((new_key, cast(UserConfigValue, v)))
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"The value for key {k} needs to be of type `int`, `float`, "
|
|
129
|
+
"`bool, `str`, or a `dict` of those.",
|
|
130
|
+
)
|
|
131
|
+
return dict(items)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def parse_config_args(
|
|
135
|
+
config: Optional[List[str]],
|
|
136
|
+
separator: str = ",",
|
|
137
|
+
) -> UserConfig:
|
|
138
|
+
"""Parse separator separated list of key-value pairs separated by '='."""
|
|
139
|
+
overrides: UserConfig = {}
|
|
140
|
+
|
|
141
|
+
if config is None:
|
|
142
|
+
return overrides
|
|
143
|
+
|
|
144
|
+
for config_line in config:
|
|
145
|
+
if config_line:
|
|
146
|
+
overrides_list = config_line.split(separator)
|
|
147
|
+
if (
|
|
148
|
+
len(overrides_list) == 1
|
|
149
|
+
and "=" not in overrides_list
|
|
150
|
+
and overrides_list[0].endswith(".toml")
|
|
151
|
+
):
|
|
152
|
+
with Path(overrides_list[0]).open("rb") as config_file:
|
|
153
|
+
overrides = flatten_dict(tomli.load(config_file))
|
|
154
|
+
else:
|
|
155
|
+
toml_str = "\n".join(overrides_list)
|
|
156
|
+
overrides.update(tomli.loads(toml_str))
|
|
157
|
+
|
|
158
|
+
return overrides
|
flwr/common/constant.py
CHANGED
|
@@ -57,6 +57,9 @@ APP_DIR = "apps"
|
|
|
57
57
|
FAB_CONFIG_FILE = "pyproject.toml"
|
|
58
58
|
FLWR_HOME = "FLWR_HOME"
|
|
59
59
|
|
|
60
|
+
# Constants entries in Node config for Simulation
|
|
61
|
+
PARTITION_ID_KEY = "partition-id"
|
|
62
|
+
NUM_PARTITIONS_KEY = "num-partitions"
|
|
60
63
|
|
|
61
64
|
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
|
|
62
65
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
|
flwr/common/context.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from .record import RecordSet
|
|
21
|
+
from .typing import UserConfig
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@dataclass
|
|
@@ -27,6 +27,11 @@ class Context:
|
|
|
27
27
|
|
|
28
28
|
Parameters
|
|
29
29
|
----------
|
|
30
|
+
node_id : int
|
|
31
|
+
The ID that identifies the node.
|
|
32
|
+
node_config : UserConfig
|
|
33
|
+
A config (key/value mapping) unique to the node and independent of the
|
|
34
|
+
`run_config`. This config persists across all runs this node participates in.
|
|
30
35
|
state : RecordSet
|
|
31
36
|
Holds records added by the entity in a given run and that will stay local.
|
|
32
37
|
This means that the data it holds will never leave the system it's running from.
|
|
@@ -34,15 +39,25 @@ class Context:
|
|
|
34
39
|
executing mods. It can also be used as a memory to access
|
|
35
40
|
at different points during the lifecycle of this entity (e.g. across
|
|
36
41
|
multiple rounds)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
42
|
+
run_config : UserConfig
|
|
43
|
+
A config (key/value mapping) held by the entity in a given run and that will
|
|
44
|
+
stay local. It can be used at any point during the lifecycle of this entity
|
|
45
|
+
(e.g. across multiple rounds)
|
|
41
46
|
"""
|
|
42
47
|
|
|
48
|
+
node_id: int
|
|
49
|
+
node_config: UserConfig
|
|
43
50
|
state: RecordSet
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def __init__(
|
|
51
|
+
run_config: UserConfig
|
|
52
|
+
|
|
53
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
54
|
+
self,
|
|
55
|
+
node_id: int,
|
|
56
|
+
node_config: UserConfig,
|
|
57
|
+
state: RecordSet,
|
|
58
|
+
run_config: UserConfig,
|
|
59
|
+
) -> None:
|
|
60
|
+
self.node_id = node_id
|
|
61
|
+
self.node_config = node_config
|
|
47
62
|
self.state = state
|
|
48
|
-
self.
|
|
63
|
+
self.run_config = run_config
|
flwr/common/logger.py
CHANGED
|
@@ -197,6 +197,31 @@ def warn_deprecated_feature(name: str) -> None:
|
|
|
197
197
|
)
|
|
198
198
|
|
|
199
199
|
|
|
200
|
+
def warn_deprecated_feature_with_example(
|
|
201
|
+
deprecation_message: str, example_message: str, code_example: str
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Warn if a feature is deprecated and show code example."""
|
|
204
|
+
log(
|
|
205
|
+
WARN,
|
|
206
|
+
"""DEPRECATED FEATURE: %s
|
|
207
|
+
|
|
208
|
+
Check the following `FEATURE UPDATE` warning message for the preferred
|
|
209
|
+
new mechanism to use this feature in Flower.
|
|
210
|
+
""",
|
|
211
|
+
deprecation_message,
|
|
212
|
+
)
|
|
213
|
+
log(
|
|
214
|
+
WARN,
|
|
215
|
+
"""FEATURE UPDATE: %s
|
|
216
|
+
------------------------------------------------------------
|
|
217
|
+
%s
|
|
218
|
+
------------------------------------------------------------
|
|
219
|
+
""",
|
|
220
|
+
example_message,
|
|
221
|
+
code_example,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
200
225
|
def warn_unsupported_feature(name: str) -> None:
|
|
201
226
|
"""Warn the user when they use an unsupported feature."""
|
|
202
227
|
log(
|
flwr/common/serde.py
CHANGED
|
@@ -671,3 +671,48 @@ def message_from_taskres(taskres: TaskRes) -> Message:
|
|
|
671
671
|
)
|
|
672
672
|
message.metadata.created_at = taskres.task.created_at
|
|
673
673
|
return message
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
# === User configs ===
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def user_config_to_proto(user_config: typing.UserConfig) -> Any:
|
|
680
|
+
"""Serialize `UserConfig` to ProtoBuf."""
|
|
681
|
+
proto = {}
|
|
682
|
+
for key, value in user_config.items():
|
|
683
|
+
proto[key] = user_config_value_to_proto(value)
|
|
684
|
+
return proto
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def user_config_from_proto(proto: Any) -> typing.UserConfig:
|
|
688
|
+
"""Deserialize `UserConfig` from ProtoBuf."""
|
|
689
|
+
metrics = {}
|
|
690
|
+
for key, value in proto.items():
|
|
691
|
+
metrics[key] = user_config_value_from_proto(value)
|
|
692
|
+
return metrics
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def user_config_value_to_proto(user_config_value: typing.UserConfigValue) -> Scalar:
|
|
696
|
+
"""Serialize `UserConfigValue` to ProtoBuf."""
|
|
697
|
+
if isinstance(user_config_value, bool):
|
|
698
|
+
return Scalar(bool=user_config_value)
|
|
699
|
+
|
|
700
|
+
if isinstance(user_config_value, float):
|
|
701
|
+
return Scalar(double=user_config_value)
|
|
702
|
+
|
|
703
|
+
if isinstance(user_config_value, int):
|
|
704
|
+
return Scalar(sint64=user_config_value)
|
|
705
|
+
|
|
706
|
+
if isinstance(user_config_value, str):
|
|
707
|
+
return Scalar(string=user_config_value)
|
|
708
|
+
|
|
709
|
+
raise ValueError(
|
|
710
|
+
f"Accepted types: {bool, float, int, str} (but not {type(user_config_value)})"
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
|
|
715
|
+
"""Deserialize `UserConfigValue` from ProtoBuf."""
|
|
716
|
+
scalar_field = scalar_msg.WhichOneof("scalar")
|
|
717
|
+
scalar = getattr(scalar_msg, cast(str, scalar_field))
|
|
718
|
+
return cast(typing.UserConfigValue, scalar)
|
flwr/common/telemetry.py
CHANGED
|
@@ -64,6 +64,18 @@ def _get_home() -> Path:
|
|
|
64
64
|
return Path().home()
|
|
65
65
|
|
|
66
66
|
|
|
67
|
+
def _get_partner_id() -> str:
|
|
68
|
+
"""Get partner ID."""
|
|
69
|
+
partner_id = os.getenv("FLWR_TELEMETRY_PARTNER_ID")
|
|
70
|
+
if not partner_id:
|
|
71
|
+
return "unavailable"
|
|
72
|
+
try:
|
|
73
|
+
uuid.UUID(partner_id)
|
|
74
|
+
except ValueError:
|
|
75
|
+
partner_id = "invalid"
|
|
76
|
+
return partner_id
|
|
77
|
+
|
|
78
|
+
|
|
67
79
|
def _get_source_id() -> str:
|
|
68
80
|
"""Get existing or new source ID."""
|
|
69
81
|
source_id = "unavailable"
|
|
@@ -177,6 +189,7 @@ state: Dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = {
|
|
|
177
189
|
"executor": None,
|
|
178
190
|
"source": None,
|
|
179
191
|
"cluster": None,
|
|
192
|
+
"partner": None,
|
|
180
193
|
}
|
|
181
194
|
|
|
182
195
|
|
|
@@ -202,11 +215,15 @@ def create_event(event_type: EventType, event_details: Optional[Dict[str, Any]])
|
|
|
202
215
|
if state["cluster"] is None:
|
|
203
216
|
state["cluster"] = str(uuid.uuid4())
|
|
204
217
|
|
|
218
|
+
if state["partner"] is None:
|
|
219
|
+
state["partner"] = _get_partner_id()
|
|
220
|
+
|
|
205
221
|
if event_details is None:
|
|
206
222
|
event_details = {}
|
|
207
223
|
|
|
208
224
|
date = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
|
|
209
225
|
context = {
|
|
226
|
+
"partner": state["partner"],
|
|
210
227
|
"source": state["source"],
|
|
211
228
|
"cluster": state["cluster"],
|
|
212
229
|
"date": date,
|