flwr-nightly 1.11.0.dev20240812__py3-none-any.whl → 1.11.0.dev20240815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/run/run.py +6 -2
- flwr/client/app.py +4 -3
- flwr/client/grpc_adapter_client/connection.py +3 -1
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +8 -2
- flwr/client/process/__init__.py +15 -0
- flwr/client/process/clientappio_servicer.py +145 -0
- flwr/client/rest_client/connection.py +9 -3
- flwr/common/config.py +7 -2
- flwr/common/record/recordset.py +9 -7
- flwr/common/record/typeddict.py +20 -58
- flwr/common/recordset_compat.py +6 -6
- flwr/common/serde.py +178 -1
- flwr/common/typing.py +17 -0
- flwr/proto/exec_pb2.py +16 -15
- flwr/proto/exec_pb2.pyi +7 -4
- flwr/proto/message_pb2.py +2 -2
- flwr/proto/message_pb2.pyi +4 -1
- flwr/server/app.py +12 -0
- flwr/server/driver/grpc_driver.py +1 -0
- flwr/server/superlink/driver/driver_grpc.py +3 -0
- flwr/server/superlink/driver/driver_servicer.py +14 -1
- flwr/server/superlink/ffs/ffs_factory.py +47 -0
- flwr/server/superlink/state/in_memory_state.py +7 -5
- flwr/server/superlink/state/sqlite_state.py +17 -7
- flwr/server/superlink/state/state.py +4 -3
- flwr/simulation/run_simulation.py +4 -1
- flwr/superexec/exec_servicer.py +1 -1
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/METADATA +1 -1
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/RECORD +33 -30
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/entry_points.txt +0 -0
flwr/common/serde.py
CHANGED
|
@@ -20,7 +20,12 @@ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar,
|
|
|
20
20
|
from google.protobuf.message import Message as GrpcMessage
|
|
21
21
|
|
|
22
22
|
# pylint: disable=E0611
|
|
23
|
+
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
23
24
|
from flwr.proto.error_pb2 import Error as ProtoError
|
|
25
|
+
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
26
|
+
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
27
|
+
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
28
|
+
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
24
29
|
from flwr.proto.node_pb2 import Node
|
|
25
30
|
from flwr.proto.recordset_pb2 import Array as ProtoArray
|
|
26
31
|
from flwr.proto.recordset_pb2 import BoolList, BytesList
|
|
@@ -32,6 +37,7 @@ from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordVal
|
|
|
32
37
|
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
|
|
33
38
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
34
39
|
from flwr.proto.recordset_pb2 import Sint64List, StringList
|
|
40
|
+
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
35
41
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
36
42
|
from flwr.proto.transport_pb2 import (
|
|
37
43
|
ClientMessage,
|
|
@@ -44,7 +50,15 @@ from flwr.proto.transport_pb2 import (
|
|
|
44
50
|
)
|
|
45
51
|
|
|
46
52
|
# pylint: enable=E0611
|
|
47
|
-
from . import
|
|
53
|
+
from . import (
|
|
54
|
+
Array,
|
|
55
|
+
ConfigsRecord,
|
|
56
|
+
Context,
|
|
57
|
+
MetricsRecord,
|
|
58
|
+
ParametersRecord,
|
|
59
|
+
RecordSet,
|
|
60
|
+
typing,
|
|
61
|
+
)
|
|
48
62
|
from .message import Error, Message, Metadata
|
|
49
63
|
from .record.typeddict import TypedDict
|
|
50
64
|
|
|
@@ -673,6 +687,19 @@ def message_from_taskres(taskres: TaskRes) -> Message:
|
|
|
673
687
|
return message
|
|
674
688
|
|
|
675
689
|
|
|
690
|
+
# === FAB ===
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def fab_to_proto(fab: typing.Fab) -> ProtoFab:
|
|
694
|
+
"""Create a proto Fab object from a Python Fab."""
|
|
695
|
+
return ProtoFab(hash_str=fab.hash_str, content=fab.content)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def fab_from_proto(fab: ProtoFab) -> typing.Fab:
|
|
699
|
+
"""Create a Python Fab object from a proto Fab."""
|
|
700
|
+
return typing.Fab(fab.hash_str, fab.content)
|
|
701
|
+
|
|
702
|
+
|
|
676
703
|
# === User configs ===
|
|
677
704
|
|
|
678
705
|
|
|
@@ -716,3 +743,153 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
|
|
|
716
743
|
scalar_field = scalar_msg.WhichOneof("scalar")
|
|
717
744
|
scalar = getattr(scalar_msg, cast(str, scalar_field))
|
|
718
745
|
return cast(typing.UserConfigValue, scalar)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
# === Metadata messages ===
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
|
752
|
+
"""Serialize `Metadata` to ProtoBuf."""
|
|
753
|
+
proto = ProtoMetadata( # pylint: disable=E1101
|
|
754
|
+
run_id=metadata.run_id,
|
|
755
|
+
message_id=metadata.message_id,
|
|
756
|
+
src_node_id=metadata.src_node_id,
|
|
757
|
+
dst_node_id=metadata.dst_node_id,
|
|
758
|
+
reply_to_message=metadata.reply_to_message,
|
|
759
|
+
group_id=metadata.group_id,
|
|
760
|
+
ttl=metadata.ttl,
|
|
761
|
+
message_type=metadata.message_type,
|
|
762
|
+
created_at=metadata.created_at,
|
|
763
|
+
)
|
|
764
|
+
return proto
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
|
768
|
+
"""Deserialize `Metadata` from ProtoBuf."""
|
|
769
|
+
metadata = Metadata(
|
|
770
|
+
run_id=metadata_proto.run_id,
|
|
771
|
+
message_id=metadata_proto.message_id,
|
|
772
|
+
src_node_id=metadata_proto.src_node_id,
|
|
773
|
+
dst_node_id=metadata_proto.dst_node_id,
|
|
774
|
+
reply_to_message=metadata_proto.reply_to_message,
|
|
775
|
+
group_id=metadata_proto.group_id,
|
|
776
|
+
ttl=metadata_proto.ttl,
|
|
777
|
+
message_type=metadata_proto.message_type,
|
|
778
|
+
)
|
|
779
|
+
return metadata
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
# === Message messages ===
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def message_to_proto(message: Message) -> ProtoMessage:
|
|
786
|
+
"""Serialize `Message` to ProtoBuf."""
|
|
787
|
+
proto = ProtoMessage(
|
|
788
|
+
metadata=metadata_to_proto(message.metadata),
|
|
789
|
+
content=(
|
|
790
|
+
recordset_to_proto(message.content) if message.has_content() else None
|
|
791
|
+
),
|
|
792
|
+
error=error_to_proto(message.error) if message.has_error() else None,
|
|
793
|
+
)
|
|
794
|
+
return proto
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def message_from_proto(message_proto: ProtoMessage) -> Message:
|
|
798
|
+
"""Deserialize `Message` from ProtoBuf."""
|
|
799
|
+
created_at = message_proto.metadata.created_at
|
|
800
|
+
message = Message(
|
|
801
|
+
metadata=metadata_from_proto(message_proto.metadata),
|
|
802
|
+
content=(
|
|
803
|
+
recordset_from_proto(message_proto.content)
|
|
804
|
+
if message_proto.HasField("content")
|
|
805
|
+
else None
|
|
806
|
+
),
|
|
807
|
+
error=(
|
|
808
|
+
error_from_proto(message_proto.error)
|
|
809
|
+
if message_proto.HasField("error")
|
|
810
|
+
else None
|
|
811
|
+
),
|
|
812
|
+
)
|
|
813
|
+
# `.created_at` is set upon Message object construction
|
|
814
|
+
# we need to manually set it to the original value
|
|
815
|
+
message.metadata.created_at = created_at
|
|
816
|
+
return message
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
# === Context messages ===
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def context_to_proto(context: Context) -> ProtoContext:
|
|
823
|
+
"""Serialize `Context` to ProtoBuf."""
|
|
824
|
+
proto = ProtoContext(
|
|
825
|
+
node_id=context.node_id,
|
|
826
|
+
node_config=user_config_to_proto(context.node_config),
|
|
827
|
+
state=recordset_to_proto(context.state),
|
|
828
|
+
run_config=user_config_to_proto(context.run_config),
|
|
829
|
+
)
|
|
830
|
+
return proto
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
def context_from_proto(context_proto: ProtoContext) -> Context:
|
|
834
|
+
"""Deserialize `Context` from ProtoBuf."""
|
|
835
|
+
context = Context(
|
|
836
|
+
node_id=context_proto.node_id,
|
|
837
|
+
node_config=user_config_from_proto(context_proto.node_config),
|
|
838
|
+
state=recordset_from_proto(context_proto.state),
|
|
839
|
+
run_config=user_config_from_proto(context_proto.run_config),
|
|
840
|
+
)
|
|
841
|
+
return context
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
# === Run messages ===
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
848
|
+
"""Serialize `Run` to ProtoBuf."""
|
|
849
|
+
proto = ProtoRun(
|
|
850
|
+
run_id=run.run_id,
|
|
851
|
+
fab_id=run.fab_id,
|
|
852
|
+
fab_version=run.fab_version,
|
|
853
|
+
fab_hash=run.fab_hash,
|
|
854
|
+
override_config=user_config_to_proto(run.override_config),
|
|
855
|
+
)
|
|
856
|
+
return proto
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
860
|
+
"""Deserialize `Run` from ProtoBuf."""
|
|
861
|
+
run = typing.Run(
|
|
862
|
+
run_id=run_proto.run_id,
|
|
863
|
+
fab_id=run_proto.fab_id,
|
|
864
|
+
fab_version=run_proto.fab_version,
|
|
865
|
+
fab_hash=run_proto.fab_hash,
|
|
866
|
+
override_config=user_config_from_proto(run_proto.override_config),
|
|
867
|
+
)
|
|
868
|
+
return run
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
# === ClientApp status messages ===
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def clientappstatus_to_proto(
|
|
875
|
+
status: typing.ClientAppOutputStatus,
|
|
876
|
+
) -> ClientAppOutputStatus:
|
|
877
|
+
"""Serialize `ClientAppOutputStatus` to ProtoBuf."""
|
|
878
|
+
code = ClientAppOutputCode.SUCCESS
|
|
879
|
+
if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
880
|
+
code = ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
881
|
+
if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
|
|
882
|
+
code = ClientAppOutputCode.UNKNOWN_ERROR
|
|
883
|
+
return ClientAppOutputStatus(code=code, message=status.message)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def clientappstatus_from_proto(
|
|
887
|
+
msg: ClientAppOutputStatus,
|
|
888
|
+
) -> typing.ClientAppOutputStatus:
|
|
889
|
+
"""Deserialize `ClientAppOutputStatus` from ProtoBuf."""
|
|
890
|
+
code = typing.ClientAppOutputCode.SUCCESS
|
|
891
|
+
if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
892
|
+
code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
893
|
+
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
|
|
894
|
+
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
895
|
+
return typing.ClientAppOutputStatus(code=code, message=msg.message)
|
flwr/common/typing.py
CHANGED
|
@@ -83,6 +83,22 @@ class Status:
|
|
|
83
83
|
message: str
|
|
84
84
|
|
|
85
85
|
|
|
86
|
+
class ClientAppOutputCode(Enum):
|
|
87
|
+
"""ClientAppIO status codes."""
|
|
88
|
+
|
|
89
|
+
SUCCESS = 0
|
|
90
|
+
DEADLINE_EXCEEDED = 1
|
|
91
|
+
UNKNOWN_ERROR = 2
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class ClientAppOutputStatus:
|
|
96
|
+
"""ClientAppIO status."""
|
|
97
|
+
|
|
98
|
+
code: ClientAppOutputCode
|
|
99
|
+
message: str
|
|
100
|
+
|
|
101
|
+
|
|
86
102
|
@dataclass
|
|
87
103
|
class Parameters:
|
|
88
104
|
"""Model parameters."""
|
|
@@ -198,6 +214,7 @@ class Run:
|
|
|
198
214
|
run_id: int
|
|
199
215
|
fab_id: str
|
|
200
216
|
fab_version: str
|
|
217
|
+
fab_hash: str
|
|
201
218
|
override_config: UserConfig
|
|
202
219
|
|
|
203
220
|
|
flwr/proto/exec_pb2.py
CHANGED
|
@@ -12,10 +12,11 @@ from google.protobuf.internal import builder as _builder
|
|
|
12
12
|
_sym_db = _symbol_database.Default()
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
15
16
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xdf\x02\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12L\n\x11\x66\x65\x64\x65ration_config\x18\x03 \x03(\x0b\x32\x31.flwr.proto.StartRunRequest.FederationConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1aK\n\x15\x46\x65\x64\x65rationConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3')
|
|
19
20
|
|
|
20
21
|
_globals = globals()
|
|
21
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -26,18 +27,18 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
26
27
|
_globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
|
|
27
28
|
_globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._options = None
|
|
28
29
|
_globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_options = b'8\001'
|
|
29
|
-
_globals['_STARTRUNREQUEST']._serialized_start=
|
|
30
|
-
_globals['_STARTRUNREQUEST']._serialized_end=
|
|
31
|
-
_globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=
|
|
32
|
-
_globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=
|
|
33
|
-
_globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=
|
|
34
|
-
_globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=
|
|
35
|
-
_globals['_STARTRUNRESPONSE']._serialized_start=
|
|
36
|
-
_globals['_STARTRUNRESPONSE']._serialized_end=
|
|
37
|
-
_globals['_STREAMLOGSREQUEST']._serialized_start=
|
|
38
|
-
_globals['_STREAMLOGSREQUEST']._serialized_end=
|
|
39
|
-
_globals['_STREAMLOGSRESPONSE']._serialized_start=
|
|
40
|
-
_globals['_STREAMLOGSRESPONSE']._serialized_end=
|
|
41
|
-
_globals['_EXEC']._serialized_start=
|
|
42
|
-
_globals['_EXEC']._serialized_end=
|
|
30
|
+
_globals['_STARTRUNREQUEST']._serialized_start=88
|
|
31
|
+
_globals['_STARTRUNREQUEST']._serialized_end=439
|
|
32
|
+
_globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=289
|
|
33
|
+
_globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=362
|
|
34
|
+
_globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=364
|
|
35
|
+
_globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=439
|
|
36
|
+
_globals['_STARTRUNRESPONSE']._serialized_start=441
|
|
37
|
+
_globals['_STARTRUNRESPONSE']._serialized_end=475
|
|
38
|
+
_globals['_STREAMLOGSREQUEST']._serialized_start=477
|
|
39
|
+
_globals['_STREAMLOGSREQUEST']._serialized_end=512
|
|
40
|
+
_globals['_STREAMLOGSRESPONSE']._serialized_start=514
|
|
41
|
+
_globals['_STREAMLOGSRESPONSE']._serialized_end=554
|
|
42
|
+
_globals['_EXEC']._serialized_start=557
|
|
43
|
+
_globals['_EXEC']._serialized_end=717
|
|
43
44
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/exec_pb2.pyi
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
isort:skip_file
|
|
4
4
|
"""
|
|
5
5
|
import builtins
|
|
6
|
+
import flwr.proto.fab_pb2
|
|
6
7
|
import flwr.proto.transport_pb2
|
|
7
8
|
import google.protobuf.descriptor
|
|
8
9
|
import google.protobuf.internal.containers
|
|
@@ -44,21 +45,23 @@ class StartRunRequest(google.protobuf.message.Message):
|
|
|
44
45
|
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
|
45
46
|
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
|
46
47
|
|
|
47
|
-
|
|
48
|
+
FAB_FIELD_NUMBER: builtins.int
|
|
48
49
|
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
|
|
49
50
|
FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
|
|
50
|
-
|
|
51
|
+
@property
|
|
52
|
+
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
|
|
51
53
|
@property
|
|
52
54
|
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
53
55
|
@property
|
|
54
56
|
def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
55
57
|
def __init__(self,
|
|
56
58
|
*,
|
|
57
|
-
|
|
59
|
+
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
|
|
58
60
|
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
59
61
|
federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
60
62
|
) -> None: ...
|
|
61
|
-
def
|
|
63
|
+
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
|
|
64
|
+
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
|
|
62
65
|
global___StartRunRequest = StartRunRequest
|
|
63
66
|
|
|
64
67
|
class StartRunResponse(google.protobuf.message.Message):
|
flwr/proto/message_pb2.py
CHANGED
|
@@ -17,7 +17,7 @@ from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
|
|
17
17
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3')
|
|
21
21
|
|
|
22
22
|
_globals = globals()
|
|
23
23
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -37,5 +37,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
37
37
|
_globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497
|
|
38
38
|
_globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565
|
|
39
39
|
_globals['_METADATA']._serialized_start=568
|
|
40
|
-
_globals['_METADATA']._serialized_end=
|
|
40
|
+
_globals['_METADATA']._serialized_end=755
|
|
41
41
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/message_pb2.pyi
CHANGED
|
@@ -99,6 +99,7 @@ class Metadata(google.protobuf.message.Message):
|
|
|
99
99
|
GROUP_ID_FIELD_NUMBER: builtins.int
|
|
100
100
|
TTL_FIELD_NUMBER: builtins.int
|
|
101
101
|
MESSAGE_TYPE_FIELD_NUMBER: builtins.int
|
|
102
|
+
CREATED_AT_FIELD_NUMBER: builtins.int
|
|
102
103
|
run_id: builtins.int
|
|
103
104
|
message_id: typing.Text
|
|
104
105
|
src_node_id: builtins.int
|
|
@@ -107,6 +108,7 @@ class Metadata(google.protobuf.message.Message):
|
|
|
107
108
|
group_id: typing.Text
|
|
108
109
|
ttl: builtins.float
|
|
109
110
|
message_type: typing.Text
|
|
111
|
+
created_at: builtins.float
|
|
110
112
|
def __init__(self,
|
|
111
113
|
*,
|
|
112
114
|
run_id: builtins.int = ...,
|
|
@@ -117,6 +119,7 @@ class Metadata(google.protobuf.message.Message):
|
|
|
117
119
|
group_id: typing.Text = ...,
|
|
118
120
|
ttl: builtins.float = ...,
|
|
119
121
|
message_type: typing.Text = ...,
|
|
122
|
+
created_at: builtins.float = ...,
|
|
120
123
|
) -> None: ...
|
|
121
|
-
def ClearField(self, field_name: typing_extensions.Literal["dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
|
|
124
|
+
def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
|
|
122
125
|
global___Metadata = Metadata
|
flwr/server/app.py
CHANGED
|
@@ -34,6 +34,7 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
34
34
|
|
|
35
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
|
+
from flwr.common.config import get_flwr_dir
|
|
37
38
|
from flwr.common.constant import (
|
|
38
39
|
MISSING_EXTRA_REST,
|
|
39
40
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
@@ -57,6 +58,7 @@ from .server import Server, init_defaults, run_fl
|
|
|
57
58
|
from .server_config import ServerConfig
|
|
58
59
|
from .strategy import Strategy
|
|
59
60
|
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
|
61
|
+
from .superlink.ffs.ffs_factory import FfsFactory
|
|
60
62
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
61
63
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
62
64
|
generic_create_grpc_server,
|
|
@@ -72,6 +74,7 @@ ADDRESS_FLEET_API_GRPC_BIDI = "[::]:8080" # IPv6 to keep start_server compatibl
|
|
|
72
74
|
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"
|
|
73
75
|
|
|
74
76
|
DATABASE = ":flwr-in-memory-state:"
|
|
77
|
+
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
@@ -211,10 +214,14 @@ def run_superlink() -> None:
|
|
|
211
214
|
# Initialize StateFactory
|
|
212
215
|
state_factory = StateFactory(args.database)
|
|
213
216
|
|
|
217
|
+
# Initialize FfsFactory
|
|
218
|
+
ffs_factory = FfsFactory(args.storage_dir)
|
|
219
|
+
|
|
214
220
|
# Start Driver API
|
|
215
221
|
driver_server: grpc.Server = run_driver_api_grpc(
|
|
216
222
|
address=driver_address,
|
|
217
223
|
state_factory=state_factory,
|
|
224
|
+
ffs_factory=ffs_factory,
|
|
218
225
|
certificates=certificates,
|
|
219
226
|
)
|
|
220
227
|
|
|
@@ -610,6 +617,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
610
617
|
"Flower will just create a state in memory.",
|
|
611
618
|
default=DATABASE,
|
|
612
619
|
)
|
|
620
|
+
parser.add_argument(
|
|
621
|
+
"--storage-dir",
|
|
622
|
+
help="The base directory to store the objects for the Flower File System.",
|
|
623
|
+
default=BASE_DIR,
|
|
624
|
+
)
|
|
613
625
|
parser.add_argument(
|
|
614
626
|
"--auth-list-public-keys",
|
|
615
627
|
type=str,
|
|
@@ -24,6 +24,7 @@ from flwr.common.logger import log
|
|
|
24
24
|
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
|
25
25
|
add_DriverServicer_to_server,
|
|
26
26
|
)
|
|
27
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
27
28
|
from flwr.server.superlink.state import StateFactory
|
|
28
29
|
|
|
29
30
|
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
@@ -33,12 +34,14 @@ from .driver_servicer import DriverServicer
|
|
|
33
34
|
def run_driver_api_grpc(
|
|
34
35
|
address: str,
|
|
35
36
|
state_factory: StateFactory,
|
|
37
|
+
ffs_factory: FfsFactory,
|
|
36
38
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
37
39
|
) -> grpc.Server:
|
|
38
40
|
"""Run Driver API (gRPC, request-response)."""
|
|
39
41
|
# Create Driver API gRPC server
|
|
40
42
|
driver_servicer: grpc.Server = DriverServicer(
|
|
41
43
|
state_factory=state_factory,
|
|
44
|
+
ffs_factory=ffs_factory,
|
|
42
45
|
)
|
|
43
46
|
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
|
|
44
47
|
driver_grpc_server = generic_create_grpc_server(
|
|
@@ -43,6 +43,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
43
|
Run,
|
|
44
44
|
)
|
|
45
45
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
46
|
+
from flwr.server.superlink.ffs import Ffs
|
|
47
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
46
48
|
from flwr.server.superlink.state import State, StateFactory
|
|
47
49
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
48
50
|
|
|
@@ -50,8 +52,9 @@ from flwr.server.utils.validator import validate_task_ins_or_res
|
|
|
50
52
|
class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
51
53
|
"""Driver API servicer."""
|
|
52
54
|
|
|
53
|
-
def __init__(self, state_factory: StateFactory) -> None:
|
|
55
|
+
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
|
|
54
56
|
self.state_factory = state_factory
|
|
57
|
+
self.ffs_factory = ffs_factory
|
|
55
58
|
|
|
56
59
|
def GetNodes(
|
|
57
60
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
@@ -71,9 +74,19 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
71
74
|
"""Create run ID."""
|
|
72
75
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
73
76
|
state: State = self.state_factory.state()
|
|
77
|
+
if request.HasField("fab") and request.fab.HasField("content"):
|
|
78
|
+
ffs: Ffs = self.ffs_factory.ffs()
|
|
79
|
+
fab_hash = ffs.put(request.fab.content, {})
|
|
80
|
+
_raise_if(
|
|
81
|
+
fab_hash != request.fab.hash_str,
|
|
82
|
+
f"FAB ({request.fab}) hash from request doesn't match contents",
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
fab_hash = ""
|
|
74
86
|
run_id = state.create_run(
|
|
75
87
|
request.fab_id,
|
|
76
88
|
request.fab_version,
|
|
89
|
+
fab_hash,
|
|
77
90
|
user_config_from_proto(request.override_config),
|
|
78
91
|
)
|
|
79
92
|
return CreateRunResponse(run_id=run_id)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Factory class that creates Ffs instances."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from logging import DEBUG
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from flwr.common.logger import log
|
|
22
|
+
|
|
23
|
+
from .disk_ffs import DiskFfs
|
|
24
|
+
from .ffs import Ffs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FfsFactory:
|
|
28
|
+
"""Factory class that creates Ffs instances.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
base_dir : str
|
|
33
|
+
The base directory used by DiskFfs to store objects.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, base_dir: str) -> None:
|
|
37
|
+
self.base_dir = base_dir
|
|
38
|
+
self.ffs_instance: Optional[Ffs] = None
|
|
39
|
+
|
|
40
|
+
def ffs(self) -> Ffs:
|
|
41
|
+
"""Return a Ffs instance and create it, if necessary."""
|
|
42
|
+
if not self.ffs_instance:
|
|
43
|
+
log(DEBUG, "Initializing DiskFfs")
|
|
44
|
+
self.ffs_instance = DiskFfs(self.base_dir)
|
|
45
|
+
|
|
46
|
+
log(DEBUG, "Using DiskFfs")
|
|
47
|
+
return self.ffs_instance
|
|
@@ -277,11 +277,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
277
277
|
|
|
278
278
|
def create_run(
|
|
279
279
|
self,
|
|
280
|
-
fab_id: str,
|
|
281
|
-
fab_version: str,
|
|
280
|
+
fab_id: Optional[str],
|
|
281
|
+
fab_version: Optional[str],
|
|
282
|
+
fab_hash: Optional[str],
|
|
282
283
|
override_config: UserConfig,
|
|
283
284
|
) -> int:
|
|
284
|
-
"""Create a new run for the specified `
|
|
285
|
+
"""Create a new run for the specified `fab_hash`."""
|
|
285
286
|
# Sample a random int64 as run_id
|
|
286
287
|
with self.lock:
|
|
287
288
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -289,8 +290,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
289
290
|
if run_id not in self.run_ids:
|
|
290
291
|
self.run_ids[run_id] = Run(
|
|
291
292
|
run_id=run_id,
|
|
292
|
-
fab_id=fab_id,
|
|
293
|
-
fab_version=fab_version,
|
|
293
|
+
fab_id=fab_id if fab_id else "",
|
|
294
|
+
fab_version=fab_version if fab_version else "",
|
|
295
|
+
fab_hash=fab_hash if fab_hash else "",
|
|
294
296
|
override_config=override_config,
|
|
295
297
|
)
|
|
296
298
|
return run_id
|
|
@@ -65,6 +65,7 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
65
65
|
run_id INTEGER UNIQUE,
|
|
66
66
|
fab_id TEXT,
|
|
67
67
|
fab_version TEXT,
|
|
68
|
+
fab_hash TEXT,
|
|
68
69
|
override_config TEXT
|
|
69
70
|
);
|
|
70
71
|
"""
|
|
@@ -617,8 +618,9 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
617
618
|
|
|
618
619
|
def create_run(
|
|
619
620
|
self,
|
|
620
|
-
fab_id: str,
|
|
621
|
-
fab_version: str,
|
|
621
|
+
fab_id: Optional[str],
|
|
622
|
+
fab_version: Optional[str],
|
|
623
|
+
fab_hash: Optional[str],
|
|
622
624
|
override_config: UserConfig,
|
|
623
625
|
) -> int:
|
|
624
626
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
@@ -630,12 +632,19 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
630
632
|
# If run_id does not exist
|
|
631
633
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
632
634
|
query = (
|
|
633
|
-
"INSERT INTO run
|
|
634
|
-
"
|
|
635
|
-
|
|
636
|
-
self.query(
|
|
637
|
-
query, (run_id, fab_id, fab_version, json.dumps(override_config))
|
|
635
|
+
"INSERT INTO run "
|
|
636
|
+
"(run_id, fab_id, fab_version, fab_hash, override_config)"
|
|
637
|
+
"VALUES (?, ?, ?, ?, ?);"
|
|
638
638
|
)
|
|
639
|
+
if fab_hash:
|
|
640
|
+
self.query(
|
|
641
|
+
query, (run_id, "", "", fab_hash, json.dumps(override_config))
|
|
642
|
+
)
|
|
643
|
+
else:
|
|
644
|
+
self.query(
|
|
645
|
+
query,
|
|
646
|
+
(run_id, fab_id, fab_version, "", json.dumps(override_config)),
|
|
647
|
+
)
|
|
639
648
|
return run_id
|
|
640
649
|
log(ERROR, "Unexpected run creation failure.")
|
|
641
650
|
return 0
|
|
@@ -702,6 +711,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
702
711
|
run_id=run_id,
|
|
703
712
|
fab_id=row["fab_id"],
|
|
704
713
|
fab_version=row["fab_version"],
|
|
714
|
+
fab_hash=row["fab_hash"],
|
|
705
715
|
override_config=json.loads(row["override_config"]),
|
|
706
716
|
)
|
|
707
717
|
except sqlite3.IntegrityError:
|
|
@@ -159,11 +159,12 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
159
159
|
@abc.abstractmethod
|
|
160
160
|
def create_run(
|
|
161
161
|
self,
|
|
162
|
-
fab_id: str,
|
|
163
|
-
fab_version: str,
|
|
162
|
+
fab_id: Optional[str],
|
|
163
|
+
fab_version: Optional[str],
|
|
164
|
+
fab_hash: Optional[str],
|
|
164
165
|
override_config: UserConfig,
|
|
165
166
|
) -> int:
|
|
166
|
-
"""Create a new run for the specified `
|
|
167
|
+
"""Create a new run for the specified `fab_hash`."""
|
|
167
168
|
|
|
168
169
|
@abc.abstractmethod
|
|
169
170
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
@@ -163,6 +163,7 @@ def run_simulation_from_cli() -> None:
|
|
|
163
163
|
run_id=run_id,
|
|
164
164
|
fab_id="",
|
|
165
165
|
fab_version="",
|
|
166
|
+
fab_hash="",
|
|
166
167
|
override_config=override_config,
|
|
167
168
|
)
|
|
168
169
|
|
|
@@ -529,7 +530,9 @@ def _run_simulation(
|
|
|
529
530
|
# If no `Run` object is set, create one
|
|
530
531
|
if run is None:
|
|
531
532
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
532
|
-
run = Run(
|
|
533
|
+
run = Run(
|
|
534
|
+
run_id=run_id, fab_id="", fab_version="", fab_hash="", override_config={}
|
|
535
|
+
)
|
|
533
536
|
|
|
534
537
|
args = (
|
|
535
538
|
num_supernodes,
|