flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,24 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower gRPC Driver."""
15
+ """Flower gRPC Grid."""
16
16
 
17
17
 
18
18
  import time
19
- import warnings
20
19
  from collections.abc import Iterable
21
- from logging import DEBUG, WARNING
20
+ from logging import DEBUG, ERROR, WARNING
22
21
  from typing import Optional, cast
23
22
 
24
23
  import grpc
25
24
 
26
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
+ from flwr.common import Message, RecordDict
27
26
  from flwr.common.constant import (
28
27
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
29
28
  SUPERLINK_NODE_ID,
30
29
  )
31
30
  from flwr.common.grpc import create_channel, on_channel_state_change
32
- from flwr.common.logger import log
31
+ from flwr.common.logger import log, warn_deprecated_feature
33
32
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
34
33
  from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
35
34
  from flwr.common.typing import Run
@@ -46,18 +45,39 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
46
45
  )
47
46
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
48
47
 
49
- from .driver import Driver
48
+ from .grid import Grid
50
49
 
51
- ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
52
- [flwr-serverapp] Error: Not connected.
50
+ ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED = """
53
51
 
54
- Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
55
- `GrpcDriverStub` methods.
52
+ [Grid.push_messages] gRPC error occurred:
53
+
54
+ The 2GB gRPC limit has been reached. Consider reducing the number of messages pushed
55
+ at once, or push messages individually, for example:
56
+
57
+ > msgs = [msg1, msg2, msg3]
58
+ > msg_ids = []
59
+ > for msg in msgs:
60
+ > msg_id = grid.push_messages([msg])
61
+ > msg_ids.extend(msg_id)
62
+ """
63
+
64
+ ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED = """
65
+
66
+ [Grid.pull_messages] gRPC error occurred:
67
+
68
+ The 2GB gRPC limit has been reached. Consider reducing the number of messages pulled
69
+ at once, or pull messages individually, for example:
70
+
71
+ > msgs_ids = [msg_id1, msg_id2, msg_id3]
72
+ > msgs = []
73
+ > for msg_id in msg_ids:
74
+ > msg = grid.pull_messages([msg_id])
75
+ > msgs.extend(msg)
56
76
  """
57
77
 
58
78
 
59
- class GrpcDriver(Driver):
60
- """`GrpcDriver` provides an interface to the ServerAppIo API.
79
+ class GrpcGrid(Grid):
80
+ """`GrpcGrid` provides an interface to the ServerAppIo API.
61
81
 
62
82
  Parameters
63
83
  ----------
@@ -69,6 +89,8 @@ class GrpcDriver(Driver):
69
89
  established to an SSL-enabled Flower server.
70
90
  """
71
91
 
92
+ _deprecation_warning_logged = False
93
+
72
94
  def __init__( # pylint: disable=too-many-arguments
73
95
  self,
74
96
  serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
@@ -81,6 +103,7 @@ class GrpcDriver(Driver):
81
103
  self._channel: Optional[grpc.Channel] = None
82
104
  self.node = Node(node_id=SUPERLINK_NODE_ID)
83
105
  self._retry_invoker = _make_simple_grpc_retry_invoker()
106
+ super().__init__()
84
107
 
85
108
  @property
86
109
  def _is_connected(self) -> bool:
@@ -140,18 +163,15 @@ class GrpcDriver(Driver):
140
163
  def _check_message(self, message: Message) -> None:
141
164
  # Check if the message is valid
142
165
  if not (
143
- # Assume self._run being initialized
144
- message.metadata.run_id == cast(Run, self._run).run_id
145
- and message.metadata.src_node_id == self.node.node_id
146
- and message.metadata.message_id == ""
147
- and message.metadata.reply_to_message == ""
166
+ message.metadata.message_id == ""
167
+ and message.metadata.reply_to_message_id == ""
148
168
  and message.metadata.ttl > 0
149
169
  ):
150
170
  raise ValueError(f"Invalid message: {message}")
151
171
 
152
172
  def create_message( # pylint: disable=too-many-arguments,R0917
153
173
  self,
154
- content: RecordSet,
174
+ content: RecordDict,
155
175
  message_type: str,
156
176
  dst_node_id: int,
157
177
  group_id: str,
@@ -162,30 +182,17 @@ class GrpcDriver(Driver):
162
182
  This method constructs a new `Message` with given content and metadata.
163
183
  The `run_id` and `src_node_id` will be set automatically.
164
184
  """
165
- if ttl:
166
- warnings.warn(
167
- "A custom TTL was set, but note that the SuperLink does not enforce "
168
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
169
- "version of Flower.",
170
- stacklevel=2,
185
+ if not GrpcGrid._deprecation_warning_logged:
186
+ GrpcGrid._deprecation_warning_logged = True
187
+ warn_deprecated_feature(
188
+ "`Driver.create_message` / `Grid.create_message` is deprecated."
189
+ "Use `Message` constructor instead."
171
190
  )
191
+ return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
172
192
 
173
- ttl_ = DEFAULT_TTL if ttl is None else ttl
174
- metadata = Metadata(
175
- run_id=cast(Run, self._run).run_id,
176
- message_id="", # Will be set by the server
177
- src_node_id=self.node.node_id,
178
- dst_node_id=dst_node_id,
179
- reply_to_message="",
180
- group_id=group_id,
181
- ttl=ttl_,
182
- message_type=message_type,
183
- )
184
- return Message(metadata=metadata, content=content)
185
-
186
- def get_node_ids(self) -> list[int]:
193
+ def get_node_ids(self) -> Iterable[int]:
187
194
  """Get node IDs."""
188
- # Call GrpcDriverStub method
195
+ # Call GrpcServerAppIoStub method
189
196
  res: GetNodesResponse = self._stub.GetNodes(
190
197
  GetNodesRequest(run_id=cast(Run, self._run).run_id)
191
198
  )
@@ -198,21 +205,40 @@ class GrpcDriver(Driver):
198
205
  to the node specified in `dst_node_id`.
199
206
  """
200
207
  # Construct Messages
208
+ run_id = cast(Run, self._run).run_id
201
209
  message_proto_list: list[ProtoMessage] = []
202
210
  for msg in messages:
211
+ # Populate metadata
212
+ msg.metadata.__dict__["_run_id"] = run_id
213
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
203
214
  # Check message
204
215
  self._check_message(msg)
205
216
  # Convert to proto
206
217
  msg_proto = message_to_proto(msg)
207
218
  # Add to list
208
219
  message_proto_list.append(msg_proto)
209
- # Call GrpcDriverStub method
210
- res: PushInsMessagesResponse = self._stub.PushMessages(
211
- PushInsMessagesRequest(
212
- messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
220
+
221
+ try:
222
+ # Call GrpcServerAppIoStub method
223
+ res: PushInsMessagesResponse = self._stub.PushMessages(
224
+ PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
213
225
  )
214
- )
215
- return list(res.message_ids)
226
+ if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
227
+ message_proto_list
228
+ ):
229
+ log(
230
+ WARNING,
231
+ "Not all messages could be pushed to the SuperLink. The returned "
232
+ "list has `None` for those messages (the order is preserved as "
233
+ "passed to `push_messages`). This could be due to a malformed "
234
+ "message.",
235
+ )
236
+ return list(res.message_ids)
237
+ except grpc.RpcError as e:
238
+ if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
239
+ log(ERROR, ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED)
240
+ return []
241
+ raise
216
242
 
217
243
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
218
244
  """Pull messages based on message IDs.
@@ -220,16 +246,22 @@ class GrpcDriver(Driver):
220
246
  This method is used to collect messages from the SuperLink that correspond to a
221
247
  set of given message IDs.
222
248
  """
223
- # Pull Messages
224
- res: PullResMessagesResponse = self._stub.PullMessages(
225
- PullResMessagesRequest(
226
- message_ids=message_ids,
227
- run_id=cast(Run, self._run).run_id,
249
+ try:
250
+ # Pull Messages
251
+ res: PullResMessagesResponse = self._stub.PullMessages(
252
+ PullResMessagesRequest(
253
+ message_ids=message_ids,
254
+ run_id=cast(Run, self._run).run_id,
255
+ )
228
256
  )
229
- )
230
- # Convert Message from Protobuf representation
231
- msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
232
- return msgs
257
+ # Convert Message from Protobuf representation
258
+ msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
259
+ return msgs
260
+ except grpc.RpcError as e:
261
+ if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
262
+ log(ERROR, ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED)
263
+ return []
264
+ raise
233
265
 
234
266
  def send_and_receive(
235
267
  self,
@@ -253,7 +285,7 @@ class GrpcDriver(Driver):
253
285
  res_msgs = self.pull_messages(msg_ids)
254
286
  ret.extend(res_msgs)
255
287
  msg_ids.difference_update(
256
- {msg.metadata.reply_to_message for msg in res_msgs}
288
+ {msg.metadata.reply_to_message_id for msg in res_msgs}
257
289
  )
258
290
  if len(msg_ids) == 0:
259
291
  break
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,36 +12,37 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower in-memory Driver."""
15
+ """Flower in-memory Grid."""
16
16
 
17
17
 
18
18
  import time
19
- import warnings
20
19
  from collections.abc import Iterable
21
20
  from typing import Optional, cast
22
21
  from uuid import UUID
23
22
 
24
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
23
+ from flwr.common import Message, RecordDict
25
24
  from flwr.common.constant import SUPERLINK_NODE_ID
26
- from flwr.common.serde import message_from_taskres, message_to_taskins
25
+ from flwr.common.logger import warn_deprecated_feature
27
26
  from flwr.common.typing import Run
28
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
28
  from flwr.server.superlink.linkstate import LinkStateFactory
30
29
 
31
- from .driver import Driver
30
+ from .grid import Grid
32
31
 
33
32
 
34
- class InMemoryDriver(Driver):
35
- """`InMemoryDriver` class provides an interface to the ServerAppIo API.
33
+ class InMemoryGrid(Grid):
34
+ """`InMemoryGrid` class provides an interface to the ServerAppIo API.
36
35
 
37
36
  Parameters
38
37
  ----------
39
38
  state_factory : StateFactory
40
- A StateFactory embedding a state that this driver can interface with.
39
+ A StateFactory embedding a state that this grid can interface with.
41
40
  pull_interval : float (default=0.1)
42
41
  Sleep duration between calls to `pull_messages`.
43
42
  """
44
43
 
44
+ _deprecation_warning_logged = False
45
+
45
46
  def __init__(
46
47
  self,
47
48
  state_factory: LinkStateFactory,
@@ -55,11 +56,10 @@ class InMemoryDriver(Driver):
55
56
  def _check_message(self, message: Message) -> None:
56
57
  # Check if the message is valid
57
58
  if not (
58
- message.metadata.run_id == cast(Run, self._run).run_id
59
- and message.metadata.src_node_id == self.node.node_id
60
- and message.metadata.message_id == ""
61
- and message.metadata.reply_to_message == ""
59
+ message.metadata.message_id == ""
60
+ and message.metadata.reply_to_message_id == ""
62
61
  and message.metadata.ttl > 0
62
+ and message.metadata.delivered_at == ""
63
63
  ):
64
64
  raise ValueError(f"Invalid message: {message}")
65
65
 
@@ -77,7 +77,7 @@ class InMemoryDriver(Driver):
77
77
 
78
78
  def create_message( # pylint: disable=too-many-arguments,R0917
79
79
  self,
80
- content: RecordSet,
80
+ content: RecordDict,
81
81
  message_type: str,
82
82
  dst_node_id: int,
83
83
  group_id: str,
@@ -88,30 +88,17 @@ class InMemoryDriver(Driver):
88
88
  This method constructs a new `Message` with given content and metadata.
89
89
  The `run_id` and `src_node_id` will be set automatically.
90
90
  """
91
- if ttl:
92
- warnings.warn(
93
- "A custom TTL was set, but note that the SuperLink does not enforce "
94
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
95
- "version of Flower.",
96
- stacklevel=2,
91
+ if not InMemoryGrid._deprecation_warning_logged:
92
+ InMemoryGrid._deprecation_warning_logged = True
93
+ warn_deprecated_feature(
94
+ "`Driver.create_message` / `Grid.create_message` is deprecated."
95
+ "Use `Message` constructor instead."
97
96
  )
98
- ttl_ = DEFAULT_TTL if ttl is None else ttl
99
-
100
- metadata = Metadata(
101
- run_id=cast(Run, self._run).run_id,
102
- message_id="", # Will be set by the server
103
- src_node_id=self.node.node_id,
104
- dst_node_id=dst_node_id,
105
- reply_to_message="",
106
- group_id=group_id,
107
- ttl=ttl_,
108
- message_type=message_type,
109
- )
110
- return Message(metadata=metadata, content=content)
111
-
112
- def get_node_ids(self) -> list[int]:
97
+ return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
98
+
99
+ def get_node_ids(self) -> Iterable[int]:
113
100
  """Get node IDs."""
114
- return list(self.state.get_nodes(cast(Run, self._run).run_id))
101
+ return self.state.get_nodes(cast(Run, self._run).run_id)
115
102
 
116
103
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
117
104
  """Push messages to specified node IDs.
@@ -119,18 +106,19 @@ class InMemoryDriver(Driver):
119
106
  This method takes an iterable of messages and sends each message
120
107
  to the node specified in `dst_node_id`.
121
108
  """
122
- task_ids: list[str] = []
109
+ msg_ids: list[str] = []
123
110
  for msg in messages:
111
+ # Populate metadata
112
+ msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
113
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
124
114
  # Check message
125
115
  self._check_message(msg)
126
- # Convert Message to TaskIns
127
- taskins = message_to_taskins(msg)
128
116
  # Store in state
129
- task_id = self.state.store_task_ins(taskins)
130
- if task_id:
131
- task_ids.append(str(task_id))
117
+ msg_id = self.state.store_message_ins(msg)
118
+ if msg_id:
119
+ msg_ids.append(str(msg_id))
132
120
 
133
- return task_ids
121
+ return msg_ids
134
122
 
135
123
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
136
124
  """Pull messages based on message IDs.
@@ -139,17 +127,16 @@ class InMemoryDriver(Driver):
139
127
  set of given message IDs.
140
128
  """
141
129
  msg_ids = {UUID(msg_id) for msg_id in message_ids}
142
- # Pull TaskRes
143
- task_res_list = self.state.get_task_res(task_ids=msg_ids)
144
- # Delete tasks in state
145
- # Delete the TaskIns/TaskRes pairs if TaskRes is found
146
- task_ins_ids_to_delete = {
147
- UUID(task_res.task.ancestry[0]) for task_res in task_res_list
130
+ # Pull Messages
131
+ message_res_list = self.state.get_message_res(message_ids=msg_ids)
132
+ # Get IDs of Messages these replies are for
133
+ message_ins_ids_to_delete = {
134
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
148
135
  }
149
- self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
150
- # Convert TaskRes to Message
151
- msgs = [message_from_taskres(taskres) for taskres in task_res_list]
152
- return msgs
136
+ # Delete
137
+ self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
138
+
139
+ return message_res_list
153
140
 
154
141
  def send_and_receive(
155
142
  self,
@@ -173,7 +160,7 @@ class InMemoryDriver(Driver):
173
160
  res_msgs = self.pull_messages(msg_ids)
174
161
  ret.extend(res_msgs)
175
162
  msg_ids.difference_update(
176
- {msg.metadata.reply_to_message for msg in res_msgs}
163
+ {msg.metadata.reply_to_message_id for msg in res_msgs}
177
164
  )
178
165
  if len(msg_ids) == 0:
179
166
  break
@@ -15,26 +15,25 @@
15
15
  """Run ServerApp."""
16
16
 
17
17
 
18
- from logging import DEBUG, ERROR
18
+ from logging import DEBUG
19
19
  from typing import Optional
20
20
 
21
- from flwr.common import Context, EventType, event
22
- from flwr.common.exit_handlers import register_exit_handlers
21
+ from flwr.common import Context
23
22
  from flwr.common.logger import log
24
23
  from flwr.common.object_ref import load_app
25
24
 
26
- from .driver import Driver
25
+ from .grid import Grid
27
26
  from .server_app import LoadServerAppError, ServerApp
28
27
 
29
28
 
30
29
  def run(
31
- driver: Driver,
30
+ grid: Grid,
32
31
  context: Context,
33
32
  server_app_dir: str,
34
33
  server_app_attr: Optional[str] = None,
35
34
  loaded_server_app: Optional[ServerApp] = None,
36
35
  ) -> Context:
37
- """Run ServerApp with a given Driver."""
36
+ """Run ServerApp with a given Grid."""
38
37
  if not (server_app_attr is None) ^ (loaded_server_app is None):
39
38
  raise ValueError(
40
39
  "Either `server_app_attr` or `loaded_server_app` should be set "
@@ -60,17 +59,7 @@ def run(
60
59
  server_app = _load()
61
60
 
62
61
  # Call ServerApp
63
- server_app(driver=driver, context=context)
62
+ server_app(grid=grid, context=context)
64
63
 
65
64
  log(DEBUG, "ServerApp finished running.")
66
65
  return context
67
-
68
-
69
- def run_server_app() -> None:
70
- """Run Flower server app."""
71
- event(EventType.RUN_SERVER_APP_ENTER)
72
- log(
73
- ERROR,
74
- "The command `flower-server-app` has been replaced by `flwr run`.",
75
- )
76
- register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)