flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. flwr/client/app.py +6 -4
  2. flwr/client/clientapp/app.py +2 -2
  3. flwr/client/grpc_client/connection.py +23 -20
  4. flwr/client/message_handler/message_handler.py +27 -27
  5. flwr/client/mod/centraldp_mods.py +7 -7
  6. flwr/client/mod/localdp_mod.py +4 -4
  7. flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
  8. flwr/client/run_info_store.py +2 -2
  9. flwr/common/__init__.py +2 -0
  10. flwr/common/constant.py +2 -0
  11. flwr/common/context.py +4 -4
  12. flwr/common/logger.py +2 -2
  13. flwr/common/message.py +269 -101
  14. flwr/common/record/__init__.py +2 -1
  15. flwr/common/record/configsrecord.py +2 -2
  16. flwr/common/record/metricsrecord.py +1 -1
  17. flwr/common/record/parametersrecord.py +1 -1
  18. flwr/common/record/{recordset.py → recorddict.py} +57 -17
  19. flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
  20. flwr/common/serde.py +33 -37
  21. flwr/proto/exec_pb2.py +32 -32
  22. flwr/proto/exec_pb2.pyi +3 -3
  23. flwr/proto/message_pb2.py +12 -12
  24. flwr/proto/message_pb2.pyi +9 -9
  25. flwr/proto/recorddict_pb2.py +70 -0
  26. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
  27. flwr/proto/run_pb2.py +32 -32
  28. flwr/proto/run_pb2.pyi +3 -3
  29. flwr/server/__init__.py +2 -0
  30. flwr/server/compat/__init__.py +2 -2
  31. flwr/server/compat/app.py +11 -11
  32. flwr/server/compat/app_utils.py +16 -16
  33. flwr/server/compat/grid_client_proxy.py +38 -38
  34. flwr/server/grid/__init__.py +7 -6
  35. flwr/server/grid/grid.py +46 -17
  36. flwr/server/grid/grpc_grid.py +26 -33
  37. flwr/server/grid/inmemory_grid.py +19 -25
  38. flwr/server/run_serverapp.py +4 -4
  39. flwr/server/server_app.py +37 -11
  40. flwr/server/serverapp/app.py +10 -10
  41. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  42. flwr/server/superlink/linkstate/in_memory_linkstate.py +29 -4
  43. flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -20
  44. flwr/server/superlink/linkstate/utils.py +77 -17
  45. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
  46. flwr/server/typing.py +3 -3
  47. flwr/server/utils/validator.py +4 -4
  48. flwr/server/workflow/default_workflows.py +24 -26
  49. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +23 -23
  50. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  51. flwr/simulation/run_simulation.py +13 -13
  52. flwr/superexec/deployment.py +2 -2
  53. flwr/superexec/simulation.py +2 -2
  54. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
  55. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +60 -60
  56. flwr/proto/recordset_pb2.py +0 -70
  57. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  58. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  59. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
  60. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
  61. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower in-memory Driver."""
15
+ """Flower in-memory Grid."""
16
16
 
17
17
 
18
18
  import time
@@ -20,22 +20,23 @@ from collections.abc import Iterable
20
20
  from typing import Optional, cast
21
21
  from uuid import UUID
22
22
 
23
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
23
+ from flwr.common import Message, RecordDict
24
24
  from flwr.common.constant import SUPERLINK_NODE_ID
25
+ from flwr.common.logger import warn_deprecated_feature
25
26
  from flwr.common.typing import Run
26
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
28
  from flwr.server.superlink.linkstate import LinkStateFactory
28
29
 
29
- from .grid import Driver
30
+ from .grid import Grid
30
31
 
31
32
 
32
- class InMemoryDriver(Driver):
33
- """`InMemoryDriver` class provides an interface to the ServerAppIo API.
33
+ class InMemoryGrid(Grid):
34
+ """`InMemoryGrid` class provides an interface to the ServerAppIo API.
34
35
 
35
36
  Parameters
36
37
  ----------
37
38
  state_factory : StateFactory
38
- A StateFactory embedding a state that this driver can interface with.
39
+ A StateFactory embedding a state that this grid can interface with.
39
40
  pull_interval : float (default=0.1)
40
41
  Sleep duration between calls to `pull_messages`.
41
42
  """
@@ -53,10 +54,8 @@ class InMemoryDriver(Driver):
53
54
  def _check_message(self, message: Message) -> None:
54
55
  # Check if the message is valid
55
56
  if not (
56
- message.metadata.run_id == cast(Run, self._run).run_id
57
- and message.metadata.src_node_id == self.node.node_id
58
- and message.metadata.message_id == ""
59
- and message.metadata.reply_to_message == ""
57
+ message.metadata.message_id == ""
58
+ and message.metadata.reply_to_message_id == ""
60
59
  and message.metadata.ttl > 0
61
60
  and message.metadata.delivered_at == ""
62
61
  ):
@@ -76,7 +75,7 @@ class InMemoryDriver(Driver):
76
75
 
77
76
  def create_message( # pylint: disable=too-many-arguments,R0917
78
77
  self,
79
- content: RecordSet,
78
+ content: RecordDict,
80
79
  message_type: str,
81
80
  dst_node_id: int,
82
81
  group_id: str,
@@ -87,19 +86,11 @@ class InMemoryDriver(Driver):
87
86
  This method constructs a new `Message` with given content and metadata.
88
87
  The `run_id` and `src_node_id` will be set automatically.
89
88
  """
90
- ttl_ = DEFAULT_TTL if ttl is None else ttl
91
-
92
- metadata = Metadata(
93
- run_id=cast(Run, self._run).run_id,
94
- message_id="", # Will be set by the server
95
- src_node_id=self.node.node_id,
96
- dst_node_id=dst_node_id,
97
- reply_to_message="",
98
- group_id=group_id,
99
- ttl=ttl_,
100
- message_type=message_type,
89
+ warn_deprecated_feature(
90
+ "`Driver.create_message` / `Grid.create_message` is deprecated."
91
+ "Use `Message` constructor instead."
101
92
  )
102
- return Message(metadata=metadata, content=content)
93
+ return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
103
94
 
104
95
  def get_node_ids(self) -> Iterable[int]:
105
96
  """Get node IDs."""
@@ -113,6 +104,9 @@ class InMemoryDriver(Driver):
113
104
  """
114
105
  msg_ids: list[str] = []
115
106
  for msg in messages:
107
+ # Populate metadata
108
+ msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
109
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
116
110
  # Check message
117
111
  self._check_message(msg)
118
112
  # Store in state
@@ -133,7 +127,7 @@ class InMemoryDriver(Driver):
133
127
  message_res_list = self.state.get_message_res(message_ids=msg_ids)
134
128
  # Get IDs of Messages these replies are for
135
129
  message_ins_ids_to_delete = {
136
- UUID(msg_res.metadata.reply_to_message) for msg_res in message_res_list
130
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
137
131
  }
138
132
  # Delete
139
133
  self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
@@ -162,7 +156,7 @@ class InMemoryDriver(Driver):
162
156
  res_msgs = self.pull_messages(msg_ids)
163
157
  ret.extend(res_msgs)
164
158
  msg_ids.difference_update(
165
- {msg.metadata.reply_to_message for msg in res_msgs}
159
+ {msg.metadata.reply_to_message_id for msg in res_msgs}
166
160
  )
167
161
  if len(msg_ids) == 0:
168
162
  break
@@ -22,18 +22,18 @@ from flwr.common import Context
22
22
  from flwr.common.logger import log
23
23
  from flwr.common.object_ref import load_app
24
24
 
25
- from .grid import Driver
25
+ from .grid import Grid
26
26
  from .server_app import LoadServerAppError, ServerApp
27
27
 
28
28
 
29
29
  def run(
30
- driver: Driver,
30
+ grid: Grid,
31
31
  context: Context,
32
32
  server_app_dir: str,
33
33
  server_app_attr: Optional[str] = None,
34
34
  loaded_server_app: Optional[ServerApp] = None,
35
35
  ) -> Context:
36
- """Run ServerApp with a given Driver."""
36
+ """Run ServerApp with a given Grid."""
37
37
  if not (server_app_attr is None) ^ (loaded_server_app is None):
38
38
  raise ValueError(
39
39
  "Either `server_app_attr` or `loaded_server_app` should be set "
@@ -59,7 +59,7 @@ def run(
59
59
  server_app = _load()
60
60
 
61
61
  # Call ServerApp
62
- server_app(driver=driver, context=context)
62
+ server_app(grid=grid, context=context)
63
63
 
64
64
  log(DEBUG, "ServerApp finished running.")
65
65
  return context
flwr/server/server_app.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
+ import inspect
18
19
  from collections.abc import Iterator
19
20
  from contextlib import contextmanager
20
21
  from typing import Callable, Optional
@@ -24,8 +25,8 @@ from flwr.common.logger import warn_deprecated_feature_with_example
24
25
  from flwr.server.strategy import Strategy
25
26
 
26
27
  from .client_manager import ClientManager
27
- from .compat import start_driver
28
- from .grid import Driver
28
+ from .compat import start_grid
29
+ from .grid import Driver, Grid
29
30
  from .server import Server
30
31
  from .server_config import ServerConfig
31
32
  from .typing import ServerAppCallable, ServerFn
@@ -43,6 +44,21 @@ SERVER_FN_USAGE_EXAMPLE = """
43
44
  app = ServerApp(server_fn=server_fn)
44
45
  """
45
46
 
47
+ GRID_USAGE_EXAMPLE = """
48
+ app = ServerApp()
49
+
50
+ @app.main()
51
+ def main(grid: Grid, context: Context) -> None:
52
+ # Your existing ServerApp code ...
53
+ """
54
+
55
+ DRIVER_DEPRECATION_MSG = """
56
+ The `Driver` class is deprecated, it will be removed in a future release.
57
+ """
58
+ DRIVER_EXAMPLE_MSG = """
59
+ Instead, use `Grid` in the signature of your `ServerApp`. For example:
60
+ """
61
+
46
62
 
47
63
  @contextmanager
48
64
  def _empty_lifespan(_: Context) -> Iterator[None]:
@@ -54,7 +70,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
54
70
 
55
71
  Examples
56
72
  --------
57
- Use the `ServerApp` with an existing `Strategy`:
73
+ Use the ``ServerApp`` with an existing ``Strategy``:
58
74
 
59
75
  >>> def server_fn(context: Context):
60
76
  >>> server_config = ServerConfig(num_rounds=3)
@@ -66,12 +82,12 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
66
82
  >>>
67
83
  >>> app = ServerApp(server_fn=server_fn)
68
84
 
69
- Use the `ServerApp` with a custom main function:
85
+ Use the ``ServerApp`` with a custom main function:
70
86
 
71
87
  >>> app = ServerApp()
72
88
  >>>
73
89
  >>> @app.main()
74
- >>> def main(driver: Driver, context: Context) -> None:
90
+ >>> def main(grid: Grid, context: Context) -> None:
75
91
  >>> print("ServerApp running")
76
92
  """
77
93
 
@@ -111,7 +127,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
111
127
  self._main: Optional[ServerAppCallable] = None
112
128
  self._lifespan = _empty_lifespan
113
129
 
114
- def __call__(self, driver: Driver, context: Context) -> None:
130
+ def __call__(self, grid: Grid, context: Context) -> None:
115
131
  """Execute `ServerApp`."""
116
132
  with self._lifespan(context):
117
133
  # Compatibility mode
@@ -123,17 +139,17 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
123
139
  self._config = components.config
124
140
  self._strategy = components.strategy
125
141
  self._client_manager = components.client_manager
126
- start_driver(
142
+ start_grid(
127
143
  server=self._server,
128
144
  config=self._config,
129
145
  strategy=self._strategy,
130
146
  client_manager=self._client_manager,
131
- driver=driver,
147
+ grid=grid,
132
148
  )
133
149
  return
134
150
 
135
151
  # New execution mode
136
- self._main(driver, context)
152
+ self._main(grid, context)
137
153
 
138
154
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
139
155
  """Return a decorator that registers the main fn with the server app.
@@ -143,7 +159,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
143
159
  >>> app = ServerApp()
144
160
  >>>
145
161
  >>> @app.main()
146
- >>> def main(driver: Driver, context: Context) -> None:
162
+ >>> def main(grid: Grid, context: Context) -> None:
147
163
  >>> print("ServerApp running")
148
164
  """
149
165
 
@@ -168,11 +184,21 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
168
184
  >>> app = ServerApp()
169
185
  >>>
170
186
  >>> @app.main()
171
- >>> def main(driver: Driver, context: Context) -> None:
187
+ >>> def main(grid: Grid, context: Context) -> None:
172
188
  >>> print("ServerApp running")
173
189
  """,
174
190
  )
175
191
 
192
+ sig = inspect.signature(main_fn)
193
+ param = list(sig.parameters.values())[0]
194
+ # Check if parameter name or the annotation should be updated
195
+ if param.name == "driver" or param.annotation is Driver:
196
+ warn_deprecated_feature_with_example(
197
+ deprecation_message=DRIVER_DEPRECATION_MSG,
198
+ example_message=DRIVER_EXAMPLE_MSG,
199
+ code_example=GRID_USAGE_EXAMPLE,
200
+ )
201
+
176
202
  # Register provided function with the ServerApp object
177
203
  self._main = main_fn
178
204
 
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
60
60
  PullServerAppInputsResponse,
61
61
  PushServerAppOutputsRequest,
62
62
  )
63
- from flwr.server.grid.grpc_grid import GrpcDriver
63
+ from flwr.server.grid.grpc_grid import GrpcGrid
64
64
  from flwr.server.run_serverapp import run as run_
65
65
 
66
66
 
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
106
106
  certificates: Optional[bytes] = None,
107
107
  ) -> None:
108
108
  """Run Flower ServerApp process."""
109
- driver = GrpcDriver(
109
+ grid = GrpcGrid(
110
110
  serverappio_service_address=serverappio_api_address,
111
111
  root_certificates=certificates,
112
112
  )
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
123
123
  # Pull ServerAppInputs from LinkState
124
124
  req = PullServerAppInputsRequest()
125
125
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
126
- res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
126
+ res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
127
127
  if not res.HasField("run"):
128
128
  sleep(3)
129
129
  run_status = None
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
135
135
 
136
136
  hash_run_id = get_sha256_hash(run.run_id)
137
137
 
138
- driver.set_run(run.run_id)
138
+ grid.set_run(run.run_id)
139
139
 
140
140
  # Start log uploader for this run
141
141
  log_uploader = start_log_uploader(
142
142
  log_queue=log_queue,
143
143
  node_id=0,
144
144
  run_id=run.run_id,
145
- stub=driver._stub,
145
+ stub=grid._stub,
146
146
  )
147
147
 
148
148
  log(DEBUG, "[flwr-serverapp] Start FAB installation.")
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
173
173
 
174
174
  # Change status to Running
175
175
  run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
176
- driver._stub.UpdateRunStatus(
176
+ grid._stub.UpdateRunStatus(
177
177
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
178
178
  )
179
179
 
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
182
182
  event_details={"run-id-hash": hash_run_id},
183
183
  )
184
184
 
185
- # Load and run the ServerApp with the Driver
185
+ # Load and run the ServerApp with the Grid
186
186
  updated_context = run_(
187
- driver=driver,
187
+ grid=grid,
188
188
  server_app_dir=app_path,
189
189
  server_app_attr=server_app_attr,
190
190
  context=context,
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
196
196
  out_req = PushServerAppOutputsRequest(
197
197
  run_id=run.run_id, context=context_proto
198
198
  )
199
- _ = driver._stub.PushServerAppOutputs(out_req)
199
+ _ = grid._stub.PushServerAppOutputs(out_req)
200
200
 
201
201
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
202
202
  except RunNotRunningException:
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
221
221
  # Update run status
222
222
  if run_status:
223
223
  run_status_proto = run_status_to_proto(run_status)
224
- driver._stub.UpdateRunStatus(
224
+ grid._stub.UpdateRunStatus(
225
225
  UpdateRunStatusRequest(
226
226
  run_id=run.run_id, run_status=run_status_proto
227
227
  )
@@ -130,9 +130,7 @@ def worker(
130
130
  e_code = ErrorCode.UNKNOWN
131
131
 
132
132
  reason = str(type(ex)) + ":<'" + str(ex) + "'>"
133
- out_mssg = message.create_error_reply(
134
- error=Error(code=e_code, reason=reason)
135
- )
133
+ out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
136
134
 
137
135
  finally:
138
136
  if out_mssg:
@@ -27,6 +27,7 @@ from flwr.common import Context, Message, log, now
27
27
  from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
+ PING_PATIENCE,
30
31
  RUN_ID_NUM_BYTES,
31
32
  SUPERLINK_NODE_ID,
32
33
  Status,
@@ -37,6 +38,7 @@ from flwr.server.superlink.linkstate.linkstate import LinkState
37
38
  from flwr.server.utils import validate_message
38
39
 
39
40
  from .utils import (
41
+ check_node_availability_for_in_message,
40
42
  generate_rand_int_from_bytes,
41
43
  has_valid_sub_status,
42
44
  is_valid_transition,
@@ -156,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
156
158
  res_metadata = message.metadata
157
159
  with self.lock:
158
160
  # Check if the Message it is replying to exists and is valid
159
- msg_ins_id = res_metadata.reply_to_message
161
+ msg_ins_id = res_metadata.reply_to_message_id
160
162
  msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
161
163
 
162
164
  # Ensure that dst_node_id of original Message matches the src_node_id of
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
232
234
  with self.lock:
233
235
  current = time.time()
234
236
 
235
- # Verify Messge IDs
237
+ # Verify Message IDs
236
238
  ret = verify_message_ids(
237
239
  inquired_message_ids=message_ids,
238
240
  found_message_ins_dict=self.message_ins_store,
239
241
  current_time=current,
240
242
  )
241
243
 
244
+ # Check node availability
245
+ dst_node_ids = {
246
+ self.message_ins_store[message_id].metadata.dst_node_id
247
+ for message_id in message_ids
248
+ }
249
+ tmp_ret_dict = check_node_availability_for_in_message(
250
+ inquired_in_message_ids=message_ids,
251
+ found_in_message_dict=self.message_ins_store,
252
+ node_id_to_online_until={
253
+ node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
254
+ },
255
+ current_time=current,
256
+ )
257
+ ret.update(tmp_ret_dict)
258
+
242
259
  # Find all reply Messages
243
260
  message_res_found: list[Message] = []
244
261
  for message_id in message_ids:
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
317
334
  log(ERROR, "Unexpected node registration failure.")
318
335
  return 0
319
336
 
337
+ # Mark the node online util time.time() + ping_interval
320
338
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
321
339
  return node_id
322
340
 
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
519
537
  return self.federation_options[run_id]
520
538
 
521
539
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
522
- """Acknowledge a ping received from a node, serving as a heartbeat."""
540
+ """Acknowledge a ping received from a node, serving as a heartbeat.
541
+
542
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
543
+ marking the node as offline, where PING_PATIENCE = 2 in default.
544
+ """
523
545
  with self.lock:
524
546
  if node_id in self.node_ids:
525
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
547
+ self.node_ids[node_id] = (
548
+ time.time() + PING_PATIENCE * ping_interval,
549
+ ping_interval,
550
+ )
526
551
  return True
527
552
  return False
528
553
 
@@ -30,28 +30,31 @@ from flwr.common import Context, Message, Metadata, log, now
30
30
  from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
+ PING_PATIENCE,
33
34
  RUN_ID_NUM_BYTES,
34
35
  SUPERLINK_NODE_ID,
35
36
  Status,
36
37
  )
38
+ from flwr.common.message import make_message
37
39
  from flwr.common.record import ConfigsRecord
38
40
  from flwr.common.serde import (
39
41
  error_from_proto,
40
42
  error_to_proto,
41
- recordset_from_proto,
42
- recordset_to_proto,
43
+ recorddict_from_proto,
44
+ recorddict_to_proto,
43
45
  )
44
46
  from flwr.common.typing import Run, RunStatus, UserConfig
45
47
 
46
48
  # pylint: disable=E0611
47
49
  from flwr.proto.error_pb2 import Error as ProtoError
48
- from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
50
+ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
49
51
 
50
52
  # pylint: enable=E0611
51
53
  from flwr.server.utils.validator import validate_message
52
54
 
53
55
  from .linkstate import LinkState
54
56
  from .utils import (
57
+ check_node_availability_for_in_message,
55
58
  configsrecord_from_bytes,
56
59
  configsrecord_to_bytes,
57
60
  context_from_bytes,
@@ -129,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
129
132
  run_id INTEGER,
130
133
  src_node_id INTEGER,
131
134
  dst_node_id INTEGER,
132
- reply_to_message TEXT,
135
+ reply_to_message_id TEXT,
133
136
  created_at REAL,
134
137
  delivered_at TEXT,
135
138
  ttl REAL,
@@ -148,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
148
151
  run_id INTEGER,
149
152
  src_node_id INTEGER,
150
153
  dst_node_id INTEGER,
151
- reply_to_message TEXT,
154
+ reply_to_message_id TEXT,
152
155
  created_at REAL,
153
156
  delivered_at TEXT,
154
157
  ttl REAL,
@@ -371,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
371
374
  return None
372
375
 
373
376
  res_metadata = message.metadata
374
- msg_ins_id = res_metadata.reply_to_message
377
+ msg_ins_id = res_metadata.reply_to_message_id
375
378
  msg_ins = self.get_valid_message_ins(msg_ins_id)
376
379
  if msg_ins is None:
377
380
  log(
@@ -442,6 +445,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
442
445
 
443
446
  def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
444
447
  """Get reply Messages for the given Message IDs."""
448
+ # pylint: disable-msg=too-many-locals
445
449
  ret: dict[UUID, Message] = {}
446
450
 
447
451
  # Verify Message IDs
@@ -465,11 +469,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
465
469
  current_time=current,
466
470
  )
467
471
 
472
+ # Check node availability
473
+ dst_node_ids: set[int] = set()
474
+ for message_id in message_ids:
475
+ in_message = found_message_ins_dict[message_id]
476
+ sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
477
+ dst_node_ids.add(sint_node_id)
478
+ query = f"""
479
+ SELECT node_id, online_until
480
+ FROM node
481
+ WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
482
+ """
483
+ rows = self.query(query, tuple(dst_node_ids))
484
+ tmp_ret_dict = check_node_availability_for_in_message(
485
+ inquired_in_message_ids=message_ids,
486
+ found_in_message_dict=found_message_ins_dict,
487
+ node_id_to_online_until={
488
+ convert_sint64_to_uint64(row["node_id"]): row["online_until"]
489
+ for row in rows
490
+ },
491
+ current_time=current,
492
+ )
493
+ ret.update(tmp_ret_dict)
494
+
468
495
  # Find all reply Messages
469
496
  query = f"""
470
497
  SELECT *
471
498
  FROM message_res
472
- WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
499
+ WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
473
500
  AND delivered_at = "";
474
501
  """
475
502
  rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
@@ -542,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
542
569
  # Delete reply Message
543
570
  query_2 = f"""
544
571
  DELETE FROM message_res
545
- WHERE reply_to_message IN ({placeholders});
572
+ WHERE reply_to_message_id IN ({placeholders});
546
573
  """
547
574
 
548
575
  with self.conn:
@@ -584,6 +611,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
584
611
  "VALUES (?, ?, ?, ?)"
585
612
  )
586
613
 
614
+ # Mark the node online util time.time() + ping_interval
587
615
  try:
588
616
  self.query(
589
617
  query,
@@ -899,7 +927,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
899
927
  return configsrecord_from_bytes(row["federation_options"])
900
928
 
901
929
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
902
- """Acknowledge a ping received from a node, serving as a heartbeat."""
930
+ """Acknowledge a ping received from a node, serving as a heartbeat.
931
+
932
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
933
+ marking the node as offline, where PING_PATIENCE = 2 in default.
934
+ """
903
935
  sint64_node_id = convert_uint64_to_sint64(node_id)
904
936
 
905
937
  # Check if the node exists in the `node` table
@@ -909,7 +941,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
909
941
 
910
942
  # Update `online_until` and `ping_interval` for the given `node_id`
911
943
  query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
912
- self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
944
+ self.query(
945
+ query,
946
+ (
947
+ time.time() + PING_PATIENCE * ping_interval,
948
+ ping_interval,
949
+ sint64_node_id,
950
+ ),
951
+ )
913
952
  return True
914
953
 
915
954
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
@@ -1026,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1026
1065
  "run_id": message.metadata.run_id,
1027
1066
  "src_node_id": message.metadata.src_node_id,
1028
1067
  "dst_node_id": message.metadata.dst_node_id,
1029
- "reply_to_message": message.metadata.reply_to_message,
1068
+ "reply_to_message_id": message.metadata.reply_to_message_id,
1030
1069
  "created_at": message.metadata.created_at,
1031
1070
  "delivered_at": message.metadata.delivered_at,
1032
1071
  "ttl": message.metadata.ttl,
@@ -1036,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1036
1075
  }
1037
1076
 
1038
1077
  if message.has_content():
1039
- result["content"] = recordset_to_proto(message.content).SerializeToString()
1078
+ result["content"] = recorddict_to_proto(message.content).SerializeToString()
1040
1079
  else:
1041
1080
  result["error"] = error_to_proto(message.error).SerializeToString()
1042
1081
 
@@ -1047,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
1047
1086
  """Transform dict to Message."""
1048
1087
  content, error = None, None
1049
1088
  if (b_content := message_dict.pop("content")) is not None:
1050
- content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
1089
+ content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
1051
1090
  if (b_error := message_dict.pop("error")) is not None:
1052
1091
  error = error_from_proto(ProtoError.FromString(b_error))
1053
1092
 
1054
1093
  # Metadata constructor doesn't allow passing created_at. We set it later
1055
1094
  metadata = Metadata(
1056
- **{
1057
- k: v
1058
- for k, v in message_dict.items()
1059
- if k not in ["created_at", "delivered_at"]
1060
- }
1095
+ **{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
1061
1096
  )
1062
- msg = Message(metadata=metadata, content=content, error=error)
1063
- msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
1097
+ msg = make_message(metadata=metadata, content=content, error=error)
1064
1098
  msg.metadata.delivered_at = message_dict["delivered_at"]
1065
1099
  return msg
1066
1100