flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  11. flwr/cli/run/run.py +5 -9
  12. flwr/client/app.py +6 -4
  13. flwr/client/client_app.py +162 -99
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/grpc_client/connection.py +24 -21
  16. flwr/client/message_handler/message_handler.py +27 -27
  17. flwr/client/mod/__init__.py +2 -2
  18. flwr/client/mod/centraldp_mods.py +7 -7
  19. flwr/client/mod/comms_mods.py +16 -22
  20. flwr/client/mod/localdp_mod.py +4 -4
  21. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  22. flwr/client/run_info_store.py +2 -2
  23. flwr/common/__init__.py +12 -4
  24. flwr/common/config.py +4 -4
  25. flwr/common/constant.py +6 -6
  26. flwr/common/context.py +4 -4
  27. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  28. flwr/common/logger.py +2 -2
  29. flwr/common/message.py +327 -102
  30. flwr/common/record/__init__.py +8 -4
  31. flwr/common/record/arrayrecord.py +626 -0
  32. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  33. flwr/common/record/conversion_utils.py +1 -1
  34. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  35. flwr/common/record/recorddict.py +288 -0
  36. flwr/common/recorddict_compat.py +410 -0
  37. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  38. flwr/common/serde.py +66 -71
  39. flwr/common/typing.py +8 -8
  40. flwr/proto/exec_pb2.py +3 -3
  41. flwr/proto/exec_pb2.pyi +3 -3
  42. flwr/proto/message_pb2.py +12 -12
  43. flwr/proto/message_pb2.pyi +9 -9
  44. flwr/proto/recorddict_pb2.py +70 -0
  45. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  46. flwr/proto/run_pb2.py +31 -31
  47. flwr/proto/run_pb2.pyi +3 -3
  48. flwr/server/__init__.py +3 -1
  49. flwr/server/app.py +56 -1
  50. flwr/server/compat/__init__.py +2 -2
  51. flwr/server/compat/app.py +11 -11
  52. flwr/server/compat/app_utils.py +16 -16
  53. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  54. flwr/server/fleet_event_log_interceptor.py +94 -0
  55. flwr/server/{driver → grid}/__init__.py +8 -7
  56. flwr/server/{driver/driver.py → grid/grid.py} +47 -18
  57. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  58. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  59. flwr/server/run_serverapp.py +4 -4
  60. flwr/server/server_app.py +38 -18
  61. flwr/server/serverapp/app.py +10 -10
  62. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  63. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  64. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  65. flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
  66. flwr/server/superlink/linkstate/linkstate.py +4 -4
  67. flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
  68. flwr/server/superlink/linkstate/utils.py +93 -27
  69. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  70. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  71. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  72. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  73. flwr/server/typing.py +3 -3
  74. flwr/server/utils/validator.py +4 -4
  75. flwr/server/workflow/default_workflows.py +48 -57
  76. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  77. flwr/simulation/app.py +2 -2
  78. flwr/simulation/ray_transport/ray_actor.py +4 -2
  79. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  80. flwr/simulation/run_simulation.py +15 -15
  81. flwr/superexec/deployment.py +4 -4
  82. flwr/superexec/exec_event_log_interceptor.py +135 -0
  83. flwr/superexec/exec_grpc.py +10 -4
  84. flwr/superexec/exec_servicer.py +2 -2
  85. flwr/superexec/exec_user_auth_interceptor.py +18 -2
  86. flwr/superexec/executor.py +3 -3
  87. flwr/superexec/simulation.py +3 -3
  88. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
  89. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
  90. flwr/common/record/parametersrecord.py +0 -339
  91. flwr/common/record/recordset.py +0 -209
  92. flwr/common/recordset_compat.py +0 -418
  93. flwr/proto/recordset_pb2.py +0 -70
  94. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  95. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  96. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  97. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
  98. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,35 +12,37 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower in-memory Driver."""
15
+ """Flower in-memory Grid."""
16
16
 
17
17
 
18
18
  import time
19
- import warnings
20
19
  from collections.abc import Iterable
21
20
  from typing import Optional, cast
22
21
  from uuid import UUID
23
22
 
24
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
23
+ from flwr.common import Message, RecordDict
25
24
  from flwr.common.constant import SUPERLINK_NODE_ID
25
+ from flwr.common.logger import warn_deprecated_feature
26
26
  from flwr.common.typing import Run
27
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
28
  from flwr.server.superlink.linkstate import LinkStateFactory
29
29
 
30
- from .driver import Driver
30
+ from .grid import Grid
31
31
 
32
32
 
33
- class InMemoryDriver(Driver):
34
- """`InMemoryDriver` class provides an interface to the ServerAppIo API.
33
+ class InMemoryGrid(Grid):
34
+ """`InMemoryGrid` class provides an interface to the ServerAppIo API.
35
35
 
36
36
  Parameters
37
37
  ----------
38
38
  state_factory : StateFactory
39
- A StateFactory embedding a state that this driver can interface with.
39
+ A StateFactory embedding a state that this grid can interface with.
40
40
  pull_interval : float (default=0.1)
41
41
  Sleep duration between calls to `pull_messages`.
42
42
  """
43
43
 
44
+ _deprecation_warning_logged = False
45
+
44
46
  def __init__(
45
47
  self,
46
48
  state_factory: LinkStateFactory,
@@ -54,10 +56,8 @@ class InMemoryDriver(Driver):
54
56
  def _check_message(self, message: Message) -> None:
55
57
  # Check if the message is valid
56
58
  if not (
57
- message.metadata.run_id == cast(Run, self._run).run_id
58
- and message.metadata.src_node_id == self.node.node_id
59
- and message.metadata.message_id == ""
60
- and message.metadata.reply_to_message == ""
59
+ message.metadata.message_id == ""
60
+ and message.metadata.reply_to_message_id == ""
61
61
  and message.metadata.ttl > 0
62
62
  and message.metadata.delivered_at == ""
63
63
  ):
@@ -77,7 +77,7 @@ class InMemoryDriver(Driver):
77
77
 
78
78
  def create_message( # pylint: disable=too-many-arguments,R0917
79
79
  self,
80
- content: RecordSet,
80
+ content: RecordDict,
81
81
  message_type: str,
82
82
  dst_node_id: int,
83
83
  group_id: str,
@@ -88,26 +88,13 @@ class InMemoryDriver(Driver):
88
88
  This method constructs a new `Message` with given content and metadata.
89
89
  The `run_id` and `src_node_id` will be set automatically.
90
90
  """
91
- if ttl:
92
- warnings.warn(
93
- "A custom TTL was set, but note that the SuperLink does not enforce "
94
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
95
- "version of Flower.",
96
- stacklevel=2,
91
+ if not InMemoryGrid._deprecation_warning_logged:
92
+ InMemoryGrid._deprecation_warning_logged = True
93
+ warn_deprecated_feature(
94
+ "`Driver.create_message` / `Grid.create_message` is deprecated."
95
+ "Use `Message` constructor instead."
97
96
  )
98
- ttl_ = DEFAULT_TTL if ttl is None else ttl
99
-
100
- metadata = Metadata(
101
- run_id=cast(Run, self._run).run_id,
102
- message_id="", # Will be set by the server
103
- src_node_id=self.node.node_id,
104
- dst_node_id=dst_node_id,
105
- reply_to_message="",
106
- group_id=group_id,
107
- ttl=ttl_,
108
- message_type=message_type,
109
- )
110
- return Message(metadata=metadata, content=content)
97
+ return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
111
98
 
112
99
  def get_node_ids(self) -> Iterable[int]:
113
100
  """Get node IDs."""
@@ -121,6 +108,9 @@ class InMemoryDriver(Driver):
121
108
  """
122
109
  msg_ids: list[str] = []
123
110
  for msg in messages:
111
+ # Populate metadata
112
+ msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
113
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
124
114
  # Check message
125
115
  self._check_message(msg)
126
116
  # Store in state
@@ -141,7 +131,7 @@ class InMemoryDriver(Driver):
141
131
  message_res_list = self.state.get_message_res(message_ids=msg_ids)
142
132
  # Get IDs of Messages these replies are for
143
133
  message_ins_ids_to_delete = {
144
- UUID(msg_res.metadata.reply_to_message) for msg_res in message_res_list
134
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
145
135
  }
146
136
  # Delete
147
137
  self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
@@ -170,7 +160,7 @@ class InMemoryDriver(Driver):
170
160
  res_msgs = self.pull_messages(msg_ids)
171
161
  ret.extend(res_msgs)
172
162
  msg_ids.difference_update(
173
- {msg.metadata.reply_to_message for msg in res_msgs}
163
+ {msg.metadata.reply_to_message_id for msg in res_msgs}
174
164
  )
175
165
  if len(msg_ids) == 0:
176
166
  break
@@ -22,18 +22,18 @@ from flwr.common import Context
22
22
  from flwr.common.logger import log
23
23
  from flwr.common.object_ref import load_app
24
24
 
25
- from .driver import Driver
25
+ from .grid import Grid
26
26
  from .server_app import LoadServerAppError, ServerApp
27
27
 
28
28
 
29
29
  def run(
30
- driver: Driver,
30
+ grid: Grid,
31
31
  context: Context,
32
32
  server_app_dir: str,
33
33
  server_app_attr: Optional[str] = None,
34
34
  loaded_server_app: Optional[ServerApp] = None,
35
35
  ) -> Context:
36
- """Run ServerApp with a given Driver."""
36
+ """Run ServerApp with a given Grid."""
37
37
  if not (server_app_attr is None) ^ (loaded_server_app is None):
38
38
  raise ValueError(
39
39
  "Either `server_app_attr` or `loaded_server_app` should be set "
@@ -59,7 +59,7 @@ def run(
59
59
  server_app = _load()
60
60
 
61
61
  # Call ServerApp
62
- server_app(driver=driver, context=context)
62
+ server_app(grid=grid, context=context)
63
63
 
64
64
  log(DEBUG, "ServerApp finished running.")
65
65
  return context
flwr/server/server_app.py CHANGED
@@ -15,20 +15,18 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
+ import inspect
18
19
  from collections.abc import Iterator
19
20
  from contextlib import contextmanager
20
21
  from typing import Callable, Optional
21
22
 
22
23
  from flwr.common import Context
23
- from flwr.common.logger import (
24
- warn_deprecated_feature_with_example,
25
- warn_preview_feature,
26
- )
24
+ from flwr.common.logger import warn_deprecated_feature_with_example
27
25
  from flwr.server.strategy import Strategy
28
26
 
29
27
  from .client_manager import ClientManager
30
- from .compat import start_driver
31
- from .driver import Driver
28
+ from .compat import start_grid
29
+ from .grid import Driver, Grid
32
30
  from .server import Server
33
31
  from .server_config import ServerConfig
34
32
  from .typing import ServerAppCallable, ServerFn
@@ -46,6 +44,21 @@ SERVER_FN_USAGE_EXAMPLE = """
46
44
  app = ServerApp(server_fn=server_fn)
47
45
  """
48
46
 
47
+ GRID_USAGE_EXAMPLE = """
48
+ app = ServerApp()
49
+
50
+ @app.main()
51
+ def main(grid: Grid, context: Context) -> None:
52
+ # Your existing ServerApp code ...
53
+ """
54
+
55
+ DRIVER_DEPRECATION_MSG = """
56
+ The `Driver` class is deprecated, it will be removed in a future release.
57
+ """
58
+ DRIVER_EXAMPLE_MSG = """
59
+ Instead, use `Grid` in the signature of your `ServerApp`. For example:
60
+ """
61
+
49
62
 
50
63
  @contextmanager
51
64
  def _empty_lifespan(_: Context) -> Iterator[None]:
@@ -57,7 +70,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
57
70
 
58
71
  Examples
59
72
  --------
60
- Use the `ServerApp` with an existing `Strategy`:
73
+ Use the ``ServerApp`` with an existing ``Strategy``:
61
74
 
62
75
  >>> def server_fn(context: Context):
63
76
  >>> server_config = ServerConfig(num_rounds=3)
@@ -69,12 +82,12 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
69
82
  >>>
70
83
  >>> app = ServerApp(server_fn=server_fn)
71
84
 
72
- Use the `ServerApp` with a custom main function:
85
+ Use the ``ServerApp`` with a custom main function:
73
86
 
74
87
  >>> app = ServerApp()
75
88
  >>>
76
89
  >>> @app.main()
77
- >>> def main(driver: Driver, context: Context) -> None:
90
+ >>> def main(grid: Grid, context: Context) -> None:
78
91
  >>> print("ServerApp running")
79
92
  """
80
93
 
@@ -114,7 +127,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
114
127
  self._main: Optional[ServerAppCallable] = None
115
128
  self._lifespan = _empty_lifespan
116
129
 
117
- def __call__(self, driver: Driver, context: Context) -> None:
130
+ def __call__(self, grid: Grid, context: Context) -> None:
118
131
  """Execute `ServerApp`."""
119
132
  with self._lifespan(context):
120
133
  # Compatibility mode
@@ -126,17 +139,17 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
126
139
  self._config = components.config
127
140
  self._strategy = components.strategy
128
141
  self._client_manager = components.client_manager
129
- start_driver(
142
+ start_grid(
130
143
  server=self._server,
131
144
  config=self._config,
132
145
  strategy=self._strategy,
133
146
  client_manager=self._client_manager,
134
- driver=driver,
147
+ grid=grid,
135
148
  )
136
149
  return
137
150
 
138
151
  # New execution mode
139
- self._main(driver, context)
152
+ self._main(grid, context)
140
153
 
141
154
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
142
155
  """Return a decorator that registers the main fn with the server app.
@@ -146,7 +159,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
146
159
  >>> app = ServerApp()
147
160
  >>>
148
161
  >>> @app.main()
149
- >>> def main(driver: Driver, context: Context) -> None:
162
+ >>> def main(grid: Grid, context: Context) -> None:
150
163
  >>> print("ServerApp running")
151
164
  """
152
165
 
@@ -171,12 +184,20 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
171
184
  >>> app = ServerApp()
172
185
  >>>
173
186
  >>> @app.main()
174
- >>> def main(driver: Driver, context: Context) -> None:
187
+ >>> def main(grid: Grid, context: Context) -> None:
175
188
  >>> print("ServerApp running")
176
189
  """,
177
190
  )
178
191
 
179
- warn_preview_feature("ServerApp-register-main-function")
192
+ sig = inspect.signature(main_fn)
193
+ param = list(sig.parameters.values())[0]
194
+ # Check if parameter name or the annotation should be updated
195
+ if param.name == "driver" or param.annotation is Driver:
196
+ warn_deprecated_feature_with_example(
197
+ deprecation_message=DRIVER_DEPRECATION_MSG,
198
+ example_message=DRIVER_EXAMPLE_MSG,
199
+ code_example=GRID_USAGE_EXAMPLE,
200
+ )
180
201
 
181
202
  # Register provided function with the ServerApp object
182
203
  self._main = main_fn
@@ -212,10 +233,9 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
212
233
  """
213
234
 
214
235
  def lifespan_decorator(
215
- lifespan_fn: Callable[[Context], Iterator[None]]
236
+ lifespan_fn: Callable[[Context], Iterator[None]],
216
237
  ) -> Callable[[Context], Iterator[None]]:
217
238
  """Register the lifespan fn with the ServerApp object."""
218
- warn_preview_feature("ServerApp-register-lifespan-function")
219
239
 
220
240
  @contextmanager
221
241
  def decorated_lifespan(context: Context) -> Iterator[None]:
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
60
60
  PullServerAppInputsResponse,
61
61
  PushServerAppOutputsRequest,
62
62
  )
63
- from flwr.server.driver.grpc_driver import GrpcDriver
63
+ from flwr.server.grid.grpc_grid import GrpcGrid
64
64
  from flwr.server.run_serverapp import run as run_
65
65
 
66
66
 
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
106
106
  certificates: Optional[bytes] = None,
107
107
  ) -> None:
108
108
  """Run Flower ServerApp process."""
109
- driver = GrpcDriver(
109
+ grid = GrpcGrid(
110
110
  serverappio_service_address=serverappio_api_address,
111
111
  root_certificates=certificates,
112
112
  )
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
123
123
  # Pull ServerAppInputs from LinkState
124
124
  req = PullServerAppInputsRequest()
125
125
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
126
- res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
126
+ res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
127
127
  if not res.HasField("run"):
128
128
  sleep(3)
129
129
  run_status = None
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
135
135
 
136
136
  hash_run_id = get_sha256_hash(run.run_id)
137
137
 
138
- driver.set_run(run.run_id)
138
+ grid.set_run(run.run_id)
139
139
 
140
140
  # Start log uploader for this run
141
141
  log_uploader = start_log_uploader(
142
142
  log_queue=log_queue,
143
143
  node_id=0,
144
144
  run_id=run.run_id,
145
- stub=driver._stub,
145
+ stub=grid._stub,
146
146
  )
147
147
 
148
148
  log(DEBUG, "[flwr-serverapp] Start FAB installation.")
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
173
173
 
174
174
  # Change status to Running
175
175
  run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
176
- driver._stub.UpdateRunStatus(
176
+ grid._stub.UpdateRunStatus(
177
177
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
178
178
  )
179
179
 
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
182
182
  event_details={"run-id-hash": hash_run_id},
183
183
  )
184
184
 
185
- # Load and run the ServerApp with the Driver
185
+ # Load and run the ServerApp with the Grid
186
186
  updated_context = run_(
187
- driver=driver,
187
+ grid=grid,
188
188
  server_app_dir=app_path,
189
189
  server_app_attr=server_app_attr,
190
190
  context=context,
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
196
196
  out_req = PushServerAppOutputsRequest(
197
197
  run_id=run.run_id, context=context_proto
198
198
  )
199
- _ = driver._stub.PushServerAppOutputs(out_req)
199
+ _ = grid._stub.PushServerAppOutputs(out_req)
200
200
 
201
201
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
202
202
  except RunNotRunningException:
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
221
221
  # Update run status
222
222
  if run_status:
223
223
  run_status_proto = run_status_to_proto(run_status)
224
- driver._stub.UpdateRunStatus(
224
+ grid._stub.UpdateRunStatus(
225
225
  UpdateRunStatusRequest(
226
226
  run_id=run.run_id, run_status=run_status_proto
227
227
  )
@@ -21,9 +21,9 @@ from typing import Callable
21
21
  from flwr.client.client_app import ClientApp
22
22
  from flwr.common.context import Context
23
23
  from flwr.common.message import Message
24
- from flwr.common.typing import ConfigsRecordValues
24
+ from flwr.common.typing import ConfigRecordValues
25
25
 
26
- BackendConfig = dict[str, dict[str, ConfigsRecordValues]]
26
+ BackendConfig = dict[str, dict[str, ConfigRecordValues]]
27
27
 
28
28
 
29
29
  class Backend(ABC):
@@ -26,7 +26,7 @@ from flwr.common.constant import PARTITION_ID_KEY
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.message import Message
29
- from flwr.common.typing import ConfigsRecordValues
29
+ from flwr.common.typing import ConfigRecordValues
30
30
  from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
31
31
  from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
32
32
 
@@ -104,7 +104,7 @@ class RayBackend(Backend):
104
104
  if not ray.is_initialized():
105
105
  ray_init_args: dict[
106
106
  str,
107
- ConfigsRecordValues,
107
+ ConfigRecordValues,
108
108
  ] = {}
109
109
 
110
110
  if backend_config.get(self.init_args_key):
@@ -130,9 +130,7 @@ def worker(
130
130
  e_code = ErrorCode.UNKNOWN
131
131
 
132
132
  reason = str(type(ex)) + ":<'" + str(ex) + "'>"
133
- out_mssg = message.create_error_reply(
134
- error=Error(code=e_code, reason=reason)
135
- )
133
+ out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
136
134
 
137
135
  finally:
138
136
  if out_mssg:
@@ -27,16 +27,18 @@ from flwr.common import Context, Message, log, now
27
27
  from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
+ PING_PATIENCE,
30
31
  RUN_ID_NUM_BYTES,
31
32
  SUPERLINK_NODE_ID,
32
33
  Status,
33
34
  )
34
- from flwr.common.record import ConfigsRecord
35
+ from flwr.common.record import ConfigRecord
35
36
  from flwr.common.typing import Run, RunStatus, UserConfig
36
37
  from flwr.server.superlink.linkstate.linkstate import LinkState
37
38
  from flwr.server.utils import validate_message
38
39
 
39
40
  from .utils import (
41
+ check_node_availability_for_in_message,
40
42
  generate_rand_int_from_bytes,
41
43
  has_valid_sub_status,
42
44
  is_valid_transition,
@@ -67,7 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
67
69
  # Map run_id to RunRecord
68
70
  self.run_ids: dict[int, RunRecord] = {}
69
71
  self.contexts: dict[int, Context] = {}
70
- self.federation_options: dict[int, ConfigsRecord] = {}
72
+ self.federation_options: dict[int, ConfigRecord] = {}
71
73
  self.message_ins_store: dict[UUID, Message] = {}
72
74
  self.message_res_store: dict[UUID, Message] = {}
73
75
  self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
@@ -156,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
156
158
  res_metadata = message.metadata
157
159
  with self.lock:
158
160
  # Check if the Message it is replying to exists and is valid
159
- msg_ins_id = res_metadata.reply_to_message
161
+ msg_ins_id = res_metadata.reply_to_message_id
160
162
  msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
161
163
 
162
164
  # Ensure that dst_node_id of original Message matches the src_node_id of
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
232
234
  with self.lock:
233
235
  current = time.time()
234
236
 
235
- # Verify Messge IDs
237
+ # Verify Message IDs
236
238
  ret = verify_message_ids(
237
239
  inquired_message_ids=message_ids,
238
240
  found_message_ins_dict=self.message_ins_store,
239
241
  current_time=current,
240
242
  )
241
243
 
244
+ # Check node availability
245
+ dst_node_ids = {
246
+ self.message_ins_store[message_id].metadata.dst_node_id
247
+ for message_id in message_ids
248
+ }
249
+ tmp_ret_dict = check_node_availability_for_in_message(
250
+ inquired_in_message_ids=message_ids,
251
+ found_in_message_dict=self.message_ins_store,
252
+ node_id_to_online_until={
253
+ node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
254
+ },
255
+ current_time=current,
256
+ )
257
+ ret.update(tmp_ret_dict)
258
+
242
259
  # Find all reply Messages
243
260
  message_res_found: list[Message] = []
244
261
  for message_id in message_ids:
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
317
334
  log(ERROR, "Unexpected node registration failure.")
318
335
  return 0
319
336
 
337
+ # Mark the node online util time.time() + ping_interval
320
338
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
321
339
  return node_id
322
340
 
@@ -381,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
381
399
  fab_version: Optional[str],
382
400
  fab_hash: Optional[str],
383
401
  override_config: UserConfig,
384
- federation_options: ConfigsRecord,
402
+ federation_options: ConfigRecord,
385
403
  ) -> int:
386
404
  """Create a new run for the specified `fab_hash`."""
387
405
  # Sample a random int64 as run_id
@@ -510,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
510
528
 
511
529
  return pending_run_id
512
530
 
513
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
531
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
514
532
  """Retrieve the federation options for the specified `run_id`."""
515
533
  with self.lock:
516
534
  if run_id not in self.run_ids:
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
519
537
  return self.federation_options[run_id]
520
538
 
521
539
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
522
- """Acknowledge a ping received from a node, serving as a heartbeat."""
540
+ """Acknowledge a ping received from a node, serving as a heartbeat.
541
+
542
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
543
+ marking the node as offline, where PING_PATIENCE = 2 in default.
544
+ """
523
545
  with self.lock:
524
546
  if node_id in self.node_ids:
525
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
547
+ self.node_ids[node_id] = (
548
+ time.time() + PING_PATIENCE * ping_interval,
549
+ ping_interval,
550
+ )
526
551
  return True
527
552
  return False
528
553
 
@@ -20,7 +20,7 @@ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common import Context, Message
23
- from flwr.common.record import ConfigsRecord
23
+ from flwr.common.record import ConfigRecord
24
24
  from flwr.common.typing import Run, RunStatus, UserConfig
25
25
 
26
26
 
@@ -164,7 +164,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
164
164
  fab_version: Optional[str],
165
165
  fab_hash: Optional[str],
166
166
  override_config: UserConfig,
167
- federation_options: ConfigsRecord,
167
+ federation_options: ConfigRecord,
168
168
  ) -> int:
169
169
  """Create a new run for the specified `fab_hash`."""
170
170
 
@@ -236,7 +236,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
236
236
  """
237
237
 
238
238
  @abc.abstractmethod
239
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
239
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
240
240
  """Retrieve the federation options for the specified `run_id`.
241
241
 
242
242
  Parameters
@@ -246,7 +246,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
246
246
 
247
247
  Returns
248
248
  -------
249
- Optional[ConfigsRecord]
249
+ Optional[ConfigRecord]
250
250
  The federation options for the run if it exists; None otherwise.
251
251
  """
252
252