flwr-nightly 1.19.0.dev20250528__py3-none-any.whl → 1.19.0.dev20250530__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. flwr/cli/utils.py +11 -3
  2. flwr/client/mod/comms_mods.py +36 -17
  3. flwr/common/auth_plugin/auth_plugin.py +9 -3
  4. flwr/common/exit_handlers.py +30 -0
  5. flwr/common/inflatable_grpc_utils.py +27 -13
  6. flwr/common/message.py +11 -0
  7. flwr/common/record/array.py +10 -21
  8. flwr/common/record/arrayrecord.py +1 -1
  9. flwr/common/recorddict_compat.py +2 -2
  10. flwr/common/serde.py +1 -1
  11. flwr/proto/fleet_pb2.py +16 -16
  12. flwr/proto/fleet_pb2.pyi +5 -5
  13. flwr/proto/message_pb2.py +10 -10
  14. flwr/proto/message_pb2.pyi +4 -4
  15. flwr/proto/serverappio_pb2.py +26 -26
  16. flwr/proto/serverappio_pb2.pyi +5 -5
  17. flwr/server/app.py +45 -57
  18. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -0
  19. flwr/server/superlink/fleet/message_handler/message_handler.py +34 -7
  20. flwr/server/superlink/fleet/rest_rere/rest_api.py +5 -2
  21. flwr/server/superlink/linkstate/utils.py +8 -5
  22. flwr/server/superlink/serverappio/serverappio_servicer.py +45 -5
  23. flwr/server/superlink/utils.py +29 -0
  24. flwr/supercore/object_store/__init__.py +2 -1
  25. flwr/supercore/object_store/in_memory_object_store.py +9 -2
  26. flwr/supercore/object_store/object_store.py +12 -0
  27. flwr/superexec/exec_grpc.py +4 -3
  28. flwr/superexec/exec_user_auth_interceptor.py +33 -4
  29. flwr/supernode/start_client_internal.py +144 -170
  30. {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/METADATA +1 -1
  31. {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/RECORD +33 -33
  32. {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/WHEEL +0 -0
  33. {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/entry_points.txt +0 -0
@@ -20,15 +20,15 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
20
20
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
21
21
 
22
22
 
23
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x87\x02\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12]\n\x17msg_to_children_mapping\x18\x03 \x03(\x0b\x32<.flwr.proto.PushInsMessagesRequest.MsgToChildrenMappingEntry\x1aR\n\x19MsgToChildrenMappingEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xe4\x08\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x62\x06proto3')
23
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x8d\x02\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x61\n\x19msg_to_descendant_mapping\x18\x03 \x03(\x0b\x32>.flwr.proto.PushInsMessagesRequest.MsgToDescendantMappingEntry\x1aT\n\x1bMsgToDescendantMappingEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xe4\x08\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x62\x06proto3')
24
24
 
25
25
  _globals = globals()
26
26
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
27
27
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.serverappio_pb2', _globals)
28
28
  if _descriptor._USE_C_DESCRIPTORS == False:
29
29
  DESCRIPTOR._options = None
30
- _globals['_PUSHINSMESSAGESREQUEST_MSGTOCHILDRENMAPPINGENTRY']._options = None
31
- _globals['_PUSHINSMESSAGESREQUEST_MSGTOCHILDRENMAPPINGENTRY']._serialized_options = b'8\001'
30
+ _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._options = None
31
+ _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_options = b'8\001'
32
32
  _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._options = None
33
33
  _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_options = b'8\001'
34
34
  _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._options = None
@@ -38,27 +38,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
38
38
  _globals['_GETNODESRESPONSE']._serialized_start=222
39
39
  _globals['_GETNODESRESPONSE']._serialized_end=273
40
40
  _globals['_PUSHINSMESSAGESREQUEST']._serialized_start=276
41
- _globals['_PUSHINSMESSAGESREQUEST']._serialized_end=539
42
- _globals['_PUSHINSMESSAGESREQUEST_MSGTOCHILDRENMAPPINGENTRY']._serialized_start=457
43
- _globals['_PUSHINSMESSAGESREQUEST_MSGTOCHILDRENMAPPINGENTRY']._serialized_end=539
44
- _globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=542
45
- _globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=746
46
- _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=671
47
- _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=746
48
- _globals['_PULLRESMESSAGESREQUEST']._serialized_start=748
49
- _globals['_PULLRESMESSAGESREQUEST']._serialized_end=809
50
- _globals['_PULLRESMESSAGESRESPONSE']._serialized_start=812
51
- _globals['_PULLRESMESSAGESRESPONSE']._serialized_end=1039
52
- _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=964
53
- _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=1039
54
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=1041
55
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=1069
56
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=1071
57
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1198
58
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1200
59
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1283
60
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1285
61
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1315
62
- _globals['_SERVERAPPIO']._serialized_start=1318
63
- _globals['_SERVERAPPIO']._serialized_end=2442
41
+ _globals['_PUSHINSMESSAGESREQUEST']._serialized_end=545
42
+ _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_start=461
43
+ _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_end=545
44
+ _globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=548
45
+ _globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=752
46
+ _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=677
47
+ _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=752
48
+ _globals['_PULLRESMESSAGESREQUEST']._serialized_start=754
49
+ _globals['_PULLRESMESSAGESREQUEST']._serialized_end=815
50
+ _globals['_PULLRESMESSAGESRESPONSE']._serialized_start=818
51
+ _globals['_PULLRESMESSAGESRESPONSE']._serialized_end=1045
52
+ _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=970
53
+ _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=1045
54
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=1047
55
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=1075
56
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=1077
57
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1204
58
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1206
59
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1289
60
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1291
61
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1321
62
+ _globals['_SERVERAPPIO']._serialized_start=1324
63
+ _globals['_SERVERAPPIO']._serialized_end=2448
64
64
  # @@protoc_insertion_point(module_scope)
@@ -42,7 +42,7 @@ global___GetNodesResponse = GetNodesResponse
42
42
  class PushInsMessagesRequest(google.protobuf.message.Message):
43
43
  """PushMessages messages"""
44
44
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
45
- class MsgToChildrenMappingEntry(google.protobuf.message.Message):
45
+ class MsgToDescendantMappingEntry(google.protobuf.message.Message):
46
46
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
47
47
  KEY_FIELD_NUMBER: builtins.int
48
48
  VALUE_FIELD_NUMBER: builtins.int
@@ -59,19 +59,19 @@ class PushInsMessagesRequest(google.protobuf.message.Message):
59
59
 
60
60
  MESSAGES_LIST_FIELD_NUMBER: builtins.int
61
61
  RUN_ID_FIELD_NUMBER: builtins.int
62
- MSG_TO_CHILDREN_MAPPING_FIELD_NUMBER: builtins.int
62
+ MSG_TO_DESCENDANT_MAPPING_FIELD_NUMBER: builtins.int
63
63
  @property
64
64
  def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
65
65
  run_id: builtins.int
66
66
  @property
67
- def msg_to_children_mapping(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.message_pb2.ObjectIDs]: ...
67
+ def msg_to_descendant_mapping(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.message_pb2.ObjectIDs]: ...
68
68
  def __init__(self,
69
69
  *,
70
70
  messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
71
71
  run_id: builtins.int = ...,
72
- msg_to_children_mapping: typing.Optional[typing.Mapping[typing.Text, flwr.proto.message_pb2.ObjectIDs]] = ...,
72
+ msg_to_descendant_mapping: typing.Optional[typing.Mapping[typing.Text, flwr.proto.message_pb2.ObjectIDs]] = ...,
73
73
  ) -> None: ...
74
- def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","msg_to_children_mapping",b"msg_to_children_mapping","run_id",b"run_id"]) -> None: ...
74
+ def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","msg_to_descendant_mapping",b"msg_to_descendant_mapping","run_id",b"run_id"]) -> None: ...
75
75
  global___PushInsMessagesRequest = PushInsMessagesRequest
76
76
 
77
77
  class PushInsMessagesResponse(google.protobuf.message.Message):
flwr/server/app.py CHANGED
@@ -27,7 +27,7 @@ from collections.abc import Sequence
27
27
  from logging import DEBUG, INFO, WARN
28
28
  from pathlib import Path
29
29
  from time import sleep
30
- from typing import Any, Optional
30
+ from typing import Any, Callable, Optional, TypeVar
31
31
 
32
32
  import grpc
33
33
  import yaml
@@ -85,6 +85,7 @@ from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
85
85
 
86
86
  DATABASE = ":flwr-in-memory-state:"
87
87
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
88
+ P = TypeVar("P", ExecAuthPlugin, ExecAuthzPlugin)
88
89
 
89
90
 
90
91
  try:
@@ -151,21 +152,16 @@ def run_superlink() -> None:
151
152
  verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
152
153
 
153
154
  auth_plugin: Optional[ExecAuthPlugin] = None
154
- authz_plugin: Optional[ExecAuthzPlugin] = None # pylint: disable=unused-variable
155
+ authz_plugin: Optional[ExecAuthzPlugin] = None
155
156
  event_log_plugin: Optional[EventLogWriterPlugin] = None
156
157
  # Load the auth plugin if the args.user_auth_config is provided
157
158
  if cfg_path := getattr(args, "user_auth_config", None):
158
- auth_plugin = _try_obtain_exec_auth_plugin(Path(cfg_path), verify_tls_cert)
159
+ auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins(
160
+ Path(cfg_path), verify_tls_cert
161
+ )
159
162
  # Enable event logging if the args.enable_event_log is True
160
163
  if args.enable_event_log:
161
164
  event_log_plugin = _try_obtain_exec_event_log_writer_plugin()
162
- # Enable authorization if the args.enable_authorization is True
163
- if args.enable_authorization:
164
- # pylint: disable=unused-variable
165
- authz_plugin = _try_obtain_exec_authz_plugin( # noqa: F841
166
- Path(cfg_path), verify_tls_cert
167
- )
168
- # pylint: enable=unused-variable
169
165
 
170
166
  # Initialize StateFactory
171
167
  state_factory = LinkStateFactory(args.database)
@@ -188,6 +184,7 @@ def run_superlink() -> None:
188
184
  [args.executor_config] if args.executor_config else args.executor_config
189
185
  ),
190
186
  auth_plugin=auth_plugin,
187
+ authz_plugin=authz_plugin,
191
188
  event_log_plugin=event_log_plugin,
192
189
  )
193
190
  grpc_servers = [exec_server]
@@ -258,6 +255,7 @@ def run_superlink() -> None:
258
255
  args.ssl_certfile,
259
256
  state_factory,
260
257
  ffs_factory,
258
+ objectstore_factory,
261
259
  num_workers,
262
260
  ),
263
261
  daemon=True,
@@ -483,62 +481,50 @@ def _try_load_public_keys_node_authentication(
483
481
  return node_public_keys
484
482
 
485
483
 
486
- def _try_obtain_exec_auth_plugin(
484
+ def _try_obtain_exec_auth_plugins(
487
485
  config_path: Path, verify_tls_cert: bool
488
- ) -> Optional[ExecAuthPlugin]:
486
+ ) -> tuple[ExecAuthPlugin, ExecAuthzPlugin]:
487
+ """Obtain Exec API authentication and authorization plugins."""
489
488
  # Load YAML file
490
489
  with config_path.open("r", encoding="utf-8") as file:
491
490
  config: dict[str, Any] = yaml.safe_load(file)
492
491
 
493
- # Load authentication configuration
494
- auth_config: dict[str, Any] = config.get("authentication", {})
495
- auth_type: str = auth_config.get(AUTH_TYPE_YAML_KEY, "")
496
-
497
- # Load authentication plugin
498
- try:
499
- all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
500
- auth_plugin_class = all_plugins[auth_type]
501
- return auth_plugin_class(
502
- user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
503
- )
504
- except KeyError:
505
- if auth_type != "":
506
- sys.exit(
507
- f'Authentication type "{auth_type}" is not supported. '
508
- "Please provide a valid authentication type in the configuration."
492
+ def _load_plugin(
493
+ section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
494
+ ) -> P:
495
+ section_cfg = config.get(section, {})
496
+ auth_plugin_name = section_cfg.get(yaml_key, "")
497
+ try:
498
+ plugins: dict[str, type[P]] = loader()
499
+ plugin_cls: type[P] = plugins[auth_plugin_name]
500
+ return plugin_cls(
501
+ user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
509
502
  )
510
- sys.exit("No authentication type is provided in the configuration.")
511
- except NotImplementedError:
512
- sys.exit("No authentication plugins are currently supported.")
513
-
514
-
515
- def _try_obtain_exec_authz_plugin(
516
- config_path: Path, verify_tls_cert: bool
517
- ) -> Optional[ExecAuthzPlugin]:
518
- # Load YAML file
519
- with config_path.open("r", encoding="utf-8") as file:
520
- config: dict[str, Any] = yaml.safe_load(file)
503
+ except KeyError:
504
+ if auth_plugin_name:
505
+ sys.exit(
506
+ f"{yaml_key}: {auth_plugin_name} is not supported. "
507
+ f"Please provide a valid {section} type in the configuration."
508
+ )
509
+ sys.exit(f"No {section} type is provided in the configuration.")
510
+ except NotImplementedError:
511
+ sys.exit(f"No {section} plugins are currently supported.")
521
512
 
522
- # Load authentication configuration
523
- authz_config: dict[str, Any] = config.get("authorization", {})
524
- authz_type: str = authz_config.get(AUTHZ_TYPE_YAML_KEY, "")
513
+ # Load authentication plugin
514
+ auth_plugin = _load_plugin(
515
+ section="authentication",
516
+ yaml_key=AUTH_TYPE_YAML_KEY,
517
+ loader=get_exec_auth_plugins,
518
+ )
525
519
 
526
520
  # Load authorization plugin
527
- try:
528
- all_plugins: dict[str, type[ExecAuthzPlugin]] = get_exec_authz_plugins()
529
- authz_plugin_class = all_plugins[authz_type]
530
- return authz_plugin_class(
531
- user_authz_config_path=config_path, verify_tls_cert=verify_tls_cert
532
- )
533
- except KeyError:
534
- if authz_type != "":
535
- sys.exit(
536
- f'Authentication type "{authz_type}" is not supported. '
537
- "Please provide a valid authorization type in the configuration."
538
- )
539
- sys.exit("No authorization type is provided in the configuration.")
540
- except NotImplementedError:
541
- sys.exit("No authorization plugins are currently supported.")
521
+ authz_plugin = _load_plugin(
522
+ section="authorization",
523
+ yaml_key=AUTHZ_TYPE_YAML_KEY,
524
+ loader=get_exec_authz_plugins,
525
+ )
526
+
527
+ return auth_plugin, authz_plugin
542
528
 
543
529
 
544
530
  def _try_obtain_exec_event_log_writer_plugin() -> Optional[EventLogWriterPlugin]:
@@ -636,6 +622,7 @@ def _run_fleet_api_rest(
636
622
  ssl_certfile: Optional[str],
637
623
  state_factory: LinkStateFactory,
638
624
  ffs_factory: FfsFactory,
625
+ objectstore_factory: ObjectStoreFactory,
639
626
  num_workers: int,
640
627
  ) -> None:
641
628
  """Run ServerAppIo API (REST-based)."""
@@ -651,6 +638,7 @@ def _run_fleet_api_rest(
651
638
  # See: https://www.starlette.io/applications/#accessing-the-app-instance
652
639
  fast_api_app.state.STATE_FACTORY = state_factory
653
640
  fast_api_app.state.FFS_FACTORY = ffs_factory
641
+ fast_api_app.state.OBJECTSTORE_FACTORY = objectstore_factory
654
642
 
655
643
  uvicorn.run(
656
644
  app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
@@ -114,6 +114,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
114
114
  return message_handler.pull_messages(
115
115
  request=request,
116
116
  state=self.state_factory.state(),
117
+ store=self.objectstore_factory.store(),
117
118
  )
118
119
 
119
120
  def PushMessages(
@@ -133,6 +134,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
133
134
  res = message_handler.push_messages(
134
135
  request=request,
135
136
  state=self.state_factory.state(),
137
+ store=self.objectstore_factory.store(),
136
138
  )
137
139
  except InvalidRunStatusException as e:
138
140
  abort_grpc_context(e.message, context)
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
  """Fleet API message handlers."""
16
16
 
17
-
17
+ from logging import ERROR
18
18
  from typing import Optional
19
19
 
20
- from flwr.common import Message
20
+ from flwr.common import Message, log
21
21
  from flwr.common.constant import Status
22
22
  from flwr.common.serde import (
23
23
  fab_to_proto,
@@ -42,6 +42,7 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
42
42
  SendNodeHeartbeatRequest,
43
43
  SendNodeHeartbeatResponse,
44
44
  )
45
+ from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
45
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
46
47
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
47
48
  GetRunRequest,
@@ -51,6 +52,9 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
51
52
  from flwr.server.superlink.ffs.ffs import Ffs
52
53
  from flwr.server.superlink.linkstate import LinkState
53
54
  from flwr.server.superlink.utils import check_abort
55
+ from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
56
+
57
+ from ...utils import store_mapping_and_register_objects
54
58
 
55
59
 
56
60
  def create_node(
@@ -86,7 +90,9 @@ def send_node_heartbeat(
86
90
 
87
91
 
88
92
  def pull_messages(
89
- request: PullMessagesRequest, state: LinkState
93
+ request: PullMessagesRequest,
94
+ state: LinkState,
95
+ store: ObjectStore,
90
96
  ) -> PullMessagesResponse:
91
97
  """Pull Messages handler."""
92
98
  # Get node_id if client node is not anonymous
@@ -98,14 +104,31 @@ def pull_messages(
98
104
 
99
105
  # Convert to Messages
100
106
  msg_proto = []
107
+ objects_to_pull: dict[str, ObjectIDs] = {}
101
108
  for msg in message_list:
102
- msg_proto.append(message_to_proto(msg))
103
-
104
- return PullMessagesResponse(messages_list=msg_proto)
109
+ try:
110
+ msg_proto.append(message_to_proto(msg))
111
+
112
+ msg_object_id = msg.metadata.message_id
113
+ descendants = store.get_message_descendant_ids(msg_object_id)
114
+ # Include the object_id of the message itself
115
+ objects_to_pull[msg_object_id] = ObjectIDs(
116
+ object_ids=descendants + [msg_object_id]
117
+ )
118
+ except NoObjectInStoreError as e:
119
+ log(ERROR, e.message)
120
+ # Delete message ins from state
121
+ state.delete_messages(message_ins_ids={msg_object_id})
122
+
123
+ return PullMessagesResponse(
124
+ messages_list=msg_proto, objects_to_pull=objects_to_pull
125
+ )
105
126
 
106
127
 
107
128
  def push_messages(
108
- request: PushMessagesRequest, state: LinkState
129
+ request: PushMessagesRequest,
130
+ state: LinkState,
131
+ store: ObjectStore,
109
132
  ) -> PushMessagesResponse:
110
133
  """Push Messages handler."""
111
134
  # Convert Message from proto
@@ -123,10 +146,14 @@ def push_messages(
123
146
  # Store Message in State
124
147
  message_id: Optional[str] = state.store_message_res(message=msg)
125
148
 
149
+ # Store Message object to descendants mapping and preregister objects
150
+ objects_to_push = store_mapping_and_register_objects(store, request=request)
151
+
126
152
  # Build response
127
153
  response = PushMessagesResponse(
128
154
  reconnect=Reconnect(reconnect=5),
129
155
  results={str(message_id): 0},
156
+ objects_to_push=objects_to_push,
130
157
  )
131
158
  return response
132
159
 
@@ -43,6 +43,7 @@ from flwr.server.superlink.ffs.ffs import Ffs
43
43
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
44
44
  from flwr.server.superlink.fleet.message_handler import message_handler
45
45
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
46
+ from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
46
47
 
47
48
  try:
48
49
  from starlette.applications import Starlette
@@ -113,9 +114,10 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
113
114
  """Pull PullMessages."""
114
115
  # Get state from app
115
116
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
117
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
116
118
 
117
119
  # Handle message
118
- return message_handler.pull_messages(request=request, state=state)
120
+ return message_handler.pull_messages(request=request, state=state, store=store)
119
121
 
120
122
 
121
123
  @rest_request_response(PushMessagesRequest)
@@ -123,9 +125,10 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
123
125
  """Pull PushMessages."""
124
126
  # Get state from app
125
127
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
128
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
126
129
 
127
130
  # Handle message
128
- return message_handler.push_messages(request=request, state=state)
131
+ return message_handler.push_messages(request=request, state=state, store=store)
129
132
 
130
133
 
131
134
  @rest_request_response(SendNodeHeartbeatRequest)
@@ -17,7 +17,6 @@
17
17
 
18
18
  from os import urandom
19
19
  from typing import Optional
20
- from uuid import uuid4
21
20
 
22
21
  from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
23
22
  from flwr.common.constant import (
@@ -246,7 +245,7 @@ def create_message_error_unavailable_res_message(
246
245
  ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
247
246
  metadata = Metadata(
248
247
  run_id=ins_metadata.run_id,
249
- message_id=str(uuid4()),
248
+ message_id="",
250
249
  src_node_id=SUPERLINK_NODE_ID,
251
250
  dst_node_id=SUPERLINK_NODE_ID,
252
251
  reply_to_message_id=ins_metadata.message_id,
@@ -256,7 +255,7 @@ def create_message_error_unavailable_res_message(
256
255
  ttl=ttl,
257
256
  )
258
257
 
259
- return make_message(
258
+ msg = make_message(
260
259
  metadata=metadata,
261
260
  error=Error(
262
261
  code=(
@@ -271,6 +270,8 @@ def create_message_error_unavailable_res_message(
271
270
  ),
272
271
  ),
273
272
  )
273
+ msg.metadata.__dict__["_message_id"] = msg.object_id
274
+ return msg
274
275
 
275
276
 
276
277
  def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Message:
@@ -278,7 +279,7 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
278
279
  that it isn't found."""
279
280
  metadata = Metadata(
280
281
  run_id=0, # Unknown
281
- message_id=str(uuid4()),
282
+ message_id="",
282
283
  src_node_id=SUPERLINK_NODE_ID,
283
284
  dst_node_id=SUPERLINK_NODE_ID,
284
285
  reply_to_message_id=reply_to_message_id,
@@ -288,13 +289,15 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
288
289
  ttl=0,
289
290
  )
290
291
 
291
- return make_message(
292
+ msg = make_message(
292
293
  metadata=metadata,
293
294
  error=Error(
294
295
  code=ErrorCode.MESSAGE_UNAVAILABLE,
295
296
  reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
296
297
  ),
297
298
  )
299
+ msg.metadata.__dict__["_message_id"] = msg.object_id
300
+ return msg
298
301
 
299
302
 
300
303
  def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
@@ -16,14 +16,14 @@
16
16
 
17
17
 
18
18
  import threading
19
- from logging import DEBUG, INFO
19
+ from logging import DEBUG, ERROR, INFO
20
20
  from typing import Optional
21
21
 
22
22
  import grpc
23
23
 
24
24
  from flwr.common import Message
25
25
  from flwr.common.constant import SUPERLINK_NODE_ID, Status
26
- from flwr.common.inflatable import check_body_len_consistency
26
+ from flwr.common.inflatable import check_body_len_consistency, get_desdendant_object_ids
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
29
29
  context_from_proto,
@@ -47,6 +47,7 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
47
47
  PushLogsResponse,
48
48
  )
49
49
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
50
+ ObjectIDs,
50
51
  PullObjectRequest,
51
52
  PullObjectResponse,
52
53
  PushObjectRequest,
@@ -78,7 +79,9 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
78
79
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
79
80
  from flwr.server.superlink.utils import abort_if
80
81
  from flwr.server.utils.validator import validate_message
81
- from flwr.supercore.object_store import ObjectStoreFactory
82
+ from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
83
+
84
+ from ..utils import store_mapping_and_register_objects
82
85
 
83
86
 
84
87
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -158,10 +161,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
158
161
  message_id: Optional[str] = state.store_message_ins(message=message)
159
162
  message_ids.append(message_id)
160
163
 
164
+ # Init store
165
+ store = self.objectstore_factory.store()
166
+
167
+ # Store Message object to descendants mapping and preregister objects
168
+ objects_to_push = store_mapping_and_register_objects(store, request=request)
169
+
161
170
  return PushInsMessagesResponse(
162
171
  message_ids=[
163
172
  str(message_id) if message_id else "" for message_id in message_ids
164
- ]
173
+ ],
174
+ objects_to_push=objects_to_push,
165
175
  )
166
176
 
167
177
  def PullMessages(
@@ -173,6 +183,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
173
183
  # Init state
174
184
  state: LinkState = self.state_factory.state()
175
185
 
186
+ # Init store
187
+ store = self.objectstore_factory.store()
188
+
176
189
  # Abort if the run is not running
177
190
  abort_if(
178
191
  request.run_id,
@@ -186,6 +199,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
186
199
  message_ids=set(request.message_ids)
187
200
  )
188
201
 
202
+ # Register messages generated by LinkState in the Store for consistency
203
+ for msg_res in messages_res:
204
+ if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
205
+ descendants = list(get_desdendant_object_ids(msg_res))
206
+ message_obj_id = msg_res.metadata.message_id
207
+ # Store mapping
208
+ store.set_message_descendant_ids(
209
+ msg_object_id=message_obj_id, descendant_ids=descendants
210
+ )
211
+ # Preregister
212
+ store.preregister(descendants + [message_obj_id])
213
+
189
214
  # Delete the instruction Messages and their replies if found
190
215
  message_ins_ids_to_delete = {
191
216
  msg_res.metadata.reply_to_message_id for msg_res in messages_res
@@ -195,6 +220,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
195
220
 
196
221
  # Convert Messages to proto
197
222
  messages_list = []
223
+ objects_to_pull: dict[str, ObjectIDs] = {}
198
224
  while messages_res:
199
225
  msg = messages_res.pop(0)
200
226
 
@@ -207,7 +233,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
207
233
  )
208
234
  messages_list.append(message_to_proto(msg))
209
235
 
210
- return PullResMessagesResponse(messages_list=messages_list)
236
+ try:
237
+ msg_object_id = msg.metadata.message_id
238
+ descendants = store.get_message_descendant_ids(msg_object_id)
239
+ # Include the object_id of the message itself
240
+ objects_to_pull[msg_object_id] = ObjectIDs(
241
+ object_ids=descendants + [msg_object_id]
242
+ )
243
+ except NoObjectInStoreError as e:
244
+ log(ERROR, e.message)
245
+ # Delete message ins from state
246
+ state.delete_messages(message_ins_ids={msg_object_id})
247
+
248
+ return PullResMessagesResponse(
249
+ messages_list=messages_list, objects_to_pull=objects_to_pull
250
+ )
211
251
 
212
252
  def GetRun(
213
253
  self, request: GetRunRequest, context: grpc.ServicerContext
@@ -21,7 +21,11 @@ import grpc
21
21
 
22
22
  from flwr.common.constant import Status, SubStatus
23
23
  from flwr.common.typing import RunStatus
24
+ from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
25
+ from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
26
+ from flwr.proto.serverappio_pb2 import PushInsMessagesRequest # pylint: disable=E0611
24
27
  from flwr.server.superlink.linkstate import LinkState
28
+ from flwr.supercore.object_store import ObjectStore
25
29
 
26
30
  _STATUS_TO_MSG = {
27
31
  Status.PENDING: "Run is pending.",
@@ -63,3 +67,28 @@ def abort_if(
63
67
  """Abort context if status of the provided `run_id` is in `abort_status_list`."""
64
68
  msg = check_abort(run_id, abort_status_list, state)
65
69
  abort_grpc_context(msg, context)
70
+
71
+
72
+ def store_mapping_and_register_objects(
73
+ store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
74
+ ) -> dict[str, ObjectIDs]:
75
+ """Store Message object to descendants mapping and preregister objects."""
76
+ objects_to_push: dict[str, ObjectIDs] = {}
77
+ for (
78
+ message_obj_id,
79
+ descendant_obj_ids,
80
+ ) in request.msg_to_descendant_mapping.items():
81
+ descendants = list(descendant_obj_ids.object_ids)
82
+ # Store mapping
83
+ store.set_message_descendant_ids(
84
+ msg_object_id=message_obj_id, descendant_ids=descendants
85
+ )
86
+
87
+ # Preregister
88
+ object_ids_just_registered = store.preregister(descendants + [message_obj_id])
89
+ # Keep track of objects that need to be pushed
90
+ objects_to_push[message_obj_id] = ObjectIDs(
91
+ object_ids=object_ids_just_registered
92
+ )
93
+
94
+ return objects_to_push
@@ -14,10 +14,11 @@
14
14
  # ==============================================================================
15
15
  """Flower ObjectStore."""
16
16
 
17
- from .object_store import ObjectStore
17
+ from .object_store import NoObjectInStoreError, ObjectStore
18
18
  from .object_store_factory import ObjectStoreFactory
19
19
 
20
20
  __all__ = [
21
+ "NoObjectInStoreError",
21
22
  "ObjectStore",
22
23
  "ObjectStoreFactory",
23
24
  ]