flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
flwr/server/compat/app.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower grid app."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import INFO
|
|
@@ -25,27 +25,27 @@ from flwr.server.server import Server, init_defaults, run_fl
|
|
|
25
25
|
from flwr.server.server_config import ServerConfig
|
|
26
26
|
from flwr.server.strategy import Strategy
|
|
27
27
|
|
|
28
|
-
from ..
|
|
28
|
+
from ..grid import Grid
|
|
29
29
|
from .app_utils import start_update_client_manager_thread
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def
|
|
32
|
+
def start_grid( # pylint: disable=too-many-arguments, too-many-locals
|
|
33
33
|
*,
|
|
34
|
-
|
|
34
|
+
grid: Grid,
|
|
35
35
|
server: Optional[Server] = None,
|
|
36
36
|
config: Optional[ServerConfig] = None,
|
|
37
37
|
strategy: Optional[Strategy] = None,
|
|
38
38
|
client_manager: Optional[ClientManager] = None,
|
|
39
39
|
) -> History:
|
|
40
|
-
"""Start a Flower
|
|
40
|
+
"""Start a Flower server.
|
|
41
41
|
|
|
42
42
|
Parameters
|
|
43
43
|
----------
|
|
44
|
-
|
|
45
|
-
The
|
|
44
|
+
grid : Grid
|
|
45
|
+
The Grid object to use.
|
|
46
46
|
server : Optional[flwr.server.Server] (default: None)
|
|
47
47
|
A server implementation, either `flwr.server.Server` or a subclass
|
|
48
|
-
thereof. If no instance is provided, then `
|
|
48
|
+
thereof. If no instance is provided, then `start_grid` will create
|
|
49
49
|
one.
|
|
50
50
|
config : Optional[ServerConfig] (default: None)
|
|
51
51
|
Currently supported values are `num_rounds` (int, default: 1) and
|
|
@@ -56,7 +56,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
56
56
|
`start_server` will use `flwr.server.strategy.FedAvg`.
|
|
57
57
|
client_manager : Optional[flwr.server.ClientManager] (default: None)
|
|
58
58
|
An implementation of the class `flwr.server.ClientManager`. If no
|
|
59
|
-
implementation is provided, then `
|
|
59
|
+
implementation is provided, then `start_grid` will use
|
|
60
60
|
`flwr.server.SimpleClientManager`.
|
|
61
61
|
|
|
62
62
|
Returns
|
|
@@ -64,7 +64,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
64
64
|
hist : flwr.server.history.History
|
|
65
65
|
Object containing training and evaluation metrics.
|
|
66
66
|
"""
|
|
67
|
-
# Initialize the
|
|
67
|
+
# Initialize the server and config
|
|
68
68
|
initialized_server, initialized_config = init_defaults(
|
|
69
69
|
server=server,
|
|
70
70
|
config=config,
|
|
@@ -79,10 +79,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
79
79
|
log(INFO, "")
|
|
80
80
|
|
|
81
81
|
# Start the thread updating nodes
|
|
82
|
-
thread, f_stop = start_update_client_manager_thread(
|
|
83
|
-
|
|
82
|
+
thread, f_stop, c_done = start_update_client_manager_thread(
|
|
83
|
+
grid, initialized_server.client_manager()
|
|
84
84
|
)
|
|
85
85
|
|
|
86
|
+
# Wait until the node registration done
|
|
87
|
+
c_done.wait()
|
|
88
|
+
|
|
86
89
|
# Start training
|
|
87
90
|
hist = run_fl(
|
|
88
91
|
server=initialized_server,
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Utility functions for the `
|
|
15
|
+
"""Utility functions for the `start_grid`."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
@@ -20,18 +20,18 @@ import threading
|
|
|
20
20
|
from flwr.common.typing import RunNotRunningException
|
|
21
21
|
|
|
22
22
|
from ..client_manager import ClientManager
|
|
23
|
-
from ..
|
|
24
|
-
from
|
|
23
|
+
from ..grid import Grid
|
|
24
|
+
from .grid_client_proxy import GridClientProxy
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def start_update_client_manager_thread(
|
|
28
|
-
|
|
28
|
+
grid: Grid,
|
|
29
29
|
client_manager: ClientManager,
|
|
30
|
-
) -> tuple[threading.Thread, threading.Event]:
|
|
30
|
+
) -> tuple[threading.Thread, threading.Event, threading.Event]:
|
|
31
31
|
"""Periodically update the nodes list in the client manager in a thread.
|
|
32
32
|
|
|
33
|
-
This function starts a thread that periodically uses the associated
|
|
34
|
-
get all node_ids. Each node_id is then converted into a `
|
|
33
|
+
This function starts a thread that periodically uses the associated grid to
|
|
34
|
+
get all node_ids. Each node_id is then converted into a `GridClientProxy`
|
|
35
35
|
instance and stored in the `registered_nodes` dictionary with node_id as key.
|
|
36
36
|
|
|
37
37
|
New nodes will be added to the ClientManager via `client_manager.register()`,
|
|
@@ -40,8 +40,8 @@ def start_update_client_manager_thread(
|
|
|
40
40
|
|
|
41
41
|
Parameters
|
|
42
42
|
----------
|
|
43
|
-
|
|
44
|
-
The
|
|
43
|
+
grid : Grid
|
|
44
|
+
The Grid object to use.
|
|
45
45
|
client_manager : ClientManager
|
|
46
46
|
The ClientManager object to be updated.
|
|
47
47
|
|
|
@@ -51,33 +51,38 @@ def start_update_client_manager_thread(
|
|
|
51
51
|
A thread that updates the ClientManager and handles the stop event.
|
|
52
52
|
threading.Event
|
|
53
53
|
An event that, when set, signals the thread to stop.
|
|
54
|
+
threading.Event
|
|
55
|
+
An event that, when set, signals the node registration done.
|
|
54
56
|
"""
|
|
55
57
|
f_stop = threading.Event()
|
|
58
|
+
c_done = threading.Event()
|
|
56
59
|
thread = threading.Thread(
|
|
57
60
|
target=_update_client_manager,
|
|
58
61
|
args=(
|
|
59
|
-
|
|
62
|
+
grid,
|
|
60
63
|
client_manager,
|
|
61
64
|
f_stop,
|
|
65
|
+
c_done,
|
|
62
66
|
),
|
|
63
67
|
daemon=True,
|
|
64
68
|
)
|
|
65
69
|
thread.start()
|
|
66
70
|
|
|
67
|
-
return thread, f_stop
|
|
71
|
+
return thread, f_stop, c_done
|
|
68
72
|
|
|
69
73
|
|
|
70
74
|
def _update_client_manager(
|
|
71
|
-
|
|
75
|
+
grid: Grid,
|
|
72
76
|
client_manager: ClientManager,
|
|
73
77
|
f_stop: threading.Event,
|
|
78
|
+
c_done: threading.Event,
|
|
74
79
|
) -> None:
|
|
75
80
|
"""Update the nodes list in the client manager."""
|
|
76
|
-
# Loop until the
|
|
77
|
-
registered_nodes: dict[int,
|
|
81
|
+
# Loop until the grid is disconnected
|
|
82
|
+
registered_nodes: dict[int, GridClientProxy] = {}
|
|
78
83
|
while not f_stop.is_set():
|
|
79
84
|
try:
|
|
80
|
-
all_node_ids = set(
|
|
85
|
+
all_node_ids = set(grid.get_node_ids())
|
|
81
86
|
except RunNotRunningException:
|
|
82
87
|
f_stop.set()
|
|
83
88
|
break
|
|
@@ -92,16 +97,19 @@ def _update_client_manager(
|
|
|
92
97
|
|
|
93
98
|
# Register new nodes
|
|
94
99
|
for node_id in new_nodes:
|
|
95
|
-
client_proxy =
|
|
100
|
+
client_proxy = GridClientProxy(
|
|
96
101
|
node_id=node_id,
|
|
97
|
-
|
|
98
|
-
run_id=
|
|
102
|
+
grid=grid,
|
|
103
|
+
run_id=grid.run.run_id,
|
|
99
104
|
)
|
|
100
105
|
if client_manager.register(client_proxy):
|
|
101
106
|
registered_nodes[node_id] = client_proxy
|
|
102
107
|
else:
|
|
103
108
|
raise RuntimeError("Could not register node.")
|
|
104
109
|
|
|
110
|
+
# Flag first pass for nodes registration is completed
|
|
111
|
+
c_done.set()
|
|
112
|
+
|
|
105
113
|
# Sleep for 3 seconds
|
|
106
114
|
if not f_stop.is_set():
|
|
107
115
|
f_stop.wait(3)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,26 +12,26 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower ClientProxy implementation
|
|
15
|
+
"""Flower ClientProxy implementation using Grid."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from typing import Optional
|
|
19
19
|
|
|
20
20
|
from flwr import common
|
|
21
|
-
from flwr.common import Message, MessageType, MessageTypeLegacy,
|
|
22
|
-
from flwr.common import
|
|
21
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordDict
|
|
22
|
+
from flwr.common import recorddict_compat as compat
|
|
23
23
|
from flwr.server.client_proxy import ClientProxy
|
|
24
24
|
|
|
25
|
-
from ..
|
|
25
|
+
from ..grid.grid import Grid
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class
|
|
29
|
-
"""Flower client proxy which delegates work using
|
|
28
|
+
class GridClientProxy(ClientProxy):
|
|
29
|
+
"""Flower client proxy which delegates work using Grid."""
|
|
30
30
|
|
|
31
|
-
def __init__(self, node_id: int,
|
|
31
|
+
def __init__(self, node_id: int, grid: Grid, run_id: int):
|
|
32
32
|
super().__init__(str(node_id))
|
|
33
33
|
self.node_id = node_id
|
|
34
|
-
self.
|
|
34
|
+
self.grid = grid
|
|
35
35
|
self.run_id = run_id
|
|
36
36
|
|
|
37
37
|
def get_properties(
|
|
@@ -41,14 +41,14 @@ class DriverClientProxy(ClientProxy):
|
|
|
41
41
|
group_id: Optional[int],
|
|
42
42
|
) -> common.GetPropertiesRes:
|
|
43
43
|
"""Return client's properties."""
|
|
44
|
-
# Ins to
|
|
45
|
-
|
|
44
|
+
# Ins to RecordDict
|
|
45
|
+
out_recorddict = compat.getpropertiesins_to_recorddict(ins)
|
|
46
46
|
# Fetch response
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
in_recorddict = self._send_receive_recorddict(
|
|
48
|
+
out_recorddict, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
|
|
49
49
|
)
|
|
50
|
-
#
|
|
51
|
-
return compat.
|
|
50
|
+
# RecordDict to Res
|
|
51
|
+
return compat.recorddict_to_getpropertiesres(in_recorddict)
|
|
52
52
|
|
|
53
53
|
def get_parameters(
|
|
54
54
|
self,
|
|
@@ -57,40 +57,40 @@ class DriverClientProxy(ClientProxy):
|
|
|
57
57
|
group_id: Optional[int],
|
|
58
58
|
) -> common.GetParametersRes:
|
|
59
59
|
"""Return the current local model parameters."""
|
|
60
|
-
# Ins to
|
|
61
|
-
|
|
60
|
+
# Ins to RecordDict
|
|
61
|
+
out_recorddict = compat.getparametersins_to_recorddict(ins)
|
|
62
62
|
# Fetch response
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
in_recorddict = self._send_receive_recorddict(
|
|
64
|
+
out_recorddict, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
|
|
65
65
|
)
|
|
66
|
-
#
|
|
67
|
-
return compat.
|
|
66
|
+
# RecordDict to Res
|
|
67
|
+
return compat.recorddict_to_getparametersres(in_recorddict, False)
|
|
68
68
|
|
|
69
69
|
def fit(
|
|
70
70
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
|
71
71
|
) -> common.FitRes:
|
|
72
72
|
"""Train model parameters on the locally held dataset."""
|
|
73
|
-
# Ins to
|
|
74
|
-
|
|
73
|
+
# Ins to RecordDict
|
|
74
|
+
out_recorddict = compat.fitins_to_recorddict(ins, keep_input=True)
|
|
75
75
|
# Fetch response
|
|
76
|
-
|
|
77
|
-
|
|
76
|
+
in_recorddict = self._send_receive_recorddict(
|
|
77
|
+
out_recorddict, MessageType.TRAIN, timeout, group_id
|
|
78
78
|
)
|
|
79
|
-
#
|
|
80
|
-
return compat.
|
|
79
|
+
# RecordDict to Res
|
|
80
|
+
return compat.recorddict_to_fitres(in_recorddict, keep_input=False)
|
|
81
81
|
|
|
82
82
|
def evaluate(
|
|
83
83
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
|
84
84
|
) -> common.EvaluateRes:
|
|
85
85
|
"""Evaluate model parameters on the locally held dataset."""
|
|
86
|
-
# Ins to
|
|
87
|
-
|
|
86
|
+
# Ins to RecordDict
|
|
87
|
+
out_recorddict = compat.evaluateins_to_recorddict(ins, keep_input=True)
|
|
88
88
|
# Fetch response
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
in_recorddict = self._send_receive_recorddict(
|
|
90
|
+
out_recorddict, MessageType.EVALUATE, timeout, group_id
|
|
91
91
|
)
|
|
92
|
-
#
|
|
93
|
-
return compat.
|
|
92
|
+
# RecordDict to Res
|
|
93
|
+
return compat.recorddict_to_evaluateres(in_recorddict)
|
|
94
94
|
|
|
95
95
|
def reconnect(
|
|
96
96
|
self,
|
|
@@ -101,25 +101,25 @@ class DriverClientProxy(ClientProxy):
|
|
|
101
101
|
"""Disconnect and (optionally) reconnect later."""
|
|
102
102
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
|
103
103
|
|
|
104
|
-
def
|
|
104
|
+
def _send_receive_recorddict(
|
|
105
105
|
self,
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
recorddict: RecordDict,
|
|
107
|
+
message_type: str,
|
|
108
108
|
timeout: Optional[float],
|
|
109
109
|
group_id: Optional[int],
|
|
110
|
-
) ->
|
|
110
|
+
) -> RecordDict:
|
|
111
111
|
|
|
112
112
|
# Create message
|
|
113
|
-
message =
|
|
114
|
-
content=
|
|
115
|
-
message_type=
|
|
113
|
+
message = Message(
|
|
114
|
+
content=recorddict,
|
|
115
|
+
message_type=message_type,
|
|
116
116
|
dst_node_id=self.node_id,
|
|
117
117
|
group_id=str(group_id) if group_id else "",
|
|
118
118
|
ttl=timeout,
|
|
119
119
|
)
|
|
120
120
|
|
|
121
121
|
# Send message and wait for reply
|
|
122
|
-
messages = list(self.
|
|
122
|
+
messages = list(self.grid.send_and_receive(messages=[message]))
|
|
123
123
|
|
|
124
124
|
# A single reply is expected
|
|
125
125
|
if len(messages) != 1:
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower Fleet API event log interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from typing import Any, Callable, cast
|
|
19
|
+
|
|
20
|
+
import grpc
|
|
21
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
22
|
+
|
|
23
|
+
from flwr.common.event_log_plugin.event_log_plugin import EventLogWriterPlugin
|
|
24
|
+
from flwr.common.typing import LogEntry
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
28
|
+
"""Fleet API interceptor for logging events."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, log_plugin: EventLogWriterPlugin) -> None:
|
|
31
|
+
self.log_plugin = log_plugin
|
|
32
|
+
|
|
33
|
+
def intercept_service(
|
|
34
|
+
self,
|
|
35
|
+
continuation: Callable[[Any], Any],
|
|
36
|
+
handler_call_details: grpc.HandlerCallDetails,
|
|
37
|
+
) -> grpc.RpcMethodHandler:
|
|
38
|
+
"""Flower Fleet API server interceptor logging logic.
|
|
39
|
+
|
|
40
|
+
Intercept all unary-unary calls from users and log the event. Continue RPC call
|
|
41
|
+
if event logger is enabled on the SuperLink, else, terminate RPC call by setting
|
|
42
|
+
context to abort.
|
|
43
|
+
"""
|
|
44
|
+
# One of the method handlers in
|
|
45
|
+
# `flwr.server.superlink.fleet.grpc_rere.fleet_servicer.FleetServicer`
|
|
46
|
+
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
47
|
+
method_name: str = handler_call_details.method
|
|
48
|
+
return self._generic_event_log_unary_method_handler(method_handler, method_name)
|
|
49
|
+
|
|
50
|
+
def _generic_event_log_unary_method_handler(
|
|
51
|
+
self, method_handler: grpc.RpcMethodHandler, method_name: str
|
|
52
|
+
) -> grpc.RpcMethodHandler:
|
|
53
|
+
def _generic_method_handler(
|
|
54
|
+
request: GrpcMessage,
|
|
55
|
+
context: grpc.ServicerContext,
|
|
56
|
+
) -> GrpcMessage:
|
|
57
|
+
log_entry: LogEntry
|
|
58
|
+
# Log before call
|
|
59
|
+
log_entry = self.log_plugin.compose_log_before_event(
|
|
60
|
+
request=request,
|
|
61
|
+
context=context,
|
|
62
|
+
user_info=None,
|
|
63
|
+
method_name=method_name,
|
|
64
|
+
)
|
|
65
|
+
self.log_plugin.write_log(log_entry)
|
|
66
|
+
|
|
67
|
+
call = method_handler.unary_unary
|
|
68
|
+
unary_response, error = None, None
|
|
69
|
+
try:
|
|
70
|
+
unary_response = cast(GrpcMessage, call(request, context))
|
|
71
|
+
except BaseException as e:
|
|
72
|
+
error = e
|
|
73
|
+
raise
|
|
74
|
+
finally:
|
|
75
|
+
log_entry = self.log_plugin.compose_log_after_event(
|
|
76
|
+
request=request,
|
|
77
|
+
context=context,
|
|
78
|
+
user_info=None,
|
|
79
|
+
method_name=method_name,
|
|
80
|
+
response=unary_response or error,
|
|
81
|
+
)
|
|
82
|
+
self.log_plugin.write_log(log_entry)
|
|
83
|
+
return unary_response
|
|
84
|
+
|
|
85
|
+
if method_handler.unary_unary:
|
|
86
|
+
message_handler = grpc.unary_unary_rpc_method_handler
|
|
87
|
+
else:
|
|
88
|
+
# If the method type is not `unary_unary` raise an error
|
|
89
|
+
raise NotImplementedError("This RPC method type is not supported.")
|
|
90
|
+
return message_handler(
|
|
91
|
+
_generic_method_handler,
|
|
92
|
+
request_deserializer=method_handler.request_deserializer,
|
|
93
|
+
response_serializer=method_handler.response_serializer,
|
|
94
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,15 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower grid SDK."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
from .
|
|
18
|
+
from .grid import Driver, Grid
|
|
19
|
+
from .grpc_grid import GrpcGrid
|
|
20
|
+
from .inmemory_grid import InMemoryGrid
|
|
21
21
|
|
|
22
22
|
__all__ = [
|
|
23
23
|
"Driver",
|
|
24
|
-
"
|
|
25
|
-
"
|
|
24
|
+
"Grid",
|
|
25
|
+
"GrpcGrid",
|
|
26
|
+
"InMemoryGrid",
|
|
26
27
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,32 +12,32 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Grid (abstract base class)."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from collections.abc import Iterable
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
|
-
from flwr.common import Message,
|
|
22
|
+
from flwr.common import Message, RecordDict
|
|
23
23
|
from flwr.common.typing import Run
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
27
|
-
"""Abstract base
|
|
26
|
+
class Grid(ABC):
|
|
27
|
+
"""Abstract base class Grid to send/receive messages."""
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
30
|
def set_run(self, run_id: int) -> None:
|
|
31
31
|
"""Request a run to the SuperLink with a given `run_id`.
|
|
32
32
|
|
|
33
|
-
If a Run with the specified
|
|
33
|
+
If a ``Run`` with the specified ``run_id`` exists, a local ``Run``
|
|
34
34
|
object will be created. It enables further functionality
|
|
35
|
-
in the
|
|
35
|
+
in the grid, such as sending ``Message``s.
|
|
36
36
|
|
|
37
37
|
Parameters
|
|
38
38
|
----------
|
|
39
39
|
run_id : int
|
|
40
|
-
The
|
|
40
|
+
The ``run_id`` of the ``Run`` this ``Grid`` object operates in.
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
43
|
@property
|
|
@@ -48,7 +48,7 @@ class Driver(ABC):
|
|
|
48
48
|
@abstractmethod
|
|
49
49
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
50
50
|
self,
|
|
51
|
-
content:
|
|
51
|
+
content: RecordDict,
|
|
52
52
|
message_type: str,
|
|
53
53
|
dst_node_id: int,
|
|
54
54
|
group_id: str,
|
|
@@ -56,12 +56,12 @@ class Driver(ABC):
|
|
|
56
56
|
) -> Message:
|
|
57
57
|
"""Create a new message with specified parameters.
|
|
58
58
|
|
|
59
|
-
This method constructs a new
|
|
60
|
-
The
|
|
59
|
+
This method constructs a new ``Message`` with given content and metadata.
|
|
60
|
+
The ``run_id`` and ``src_node_id`` will be set automatically.
|
|
61
61
|
|
|
62
62
|
Parameters
|
|
63
63
|
----------
|
|
64
|
-
content :
|
|
64
|
+
content : RecordDict
|
|
65
65
|
The content for the new message. This holds records that are to be sent
|
|
66
66
|
to the destination node.
|
|
67
67
|
message_type : str
|
|
@@ -71,12 +71,12 @@ class Driver(ABC):
|
|
|
71
71
|
The ID of the destination node to which the message is being sent.
|
|
72
72
|
group_id : str
|
|
73
73
|
The ID of the group to which this message is associated. In some settings,
|
|
74
|
-
this is used as the
|
|
74
|
+
this is used as the federated learning round.
|
|
75
75
|
ttl : Optional[float] (default: None)
|
|
76
76
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
77
77
|
this message to receiving a reply. It specifies in seconds the duration for
|
|
78
78
|
which the message and its potential reply are considered valid. If unset,
|
|
79
|
-
the default TTL (i.e.,
|
|
79
|
+
the default TTL (i.e., ``common.DEFAULT_TTL``) will be used.
|
|
80
80
|
|
|
81
81
|
Returns
|
|
82
82
|
-------
|
|
@@ -85,7 +85,7 @@ class Driver(ABC):
|
|
|
85
85
|
"""
|
|
86
86
|
|
|
87
87
|
@abstractmethod
|
|
88
|
-
def get_node_ids(self) ->
|
|
88
|
+
def get_node_ids(self) -> Iterable[int]:
|
|
89
89
|
"""Get node IDs."""
|
|
90
90
|
|
|
91
91
|
@abstractmethod
|
|
@@ -93,7 +93,7 @@ class Driver(ABC):
|
|
|
93
93
|
"""Push messages to specified node IDs.
|
|
94
94
|
|
|
95
95
|
This method takes an iterable of messages and sends each message
|
|
96
|
-
to the node specified in
|
|
96
|
+
to the node specified in ``dst_node_id``.
|
|
97
97
|
|
|
98
98
|
Parameters
|
|
99
99
|
----------
|
|
@@ -154,8 +154,37 @@ class Driver(ABC):
|
|
|
154
154
|
|
|
155
155
|
Notes
|
|
156
156
|
-----
|
|
157
|
-
This method uses
|
|
158
|
-
to collect the replies. If
|
|
157
|
+
This method uses ``push_messages`` to send the messages and ``pull_messages``
|
|
158
|
+
to collect the replies. If ``timeout`` is set, the method may not return
|
|
159
159
|
replies for all sent messages. A message remains valid until its TTL,
|
|
160
|
-
which is not affected by
|
|
160
|
+
which is not affected by ``timeout``.
|
|
161
161
|
"""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Driver(Grid):
|
|
165
|
+
"""Deprecated abstract base class ``Driver``, use ``Grid`` instead.
|
|
166
|
+
|
|
167
|
+
This class is provided solely for backward compatibility with legacy
|
|
168
|
+
code that previously relied on the ``Driver`` class. It has been deprecated
|
|
169
|
+
in favor of the updated abstract base class ``Grid``, which now encompasses
|
|
170
|
+
all communication-related functionality and improvements between the
|
|
171
|
+
ServerApp and the SuperLink.
|
|
172
|
+
|
|
173
|
+
.. warning::
|
|
174
|
+
``Driver`` is deprecated and will be removed in a future release.
|
|
175
|
+
Use `Grid` in the signature of your ServerApp.
|
|
176
|
+
|
|
177
|
+
Examples
|
|
178
|
+
--------
|
|
179
|
+
Legacy (deprecated) usage::
|
|
180
|
+
|
|
181
|
+
@app.main()
|
|
182
|
+
def main(driver: Driver, context: Context) -> None:
|
|
183
|
+
...
|
|
184
|
+
|
|
185
|
+
Updated usage::
|
|
186
|
+
|
|
187
|
+
@app.main()
|
|
188
|
+
def main(grid: Grid, context: Context) -> None:
|
|
189
|
+
...
|
|
190
|
+
"""
|