flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +47 -27
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +32 -21
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +133 -54
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +39 -39
- flwr/client/typing.py +2 -2
- flwr/common/config.py +92 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/object_ref.py +84 -21
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +20 -11
- flwr/proto/exec_pb2.pyi +41 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -18
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +269 -70
- flwr/superexec/app.py +17 -11
- flwr/superexec/deployment.py +111 -35
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +6 -1
- flwr/superexec/executor.py +21 -0
- flwr/superexec/simulation.py +181 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
- flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
|
@@ -40,7 +40,12 @@ from flwr.common.grpc import create_channel
|
|
|
40
40
|
from flwr.common.logger import log
|
|
41
41
|
from flwr.common.message import Message, Metadata
|
|
42
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import (
|
|
44
|
+
message_from_taskins,
|
|
45
|
+
message_to_taskres,
|
|
46
|
+
user_config_from_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import Run
|
|
44
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
50
|
CreateNodeRequest,
|
|
46
51
|
DeleteNodeRequest,
|
|
@@ -78,9 +83,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
78
83
|
Tuple[
|
|
79
84
|
Callable[[], Optional[Message]],
|
|
80
85
|
Callable[[Message], None],
|
|
86
|
+
Optional[Callable[[], Optional[int]]],
|
|
81
87
|
Optional[Callable[[], None]],
|
|
82
|
-
Optional[Callable[[],
|
|
83
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
88
|
+
Optional[Callable[[int], Run]],
|
|
84
89
|
]
|
|
85
90
|
]:
|
|
86
91
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -175,7 +180,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
175
180
|
if not ping_stop_event.is_set():
|
|
176
181
|
ping_stop_event.wait(next_interval)
|
|
177
182
|
|
|
178
|
-
def create_node() ->
|
|
183
|
+
def create_node() -> Optional[int]:
|
|
179
184
|
"""Set create_node."""
|
|
180
185
|
# Call FleetAPI
|
|
181
186
|
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
@@ -188,6 +193,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
188
193
|
nonlocal node, ping_thread
|
|
189
194
|
node = cast(Node, create_node_response.node)
|
|
190
195
|
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
196
|
+
return node.node_id
|
|
191
197
|
|
|
192
198
|
def delete_node() -> None:
|
|
193
199
|
"""Set delete_node."""
|
|
@@ -266,7 +272,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
266
272
|
# Cleanup
|
|
267
273
|
metadata = None
|
|
268
274
|
|
|
269
|
-
def get_run(run_id: int) ->
|
|
275
|
+
def get_run(run_id: int) -> Run:
|
|
270
276
|
# Call FleetAPI
|
|
271
277
|
get_run_request = GetRunRequest(run_id=run_id)
|
|
272
278
|
get_run_response: GetRunResponse = retry_invoker.invoke(
|
|
@@ -275,7 +281,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
275
281
|
)
|
|
276
282
|
|
|
277
283
|
# Return fab_id and fab_version
|
|
278
|
-
return
|
|
284
|
+
return Run(
|
|
285
|
+
run_id,
|
|
286
|
+
get_run_response.run.fab_id,
|
|
287
|
+
get_run_response.run.fab_version,
|
|
288
|
+
user_config_from_proto(get_run_response.run.override_config),
|
|
289
|
+
)
|
|
279
290
|
|
|
280
291
|
try:
|
|
281
292
|
# Yield methods
|
|
@@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
|
|
|
92
92
|
client_fn: ClientFnExt, message: Message, context: Context
|
|
93
93
|
) -> Message:
|
|
94
94
|
"""Handle legacy message in the inner most mod."""
|
|
95
|
-
client = client_fn(
|
|
95
|
+
client = client_fn(context)
|
|
96
96
|
|
|
97
97
|
# Check if NumPyClient is returend
|
|
98
98
|
if isinstance(client, NumPyClient):
|
flwr/client/node_state.py
CHANGED
|
@@ -15,30 +15,72 @@
|
|
|
15
15
|
"""Node state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Dict, Optional
|
|
19
21
|
|
|
20
22
|
from flwr.common import Context, RecordSet
|
|
23
|
+
from flwr.common.config import get_fused_config, get_fused_config_from_dir
|
|
24
|
+
from flwr.common.typing import Run, UserConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass()
|
|
28
|
+
class RunInfo:
|
|
29
|
+
"""Contains the Context and initial run_config of a Run."""
|
|
30
|
+
|
|
31
|
+
context: Context
|
|
32
|
+
initial_run_config: UserConfig
|
|
21
33
|
|
|
22
34
|
|
|
23
35
|
class NodeState:
|
|
24
36
|
"""State of a node where client nodes execute runs."""
|
|
25
37
|
|
|
26
|
-
def __init__(
|
|
27
|
-
self
|
|
28
|
-
|
|
29
|
-
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
node_id: int,
|
|
41
|
+
node_config: UserConfig,
|
|
42
|
+
) -> None:
|
|
43
|
+
self.node_id = node_id
|
|
44
|
+
self.node_config = node_config
|
|
45
|
+
self.run_infos: Dict[int, RunInfo] = {}
|
|
30
46
|
|
|
31
|
-
def register_context(
|
|
47
|
+
def register_context(
|
|
48
|
+
self,
|
|
49
|
+
run_id: int,
|
|
50
|
+
run: Optional[Run] = None,
|
|
51
|
+
flwr_path: Optional[Path] = None,
|
|
52
|
+
app_dir: Optional[str] = None,
|
|
53
|
+
) -> None:
|
|
32
54
|
"""Register new run context for this node."""
|
|
33
|
-
if run_id not in self.
|
|
34
|
-
|
|
35
|
-
|
|
55
|
+
if run_id not in self.run_infos:
|
|
56
|
+
initial_run_config = {}
|
|
57
|
+
if app_dir:
|
|
58
|
+
# Load from app directory
|
|
59
|
+
app_path = Path(app_dir)
|
|
60
|
+
if app_path.is_dir():
|
|
61
|
+
override_config = run.override_config if run else {}
|
|
62
|
+
initial_run_config = get_fused_config_from_dir(
|
|
63
|
+
app_path, override_config
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("The specified `app_dir` must be a directory.")
|
|
67
|
+
else:
|
|
68
|
+
# Load from .fab
|
|
69
|
+
initial_run_config = get_fused_config(run, flwr_path) if run else {}
|
|
70
|
+
self.run_infos[run_id] = RunInfo(
|
|
71
|
+
initial_run_config=initial_run_config,
|
|
72
|
+
context=Context(
|
|
73
|
+
node_id=self.node_id,
|
|
74
|
+
node_config=self.node_config,
|
|
75
|
+
state=RecordSet(),
|
|
76
|
+
run_config=initial_run_config.copy(),
|
|
77
|
+
),
|
|
36
78
|
)
|
|
37
79
|
|
|
38
80
|
def retrieve_context(self, run_id: int) -> Context:
|
|
39
81
|
"""Get run context given a run_id."""
|
|
40
|
-
if run_id in self.
|
|
41
|
-
return self.
|
|
82
|
+
if run_id in self.run_infos:
|
|
83
|
+
return self.run_infos[run_id].context
|
|
42
84
|
|
|
43
85
|
raise RuntimeError(
|
|
44
86
|
f"Context for run_id={run_id} doesn't exist."
|
|
@@ -48,4 +90,9 @@ class NodeState:
|
|
|
48
90
|
|
|
49
91
|
def update_context(self, run_id: int, context: Context) -> None:
|
|
50
92
|
"""Update run context."""
|
|
51
|
-
self.
|
|
93
|
+
if context.run_config != self.run_infos[run_id].initial_run_config:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"The `run_config` field of the `Context` object cannot be "
|
|
96
|
+
f"modified (run_id: {run_id})."
|
|
97
|
+
)
|
|
98
|
+
self.run_infos[run_id].context = context
|
flwr/client/node_state_tests.py
CHANGED
|
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
|
|
|
41
41
|
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
|
|
42
42
|
|
|
43
43
|
# NodeState
|
|
44
|
-
node_state = NodeState(
|
|
44
|
+
node_state = NodeState(node_id=0, node_config={})
|
|
45
45
|
|
|
46
46
|
for task in tasks:
|
|
47
47
|
run_id = task.run_id
|
|
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
|
|
|
59
59
|
node_state.update_context(run_id=run_id, context=updated_state)
|
|
60
60
|
|
|
61
61
|
# Verify values
|
|
62
|
-
for run_id,
|
|
62
|
+
for run_id, run_info in node_state.run_infos.items():
|
|
63
63
|
assert (
|
|
64
|
-
context.state.configs_records["counter"]["count"]
|
|
64
|
+
run_info.context.state.configs_records["counter"]["count"]
|
|
65
|
+
== expected_values[run_id]
|
|
65
66
|
)
|
|
@@ -40,7 +40,12 @@ from flwr.common.constant import (
|
|
|
40
40
|
from flwr.common.logger import log
|
|
41
41
|
from flwr.common.message import Message, Metadata
|
|
42
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import (
|
|
44
|
+
message_from_taskins,
|
|
45
|
+
message_to_taskres,
|
|
46
|
+
user_config_from_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import Run
|
|
44
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
50
|
CreateNodeRequest,
|
|
46
51
|
CreateNodeResponse,
|
|
@@ -89,9 +94,9 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
89
94
|
Tuple[
|
|
90
95
|
Callable[[], Optional[Message]],
|
|
91
96
|
Callable[[Message], None],
|
|
97
|
+
Optional[Callable[[], Optional[int]]],
|
|
92
98
|
Optional[Callable[[], None]],
|
|
93
|
-
Optional[Callable[[],
|
|
94
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
99
|
+
Optional[Callable[[int], Run]],
|
|
95
100
|
]
|
|
96
101
|
]:
|
|
97
102
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -236,19 +241,20 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
236
241
|
if not ping_stop_event.is_set():
|
|
237
242
|
ping_stop_event.wait(next_interval)
|
|
238
243
|
|
|
239
|
-
def create_node() ->
|
|
244
|
+
def create_node() -> Optional[int]:
|
|
240
245
|
"""Set create_node."""
|
|
241
246
|
req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
242
247
|
|
|
243
248
|
# Send the request
|
|
244
249
|
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
|
245
250
|
if res is None:
|
|
246
|
-
return
|
|
251
|
+
return None
|
|
247
252
|
|
|
248
253
|
# Remember the node and the ping-loop thread
|
|
249
254
|
nonlocal node, ping_thread
|
|
250
255
|
node = res.node
|
|
251
256
|
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
257
|
+
return node.node_id
|
|
252
258
|
|
|
253
259
|
def delete_node() -> None:
|
|
254
260
|
"""Set delete_node."""
|
|
@@ -344,16 +350,21 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
344
350
|
res.results, # pylint: disable=no-member
|
|
345
351
|
)
|
|
346
352
|
|
|
347
|
-
def get_run(run_id: int) ->
|
|
353
|
+
def get_run(run_id: int) -> Run:
|
|
348
354
|
# Construct the request
|
|
349
355
|
req = GetRunRequest(run_id=run_id)
|
|
350
356
|
|
|
351
357
|
# Send the request
|
|
352
358
|
res = _request(req, GetRunResponse, PATH_GET_RUN)
|
|
353
359
|
if res is None:
|
|
354
|
-
return "", ""
|
|
360
|
+
return Run(run_id, "", "", {})
|
|
355
361
|
|
|
356
|
-
return
|
|
362
|
+
return Run(
|
|
363
|
+
run_id,
|
|
364
|
+
res.run.fab_id,
|
|
365
|
+
res.run.fab_version,
|
|
366
|
+
user_config_from_proto(res.run.override_config),
|
|
367
|
+
)
|
|
357
368
|
|
|
358
369
|
try:
|
|
359
370
|
# Yield methods
|
flwr/client/supernode/app.py
CHANGED
|
@@ -29,7 +29,12 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
29
29
|
|
|
30
30
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
31
31
|
from flwr.common import EventType, event
|
|
32
|
-
from flwr.common.config import
|
|
32
|
+
from flwr.common.config import (
|
|
33
|
+
get_flwr_dir,
|
|
34
|
+
get_project_config,
|
|
35
|
+
get_project_dir,
|
|
36
|
+
parse_config_args,
|
|
37
|
+
)
|
|
33
38
|
from flwr.common.constant import (
|
|
34
39
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
35
40
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
@@ -55,7 +60,12 @@ def run_supernode() -> None:
|
|
|
55
60
|
_warn_deprecated_server_arg(args)
|
|
56
61
|
|
|
57
62
|
root_certificates = _get_certificates(args)
|
|
58
|
-
load_fn = _get_load_client_app_fn(
|
|
63
|
+
load_fn = _get_load_client_app_fn(
|
|
64
|
+
default_app_ref=getattr(args, "client-app"),
|
|
65
|
+
project_dir=args.dir,
|
|
66
|
+
flwr_dir=args.flwr_dir,
|
|
67
|
+
multi_app=True,
|
|
68
|
+
)
|
|
59
69
|
authentication_keys = _try_setup_client_authentication(args)
|
|
60
70
|
|
|
61
71
|
_start_client_internal(
|
|
@@ -67,7 +77,8 @@ def run_supernode() -> None:
|
|
|
67
77
|
authentication_keys=authentication_keys,
|
|
68
78
|
max_retries=args.max_retries,
|
|
69
79
|
max_wait_time=args.max_wait_time,
|
|
70
|
-
|
|
80
|
+
node_config=parse_config_args([args.node_config]),
|
|
81
|
+
flwr_path=get_flwr_dir(args.flwr_dir),
|
|
71
82
|
)
|
|
72
83
|
|
|
73
84
|
# Graceful shutdown
|
|
@@ -87,11 +98,16 @@ def run_client_app() -> None:
|
|
|
87
98
|
_warn_deprecated_server_arg(args)
|
|
88
99
|
|
|
89
100
|
root_certificates = _get_certificates(args)
|
|
90
|
-
load_fn = _get_load_client_app_fn(
|
|
101
|
+
load_fn = _get_load_client_app_fn(
|
|
102
|
+
default_app_ref=getattr(args, "client-app"),
|
|
103
|
+
project_dir=args.dir,
|
|
104
|
+
multi_app=False,
|
|
105
|
+
)
|
|
91
106
|
authentication_keys = _try_setup_client_authentication(args)
|
|
92
107
|
|
|
93
108
|
_start_client_internal(
|
|
94
109
|
server_address=args.superlink,
|
|
110
|
+
node_config=parse_config_args([args.node_config]),
|
|
95
111
|
load_client_app_fn=load_fn,
|
|
96
112
|
transport=args.transport,
|
|
97
113
|
root_certificates=root_certificates,
|
|
@@ -159,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
159
175
|
|
|
160
176
|
|
|
161
177
|
def _get_load_client_app_fn(
|
|
162
|
-
|
|
178
|
+
default_app_ref: str,
|
|
179
|
+
project_dir: str,
|
|
180
|
+
multi_app: bool,
|
|
181
|
+
flwr_dir: Optional[str] = None,
|
|
163
182
|
) -> Callable[[str, str], ClientApp]:
|
|
164
183
|
"""Get the load_client_app_fn function.
|
|
165
184
|
|
|
@@ -170,34 +189,21 @@ def _get_load_client_app_fn(
|
|
|
170
189
|
If `multi_app` is False, it ignores `fab_id` and `fab_version` and
|
|
171
190
|
loads a default ClientApp.
|
|
172
191
|
"""
|
|
173
|
-
# Find the Flower directory containing Flower Apps (only for multi-app)
|
|
174
|
-
flwr_dir = Path("")
|
|
175
|
-
if "flwr_dir" in args:
|
|
176
|
-
if args.flwr_dir is None:
|
|
177
|
-
flwr_dir = get_flwr_dir()
|
|
178
|
-
else:
|
|
179
|
-
flwr_dir = Path(args.flwr_dir).absolute()
|
|
180
|
-
|
|
181
|
-
sys.path.insert(0, str(flwr_dir.absolute()))
|
|
182
|
-
|
|
183
|
-
default_app_ref: str = getattr(args, "client-app")
|
|
184
|
-
|
|
185
192
|
if not multi_app:
|
|
186
193
|
log(
|
|
187
194
|
DEBUG,
|
|
188
195
|
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
189
|
-
|
|
196
|
+
default_app_ref,
|
|
190
197
|
)
|
|
191
|
-
|
|
198
|
+
|
|
199
|
+
valid, error_msg = validate(default_app_ref, project_dir=project_dir)
|
|
192
200
|
if not valid and error_msg:
|
|
193
201
|
raise LoadClientAppError(error_msg) from None
|
|
194
202
|
|
|
195
203
|
def _load(fab_id: str, fab_version: str) -> ClientApp:
|
|
204
|
+
runtime_project_dir = Path(project_dir).absolute()
|
|
196
205
|
# If multi-app feature is disabled
|
|
197
206
|
if not multi_app:
|
|
198
|
-
# Get sys path to be inserted
|
|
199
|
-
sys_path = Path(args.dir).absolute()
|
|
200
|
-
|
|
201
207
|
# Set app reference
|
|
202
208
|
client_app_ref = default_app_ref
|
|
203
209
|
# If multi-app feature is enabled but the fab id is not specified
|
|
@@ -208,27 +214,21 @@ def _get_load_client_app_fn(
|
|
|
208
214
|
) from None
|
|
209
215
|
|
|
210
216
|
log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
211
|
-
# Get sys path to be inserted
|
|
212
|
-
sys_path = Path(args.dir).absolute()
|
|
213
217
|
|
|
214
218
|
# Set app reference
|
|
215
219
|
client_app_ref = default_app_ref
|
|
216
220
|
# If multi-app feature is enabled
|
|
217
221
|
else:
|
|
218
222
|
try:
|
|
219
|
-
|
|
220
|
-
|
|
223
|
+
runtime_project_dir = get_project_dir(
|
|
224
|
+
fab_id, fab_version, get_flwr_dir(flwr_dir)
|
|
225
|
+
)
|
|
226
|
+
config = get_project_config(runtime_project_dir)
|
|
221
227
|
except Exception as e:
|
|
222
228
|
raise LoadClientAppError("Failed to load ClientApp") from e
|
|
223
229
|
|
|
224
|
-
# Get sys path to be inserted
|
|
225
|
-
sys_path = Path(project_dir).absolute()
|
|
226
|
-
|
|
227
230
|
# Set app reference
|
|
228
|
-
client_app_ref = config["
|
|
229
|
-
|
|
230
|
-
# Set sys.path
|
|
231
|
-
sys.path.insert(0, str(sys_path))
|
|
231
|
+
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
232
232
|
|
|
233
233
|
# Load ClientApp
|
|
234
234
|
log(
|
|
@@ -236,7 +236,7 @@ def _get_load_client_app_fn(
|
|
|
236
236
|
"Loading ClientApp `%s`",
|
|
237
237
|
client_app_ref,
|
|
238
238
|
)
|
|
239
|
-
client_app = load_app(client_app_ref, LoadClientAppError,
|
|
239
|
+
client_app = load_app(client_app_ref, LoadClientAppError, runtime_project_dir)
|
|
240
240
|
|
|
241
241
|
if not isinstance(client_app, ClientApp):
|
|
242
242
|
raise LoadClientAppError(
|
|
@@ -375,11 +375,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
375
375
|
help="The SuperNode's public key (as a path str) to enable authentication.",
|
|
376
376
|
)
|
|
377
377
|
parser.add_argument(
|
|
378
|
-
"--
|
|
379
|
-
type=
|
|
380
|
-
help="
|
|
381
|
-
"
|
|
382
|
-
"
|
|
378
|
+
"--node-config",
|
|
379
|
+
type=str,
|
|
380
|
+
help="A comma separated list of key/value pairs (separated by `=`) to "
|
|
381
|
+
"configure the SuperNode. "
|
|
382
|
+
"E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
|
|
383
383
|
)
|
|
384
384
|
|
|
385
385
|
|
flwr/client/typing.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Custom types for Flower clients."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Callable
|
|
18
|
+
from typing import Callable
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context, Message
|
|
21
21
|
|
|
@@ -23,7 +23,7 @@ from .client import Client as Client
|
|
|
23
23
|
|
|
24
24
|
# Compatibility
|
|
25
25
|
ClientFn = Callable[[str], Client]
|
|
26
|
-
ClientFnExt = Callable[[
|
|
26
|
+
ClientFnExt = Callable[[Context], Client]
|
|
27
27
|
|
|
28
28
|
ClientAppCallable = Callable[[Message, Context], Message]
|
|
29
29
|
Mod = Callable[[Message, Context, ClientAppCallable], Message]
|
flwr/common/config.py
CHANGED
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
18
|
from pathlib import Path
|
|
19
|
-
from typing import Any, Dict, Optional, Union
|
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
|
|
20
20
|
|
|
21
21
|
import tomli
|
|
22
22
|
|
|
23
23
|
from flwr.cli.config_utils import validate_fields
|
|
24
24
|
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
|
|
25
|
+
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
@@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
|
30
31
|
return Path(
|
|
31
32
|
os.getenv(
|
|
32
33
|
FLWR_HOME,
|
|
33
|
-
f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}
|
|
34
|
+
Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
|
|
34
35
|
)
|
|
35
36
|
)
|
|
36
37
|
return Path(provided_path).absolute()
|
|
@@ -71,3 +72,92 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
|
|
|
71
72
|
)
|
|
72
73
|
|
|
73
74
|
return config
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _fuse_dicts(
|
|
78
|
+
main_dict: UserConfig,
|
|
79
|
+
override_dict: UserConfig,
|
|
80
|
+
) -> UserConfig:
|
|
81
|
+
fused_dict = main_dict.copy()
|
|
82
|
+
|
|
83
|
+
for key, value in override_dict.items():
|
|
84
|
+
if key in main_dict:
|
|
85
|
+
fused_dict[key] = value
|
|
86
|
+
|
|
87
|
+
return fused_dict
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_fused_config_from_dir(
|
|
91
|
+
project_dir: Path, override_config: UserConfig
|
|
92
|
+
) -> UserConfig:
|
|
93
|
+
"""Merge the overrides from a given dict with the config from a Flower App."""
|
|
94
|
+
default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
|
|
95
|
+
"config", {}
|
|
96
|
+
)
|
|
97
|
+
flat_default_config = flatten_dict(default_config)
|
|
98
|
+
|
|
99
|
+
return _fuse_dicts(flat_default_config, override_config)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
|
103
|
+
"""Merge the overrides from a `Run` with the config from a FAB.
|
|
104
|
+
|
|
105
|
+
Get the config using the fab_id and the fab_version, remove the nesting by adding
|
|
106
|
+
the nested keys as prefixes separated by dots, and fuse it with the override dict.
|
|
107
|
+
"""
|
|
108
|
+
if not run.fab_id or not run.fab_version:
|
|
109
|
+
return {}
|
|
110
|
+
|
|
111
|
+
project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
|
|
112
|
+
|
|
113
|
+
return get_fused_config_from_dir(project_dir, run.override_config)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def flatten_dict(
|
|
117
|
+
raw_dict: Optional[Dict[str, Any]], parent_key: str = ""
|
|
118
|
+
) -> UserConfig:
|
|
119
|
+
"""Flatten dict by joining nested keys with a given separator."""
|
|
120
|
+
if raw_dict is None:
|
|
121
|
+
return {}
|
|
122
|
+
|
|
123
|
+
items: List[Tuple[str, UserConfigValue]] = []
|
|
124
|
+
separator: str = "."
|
|
125
|
+
for k, v in raw_dict.items():
|
|
126
|
+
new_key = f"{parent_key}{separator}{k}" if parent_key else k
|
|
127
|
+
if isinstance(v, dict):
|
|
128
|
+
items.extend(flatten_dict(v, parent_key=new_key).items())
|
|
129
|
+
elif isinstance(v, get_args(UserConfigValue)):
|
|
130
|
+
items.append((new_key, cast(UserConfigValue, v)))
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"The value for key {k} needs to be of type `int`, `float`, "
|
|
134
|
+
"`bool, `str`, or a `dict` of those.",
|
|
135
|
+
)
|
|
136
|
+
return dict(items)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def parse_config_args(
|
|
140
|
+
config: Optional[List[str]],
|
|
141
|
+
separator: str = ",",
|
|
142
|
+
) -> UserConfig:
|
|
143
|
+
"""Parse separator separated list of key-value pairs separated by '='."""
|
|
144
|
+
overrides: UserConfig = {}
|
|
145
|
+
|
|
146
|
+
if config is None:
|
|
147
|
+
return overrides
|
|
148
|
+
|
|
149
|
+
for config_line in config:
|
|
150
|
+
if config_line:
|
|
151
|
+
overrides_list = config_line.split(separator)
|
|
152
|
+
if (
|
|
153
|
+
len(overrides_list) == 1
|
|
154
|
+
and "=" not in overrides_list
|
|
155
|
+
and overrides_list[0].endswith(".toml")
|
|
156
|
+
):
|
|
157
|
+
with Path(overrides_list[0]).open("rb") as config_file:
|
|
158
|
+
overrides = flatten_dict(tomli.load(config_file))
|
|
159
|
+
else:
|
|
160
|
+
toml_str = "\n".join(overrides_list)
|
|
161
|
+
overrides.update(tomli.loads(toml_str))
|
|
162
|
+
|
|
163
|
+
return overrides
|
flwr/common/constant.py
CHANGED
|
@@ -57,6 +57,9 @@ APP_DIR = "apps"
|
|
|
57
57
|
FAB_CONFIG_FILE = "pyproject.toml"
|
|
58
58
|
FLWR_HOME = "FLWR_HOME"
|
|
59
59
|
|
|
60
|
+
# Constants entries in Node config for Simulation
|
|
61
|
+
PARTITION_ID_KEY = "partition-id"
|
|
62
|
+
NUM_PARTITIONS_KEY = "num-partitions"
|
|
60
63
|
|
|
61
64
|
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
|
|
62
65
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
|
flwr/common/context.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from .record import RecordSet
|
|
21
|
+
from .typing import UserConfig
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@dataclass
|
|
@@ -27,6 +27,11 @@ class Context:
|
|
|
27
27
|
|
|
28
28
|
Parameters
|
|
29
29
|
----------
|
|
30
|
+
node_id : int
|
|
31
|
+
The ID that identifies the node.
|
|
32
|
+
node_config : UserConfig
|
|
33
|
+
A config (key/value mapping) unique to the node and independent of the
|
|
34
|
+
`run_config`. This config persists across all runs this node participates in.
|
|
30
35
|
state : RecordSet
|
|
31
36
|
Holds records added by the entity in a given run and that will stay local.
|
|
32
37
|
This means that the data it holds will never leave the system it's running from.
|
|
@@ -34,15 +39,25 @@ class Context:
|
|
|
34
39
|
executing mods. It can also be used as a memory to access
|
|
35
40
|
at different points during the lifecycle of this entity (e.g. across
|
|
36
41
|
multiple rounds)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
42
|
+
run_config : UserConfig
|
|
43
|
+
A config (key/value mapping) held by the entity in a given run and that will
|
|
44
|
+
stay local. It can be used at any point during the lifecycle of this entity
|
|
45
|
+
(e.g. across multiple rounds)
|
|
41
46
|
"""
|
|
42
47
|
|
|
48
|
+
node_id: int
|
|
49
|
+
node_config: UserConfig
|
|
43
50
|
state: RecordSet
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def __init__(
|
|
51
|
+
run_config: UserConfig
|
|
52
|
+
|
|
53
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
54
|
+
self,
|
|
55
|
+
node_id: int,
|
|
56
|
+
node_config: UserConfig,
|
|
57
|
+
state: RecordSet,
|
|
58
|
+
run_config: UserConfig,
|
|
59
|
+
) -> None:
|
|
60
|
+
self.node_id = node_id
|
|
61
|
+
self.node_config = node_config
|
|
47
62
|
self.state = state
|
|
48
|
-
self.
|
|
63
|
+
self.run_config = run_config
|
flwr/common/logger.py
CHANGED
|
@@ -197,6 +197,31 @@ def warn_deprecated_feature(name: str) -> None:
|
|
|
197
197
|
)
|
|
198
198
|
|
|
199
199
|
|
|
200
|
+
def warn_deprecated_feature_with_example(
|
|
201
|
+
deprecation_message: str, example_message: str, code_example: str
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Warn if a feature is deprecated and show code example."""
|
|
204
|
+
log(
|
|
205
|
+
WARN,
|
|
206
|
+
"""DEPRECATED FEATURE: %s
|
|
207
|
+
|
|
208
|
+
Check the following `FEATURE UPDATE` warning message for the preferred
|
|
209
|
+
new mechanism to use this feature in Flower.
|
|
210
|
+
""",
|
|
211
|
+
deprecation_message,
|
|
212
|
+
)
|
|
213
|
+
log(
|
|
214
|
+
WARN,
|
|
215
|
+
"""FEATURE UPDATE: %s
|
|
216
|
+
------------------------------------------------------------
|
|
217
|
+
%s
|
|
218
|
+
------------------------------------------------------------
|
|
219
|
+
""",
|
|
220
|
+
example_message,
|
|
221
|
+
code_example,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
200
225
|
def warn_unsupported_feature(name: str) -> None:
|
|
201
226
|
"""Warn the user when they use an unsupported feature."""
|
|
202
227
|
log(
|