flwr-nightly 1.19.0.dev20250528__py3-none-any.whl → 1.19.0.dev20250530__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/utils.py +11 -3
- flwr/client/mod/comms_mods.py +36 -17
- flwr/common/auth_plugin/auth_plugin.py +9 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/inflatable_grpc_utils.py +27 -13
- flwr/common/message.py +11 -0
- flwr/common/record/array.py +10 -21
- flwr/common/record/arrayrecord.py +1 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/serde.py +1 -1
- flwr/proto/fleet_pb2.py +16 -16
- flwr/proto/fleet_pb2.pyi +5 -5
- flwr/proto/message_pb2.py +10 -10
- flwr/proto/message_pb2.pyi +4 -4
- flwr/proto/serverappio_pb2.py +26 -26
- flwr/proto/serverappio_pb2.pyi +5 -5
- flwr/server/app.py +45 -57
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +5 -2
- flwr/server/superlink/linkstate/utils.py +8 -5
- flwr/server/superlink/serverappio/serverappio_servicer.py +45 -5
- flwr/server/superlink/utils.py +29 -0
- flwr/supercore/object_store/__init__.py +2 -1
- flwr/supercore/object_store/in_memory_object_store.py +9 -2
- flwr/supercore/object_store/object_store.py +12 -0
- flwr/superexec/exec_grpc.py +4 -3
- flwr/superexec/exec_user_auth_interceptor.py +33 -4
- flwr/supernode/start_client_internal.py +144 -170
- {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/RECORD +33 -33
- {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250528.dist-info → flwr_nightly-1.19.0.dev20250530.dist-info}/entry_points.txt +0 -0
flwr/proto/serverappio_pb2.py
CHANGED
@@ -20,15 +20,15 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
20
20
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
21
21
|
|
22
22
|
|
23
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\
|
23
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x8d\x02\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x61\n\x19msg_to_descendant_mapping\x18\x03 \x03(\x0b\x32>.flwr.proto.PushInsMessagesRequest.MsgToDescendantMappingEntry\x1aT\n\x1bMsgToDescendantMappingEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xe4\x08\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x62\x06proto3')
|
24
24
|
|
25
25
|
_globals = globals()
|
26
26
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
27
27
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.serverappio_pb2', _globals)
|
28
28
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
29
29
|
DESCRIPTOR._options = None
|
30
|
-
_globals['
|
31
|
-
_globals['
|
30
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._options = None
|
31
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_options = b'8\001'
|
32
32
|
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._options = None
|
33
33
|
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_options = b'8\001'
|
34
34
|
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._options = None
|
@@ -38,27 +38,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
38
38
|
_globals['_GETNODESRESPONSE']._serialized_start=222
|
39
39
|
_globals['_GETNODESRESPONSE']._serialized_end=273
|
40
40
|
_globals['_PUSHINSMESSAGESREQUEST']._serialized_start=276
|
41
|
-
_globals['_PUSHINSMESSAGESREQUEST']._serialized_end=
|
42
|
-
_globals['
|
43
|
-
_globals['
|
44
|
-
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=
|
45
|
-
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=
|
46
|
-
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=
|
47
|
-
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=
|
48
|
-
_globals['_PULLRESMESSAGESREQUEST']._serialized_start=
|
49
|
-
_globals['_PULLRESMESSAGESREQUEST']._serialized_end=
|
50
|
-
_globals['_PULLRESMESSAGESRESPONSE']._serialized_start=
|
51
|
-
_globals['_PULLRESMESSAGESRESPONSE']._serialized_end=
|
52
|
-
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=
|
53
|
-
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=
|
54
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=
|
55
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=
|
56
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=
|
57
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=
|
58
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=
|
59
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=
|
60
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=
|
61
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=
|
62
|
-
_globals['_SERVERAPPIO']._serialized_start=
|
63
|
-
_globals['_SERVERAPPIO']._serialized_end=
|
41
|
+
_globals['_PUSHINSMESSAGESREQUEST']._serialized_end=545
|
42
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_start=461
|
43
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_end=545
|
44
|
+
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=548
|
45
|
+
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=752
|
46
|
+
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=677
|
47
|
+
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=752
|
48
|
+
_globals['_PULLRESMESSAGESREQUEST']._serialized_start=754
|
49
|
+
_globals['_PULLRESMESSAGESREQUEST']._serialized_end=815
|
50
|
+
_globals['_PULLRESMESSAGESRESPONSE']._serialized_start=818
|
51
|
+
_globals['_PULLRESMESSAGESRESPONSE']._serialized_end=1045
|
52
|
+
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=970
|
53
|
+
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=1045
|
54
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=1047
|
55
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=1075
|
56
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=1077
|
57
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1204
|
58
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1206
|
59
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1289
|
60
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1291
|
61
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1321
|
62
|
+
_globals['_SERVERAPPIO']._serialized_start=1324
|
63
|
+
_globals['_SERVERAPPIO']._serialized_end=2448
|
64
64
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/serverappio_pb2.pyi
CHANGED
@@ -42,7 +42,7 @@ global___GetNodesResponse = GetNodesResponse
|
|
42
42
|
class PushInsMessagesRequest(google.protobuf.message.Message):
|
43
43
|
"""PushMessages messages"""
|
44
44
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
45
|
-
class
|
45
|
+
class MsgToDescendantMappingEntry(google.protobuf.message.Message):
|
46
46
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
47
47
|
KEY_FIELD_NUMBER: builtins.int
|
48
48
|
VALUE_FIELD_NUMBER: builtins.int
|
@@ -59,19 +59,19 @@ class PushInsMessagesRequest(google.protobuf.message.Message):
|
|
59
59
|
|
60
60
|
MESSAGES_LIST_FIELD_NUMBER: builtins.int
|
61
61
|
RUN_ID_FIELD_NUMBER: builtins.int
|
62
|
-
|
62
|
+
MSG_TO_DESCENDANT_MAPPING_FIELD_NUMBER: builtins.int
|
63
63
|
@property
|
64
64
|
def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
|
65
65
|
run_id: builtins.int
|
66
66
|
@property
|
67
|
-
def
|
67
|
+
def msg_to_descendant_mapping(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.message_pb2.ObjectIDs]: ...
|
68
68
|
def __init__(self,
|
69
69
|
*,
|
70
70
|
messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
|
71
71
|
run_id: builtins.int = ...,
|
72
|
-
|
72
|
+
msg_to_descendant_mapping: typing.Optional[typing.Mapping[typing.Text, flwr.proto.message_pb2.ObjectIDs]] = ...,
|
73
73
|
) -> None: ...
|
74
|
-
def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","
|
74
|
+
def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","msg_to_descendant_mapping",b"msg_to_descendant_mapping","run_id",b"run_id"]) -> None: ...
|
75
75
|
global___PushInsMessagesRequest = PushInsMessagesRequest
|
76
76
|
|
77
77
|
class PushInsMessagesResponse(google.protobuf.message.Message):
|
flwr/server/app.py
CHANGED
@@ -27,7 +27,7 @@ from collections.abc import Sequence
|
|
27
27
|
from logging import DEBUG, INFO, WARN
|
28
28
|
from pathlib import Path
|
29
29
|
from time import sleep
|
30
|
-
from typing import Any, Optional
|
30
|
+
from typing import Any, Callable, Optional, TypeVar
|
31
31
|
|
32
32
|
import grpc
|
33
33
|
import yaml
|
@@ -85,6 +85,7 @@ from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
|
85
85
|
|
86
86
|
DATABASE = ":flwr-in-memory-state:"
|
87
87
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
88
|
+
P = TypeVar("P", ExecAuthPlugin, ExecAuthzPlugin)
|
88
89
|
|
89
90
|
|
90
91
|
try:
|
@@ -151,21 +152,16 @@ def run_superlink() -> None:
|
|
151
152
|
verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
|
152
153
|
|
153
154
|
auth_plugin: Optional[ExecAuthPlugin] = None
|
154
|
-
authz_plugin: Optional[ExecAuthzPlugin] = None
|
155
|
+
authz_plugin: Optional[ExecAuthzPlugin] = None
|
155
156
|
event_log_plugin: Optional[EventLogWriterPlugin] = None
|
156
157
|
# Load the auth plugin if the args.user_auth_config is provided
|
157
158
|
if cfg_path := getattr(args, "user_auth_config", None):
|
158
|
-
auth_plugin =
|
159
|
+
auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins(
|
160
|
+
Path(cfg_path), verify_tls_cert
|
161
|
+
)
|
159
162
|
# Enable event logging if the args.enable_event_log is True
|
160
163
|
if args.enable_event_log:
|
161
164
|
event_log_plugin = _try_obtain_exec_event_log_writer_plugin()
|
162
|
-
# Enable authorization if the args.enable_authorization is True
|
163
|
-
if args.enable_authorization:
|
164
|
-
# pylint: disable=unused-variable
|
165
|
-
authz_plugin = _try_obtain_exec_authz_plugin( # noqa: F841
|
166
|
-
Path(cfg_path), verify_tls_cert
|
167
|
-
)
|
168
|
-
# pylint: enable=unused-variable
|
169
165
|
|
170
166
|
# Initialize StateFactory
|
171
167
|
state_factory = LinkStateFactory(args.database)
|
@@ -188,6 +184,7 @@ def run_superlink() -> None:
|
|
188
184
|
[args.executor_config] if args.executor_config else args.executor_config
|
189
185
|
),
|
190
186
|
auth_plugin=auth_plugin,
|
187
|
+
authz_plugin=authz_plugin,
|
191
188
|
event_log_plugin=event_log_plugin,
|
192
189
|
)
|
193
190
|
grpc_servers = [exec_server]
|
@@ -258,6 +255,7 @@ def run_superlink() -> None:
|
|
258
255
|
args.ssl_certfile,
|
259
256
|
state_factory,
|
260
257
|
ffs_factory,
|
258
|
+
objectstore_factory,
|
261
259
|
num_workers,
|
262
260
|
),
|
263
261
|
daemon=True,
|
@@ -483,62 +481,50 @@ def _try_load_public_keys_node_authentication(
|
|
483
481
|
return node_public_keys
|
484
482
|
|
485
483
|
|
486
|
-
def
|
484
|
+
def _try_obtain_exec_auth_plugins(
|
487
485
|
config_path: Path, verify_tls_cert: bool
|
488
|
-
) ->
|
486
|
+
) -> tuple[ExecAuthPlugin, ExecAuthzPlugin]:
|
487
|
+
"""Obtain Exec API authentication and authorization plugins."""
|
489
488
|
# Load YAML file
|
490
489
|
with config_path.open("r", encoding="utf-8") as file:
|
491
490
|
config: dict[str, Any] = yaml.safe_load(file)
|
492
491
|
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
)
|
504
|
-
except KeyError:
|
505
|
-
if auth_type != "":
|
506
|
-
sys.exit(
|
507
|
-
f'Authentication type "{auth_type}" is not supported. '
|
508
|
-
"Please provide a valid authentication type in the configuration."
|
492
|
+
def _load_plugin(
|
493
|
+
section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
|
494
|
+
) -> P:
|
495
|
+
section_cfg = config.get(section, {})
|
496
|
+
auth_plugin_name = section_cfg.get(yaml_key, "")
|
497
|
+
try:
|
498
|
+
plugins: dict[str, type[P]] = loader()
|
499
|
+
plugin_cls: type[P] = plugins[auth_plugin_name]
|
500
|
+
return plugin_cls(
|
501
|
+
user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
509
502
|
)
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
with config_path.open("r", encoding="utf-8") as file:
|
520
|
-
config: dict[str, Any] = yaml.safe_load(file)
|
503
|
+
except KeyError:
|
504
|
+
if auth_plugin_name:
|
505
|
+
sys.exit(
|
506
|
+
f"{yaml_key}: {auth_plugin_name} is not supported. "
|
507
|
+
f"Please provide a valid {section} type in the configuration."
|
508
|
+
)
|
509
|
+
sys.exit(f"No {section} type is provided in the configuration.")
|
510
|
+
except NotImplementedError:
|
511
|
+
sys.exit(f"No {section} plugins are currently supported.")
|
521
512
|
|
522
|
-
# Load authentication
|
523
|
-
|
524
|
-
|
513
|
+
# Load authentication plugin
|
514
|
+
auth_plugin = _load_plugin(
|
515
|
+
section="authentication",
|
516
|
+
yaml_key=AUTH_TYPE_YAML_KEY,
|
517
|
+
loader=get_exec_auth_plugins,
|
518
|
+
)
|
525
519
|
|
526
520
|
# Load authorization plugin
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
if authz_type != "":
|
535
|
-
sys.exit(
|
536
|
-
f'Authentication type "{authz_type}" is not supported. '
|
537
|
-
"Please provide a valid authorization type in the configuration."
|
538
|
-
)
|
539
|
-
sys.exit("No authorization type is provided in the configuration.")
|
540
|
-
except NotImplementedError:
|
541
|
-
sys.exit("No authorization plugins are currently supported.")
|
521
|
+
authz_plugin = _load_plugin(
|
522
|
+
section="authorization",
|
523
|
+
yaml_key=AUTHZ_TYPE_YAML_KEY,
|
524
|
+
loader=get_exec_authz_plugins,
|
525
|
+
)
|
526
|
+
|
527
|
+
return auth_plugin, authz_plugin
|
542
528
|
|
543
529
|
|
544
530
|
def _try_obtain_exec_event_log_writer_plugin() -> Optional[EventLogWriterPlugin]:
|
@@ -636,6 +622,7 @@ def _run_fleet_api_rest(
|
|
636
622
|
ssl_certfile: Optional[str],
|
637
623
|
state_factory: LinkStateFactory,
|
638
624
|
ffs_factory: FfsFactory,
|
625
|
+
objectstore_factory: ObjectStoreFactory,
|
639
626
|
num_workers: int,
|
640
627
|
) -> None:
|
641
628
|
"""Run ServerAppIo API (REST-based)."""
|
@@ -651,6 +638,7 @@ def _run_fleet_api_rest(
|
|
651
638
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
652
639
|
fast_api_app.state.STATE_FACTORY = state_factory
|
653
640
|
fast_api_app.state.FFS_FACTORY = ffs_factory
|
641
|
+
fast_api_app.state.OBJECTSTORE_FACTORY = objectstore_factory
|
654
642
|
|
655
643
|
uvicorn.run(
|
656
644
|
app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
|
@@ -114,6 +114,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
114
114
|
return message_handler.pull_messages(
|
115
115
|
request=request,
|
116
116
|
state=self.state_factory.state(),
|
117
|
+
store=self.objectstore_factory.store(),
|
117
118
|
)
|
118
119
|
|
119
120
|
def PushMessages(
|
@@ -133,6 +134,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
133
134
|
res = message_handler.push_messages(
|
134
135
|
request=request,
|
135
136
|
state=self.state_factory.state(),
|
137
|
+
store=self.objectstore_factory.store(),
|
136
138
|
)
|
137
139
|
except InvalidRunStatusException as e:
|
138
140
|
abort_grpc_context(e.message, context)
|
@@ -14,10 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Fleet API message handlers."""
|
16
16
|
|
17
|
-
|
17
|
+
from logging import ERROR
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from flwr.common import Message
|
20
|
+
from flwr.common import Message, log
|
21
21
|
from flwr.common.constant import Status
|
22
22
|
from flwr.common.serde import (
|
23
23
|
fab_to_proto,
|
@@ -42,6 +42,7 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
42
42
|
SendNodeHeartbeatRequest,
|
43
43
|
SendNodeHeartbeatResponse,
|
44
44
|
)
|
45
|
+
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
45
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
46
47
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
47
48
|
GetRunRequest,
|
@@ -51,6 +52,9 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
51
52
|
from flwr.server.superlink.ffs.ffs import Ffs
|
52
53
|
from flwr.server.superlink.linkstate import LinkState
|
53
54
|
from flwr.server.superlink.utils import check_abort
|
55
|
+
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
|
56
|
+
|
57
|
+
from ...utils import store_mapping_and_register_objects
|
54
58
|
|
55
59
|
|
56
60
|
def create_node(
|
@@ -86,7 +90,9 @@ def send_node_heartbeat(
|
|
86
90
|
|
87
91
|
|
88
92
|
def pull_messages(
|
89
|
-
request: PullMessagesRequest,
|
93
|
+
request: PullMessagesRequest,
|
94
|
+
state: LinkState,
|
95
|
+
store: ObjectStore,
|
90
96
|
) -> PullMessagesResponse:
|
91
97
|
"""Pull Messages handler."""
|
92
98
|
# Get node_id if client node is not anonymous
|
@@ -98,14 +104,31 @@ def pull_messages(
|
|
98
104
|
|
99
105
|
# Convert to Messages
|
100
106
|
msg_proto = []
|
107
|
+
objects_to_pull: dict[str, ObjectIDs] = {}
|
101
108
|
for msg in message_list:
|
102
|
-
|
103
|
-
|
104
|
-
|
109
|
+
try:
|
110
|
+
msg_proto.append(message_to_proto(msg))
|
111
|
+
|
112
|
+
msg_object_id = msg.metadata.message_id
|
113
|
+
descendants = store.get_message_descendant_ids(msg_object_id)
|
114
|
+
# Include the object_id of the message itself
|
115
|
+
objects_to_pull[msg_object_id] = ObjectIDs(
|
116
|
+
object_ids=descendants + [msg_object_id]
|
117
|
+
)
|
118
|
+
except NoObjectInStoreError as e:
|
119
|
+
log(ERROR, e.message)
|
120
|
+
# Delete message ins from state
|
121
|
+
state.delete_messages(message_ins_ids={msg_object_id})
|
122
|
+
|
123
|
+
return PullMessagesResponse(
|
124
|
+
messages_list=msg_proto, objects_to_pull=objects_to_pull
|
125
|
+
)
|
105
126
|
|
106
127
|
|
107
128
|
def push_messages(
|
108
|
-
request: PushMessagesRequest,
|
129
|
+
request: PushMessagesRequest,
|
130
|
+
state: LinkState,
|
131
|
+
store: ObjectStore,
|
109
132
|
) -> PushMessagesResponse:
|
110
133
|
"""Push Messages handler."""
|
111
134
|
# Convert Message from proto
|
@@ -123,10 +146,14 @@ def push_messages(
|
|
123
146
|
# Store Message in State
|
124
147
|
message_id: Optional[str] = state.store_message_res(message=msg)
|
125
148
|
|
149
|
+
# Store Message object to descendants mapping and preregister objects
|
150
|
+
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
151
|
+
|
126
152
|
# Build response
|
127
153
|
response = PushMessagesResponse(
|
128
154
|
reconnect=Reconnect(reconnect=5),
|
129
155
|
results={str(message_id): 0},
|
156
|
+
objects_to_push=objects_to_push,
|
130
157
|
)
|
131
158
|
return response
|
132
159
|
|
@@ -43,6 +43,7 @@ from flwr.server.superlink.ffs.ffs import Ffs
|
|
43
43
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
44
44
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
45
45
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
46
|
+
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
46
47
|
|
47
48
|
try:
|
48
49
|
from starlette.applications import Starlette
|
@@ -113,9 +114,10 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
113
114
|
"""Pull PullMessages."""
|
114
115
|
# Get state from app
|
115
116
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
117
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
|
116
118
|
|
117
119
|
# Handle message
|
118
|
-
return message_handler.pull_messages(request=request, state=state)
|
120
|
+
return message_handler.pull_messages(request=request, state=state, store=store)
|
119
121
|
|
120
122
|
|
121
123
|
@rest_request_response(PushMessagesRequest)
|
@@ -123,9 +125,10 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
123
125
|
"""Pull PushMessages."""
|
124
126
|
# Get state from app
|
125
127
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
128
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
|
126
129
|
|
127
130
|
# Handle message
|
128
|
-
return message_handler.push_messages(request=request, state=state)
|
131
|
+
return message_handler.push_messages(request=request, state=state, store=store)
|
129
132
|
|
130
133
|
|
131
134
|
@rest_request_response(SendNodeHeartbeatRequest)
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
from os import urandom
|
19
19
|
from typing import Optional
|
20
|
-
from uuid import uuid4
|
21
20
|
|
22
21
|
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
23
22
|
from flwr.common.constant import (
|
@@ -246,7 +245,7 @@ def create_message_error_unavailable_res_message(
|
|
246
245
|
ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
|
247
246
|
metadata = Metadata(
|
248
247
|
run_id=ins_metadata.run_id,
|
249
|
-
message_id=
|
248
|
+
message_id="",
|
250
249
|
src_node_id=SUPERLINK_NODE_ID,
|
251
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
252
251
|
reply_to_message_id=ins_metadata.message_id,
|
@@ -256,7 +255,7 @@ def create_message_error_unavailable_res_message(
|
|
256
255
|
ttl=ttl,
|
257
256
|
)
|
258
257
|
|
259
|
-
|
258
|
+
msg = make_message(
|
260
259
|
metadata=metadata,
|
261
260
|
error=Error(
|
262
261
|
code=(
|
@@ -271,6 +270,8 @@ def create_message_error_unavailable_res_message(
|
|
271
270
|
),
|
272
271
|
),
|
273
272
|
)
|
273
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
274
|
+
return msg
|
274
275
|
|
275
276
|
|
276
277
|
def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Message:
|
@@ -278,7 +279,7 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
|
|
278
279
|
that it isn't found."""
|
279
280
|
metadata = Metadata(
|
280
281
|
run_id=0, # Unknown
|
281
|
-
message_id=
|
282
|
+
message_id="",
|
282
283
|
src_node_id=SUPERLINK_NODE_ID,
|
283
284
|
dst_node_id=SUPERLINK_NODE_ID,
|
284
285
|
reply_to_message_id=reply_to_message_id,
|
@@ -288,13 +289,15 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
|
|
288
289
|
ttl=0,
|
289
290
|
)
|
290
291
|
|
291
|
-
|
292
|
+
msg = make_message(
|
292
293
|
metadata=metadata,
|
293
294
|
error=Error(
|
294
295
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
295
296
|
reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
|
296
297
|
),
|
297
298
|
)
|
299
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
300
|
+
return msg
|
298
301
|
|
299
302
|
|
300
303
|
def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
|
@@ -16,14 +16,14 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import threading
|
19
|
-
from logging import DEBUG, INFO
|
19
|
+
from logging import DEBUG, ERROR, INFO
|
20
20
|
from typing import Optional
|
21
21
|
|
22
22
|
import grpc
|
23
23
|
|
24
24
|
from flwr.common import Message
|
25
25
|
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
26
|
-
from flwr.common.inflatable import check_body_len_consistency
|
26
|
+
from flwr.common.inflatable import check_body_len_consistency, get_desdendant_object_ids
|
27
27
|
from flwr.common.logger import log
|
28
28
|
from flwr.common.serde import (
|
29
29
|
context_from_proto,
|
@@ -47,6 +47,7 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
47
47
|
PushLogsResponse,
|
48
48
|
)
|
49
49
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
50
|
+
ObjectIDs,
|
50
51
|
PullObjectRequest,
|
51
52
|
PullObjectResponse,
|
52
53
|
PushObjectRequest,
|
@@ -78,7 +79,9 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
78
79
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
79
80
|
from flwr.server.superlink.utils import abort_if
|
80
81
|
from flwr.server.utils.validator import validate_message
|
81
|
-
from flwr.supercore.object_store import ObjectStoreFactory
|
82
|
+
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
|
83
|
+
|
84
|
+
from ..utils import store_mapping_and_register_objects
|
82
85
|
|
83
86
|
|
84
87
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
@@ -158,10 +161,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
158
161
|
message_id: Optional[str] = state.store_message_ins(message=message)
|
159
162
|
message_ids.append(message_id)
|
160
163
|
|
164
|
+
# Init store
|
165
|
+
store = self.objectstore_factory.store()
|
166
|
+
|
167
|
+
# Store Message object to descendants mapping and preregister objects
|
168
|
+
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
169
|
+
|
161
170
|
return PushInsMessagesResponse(
|
162
171
|
message_ids=[
|
163
172
|
str(message_id) if message_id else "" for message_id in message_ids
|
164
|
-
]
|
173
|
+
],
|
174
|
+
objects_to_push=objects_to_push,
|
165
175
|
)
|
166
176
|
|
167
177
|
def PullMessages(
|
@@ -173,6 +183,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
173
183
|
# Init state
|
174
184
|
state: LinkState = self.state_factory.state()
|
175
185
|
|
186
|
+
# Init store
|
187
|
+
store = self.objectstore_factory.store()
|
188
|
+
|
176
189
|
# Abort if the run is not running
|
177
190
|
abort_if(
|
178
191
|
request.run_id,
|
@@ -186,6 +199,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
186
199
|
message_ids=set(request.message_ids)
|
187
200
|
)
|
188
201
|
|
202
|
+
# Register messages generated by LinkState in the Store for consistency
|
203
|
+
for msg_res in messages_res:
|
204
|
+
if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
|
205
|
+
descendants = list(get_desdendant_object_ids(msg_res))
|
206
|
+
message_obj_id = msg_res.metadata.message_id
|
207
|
+
# Store mapping
|
208
|
+
store.set_message_descendant_ids(
|
209
|
+
msg_object_id=message_obj_id, descendant_ids=descendants
|
210
|
+
)
|
211
|
+
# Preregister
|
212
|
+
store.preregister(descendants + [message_obj_id])
|
213
|
+
|
189
214
|
# Delete the instruction Messages and their replies if found
|
190
215
|
message_ins_ids_to_delete = {
|
191
216
|
msg_res.metadata.reply_to_message_id for msg_res in messages_res
|
@@ -195,6 +220,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
195
220
|
|
196
221
|
# Convert Messages to proto
|
197
222
|
messages_list = []
|
223
|
+
objects_to_pull: dict[str, ObjectIDs] = {}
|
198
224
|
while messages_res:
|
199
225
|
msg = messages_res.pop(0)
|
200
226
|
|
@@ -207,7 +233,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
207
233
|
)
|
208
234
|
messages_list.append(message_to_proto(msg))
|
209
235
|
|
210
|
-
|
236
|
+
try:
|
237
|
+
msg_object_id = msg.metadata.message_id
|
238
|
+
descendants = store.get_message_descendant_ids(msg_object_id)
|
239
|
+
# Include the object_id of the message itself
|
240
|
+
objects_to_pull[msg_object_id] = ObjectIDs(
|
241
|
+
object_ids=descendants + [msg_object_id]
|
242
|
+
)
|
243
|
+
except NoObjectInStoreError as e:
|
244
|
+
log(ERROR, e.message)
|
245
|
+
# Delete message ins from state
|
246
|
+
state.delete_messages(message_ins_ids={msg_object_id})
|
247
|
+
|
248
|
+
return PullResMessagesResponse(
|
249
|
+
messages_list=messages_list, objects_to_pull=objects_to_pull
|
250
|
+
)
|
211
251
|
|
212
252
|
def GetRun(
|
213
253
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
flwr/server/superlink/utils.py
CHANGED
@@ -21,7 +21,11 @@ import grpc
|
|
21
21
|
|
22
22
|
from flwr.common.constant import Status, SubStatus
|
23
23
|
from flwr.common.typing import RunStatus
|
24
|
+
from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
|
25
|
+
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
26
|
+
from flwr.proto.serverappio_pb2 import PushInsMessagesRequest # pylint: disable=E0611
|
24
27
|
from flwr.server.superlink.linkstate import LinkState
|
28
|
+
from flwr.supercore.object_store import ObjectStore
|
25
29
|
|
26
30
|
_STATUS_TO_MSG = {
|
27
31
|
Status.PENDING: "Run is pending.",
|
@@ -63,3 +67,28 @@ def abort_if(
|
|
63
67
|
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
64
68
|
msg = check_abort(run_id, abort_status_list, state)
|
65
69
|
abort_grpc_context(msg, context)
|
70
|
+
|
71
|
+
|
72
|
+
def store_mapping_and_register_objects(
|
73
|
+
store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
|
74
|
+
) -> dict[str, ObjectIDs]:
|
75
|
+
"""Store Message object to descendants mapping and preregister objects."""
|
76
|
+
objects_to_push: dict[str, ObjectIDs] = {}
|
77
|
+
for (
|
78
|
+
message_obj_id,
|
79
|
+
descendant_obj_ids,
|
80
|
+
) in request.msg_to_descendant_mapping.items():
|
81
|
+
descendants = list(descendant_obj_ids.object_ids)
|
82
|
+
# Store mapping
|
83
|
+
store.set_message_descendant_ids(
|
84
|
+
msg_object_id=message_obj_id, descendant_ids=descendants
|
85
|
+
)
|
86
|
+
|
87
|
+
# Preregister
|
88
|
+
object_ids_just_registered = store.preregister(descendants + [message_obj_id])
|
89
|
+
# Keep track of objects that need to be pushed
|
90
|
+
objects_to_push[message_obj_id] = ObjectIDs(
|
91
|
+
object_ids=object_ids_just_registered
|
92
|
+
)
|
93
|
+
|
94
|
+
return objects_to_push
|
@@ -14,10 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Flower ObjectStore."""
|
16
16
|
|
17
|
-
from .object_store import ObjectStore
|
17
|
+
from .object_store import NoObjectInStoreError, ObjectStore
|
18
18
|
from .object_store_factory import ObjectStoreFactory
|
19
19
|
|
20
20
|
__all__ = [
|
21
|
+
"NoObjectInStoreError",
|
21
22
|
"ObjectStore",
|
22
23
|
"ObjectStoreFactory",
|
23
24
|
]
|