flwr-nightly 1.17.0.dev20250317__py3-none-any.whl → 1.17.0.dev20250319__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. flwr/common/constant.py +5 -0
  2. flwr/common/logger.py +2 -2
  3. flwr/common/record/parametersrecord.py +336 -92
  4. flwr/server/__init__.py +3 -1
  5. flwr/server/app.py +1 -1
  6. flwr/server/compat/__init__.py +2 -2
  7. flwr/server/compat/app.py +11 -11
  8. flwr/server/compat/app_utils.py +16 -16
  9. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +9 -9
  10. flwr/server/{driver → grid}/__init__.py +8 -7
  11. flwr/server/{driver/driver.py → grid/grid.py} +44 -15
  12. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +12 -20
  13. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +6 -14
  14. flwr/server/run_serverapp.py +4 -4
  15. flwr/server/server_app.py +38 -12
  16. flwr/server/serverapp/app.py +10 -10
  17. flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -3
  18. flwr/server/superlink/linkstate/sqlite_linkstate.py +40 -2
  19. flwr/server/superlink/linkstate/utils.py +67 -10
  20. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  21. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  22. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +1 -1
  23. flwr/server/typing.py +3 -3
  24. flwr/server/workflow/default_workflows.py +17 -19
  25. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +15 -15
  26. flwr/simulation/run_simulation.py +10 -10
  27. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/METADATA +1 -1
  28. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/RECORD +31 -31
  29. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/LICENSE +0 -0
  30. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/WHEEL +0 -0
  31. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/entry_points.txt +0 -0
flwr/server/server_app.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
+ import inspect
18
19
  from collections.abc import Iterator
19
20
  from contextlib import contextmanager
20
21
  from typing import Callable, Optional
@@ -24,8 +25,8 @@ from flwr.common.logger import warn_deprecated_feature_with_example
24
25
  from flwr.server.strategy import Strategy
25
26
 
26
27
  from .client_manager import ClientManager
27
- from .compat import start_driver
28
- from .driver import Driver
28
+ from .compat import start_grid
29
+ from .grid import Driver, Grid
29
30
  from .server import Server
30
31
  from .server_config import ServerConfig
31
32
  from .typing import ServerAppCallable, ServerFn
@@ -43,6 +44,21 @@ SERVER_FN_USAGE_EXAMPLE = """
43
44
  app = ServerApp(server_fn=server_fn)
44
45
  """
45
46
 
47
+ GRID_USAGE_EXAMPLE = """
48
+ app = ServerApp()
49
+
50
+ @app.main()
51
+ def main(grid: Grid, context: Context) -> None:
52
+ # Your existing ServerApp code ...
53
+ """
54
+
55
+ DRIVER_DEPRECATION_MSG = """
56
+ The `Driver` class is deprecated, it will be removed in a future release.
57
+ """
58
+ DRIVER_EXAMPLE_MSG = """
59
+ Instead, use `Grid` in the signature of your `ServerApp`. For example:
60
+ """
61
+
46
62
 
47
63
  @contextmanager
48
64
  def _empty_lifespan(_: Context) -> Iterator[None]:
@@ -54,7 +70,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
54
70
 
55
71
  Examples
56
72
  --------
57
- Use the `ServerApp` with an existing `Strategy`:
73
+ Use the ``ServerApp`` with an existing ``Strategy``:
58
74
 
59
75
  >>> def server_fn(context: Context):
60
76
  >>> server_config = ServerConfig(num_rounds=3)
@@ -66,12 +82,12 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
66
82
  >>>
67
83
  >>> app = ServerApp(server_fn=server_fn)
68
84
 
69
- Use the `ServerApp` with a custom main function:
85
+ Use the ``ServerApp`` with a custom main function:
70
86
 
71
87
  >>> app = ServerApp()
72
88
  >>>
73
89
  >>> @app.main()
74
- >>> def main(driver: Driver, context: Context) -> None:
90
+ >>> def main(grid: Grid, context: Context) -> None:
75
91
  >>> print("ServerApp running")
76
92
  """
77
93
 
@@ -111,7 +127,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
111
127
  self._main: Optional[ServerAppCallable] = None
112
128
  self._lifespan = _empty_lifespan
113
129
 
114
- def __call__(self, driver: Driver, context: Context) -> None:
130
+ def __call__(self, grid: Grid, context: Context) -> None:
115
131
  """Execute `ServerApp`."""
116
132
  with self._lifespan(context):
117
133
  # Compatibility mode
@@ -123,17 +139,17 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
123
139
  self._config = components.config
124
140
  self._strategy = components.strategy
125
141
  self._client_manager = components.client_manager
126
- start_driver(
142
+ start_grid(
127
143
  server=self._server,
128
144
  config=self._config,
129
145
  strategy=self._strategy,
130
146
  client_manager=self._client_manager,
131
- driver=driver,
147
+ grid=grid,
132
148
  )
133
149
  return
134
150
 
135
151
  # New execution mode
136
- self._main(driver, context)
152
+ self._main(grid, context)
137
153
 
138
154
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
139
155
  """Return a decorator that registers the main fn with the server app.
@@ -143,7 +159,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
143
159
  >>> app = ServerApp()
144
160
  >>>
145
161
  >>> @app.main()
146
- >>> def main(driver: Driver, context: Context) -> None:
162
+ >>> def main(grid: Grid, context: Context) -> None:
147
163
  >>> print("ServerApp running")
148
164
  """
149
165
 
@@ -168,11 +184,21 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
168
184
  >>> app = ServerApp()
169
185
  >>>
170
186
  >>> @app.main()
171
- >>> def main(driver: Driver, context: Context) -> None:
187
+ >>> def main(grid: Grid, context: Context) -> None:
172
188
  >>> print("ServerApp running")
173
189
  """,
174
190
  )
175
191
 
192
+ sig = inspect.signature(main_fn)
193
+ param = list(sig.parameters.values())[0]
194
+ # Check if parameter name or the annotation should be updated
195
+ if param.name == "driver" or param.annotation is Driver:
196
+ warn_deprecated_feature_with_example(
197
+ deprecation_message=DRIVER_DEPRECATION_MSG,
198
+ example_message=DRIVER_EXAMPLE_MSG,
199
+ code_example=GRID_USAGE_EXAMPLE,
200
+ )
201
+
176
202
  # Register provided function with the ServerApp object
177
203
  self._main = main_fn
178
204
 
@@ -207,7 +233,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
207
233
  """
208
234
 
209
235
  def lifespan_decorator(
210
- lifespan_fn: Callable[[Context], Iterator[None]]
236
+ lifespan_fn: Callable[[Context], Iterator[None]],
211
237
  ) -> Callable[[Context], Iterator[None]]:
212
238
  """Register the lifespan fn with the ServerApp object."""
213
239
 
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
60
60
  PullServerAppInputsResponse,
61
61
  PushServerAppOutputsRequest,
62
62
  )
63
- from flwr.server.driver.grpc_driver import GrpcDriver
63
+ from flwr.server.grid.grpc_grid import GrpcGrid
64
64
  from flwr.server.run_serverapp import run as run_
65
65
 
66
66
 
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
106
106
  certificates: Optional[bytes] = None,
107
107
  ) -> None:
108
108
  """Run Flower ServerApp process."""
109
- driver = GrpcDriver(
109
+ grid = GrpcGrid(
110
110
  serverappio_service_address=serverappio_api_address,
111
111
  root_certificates=certificates,
112
112
  )
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
123
123
  # Pull ServerAppInputs from LinkState
124
124
  req = PullServerAppInputsRequest()
125
125
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
126
- res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
126
+ res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
127
127
  if not res.HasField("run"):
128
128
  sleep(3)
129
129
  run_status = None
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
135
135
 
136
136
  hash_run_id = get_sha256_hash(run.run_id)
137
137
 
138
- driver.set_run(run.run_id)
138
+ grid.set_run(run.run_id)
139
139
 
140
140
  # Start log uploader for this run
141
141
  log_uploader = start_log_uploader(
142
142
  log_queue=log_queue,
143
143
  node_id=0,
144
144
  run_id=run.run_id,
145
- stub=driver._stub,
145
+ stub=grid._stub,
146
146
  )
147
147
 
148
148
  log(DEBUG, "[flwr-serverapp] Start FAB installation.")
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
173
173
 
174
174
  # Change status to Running
175
175
  run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
176
- driver._stub.UpdateRunStatus(
176
+ grid._stub.UpdateRunStatus(
177
177
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
178
178
  )
179
179
 
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
182
182
  event_details={"run-id-hash": hash_run_id},
183
183
  )
184
184
 
185
- # Load and run the ServerApp with the Driver
185
+ # Load and run the ServerApp with the Grid
186
186
  updated_context = run_(
187
- driver=driver,
187
+ grid=grid,
188
188
  server_app_dir=app_path,
189
189
  server_app_attr=server_app_attr,
190
190
  context=context,
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
196
196
  out_req = PushServerAppOutputsRequest(
197
197
  run_id=run.run_id, context=context_proto
198
198
  )
199
- _ = driver._stub.PushServerAppOutputs(out_req)
199
+ _ = grid._stub.PushServerAppOutputs(out_req)
200
200
 
201
201
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
202
202
  except RunNotRunningException:
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
221
221
  # Update run status
222
222
  if run_status:
223
223
  run_status_proto = run_status_to_proto(run_status)
224
- driver._stub.UpdateRunStatus(
224
+ grid._stub.UpdateRunStatus(
225
225
  UpdateRunStatusRequest(
226
226
  run_id=run.run_id, run_status=run_status_proto
227
227
  )
@@ -27,6 +27,7 @@ from flwr.common import Context, Message, log, now
27
27
  from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
+ PING_PATIENCE,
30
31
  RUN_ID_NUM_BYTES,
31
32
  SUPERLINK_NODE_ID,
32
33
  Status,
@@ -37,6 +38,7 @@ from flwr.server.superlink.linkstate.linkstate import LinkState
37
38
  from flwr.server.utils import validate_message
38
39
 
39
40
  from .utils import (
41
+ check_node_availability_for_in_message,
40
42
  generate_rand_int_from_bytes,
41
43
  has_valid_sub_status,
42
44
  is_valid_transition,
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
232
234
  with self.lock:
233
235
  current = time.time()
234
236
 
235
- # Verify Messge IDs
237
+ # Verify Message IDs
236
238
  ret = verify_message_ids(
237
239
  inquired_message_ids=message_ids,
238
240
  found_message_ins_dict=self.message_ins_store,
239
241
  current_time=current,
240
242
  )
241
243
 
244
+ # Check node availability
245
+ dst_node_ids = {
246
+ self.message_ins_store[message_id].metadata.dst_node_id
247
+ for message_id in message_ids
248
+ }
249
+ tmp_ret_dict = check_node_availability_for_in_message(
250
+ inquired_in_message_ids=message_ids,
251
+ found_in_message_dict=self.message_ins_store,
252
+ node_id_to_online_until={
253
+ node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
254
+ },
255
+ current_time=current,
256
+ )
257
+ ret.update(tmp_ret_dict)
258
+
242
259
  # Find all reply Messages
243
260
  message_res_found: list[Message] = []
244
261
  for message_id in message_ids:
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
317
334
  log(ERROR, "Unexpected node registration failure.")
318
335
  return 0
319
336
 
337
+ # Mark the node online util time.time() + ping_interval
320
338
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
321
339
  return node_id
322
340
 
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
519
537
  return self.federation_options[run_id]
520
538
 
521
539
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
522
- """Acknowledge a ping received from a node, serving as a heartbeat."""
540
+ """Acknowledge a ping received from a node, serving as a heartbeat.
541
+
542
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
543
+ marking the node as offline, where PING_PATIENCE = 2 in default.
544
+ """
523
545
  with self.lock:
524
546
  if node_id in self.node_ids:
525
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
547
+ self.node_ids[node_id] = (
548
+ time.time() + PING_PATIENCE * ping_interval,
549
+ ping_interval,
550
+ )
526
551
  return True
527
552
  return False
528
553
 
@@ -30,6 +30,7 @@ from flwr.common import Context, Message, Metadata, log, now
30
30
  from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
+ PING_PATIENCE,
33
34
  RUN_ID_NUM_BYTES,
34
35
  SUPERLINK_NODE_ID,
35
36
  Status,
@@ -52,6 +53,7 @@ from flwr.server.utils.validator import validate_message
52
53
 
53
54
  from .linkstate import LinkState
54
55
  from .utils import (
56
+ check_node_availability_for_in_message,
55
57
  configsrecord_from_bytes,
56
58
  configsrecord_to_bytes,
57
59
  context_from_bytes,
@@ -442,6 +444,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
442
444
 
443
445
  def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
444
446
  """Get reply Messages for the given Message IDs."""
447
+ # pylint: disable-msg=too-many-locals
445
448
  ret: dict[UUID, Message] = {}
446
449
 
447
450
  # Verify Message IDs
@@ -465,6 +468,29 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
465
468
  current_time=current,
466
469
  )
467
470
 
471
+ # Check node availability
472
+ dst_node_ids: set[int] = set()
473
+ for message_id in message_ids:
474
+ in_message = found_message_ins_dict[message_id]
475
+ sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
476
+ dst_node_ids.add(sint_node_id)
477
+ query = f"""
478
+ SELECT node_id, online_until
479
+ FROM node
480
+ WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
481
+ """
482
+ rows = self.query(query, tuple(dst_node_ids))
483
+ tmp_ret_dict = check_node_availability_for_in_message(
484
+ inquired_in_message_ids=message_ids,
485
+ found_in_message_dict=found_message_ins_dict,
486
+ node_id_to_online_until={
487
+ convert_sint64_to_uint64(row["node_id"]): row["online_until"]
488
+ for row in rows
489
+ },
490
+ current_time=current,
491
+ )
492
+ ret.update(tmp_ret_dict)
493
+
468
494
  # Find all reply Messages
469
495
  query = f"""
470
496
  SELECT *
@@ -584,6 +610,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
584
610
  "VALUES (?, ?, ?, ?)"
585
611
  )
586
612
 
613
+ # Mark the node online util time.time() + ping_interval
587
614
  try:
588
615
  self.query(
589
616
  query,
@@ -899,7 +926,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
899
926
  return configsrecord_from_bytes(row["federation_options"])
900
927
 
901
928
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
902
- """Acknowledge a ping received from a node, serving as a heartbeat."""
929
+ """Acknowledge a ping received from a node, serving as a heartbeat.
930
+
931
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
932
+ marking the node as offline, where PING_PATIENCE = 2 in default.
933
+ """
903
934
  sint64_node_id = convert_uint64_to_sint64(node_id)
904
935
 
905
936
  # Check if the node exists in the `node` table
@@ -909,7 +940,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
909
940
 
910
941
  # Update `online_until` and `ping_interval` for the given `node_id`
911
942
  query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
912
- self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
943
+ self.query(
944
+ query,
945
+ (
946
+ time.time() + PING_PATIENCE * ping_interval,
947
+ ping_interval,
948
+ sint64_node_id,
949
+ ),
950
+ )
913
951
  return True
914
952
 
915
953
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
@@ -34,12 +34,6 @@ from flwr.proto.message_pb2 import Context as ProtoContext
34
34
  from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
35
35
 
36
36
  # pylint: enable=E0611
37
-
38
- NODE_UNAVAILABLE_ERROR_REASON = (
39
- "Error: Node Unavailable - The destination node is currently unavailable. "
40
- "It exceeds the time limit specified in its last ping."
41
- )
42
-
43
37
  VALID_RUN_STATUS_TRANSITIONS = {
44
38
  (Status.PENDING, Status.STARTING),
45
39
  (Status.STARTING, Status.RUNNING),
@@ -60,6 +54,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
60
54
  REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
61
55
  "Error: Reply Message Unavailable - The reply message has expired."
62
56
  )
57
+ NODE_UNAVAILABLE_ERROR_REASON = (
58
+ "Error: Node Unavailable - The destination node is currently unavailable. "
59
+ "It exceeds twice the time limit specified in its last ping."
60
+ )
63
61
 
64
62
 
65
63
  def generate_rand_int_from_bytes(
@@ -237,7 +235,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
237
235
  return status.sub_status == ""
238
236
 
239
237
 
240
- def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Message:
238
+ def create_message_error_unavailable_res_message(
239
+ ins_metadata: Metadata, error_type: str
240
+ ) -> Message:
241
241
  """Generate an error Message that the SuperLink returns carrying the specified
242
242
  error."""
243
243
  current_time = now().timestamp()
@@ -256,8 +256,16 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
256
256
  return Message(
257
257
  metadata=metadata,
258
258
  error=Error(
259
- code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
260
- reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
259
+ code=(
260
+ ErrorCode.REPLY_MESSAGE_UNAVAILABLE
261
+ if error_type == "msg_unavail"
262
+ else ErrorCode.NODE_UNAVAILABLE
263
+ ),
264
+ reason=(
265
+ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
266
+ if error_type == "msg_unavail"
267
+ else NODE_UNAVAILABLE_ERROR_REASON
268
+ ),
261
269
  ),
262
270
  )
263
271
 
@@ -371,7 +379,56 @@ def verify_found_message_replies(
371
379
  if message_ttl_has_expired(message_res.metadata, current):
372
380
  # No need to insert the error Message
373
381
  message_res = create_message_error_unavailable_res_message(
374
- found_message_ins_dict[message_ins_id].metadata
382
+ found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
375
383
  )
376
384
  ret_dict[message_ins_id] = message_res
377
385
  return ret_dict
386
+
387
+
388
+ def check_node_availability_for_in_message(
389
+ inquired_in_message_ids: set[UUID],
390
+ found_in_message_dict: dict[UUID, Message],
391
+ node_id_to_online_until: dict[int, float],
392
+ current_time: Optional[float] = None,
393
+ update_set: bool = True,
394
+ ) -> dict[UUID, Message]:
395
+ """Check node availability for given Message and generate error reply Message if
396
+ unavailable. A Message error indicating node unavailability will be generated for
397
+ each given Message whose destination node is offline or non-existent.
398
+
399
+ Parameters
400
+ ----------
401
+ inquired_in_message_ids : set[UUID]
402
+ Set of Message IDs for which to check destination node availability.
403
+ found_in_message_dict : dict[UUID, Message]
404
+ Dictionary containing all found Message indexed by their IDs.
405
+ node_id_to_online_until : dict[int, float]
406
+ Dictionary mapping node IDs to their online-until timestamps.
407
+ current_time : Optional[float] (default: None)
408
+ The current time to check for expiration. If set to `None`, the current time
409
+ will automatically be set to the current timestamp using `now().timestamp()`.
410
+ update_set : bool (default: True)
411
+ If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
412
+ by default True.
413
+
414
+ Returns
415
+ -------
416
+ dict[UUID, Message]
417
+ A dictionary of error Message indexed by the corresponding Message ID.
418
+ """
419
+ ret_dict = {}
420
+ current = current_time if current_time else now().timestamp()
421
+ for in_message_id in list(inquired_in_message_ids):
422
+ in_message = found_in_message_dict[in_message_id]
423
+ node_id = in_message.metadata.dst_node_id
424
+ online_until = node_id_to_online_until.get(node_id)
425
+ # Generate a reply message containing an error reply
426
+ # if the node is offline or doesn't exist.
427
+ if online_until is None or online_until < current:
428
+ if update_set:
429
+ inquired_in_message_ids.remove(in_message_id)
430
+ reply_message = create_message_error_unavailable_res_message(
431
+ in_message.metadata, "node_unavail"
432
+ )
433
+ ret_dict[in_message_id] = reply_message
434
+ return ret_dict
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/server/typing.py CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
19
19
 
20
20
  from flwr.common import Context
21
21
 
22
- from .driver import Driver
22
+ from .grid import Grid
23
23
  from .serverapp_components import ServerAppComponents
24
24
 
25
- ServerAppCallable = Callable[[Driver, Context], None]
26
- Workflow = Callable[[Driver, Context], None]
25
+ ServerAppCallable = Callable[[Grid, Context], None]
26
+ Workflow = Callable[[Grid, Context], None]
27
27
  ServerFn = Callable[[Context], ServerAppComponents]
@@ -36,7 +36,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
36
36
  from ..client_proxy import ClientProxy
37
37
  from ..compat.app_utils import start_update_client_manager_thread
38
38
  from ..compat.legacy_context import LegacyContext
39
- from ..driver import Driver
39
+ from ..grid import Grid
40
40
  from ..typing import Workflow
41
41
  from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
42
42
 
@@ -56,7 +56,7 @@ class DefaultWorkflow:
56
56
  self.fit_workflow: Workflow = fit_workflow
57
57
  self.evaluate_workflow: Workflow = evaluate_workflow
58
58
 
59
- def __call__(self, driver: Driver, context: Context) -> None:
59
+ def __call__(self, grid: Grid, context: Context) -> None:
60
60
  """Execute the workflow."""
61
61
  if not isinstance(context, LegacyContext):
62
62
  raise TypeError(
@@ -65,7 +65,7 @@ class DefaultWorkflow:
65
65
 
66
66
  # Start the thread updating nodes
67
67
  thread, f_stop, c_done = start_update_client_manager_thread(
68
- driver, context.client_manager
68
+ grid, context.client_manager
69
69
  )
70
70
 
71
71
  # Wait until the node registration done
@@ -73,7 +73,7 @@ class DefaultWorkflow:
73
73
 
74
74
  # Initialize parameters
75
75
  log(INFO, "[INIT]")
76
- default_init_params_workflow(driver, context)
76
+ default_init_params_workflow(grid, context)
77
77
 
78
78
  # Run federated learning for num_rounds
79
79
  start_time = timeit.default_timer()
@@ -87,13 +87,13 @@ class DefaultWorkflow:
87
87
  cfg[Key.CURRENT_ROUND] = current_round
88
88
 
89
89
  # Fit round
90
- self.fit_workflow(driver, context)
90
+ self.fit_workflow(grid, context)
91
91
 
92
92
  # Centralized evaluation
93
- default_centralized_evaluation_workflow(driver, context)
93
+ default_centralized_evaluation_workflow(grid, context)
94
94
 
95
95
  # Evaluate round
96
- self.evaluate_workflow(driver, context)
96
+ self.evaluate_workflow(grid, context)
97
97
 
98
98
  # Bookkeeping and log results
99
99
  end_time = timeit.default_timer()
@@ -119,7 +119,7 @@ class DefaultWorkflow:
119
119
  thread.join()
120
120
 
121
121
 
122
- def default_init_params_workflow(driver: Driver, context: Context) -> None:
122
+ def default_init_params_workflow(grid: Grid, context: Context) -> None:
123
123
  """Execute the default workflow for parameters initialization."""
124
124
  if not isinstance(context, LegacyContext):
125
125
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -138,9 +138,9 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
138
138
  random_client = context.client_manager.sample(1)[0]
139
139
  # Send GetParametersIns and get the response
140
140
  content = compat.getparametersins_to_recordset(GetParametersIns({}))
141
- messages = driver.send_and_receive(
141
+ messages = grid.send_and_receive(
142
142
  [
143
- driver.create_message(
143
+ grid.create_message(
144
144
  content=content,
145
145
  message_type=MessageTypeLegacy.GET_PARAMETERS,
146
146
  dst_node_id=random_client.node_id,
@@ -186,7 +186,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
186
186
  log(INFO, "Evaluation returned no results (`None`)")
187
187
 
188
188
 
189
- def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
189
+ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
190
190
  """Execute the default workflow for centralized evaluation."""
191
191
  if not isinstance(context, LegacyContext):
192
192
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -218,9 +218,7 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
218
218
  )
219
219
 
220
220
 
221
- def default_fit_workflow( # pylint: disable=R0914
222
- driver: Driver, context: Context
223
- ) -> None:
221
+ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
224
222
  """Execute the default workflow for a single fit round."""
225
223
  if not isinstance(context, LegacyContext):
226
224
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -255,7 +253,7 @@ def default_fit_workflow( # pylint: disable=R0914
255
253
 
256
254
  # Build out messages
257
255
  out_messages = [
258
- driver.create_message(
256
+ grid.create_message(
259
257
  content=compat.fitins_to_recordset(fitins, True),
260
258
  message_type=MessageType.TRAIN,
261
259
  dst_node_id=proxy.node_id,
@@ -266,7 +264,7 @@ def default_fit_workflow( # pylint: disable=R0914
266
264
 
267
265
  # Send instructions to clients and
268
266
  # collect `fit` results from all clients participating in this round
269
- messages = list(driver.send_and_receive(out_messages))
267
+ messages = list(grid.send_and_receive(out_messages))
270
268
  del out_messages
271
269
  num_failures = len([msg for msg in messages if msg.has_error()])
272
270
 
@@ -307,7 +305,7 @@ def default_fit_workflow( # pylint: disable=R0914
307
305
 
308
306
 
309
307
  # pylint: disable-next=R0914
310
- def default_evaluate_workflow(driver: Driver, context: Context) -> None:
308
+ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
311
309
  """Execute the default workflow for a single evaluate round."""
312
310
  if not isinstance(context, LegacyContext):
313
311
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -341,7 +339,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
341
339
 
342
340
  # Build out messages
343
341
  out_messages = [
344
- driver.create_message(
342
+ grid.create_message(
345
343
  content=compat.evaluateins_to_recordset(evalins, True),
346
344
  message_type=MessageType.EVALUATE,
347
345
  dst_node_id=proxy.node_id,
@@ -352,7 +350,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
352
350
 
353
351
  # Send instructions to clients and
354
352
  # collect `evaluate` results from all clients participating in this round
355
- messages = list(driver.send_and_receive(out_messages))
353
+ messages = list(grid.send_and_receive(out_messages))
356
354
  del out_messages
357
355
  num_failures = len([msg for msg in messages if msg.has_error()])
358
356