flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +6 -4
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +23 -20
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +2 -0
- flwr/common/constant.py +2 -0
- flwr/common/context.py +4 -4
- flwr/common/logger.py +2 -2
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/configsrecord.py +2 -2
- flwr/common/record/metricsrecord.py +1 -1
- flwr/common/record/parametersrecord.py +1 -1
- flwr/common/record/{recordset.py → recorddict.py} +57 -17
- flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
- flwr/common/serde.py +33 -37
- flwr/proto/exec_pb2.py +32 -32
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
- flwr/proto/run_pb2.py +32 -32
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +2 -0
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/grid_client_proxy.py +38 -38
- flwr/server/grid/__init__.py +7 -6
- flwr/server/grid/grid.py +46 -17
- flwr/server/grid/grpc_grid.py +26 -33
- flwr/server/grid/inmemory_grid.py +19 -25
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +37 -11
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +29 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -20
- flwr/server/superlink/linkstate/utils.py +77 -17
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +24 -26
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +23 -23
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +13 -13
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +60 -60
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
flwr/server/compat/app.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Flower
|
15
|
+
"""Flower grid app."""
|
16
16
|
|
17
17
|
|
18
18
|
from logging import INFO
|
@@ -25,27 +25,27 @@ from flwr.server.server import Server, init_defaults, run_fl
|
|
25
25
|
from flwr.server.server_config import ServerConfig
|
26
26
|
from flwr.server.strategy import Strategy
|
27
27
|
|
28
|
-
from ..grid import
|
28
|
+
from ..grid import Grid
|
29
29
|
from .app_utils import start_update_client_manager_thread
|
30
30
|
|
31
31
|
|
32
|
-
def
|
32
|
+
def start_grid( # pylint: disable=too-many-arguments, too-many-locals
|
33
33
|
*,
|
34
|
-
|
34
|
+
grid: Grid,
|
35
35
|
server: Optional[Server] = None,
|
36
36
|
config: Optional[ServerConfig] = None,
|
37
37
|
strategy: Optional[Strategy] = None,
|
38
38
|
client_manager: Optional[ClientManager] = None,
|
39
39
|
) -> History:
|
40
|
-
"""Start a Flower
|
40
|
+
"""Start a Flower server.
|
41
41
|
|
42
42
|
Parameters
|
43
43
|
----------
|
44
|
-
|
45
|
-
The
|
44
|
+
grid : Grid
|
45
|
+
The Grid object to use.
|
46
46
|
server : Optional[flwr.server.Server] (default: None)
|
47
47
|
A server implementation, either `flwr.server.Server` or a subclass
|
48
|
-
thereof. If no instance is provided, then `
|
48
|
+
thereof. If no instance is provided, then `start_grid` will create
|
49
49
|
one.
|
50
50
|
config : Optional[ServerConfig] (default: None)
|
51
51
|
Currently supported values are `num_rounds` (int, default: 1) and
|
@@ -56,7 +56,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
56
56
|
`start_server` will use `flwr.server.strategy.FedAvg`.
|
57
57
|
client_manager : Optional[flwr.server.ClientManager] (default: None)
|
58
58
|
An implementation of the class `flwr.server.ClientManager`. If no
|
59
|
-
implementation is provided, then `
|
59
|
+
implementation is provided, then `start_grid` will use
|
60
60
|
`flwr.server.SimpleClientManager`.
|
61
61
|
|
62
62
|
Returns
|
@@ -64,7 +64,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
64
64
|
hist : flwr.server.history.History
|
65
65
|
Object containing training and evaluation metrics.
|
66
66
|
"""
|
67
|
-
# Initialize the
|
67
|
+
# Initialize the server and config
|
68
68
|
initialized_server, initialized_config = init_defaults(
|
69
69
|
server=server,
|
70
70
|
config=config,
|
@@ -80,7 +80,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
80
80
|
|
81
81
|
# Start the thread updating nodes
|
82
82
|
thread, f_stop, c_done = start_update_client_manager_thread(
|
83
|
-
|
83
|
+
grid, initialized_server.client_manager()
|
84
84
|
)
|
85
85
|
|
86
86
|
# Wait until the node registration done
|
flwr/server/compat/app_utils.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Utility functions for the `
|
15
|
+
"""Utility functions for the `start_grid`."""
|
16
16
|
|
17
17
|
|
18
18
|
import threading
|
@@ -20,18 +20,18 @@ import threading
|
|
20
20
|
from flwr.common.typing import RunNotRunningException
|
21
21
|
|
22
22
|
from ..client_manager import ClientManager
|
23
|
-
from ..grid import
|
24
|
-
from .grid_client_proxy import
|
23
|
+
from ..grid import Grid
|
24
|
+
from .grid_client_proxy import GridClientProxy
|
25
25
|
|
26
26
|
|
27
27
|
def start_update_client_manager_thread(
|
28
|
-
|
28
|
+
grid: Grid,
|
29
29
|
client_manager: ClientManager,
|
30
30
|
) -> tuple[threading.Thread, threading.Event, threading.Event]:
|
31
31
|
"""Periodically update the nodes list in the client manager in a thread.
|
32
32
|
|
33
|
-
This function starts a thread that periodically uses the associated
|
34
|
-
get all node_ids. Each node_id is then converted into a `
|
33
|
+
This function starts a thread that periodically uses the associated grid to
|
34
|
+
get all node_ids. Each node_id is then converted into a `GridClientProxy`
|
35
35
|
instance and stored in the `registered_nodes` dictionary with node_id as key.
|
36
36
|
|
37
37
|
New nodes will be added to the ClientManager via `client_manager.register()`,
|
@@ -40,8 +40,8 @@ def start_update_client_manager_thread(
|
|
40
40
|
|
41
41
|
Parameters
|
42
42
|
----------
|
43
|
-
|
44
|
-
The
|
43
|
+
grid : Grid
|
44
|
+
The Grid object to use.
|
45
45
|
client_manager : ClientManager
|
46
46
|
The ClientManager object to be updated.
|
47
47
|
|
@@ -59,7 +59,7 @@ def start_update_client_manager_thread(
|
|
59
59
|
thread = threading.Thread(
|
60
60
|
target=_update_client_manager,
|
61
61
|
args=(
|
62
|
-
|
62
|
+
grid,
|
63
63
|
client_manager,
|
64
64
|
f_stop,
|
65
65
|
c_done,
|
@@ -72,17 +72,17 @@ def start_update_client_manager_thread(
|
|
72
72
|
|
73
73
|
|
74
74
|
def _update_client_manager(
|
75
|
-
|
75
|
+
grid: Grid,
|
76
76
|
client_manager: ClientManager,
|
77
77
|
f_stop: threading.Event,
|
78
78
|
c_done: threading.Event,
|
79
79
|
) -> None:
|
80
80
|
"""Update the nodes list in the client manager."""
|
81
|
-
# Loop until the
|
82
|
-
registered_nodes: dict[int,
|
81
|
+
# Loop until the grid is disconnected
|
82
|
+
registered_nodes: dict[int, GridClientProxy] = {}
|
83
83
|
while not f_stop.is_set():
|
84
84
|
try:
|
85
|
-
all_node_ids = set(
|
85
|
+
all_node_ids = set(grid.get_node_ids())
|
86
86
|
except RunNotRunningException:
|
87
87
|
f_stop.set()
|
88
88
|
break
|
@@ -97,10 +97,10 @@ def _update_client_manager(
|
|
97
97
|
|
98
98
|
# Register new nodes
|
99
99
|
for node_id in new_nodes:
|
100
|
-
client_proxy =
|
100
|
+
client_proxy = GridClientProxy(
|
101
101
|
node_id=node_id,
|
102
|
-
|
103
|
-
run_id=
|
102
|
+
grid=grid,
|
103
|
+
run_id=grid.run.run_id,
|
104
104
|
)
|
105
105
|
if client_manager.register(client_proxy):
|
106
106
|
registered_nodes[node_id] = client_proxy
|
@@ -12,26 +12,26 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Flower ClientProxy implementation
|
15
|
+
"""Flower ClientProxy implementation using Grid."""
|
16
16
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
20
|
from flwr import common
|
21
|
-
from flwr.common import Message, MessageType, MessageTypeLegacy,
|
22
|
-
from flwr.common import
|
21
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordDict
|
22
|
+
from flwr.common import recorddict_compat as compat
|
23
23
|
from flwr.server.client_proxy import ClientProxy
|
24
24
|
|
25
|
-
from ..grid.grid import
|
25
|
+
from ..grid.grid import Grid
|
26
26
|
|
27
27
|
|
28
|
-
class
|
29
|
-
"""Flower client proxy which delegates work using
|
28
|
+
class GridClientProxy(ClientProxy):
|
29
|
+
"""Flower client proxy which delegates work using Grid."""
|
30
30
|
|
31
|
-
def __init__(self, node_id: int,
|
31
|
+
def __init__(self, node_id: int, grid: Grid, run_id: int):
|
32
32
|
super().__init__(str(node_id))
|
33
33
|
self.node_id = node_id
|
34
|
-
self.
|
34
|
+
self.grid = grid
|
35
35
|
self.run_id = run_id
|
36
36
|
|
37
37
|
def get_properties(
|
@@ -41,14 +41,14 @@ class DriverClientProxy(ClientProxy):
|
|
41
41
|
group_id: Optional[int],
|
42
42
|
) -> common.GetPropertiesRes:
|
43
43
|
"""Return client's properties."""
|
44
|
-
# Ins to
|
45
|
-
|
44
|
+
# Ins to RecordDict
|
45
|
+
out_recorddict = compat.getpropertiesins_to_recorddict(ins)
|
46
46
|
# Fetch response
|
47
|
-
|
48
|
-
|
47
|
+
in_recorddict = self._send_receive_recorddict(
|
48
|
+
out_recorddict, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
|
49
49
|
)
|
50
|
-
#
|
51
|
-
return compat.
|
50
|
+
# RecordDict to Res
|
51
|
+
return compat.recorddict_to_getpropertiesres(in_recorddict)
|
52
52
|
|
53
53
|
def get_parameters(
|
54
54
|
self,
|
@@ -57,40 +57,40 @@ class DriverClientProxy(ClientProxy):
|
|
57
57
|
group_id: Optional[int],
|
58
58
|
) -> common.GetParametersRes:
|
59
59
|
"""Return the current local model parameters."""
|
60
|
-
# Ins to
|
61
|
-
|
60
|
+
# Ins to RecordDict
|
61
|
+
out_recorddict = compat.getparametersins_to_recorddict(ins)
|
62
62
|
# Fetch response
|
63
|
-
|
64
|
-
|
63
|
+
in_recorddict = self._send_receive_recorddict(
|
64
|
+
out_recorddict, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
|
65
65
|
)
|
66
|
-
#
|
67
|
-
return compat.
|
66
|
+
# RecordDict to Res
|
67
|
+
return compat.recorddict_to_getparametersres(in_recorddict, False)
|
68
68
|
|
69
69
|
def fit(
|
70
70
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
71
71
|
) -> common.FitRes:
|
72
72
|
"""Train model parameters on the locally held dataset."""
|
73
|
-
# Ins to
|
74
|
-
|
73
|
+
# Ins to RecordDict
|
74
|
+
out_recorddict = compat.fitins_to_recorddict(ins, keep_input=True)
|
75
75
|
# Fetch response
|
76
|
-
|
77
|
-
|
76
|
+
in_recorddict = self._send_receive_recorddict(
|
77
|
+
out_recorddict, MessageType.TRAIN, timeout, group_id
|
78
78
|
)
|
79
|
-
#
|
80
|
-
return compat.
|
79
|
+
# RecordDict to Res
|
80
|
+
return compat.recorddict_to_fitres(in_recorddict, keep_input=False)
|
81
81
|
|
82
82
|
def evaluate(
|
83
83
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
84
84
|
) -> common.EvaluateRes:
|
85
85
|
"""Evaluate model parameters on the locally held dataset."""
|
86
|
-
# Ins to
|
87
|
-
|
86
|
+
# Ins to RecordDict
|
87
|
+
out_recorddict = compat.evaluateins_to_recorddict(ins, keep_input=True)
|
88
88
|
# Fetch response
|
89
|
-
|
90
|
-
|
89
|
+
in_recorddict = self._send_receive_recorddict(
|
90
|
+
out_recorddict, MessageType.EVALUATE, timeout, group_id
|
91
91
|
)
|
92
|
-
#
|
93
|
-
return compat.
|
92
|
+
# RecordDict to Res
|
93
|
+
return compat.recorddict_to_evaluateres(in_recorddict)
|
94
94
|
|
95
95
|
def reconnect(
|
96
96
|
self,
|
@@ -101,17 +101,17 @@ class DriverClientProxy(ClientProxy):
|
|
101
101
|
"""Disconnect and (optionally) reconnect later."""
|
102
102
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
103
103
|
|
104
|
-
def
|
104
|
+
def _send_receive_recorddict(
|
105
105
|
self,
|
106
|
-
|
106
|
+
recorddict: RecordDict,
|
107
107
|
message_type: str,
|
108
108
|
timeout: Optional[float],
|
109
109
|
group_id: Optional[int],
|
110
|
-
) ->
|
110
|
+
) -> RecordDict:
|
111
111
|
|
112
112
|
# Create message
|
113
|
-
message = self.
|
114
|
-
content=
|
113
|
+
message = self.grid.create_message(
|
114
|
+
content=recorddict,
|
115
115
|
message_type=message_type,
|
116
116
|
dst_node_id=self.node_id,
|
117
117
|
group_id=str(group_id) if group_id else "",
|
@@ -119,7 +119,7 @@ class DriverClientProxy(ClientProxy):
|
|
119
119
|
)
|
120
120
|
|
121
121
|
# Send message and wait for reply
|
122
|
-
messages = list(self.
|
122
|
+
messages = list(self.grid.send_and_receive(messages=[message]))
|
123
123
|
|
124
124
|
# A single reply is expected
|
125
125
|
if len(messages) != 1:
|
flwr/server/grid/__init__.py
CHANGED
@@ -12,15 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Flower
|
15
|
+
"""Flower grid SDK."""
|
16
16
|
|
17
17
|
|
18
|
-
from .grid import Driver
|
19
|
-
from .grpc_grid import
|
20
|
-
from .inmemory_grid import
|
18
|
+
from .grid import Driver, Grid
|
19
|
+
from .grpc_grid import GrpcGrid
|
20
|
+
from .inmemory_grid import InMemoryGrid
|
21
21
|
|
22
22
|
__all__ = [
|
23
23
|
"Driver",
|
24
|
-
"
|
25
|
-
"
|
24
|
+
"Grid",
|
25
|
+
"GrpcGrid",
|
26
|
+
"InMemoryGrid",
|
26
27
|
]
|
flwr/server/grid/grid.py
CHANGED
@@ -12,32 +12,32 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""Grid (abstract base class)."""
|
16
16
|
|
17
17
|
|
18
18
|
from abc import ABC, abstractmethod
|
19
19
|
from collections.abc import Iterable
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
-
from flwr.common import Message,
|
22
|
+
from flwr.common import Message, RecordDict
|
23
23
|
from flwr.common.typing import Run
|
24
24
|
|
25
25
|
|
26
|
-
class
|
27
|
-
"""Abstract base
|
26
|
+
class Grid(ABC):
|
27
|
+
"""Abstract base class Grid to send/receive messages."""
|
28
28
|
|
29
29
|
@abstractmethod
|
30
30
|
def set_run(self, run_id: int) -> None:
|
31
31
|
"""Request a run to the SuperLink with a given `run_id`.
|
32
32
|
|
33
|
-
If a Run with the specified
|
33
|
+
If a ``Run`` with the specified ``run_id`` exists, a local ``Run``
|
34
34
|
object will be created. It enables further functionality
|
35
|
-
in the
|
35
|
+
in the grid, such as sending ``Message``s.
|
36
36
|
|
37
37
|
Parameters
|
38
38
|
----------
|
39
39
|
run_id : int
|
40
|
-
The
|
40
|
+
The ``run_id`` of the ``Run`` this ``Grid`` object operates in.
|
41
41
|
"""
|
42
42
|
|
43
43
|
@property
|
@@ -48,7 +48,7 @@ class Driver(ABC):
|
|
48
48
|
@abstractmethod
|
49
49
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
50
50
|
self,
|
51
|
-
content:
|
51
|
+
content: RecordDict,
|
52
52
|
message_type: str,
|
53
53
|
dst_node_id: int,
|
54
54
|
group_id: str,
|
@@ -56,12 +56,12 @@ class Driver(ABC):
|
|
56
56
|
) -> Message:
|
57
57
|
"""Create a new message with specified parameters.
|
58
58
|
|
59
|
-
This method constructs a new
|
60
|
-
The
|
59
|
+
This method constructs a new ``Message`` with given content and metadata.
|
60
|
+
The ``run_id`` and ``src_node_id`` will be set automatically.
|
61
61
|
|
62
62
|
Parameters
|
63
63
|
----------
|
64
|
-
content :
|
64
|
+
content : RecordDict
|
65
65
|
The content for the new message. This holds records that are to be sent
|
66
66
|
to the destination node.
|
67
67
|
message_type : str
|
@@ -71,12 +71,12 @@ class Driver(ABC):
|
|
71
71
|
The ID of the destination node to which the message is being sent.
|
72
72
|
group_id : str
|
73
73
|
The ID of the group to which this message is associated. In some settings,
|
74
|
-
this is used as the
|
74
|
+
this is used as the federated learning round.
|
75
75
|
ttl : Optional[float] (default: None)
|
76
76
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
77
77
|
this message to receiving a reply. It specifies in seconds the duration for
|
78
78
|
which the message and its potential reply are considered valid. If unset,
|
79
|
-
the default TTL (i.e.,
|
79
|
+
the default TTL (i.e., ``common.DEFAULT_TTL``) will be used.
|
80
80
|
|
81
81
|
Returns
|
82
82
|
-------
|
@@ -93,7 +93,7 @@ class Driver(ABC):
|
|
93
93
|
"""Push messages to specified node IDs.
|
94
94
|
|
95
95
|
This method takes an iterable of messages and sends each message
|
96
|
-
to the node specified in
|
96
|
+
to the node specified in ``dst_node_id``.
|
97
97
|
|
98
98
|
Parameters
|
99
99
|
----------
|
@@ -154,8 +154,37 @@ class Driver(ABC):
|
|
154
154
|
|
155
155
|
Notes
|
156
156
|
-----
|
157
|
-
This method uses
|
158
|
-
to collect the replies. If
|
157
|
+
This method uses ``push_messages`` to send the messages and ``pull_messages``
|
158
|
+
to collect the replies. If ``timeout`` is set, the method may not return
|
159
159
|
replies for all sent messages. A message remains valid until its TTL,
|
160
|
-
which is not affected by
|
160
|
+
which is not affected by ``timeout``.
|
161
161
|
"""
|
162
|
+
|
163
|
+
|
164
|
+
class Driver(Grid):
|
165
|
+
"""Deprecated abstract base class ``Driver``, use ``Grid`` instead.
|
166
|
+
|
167
|
+
This class is provided solely for backward compatibility with legacy
|
168
|
+
code that previously relied on the ``Driver`` class. It has been deprecated
|
169
|
+
in favor of the updated abstract base class ``Grid``, which now encompasses
|
170
|
+
all communication-related functionality and improvements between the
|
171
|
+
ServerApp and the SuperLink.
|
172
|
+
|
173
|
+
.. warning::
|
174
|
+
``Driver`` is deprecated and will be removed in a future release.
|
175
|
+
Use `Grid` in the signature of your ServerApp.
|
176
|
+
|
177
|
+
Examples
|
178
|
+
--------
|
179
|
+
Legacy (deprecated) usage::
|
180
|
+
|
181
|
+
@app.main()
|
182
|
+
def main(driver: Driver, context: Context) -> None:
|
183
|
+
...
|
184
|
+
|
185
|
+
Updated usage::
|
186
|
+
|
187
|
+
@app.main()
|
188
|
+
def main(grid: Grid, context: Context) -> None:
|
189
|
+
...
|
190
|
+
"""
|
flwr/server/grid/grpc_grid.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Flower gRPC
|
15
|
+
"""Flower gRPC Grid."""
|
16
16
|
|
17
17
|
|
18
18
|
import time
|
@@ -22,13 +22,13 @@ from typing import Optional, cast
|
|
22
22
|
|
23
23
|
import grpc
|
24
24
|
|
25
|
-
from flwr.common import
|
25
|
+
from flwr.common import Message, RecordDict
|
26
26
|
from flwr.common.constant import (
|
27
27
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
28
28
|
SUPERLINK_NODE_ID,
|
29
29
|
)
|
30
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
31
|
-
from flwr.common.logger import log
|
31
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
32
32
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
33
33
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
34
34
|
from flwr.common.typing import Run
|
@@ -45,11 +45,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
45
45
|
)
|
46
46
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
47
47
|
|
48
|
-
from .grid import
|
48
|
+
from .grid import Grid
|
49
49
|
|
50
50
|
ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED = """
|
51
51
|
|
52
|
-
[
|
52
|
+
[Grid.push_messages] gRPC error occurred:
|
53
53
|
|
54
54
|
The 2GB gRPC limit has been reached. Consider reducing the number of messages pushed
|
55
55
|
at once, or push messages individually, for example:
|
@@ -57,13 +57,13 @@ at once, or push messages individually, for example:
|
|
57
57
|
> msgs = [msg1, msg2, msg3]
|
58
58
|
> msg_ids = []
|
59
59
|
> for msg in msgs:
|
60
|
-
> msg_id =
|
60
|
+
> msg_id = grid.push_messages([msg])
|
61
61
|
> msg_ids.extend(msg_id)
|
62
62
|
"""
|
63
63
|
|
64
64
|
ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED = """
|
65
65
|
|
66
|
-
[
|
66
|
+
[Grid.pull_messages] gRPC error occurred:
|
67
67
|
|
68
68
|
The 2GB gRPC limit has been reached. Consider reducing the number of messages pulled
|
69
69
|
at once, or pull messages individually, for example:
|
@@ -71,13 +71,13 @@ at once, or pull messages individually, for example:
|
|
71
71
|
> msgs_ids = [msg_id1, msg_id2, msg_id3]
|
72
72
|
> msgs = []
|
73
73
|
> for msg_id in msg_ids:
|
74
|
-
> msg =
|
74
|
+
> msg = grid.pull_messages([msg_id])
|
75
75
|
> msgs.extend(msg)
|
76
76
|
"""
|
77
77
|
|
78
78
|
|
79
|
-
class
|
80
|
-
"""`
|
79
|
+
class GrpcGrid(Grid):
|
80
|
+
"""`GrpcGrid` provides an interface to the ServerAppIo API.
|
81
81
|
|
82
82
|
Parameters
|
83
83
|
----------
|
@@ -101,6 +101,7 @@ class GrpcDriver(Driver):
|
|
101
101
|
self._channel: Optional[grpc.Channel] = None
|
102
102
|
self.node = Node(node_id=SUPERLINK_NODE_ID)
|
103
103
|
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
104
|
+
super().__init__()
|
104
105
|
|
105
106
|
@property
|
106
107
|
def _is_connected(self) -> bool:
|
@@ -160,18 +161,15 @@ class GrpcDriver(Driver):
|
|
160
161
|
def _check_message(self, message: Message) -> None:
|
161
162
|
# Check if the message is valid
|
162
163
|
if not (
|
163
|
-
|
164
|
-
message.metadata.
|
165
|
-
and message.metadata.src_node_id == self.node.node_id
|
166
|
-
and message.metadata.message_id == ""
|
167
|
-
and message.metadata.reply_to_message == ""
|
164
|
+
message.metadata.message_id == ""
|
165
|
+
and message.metadata.reply_to_message_id == ""
|
168
166
|
and message.metadata.ttl > 0
|
169
167
|
):
|
170
168
|
raise ValueError(f"Invalid message: {message}")
|
171
169
|
|
172
170
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
173
171
|
self,
|
174
|
-
content:
|
172
|
+
content: RecordDict,
|
175
173
|
message_type: str,
|
176
174
|
dst_node_id: int,
|
177
175
|
group_id: str,
|
@@ -182,22 +180,15 @@ class GrpcDriver(Driver):
|
|
182
180
|
This method constructs a new `Message` with given content and metadata.
|
183
181
|
The `run_id` and `src_node_id` will be set automatically.
|
184
182
|
"""
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
message_id="", # Will be set by the server
|
189
|
-
src_node_id=self.node.node_id,
|
190
|
-
dst_node_id=dst_node_id,
|
191
|
-
reply_to_message="",
|
192
|
-
group_id=group_id,
|
193
|
-
ttl=ttl_,
|
194
|
-
message_type=message_type,
|
183
|
+
warn_deprecated_feature(
|
184
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
185
|
+
"Use `Message` constructor instead."
|
195
186
|
)
|
196
|
-
return Message(
|
187
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
197
188
|
|
198
189
|
def get_node_ids(self) -> Iterable[int]:
|
199
190
|
"""Get node IDs."""
|
200
|
-
# Call
|
191
|
+
# Call GrpcServerAppIoStub method
|
201
192
|
res: GetNodesResponse = self._stub.GetNodes(
|
202
193
|
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
203
194
|
)
|
@@ -210,8 +201,12 @@ class GrpcDriver(Driver):
|
|
210
201
|
to the node specified in `dst_node_id`.
|
211
202
|
"""
|
212
203
|
# Construct Messages
|
204
|
+
run_id = cast(Run, self._run).run_id
|
213
205
|
message_proto_list: list[ProtoMessage] = []
|
214
206
|
for msg in messages:
|
207
|
+
# Populate metadata
|
208
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
209
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
215
210
|
# Check message
|
216
211
|
self._check_message(msg)
|
217
212
|
# Convert to proto
|
@@ -220,11 +215,9 @@ class GrpcDriver(Driver):
|
|
220
215
|
message_proto_list.append(msg_proto)
|
221
216
|
|
222
217
|
try:
|
223
|
-
# Call
|
218
|
+
# Call GrpcServerAppIoStub method
|
224
219
|
res: PushInsMessagesResponse = self._stub.PushMessages(
|
225
|
-
PushInsMessagesRequest(
|
226
|
-
messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
|
227
|
-
)
|
220
|
+
PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
|
228
221
|
)
|
229
222
|
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
|
230
223
|
message_proto_list
|
@@ -288,7 +281,7 @@ class GrpcDriver(Driver):
|
|
288
281
|
res_msgs = self.pull_messages(msg_ids)
|
289
282
|
ret.extend(res_msgs)
|
290
283
|
msg_ids.difference_update(
|
291
|
-
{msg.metadata.
|
284
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
292
285
|
)
|
293
286
|
if len(msg_ids) == 0:
|
294
287
|
break
|