flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +162 -99
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +6 -6
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/logger.py +2 -2
- flwr/common/message.py +327 -102
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +56 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +47 -18
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -18
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
- flwr/server/superlink/linkstate/utils.py +93 -27
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +48 -57
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/exec_user_auth_interceptor.py +18 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,26 +12,26 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower ClientProxy implementation
|
|
15
|
+
"""Flower ClientProxy implementation using Grid."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from typing import Optional
|
|
19
19
|
|
|
20
20
|
from flwr import common
|
|
21
|
-
from flwr.common import Message, MessageType, MessageTypeLegacy,
|
|
22
|
-
from flwr.common import
|
|
21
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordDict
|
|
22
|
+
from flwr.common import recorddict_compat as compat
|
|
23
23
|
from flwr.server.client_proxy import ClientProxy
|
|
24
24
|
|
|
25
|
-
from ..
|
|
25
|
+
from ..grid.grid import Grid
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class
|
|
29
|
-
"""Flower client proxy which delegates work using
|
|
28
|
+
class GridClientProxy(ClientProxy):
|
|
29
|
+
"""Flower client proxy which delegates work using Grid."""
|
|
30
30
|
|
|
31
|
-
def __init__(self, node_id: int,
|
|
31
|
+
def __init__(self, node_id: int, grid: Grid, run_id: int):
|
|
32
32
|
super().__init__(str(node_id))
|
|
33
33
|
self.node_id = node_id
|
|
34
|
-
self.
|
|
34
|
+
self.grid = grid
|
|
35
35
|
self.run_id = run_id
|
|
36
36
|
|
|
37
37
|
def get_properties(
|
|
@@ -41,14 +41,14 @@ class DriverClientProxy(ClientProxy):
|
|
|
41
41
|
group_id: Optional[int],
|
|
42
42
|
) -> common.GetPropertiesRes:
|
|
43
43
|
"""Return client's properties."""
|
|
44
|
-
# Ins to
|
|
45
|
-
|
|
44
|
+
# Ins to RecordDict
|
|
45
|
+
out_recorddict = compat.getpropertiesins_to_recorddict(ins)
|
|
46
46
|
# Fetch response
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
in_recorddict = self._send_receive_recorddict(
|
|
48
|
+
out_recorddict, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
|
|
49
49
|
)
|
|
50
|
-
#
|
|
51
|
-
return compat.
|
|
50
|
+
# RecordDict to Res
|
|
51
|
+
return compat.recorddict_to_getpropertiesres(in_recorddict)
|
|
52
52
|
|
|
53
53
|
def get_parameters(
|
|
54
54
|
self,
|
|
@@ -57,40 +57,40 @@ class DriverClientProxy(ClientProxy):
|
|
|
57
57
|
group_id: Optional[int],
|
|
58
58
|
) -> common.GetParametersRes:
|
|
59
59
|
"""Return the current local model parameters."""
|
|
60
|
-
# Ins to
|
|
61
|
-
|
|
60
|
+
# Ins to RecordDict
|
|
61
|
+
out_recorddict = compat.getparametersins_to_recorddict(ins)
|
|
62
62
|
# Fetch response
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
in_recorddict = self._send_receive_recorddict(
|
|
64
|
+
out_recorddict, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
|
|
65
65
|
)
|
|
66
|
-
#
|
|
67
|
-
return compat.
|
|
66
|
+
# RecordDict to Res
|
|
67
|
+
return compat.recorddict_to_getparametersres(in_recorddict, False)
|
|
68
68
|
|
|
69
69
|
def fit(
|
|
70
70
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
|
71
71
|
) -> common.FitRes:
|
|
72
72
|
"""Train model parameters on the locally held dataset."""
|
|
73
|
-
# Ins to
|
|
74
|
-
|
|
73
|
+
# Ins to RecordDict
|
|
74
|
+
out_recorddict = compat.fitins_to_recorddict(ins, keep_input=True)
|
|
75
75
|
# Fetch response
|
|
76
|
-
|
|
77
|
-
|
|
76
|
+
in_recorddict = self._send_receive_recorddict(
|
|
77
|
+
out_recorddict, MessageType.TRAIN, timeout, group_id
|
|
78
78
|
)
|
|
79
|
-
#
|
|
80
|
-
return compat.
|
|
79
|
+
# RecordDict to Res
|
|
80
|
+
return compat.recorddict_to_fitres(in_recorddict, keep_input=False)
|
|
81
81
|
|
|
82
82
|
def evaluate(
|
|
83
83
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
|
84
84
|
) -> common.EvaluateRes:
|
|
85
85
|
"""Evaluate model parameters on the locally held dataset."""
|
|
86
|
-
# Ins to
|
|
87
|
-
|
|
86
|
+
# Ins to RecordDict
|
|
87
|
+
out_recorddict = compat.evaluateins_to_recorddict(ins, keep_input=True)
|
|
88
88
|
# Fetch response
|
|
89
|
-
|
|
90
|
-
|
|
89
|
+
in_recorddict = self._send_receive_recorddict(
|
|
90
|
+
out_recorddict, MessageType.EVALUATE, timeout, group_id
|
|
91
91
|
)
|
|
92
|
-
#
|
|
93
|
-
return compat.
|
|
92
|
+
# RecordDict to Res
|
|
93
|
+
return compat.recorddict_to_evaluateres(in_recorddict)
|
|
94
94
|
|
|
95
95
|
def reconnect(
|
|
96
96
|
self,
|
|
@@ -101,17 +101,17 @@ class DriverClientProxy(ClientProxy):
|
|
|
101
101
|
"""Disconnect and (optionally) reconnect later."""
|
|
102
102
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|
|
103
103
|
|
|
104
|
-
def
|
|
104
|
+
def _send_receive_recorddict(
|
|
105
105
|
self,
|
|
106
|
-
|
|
106
|
+
recorddict: RecordDict,
|
|
107
107
|
message_type: str,
|
|
108
108
|
timeout: Optional[float],
|
|
109
109
|
group_id: Optional[int],
|
|
110
|
-
) ->
|
|
110
|
+
) -> RecordDict:
|
|
111
111
|
|
|
112
112
|
# Create message
|
|
113
|
-
message =
|
|
114
|
-
content=
|
|
113
|
+
message = Message(
|
|
114
|
+
content=recorddict,
|
|
115
115
|
message_type=message_type,
|
|
116
116
|
dst_node_id=self.node_id,
|
|
117
117
|
group_id=str(group_id) if group_id else "",
|
|
@@ -119,7 +119,7 @@ class DriverClientProxy(ClientProxy):
|
|
|
119
119
|
)
|
|
120
120
|
|
|
121
121
|
# Send message and wait for reply
|
|
122
|
-
messages = list(self.
|
|
122
|
+
messages = list(self.grid.send_and_receive(messages=[message]))
|
|
123
123
|
|
|
124
124
|
# A single reply is expected
|
|
125
125
|
if len(messages) != 1:
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower Fleet API event log interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from typing import Any, Callable, cast
|
|
19
|
+
|
|
20
|
+
import grpc
|
|
21
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
22
|
+
|
|
23
|
+
from flwr.common.event_log_plugin.event_log_plugin import EventLogWriterPlugin
|
|
24
|
+
from flwr.common.typing import LogEntry
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
28
|
+
"""Fleet API interceptor for logging events."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, log_plugin: EventLogWriterPlugin) -> None:
|
|
31
|
+
self.log_plugin = log_plugin
|
|
32
|
+
|
|
33
|
+
def intercept_service(
|
|
34
|
+
self,
|
|
35
|
+
continuation: Callable[[Any], Any],
|
|
36
|
+
handler_call_details: grpc.HandlerCallDetails,
|
|
37
|
+
) -> grpc.RpcMethodHandler:
|
|
38
|
+
"""Flower Fleet API server interceptor logging logic.
|
|
39
|
+
|
|
40
|
+
Intercept all unary-unary calls from users and log the event. Continue RPC call
|
|
41
|
+
if event logger is enabled on the SuperLink, else, terminate RPC call by setting
|
|
42
|
+
context to abort.
|
|
43
|
+
"""
|
|
44
|
+
# One of the method handlers in
|
|
45
|
+
# `flwr.server.superlink.fleet.grpc_rere.fleet_servicer.FleetServicer`
|
|
46
|
+
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
47
|
+
method_name: str = handler_call_details.method
|
|
48
|
+
return self._generic_event_log_unary_method_handler(method_handler, method_name)
|
|
49
|
+
|
|
50
|
+
def _generic_event_log_unary_method_handler(
|
|
51
|
+
self, method_handler: grpc.RpcMethodHandler, method_name: str
|
|
52
|
+
) -> grpc.RpcMethodHandler:
|
|
53
|
+
def _generic_method_handler(
|
|
54
|
+
request: GrpcMessage,
|
|
55
|
+
context: grpc.ServicerContext,
|
|
56
|
+
) -> GrpcMessage:
|
|
57
|
+
log_entry: LogEntry
|
|
58
|
+
# Log before call
|
|
59
|
+
log_entry = self.log_plugin.compose_log_before_event(
|
|
60
|
+
request=request,
|
|
61
|
+
context=context,
|
|
62
|
+
user_info=None,
|
|
63
|
+
method_name=method_name,
|
|
64
|
+
)
|
|
65
|
+
self.log_plugin.write_log(log_entry)
|
|
66
|
+
|
|
67
|
+
call = method_handler.unary_unary
|
|
68
|
+
unary_response, error = None, None
|
|
69
|
+
try:
|
|
70
|
+
unary_response = cast(GrpcMessage, call(request, context))
|
|
71
|
+
except BaseException as e:
|
|
72
|
+
error = e
|
|
73
|
+
raise
|
|
74
|
+
finally:
|
|
75
|
+
log_entry = self.log_plugin.compose_log_after_event(
|
|
76
|
+
request=request,
|
|
77
|
+
context=context,
|
|
78
|
+
user_info=None,
|
|
79
|
+
method_name=method_name,
|
|
80
|
+
response=unary_response or error,
|
|
81
|
+
)
|
|
82
|
+
self.log_plugin.write_log(log_entry)
|
|
83
|
+
return unary_response
|
|
84
|
+
|
|
85
|
+
if method_handler.unary_unary:
|
|
86
|
+
message_handler = grpc.unary_unary_rpc_method_handler
|
|
87
|
+
else:
|
|
88
|
+
# If the method type is not `unary_unary` raise an error
|
|
89
|
+
raise NotImplementedError("This RPC method type is not supported.")
|
|
90
|
+
return message_handler(
|
|
91
|
+
_generic_method_handler,
|
|
92
|
+
request_deserializer=method_handler.request_deserializer,
|
|
93
|
+
response_serializer=method_handler.response_serializer,
|
|
94
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,15 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower grid SDK."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
from .
|
|
18
|
+
from .grid import Driver, Grid
|
|
19
|
+
from .grpc_grid import GrpcGrid
|
|
20
|
+
from .inmemory_grid import InMemoryGrid
|
|
21
21
|
|
|
22
22
|
__all__ = [
|
|
23
23
|
"Driver",
|
|
24
|
-
"
|
|
25
|
-
"
|
|
24
|
+
"Grid",
|
|
25
|
+
"GrpcGrid",
|
|
26
|
+
"InMemoryGrid",
|
|
26
27
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,32 +12,32 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Grid (abstract base class)."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from collections.abc import Iterable
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
|
-
from flwr.common import Message,
|
|
22
|
+
from flwr.common import Message, RecordDict
|
|
23
23
|
from flwr.common.typing import Run
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
27
|
-
"""Abstract base
|
|
26
|
+
class Grid(ABC):
|
|
27
|
+
"""Abstract base class Grid to send/receive messages."""
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
30
|
def set_run(self, run_id: int) -> None:
|
|
31
31
|
"""Request a run to the SuperLink with a given `run_id`.
|
|
32
32
|
|
|
33
|
-
If a Run with the specified
|
|
33
|
+
If a ``Run`` with the specified ``run_id`` exists, a local ``Run``
|
|
34
34
|
object will be created. It enables further functionality
|
|
35
|
-
in the
|
|
35
|
+
in the grid, such as sending ``Message``s.
|
|
36
36
|
|
|
37
37
|
Parameters
|
|
38
38
|
----------
|
|
39
39
|
run_id : int
|
|
40
|
-
The
|
|
40
|
+
The ``run_id`` of the ``Run`` this ``Grid`` object operates in.
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
43
|
@property
|
|
@@ -48,7 +48,7 @@ class Driver(ABC):
|
|
|
48
48
|
@abstractmethod
|
|
49
49
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
50
50
|
self,
|
|
51
|
-
content:
|
|
51
|
+
content: RecordDict,
|
|
52
52
|
message_type: str,
|
|
53
53
|
dst_node_id: int,
|
|
54
54
|
group_id: str,
|
|
@@ -56,12 +56,12 @@ class Driver(ABC):
|
|
|
56
56
|
) -> Message:
|
|
57
57
|
"""Create a new message with specified parameters.
|
|
58
58
|
|
|
59
|
-
This method constructs a new
|
|
60
|
-
The
|
|
59
|
+
This method constructs a new ``Message`` with given content and metadata.
|
|
60
|
+
The ``run_id`` and ``src_node_id`` will be set automatically.
|
|
61
61
|
|
|
62
62
|
Parameters
|
|
63
63
|
----------
|
|
64
|
-
content :
|
|
64
|
+
content : RecordDict
|
|
65
65
|
The content for the new message. This holds records that are to be sent
|
|
66
66
|
to the destination node.
|
|
67
67
|
message_type : str
|
|
@@ -71,12 +71,12 @@ class Driver(ABC):
|
|
|
71
71
|
The ID of the destination node to which the message is being sent.
|
|
72
72
|
group_id : str
|
|
73
73
|
The ID of the group to which this message is associated. In some settings,
|
|
74
|
-
this is used as the
|
|
74
|
+
this is used as the federated learning round.
|
|
75
75
|
ttl : Optional[float] (default: None)
|
|
76
76
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
77
77
|
this message to receiving a reply. It specifies in seconds the duration for
|
|
78
78
|
which the message and its potential reply are considered valid. If unset,
|
|
79
|
-
the default TTL (i.e.,
|
|
79
|
+
the default TTL (i.e., ``common.DEFAULT_TTL``) will be used.
|
|
80
80
|
|
|
81
81
|
Returns
|
|
82
82
|
-------
|
|
@@ -93,7 +93,7 @@ class Driver(ABC):
|
|
|
93
93
|
"""Push messages to specified node IDs.
|
|
94
94
|
|
|
95
95
|
This method takes an iterable of messages and sends each message
|
|
96
|
-
to the node specified in
|
|
96
|
+
to the node specified in ``dst_node_id``.
|
|
97
97
|
|
|
98
98
|
Parameters
|
|
99
99
|
----------
|
|
@@ -154,8 +154,37 @@ class Driver(ABC):
|
|
|
154
154
|
|
|
155
155
|
Notes
|
|
156
156
|
-----
|
|
157
|
-
This method uses
|
|
158
|
-
to collect the replies. If
|
|
157
|
+
This method uses ``push_messages`` to send the messages and ``pull_messages``
|
|
158
|
+
to collect the replies. If ``timeout`` is set, the method may not return
|
|
159
159
|
replies for all sent messages. A message remains valid until its TTL,
|
|
160
|
-
which is not affected by
|
|
160
|
+
which is not affected by ``timeout``.
|
|
161
161
|
"""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Driver(Grid):
|
|
165
|
+
"""Deprecated abstract base class ``Driver``, use ``Grid`` instead.
|
|
166
|
+
|
|
167
|
+
This class is provided solely for backward compatibility with legacy
|
|
168
|
+
code that previously relied on the ``Driver`` class. It has been deprecated
|
|
169
|
+
in favor of the updated abstract base class ``Grid``, which now encompasses
|
|
170
|
+
all communication-related functionality and improvements between the
|
|
171
|
+
ServerApp and the SuperLink.
|
|
172
|
+
|
|
173
|
+
.. warning::
|
|
174
|
+
``Driver`` is deprecated and will be removed in a future release.
|
|
175
|
+
Use `Grid` in the signature of your ServerApp.
|
|
176
|
+
|
|
177
|
+
Examples
|
|
178
|
+
--------
|
|
179
|
+
Legacy (deprecated) usage::
|
|
180
|
+
|
|
181
|
+
@app.main()
|
|
182
|
+
def main(driver: Driver, context: Context) -> None:
|
|
183
|
+
...
|
|
184
|
+
|
|
185
|
+
Updated usage::
|
|
186
|
+
|
|
187
|
+
@app.main()
|
|
188
|
+
def main(grid: Grid, context: Context) -> None:
|
|
189
|
+
...
|
|
190
|
+
"""
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,24 +12,23 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower gRPC
|
|
15
|
+
"""Flower gRPC Grid."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
import warnings
|
|
20
19
|
from collections.abc import Iterable
|
|
21
|
-
from logging import DEBUG, WARNING
|
|
20
|
+
from logging import DEBUG, ERROR, WARNING
|
|
22
21
|
from typing import Optional, cast
|
|
23
22
|
|
|
24
23
|
import grpc
|
|
25
24
|
|
|
26
|
-
from flwr.common import
|
|
25
|
+
from flwr.common import Message, RecordDict
|
|
27
26
|
from flwr.common.constant import (
|
|
28
27
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
29
28
|
SUPERLINK_NODE_ID,
|
|
30
29
|
)
|
|
31
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
32
|
-
from flwr.common.logger import log
|
|
31
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
|
33
32
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
34
33
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
35
34
|
from flwr.common.typing import Run
|
|
@@ -46,18 +45,39 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
|
46
45
|
)
|
|
47
46
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
48
47
|
|
|
49
|
-
from .
|
|
48
|
+
from .grid import Grid
|
|
50
49
|
|
|
51
|
-
|
|
52
|
-
[flwr-serverapp] Error: Not connected.
|
|
50
|
+
ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED = """
|
|
53
51
|
|
|
54
|
-
|
|
55
|
-
|
|
52
|
+
[Grid.push_messages] gRPC error occurred:
|
|
53
|
+
|
|
54
|
+
The 2GB gRPC limit has been reached. Consider reducing the number of messages pushed
|
|
55
|
+
at once, or push messages individually, for example:
|
|
56
|
+
|
|
57
|
+
> msgs = [msg1, msg2, msg3]
|
|
58
|
+
> msg_ids = []
|
|
59
|
+
> for msg in msgs:
|
|
60
|
+
> msg_id = grid.push_messages([msg])
|
|
61
|
+
> msg_ids.extend(msg_id)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED = """
|
|
65
|
+
|
|
66
|
+
[Grid.pull_messages] gRPC error occurred:
|
|
67
|
+
|
|
68
|
+
The 2GB gRPC limit has been reached. Consider reducing the number of messages pulled
|
|
69
|
+
at once, or pull messages individually, for example:
|
|
70
|
+
|
|
71
|
+
> msgs_ids = [msg_id1, msg_id2, msg_id3]
|
|
72
|
+
> msgs = []
|
|
73
|
+
> for msg_id in msg_ids:
|
|
74
|
+
> msg = grid.pull_messages([msg_id])
|
|
75
|
+
> msgs.extend(msg)
|
|
56
76
|
"""
|
|
57
77
|
|
|
58
78
|
|
|
59
|
-
class
|
|
60
|
-
"""`
|
|
79
|
+
class GrpcGrid(Grid):
|
|
80
|
+
"""`GrpcGrid` provides an interface to the ServerAppIo API.
|
|
61
81
|
|
|
62
82
|
Parameters
|
|
63
83
|
----------
|
|
@@ -69,6 +89,8 @@ class GrpcDriver(Driver):
|
|
|
69
89
|
established to an SSL-enabled Flower server.
|
|
70
90
|
"""
|
|
71
91
|
|
|
92
|
+
_deprecation_warning_logged = False
|
|
93
|
+
|
|
72
94
|
def __init__( # pylint: disable=too-many-arguments
|
|
73
95
|
self,
|
|
74
96
|
serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
@@ -81,6 +103,7 @@ class GrpcDriver(Driver):
|
|
|
81
103
|
self._channel: Optional[grpc.Channel] = None
|
|
82
104
|
self.node = Node(node_id=SUPERLINK_NODE_ID)
|
|
83
105
|
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
|
106
|
+
super().__init__()
|
|
84
107
|
|
|
85
108
|
@property
|
|
86
109
|
def _is_connected(self) -> bool:
|
|
@@ -140,18 +163,15 @@ class GrpcDriver(Driver):
|
|
|
140
163
|
def _check_message(self, message: Message) -> None:
|
|
141
164
|
# Check if the message is valid
|
|
142
165
|
if not (
|
|
143
|
-
|
|
144
|
-
message.metadata.
|
|
145
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
146
|
-
and message.metadata.message_id == ""
|
|
147
|
-
and message.metadata.reply_to_message == ""
|
|
166
|
+
message.metadata.message_id == ""
|
|
167
|
+
and message.metadata.reply_to_message_id == ""
|
|
148
168
|
and message.metadata.ttl > 0
|
|
149
169
|
):
|
|
150
170
|
raise ValueError(f"Invalid message: {message}")
|
|
151
171
|
|
|
152
172
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
153
173
|
self,
|
|
154
|
-
content:
|
|
174
|
+
content: RecordDict,
|
|
155
175
|
message_type: str,
|
|
156
176
|
dst_node_id: int,
|
|
157
177
|
group_id: str,
|
|
@@ -162,30 +182,17 @@ class GrpcDriver(Driver):
|
|
|
162
182
|
This method constructs a new `Message` with given content and metadata.
|
|
163
183
|
The `run_id` and `src_node_id` will be set automatically.
|
|
164
184
|
"""
|
|
165
|
-
if
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
"
|
|
169
|
-
"
|
|
170
|
-
stacklevel=2,
|
|
185
|
+
if not GrpcGrid._deprecation_warning_logged:
|
|
186
|
+
GrpcGrid._deprecation_warning_logged = True
|
|
187
|
+
warn_deprecated_feature(
|
|
188
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
|
189
|
+
"Use `Message` constructor instead."
|
|
171
190
|
)
|
|
172
|
-
|
|
173
|
-
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
174
|
-
metadata = Metadata(
|
|
175
|
-
run_id=cast(Run, self._run).run_id,
|
|
176
|
-
message_id="", # Will be set by the server
|
|
177
|
-
src_node_id=self.node.node_id,
|
|
178
|
-
dst_node_id=dst_node_id,
|
|
179
|
-
reply_to_message="",
|
|
180
|
-
group_id=group_id,
|
|
181
|
-
ttl=ttl_,
|
|
182
|
-
message_type=message_type,
|
|
183
|
-
)
|
|
184
|
-
return Message(metadata=metadata, content=content)
|
|
191
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
|
185
192
|
|
|
186
193
|
def get_node_ids(self) -> Iterable[int]:
|
|
187
194
|
"""Get node IDs."""
|
|
188
|
-
# Call
|
|
195
|
+
# Call GrpcServerAppIoStub method
|
|
189
196
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
190
197
|
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
|
191
198
|
)
|
|
@@ -198,30 +205,40 @@ class GrpcDriver(Driver):
|
|
|
198
205
|
to the node specified in `dst_node_id`.
|
|
199
206
|
"""
|
|
200
207
|
# Construct Messages
|
|
208
|
+
run_id = cast(Run, self._run).run_id
|
|
201
209
|
message_proto_list: list[ProtoMessage] = []
|
|
202
210
|
for msg in messages:
|
|
211
|
+
# Populate metadata
|
|
212
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
213
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
203
214
|
# Check message
|
|
204
215
|
self._check_message(msg)
|
|
205
216
|
# Convert to proto
|
|
206
217
|
msg_proto = message_to_proto(msg)
|
|
207
218
|
# Add to list
|
|
208
219
|
message_proto_list.append(msg_proto)
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
)
|
|
215
|
-
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
|
|
216
|
-
list(message_proto_list)
|
|
217
|
-
):
|
|
218
|
-
log(
|
|
219
|
-
WARNING,
|
|
220
|
-
"Not all messages could be pushed to the SuperLink. The returned "
|
|
221
|
-
"list has `None` for those messages (the order is preserved as passed "
|
|
222
|
-
"to `push_messages`). This could be due to a malformed message.",
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
# Call GrpcServerAppIoStub method
|
|
223
|
+
res: PushInsMessagesResponse = self._stub.PushMessages(
|
|
224
|
+
PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
|
|
223
225
|
)
|
|
224
|
-
|
|
226
|
+
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
|
|
227
|
+
message_proto_list
|
|
228
|
+
):
|
|
229
|
+
log(
|
|
230
|
+
WARNING,
|
|
231
|
+
"Not all messages could be pushed to the SuperLink. The returned "
|
|
232
|
+
"list has `None` for those messages (the order is preserved as "
|
|
233
|
+
"passed to `push_messages`). This could be due to a malformed "
|
|
234
|
+
"message.",
|
|
235
|
+
)
|
|
236
|
+
return list(res.message_ids)
|
|
237
|
+
except grpc.RpcError as e:
|
|
238
|
+
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
|
239
|
+
log(ERROR, ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED)
|
|
240
|
+
return []
|
|
241
|
+
raise
|
|
225
242
|
|
|
226
243
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
227
244
|
"""Pull messages based on message IDs.
|
|
@@ -229,16 +246,22 @@ class GrpcDriver(Driver):
|
|
|
229
246
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
230
247
|
set of given message IDs.
|
|
231
248
|
"""
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
249
|
+
try:
|
|
250
|
+
# Pull Messages
|
|
251
|
+
res: PullResMessagesResponse = self._stub.PullMessages(
|
|
252
|
+
PullResMessagesRequest(
|
|
253
|
+
message_ids=message_ids,
|
|
254
|
+
run_id=cast(Run, self._run).run_id,
|
|
255
|
+
)
|
|
237
256
|
)
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
257
|
+
# Convert Message from Protobuf representation
|
|
258
|
+
msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
|
|
259
|
+
return msgs
|
|
260
|
+
except grpc.RpcError as e:
|
|
261
|
+
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
|
262
|
+
log(ERROR, ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED)
|
|
263
|
+
return []
|
|
264
|
+
raise
|
|
242
265
|
|
|
243
266
|
def send_and_receive(
|
|
244
267
|
self,
|
|
@@ -262,7 +285,7 @@ class GrpcDriver(Driver):
|
|
|
262
285
|
res_msgs = self.pull_messages(msg_ids)
|
|
263
286
|
ret.extend(res_msgs)
|
|
264
287
|
msg_ids.difference_update(
|
|
265
|
-
{msg.metadata.
|
|
288
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
|
266
289
|
)
|
|
267
290
|
if len(msg_ids) == 0:
|
|
268
291
|
break
|