flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. flwr/client/app.py +6 -4
  2. flwr/client/clientapp/app.py +2 -2
  3. flwr/client/grpc_client/connection.py +23 -20
  4. flwr/client/message_handler/message_handler.py +27 -27
  5. flwr/client/mod/centraldp_mods.py +7 -7
  6. flwr/client/mod/localdp_mod.py +4 -4
  7. flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
  8. flwr/client/run_info_store.py +2 -2
  9. flwr/common/__init__.py +2 -0
  10. flwr/common/constant.py +2 -0
  11. flwr/common/context.py +4 -4
  12. flwr/common/logger.py +2 -2
  13. flwr/common/message.py +269 -101
  14. flwr/common/record/__init__.py +2 -1
  15. flwr/common/record/configsrecord.py +2 -2
  16. flwr/common/record/metricsrecord.py +1 -1
  17. flwr/common/record/parametersrecord.py +1 -1
  18. flwr/common/record/{recordset.py → recorddict.py} +57 -17
  19. flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
  20. flwr/common/serde.py +33 -37
  21. flwr/proto/exec_pb2.py +32 -32
  22. flwr/proto/exec_pb2.pyi +3 -3
  23. flwr/proto/message_pb2.py +12 -12
  24. flwr/proto/message_pb2.pyi +9 -9
  25. flwr/proto/recorddict_pb2.py +70 -0
  26. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
  27. flwr/proto/run_pb2.py +32 -32
  28. flwr/proto/run_pb2.pyi +3 -3
  29. flwr/server/__init__.py +2 -0
  30. flwr/server/compat/__init__.py +2 -2
  31. flwr/server/compat/app.py +11 -11
  32. flwr/server/compat/app_utils.py +16 -16
  33. flwr/server/compat/grid_client_proxy.py +38 -38
  34. flwr/server/grid/__init__.py +7 -6
  35. flwr/server/grid/grid.py +46 -17
  36. flwr/server/grid/grpc_grid.py +26 -33
  37. flwr/server/grid/inmemory_grid.py +19 -25
  38. flwr/server/run_serverapp.py +4 -4
  39. flwr/server/server_app.py +37 -11
  40. flwr/server/serverapp/app.py +10 -10
  41. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  42. flwr/server/superlink/linkstate/in_memory_linkstate.py +29 -4
  43. flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -20
  44. flwr/server/superlink/linkstate/utils.py +77 -17
  45. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
  46. flwr/server/typing.py +3 -3
  47. flwr/server/utils/validator.py +4 -4
  48. flwr/server/workflow/default_workflows.py +24 -26
  49. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +23 -23
  50. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  51. flwr/simulation/run_simulation.py +13 -13
  52. flwr/superexec/deployment.py +2 -2
  53. flwr/superexec/simulation.py +2 -2
  54. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
  55. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +60 -60
  56. flwr/proto/recordset_pb2.py +0 -70
  57. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  58. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  59. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
  60. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
  61. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
flwr/server/compat/app.py CHANGED
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver app."""
15
+ """Flower grid app."""
16
16
 
17
17
 
18
18
  from logging import INFO
@@ -25,27 +25,27 @@ from flwr.server.server import Server, init_defaults, run_fl
25
25
  from flwr.server.server_config import ServerConfig
26
26
  from flwr.server.strategy import Strategy
27
27
 
28
- from ..grid import Driver
28
+ from ..grid import Grid
29
29
  from .app_utils import start_update_client_manager_thread
30
30
 
31
31
 
32
- def start_driver( # pylint: disable=too-many-arguments, too-many-locals
32
+ def start_grid( # pylint: disable=too-many-arguments, too-many-locals
33
33
  *,
34
- driver: Driver,
34
+ grid: Grid,
35
35
  server: Optional[Server] = None,
36
36
  config: Optional[ServerConfig] = None,
37
37
  strategy: Optional[Strategy] = None,
38
38
  client_manager: Optional[ClientManager] = None,
39
39
  ) -> History:
40
- """Start a Flower Driver API server.
40
+ """Start a Flower server.
41
41
 
42
42
  Parameters
43
43
  ----------
44
- driver : Driver
45
- The Driver object to use.
44
+ grid : Grid
45
+ The Grid object to use.
46
46
  server : Optional[flwr.server.Server] (default: None)
47
47
  A server implementation, either `flwr.server.Server` or a subclass
48
- thereof. If no instance is provided, then `start_driver` will create
48
+ thereof. If no instance is provided, then `start_grid` will create
49
49
  one.
50
50
  config : Optional[ServerConfig] (default: None)
51
51
  Currently supported values are `num_rounds` (int, default: 1) and
@@ -56,7 +56,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
56
56
  `start_server` will use `flwr.server.strategy.FedAvg`.
57
57
  client_manager : Optional[flwr.server.ClientManager] (default: None)
58
58
  An implementation of the class `flwr.server.ClientManager`. If no
59
- implementation is provided, then `start_driver` will use
59
+ implementation is provided, then `start_grid` will use
60
60
  `flwr.server.SimpleClientManager`.
61
61
 
62
62
  Returns
@@ -64,7 +64,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
64
64
  hist : flwr.server.history.History
65
65
  Object containing training and evaluation metrics.
66
66
  """
67
- # Initialize the Driver API server and config
67
+ # Initialize the server and config
68
68
  initialized_server, initialized_config = init_defaults(
69
69
  server=server,
70
70
  config=config,
@@ -80,7 +80,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
80
80
 
81
81
  # Start the thread updating nodes
82
82
  thread, f_stop, c_done = start_update_client_manager_thread(
83
- driver, initialized_server.client_manager()
83
+ grid, initialized_server.client_manager()
84
84
  )
85
85
 
86
86
  # Wait until the node registration done
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Utility functions for the `start_driver`."""
15
+ """Utility functions for the `start_grid`."""
16
16
 
17
17
 
18
18
  import threading
@@ -20,18 +20,18 @@ import threading
20
20
  from flwr.common.typing import RunNotRunningException
21
21
 
22
22
  from ..client_manager import ClientManager
23
- from ..grid import Driver
24
- from .grid_client_proxy import DriverClientProxy
23
+ from ..grid import Grid
24
+ from .grid_client_proxy import GridClientProxy
25
25
 
26
26
 
27
27
  def start_update_client_manager_thread(
28
- driver: Driver,
28
+ grid: Grid,
29
29
  client_manager: ClientManager,
30
30
  ) -> tuple[threading.Thread, threading.Event, threading.Event]:
31
31
  """Periodically update the nodes list in the client manager in a thread.
32
32
 
33
- This function starts a thread that periodically uses the associated driver to
34
- get all node_ids. Each node_id is then converted into a `DriverClientProxy`
33
+ This function starts a thread that periodically uses the associated grid to
34
+ get all node_ids. Each node_id is then converted into a `GridClientProxy`
35
35
  instance and stored in the `registered_nodes` dictionary with node_id as key.
36
36
 
37
37
  New nodes will be added to the ClientManager via `client_manager.register()`,
@@ -40,8 +40,8 @@ def start_update_client_manager_thread(
40
40
 
41
41
  Parameters
42
42
  ----------
43
- driver : Driver
44
- The Driver object to use.
43
+ grid : Grid
44
+ The Grid object to use.
45
45
  client_manager : ClientManager
46
46
  The ClientManager object to be updated.
47
47
 
@@ -59,7 +59,7 @@ def start_update_client_manager_thread(
59
59
  thread = threading.Thread(
60
60
  target=_update_client_manager,
61
61
  args=(
62
- driver,
62
+ grid,
63
63
  client_manager,
64
64
  f_stop,
65
65
  c_done,
@@ -72,17 +72,17 @@ def start_update_client_manager_thread(
72
72
 
73
73
 
74
74
  def _update_client_manager(
75
- driver: Driver,
75
+ grid: Grid,
76
76
  client_manager: ClientManager,
77
77
  f_stop: threading.Event,
78
78
  c_done: threading.Event,
79
79
  ) -> None:
80
80
  """Update the nodes list in the client manager."""
81
- # Loop until the driver is disconnected
82
- registered_nodes: dict[int, DriverClientProxy] = {}
81
+ # Loop until the grid is disconnected
82
+ registered_nodes: dict[int, GridClientProxy] = {}
83
83
  while not f_stop.is_set():
84
84
  try:
85
- all_node_ids = set(driver.get_node_ids())
85
+ all_node_ids = set(grid.get_node_ids())
86
86
  except RunNotRunningException:
87
87
  f_stop.set()
88
88
  break
@@ -97,10 +97,10 @@ def _update_client_manager(
97
97
 
98
98
  # Register new nodes
99
99
  for node_id in new_nodes:
100
- client_proxy = DriverClientProxy(
100
+ client_proxy = GridClientProxy(
101
101
  node_id=node_id,
102
- driver=driver,
103
- run_id=driver.run.run_id,
102
+ grid=grid,
103
+ run_id=grid.run.run_id,
104
104
  )
105
105
  if client_manager.register(client_proxy):
106
106
  registered_nodes[node_id] = client_proxy
@@ -12,26 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower ClientProxy implementation for Driver API."""
15
+ """Flower ClientProxy implementation using Grid."""
16
16
 
17
17
 
18
18
  from typing import Optional
19
19
 
20
20
  from flwr import common
21
- from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
22
- from flwr.common import recordset_compat as compat
21
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordDict
22
+ from flwr.common import recorddict_compat as compat
23
23
  from flwr.server.client_proxy import ClientProxy
24
24
 
25
- from ..grid.grid import Driver
25
+ from ..grid.grid import Grid
26
26
 
27
27
 
28
- class DriverClientProxy(ClientProxy):
29
- """Flower client proxy which delegates work using the Driver API."""
28
+ class GridClientProxy(ClientProxy):
29
+ """Flower client proxy which delegates work using Grid."""
30
30
 
31
- def __init__(self, node_id: int, driver: Driver, run_id: int):
31
+ def __init__(self, node_id: int, grid: Grid, run_id: int):
32
32
  super().__init__(str(node_id))
33
33
  self.node_id = node_id
34
- self.driver = driver
34
+ self.grid = grid
35
35
  self.run_id = run_id
36
36
 
37
37
  def get_properties(
@@ -41,14 +41,14 @@ class DriverClientProxy(ClientProxy):
41
41
  group_id: Optional[int],
42
42
  ) -> common.GetPropertiesRes:
43
43
  """Return client's properties."""
44
- # Ins to RecordSet
45
- out_recordset = compat.getpropertiesins_to_recordset(ins)
44
+ # Ins to RecordDict
45
+ out_recorddict = compat.getpropertiesins_to_recorddict(ins)
46
46
  # Fetch response
47
- in_recordset = self._send_receive_recordset(
48
- out_recordset, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
47
+ in_recorddict = self._send_receive_recorddict(
48
+ out_recorddict, MessageTypeLegacy.GET_PROPERTIES, timeout, group_id
49
49
  )
50
- # RecordSet to Res
51
- return compat.recordset_to_getpropertiesres(in_recordset)
50
+ # RecordDict to Res
51
+ return compat.recorddict_to_getpropertiesres(in_recorddict)
52
52
 
53
53
  def get_parameters(
54
54
  self,
@@ -57,40 +57,40 @@ class DriverClientProxy(ClientProxy):
57
57
  group_id: Optional[int],
58
58
  ) -> common.GetParametersRes:
59
59
  """Return the current local model parameters."""
60
- # Ins to RecordSet
61
- out_recordset = compat.getparametersins_to_recordset(ins)
60
+ # Ins to RecordDict
61
+ out_recorddict = compat.getparametersins_to_recorddict(ins)
62
62
  # Fetch response
63
- in_recordset = self._send_receive_recordset(
64
- out_recordset, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
63
+ in_recorddict = self._send_receive_recorddict(
64
+ out_recorddict, MessageTypeLegacy.GET_PARAMETERS, timeout, group_id
65
65
  )
66
- # RecordSet to Res
67
- return compat.recordset_to_getparametersres(in_recordset, False)
66
+ # RecordDict to Res
67
+ return compat.recorddict_to_getparametersres(in_recorddict, False)
68
68
 
69
69
  def fit(
70
70
  self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
71
71
  ) -> common.FitRes:
72
72
  """Train model parameters on the locally held dataset."""
73
- # Ins to RecordSet
74
- out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
73
+ # Ins to RecordDict
74
+ out_recorddict = compat.fitins_to_recorddict(ins, keep_input=True)
75
75
  # Fetch response
76
- in_recordset = self._send_receive_recordset(
77
- out_recordset, MessageType.TRAIN, timeout, group_id
76
+ in_recorddict = self._send_receive_recorddict(
77
+ out_recorddict, MessageType.TRAIN, timeout, group_id
78
78
  )
79
- # RecordSet to Res
80
- return compat.recordset_to_fitres(in_recordset, keep_input=False)
79
+ # RecordDict to Res
80
+ return compat.recorddict_to_fitres(in_recorddict, keep_input=False)
81
81
 
82
82
  def evaluate(
83
83
  self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
84
84
  ) -> common.EvaluateRes:
85
85
  """Evaluate model parameters on the locally held dataset."""
86
- # Ins to RecordSet
87
- out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
86
+ # Ins to RecordDict
87
+ out_recorddict = compat.evaluateins_to_recorddict(ins, keep_input=True)
88
88
  # Fetch response
89
- in_recordset = self._send_receive_recordset(
90
- out_recordset, MessageType.EVALUATE, timeout, group_id
89
+ in_recorddict = self._send_receive_recorddict(
90
+ out_recorddict, MessageType.EVALUATE, timeout, group_id
91
91
  )
92
- # RecordSet to Res
93
- return compat.recordset_to_evaluateres(in_recordset)
92
+ # RecordDict to Res
93
+ return compat.recorddict_to_evaluateres(in_recorddict)
94
94
 
95
95
  def reconnect(
96
96
  self,
@@ -101,17 +101,17 @@ class DriverClientProxy(ClientProxy):
101
101
  """Disconnect and (optionally) reconnect later."""
102
102
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
103
103
 
104
- def _send_receive_recordset(
104
+ def _send_receive_recorddict(
105
105
  self,
106
- recordset: RecordSet,
106
+ recorddict: RecordDict,
107
107
  message_type: str,
108
108
  timeout: Optional[float],
109
109
  group_id: Optional[int],
110
- ) -> RecordSet:
110
+ ) -> RecordDict:
111
111
 
112
112
  # Create message
113
- message = self.driver.create_message(
114
- content=recordset,
113
+ message = self.grid.create_message(
114
+ content=recorddict,
115
115
  message_type=message_type,
116
116
  dst_node_id=self.node_id,
117
117
  group_id=str(group_id) if group_id else "",
@@ -119,7 +119,7 @@ class DriverClientProxy(ClientProxy):
119
119
  )
120
120
 
121
121
  # Send message and wait for reply
122
- messages = list(self.driver.send_and_receive(messages=[message]))
122
+ messages = list(self.grid.send_and_receive(messages=[message]))
123
123
 
124
124
  # A single reply is expected
125
125
  if len(messages) != 1:
@@ -12,15 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver SDK."""
15
+ """Flower grid SDK."""
16
16
 
17
17
 
18
- from .grid import Driver
19
- from .grpc_grid import GrpcDriver
20
- from .inmemory_grid import InMemoryDriver
18
+ from .grid import Driver, Grid
19
+ from .grpc_grid import GrpcGrid
20
+ from .inmemory_grid import InMemoryGrid
21
21
 
22
22
  __all__ = [
23
23
  "Driver",
24
- "GrpcDriver",
25
- "InMemoryDriver",
24
+ "Grid",
25
+ "GrpcGrid",
26
+ "InMemoryGrid",
26
27
  ]
flwr/server/grid/grid.py CHANGED
@@ -12,32 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Driver (abstract base class)."""
15
+ """Grid (abstract base class)."""
16
16
 
17
17
 
18
18
  from abc import ABC, abstractmethod
19
19
  from collections.abc import Iterable
20
20
  from typing import Optional
21
21
 
22
- from flwr.common import Message, RecordSet
22
+ from flwr.common import Message, RecordDict
23
23
  from flwr.common.typing import Run
24
24
 
25
25
 
26
- class Driver(ABC):
27
- """Abstract base Driver class for the ServerAppIo API."""
26
+ class Grid(ABC):
27
+ """Abstract base class Grid to send/receive messages."""
28
28
 
29
29
  @abstractmethod
30
30
  def set_run(self, run_id: int) -> None:
31
31
  """Request a run to the SuperLink with a given `run_id`.
32
32
 
33
- If a Run with the specified `run_id` exists, a local Run
33
+ If a ``Run`` with the specified ``run_id`` exists, a local ``Run``
34
34
  object will be created. It enables further functionality
35
- in the driver, such as sending `Messages`.
35
+ in the grid, such as sending ``Message``s.
36
36
 
37
37
  Parameters
38
38
  ----------
39
39
  run_id : int
40
- The `run_id` of the Run this Driver object operates in.
40
+ The ``run_id`` of the ``Run`` this ``Grid`` object operates in.
41
41
  """
42
42
 
43
43
  @property
@@ -48,7 +48,7 @@ class Driver(ABC):
48
48
  @abstractmethod
49
49
  def create_message( # pylint: disable=too-many-arguments,R0917
50
50
  self,
51
- content: RecordSet,
51
+ content: RecordDict,
52
52
  message_type: str,
53
53
  dst_node_id: int,
54
54
  group_id: str,
@@ -56,12 +56,12 @@ class Driver(ABC):
56
56
  ) -> Message:
57
57
  """Create a new message with specified parameters.
58
58
 
59
- This method constructs a new `Message` with given content and metadata.
60
- The `run_id` and `src_node_id` will be set automatically.
59
+ This method constructs a new ``Message`` with given content and metadata.
60
+ The ``run_id`` and ``src_node_id`` will be set automatically.
61
61
 
62
62
  Parameters
63
63
  ----------
64
- content : RecordSet
64
+ content : RecordDict
65
65
  The content for the new message. This holds records that are to be sent
66
66
  to the destination node.
67
67
  message_type : str
@@ -71,12 +71,12 @@ class Driver(ABC):
71
71
  The ID of the destination node to which the message is being sent.
72
72
  group_id : str
73
73
  The ID of the group to which this message is associated. In some settings,
74
- this is used as the FL round.
74
+ this is used as the federated learning round.
75
75
  ttl : Optional[float] (default: None)
76
76
  Time-to-live for the round trip of this message, i.e., the time from sending
77
77
  this message to receiving a reply. It specifies in seconds the duration for
78
78
  which the message and its potential reply are considered valid. If unset,
79
- the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
79
+ the default TTL (i.e., ``common.DEFAULT_TTL``) will be used.
80
80
 
81
81
  Returns
82
82
  -------
@@ -93,7 +93,7 @@ class Driver(ABC):
93
93
  """Push messages to specified node IDs.
94
94
 
95
95
  This method takes an iterable of messages and sends each message
96
- to the node specified in `dst_node_id`.
96
+ to the node specified in ``dst_node_id``.
97
97
 
98
98
  Parameters
99
99
  ----------
@@ -154,8 +154,37 @@ class Driver(ABC):
154
154
 
155
155
  Notes
156
156
  -----
157
- This method uses `push_messages` to send the messages and `pull_messages`
158
- to collect the replies. If `timeout` is set, the method may not return
157
+ This method uses ``push_messages`` to send the messages and ``pull_messages``
158
+ to collect the replies. If ``timeout`` is set, the method may not return
159
159
  replies for all sent messages. A message remains valid until its TTL,
160
- which is not affected by `timeout`.
160
+ which is not affected by ``timeout``.
161
161
  """
162
+
163
+
164
+ class Driver(Grid):
165
+ """Deprecated abstract base class ``Driver``, use ``Grid`` instead.
166
+
167
+ This class is provided solely for backward compatibility with legacy
168
+ code that previously relied on the ``Driver`` class. It has been deprecated
169
+ in favor of the updated abstract base class ``Grid``, which now encompasses
170
+ all communication-related functionality and improvements between the
171
+ ServerApp and the SuperLink.
172
+
173
+ .. warning::
174
+ ``Driver`` is deprecated and will be removed in a future release.
175
+ Use `Grid` in the signature of your ServerApp.
176
+
177
+ Examples
178
+ --------
179
+ Legacy (deprecated) usage::
180
+
181
+ @app.main()
182
+ def main(driver: Driver, context: Context) -> None:
183
+ ...
184
+
185
+ Updated usage::
186
+
187
+ @app.main()
188
+ def main(grid: Grid, context: Context) -> None:
189
+ ...
190
+ """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower gRPC Driver."""
15
+ """Flower gRPC Grid."""
16
16
 
17
17
 
18
18
  import time
@@ -22,13 +22,13 @@ from typing import Optional, cast
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
+ from flwr.common import Message, RecordDict
26
26
  from flwr.common.constant import (
27
27
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
28
28
  SUPERLINK_NODE_ID,
29
29
  )
30
30
  from flwr.common.grpc import create_channel, on_channel_state_change
31
- from flwr.common.logger import log
31
+ from flwr.common.logger import log, warn_deprecated_feature
32
32
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
33
33
  from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
34
34
  from flwr.common.typing import Run
@@ -45,11 +45,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
47
47
 
48
- from .grid import Driver
48
+ from .grid import Grid
49
49
 
50
50
  ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED = """
51
51
 
52
- [Driver.push_messages] gRPC error occurred:
52
+ [Grid.push_messages] gRPC error occurred:
53
53
 
54
54
  The 2GB gRPC limit has been reached. Consider reducing the number of messages pushed
55
55
  at once, or push messages individually, for example:
@@ -57,13 +57,13 @@ at once, or push messages individually, for example:
57
57
  > msgs = [msg1, msg2, msg3]
58
58
  > msg_ids = []
59
59
  > for msg in msgs:
60
- > msg_id = driver.push_messages([msg])
60
+ > msg_id = grid.push_messages([msg])
61
61
  > msg_ids.extend(msg_id)
62
62
  """
63
63
 
64
64
  ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED = """
65
65
 
66
- [Driver.pull_messages] gRPC error occurred:
66
+ [Grid.pull_messages] gRPC error occurred:
67
67
 
68
68
  The 2GB gRPC limit has been reached. Consider reducing the number of messages pulled
69
69
  at once, or pull messages individually, for example:
@@ -71,13 +71,13 @@ at once, or pull messages individually, for example:
71
71
  > msgs_ids = [msg_id1, msg_id2, msg_id3]
72
72
  > msgs = []
73
73
  > for msg_id in msg_ids:
74
- > msg = driver.pull_messages([msg_id])
74
+ > msg = grid.pull_messages([msg_id])
75
75
  > msgs.extend(msg)
76
76
  """
77
77
 
78
78
 
79
- class GrpcDriver(Driver):
80
- """`GrpcDriver` provides an interface to the ServerAppIo API.
79
+ class GrpcGrid(Grid):
80
+ """`GrpcGrid` provides an interface to the ServerAppIo API.
81
81
 
82
82
  Parameters
83
83
  ----------
@@ -101,6 +101,7 @@ class GrpcDriver(Driver):
101
101
  self._channel: Optional[grpc.Channel] = None
102
102
  self.node = Node(node_id=SUPERLINK_NODE_ID)
103
103
  self._retry_invoker = _make_simple_grpc_retry_invoker()
104
+ super().__init__()
104
105
 
105
106
  @property
106
107
  def _is_connected(self) -> bool:
@@ -160,18 +161,15 @@ class GrpcDriver(Driver):
160
161
  def _check_message(self, message: Message) -> None:
161
162
  # Check if the message is valid
162
163
  if not (
163
- # Assume self._run being initialized
164
- message.metadata.run_id == cast(Run, self._run).run_id
165
- and message.metadata.src_node_id == self.node.node_id
166
- and message.metadata.message_id == ""
167
- and message.metadata.reply_to_message == ""
164
+ message.metadata.message_id == ""
165
+ and message.metadata.reply_to_message_id == ""
168
166
  and message.metadata.ttl > 0
169
167
  ):
170
168
  raise ValueError(f"Invalid message: {message}")
171
169
 
172
170
  def create_message( # pylint: disable=too-many-arguments,R0917
173
171
  self,
174
- content: RecordSet,
172
+ content: RecordDict,
175
173
  message_type: str,
176
174
  dst_node_id: int,
177
175
  group_id: str,
@@ -182,22 +180,15 @@ class GrpcDriver(Driver):
182
180
  This method constructs a new `Message` with given content and metadata.
183
181
  The `run_id` and `src_node_id` will be set automatically.
184
182
  """
185
- ttl_ = DEFAULT_TTL if ttl is None else ttl
186
- metadata = Metadata(
187
- run_id=cast(Run, self._run).run_id,
188
- message_id="", # Will be set by the server
189
- src_node_id=self.node.node_id,
190
- dst_node_id=dst_node_id,
191
- reply_to_message="",
192
- group_id=group_id,
193
- ttl=ttl_,
194
- message_type=message_type,
183
+ warn_deprecated_feature(
184
+ "`Driver.create_message` / `Grid.create_message` is deprecated."
185
+ "Use `Message` constructor instead."
195
186
  )
196
- return Message(metadata=metadata, content=content)
187
+ return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
197
188
 
198
189
  def get_node_ids(self) -> Iterable[int]:
199
190
  """Get node IDs."""
200
- # Call GrpcDriverStub method
191
+ # Call GrpcServerAppIoStub method
201
192
  res: GetNodesResponse = self._stub.GetNodes(
202
193
  GetNodesRequest(run_id=cast(Run, self._run).run_id)
203
194
  )
@@ -210,8 +201,12 @@ class GrpcDriver(Driver):
210
201
  to the node specified in `dst_node_id`.
211
202
  """
212
203
  # Construct Messages
204
+ run_id = cast(Run, self._run).run_id
213
205
  message_proto_list: list[ProtoMessage] = []
214
206
  for msg in messages:
207
+ # Populate metadata
208
+ msg.metadata.__dict__["_run_id"] = run_id
209
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
215
210
  # Check message
216
211
  self._check_message(msg)
217
212
  # Convert to proto
@@ -220,11 +215,9 @@ class GrpcDriver(Driver):
220
215
  message_proto_list.append(msg_proto)
221
216
 
222
217
  try:
223
- # Call GrpcDriverStub method
218
+ # Call GrpcServerAppIoStub method
224
219
  res: PushInsMessagesResponse = self._stub.PushMessages(
225
- PushInsMessagesRequest(
226
- messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
227
- )
220
+ PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
228
221
  )
229
222
  if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
230
223
  message_proto_list
@@ -288,7 +281,7 @@ class GrpcDriver(Driver):
288
281
  res_msgs = self.pull_messages(msg_ids)
289
282
  ret.extend(res_msgs)
290
283
  msg_ids.difference_update(
291
- {msg.metadata.reply_to_message for msg in res_msgs}
284
+ {msg.metadata.reply_to_message_id for msg in res_msgs}
292
285
  )
293
286
  if len(msg_ids) == 0:
294
287
  break