flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  11. flwr/cli/run/run.py +5 -9
  12. flwr/client/app.py +6 -4
  13. flwr/client/client_app.py +162 -99
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/grpc_client/connection.py +24 -21
  16. flwr/client/message_handler/message_handler.py +27 -27
  17. flwr/client/mod/__init__.py +2 -2
  18. flwr/client/mod/centraldp_mods.py +7 -7
  19. flwr/client/mod/comms_mods.py +16 -22
  20. flwr/client/mod/localdp_mod.py +4 -4
  21. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  22. flwr/client/run_info_store.py +2 -2
  23. flwr/common/__init__.py +12 -4
  24. flwr/common/config.py +4 -4
  25. flwr/common/constant.py +6 -6
  26. flwr/common/context.py +4 -4
  27. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  28. flwr/common/logger.py +2 -2
  29. flwr/common/message.py +327 -102
  30. flwr/common/record/__init__.py +8 -4
  31. flwr/common/record/arrayrecord.py +626 -0
  32. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  33. flwr/common/record/conversion_utils.py +1 -1
  34. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  35. flwr/common/record/recorddict.py +288 -0
  36. flwr/common/recorddict_compat.py +410 -0
  37. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  38. flwr/common/serde.py +66 -71
  39. flwr/common/typing.py +8 -8
  40. flwr/proto/exec_pb2.py +3 -3
  41. flwr/proto/exec_pb2.pyi +3 -3
  42. flwr/proto/message_pb2.py +12 -12
  43. flwr/proto/message_pb2.pyi +9 -9
  44. flwr/proto/recorddict_pb2.py +70 -0
  45. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  46. flwr/proto/run_pb2.py +31 -31
  47. flwr/proto/run_pb2.pyi +3 -3
  48. flwr/server/__init__.py +3 -1
  49. flwr/server/app.py +56 -1
  50. flwr/server/compat/__init__.py +2 -2
  51. flwr/server/compat/app.py +11 -11
  52. flwr/server/compat/app_utils.py +16 -16
  53. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  54. flwr/server/fleet_event_log_interceptor.py +94 -0
  55. flwr/server/{driver → grid}/__init__.py +8 -7
  56. flwr/server/{driver/driver.py → grid/grid.py} +47 -18
  57. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  58. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  59. flwr/server/run_serverapp.py +4 -4
  60. flwr/server/server_app.py +38 -18
  61. flwr/server/serverapp/app.py +10 -10
  62. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  63. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  64. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  65. flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
  66. flwr/server/superlink/linkstate/linkstate.py +4 -4
  67. flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
  68. flwr/server/superlink/linkstate/utils.py +93 -27
  69. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  70. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  71. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  72. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  73. flwr/server/typing.py +3 -3
  74. flwr/server/utils/validator.py +4 -4
  75. flwr/server/workflow/default_workflows.py +48 -57
  76. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  77. flwr/simulation/app.py +2 -2
  78. flwr/simulation/ray_transport/ray_actor.py +4 -2
  79. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  80. flwr/simulation/run_simulation.py +15 -15
  81. flwr/superexec/deployment.py +4 -4
  82. flwr/superexec/exec_event_log_interceptor.py +135 -0
  83. flwr/superexec/exec_grpc.py +10 -4
  84. flwr/superexec/exec_servicer.py +2 -2
  85. flwr/superexec/exec_user_auth_interceptor.py +18 -2
  86. flwr/superexec/executor.py +3 -3
  87. flwr/superexec/simulation.py +3 -3
  88. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
  89. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
  90. flwr/common/record/parametersrecord.py +0 -339
  91. flwr/common/record/recordset.py +0 -209
  92. flwr/common/recordset_compat.py +0 -418
  93. flwr/proto/recordset_pb2.py +0 -70
  94. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  95. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  96. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  97. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
  98. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
@@ -30,30 +30,33 @@ from flwr.common import Context, Message, Metadata, log, now
30
30
  from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
+ PING_PATIENCE,
33
34
  RUN_ID_NUM_BYTES,
34
35
  SUPERLINK_NODE_ID,
35
36
  Status,
36
37
  )
37
- from flwr.common.record import ConfigsRecord
38
+ from flwr.common.message import make_message
39
+ from flwr.common.record import ConfigRecord
38
40
  from flwr.common.serde import (
39
41
  error_from_proto,
40
42
  error_to_proto,
41
- recordset_from_proto,
42
- recordset_to_proto,
43
+ recorddict_from_proto,
44
+ recorddict_to_proto,
43
45
  )
44
46
  from flwr.common.typing import Run, RunStatus, UserConfig
45
47
 
46
48
  # pylint: disable=E0611
47
49
  from flwr.proto.error_pb2 import Error as ProtoError
48
- from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
50
+ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
49
51
 
50
52
  # pylint: enable=E0611
51
53
  from flwr.server.utils.validator import validate_message
52
54
 
53
55
  from .linkstate import LinkState
54
56
  from .utils import (
55
- configsrecord_from_bytes,
56
- configsrecord_to_bytes,
57
+ check_node_availability_for_in_message,
58
+ configrecord_from_bytes,
59
+ configrecord_to_bytes,
57
60
  context_from_bytes,
58
61
  context_to_bytes,
59
62
  convert_sint64_to_uint64,
@@ -129,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
129
132
  run_id INTEGER,
130
133
  src_node_id INTEGER,
131
134
  dst_node_id INTEGER,
132
- reply_to_message TEXT,
135
+ reply_to_message_id TEXT,
133
136
  created_at REAL,
134
137
  delivered_at TEXT,
135
138
  ttl REAL,
@@ -148,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
148
151
  run_id INTEGER,
149
152
  src_node_id INTEGER,
150
153
  dst_node_id INTEGER,
151
- reply_to_message TEXT,
154
+ reply_to_message_id TEXT,
152
155
  created_at REAL,
153
156
  delivered_at TEXT,
154
157
  ttl REAL,
@@ -371,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
371
374
  return None
372
375
 
373
376
  res_metadata = message.metadata
374
- msg_ins_id = res_metadata.reply_to_message
377
+ msg_ins_id = res_metadata.reply_to_message_id
375
378
  msg_ins = self.get_valid_message_ins(msg_ins_id)
376
379
  if msg_ins is None:
377
380
  log(
@@ -442,6 +445,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
442
445
 
443
446
  def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
444
447
  """Get reply Messages for the given Message IDs."""
448
+ # pylint: disable-msg=too-many-locals
445
449
  ret: dict[UUID, Message] = {}
446
450
 
447
451
  # Verify Message IDs
@@ -465,11 +469,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
465
469
  current_time=current,
466
470
  )
467
471
 
472
+ # Check node availability
473
+ dst_node_ids: set[int] = set()
474
+ for message_id in message_ids:
475
+ in_message = found_message_ins_dict[message_id]
476
+ sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
477
+ dst_node_ids.add(sint_node_id)
478
+ query = f"""
479
+ SELECT node_id, online_until
480
+ FROM node
481
+ WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
482
+ """
483
+ rows = self.query(query, tuple(dst_node_ids))
484
+ tmp_ret_dict = check_node_availability_for_in_message(
485
+ inquired_in_message_ids=message_ids,
486
+ found_in_message_dict=found_message_ins_dict,
487
+ node_id_to_online_until={
488
+ convert_sint64_to_uint64(row["node_id"]): row["online_until"]
489
+ for row in rows
490
+ },
491
+ current_time=current,
492
+ )
493
+ ret.update(tmp_ret_dict)
494
+
468
495
  # Find all reply Messages
469
496
  query = f"""
470
497
  SELECT *
471
498
  FROM message_res
472
- WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
499
+ WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
473
500
  AND delivered_at = "";
474
501
  """
475
502
  rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
@@ -542,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
542
569
  # Delete reply Message
543
570
  query_2 = f"""
544
571
  DELETE FROM message_res
545
- WHERE reply_to_message IN ({placeholders});
572
+ WHERE reply_to_message_id IN ({placeholders});
546
573
  """
547
574
 
548
575
  with self.conn:
@@ -584,6 +611,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
584
611
  "VALUES (?, ?, ?, ?)"
585
612
  )
586
613
 
614
+ # Mark the node online util time.time() + ping_interval
587
615
  try:
588
616
  self.query(
589
617
  query,
@@ -699,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
699
727
  fab_version: Optional[str],
700
728
  fab_hash: Optional[str],
701
729
  override_config: UserConfig,
702
- federation_options: ConfigsRecord,
730
+ federation_options: ConfigRecord,
703
731
  ) -> int:
704
732
  """Create a new run for the specified `fab_id` and `fab_version`."""
705
733
  # Sample a random int64 as run_id
@@ -725,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
725
753
  fab_version,
726
754
  fab_hash,
727
755
  override_config_json,
728
- configsrecord_to_bytes(federation_options),
756
+ configrecord_to_bytes(federation_options),
729
757
  ]
730
758
  data += [
731
759
  now().isoformat(),
@@ -883,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
883
911
 
884
912
  return pending_run_id
885
913
 
886
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
914
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
887
915
  """Retrieve the federation options for the specified `run_id`."""
888
916
  # Convert the uint64 value to sint64 for SQLite
889
917
  sint64_run_id = convert_uint64_to_sint64(run_id)
@@ -896,10 +924,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
896
924
  return None
897
925
 
898
926
  row = rows[0]
899
- return configsrecord_from_bytes(row["federation_options"])
927
+ return configrecord_from_bytes(row["federation_options"])
900
928
 
901
929
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
902
- """Acknowledge a ping received from a node, serving as a heartbeat."""
930
+ """Acknowledge a ping received from a node, serving as a heartbeat.
931
+
932
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
933
+ marking the node as offline, where PING_PATIENCE = 2 in default.
934
+ """
903
935
  sint64_node_id = convert_uint64_to_sint64(node_id)
904
936
 
905
937
  # Check if the node exists in the `node` table
@@ -909,7 +941,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
909
941
 
910
942
  # Update `online_until` and `ping_interval` for the given `node_id`
911
943
  query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
912
- self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
944
+ self.query(
945
+ query,
946
+ (
947
+ time.time() + PING_PATIENCE * ping_interval,
948
+ ping_interval,
949
+ sint64_node_id,
950
+ ),
951
+ )
913
952
  return True
914
953
 
915
954
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
@@ -1026,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1026
1065
  "run_id": message.metadata.run_id,
1027
1066
  "src_node_id": message.metadata.src_node_id,
1028
1067
  "dst_node_id": message.metadata.dst_node_id,
1029
- "reply_to_message": message.metadata.reply_to_message,
1068
+ "reply_to_message_id": message.metadata.reply_to_message_id,
1030
1069
  "created_at": message.metadata.created_at,
1031
1070
  "delivered_at": message.metadata.delivered_at,
1032
1071
  "ttl": message.metadata.ttl,
@@ -1036,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1036
1075
  }
1037
1076
 
1038
1077
  if message.has_content():
1039
- result["content"] = recordset_to_proto(message.content).SerializeToString()
1078
+ result["content"] = recorddict_to_proto(message.content).SerializeToString()
1040
1079
  else:
1041
1080
  result["error"] = error_to_proto(message.error).SerializeToString()
1042
1081
 
@@ -1047,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
1047
1086
  """Transform dict to Message."""
1048
1087
  content, error = None, None
1049
1088
  if (b_content := message_dict.pop("content")) is not None:
1050
- content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
1089
+ content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
1051
1090
  if (b_error := message_dict.pop("error")) is not None:
1052
1091
  error = error_from_proto(ProtoError.FromString(b_error))
1053
1092
 
1054
1093
  # Metadata constructor doesn't allow passing created_at. We set it later
1055
1094
  metadata = Metadata(
1056
- **{
1057
- k: v
1058
- for k, v in message_dict.items()
1059
- if k not in ["created_at", "delivered_at"]
1060
- }
1095
+ **{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
1061
1096
  )
1062
- msg = Message(metadata=metadata, content=content, error=error)
1063
- msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
1097
+ msg = make_message(metadata=metadata, content=content, error=error)
1064
1098
  msg.metadata.delivered_at = message_dict["delivered_at"]
1065
1099
  return msg
1066
1100
 
@@ -19,21 +19,22 @@ from os import urandom
19
19
  from typing import Optional
20
20
  from uuid import UUID, uuid4
21
21
 
22
- from flwr.common import ConfigsRecord, Context, Error, Message, Metadata, now, serde
23
- from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
22
+ from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
23
+ from flwr.common.constant import (
24
+ SUPERLINK_NODE_ID,
25
+ ErrorCode,
26
+ MessageType,
27
+ Status,
28
+ SubStatus,
29
+ )
30
+ from flwr.common.message import make_message
24
31
  from flwr.common.typing import RunStatus
25
32
 
26
33
  # pylint: disable=E0611
27
34
  from flwr.proto.message_pb2 import Context as ProtoContext
28
- from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
35
+ from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
29
36
 
30
37
  # pylint: enable=E0611
31
-
32
- NODE_UNAVAILABLE_ERROR_REASON = (
33
- "Error: Node Unavailable - The destination node is currently unavailable. "
34
- "It exceeds the time limit specified in its last ping."
35
- )
36
-
37
38
  VALID_RUN_STATUS_TRANSITIONS = {
38
39
  (Status.PENDING, Status.STARTING),
39
40
  (Status.STARTING, Status.RUNNING),
@@ -54,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
54
55
  REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
55
56
  "Error: Reply Message Unavailable - The reply message has expired."
56
57
  )
58
+ NODE_UNAVAILABLE_ERROR_REASON = (
59
+ "Error: Node Unavailable - The destination node is currently unavailable. "
60
+ "It exceeds twice the time limit specified in its last ping."
61
+ )
57
62
 
58
63
 
59
64
  def generate_rand_int_from_bytes(
@@ -167,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
167
172
  return serde.context_from_proto(ProtoContext.FromString(context_bytes))
168
173
 
169
174
 
170
- def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes:
171
- """Serialize a `ConfigsRecord` to bytes."""
172
- return serde.configs_record_to_proto(configs_record).SerializeToString()
175
+ def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
176
+ """Serialize a `ConfigRecord` to bytes."""
177
+ return serde.config_record_to_proto(config_record).SerializeToString()
173
178
 
174
179
 
175
- def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
176
- """Deserialize `ConfigsRecord` from bytes."""
177
- return serde.configs_record_from_proto(
178
- ProtoConfigsRecord.FromString(configsrecord_bytes)
180
+ def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
181
+ """Deserialize `ConfigRecord` from bytes."""
182
+ return serde.config_record_from_proto(
183
+ ProtoConfigRecord.FromString(configrecord_bytes)
179
184
  )
180
185
 
181
186
 
@@ -231,7 +236,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
231
236
  return status.sub_status == ""
232
237
 
233
238
 
234
- def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Message:
239
+ def create_message_error_unavailable_res_message(
240
+ ins_metadata: Metadata, error_type: str
241
+ ) -> Message:
235
242
  """Generate an error Message that the SuperLink returns carrying the specified
236
243
  error."""
237
244
  current_time = now().timestamp()
@@ -241,22 +248,31 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
241
248
  message_id=str(uuid4()),
242
249
  src_node_id=SUPERLINK_NODE_ID,
243
250
  dst_node_id=SUPERLINK_NODE_ID,
244
- reply_to_message=ins_metadata.message_id,
251
+ reply_to_message_id=ins_metadata.message_id,
245
252
  group_id=ins_metadata.group_id,
246
253
  message_type=ins_metadata.message_type,
254
+ created_at=current_time,
247
255
  ttl=ttl,
248
256
  )
249
257
 
250
- return Message(
258
+ return make_message(
251
259
  metadata=metadata,
252
260
  error=Error(
253
- code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
254
- reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
261
+ code=(
262
+ ErrorCode.REPLY_MESSAGE_UNAVAILABLE
263
+ if error_type == "msg_unavail"
264
+ else ErrorCode.NODE_UNAVAILABLE
265
+ ),
266
+ reason=(
267
+ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
268
+ if error_type == "msg_unavail"
269
+ else NODE_UNAVAILABLE_ERROR_REASON
270
+ ),
255
271
  ),
256
272
  )
257
273
 
258
274
 
259
- def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Message:
275
+ def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
260
276
  """Error to indicate that the enquired Message had expired before reply arrived or
261
277
  that it isn't found."""
262
278
  metadata = Metadata(
@@ -264,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
264
280
  message_id=str(uuid4()),
265
281
  src_node_id=SUPERLINK_NODE_ID,
266
282
  dst_node_id=SUPERLINK_NODE_ID,
267
- reply_to_message=str(reply_to_message),
283
+ reply_to_message_id=str(reply_to_message_id),
268
284
  group_id="", # Unknown
269
- message_type="", # Unknown
285
+ message_type=MessageType.SYSTEM,
286
+ created_at=now().timestamp(),
270
287
  ttl=0,
271
288
  )
272
289
 
273
- return Message(
290
+ return make_message(
274
291
  metadata=metadata,
275
292
  error=Error(
276
293
  code=ErrorCode.MESSAGE_UNAVAILABLE,
@@ -358,14 +375,63 @@ def verify_found_message_replies(
358
375
  ret_dict: dict[UUID, Message] = {}
359
376
  current = current_time if current_time else now().timestamp()
360
377
  for message_res in found_message_res_list:
361
- message_ins_id = UUID(message_res.metadata.reply_to_message)
378
+ message_ins_id = UUID(message_res.metadata.reply_to_message_id)
362
379
  if update_set:
363
380
  inquired_message_ids.remove(message_ins_id)
364
381
  # Check if the reply Message has expired
365
382
  if message_ttl_has_expired(message_res.metadata, current):
366
383
  # No need to insert the error Message
367
384
  message_res = create_message_error_unavailable_res_message(
368
- found_message_ins_dict[message_ins_id].metadata
385
+ found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
369
386
  )
370
387
  ret_dict[message_ins_id] = message_res
371
388
  return ret_dict
389
+
390
+
391
+ def check_node_availability_for_in_message(
392
+ inquired_in_message_ids: set[UUID],
393
+ found_in_message_dict: dict[UUID, Message],
394
+ node_id_to_online_until: dict[int, float],
395
+ current_time: Optional[float] = None,
396
+ update_set: bool = True,
397
+ ) -> dict[UUID, Message]:
398
+ """Check node availability for given Message and generate error reply Message if
399
+ unavailable. A Message error indicating node unavailability will be generated for
400
+ each given Message whose destination node is offline or non-existent.
401
+
402
+ Parameters
403
+ ----------
404
+ inquired_in_message_ids : set[UUID]
405
+ Set of Message IDs for which to check destination node availability.
406
+ found_in_message_dict : dict[UUID, Message]
407
+ Dictionary containing all found Message indexed by their IDs.
408
+ node_id_to_online_until : dict[int, float]
409
+ Dictionary mapping node IDs to their online-until timestamps.
410
+ current_time : Optional[float] (default: None)
411
+ The current time to check for expiration. If set to `None`, the current time
412
+ will automatically be set to the current timestamp using `now().timestamp()`.
413
+ update_set : bool (default: True)
414
+ If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
415
+ by default True.
416
+
417
+ Returns
418
+ -------
419
+ dict[UUID, Message]
420
+ A dictionary of error Message indexed by the corresponding Message ID.
421
+ """
422
+ ret_dict = {}
423
+ current = current_time if current_time else now().timestamp()
424
+ for in_message_id in list(inquired_in_message_ids):
425
+ in_message = found_in_message_dict[in_message_id]
426
+ node_id = in_message.metadata.dst_node_id
427
+ online_until = node_id_to_online_until.get(node_id)
428
+ # Generate a reply message containing an error reply
429
+ # if the node is offline or doesn't exist.
430
+ if online_until is None or online_until < current:
431
+ if update_set:
432
+ inquired_in_message_ids.remove(in_message_id)
433
+ reply_message = create_message_error_unavailable_res_message(
434
+ in_message.metadata, "node_unavail"
435
+ )
436
+ ret_dict[in_message_id] = reply_message
437
+ return ret_dict
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ from uuid import UUID
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import ConfigsRecord, Message
25
+ from flwr.common import ConfigRecord, Message
26
26
  from flwr.common.constant import SUPERLINK_NODE_ID, Status
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
127
127
  request.fab_version,
128
128
  fab_hash,
129
129
  user_config_from_proto(request.override_config),
130
- ConfigsRecord(),
130
+ ConfigRecord(),
131
131
  )
132
132
  return CreateRunResponse(run_id=run_id)
133
133
 
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
206
206
 
207
207
  # Delete the instruction Messages and their replies if found
208
208
  message_ins_ids_to_delete = {
209
- UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
209
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
210
210
  }
211
211
 
212
212
  state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
@@ -24,7 +24,7 @@ from grpc import ServicerContext
24
24
  from flwr.common.constant import Status
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.serde import (
27
- configs_record_to_proto,
27
+ config_record_to_proto,
28
28
  context_from_proto,
29
29
  context_to_proto,
30
30
  fab_to_proto,
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
182
182
  )
183
183
  return GetFederationOptionsResponse()
184
184
  return GetFederationOptionsResponse(
185
- federation_options=configs_record_to_proto(federation_options)
185
+ federation_options=config_record_to_proto(federation_options)
186
186
  )
flwr/server/typing.py CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
19
19
 
20
20
  from flwr.common import Context
21
21
 
22
- from .driver import Driver
22
+ from .grid import Grid
23
23
  from .serverapp_components import ServerAppComponents
24
24
 
25
- ServerAppCallable = Callable[[Driver, Context], None]
26
- Workflow = Callable[[Driver, Context], None]
25
+ ServerAppCallable = Callable[[Grid, Context], None]
26
+ Workflow = Callable[[Grid, Context], None]
27
27
  ServerFn = Callable[[Context], ServerAppComponents]
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
68
68
 
69
69
  # Link respose to original message
70
70
  if not is_reply_message:
71
- if metadata.reply_to_message != "":
72
- validation_errors.append("`metadata.reply_to_message` MUST not be set.")
71
+ if metadata.reply_to_message_id != "":
72
+ validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
73
73
  if metadata.src_node_id != SUPERLINK_NODE_ID:
74
74
  validation_errors.append(
75
75
  f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
79
79
  f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
80
80
  )
81
81
  else:
82
- if metadata.reply_to_message == "":
83
- validation_errors.append("`metadata.reply_to_message` MUST be set.")
82
+ if metadata.reply_to_message_id == "":
83
+ validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
84
84
  if metadata.src_node_id == SUPERLINK_NODE_ID:
85
85
  validation_errors.append(
86
86
  f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"