flwr-nightly 1.11.0.dev20240812__py3-none-any.whl → 1.11.0.dev20240815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. flwr/cli/run/run.py +6 -2
  2. flwr/client/app.py +4 -3
  3. flwr/client/grpc_adapter_client/connection.py +3 -1
  4. flwr/client/grpc_client/connection.py +3 -2
  5. flwr/client/grpc_rere_client/connection.py +8 -2
  6. flwr/client/process/__init__.py +15 -0
  7. flwr/client/process/clientappio_servicer.py +145 -0
  8. flwr/client/rest_client/connection.py +9 -3
  9. flwr/common/config.py +7 -2
  10. flwr/common/record/recordset.py +9 -7
  11. flwr/common/record/typeddict.py +20 -58
  12. flwr/common/recordset_compat.py +6 -6
  13. flwr/common/serde.py +178 -1
  14. flwr/common/typing.py +17 -0
  15. flwr/proto/exec_pb2.py +16 -15
  16. flwr/proto/exec_pb2.pyi +7 -4
  17. flwr/proto/message_pb2.py +2 -2
  18. flwr/proto/message_pb2.pyi +4 -1
  19. flwr/server/app.py +12 -0
  20. flwr/server/driver/grpc_driver.py +1 -0
  21. flwr/server/superlink/driver/driver_grpc.py +3 -0
  22. flwr/server/superlink/driver/driver_servicer.py +14 -1
  23. flwr/server/superlink/ffs/ffs_factory.py +47 -0
  24. flwr/server/superlink/state/in_memory_state.py +7 -5
  25. flwr/server/superlink/state/sqlite_state.py +17 -7
  26. flwr/server/superlink/state/state.py +4 -3
  27. flwr/simulation/run_simulation.py +4 -1
  28. flwr/superexec/exec_servicer.py +1 -1
  29. {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/METADATA +1 -1
  30. {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/RECORD +33 -30
  31. {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/LICENSE +0 -0
  32. {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/WHEEL +0 -0
  33. {flwr_nightly-1.11.0.dev20240812.dist-info → flwr_nightly-1.11.0.dev20240815.dist-info}/entry_points.txt +0 -0
flwr/common/serde.py CHANGED
@@ -20,7 +20,12 @@ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar,
20
20
  from google.protobuf.message import Message as GrpcMessage
21
21
 
22
22
  # pylint: disable=E0611
23
+ from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
23
24
  from flwr.proto.error_pb2 import Error as ProtoError
25
+ from flwr.proto.fab_pb2 import Fab as ProtoFab
26
+ from flwr.proto.message_pb2 import Context as ProtoContext
27
+ from flwr.proto.message_pb2 import Message as ProtoMessage
28
+ from flwr.proto.message_pb2 import Metadata as ProtoMetadata
24
29
  from flwr.proto.node_pb2 import Node
25
30
  from flwr.proto.recordset_pb2 import Array as ProtoArray
26
31
  from flwr.proto.recordset_pb2 import BoolList, BytesList
@@ -32,6 +37,7 @@ from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordVal
32
37
  from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
33
38
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
34
39
  from flwr.proto.recordset_pb2 import Sint64List, StringList
40
+ from flwr.proto.run_pb2 import Run as ProtoRun
35
41
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
36
42
  from flwr.proto.transport_pb2 import (
37
43
  ClientMessage,
@@ -44,7 +50,15 @@ from flwr.proto.transport_pb2 import (
44
50
  )
45
51
 
46
52
  # pylint: enable=E0611
47
- from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing
53
+ from . import (
54
+ Array,
55
+ ConfigsRecord,
56
+ Context,
57
+ MetricsRecord,
58
+ ParametersRecord,
59
+ RecordSet,
60
+ typing,
61
+ )
48
62
  from .message import Error, Message, Metadata
49
63
  from .record.typeddict import TypedDict
50
64
 
@@ -673,6 +687,19 @@ def message_from_taskres(taskres: TaskRes) -> Message:
673
687
  return message
674
688
 
675
689
 
690
+ # === FAB ===
691
+
692
+
693
+ def fab_to_proto(fab: typing.Fab) -> ProtoFab:
694
+ """Create a proto Fab object from a Python Fab."""
695
+ return ProtoFab(hash_str=fab.hash_str, content=fab.content)
696
+
697
+
698
+ def fab_from_proto(fab: ProtoFab) -> typing.Fab:
699
+ """Create a Python Fab object from a proto Fab."""
700
+ return typing.Fab(fab.hash_str, fab.content)
701
+
702
+
676
703
  # === User configs ===
677
704
 
678
705
 
@@ -716,3 +743,153 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
716
743
  scalar_field = scalar_msg.WhichOneof("scalar")
717
744
  scalar = getattr(scalar_msg, cast(str, scalar_field))
718
745
  return cast(typing.UserConfigValue, scalar)
746
+
747
+
748
+ # === Metadata messages ===
749
+
750
+
751
+ def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
752
+ """Serialize `Metadata` to ProtoBuf."""
753
+ proto = ProtoMetadata( # pylint: disable=E1101
754
+ run_id=metadata.run_id,
755
+ message_id=metadata.message_id,
756
+ src_node_id=metadata.src_node_id,
757
+ dst_node_id=metadata.dst_node_id,
758
+ reply_to_message=metadata.reply_to_message,
759
+ group_id=metadata.group_id,
760
+ ttl=metadata.ttl,
761
+ message_type=metadata.message_type,
762
+ created_at=metadata.created_at,
763
+ )
764
+ return proto
765
+
766
+
767
+ def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
768
+ """Deserialize `Metadata` from ProtoBuf."""
769
+ metadata = Metadata(
770
+ run_id=metadata_proto.run_id,
771
+ message_id=metadata_proto.message_id,
772
+ src_node_id=metadata_proto.src_node_id,
773
+ dst_node_id=metadata_proto.dst_node_id,
774
+ reply_to_message=metadata_proto.reply_to_message,
775
+ group_id=metadata_proto.group_id,
776
+ ttl=metadata_proto.ttl,
777
+ message_type=metadata_proto.message_type,
778
+ )
779
+ return metadata
780
+
781
+
782
+ # === Message messages ===
783
+
784
+
785
+ def message_to_proto(message: Message) -> ProtoMessage:
786
+ """Serialize `Message` to ProtoBuf."""
787
+ proto = ProtoMessage(
788
+ metadata=metadata_to_proto(message.metadata),
789
+ content=(
790
+ recordset_to_proto(message.content) if message.has_content() else None
791
+ ),
792
+ error=error_to_proto(message.error) if message.has_error() else None,
793
+ )
794
+ return proto
795
+
796
+
797
+ def message_from_proto(message_proto: ProtoMessage) -> Message:
798
+ """Deserialize `Message` from ProtoBuf."""
799
+ created_at = message_proto.metadata.created_at
800
+ message = Message(
801
+ metadata=metadata_from_proto(message_proto.metadata),
802
+ content=(
803
+ recordset_from_proto(message_proto.content)
804
+ if message_proto.HasField("content")
805
+ else None
806
+ ),
807
+ error=(
808
+ error_from_proto(message_proto.error)
809
+ if message_proto.HasField("error")
810
+ else None
811
+ ),
812
+ )
813
+ # `.created_at` is set upon Message object construction
814
+ # we need to manually set it to the original value
815
+ message.metadata.created_at = created_at
816
+ return message
817
+
818
+
819
+ # === Context messages ===
820
+
821
+
822
+ def context_to_proto(context: Context) -> ProtoContext:
823
+ """Serialize `Context` to ProtoBuf."""
824
+ proto = ProtoContext(
825
+ node_id=context.node_id,
826
+ node_config=user_config_to_proto(context.node_config),
827
+ state=recordset_to_proto(context.state),
828
+ run_config=user_config_to_proto(context.run_config),
829
+ )
830
+ return proto
831
+
832
+
833
+ def context_from_proto(context_proto: ProtoContext) -> Context:
834
+ """Deserialize `Context` from ProtoBuf."""
835
+ context = Context(
836
+ node_id=context_proto.node_id,
837
+ node_config=user_config_from_proto(context_proto.node_config),
838
+ state=recordset_from_proto(context_proto.state),
839
+ run_config=user_config_from_proto(context_proto.run_config),
840
+ )
841
+ return context
842
+
843
+
844
+ # === Run messages ===
845
+
846
+
847
+ def run_to_proto(run: typing.Run) -> ProtoRun:
848
+ """Serialize `Run` to ProtoBuf."""
849
+ proto = ProtoRun(
850
+ run_id=run.run_id,
851
+ fab_id=run.fab_id,
852
+ fab_version=run.fab_version,
853
+ fab_hash=run.fab_hash,
854
+ override_config=user_config_to_proto(run.override_config),
855
+ )
856
+ return proto
857
+
858
+
859
+ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
860
+ """Deserialize `Run` from ProtoBuf."""
861
+ run = typing.Run(
862
+ run_id=run_proto.run_id,
863
+ fab_id=run_proto.fab_id,
864
+ fab_version=run_proto.fab_version,
865
+ fab_hash=run_proto.fab_hash,
866
+ override_config=user_config_from_proto(run_proto.override_config),
867
+ )
868
+ return run
869
+
870
+
871
+ # === ClientApp status messages ===
872
+
873
+
874
+ def clientappstatus_to_proto(
875
+ status: typing.ClientAppOutputStatus,
876
+ ) -> ClientAppOutputStatus:
877
+ """Serialize `ClientAppOutputStatus` to ProtoBuf."""
878
+ code = ClientAppOutputCode.SUCCESS
879
+ if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
880
+ code = ClientAppOutputCode.DEADLINE_EXCEEDED
881
+ if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
882
+ code = ClientAppOutputCode.UNKNOWN_ERROR
883
+ return ClientAppOutputStatus(code=code, message=status.message)
884
+
885
+
886
+ def clientappstatus_from_proto(
887
+ msg: ClientAppOutputStatus,
888
+ ) -> typing.ClientAppOutputStatus:
889
+ """Deserialize `ClientAppOutputStatus` from ProtoBuf."""
890
+ code = typing.ClientAppOutputCode.SUCCESS
891
+ if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
892
+ code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
893
+ if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
894
+ code = typing.ClientAppOutputCode.UNKNOWN_ERROR
895
+ return typing.ClientAppOutputStatus(code=code, message=msg.message)
flwr/common/typing.py CHANGED
@@ -83,6 +83,22 @@ class Status:
83
83
  message: str
84
84
 
85
85
 
86
+ class ClientAppOutputCode(Enum):
87
+ """ClientAppIO status codes."""
88
+
89
+ SUCCESS = 0
90
+ DEADLINE_EXCEEDED = 1
91
+ UNKNOWN_ERROR = 2
92
+
93
+
94
+ @dataclass
95
+ class ClientAppOutputStatus:
96
+ """ClientAppIO status."""
97
+
98
+ code: ClientAppOutputCode
99
+ message: str
100
+
101
+
86
102
  @dataclass
87
103
  class Parameters:
88
104
  """Model parameters."""
@@ -198,6 +214,7 @@ class Run:
198
214
  run_id: int
199
215
  fab_id: str
200
216
  fab_version: str
217
+ fab_hash: str
201
218
  override_config: UserConfig
202
219
 
203
220
 
flwr/proto/exec_pb2.py CHANGED
@@ -12,10 +12,11 @@ from google.protobuf.internal import builder as _builder
12
12
  _sym_db = _symbol_database.Default()
13
13
 
14
14
 
15
+ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
15
16
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
16
17
 
17
18
 
18
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd3\x02\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12L\n\x11\x66\x65\x64\x65ration_config\x18\x03 \x03(\x0b\x32\x31.flwr.proto.StartRunRequest.FederationConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1aK\n\x15\x46\x65\x64\x65rationConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xdf\x02\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12L\n\x11\x66\x65\x64\x65ration_config\x18\x03 \x03(\x0b\x32\x31.flwr.proto.StartRunRequest.FederationConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1aK\n\x15\x46\x65\x64\x65rationConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3')
19
20
 
20
21
  _globals = globals()
21
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -26,18 +27,18 @@ if _descriptor._USE_C_DESCRIPTORS == False:
26
27
  _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
27
28
  _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._options = None
28
29
  _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_options = b'8\001'
29
- _globals['_STARTRUNREQUEST']._serialized_start=66
30
- _globals['_STARTRUNREQUEST']._serialized_end=405
31
- _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=255
32
- _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=328
33
- _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=330
34
- _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=405
35
- _globals['_STARTRUNRESPONSE']._serialized_start=407
36
- _globals['_STARTRUNRESPONSE']._serialized_end=441
37
- _globals['_STREAMLOGSREQUEST']._serialized_start=443
38
- _globals['_STREAMLOGSREQUEST']._serialized_end=478
39
- _globals['_STREAMLOGSRESPONSE']._serialized_start=480
40
- _globals['_STREAMLOGSRESPONSE']._serialized_end=520
41
- _globals['_EXEC']._serialized_start=523
42
- _globals['_EXEC']._serialized_end=683
30
+ _globals['_STARTRUNREQUEST']._serialized_start=88
31
+ _globals['_STARTRUNREQUEST']._serialized_end=439
32
+ _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=289
33
+ _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=362
34
+ _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=364
35
+ _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=439
36
+ _globals['_STARTRUNRESPONSE']._serialized_start=441
37
+ _globals['_STARTRUNRESPONSE']._serialized_end=475
38
+ _globals['_STREAMLOGSREQUEST']._serialized_start=477
39
+ _globals['_STREAMLOGSREQUEST']._serialized_end=512
40
+ _globals['_STREAMLOGSRESPONSE']._serialized_start=514
41
+ _globals['_STREAMLOGSRESPONSE']._serialized_end=554
42
+ _globals['_EXEC']._serialized_start=557
43
+ _globals['_EXEC']._serialized_end=717
43
44
  # @@protoc_insertion_point(module_scope)
flwr/proto/exec_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.fab_pb2
6
7
  import flwr.proto.transport_pb2
7
8
  import google.protobuf.descriptor
8
9
  import google.protobuf.internal.containers
@@ -44,21 +45,23 @@ class StartRunRequest(google.protobuf.message.Message):
44
45
  def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
45
46
  def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
46
47
 
47
- FAB_FILE_FIELD_NUMBER: builtins.int
48
+ FAB_FIELD_NUMBER: builtins.int
48
49
  OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
49
50
  FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
50
- fab_file: builtins.bytes
51
+ @property
52
+ def fab(self) -> flwr.proto.fab_pb2.Fab: ...
51
53
  @property
52
54
  def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
53
55
  @property
54
56
  def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
55
57
  def __init__(self,
56
58
  *,
57
- fab_file: builtins.bytes = ...,
59
+ fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
58
60
  override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
59
61
  federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
60
62
  ) -> None: ...
61
- def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
63
+ def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
64
+ def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
62
65
  global___StartRunRequest = StartRunRequest
63
66
 
64
67
  class StartRunResponse(google.protobuf.message.Message):
flwr/proto/message_pb2.py CHANGED
@@ -17,7 +17,7 @@ from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
17
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
18
 
19
19
 
20
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xa7\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\tb\x06proto3')
20
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3')
21
21
 
22
22
  _globals = globals()
23
23
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -37,5 +37,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
37
37
  _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497
38
38
  _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565
39
39
  _globals['_METADATA']._serialized_start=568
40
- _globals['_METADATA']._serialized_end=735
40
+ _globals['_METADATA']._serialized_end=755
41
41
  # @@protoc_insertion_point(module_scope)
@@ -99,6 +99,7 @@ class Metadata(google.protobuf.message.Message):
99
99
  GROUP_ID_FIELD_NUMBER: builtins.int
100
100
  TTL_FIELD_NUMBER: builtins.int
101
101
  MESSAGE_TYPE_FIELD_NUMBER: builtins.int
102
+ CREATED_AT_FIELD_NUMBER: builtins.int
102
103
  run_id: builtins.int
103
104
  message_id: typing.Text
104
105
  src_node_id: builtins.int
@@ -107,6 +108,7 @@ class Metadata(google.protobuf.message.Message):
107
108
  group_id: typing.Text
108
109
  ttl: builtins.float
109
110
  message_type: typing.Text
111
+ created_at: builtins.float
110
112
  def __init__(self,
111
113
  *,
112
114
  run_id: builtins.int = ...,
@@ -117,6 +119,7 @@ class Metadata(google.protobuf.message.Message):
117
119
  group_id: typing.Text = ...,
118
120
  ttl: builtins.float = ...,
119
121
  message_type: typing.Text = ...,
122
+ created_at: builtins.float = ...,
120
123
  ) -> None: ...
121
- def ClearField(self, field_name: typing_extensions.Literal["dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
124
+ def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
122
125
  global___Metadata = Metadata
flwr/server/app.py CHANGED
@@ -34,6 +34,7 @@ from cryptography.hazmat.primitives.serialization import (
34
34
 
35
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
36
36
  from flwr.common.address import parse_address
37
+ from flwr.common.config import get_flwr_dir
37
38
  from flwr.common.constant import (
38
39
  MISSING_EXTRA_REST,
39
40
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -57,6 +58,7 @@ from .server import Server, init_defaults, run_fl
57
58
  from .server_config import ServerConfig
58
59
  from .strategy import Strategy
59
60
  from .superlink.driver.driver_grpc import run_driver_api_grpc
61
+ from .superlink.ffs.ffs_factory import FfsFactory
60
62
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
61
63
  from .superlink.fleet.grpc_bidi.grpc_server import (
62
64
  generic_create_grpc_server,
@@ -72,6 +74,7 @@ ADDRESS_FLEET_API_GRPC_BIDI = "[::]:8080" # IPv6 to keep start_server compatibl
72
74
  ADDRESS_FLEET_API_REST = "0.0.0.0:9093"
73
75
 
74
76
  DATABASE = ":flwr-in-memory-state:"
77
+ BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
75
78
 
76
79
 
77
80
  def start_server( # pylint: disable=too-many-arguments,too-many-locals
@@ -211,10 +214,14 @@ def run_superlink() -> None:
211
214
  # Initialize StateFactory
212
215
  state_factory = StateFactory(args.database)
213
216
 
217
+ # Initialize FfsFactory
218
+ ffs_factory = FfsFactory(args.storage_dir)
219
+
214
220
  # Start Driver API
215
221
  driver_server: grpc.Server = run_driver_api_grpc(
216
222
  address=driver_address,
217
223
  state_factory=state_factory,
224
+ ffs_factory=ffs_factory,
218
225
  certificates=certificates,
219
226
  )
220
227
 
@@ -610,6 +617,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
610
617
  "Flower will just create a state in memory.",
611
618
  default=DATABASE,
612
619
  )
620
+ parser.add_argument(
621
+ "--storage-dir",
622
+ help="The base directory to store the objects for the Flower File System.",
623
+ default=BASE_DIR,
624
+ )
613
625
  parser.add_argument(
614
626
  "--auth-list-public-keys",
615
627
  type=str,
@@ -131,6 +131,7 @@ class GrpcDriver(Driver):
131
131
  run_id=res.run.run_id,
132
132
  fab_id=res.run.fab_id,
133
133
  fab_version=res.run.fab_version,
134
+ fab_hash=res.run.fab_hash,
134
135
  override_config=user_config_from_proto(res.run.override_config),
135
136
  )
136
137
 
@@ -24,6 +24,7 @@ from flwr.common.logger import log
24
24
  from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
25
  add_DriverServicer_to_server,
26
26
  )
27
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
27
28
  from flwr.server.superlink.state import StateFactory
28
29
 
29
30
  from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
@@ -33,12 +34,14 @@ from .driver_servicer import DriverServicer
33
34
  def run_driver_api_grpc(
34
35
  address: str,
35
36
  state_factory: StateFactory,
37
+ ffs_factory: FfsFactory,
36
38
  certificates: Optional[Tuple[bytes, bytes, bytes]],
37
39
  ) -> grpc.Server:
38
40
  """Run Driver API (gRPC, request-response)."""
39
41
  # Create Driver API gRPC server
40
42
  driver_servicer: grpc.Server = DriverServicer(
41
43
  state_factory=state_factory,
44
+ ffs_factory=ffs_factory,
42
45
  )
43
46
  driver_add_servicer_to_server_fn = add_DriverServicer_to_server
44
47
  driver_grpc_server = generic_create_grpc_server(
@@ -43,6 +43,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
43
43
  Run,
44
44
  )
45
45
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
46
+ from flwr.server.superlink.ffs import Ffs
47
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
46
48
  from flwr.server.superlink.state import State, StateFactory
47
49
  from flwr.server.utils.validator import validate_task_ins_or_res
48
50
 
@@ -50,8 +52,9 @@ from flwr.server.utils.validator import validate_task_ins_or_res
50
52
  class DriverServicer(driver_pb2_grpc.DriverServicer):
51
53
  """Driver API servicer."""
52
54
 
53
- def __init__(self, state_factory: StateFactory) -> None:
55
+ def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
54
56
  self.state_factory = state_factory
57
+ self.ffs_factory = ffs_factory
55
58
 
56
59
  def GetNodes(
57
60
  self, request: GetNodesRequest, context: grpc.ServicerContext
@@ -71,9 +74,19 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
71
74
  """Create run ID."""
72
75
  log(DEBUG, "DriverServicer.CreateRun")
73
76
  state: State = self.state_factory.state()
77
+ if request.HasField("fab") and request.fab.HasField("content"):
78
+ ffs: Ffs = self.ffs_factory.ffs()
79
+ fab_hash = ffs.put(request.fab.content, {})
80
+ _raise_if(
81
+ fab_hash != request.fab.hash_str,
82
+ f"FAB ({request.fab}) hash from request doesn't match contents",
83
+ )
84
+ else:
85
+ fab_hash = ""
74
86
  run_id = state.create_run(
75
87
  request.fab_id,
76
88
  request.fab_version,
89
+ fab_hash,
77
90
  user_config_from_proto(request.override_config),
78
91
  )
79
92
  return CreateRunResponse(run_id=run_id)
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Factory class that creates Ffs instances."""
16
+
17
+
18
+ from logging import DEBUG
19
+ from typing import Optional
20
+
21
+ from flwr.common.logger import log
22
+
23
+ from .disk_ffs import DiskFfs
24
+ from .ffs import Ffs
25
+
26
+
27
+ class FfsFactory:
28
+ """Factory class that creates Ffs instances.
29
+
30
+ Parameters
31
+ ----------
32
+ base_dir : str
33
+ The base directory used by DiskFfs to store objects.
34
+ """
35
+
36
+ def __init__(self, base_dir: str) -> None:
37
+ self.base_dir = base_dir
38
+ self.ffs_instance: Optional[Ffs] = None
39
+
40
+ def ffs(self) -> Ffs:
41
+ """Return a Ffs instance and create it, if necessary."""
42
+ if not self.ffs_instance:
43
+ log(DEBUG, "Initializing DiskFfs")
44
+ self.ffs_instance = DiskFfs(self.base_dir)
45
+
46
+ log(DEBUG, "Using DiskFfs")
47
+ return self.ffs_instance
@@ -277,11 +277,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
277
277
 
278
278
  def create_run(
279
279
  self,
280
- fab_id: str,
281
- fab_version: str,
280
+ fab_id: Optional[str],
281
+ fab_version: Optional[str],
282
+ fab_hash: Optional[str],
282
283
  override_config: UserConfig,
283
284
  ) -> int:
284
- """Create a new run for the specified `fab_id` and `fab_version`."""
285
+ """Create a new run for the specified `fab_hash`."""
285
286
  # Sample a random int64 as run_id
286
287
  with self.lock:
287
288
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
@@ -289,8 +290,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
289
290
  if run_id not in self.run_ids:
290
291
  self.run_ids[run_id] = Run(
291
292
  run_id=run_id,
292
- fab_id=fab_id,
293
- fab_version=fab_version,
293
+ fab_id=fab_id if fab_id else "",
294
+ fab_version=fab_version if fab_version else "",
295
+ fab_hash=fab_hash if fab_hash else "",
294
296
  override_config=override_config,
295
297
  )
296
298
  return run_id
@@ -65,6 +65,7 @@ CREATE TABLE IF NOT EXISTS run(
65
65
  run_id INTEGER UNIQUE,
66
66
  fab_id TEXT,
67
67
  fab_version TEXT,
68
+ fab_hash TEXT,
68
69
  override_config TEXT
69
70
  );
70
71
  """
@@ -617,8 +618,9 @@ class SqliteState(State): # pylint: disable=R0904
617
618
 
618
619
  def create_run(
619
620
  self,
620
- fab_id: str,
621
- fab_version: str,
621
+ fab_id: Optional[str],
622
+ fab_version: Optional[str],
623
+ fab_hash: Optional[str],
622
624
  override_config: UserConfig,
623
625
  ) -> int:
624
626
  """Create a new run for the specified `fab_id` and `fab_version`."""
@@ -630,12 +632,19 @@ class SqliteState(State): # pylint: disable=R0904
630
632
  # If run_id does not exist
631
633
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
632
634
  query = (
633
- "INSERT INTO run (run_id, fab_id, fab_version, override_config)"
634
- "VALUES (?, ?, ?, ?);"
635
- )
636
- self.query(
637
- query, (run_id, fab_id, fab_version, json.dumps(override_config))
635
+ "INSERT INTO run "
636
+ "(run_id, fab_id, fab_version, fab_hash, override_config)"
637
+ "VALUES (?, ?, ?, ?, ?);"
638
638
  )
639
+ if fab_hash:
640
+ self.query(
641
+ query, (run_id, "", "", fab_hash, json.dumps(override_config))
642
+ )
643
+ else:
644
+ self.query(
645
+ query,
646
+ (run_id, fab_id, fab_version, "", json.dumps(override_config)),
647
+ )
639
648
  return run_id
640
649
  log(ERROR, "Unexpected run creation failure.")
641
650
  return 0
@@ -702,6 +711,7 @@ class SqliteState(State): # pylint: disable=R0904
702
711
  run_id=run_id,
703
712
  fab_id=row["fab_id"],
704
713
  fab_version=row["fab_version"],
714
+ fab_hash=row["fab_hash"],
705
715
  override_config=json.loads(row["override_config"]),
706
716
  )
707
717
  except sqlite3.IntegrityError:
@@ -159,11 +159,12 @@ class State(abc.ABC): # pylint: disable=R0904
159
159
  @abc.abstractmethod
160
160
  def create_run(
161
161
  self,
162
- fab_id: str,
163
- fab_version: str,
162
+ fab_id: Optional[str],
163
+ fab_version: Optional[str],
164
+ fab_hash: Optional[str],
164
165
  override_config: UserConfig,
165
166
  ) -> int:
166
- """Create a new run for the specified `fab_id` and `fab_version`."""
167
+ """Create a new run for the specified `fab_hash`."""
167
168
 
168
169
  @abc.abstractmethod
169
170
  def get_run(self, run_id: int) -> Optional[Run]:
@@ -163,6 +163,7 @@ def run_simulation_from_cli() -> None:
163
163
  run_id=run_id,
164
164
  fab_id="",
165
165
  fab_version="",
166
+ fab_hash="",
166
167
  override_config=override_config,
167
168
  )
168
169
 
@@ -529,7 +530,9 @@ def _run_simulation(
529
530
  # If no `Run` object is set, create one
530
531
  if run is None:
531
532
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
532
- run = Run(run_id=run_id, fab_id="", fab_version="", override_config={})
533
+ run = Run(
534
+ run_id=run_id, fab_id="", fab_version="", fab_hash="", override_config={}
535
+ )
533
536
 
534
537
  args = (
535
538
  num_supernodes,