flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  11. flwr/cli/run/run.py +5 -9
  12. flwr/client/app.py +6 -4
  13. flwr/client/client_app.py +162 -99
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/grpc_client/connection.py +24 -21
  16. flwr/client/message_handler/message_handler.py +27 -27
  17. flwr/client/mod/__init__.py +2 -2
  18. flwr/client/mod/centraldp_mods.py +7 -7
  19. flwr/client/mod/comms_mods.py +16 -22
  20. flwr/client/mod/localdp_mod.py +4 -4
  21. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  22. flwr/client/run_info_store.py +2 -2
  23. flwr/common/__init__.py +12 -4
  24. flwr/common/config.py +4 -4
  25. flwr/common/constant.py +6 -6
  26. flwr/common/context.py +4 -4
  27. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  28. flwr/common/logger.py +2 -2
  29. flwr/common/message.py +327 -102
  30. flwr/common/record/__init__.py +8 -4
  31. flwr/common/record/arrayrecord.py +626 -0
  32. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  33. flwr/common/record/conversion_utils.py +1 -1
  34. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  35. flwr/common/record/recorddict.py +288 -0
  36. flwr/common/recorddict_compat.py +410 -0
  37. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  38. flwr/common/serde.py +66 -71
  39. flwr/common/typing.py +8 -8
  40. flwr/proto/exec_pb2.py +3 -3
  41. flwr/proto/exec_pb2.pyi +3 -3
  42. flwr/proto/message_pb2.py +12 -12
  43. flwr/proto/message_pb2.pyi +9 -9
  44. flwr/proto/recorddict_pb2.py +70 -0
  45. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  46. flwr/proto/run_pb2.py +31 -31
  47. flwr/proto/run_pb2.pyi +3 -3
  48. flwr/server/__init__.py +3 -1
  49. flwr/server/app.py +56 -1
  50. flwr/server/compat/__init__.py +2 -2
  51. flwr/server/compat/app.py +11 -11
  52. flwr/server/compat/app_utils.py +16 -16
  53. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  54. flwr/server/fleet_event_log_interceptor.py +94 -0
  55. flwr/server/{driver → grid}/__init__.py +8 -7
  56. flwr/server/{driver/driver.py → grid/grid.py} +47 -18
  57. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  58. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  59. flwr/server/run_serverapp.py +4 -4
  60. flwr/server/server_app.py +38 -18
  61. flwr/server/serverapp/app.py +10 -10
  62. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  63. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  64. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  65. flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
  66. flwr/server/superlink/linkstate/linkstate.py +4 -4
  67. flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
  68. flwr/server/superlink/linkstate/utils.py +93 -27
  69. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  70. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  71. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  72. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  73. flwr/server/typing.py +3 -3
  74. flwr/server/utils/validator.py +4 -4
  75. flwr/server/workflow/default_workflows.py +48 -57
  76. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  77. flwr/simulation/app.py +2 -2
  78. flwr/simulation/ray_transport/ray_actor.py +4 -2
  79. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  80. flwr/simulation/run_simulation.py +15 -15
  81. flwr/superexec/deployment.py +4 -4
  82. flwr/superexec/exec_event_log_interceptor.py +135 -0
  83. flwr/superexec/exec_grpc.py +10 -4
  84. flwr/superexec/exec_servicer.py +2 -2
  85. flwr/superexec/exec_user_auth_interceptor.py +18 -2
  86. flwr/superexec/executor.py +3 -3
  87. flwr/superexec/simulation.py +3 -3
  88. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
  89. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
  90. flwr/common/record/parametersrecord.py +0 -339
  91. flwr/common/record/recordset.py +0 -209
  92. flwr/common/recordset_compat.py +0 -418
  93. flwr/proto/recordset_pb2.py +0 -70
  94. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  95. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  96. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  97. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
  98. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,26 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower ClientProxy implementation for Driver API."""
15
+ """Flower ClientProxy implementation using Grid."""
16
16
 
17
17
 
18
18
  from typing import Optional
19
19
 
20
20
  from flwr import common
21
- from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
22
- from flwr.common import recordset_compat as compat
21
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordDict
22
+ from flwr.common import recorddict_compat as compat
23
23
  from flwr.server.client_proxy import ClientProxy
24
24
 
25
- from ..driver.driver import Driver
25
+ from ..grid.grid import Grid
26
26
 
27
27
 
28
- class DriverClientProxy(ClientProxy):
29
- """Flower client proxy which delegates work using the Driver API."""
28
+ class GridClientProxy(ClientProxy):
29
+ """Flower client proxy which delegates work using Grid."""
30
30
 
31
- def __init__(self, node_id: int, driver: Driver, run_id: int):
31
+ def __init__(self, node_id: int, grid: Grid, run_id: int):
32
32
  super().__init__(str(node_id))
33
33
  self.node_id = node_id
34
- self.driver = driver
34
+ self.grid = grid
35
35
  self.run_id = run_id
36
36
 
37
37
  def get_properties(
@@ -41,14 +41,14 @@ class DriverClientProxy(ClientProxy):
41
41
  group_id: Optional[int],
42
42
  ) -> common.GetPropertiesRes:
43
43
  """Return client's properties."""
44
- # Ins to RecordSet
45
- out_recordset = compat.getpropertiesins_to_recordset(ins)
44
+ # Ins to RecordDict
45
+ out_recorddict = compat.getpropertiesins_to_recorddict(ins)
46
46
  # Fetch response
47
- in_recordset = self._send_receive_recordset(
48
- out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
47
+ in_recorddict = self._send_receive_recorddict(
48
+ out_recorddict, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
49
49
  )
50
- # RecordSet to Res
51
- return compat.recordset_to_getpropertiesres(in_recordset)
50
+ # RecordDict to Res
51
+ return compat.recorddict_to_getpropertiesres(in_recorddict)
52
52
 
53
53
  def get_parameters(
54
54
  self,
@@ -57,40 +57,40 @@ class DriverClientProxy(ClientProxy):
57
57
  group_id: Optional[int],
58
58
  ) -> common.GetParametersRes:
59
59
  """Return the current local model parameters."""
60
- # Ins to RecordSet
61
- out_recordset = compat.getparametersins_to_recordset(ins)
60
+ # Ins to RecordDict
61
+ out_recorddict = compat.getparametersins_to_recorddict(ins)
62
62
  # Fetch response
63
- in_recordset = self._send_receive_recordset(
64
- out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
63
+ in_recorddict = self._send_receive_recorddict(
64
+ out_recorddict, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
65
65
  )
66
- # RecordSet to Res
67
- return compat.recordset_to_getparametersres(in_recordset, False)
66
+ # RecordDict to Res
67
+ return compat.recorddict_to_getparametersres(in_recorddict, False)
68
68
 
69
69
  def fit(
70
70
  self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
71
71
  ) -> common.FitRes:
72
72
  """Train model parameters on the locally held dataset."""
73
- # Ins to RecordSet
74
- out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
73
+ # Ins to RecordDict
74
+ out_recorddict = compat.fitins_to_recorddict(ins, keep_input=True)
75
75
  # Fetch response
76
- in_recordset = self._send_receive_recordset(
77
- out_recordset, MessageType.TRAIN, timeout, group_id
76
+ in_recorddict = self._send_receive_recorddict(
77
+ out_recorddict, MessageType.TRAIN, timeout, group_id
78
78
  )
79
- # RecordSet to Res
80
- return compat.recordset_to_fitres(in_recordset, keep_input=False)
79
+ # RecordDict to Res
80
+ return compat.recorddict_to_fitres(in_recorddict, keep_input=False)
81
81
 
82
82
  def evaluate(
83
83
  self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
84
84
  ) -> common.EvaluateRes:
85
85
  """Evaluate model parameters on the locally held dataset."""
86
- # Ins to RecordSet
87
- out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
86
+ # Ins to RecordDict
87
+ out_recorddict = compat.evaluateins_to_recorddict(ins, keep_input=True)
88
88
  # Fetch response
89
- in_recordset = self._send_receive_recordset(
90
- out_recordset, MessageType.EVALUATE, timeout, group_id
89
+ in_recorddict = self._send_receive_recorddict(
90
+ out_recorddict, MessageType.EVALUATE, timeout, group_id
91
91
  )
92
- # RecordSet to Res
93
- return compat.recordset_to_evaluateres(in_recordset)
92
+ # RecordDict to Res
93
+ return compat.recorddict_to_evaluateres(in_recorddict)
94
94
 
95
95
  def reconnect(
96
96
  self,
@@ -101,17 +101,17 @@ class DriverClientProxy(ClientProxy):
101
101
  """Disconnect and (optionally) reconnect later."""
102
102
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
103
103
 
104
- def _send_receive_recordset(
104
+ def _send_receive_recorddict(
105
105
  self,
106
- recordset: RecordSet,
106
+ recorddict: RecordDict,
107
107
  message_type: str,
108
108
  timeout: Optional[float],
109
109
  group_id: Optional[int],
110
- ) -> RecordSet:
110
+ ) -> RecordDict:
111
111
 
112
112
  # Create message
113
- message = self.driver.create_message(
114
- content=recordset,
113
+ message = Message(
114
+ content=recorddict,
115
115
  message_type=message_type,
116
116
  dst_node_id=self.node_id,
117
117
  group_id=str(group_id) if group_id else "",
@@ -119,7 +119,7 @@ class DriverClientProxy(ClientProxy):
119
119
  )
120
120
 
121
121
  # Send message and wait for reply
122
- messages = list(self.driver.send_and_receive(messages=[message]))
122
+ messages = list(self.grid.send_and_receive(messages=[message]))
123
123
 
124
124
  # A single reply is expected
125
125
  if len(messages) != 1:
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower Fleet API event log interceptor."""
16
+
17
+
18
+ from typing import Any, Callable, cast
19
+
20
+ import grpc
21
+ from google.protobuf.message import Message as GrpcMessage
22
+
23
+ from flwr.common.event_log_plugin.event_log_plugin import EventLogWriterPlugin
24
+ from flwr.common.typing import LogEntry
25
+
26
+
27
+ class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
28
+ """Fleet API interceptor for logging events."""
29
+
30
+ def __init__(self, log_plugin: EventLogWriterPlugin) -> None:
31
+ self.log_plugin = log_plugin
32
+
33
+ def intercept_service(
34
+ self,
35
+ continuation: Callable[[Any], Any],
36
+ handler_call_details: grpc.HandlerCallDetails,
37
+ ) -> grpc.RpcMethodHandler:
38
+ """Flower Fleet API server interceptor logging logic.
39
+
40
+ Intercept all unary-unary calls from users and log the event. Continue RPC call
41
+ if event logger is enabled on the SuperLink, else, terminate RPC call by setting
42
+ context to abort.
43
+ """
44
+ # One of the method handlers in
45
+ # `flwr.server.superlink.fleet.grpc_rere.fleet_servicer.FleetServicer`
46
+ method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
47
+ method_name: str = handler_call_details.method
48
+ return self._generic_event_log_unary_method_handler(method_handler, method_name)
49
+
50
+ def _generic_event_log_unary_method_handler(
51
+ self, method_handler: grpc.RpcMethodHandler, method_name: str
52
+ ) -> grpc.RpcMethodHandler:
53
+ def _generic_method_handler(
54
+ request: GrpcMessage,
55
+ context: grpc.ServicerContext,
56
+ ) -> GrpcMessage:
57
+ log_entry: LogEntry
58
+ # Log before call
59
+ log_entry = self.log_plugin.compose_log_before_event(
60
+ request=request,
61
+ context=context,
62
+ user_info=None,
63
+ method_name=method_name,
64
+ )
65
+ self.log_plugin.write_log(log_entry)
66
+
67
+ call = method_handler.unary_unary
68
+ unary_response, error = None, None
69
+ try:
70
+ unary_response = cast(GrpcMessage, call(request, context))
71
+ except BaseException as e:
72
+ error = e
73
+ raise
74
+ finally:
75
+ log_entry = self.log_plugin.compose_log_after_event(
76
+ request=request,
77
+ context=context,
78
+ user_info=None,
79
+ method_name=method_name,
80
+ response=unary_response or error,
81
+ )
82
+ self.log_plugin.write_log(log_entry)
83
+ return unary_response
84
+
85
+ if method_handler.unary_unary:
86
+ message_handler = grpc.unary_unary_rpc_method_handler
87
+ else:
88
+ # If the method type is not `unary_unary` raise an error
89
+ raise NotImplementedError("This RPC method type is not supported.")
90
+ return message_handler(
91
+ _generic_method_handler,
92
+ request_deserializer=method_handler.request_deserializer,
93
+ response_serializer=method_handler.response_serializer,
94
+ )
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver SDK."""
15
+ """Flower grid SDK."""
16
16
 
17
17
 
18
- from .driver import Driver
19
- from .grpc_driver import GrpcDriver
20
- from .inmemory_driver import InMemoryDriver
18
+ from .grid import Driver, Grid
19
+ from .grpc_grid import GrpcGrid
20
+ from .inmemory_grid import InMemoryGrid
21
21
 
22
22
  __all__ = [
23
23
  "Driver",
24
- "GrpcDriver",
25
- "InMemoryDriver",
24
+ "Grid",
25
+ "GrpcGrid",
26
+ "InMemoryGrid",
26
27
  ]
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,32 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Driver (abstract base class)."""
15
+ """Grid (abstract base class)."""
16
16
 
17
17
 
18
18
  from abc import ABC, abstractmethod
19
19
  from collections.abc import Iterable
20
20
  from typing import Optional
21
21
 
22
- from flwr.common import Message, RecordSet
22
+ from flwr.common import Message, RecordDict
23
23
  from flwr.common.typing import Run
24
24
 
25
25
 
26
- class Driver(ABC):
27
- """Abstract base Driver class for the ServerAppIo API."""
26
+ class Grid(ABC):
27
+ """Abstract base class Grid to send/receive messages."""
28
28
 
29
29
  @abstractmethod
30
30
  def set_run(self, run_id: int) -> None:
31
31
  """Request a run to the SuperLink with a given `run_id`.
32
32
 
33
- If a Run with the specified `run_id` exists, a local Run
33
+ If a ``Run`` with the specified ``run_id`` exists, a local ``Run``
34
34
  object will be created. It enables further functionality
35
- in the driver, such as sending `Messages`.
35
+ in the grid, such as sending ``Message``s.
36
36
 
37
37
  Parameters
38
38
  ----------
39
39
  run_id : int
40
- The `run_id` of the Run this Driver object operates in.
40
+ The ``run_id`` of the ``Run`` this ``Grid`` object operates in.
41
41
  """
42
42
 
43
43
  @property
@@ -48,7 +48,7 @@ class Driver(ABC):
48
48
  @abstractmethod
49
49
  def create_message( # pylint: disable=too-many-arguments,R0917
50
50
  self,
51
- content: RecordSet,
51
+ content: RecordDict,
52
52
  message_type: str,
53
53
  dst_node_id: int,
54
54
  group_id: str,
@@ -56,12 +56,12 @@ class Driver(ABC):
56
56
  ) -> Message:
57
57
  """Create a new message with specified parameters.
58
58
 
59
- This method constructs a new `Message` with given content and metadata.
60
- The `run_id` and `src_node_id` will be set automatically.
59
+ This method constructs a new ``Message`` with given content and metadata.
60
+ The ``run_id`` and ``src_node_id`` will be set automatically.
61
61
 
62
62
  Parameters
63
63
  ----------
64
- content : RecordSet
64
+ content : RecordDict
65
65
  The content for the new message. This holds records that are to be sent
66
66
  to the destination node.
67
67
  message_type : str
@@ -71,12 +71,12 @@ class Driver(ABC):
71
71
  The ID of the destination node to which the message is being sent.
72
72
  group_id : str
73
73
  The ID of the group to which this message is associated. In some settings,
74
- this is used as the FL round.
74
+ this is used as the federated learning round.
75
75
  ttl : Optional[float] (default: None)
76
76
  Time-to-live for the round trip of this message, i.e., the time from sending
77
77
  this message to receiving a reply. It specifies in seconds the duration for
78
78
  which the message and its potential reply are considered valid. If unset,
79
- the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
79
+ the default TTL (i.e., ``common.DEFAULT_TTL``) will be used.
80
80
 
81
81
  Returns
82
82
  -------
@@ -93,7 +93,7 @@ class Driver(ABC):
93
93
  """Push messages to specified node IDs.
94
94
 
95
95
  This method takes an iterable of messages and sends each message
96
- to the node specified in `dst_node_id`.
96
+ to the node specified in ``dst_node_id``.
97
97
 
98
98
  Parameters
99
99
  ----------
@@ -154,8 +154,37 @@ class Driver(ABC):
154
154
 
155
155
  Notes
156
156
  -----
157
- This method uses `push_messages` to send the messages and `pull_messages`
158
- to collect the replies. If `timeout` is set, the method may not return
157
+ This method uses ``push_messages`` to send the messages and ``pull_messages``
158
+ to collect the replies. If ``timeout`` is set, the method may not return
159
159
  replies for all sent messages. A message remains valid until its TTL,
160
- which is not affected by `timeout`.
160
+ which is not affected by ``timeout``.
161
161
  """
162
+
163
+
164
+ class Driver(Grid):
165
+ """Deprecated abstract base class ``Driver``, use ``Grid`` instead.
166
+
167
+ This class is provided solely for backward compatibility with legacy
168
+ code that previously relied on the ``Driver`` class. It has been deprecated
169
+ in favor of the updated abstract base class ``Grid``, which now encompasses
170
+ all communication-related functionality and improvements between the
171
+ ServerApp and the SuperLink.
172
+
173
+ .. warning::
174
+ ``Driver`` is deprecated and will be removed in a future release.
175
+ Use `Grid` in the signature of your ServerApp.
176
+
177
+ Examples
178
+ --------
179
+ Legacy (deprecated) usage::
180
+
181
+ @app.main()
182
+ def main(driver: Driver, context: Context) -> None:
183
+ ...
184
+
185
+ Updated usage::
186
+
187
+ @app.main()
188
+ def main(grid: Grid, context: Context) -> None:
189
+ ...
190
+ """
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,24 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower gRPC Driver."""
15
+ """Flower gRPC Grid."""
16
16
 
17
17
 
18
18
  import time
19
- import warnings
20
19
  from collections.abc import Iterable
21
- from logging import DEBUG, WARNING
20
+ from logging import DEBUG, ERROR, WARNING
22
21
  from typing import Optional, cast
23
22
 
24
23
  import grpc
25
24
 
26
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
+ from flwr.common import Message, RecordDict
27
26
  from flwr.common.constant import (
28
27
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
29
28
  SUPERLINK_NODE_ID,
30
29
  )
31
30
  from flwr.common.grpc import create_channel, on_channel_state_change
32
- from flwr.common.logger import log
31
+ from flwr.common.logger import log, warn_deprecated_feature
33
32
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
34
33
  from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
35
34
  from flwr.common.typing import Run
@@ -46,18 +45,39 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
46
45
  )
47
46
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
48
47
 
49
- from .driver import Driver
48
+ from .grid import Grid
50
49
 
51
- ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
52
- [flwr-serverapp] Error: Not connected.
50
+ ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED = """
53
51
 
54
- Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
55
- `GrpcDriverStub` methods.
52
+ [Grid.push_messages] gRPC error occurred:
53
+
54
+ The 2GB gRPC limit has been reached. Consider reducing the number of messages pushed
55
+ at once, or push messages individually, for example:
56
+
57
+ > msgs = [msg1, msg2, msg3]
58
+ > msg_ids = []
59
+ > for msg in msgs:
60
+ > msg_id = grid.push_messages([msg])
61
+ > msg_ids.extend(msg_id)
62
+ """
63
+
64
+ ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED = """
65
+
66
+ [Grid.pull_messages] gRPC error occurred:
67
+
68
+ The 2GB gRPC limit has been reached. Consider reducing the number of messages pulled
69
+ at once, or pull messages individually, for example:
70
+
71
+ > msgs_ids = [msg_id1, msg_id2, msg_id3]
72
+ > msgs = []
73
+ > for msg_id in msg_ids:
74
+ > msg = grid.pull_messages([msg_id])
75
+ > msgs.extend(msg)
56
76
  """
57
77
 
58
78
 
59
- class GrpcDriver(Driver):
60
- """`GrpcDriver` provides an interface to the ServerAppIo API.
79
+ class GrpcGrid(Grid):
80
+ """`GrpcGrid` provides an interface to the ServerAppIo API.
61
81
 
62
82
  Parameters
63
83
  ----------
@@ -69,6 +89,8 @@ class GrpcDriver(Driver):
69
89
  established to an SSL-enabled Flower server.
70
90
  """
71
91
 
92
+ _deprecation_warning_logged = False
93
+
72
94
  def __init__( # pylint: disable=too-many-arguments
73
95
  self,
74
96
  serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
@@ -81,6 +103,7 @@ class GrpcDriver(Driver):
81
103
  self._channel: Optional[grpc.Channel] = None
82
104
  self.node = Node(node_id=SUPERLINK_NODE_ID)
83
105
  self._retry_invoker = _make_simple_grpc_retry_invoker()
106
+ super().__init__()
84
107
 
85
108
  @property
86
109
  def _is_connected(self) -> bool:
@@ -140,18 +163,15 @@ class GrpcDriver(Driver):
140
163
  def _check_message(self, message: Message) -> None:
141
164
  # Check if the message is valid
142
165
  if not (
143
- # Assume self._run being initialized
144
- message.metadata.run_id == cast(Run, self._run).run_id
145
- and message.metadata.src_node_id == self.node.node_id
146
- and message.metadata.message_id == ""
147
- and message.metadata.reply_to_message == ""
166
+ message.metadata.message_id == ""
167
+ and message.metadata.reply_to_message_id == ""
148
168
  and message.metadata.ttl > 0
149
169
  ):
150
170
  raise ValueError(f"Invalid message: {message}")
151
171
 
152
172
  def create_message( # pylint: disable=too-many-arguments,R0917
153
173
  self,
154
- content: RecordSet,
174
+ content: RecordDict,
155
175
  message_type: str,
156
176
  dst_node_id: int,
157
177
  group_id: str,
@@ -162,30 +182,17 @@ class GrpcDriver(Driver):
162
182
  This method constructs a new `Message` with given content and metadata.
163
183
  The `run_id` and `src_node_id` will be set automatically.
164
184
  """
165
- if ttl:
166
- warnings.warn(
167
- "A custom TTL was set, but note that the SuperLink does not enforce "
168
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
169
- "version of Flower.",
170
- stacklevel=2,
185
+ if not GrpcGrid._deprecation_warning_logged:
186
+ GrpcGrid._deprecation_warning_logged = True
187
+ warn_deprecated_feature(
188
+ "`Driver.create_message` / `Grid.create_message` is deprecated."
189
+ "Use `Message` constructor instead."
171
190
  )
172
-
173
- ttl_ = DEFAULT_TTL if ttl is None else ttl
174
- metadata = Metadata(
175
- run_id=cast(Run, self._run).run_id,
176
- message_id="", # Will be set by the server
177
- src_node_id=self.node.node_id,
178
- dst_node_id=dst_node_id,
179
- reply_to_message="",
180
- group_id=group_id,
181
- ttl=ttl_,
182
- message_type=message_type,
183
- )
184
- return Message(metadata=metadata, content=content)
191
+ return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
185
192
 
186
193
  def get_node_ids(self) -> Iterable[int]:
187
194
  """Get node IDs."""
188
- # Call GrpcDriverStub method
195
+ # Call GrpcServerAppIoStub method
189
196
  res: GetNodesResponse = self._stub.GetNodes(
190
197
  GetNodesRequest(run_id=cast(Run, self._run).run_id)
191
198
  )
@@ -198,30 +205,40 @@ class GrpcDriver(Driver):
198
205
  to the node specified in `dst_node_id`.
199
206
  """
200
207
  # Construct Messages
208
+ run_id = cast(Run, self._run).run_id
201
209
  message_proto_list: list[ProtoMessage] = []
202
210
  for msg in messages:
211
+ # Populate metadata
212
+ msg.metadata.__dict__["_run_id"] = run_id
213
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
203
214
  # Check message
204
215
  self._check_message(msg)
205
216
  # Convert to proto
206
217
  msg_proto = message_to_proto(msg)
207
218
  # Add to list
208
219
  message_proto_list.append(msg_proto)
209
- # Call GrpcDriverStub method
210
- res: PushInsMessagesResponse = self._stub.PushMessages(
211
- PushInsMessagesRequest(
212
- messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
213
- )
214
- )
215
- if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
216
- list(message_proto_list)
217
- ):
218
- log(
219
- WARNING,
220
- "Not all messages could be pushed to the SuperLink. The returned "
221
- "list has `None` for those messages (the order is preserved as passed "
222
- "to `push_messages`). This could be due to a malformed message.",
220
+
221
+ try:
222
+ # Call GrpcServerAppIoStub method
223
+ res: PushInsMessagesResponse = self._stub.PushMessages(
224
+ PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
223
225
  )
224
- return list(res.message_ids)
226
+ if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
227
+ message_proto_list
228
+ ):
229
+ log(
230
+ WARNING,
231
+ "Not all messages could be pushed to the SuperLink. The returned "
232
+ "list has `None` for those messages (the order is preserved as "
233
+ "passed to `push_messages`). This could be due to a malformed "
234
+ "message.",
235
+ )
236
+ return list(res.message_ids)
237
+ except grpc.RpcError as e:
238
+ if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
239
+ log(ERROR, ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED)
240
+ return []
241
+ raise
225
242
 
226
243
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
227
244
  """Pull messages based on message IDs.
@@ -229,16 +246,22 @@ class GrpcDriver(Driver):
229
246
  This method is used to collect messages from the SuperLink that correspond to a
230
247
  set of given message IDs.
231
248
  """
232
- # Pull Messages
233
- res: PullResMessagesResponse = self._stub.PullMessages(
234
- PullResMessagesRequest(
235
- message_ids=message_ids,
236
- run_id=cast(Run, self._run).run_id,
249
+ try:
250
+ # Pull Messages
251
+ res: PullResMessagesResponse = self._stub.PullMessages(
252
+ PullResMessagesRequest(
253
+ message_ids=message_ids,
254
+ run_id=cast(Run, self._run).run_id,
255
+ )
237
256
  )
238
- )
239
- # Convert Message from Protobuf representation
240
- msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
241
- return msgs
257
+ # Convert Message from Protobuf representation
258
+ msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
259
+ return msgs
260
+ except grpc.RpcError as e:
261
+ if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
262
+ log(ERROR, ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED)
263
+ return []
264
+ raise
242
265
 
243
266
  def send_and_receive(
244
267
  self,
@@ -262,7 +285,7 @@ class GrpcDriver(Driver):
262
285
  res_msgs = self.pull_messages(msg_ids)
263
286
  ret.extend(res_msgs)
264
287
  msg_ids.difference_update(
265
- {msg.metadata.reply_to_message for msg in res_msgs}
288
+ {msg.metadata.reply_to_message_id for msg in res_msgs}
266
289
  )
267
290
  if len(msg_ids) == 0:
268
291
  break