flwr-nightly 1.19.0.dev20250527__py3-none-any.whl → 1.19.0.dev20250529__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +3 -3
- flwr/cli/run/run.py +2 -6
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +5 -4
- flwr/client/grpc_rere_client/connection.py +2 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/common/auth_plugin/auth_plugin.py +8 -2
- flwr/common/inflatable.py +33 -2
- flwr/common/message.py +11 -0
- flwr/common/record/array.py +38 -1
- flwr/common/record/arrayrecord.py +34 -0
- flwr/common/serde.py +6 -1
- flwr/proto/fleet_pb2.py +16 -16
- flwr/proto/fleet_pb2.pyi +5 -5
- flwr/proto/message_pb2.py +10 -10
- flwr/proto/message_pb2.pyi +4 -4
- flwr/proto/serverappio_pb2.py +26 -26
- flwr/proto/serverappio_pb2.pyi +5 -5
- flwr/server/app.py +52 -56
- flwr/server/grid/grpc_grid.py +2 -1
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +11 -3
- flwr/server/superlink/fleet/rest_rere/rest_api.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +3 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -25
- flwr/server/superlink/linkstate/linkstate.py +9 -10
- flwr/server/superlink/linkstate/sqlite_linkstate.py +11 -21
- flwr/server/superlink/linkstate/utils.py +23 -23
- flwr/server/superlink/serverappio/serverappio_servicer.py +16 -11
- flwr/server/superlink/utils.py +29 -0
- flwr/server/utils/validator.py +2 -2
- flwr/supercore/object_store/in_memory_object_store.py +30 -4
- flwr/supercore/object_store/object_store.py +48 -1
- flwr/superexec/exec_servicer.py +1 -2
- {flwr_nightly-1.19.0.dev20250527.dist-info → flwr_nightly-1.19.0.dev20250529.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250527.dist-info → flwr_nightly-1.19.0.dev20250529.dist-info}/RECORD +42 -42
- {flwr_nightly-1.19.0.dev20250527.dist-info → flwr_nightly-1.19.0.dev20250529.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250527.dist-info → flwr_nightly-1.19.0.dev20250529.dist-info}/entry_points.txt +0 -0
flwr/proto/serverappio_pb2.py
CHANGED
@@ -20,15 +20,15 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
20
20
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
21
21
|
|
22
22
|
|
23
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\
|
23
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x8d\x02\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x61\n\x19msg_to_descendant_mapping\x18\x03 \x03(\x0b\x32>.flwr.proto.PushInsMessagesRequest.MsgToDescendantMappingEntry\x1aT\n\x1bMsgToDescendantMappingEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xe4\x08\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x62\x06proto3')
|
24
24
|
|
25
25
|
_globals = globals()
|
26
26
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
27
27
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.serverappio_pb2', _globals)
|
28
28
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
29
29
|
DESCRIPTOR._options = None
|
30
|
-
_globals['
|
31
|
-
_globals['
|
30
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._options = None
|
31
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_options = b'8\001'
|
32
32
|
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._options = None
|
33
33
|
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_options = b'8\001'
|
34
34
|
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._options = None
|
@@ -38,27 +38,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
38
38
|
_globals['_GETNODESRESPONSE']._serialized_start=222
|
39
39
|
_globals['_GETNODESRESPONSE']._serialized_end=273
|
40
40
|
_globals['_PUSHINSMESSAGESREQUEST']._serialized_start=276
|
41
|
-
_globals['_PUSHINSMESSAGESREQUEST']._serialized_end=
|
42
|
-
_globals['
|
43
|
-
_globals['
|
44
|
-
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=
|
45
|
-
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=
|
46
|
-
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=
|
47
|
-
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=
|
48
|
-
_globals['_PULLRESMESSAGESREQUEST']._serialized_start=
|
49
|
-
_globals['_PULLRESMESSAGESREQUEST']._serialized_end=
|
50
|
-
_globals['_PULLRESMESSAGESRESPONSE']._serialized_start=
|
51
|
-
_globals['_PULLRESMESSAGESRESPONSE']._serialized_end=
|
52
|
-
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=
|
53
|
-
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=
|
54
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=
|
55
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=
|
56
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=
|
57
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=
|
58
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=
|
59
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=
|
60
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=
|
61
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=
|
62
|
-
_globals['_SERVERAPPIO']._serialized_start=
|
63
|
-
_globals['_SERVERAPPIO']._serialized_end=
|
41
|
+
_globals['_PUSHINSMESSAGESREQUEST']._serialized_end=545
|
42
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_start=461
|
43
|
+
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_end=545
|
44
|
+
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=548
|
45
|
+
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=752
|
46
|
+
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=677
|
47
|
+
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=752
|
48
|
+
_globals['_PULLRESMESSAGESREQUEST']._serialized_start=754
|
49
|
+
_globals['_PULLRESMESSAGESREQUEST']._serialized_end=815
|
50
|
+
_globals['_PULLRESMESSAGESRESPONSE']._serialized_start=818
|
51
|
+
_globals['_PULLRESMESSAGESRESPONSE']._serialized_end=1045
|
52
|
+
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=970
|
53
|
+
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=1045
|
54
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=1047
|
55
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=1075
|
56
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=1077
|
57
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1204
|
58
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1206
|
59
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1289
|
60
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1291
|
61
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1321
|
62
|
+
_globals['_SERVERAPPIO']._serialized_start=1324
|
63
|
+
_globals['_SERVERAPPIO']._serialized_end=2448
|
64
64
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/serverappio_pb2.pyi
CHANGED
@@ -42,7 +42,7 @@ global___GetNodesResponse = GetNodesResponse
|
|
42
42
|
class PushInsMessagesRequest(google.protobuf.message.Message):
|
43
43
|
"""PushMessages messages"""
|
44
44
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
45
|
-
class
|
45
|
+
class MsgToDescendantMappingEntry(google.protobuf.message.Message):
|
46
46
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
47
47
|
KEY_FIELD_NUMBER: builtins.int
|
48
48
|
VALUE_FIELD_NUMBER: builtins.int
|
@@ -59,19 +59,19 @@ class PushInsMessagesRequest(google.protobuf.message.Message):
|
|
59
59
|
|
60
60
|
MESSAGES_LIST_FIELD_NUMBER: builtins.int
|
61
61
|
RUN_ID_FIELD_NUMBER: builtins.int
|
62
|
-
|
62
|
+
MSG_TO_DESCENDANT_MAPPING_FIELD_NUMBER: builtins.int
|
63
63
|
@property
|
64
64
|
def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
|
65
65
|
run_id: builtins.int
|
66
66
|
@property
|
67
|
-
def
|
67
|
+
def msg_to_descendant_mapping(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.message_pb2.ObjectIDs]: ...
|
68
68
|
def __init__(self,
|
69
69
|
*,
|
70
70
|
messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
|
71
71
|
run_id: builtins.int = ...,
|
72
|
-
|
72
|
+
msg_to_descendant_mapping: typing.Optional[typing.Mapping[typing.Text, flwr.proto.message_pb2.ObjectIDs]] = ...,
|
73
73
|
) -> None: ...
|
74
|
-
def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","
|
74
|
+
def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","msg_to_descendant_mapping",b"msg_to_descendant_mapping","run_id",b"run_id"]) -> None: ...
|
75
75
|
global___PushInsMessagesRequest = PushInsMessagesRequest
|
76
76
|
|
77
77
|
class PushInsMessagesResponse(google.protobuf.message.Message):
|
flwr/server/app.py
CHANGED
@@ -27,7 +27,7 @@ from collections.abc import Sequence
|
|
27
27
|
from logging import DEBUG, INFO, WARN
|
28
28
|
from pathlib import Path
|
29
29
|
from time import sleep
|
30
|
-
from typing import Any, Optional
|
30
|
+
from typing import Any, Callable, Optional, Union, cast
|
31
31
|
|
32
32
|
import grpc
|
33
33
|
import yaml
|
@@ -155,17 +155,14 @@ def run_superlink() -> None:
|
|
155
155
|
event_log_plugin: Optional[EventLogWriterPlugin] = None
|
156
156
|
# Load the auth plugin if the args.user_auth_config is provided
|
157
157
|
if cfg_path := getattr(args, "user_auth_config", None):
|
158
|
-
|
158
|
+
# pylint: disable=unused-variable
|
159
|
+
auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins( # noqa: F841
|
160
|
+
Path(cfg_path), verify_tls_cert
|
161
|
+
)
|
162
|
+
# pylint: enable=unused-variable
|
159
163
|
# Enable event logging if the args.enable_event_log is True
|
160
164
|
if args.enable_event_log:
|
161
165
|
event_log_plugin = _try_obtain_exec_event_log_writer_plugin()
|
162
|
-
# Enable authorization if the args.enable_authorization is True
|
163
|
-
if args.enable_authorization:
|
164
|
-
# pylint: disable=unused-variable
|
165
|
-
authz_plugin = _try_obtain_exec_authz_plugin( # noqa: F841
|
166
|
-
Path(cfg_path), verify_tls_cert
|
167
|
-
)
|
168
|
-
# pylint: enable=unused-variable
|
169
166
|
|
170
167
|
# Initialize StateFactory
|
171
168
|
state_factory = LinkStateFactory(args.database)
|
@@ -258,6 +255,7 @@ def run_superlink() -> None:
|
|
258
255
|
args.ssl_certfile,
|
259
256
|
state_factory,
|
260
257
|
ffs_factory,
|
258
|
+
objectstore_factory,
|
261
259
|
num_workers,
|
262
260
|
),
|
263
261
|
daemon=True,
|
@@ -483,62 +481,58 @@ def _try_load_public_keys_node_authentication(
|
|
483
481
|
return node_public_keys
|
484
482
|
|
485
483
|
|
486
|
-
def
|
484
|
+
def _try_obtain_exec_auth_plugins(
|
487
485
|
config_path: Path, verify_tls_cert: bool
|
488
|
-
) ->
|
486
|
+
) -> tuple[ExecAuthPlugin, ExecAuthzPlugin]:
|
487
|
+
"""Obtain Exec API authentication and authorization plugins."""
|
489
488
|
# Load YAML file
|
490
489
|
with config_path.open("r", encoding="utf-8") as file:
|
491
490
|
config: dict[str, Any] = yaml.safe_load(file)
|
492
491
|
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
if auth_type != "":
|
506
|
-
sys.exit(
|
507
|
-
f'Authentication type "{auth_type}" is not supported. '
|
508
|
-
"Please provide a valid authentication type in the configuration."
|
492
|
+
def _load_plugin(
|
493
|
+
section: str,
|
494
|
+
yaml_key: str,
|
495
|
+
loader: Callable[[], dict[str, type[Union[ExecAuthPlugin, ExecAuthzPlugin]]]],
|
496
|
+
) -> Union[ExecAuthPlugin, ExecAuthzPlugin]:
|
497
|
+
section_cfg = config.get(section, {})
|
498
|
+
auth_plugin_name = section_cfg.get(yaml_key, "")
|
499
|
+
try:
|
500
|
+
plugins = loader()
|
501
|
+
plugin_cls = plugins[auth_plugin_name]
|
502
|
+
return plugin_cls(
|
503
|
+
user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
509
504
|
)
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
with config_path.open("r", encoding="utf-8") as file:
|
520
|
-
config: dict[str, Any] = yaml.safe_load(file)
|
505
|
+
except KeyError:
|
506
|
+
if auth_plugin_name:
|
507
|
+
sys.exit(
|
508
|
+
f"{yaml_key}: {auth_plugin_name} is not supported. "
|
509
|
+
f"Please provide a valid {section} type in the configuration."
|
510
|
+
)
|
511
|
+
sys.exit(f"No {section} type is provided in the configuration.")
|
512
|
+
except NotImplementedError:
|
513
|
+
sys.exit(f"No {section} plugins are currently supported.")
|
521
514
|
|
522
|
-
# Load authentication
|
523
|
-
|
524
|
-
|
515
|
+
# Load authentication plugin
|
516
|
+
auth_plugin = cast(
|
517
|
+
ExecAuthPlugin,
|
518
|
+
_load_plugin(
|
519
|
+
section="authentication",
|
520
|
+
yaml_key=AUTH_TYPE_YAML_KEY,
|
521
|
+
loader=get_exec_auth_plugins,
|
522
|
+
),
|
523
|
+
)
|
525
524
|
|
526
525
|
# Load authorization plugin
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
"Please provide a valid authorization type in the configuration."
|
538
|
-
)
|
539
|
-
sys.exit("No authorization type is provided in the configuration.")
|
540
|
-
except NotImplementedError:
|
541
|
-
sys.exit("No authorization plugins are currently supported.")
|
526
|
+
authz_plugin = cast(
|
527
|
+
ExecAuthzPlugin,
|
528
|
+
_load_plugin(
|
529
|
+
section="authorization",
|
530
|
+
yaml_key=AUTHZ_TYPE_YAML_KEY,
|
531
|
+
loader=get_exec_authz_plugins,
|
532
|
+
),
|
533
|
+
)
|
534
|
+
|
535
|
+
return auth_plugin, authz_plugin
|
542
536
|
|
543
537
|
|
544
538
|
def _try_obtain_exec_event_log_writer_plugin() -> Optional[EventLogWriterPlugin]:
|
@@ -636,6 +630,7 @@ def _run_fleet_api_rest(
|
|
636
630
|
ssl_certfile: Optional[str],
|
637
631
|
state_factory: LinkStateFactory,
|
638
632
|
ffs_factory: FfsFactory,
|
633
|
+
objectstore_factory: ObjectStoreFactory,
|
639
634
|
num_workers: int,
|
640
635
|
) -> None:
|
641
636
|
"""Run ServerAppIo API (REST-based)."""
|
@@ -651,6 +646,7 @@ def _run_fleet_api_rest(
|
|
651
646
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
652
647
|
fast_api_app.state.STATE_FACTORY = state_factory
|
653
648
|
fast_api_app.state.FFS_FACTORY = ffs_factory
|
649
|
+
fast_api_app.state.OBJECTSTORE_FACTORY = objectstore_factory
|
654
650
|
|
655
651
|
uvicorn.run(
|
656
652
|
app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
|
flwr/server/grid/grpc_grid.py
CHANGED
@@ -163,7 +163,7 @@ class GrpcGrid(Grid):
|
|
163
163
|
def _check_message(self, message: Message) -> None:
|
164
164
|
# Check if the message is valid
|
165
165
|
if not (
|
166
|
-
message.metadata.message_id
|
166
|
+
message.metadata.message_id != ""
|
167
167
|
and message.metadata.reply_to_message_id == ""
|
168
168
|
and message.metadata.ttl > 0
|
169
169
|
):
|
@@ -211,6 +211,7 @@ class GrpcGrid(Grid):
|
|
211
211
|
# Populate metadata
|
212
212
|
msg.metadata.__dict__["_run_id"] = run_id
|
213
213
|
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
214
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
214
215
|
# Check message
|
215
216
|
self._check_message(msg)
|
216
217
|
# Convert to proto
|
@@ -18,7 +18,7 @@
|
|
18
18
|
import time
|
19
19
|
from collections.abc import Iterable
|
20
20
|
from typing import Optional, cast
|
21
|
-
from uuid import
|
21
|
+
from uuid import uuid4
|
22
22
|
|
23
23
|
from flwr.common import Message, RecordDict
|
24
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
@@ -56,7 +56,7 @@ class InMemoryGrid(Grid):
|
|
56
56
|
def _check_message(self, message: Message) -> None:
|
57
57
|
# Check if the message is valid
|
58
58
|
if not (
|
59
|
-
message.metadata.message_id
|
59
|
+
message.metadata.message_id != ""
|
60
60
|
and message.metadata.reply_to_message_id == ""
|
61
61
|
and message.metadata.ttl > 0
|
62
62
|
and message.metadata.delivered_at == ""
|
@@ -111,6 +111,7 @@ class InMemoryGrid(Grid):
|
|
111
111
|
# Populate metadata
|
112
112
|
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
113
113
|
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
114
|
+
msg.metadata.__dict__["_message_id"] = str(uuid4())
|
114
115
|
# Check message
|
115
116
|
self._check_message(msg)
|
116
117
|
# Store in state
|
@@ -126,12 +127,12 @@ class InMemoryGrid(Grid):
|
|
126
127
|
This method is used to collect messages from the SuperLink that correspond to a
|
127
128
|
set of given message IDs.
|
128
129
|
"""
|
129
|
-
msg_ids =
|
130
|
+
msg_ids = set(message_ids)
|
130
131
|
# Pull Messages
|
131
132
|
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
132
133
|
# Get IDs of Messages these replies are for
|
133
134
|
message_ins_ids_to_delete = {
|
134
|
-
|
135
|
+
msg_res.metadata.reply_to_message_id for msg_res in message_res_list
|
135
136
|
}
|
136
137
|
# Delete
|
137
138
|
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
@@ -133,6 +133,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
133
133
|
res = message_handler.push_messages(
|
134
134
|
request=request,
|
135
135
|
state=self.state_factory.state(),
|
136
|
+
store=self.objectstore_factory.store(),
|
136
137
|
)
|
137
138
|
except InvalidRunStatusException as e:
|
138
139
|
abort_grpc_context(e.message, context)
|
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
|
-
from uuid import UUID
|
20
19
|
|
21
20
|
from flwr.common import Message
|
22
21
|
from flwr.common.constant import Status
|
@@ -52,6 +51,9 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
52
51
|
from flwr.server.superlink.ffs.ffs import Ffs
|
53
52
|
from flwr.server.superlink.linkstate import LinkState
|
54
53
|
from flwr.server.superlink.utils import check_abort
|
54
|
+
from flwr.supercore.object_store import ObjectStore
|
55
|
+
|
56
|
+
from ...utils import store_mapping_and_register_objects
|
55
57
|
|
56
58
|
|
57
59
|
def create_node(
|
@@ -106,7 +108,9 @@ def pull_messages(
|
|
106
108
|
|
107
109
|
|
108
110
|
def push_messages(
|
109
|
-
request: PushMessagesRequest,
|
111
|
+
request: PushMessagesRequest,
|
112
|
+
state: LinkState,
|
113
|
+
store: ObjectStore,
|
110
114
|
) -> PushMessagesResponse:
|
111
115
|
"""Push Messages handler."""
|
112
116
|
# Convert Message from proto
|
@@ -122,12 +126,16 @@ def push_messages(
|
|
122
126
|
raise InvalidRunStatusException(abort_msg)
|
123
127
|
|
124
128
|
# Store Message in State
|
125
|
-
message_id: Optional[
|
129
|
+
message_id: Optional[str] = state.store_message_res(message=msg)
|
130
|
+
|
131
|
+
# Store Message object to descendants mapping and preregister objects
|
132
|
+
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
126
133
|
|
127
134
|
# Build response
|
128
135
|
response = PushMessagesResponse(
|
129
136
|
reconnect=Reconnect(reconnect=5),
|
130
137
|
results={str(message_id): 0},
|
138
|
+
objects_to_push=objects_to_push,
|
131
139
|
)
|
132
140
|
return response
|
133
141
|
|
@@ -43,6 +43,7 @@ from flwr.server.superlink.ffs.ffs import Ffs
|
|
43
43
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
44
44
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
45
45
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
46
|
+
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
46
47
|
|
47
48
|
try:
|
48
49
|
from starlette.applications import Starlette
|
@@ -123,9 +124,10 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
123
124
|
"""Pull PushMessages."""
|
124
125
|
# Get state from app
|
125
126
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
127
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
|
126
128
|
|
127
129
|
# Handle message
|
128
|
-
return message_handler.push_messages(request=request, state=state)
|
130
|
+
return message_handler.push_messages(request=request, state=state, store=store)
|
129
131
|
|
130
132
|
|
131
133
|
@rest_request_response(SendNodeHeartbeatRequest)
|
@@ -25,6 +25,7 @@ from pathlib import Path
|
|
25
25
|
from queue import Empty, Queue
|
26
26
|
from time import sleep
|
27
27
|
from typing import Callable, Optional
|
28
|
+
from uuid import uuid4
|
28
29
|
|
29
30
|
from flwr.app.error import Error
|
30
31
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
@@ -134,6 +135,8 @@ def worker(
|
|
134
135
|
|
135
136
|
finally:
|
136
137
|
if out_mssg:
|
138
|
+
# Assign a message_id
|
139
|
+
out_mssg.metadata.__dict__["_message_id"] = str(uuid4())
|
137
140
|
# Store reply Messages in state
|
138
141
|
messageres_queue.put(out_mssg)
|
139
142
|
|
@@ -21,7 +21,6 @@ from bisect import bisect_right
|
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from logging import ERROR, WARNING
|
23
23
|
from typing import Optional
|
24
|
-
from uuid import UUID, uuid4
|
25
24
|
|
26
25
|
from flwr.common import Context, Message, log, now
|
27
26
|
from flwr.common.constant import (
|
@@ -76,15 +75,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
76
75
|
self.run_ids: dict[int, RunRecord] = {}
|
77
76
|
self.contexts: dict[int, Context] = {}
|
78
77
|
self.federation_options: dict[int, ConfigRecord] = {}
|
79
|
-
self.message_ins_store: dict[
|
80
|
-
self.message_res_store: dict[
|
81
|
-
self.message_ins_id_to_message_res_id: dict[
|
78
|
+
self.message_ins_store: dict[str, Message] = {}
|
79
|
+
self.message_res_store: dict[str, Message] = {}
|
80
|
+
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
82
81
|
|
83
82
|
self.node_public_keys: set[bytes] = set()
|
84
83
|
|
85
84
|
self.lock = threading.RLock()
|
86
85
|
|
87
|
-
def store_message_ins(self, message: Message) -> Optional[
|
86
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
88
87
|
"""Store one Message."""
|
89
88
|
# Validate message
|
90
89
|
errors = validate_message(message, is_reply_message=False)
|
@@ -112,12 +111,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
112
111
|
)
|
113
112
|
return None
|
114
113
|
|
115
|
-
|
116
|
-
message_id = uuid4()
|
117
|
-
|
118
|
-
# Store Message
|
119
|
-
# pylint: disable-next=W0212
|
120
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
114
|
+
message_id = message.metadata.message_id
|
121
115
|
with self.lock:
|
122
116
|
self.message_ins_store[message_id] = message
|
123
117
|
|
@@ -153,7 +147,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
153
147
|
return message_ins_list
|
154
148
|
|
155
149
|
# pylint: disable=R0911
|
156
|
-
def store_message_res(self, message: Message) -> Optional[
|
150
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
157
151
|
"""Store one Message."""
|
158
152
|
# Validate message
|
159
153
|
errors = validate_message(message, is_reply_message=True)
|
@@ -165,7 +159,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
165
159
|
with self.lock:
|
166
160
|
# Check if the Message it is replying to exists and is valid
|
167
161
|
msg_ins_id = res_metadata.reply_to_message_id
|
168
|
-
msg_ins = self.message_ins_store.get(
|
162
|
+
msg_ins = self.message_ins_store.get(msg_ins_id)
|
169
163
|
|
170
164
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
171
165
|
# reply Message.
|
@@ -220,22 +214,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
220
214
|
log(ERROR, "`metadata.run_id` is invalid")
|
221
215
|
return None
|
222
216
|
|
223
|
-
|
224
|
-
message_id = uuid4()
|
225
|
-
|
226
|
-
# Store Message
|
227
|
-
# pylint: disable-next=W0212
|
228
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
217
|
+
message_id = message.metadata.message_id
|
229
218
|
with self.lock:
|
230
219
|
self.message_res_store[message_id] = message
|
231
|
-
self.message_ins_id_to_message_res_id[
|
220
|
+
self.message_ins_id_to_message_res_id[msg_ins_id] = message_id
|
232
221
|
|
233
222
|
# Return the new message_id
|
234
223
|
return message_id
|
235
224
|
|
236
|
-
def get_message_res(self, message_ids: set[
|
225
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
237
226
|
"""Get reply Messages for the given Message IDs."""
|
238
|
-
ret: dict[
|
227
|
+
ret: dict[str, Message] = {}
|
239
228
|
|
240
229
|
with self.lock:
|
241
230
|
current = time.time()
|
@@ -287,7 +276,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
287
276
|
|
288
277
|
return list(ret.values())
|
289
278
|
|
290
|
-
def delete_messages(self, message_ins_ids: set[
|
279
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
291
280
|
"""Delete a Message and its reply based on provided Message IDs."""
|
292
281
|
if not message_ins_ids:
|
293
282
|
return
|
@@ -304,9 +293,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
304
293
|
)
|
305
294
|
del self.message_res_store[message_res_id]
|
306
295
|
|
307
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
296
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
308
297
|
"""Get all instruction Message IDs for the given run_id."""
|
309
|
-
message_id_list: set[
|
298
|
+
message_id_list: set[str] = set()
|
310
299
|
with self.lock:
|
311
300
|
for message_id, message in self.message_ins_store.items():
|
312
301
|
if message.metadata.run_id == run_id:
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import abc
|
19
19
|
from typing import Optional
|
20
|
-
from uuid import UUID
|
21
20
|
|
22
21
|
from flwr.common import Context, Message
|
23
22
|
from flwr.common.record import ConfigRecord
|
@@ -28,13 +27,13 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
28
27
|
"""Abstract LinkState."""
|
29
28
|
|
30
29
|
@abc.abstractmethod
|
31
|
-
def store_message_ins(self, message: Message) -> Optional[
|
30
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
32
31
|
"""Store one Message.
|
33
32
|
|
34
33
|
Usually, the ServerAppIo API calls this to schedule instructions.
|
35
34
|
|
36
35
|
Stores the value of the `message` in the link state and, if successful,
|
37
|
-
returns the `message_id` (
|
36
|
+
returns the `message_id` (str) of the `message`. If, for any reason,
|
38
37
|
storing the `message` fails, `None` is returned.
|
39
38
|
|
40
39
|
Constraints
|
@@ -61,12 +60,12 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
61
60
|
"""
|
62
61
|
|
63
62
|
@abc.abstractmethod
|
64
|
-
def store_message_res(self, message: Message) -> Optional[
|
63
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
65
64
|
"""Store one Message.
|
66
65
|
|
67
66
|
Usually, the Fleet API calls this for Nodes returning results.
|
68
67
|
|
69
|
-
Stores the Message and, if successful, returns the `message_id` (
|
68
|
+
Stores the Message and, if successful, returns the `message_id` (str) of
|
70
69
|
the `message`. If storing the `message` fails, `None` is returned.
|
71
70
|
|
72
71
|
Constraints
|
@@ -78,7 +77,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
78
77
|
"""
|
79
78
|
|
80
79
|
@abc.abstractmethod
|
81
|
-
def get_message_res(self, message_ids: set[
|
80
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
82
81
|
"""Get reply Messages for the given Message IDs.
|
83
82
|
|
84
83
|
This method is typically called by the ServerAppIo API to obtain
|
@@ -94,7 +93,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
94
93
|
|
95
94
|
Parameters
|
96
95
|
----------
|
97
|
-
message_ids : set[
|
96
|
+
message_ids : set[str]
|
98
97
|
A set of Message IDs used to retrieve reply Messages responding to them.
|
99
98
|
|
100
99
|
Returns
|
@@ -113,18 +112,18 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
113
112
|
"""Calculate the number of reply Messages in store."""
|
114
113
|
|
115
114
|
@abc.abstractmethod
|
116
|
-
def delete_messages(self, message_ins_ids: set[
|
115
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
117
116
|
"""Delete a Message and its reply based on provided Message IDs.
|
118
117
|
|
119
118
|
Parameters
|
120
119
|
----------
|
121
|
-
message_ins_ids : set[
|
120
|
+
message_ins_ids : set[str]
|
122
121
|
A set of Message IDs. For each ID in the set, the corresponding
|
123
122
|
Message and its associated reply Message will be deleted.
|
124
123
|
"""
|
125
124
|
|
126
125
|
@abc.abstractmethod
|
127
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
126
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
128
127
|
"""Get all instruction Message IDs for the given run_id."""
|
129
128
|
|
130
129
|
@abc.abstractmethod
|