flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -15,29 +15,26 @@
15
15
  """Utility functions for State."""
16
16
 
17
17
 
18
- from logging import ERROR
19
18
  from os import urandom
20
- from typing import Optional, Union
19
+ from typing import Optional
21
20
  from uuid import UUID, uuid4
22
21
 
23
- from flwr.common import ConfigsRecord, Context, log, now, serde
24
- from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
22
+ from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
23
+ from flwr.common.constant import (
24
+ SUPERLINK_NODE_ID,
25
+ ErrorCode,
26
+ MessageType,
27
+ Status,
28
+ SubStatus,
29
+ )
30
+ from flwr.common.message import make_message
25
31
  from flwr.common.typing import RunStatus
26
32
 
27
33
  # pylint: disable=E0611
28
- from flwr.proto.error_pb2 import Error
29
34
  from flwr.proto.message_pb2 import Context as ProtoContext
30
- from flwr.proto.node_pb2 import Node
31
- from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
32
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
35
+ from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
33
36
 
34
37
  # pylint: enable=E0611
35
-
36
- NODE_UNAVAILABLE_ERROR_REASON = (
37
- "Error: Node Unavailable - The destination node is currently unavailable. "
38
- "It exceeds the time limit specified in its last ping."
39
- )
40
-
41
38
  VALID_RUN_STATUS_TRANSITIONS = {
42
39
  (Status.PENDING, Status.STARTING),
43
40
  (Status.STARTING, Status.RUNNING),
@@ -58,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
58
55
  REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
59
56
  "Error: Reply Message Unavailable - The reply message has expired."
60
57
  )
58
+ NODE_UNAVAILABLE_ERROR_REASON = (
59
+ "Error: Node Unavailable - The destination node is currently unavailable. "
60
+ "It exceeds twice the time limit specified in its last ping."
61
+ )
61
62
 
62
63
 
63
64
  def generate_rand_int_from_bytes(
@@ -171,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
171
172
  return serde.context_from_proto(ProtoContext.FromString(context_bytes))
172
173
 
173
174
 
174
- def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes:
175
- """Serialize a `ConfigsRecord` to bytes."""
176
- return serde.configs_record_to_proto(configs_record).SerializeToString()
175
+ def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
176
+ """Serialize a `ConfigRecord` to bytes."""
177
+ return serde.config_record_to_proto(config_record).SerializeToString()
177
178
 
178
179
 
179
- def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
180
- """Deserialize `ConfigsRecord` from bytes."""
181
- return serde.configs_record_from_proto(
182
- ProtoConfigsRecord.FromString(configsrecord_bytes)
180
+ def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
181
+ """Deserialize `ConfigRecord` from bytes."""
182
+ return serde.config_record_from_proto(
183
+ ProtoConfigRecord.FromString(configrecord_bytes)
183
184
  )
184
185
 
185
186
 
@@ -235,165 +236,202 @@ def has_valid_sub_status(status: RunStatus) -> bool:
235
236
  return status.sub_status == ""
236
237
 
237
238
 
238
- def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> TaskRes:
239
- """Generate a TaskRes with a TaskIns unavailable error.
240
-
241
- Parameters
242
- ----------
243
- taskins_id : Union[str, UUID]
244
- The ID of the unavailable TaskIns.
245
-
246
- Returns
247
- -------
248
- TaskRes
249
- A TaskRes with an error code MESSAGE_UNAVAILABLE to indicate that the
250
- inquired TaskIns ID cannot be found (due to non-existence or expiration).
251
- """
239
+ def create_message_error_unavailable_res_message(
240
+ ins_metadata: Metadata, error_type: str
241
+ ) -> Message:
242
+ """Generate an error Message that the SuperLink returns carrying the specified
243
+ error."""
252
244
  current_time = now().timestamp()
253
- return TaskRes(
254
- task_id=str(uuid4()),
255
- group_id="", # Unknown group ID
256
- run_id=0, # Unknown run ID
257
- task=Task(
258
- # This function is only called by SuperLink, and thus it's the producer.
259
- producer=Node(node_id=SUPERLINK_NODE_ID),
260
- consumer=Node(node_id=SUPERLINK_NODE_ID),
261
- created_at=current_time,
262
- ttl=0,
263
- ancestry=[str(taskins_id)],
264
- task_type="", # Unknown message type
265
- error=Error(
266
- code=ErrorCode.MESSAGE_UNAVAILABLE,
267
- reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
245
+ ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
246
+ metadata = Metadata(
247
+ run_id=ins_metadata.run_id,
248
+ message_id=str(uuid4()),
249
+ src_node_id=SUPERLINK_NODE_ID,
250
+ dst_node_id=SUPERLINK_NODE_ID,
251
+ reply_to_message_id=ins_metadata.message_id,
252
+ group_id=ins_metadata.group_id,
253
+ message_type=ins_metadata.message_type,
254
+ created_at=current_time,
255
+ ttl=ttl,
256
+ )
257
+
258
+ return make_message(
259
+ metadata=metadata,
260
+ error=Error(
261
+ code=(
262
+ ErrorCode.REPLY_MESSAGE_UNAVAILABLE
263
+ if error_type == "msg_unavail"
264
+ else ErrorCode.NODE_UNAVAILABLE
265
+ ),
266
+ reason=(
267
+ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
268
+ if error_type == "msg_unavail"
269
+ else NODE_UNAVAILABLE_ERROR_REASON
268
270
  ),
269
271
  ),
270
272
  )
271
273
 
272
274
 
273
- def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
274
- """Generate a TaskRes with a reply message unavailable error from a TaskIns.
275
-
276
- Parameters
277
- ----------
278
- ref_taskins : TaskIns
279
- The reference TaskIns object.
275
+ def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
276
+ """Error to indicate that the enquired Message had expired before reply arrived or
277
+ that it isn't found."""
278
+ metadata = Metadata(
279
+ run_id=0, # Unknown
280
+ message_id=str(uuid4()),
281
+ src_node_id=SUPERLINK_NODE_ID,
282
+ dst_node_id=SUPERLINK_NODE_ID,
283
+ reply_to_message_id=str(reply_to_message_id),
284
+ group_id="", # Unknown
285
+ message_type=MessageType.SYSTEM,
286
+ created_at=now().timestamp(),
287
+ ttl=0,
288
+ )
280
289
 
281
- Returns
282
- -------
283
- TaskRes
284
- The generated TaskRes with an error code REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
285
- indicating that the original TaskRes has expired.
286
- """
287
- current_time = now().timestamp()
288
- ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
289
- if ttl < 0:
290
- log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
291
- ttl = 0
292
- return TaskRes(
293
- task_id=str(uuid4()),
294
- group_id=ref_taskins.group_id,
295
- run_id=ref_taskins.run_id,
296
- task=Task(
297
- # This function is only called by SuperLink, and thus it's the producer.
298
- producer=Node(node_id=SUPERLINK_NODE_ID),
299
- consumer=Node(node_id=SUPERLINK_NODE_ID),
300
- created_at=current_time,
301
- ttl=ttl,
302
- ancestry=[ref_taskins.task_id],
303
- task_type=ref_taskins.task.task_type,
304
- error=Error(
305
- code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
306
- reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
307
- ),
290
+ return make_message(
291
+ metadata=metadata,
292
+ error=Error(
293
+ code=ErrorCode.MESSAGE_UNAVAILABLE,
294
+ reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
308
295
  ),
309
296
  )
310
297
 
311
298
 
312
- def has_expired(task_ins_or_res: Union[TaskIns, TaskRes], current_time: float) -> bool:
313
- """Check if the TaskIns/TaskRes has expired."""
314
- return task_ins_or_res.task.ttl + task_ins_or_res.task.created_at < current_time
299
+ def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
300
+ """Check if the Message has expired."""
301
+ return message_metadata.ttl + message_metadata.created_at < current_time
315
302
 
316
303
 
317
- def verify_taskins_ids(
318
- inquired_taskins_ids: set[UUID],
319
- found_taskins_dict: dict[UUID, TaskIns],
304
+ def verify_message_ids(
305
+ inquired_message_ids: set[UUID],
306
+ found_message_ins_dict: dict[UUID, Message],
320
307
  current_time: Optional[float] = None,
321
308
  update_set: bool = True,
322
- ) -> dict[UUID, TaskRes]:
323
- """Verify found TaskIns and generate error TaskRes for invalid ones.
309
+ ) -> dict[UUID, Message]:
310
+ """Verify found Messages and generate error Messages for invalid ones.
324
311
 
325
312
  Parameters
326
313
  ----------
327
- inquired_taskins_ids : set[UUID]
328
- Set of TaskIns IDs for which to generate error TaskRes if invalid.
329
- found_taskins_dict : dict[UUID, TaskIns]
330
- Dictionary containing all found TaskIns indexed by their IDs.
314
+ inquired_message_ids : set[UUID]
315
+ Set of Message IDs for which to generate error Message if invalid.
316
+ found_message_ins_dict : dict[UUID, Message]
317
+ Dictionary containing all found Message indexed by their IDs.
331
318
  current_time : Optional[float] (default: None)
332
319
  The current time to check for expiration. If set to `None`, the current time
333
320
  will automatically be set to the current timestamp using `now().timestamp()`.
334
321
  update_set : bool (default: True)
335
- If True, the `inquired_taskins_ids` will be updated to remove invalid ones,
322
+ If True, the `inquired_message_ids` will be updated to remove invalid ones,
336
323
  by default True.
337
324
 
338
325
  Returns
339
326
  -------
340
- dict[UUID, TaskRes]
341
- A dictionary of error TaskRes indexed by the corresponding TaskIns ID.
327
+ dict[UUID, Message]
328
+ A dictionary of error Message indexed by the corresponding ID of the message
329
+ they are a reply of.
342
330
  """
343
331
  ret_dict = {}
344
332
  current = current_time if current_time else now().timestamp()
345
- for taskins_id in list(inquired_taskins_ids):
346
- # Generate error TaskRes if the task_ins doesn't exist or has expired
347
- taskins = found_taskins_dict.get(taskins_id)
348
- if taskins is None or has_expired(taskins, current):
333
+ for message_id in list(inquired_message_ids):
334
+ # Generate error message if the inquired message doesn't exist or has expired
335
+ message_ins = found_message_ins_dict.get(message_id)
336
+ if message_ins is None or message_ttl_has_expired(
337
+ message_ins.metadata, current
338
+ ):
349
339
  if update_set:
350
- inquired_taskins_ids.remove(taskins_id)
351
- taskres = create_taskres_for_unavailable_taskins(taskins_id)
352
- ret_dict[taskins_id] = taskres
340
+ inquired_message_ids.remove(message_id)
341
+ message_res = create_message_error_unavailable_ins_message(message_id)
342
+ ret_dict[message_id] = message_res
353
343
  return ret_dict
354
344
 
355
345
 
356
- def verify_found_taskres(
357
- inquired_taskins_ids: set[UUID],
358
- found_taskins_dict: dict[UUID, TaskIns],
359
- found_taskres_list: list[TaskRes],
346
+ def verify_found_message_replies(
347
+ inquired_message_ids: set[UUID],
348
+ found_message_ins_dict: dict[UUID, Message],
349
+ found_message_res_list: list[Message],
360
350
  current_time: Optional[float] = None,
361
351
  update_set: bool = True,
362
- ) -> dict[UUID, TaskRes]:
363
- """Verify found TaskRes and generate error TaskRes for invalid ones.
352
+ ) -> dict[UUID, Message]:
353
+ """Verify found Message replies and generate error Message for invalid ones.
364
354
 
365
355
  Parameters
366
356
  ----------
367
- inquired_taskins_ids : set[UUID]
368
- Set of TaskIns IDs for which to generate error TaskRes if invalid.
369
- found_taskins_dict : dict[UUID, TaskIns]
370
- Dictionary containing all found TaskIns indexed by their IDs.
371
- found_taskres_list : dict[TaskIns, TaskRes]
372
- List of found TaskRes to be verified.
357
+ inquired_message_ids : set[UUID]
358
+ Set of Message IDs for which to generate error Message if invalid.
359
+ found_message_ins_dict : dict[UUID, Message]
360
+ Dictionary containing all found instruction Messages indexed by their IDs.
361
+ found_message_res_list : dict[Message, Message]
362
+ List of found Message to be verified.
373
363
  current_time : Optional[float] (default: None)
374
364
  The current time to check for expiration. If set to `None`, the current time
375
365
  will automatically be set to the current timestamp using `now().timestamp()`.
376
366
  update_set : bool (default: True)
377
- If True, the `inquired_taskins_ids` will be updated to remove ones
378
- that have a TaskRes, by default True.
367
+ If True, the `inquired_message_ids` will be updated to remove ones
368
+ that have a reply Message, by default True.
379
369
 
380
370
  Returns
381
371
  -------
382
- dict[UUID, TaskRes]
383
- A dictionary of TaskRes indexed by the corresponding TaskIns ID.
372
+ dict[UUID, Message]
373
+ A dictionary of Message indexed by the corresponding Message ID.
384
374
  """
385
- ret_dict: dict[UUID, TaskRes] = {}
375
+ ret_dict: dict[UUID, Message] = {}
386
376
  current = current_time if current_time else now().timestamp()
387
- for taskres in found_taskres_list:
388
- taskins_id = UUID(taskres.task.ancestry[0])
377
+ for message_res in found_message_res_list:
378
+ message_ins_id = UUID(message_res.metadata.reply_to_message_id)
389
379
  if update_set:
390
- inquired_taskins_ids.remove(taskins_id)
391
- # Check if the TaskRes has expired
392
- if has_expired(taskres, current):
393
- # No need to insert the error TaskRes
394
- taskres = create_taskres_for_unavailable_taskres(
395
- found_taskins_dict[taskins_id]
380
+ inquired_message_ids.remove(message_ins_id)
381
+ # Check if the reply Message has expired
382
+ if message_ttl_has_expired(message_res.metadata, current):
383
+ # No need to insert the error Message
384
+ message_res = create_message_error_unavailable_res_message(
385
+ found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
386
+ )
387
+ ret_dict[message_ins_id] = message_res
388
+ return ret_dict
389
+
390
+
391
+ def check_node_availability_for_in_message(
392
+ inquired_in_message_ids: set[UUID],
393
+ found_in_message_dict: dict[UUID, Message],
394
+ node_id_to_online_until: dict[int, float],
395
+ current_time: Optional[float] = None,
396
+ update_set: bool = True,
397
+ ) -> dict[UUID, Message]:
398
+ """Check node availability for given Message and generate error reply Message if
399
+ unavailable. A Message error indicating node unavailability will be generated for
400
+ each given Message whose destination node is offline or non-existent.
401
+
402
+ Parameters
403
+ ----------
404
+ inquired_in_message_ids : set[UUID]
405
+ Set of Message IDs for which to check destination node availability.
406
+ found_in_message_dict : dict[UUID, Message]
407
+ Dictionary containing all found Message indexed by their IDs.
408
+ node_id_to_online_until : dict[int, float]
409
+ Dictionary mapping node IDs to their online-until timestamps.
410
+ current_time : Optional[float] (default: None)
411
+ The current time to check for expiration. If set to `None`, the current time
412
+ will automatically be set to the current timestamp using `now().timestamp()`.
413
+ update_set : bool (default: True)
414
+ If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
415
+ by default True.
416
+
417
+ Returns
418
+ -------
419
+ dict[UUID, Message]
420
+ A dictionary of error Message indexed by the corresponding Message ID.
421
+ """
422
+ ret_dict = {}
423
+ current = current_time if current_time else now().timestamp()
424
+ for in_message_id in list(inquired_in_message_ids):
425
+ in_message = found_in_message_dict[in_message_id]
426
+ node_id = in_message.metadata.dst_node_id
427
+ online_until = node_id_to_online_until.get(node_id)
428
+ # Generate a reply message containing an error reply
429
+ # if the node is offline or doesn't exist.
430
+ if online_until is None or online_until < current:
431
+ if update_set:
432
+ inquired_in_message_ids.remove(in_message_id)
433
+ reply_message = create_message_error_unavailable_res_message(
434
+ in_message.metadata, "node_unavail"
396
435
  )
397
- taskres.task.delivered_at = now().isoformat()
398
- ret_dict[taskins_id] = taskres
436
+ ret_dict[in_message_id] = reply_message
399
437
  return ret_dict
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@ from uuid import UUID
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import ConfigsRecord
26
- from flwr.common.constant import Status
25
+ from flwr.common import ConfigRecord, Message
26
+ from flwr.common.constant import SUPERLINK_NODE_ID, Status
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
29
29
  context_from_proto,
@@ -31,9 +31,7 @@ from flwr.common.serde import (
31
31
  fab_from_proto,
32
32
  fab_to_proto,
33
33
  message_from_proto,
34
- message_from_taskres,
35
34
  message_to_proto,
36
- message_to_taskins,
37
35
  run_status_from_proto,
38
36
  run_status_to_proto,
39
37
  run_to_proto,
@@ -69,12 +67,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
69
67
  PushServerAppOutputsRequest,
70
68
  PushServerAppOutputsResponse,
71
69
  )
72
- from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
73
70
  from flwr.server.superlink.ffs.ffs import Ffs
74
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
75
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
76
73
  from flwr.server.superlink.utils import abort_if
77
- from flwr.server.utils.validator import validate_task_ins_or_res
74
+ from flwr.server.utils.validator import validate_message
78
75
 
79
76
 
80
77
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -130,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
130
127
  request.fab_version,
131
128
  fab_hash,
132
129
  user_config_from_proto(request.override_config),
133
- ConfigsRecord(),
130
+ ConfigRecord(),
134
131
  )
135
132
  return CreateRunResponse(run_id=run_id)
136
133
 
@@ -161,20 +158,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
161
158
  while request.messages_list:
162
159
  message_proto = request.messages_list.pop(0)
163
160
  message = message_from_proto(message_proto=message_proto)
164
- task_ins = message_to_taskins(message=message)
165
- validation_errors = validate_task_ins_or_res(task_ins)
161
+ validation_errors = validate_message(message, is_reply_message=False)
166
162
  _raise_if(
167
163
  validation_error=bool(validation_errors),
168
164
  request_name="PushMessages",
169
165
  detail=", ".join(validation_errors),
170
166
  )
171
167
  _raise_if(
172
- validation_error=request.run_id != task_ins.run_id,
168
+ validation_error=request.run_id != message.metadata.run_id,
173
169
  request_name="PushMessages",
174
- detail="`task_ins` has mismatched `run_id`",
170
+ detail="`Message.metadata` has mismatched `run_id`",
175
171
  )
176
172
  # Store
177
- message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
173
+ message_id: Optional[UUID] = state.store_message_ins(message=message)
178
174
  message_ids.append(message_id)
179
175
 
180
176
  return PushInsMessagesResponse(
@@ -200,32 +196,34 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
200
196
  context,
201
197
  )
202
198
 
203
- # Convert each task_id str to UUID
199
+ # Convert each message_id str to UUID
204
200
  message_ids: set[UUID] = {
205
201
  UUID(message_id) for message_id in request.message_ids
206
202
  }
207
203
 
208
204
  # Read from state
209
- task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
205
+ messages_res: list[Message] = state.get_message_res(message_ids=message_ids)
210
206
 
211
- # Delete the TaskIns/TaskRes pairs if TaskRes is found
212
- task_ins_ids_to_delete = {
213
- UUID(task_res.task.ancestry[0]) for task_res in task_res_list
207
+ # Delete the instruction Messages and their replies if found
208
+ message_ins_ids_to_delete = {
209
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
214
210
  }
215
211
 
216
- state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
212
+ state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
217
213
 
218
- # Convert to Messages
214
+ # Convert Messages to proto
219
215
  messages_list = []
220
- while task_res_list:
221
- task_res = task_res_list.pop(0)
222
- _raise_if(
223
- validation_error=request.run_id != task_res.run_id,
224
- request_name="PullMessages",
225
- detail="`task_res` has mismatched `run_id`",
226
- )
227
- message = message_from_taskres(taskres=task_res)
228
- messages_list.append(message_to_proto(message))
216
+ while messages_res:
217
+ msg = messages_res.pop(0)
218
+
219
+ # Skip `run_id` check for SuperLink generated replies
220
+ if msg.metadata.src_node_id != SUPERLINK_NODE_ID:
221
+ _raise_if(
222
+ validation_error=request.run_id != msg.metadata.run_id,
223
+ request_name="PullMessages",
224
+ detail="`message.metadata` has mismatched `run_id`",
225
+ )
226
+ messages_list.append(message_to_proto(msg))
229
227
 
230
228
  return PullResMessagesResponse(messages_list=messages_list)
231
229
 
@@ -24,7 +24,7 @@ from grpc import ServicerContext
24
24
  from flwr.common.constant import Status
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.serde import (
27
- configs_record_to_proto,
27
+ config_record_to_proto,
28
28
  context_from_proto,
29
29
  context_to_proto,
30
30
  fab_to_proto,
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
182
182
  )
183
183
  return GetFederationOptionsResponse()
184
184
  return GetFederationOptionsResponse(
185
- federation_options=configs_record_to_proto(federation_options)
185
+ federation_options=config_record_to_proto(federation_options)
186
186
  )
flwr/server/typing.py CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
19
19
 
20
20
  from flwr.common import Context
21
21
 
22
- from .driver import Driver
22
+ from .grid import Grid
23
23
  from .serverapp_components import ServerAppComponents
24
24
 
25
- ServerAppCallable = Callable[[Driver, Context], None]
26
- Workflow = Callable[[Driver, Context], None]
25
+ ServerAppCallable = Callable[[Grid, Context], None]
26
+ Workflow = Callable[[Grid, Context], None]
27
27
  ServerFn = Callable[[Context], ServerAppComponents]
@@ -16,9 +16,9 @@
16
16
 
17
17
 
18
18
  from .tensorboard import tensorboard as tensorboard
19
- from .validator import validate_task_ins_or_res as validate_task_ins_or_res
19
+ from .validator import validate_message as validate_message
20
20
 
21
21
  __all__ = [
22
22
  "tensorboard",
23
- "validate_task_ins_or_res",
23
+ "validate_message",
24
24
  ]