flwr-nightly 1.19.0.dev20250611__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +47 -29
- flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
- flwr/client/rest_client/connection.py +70 -51
- flwr/common/inflatable.py +24 -0
- flwr/common/serde.py +2 -0
- flwr/common/typing.py +2 -0
- flwr/proto/fleet_pb2.py +12 -16
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/fleet_pb2_grpc.py +34 -0
- flwr/proto/fleet_pb2_grpc.pyi +13 -0
- flwr/proto/message_pb2.py +15 -9
- flwr/proto/message_pb2.pyi +41 -0
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +22 -26
- flwr/proto/serverappio_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/grid/grpc_grid.py +20 -9
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +26 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
- flwr/server/superlink/linkstate/linkstate.py +6 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
- flwr/server/superlink/utils.py +23 -10
- flwr/supercore/object_store/in_memory_object_store.py +160 -33
- flwr/supercore/object_store/object_store.py +54 -7
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +38 -38
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
flwr/proto/run_pb2.py
CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
|
|
18
18
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
19
19
|
|
20
20
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xe0\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x10\n\x08\x66lwr_aid\x18\x0b \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"T\n\x1cGetFederationOptionsResponse\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x18.flwr.proto.ConfigRecordb\x06proto3')
|
22
22
|
|
23
23
|
_globals = globals()
|
24
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
@@ -30,27 +30,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
30
30
|
_globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None
|
31
31
|
_globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001'
|
32
32
|
_globals['_RUN']._serialized_start=139
|
33
|
-
_globals['_RUN']._serialized_end=
|
34
|
-
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=
|
35
|
-
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=
|
36
|
-
_globals['_RUNSTATUS']._serialized_start=
|
37
|
-
_globals['_RUNSTATUS']._serialized_end=
|
38
|
-
_globals['_GETRUNREQUEST']._serialized_start=
|
39
|
-
_globals['_GETRUNREQUEST']._serialized_end=
|
40
|
-
_globals['_GETRUNRESPONSE']._serialized_start=
|
41
|
-
_globals['_GETRUNRESPONSE']._serialized_end=
|
42
|
-
_globals['_UPDATERUNSTATUSREQUEST']._serialized_start=
|
43
|
-
_globals['_UPDATERUNSTATUSREQUEST']._serialized_end=
|
44
|
-
_globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=
|
45
|
-
_globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=
|
46
|
-
_globals['_GETRUNSTATUSREQUEST']._serialized_start=
|
47
|
-
_globals['_GETRUNSTATUSREQUEST']._serialized_end=
|
48
|
-
_globals['_GETRUNSTATUSRESPONSE']._serialized_start=
|
49
|
-
_globals['_GETRUNSTATUSRESPONSE']._serialized_end=
|
50
|
-
_globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=
|
51
|
-
_globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=
|
52
|
-
_globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=
|
53
|
-
_globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=
|
54
|
-
_globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=
|
55
|
-
_globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=
|
33
|
+
_globals['_RUN']._serialized_end=491
|
34
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=418
|
35
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=491
|
36
|
+
_globals['_RUNSTATUS']._serialized_start=493
|
37
|
+
_globals['_RUNSTATUS']._serialized_end=557
|
38
|
+
_globals['_GETRUNREQUEST']._serialized_start=559
|
39
|
+
_globals['_GETRUNREQUEST']._serialized_end=622
|
40
|
+
_globals['_GETRUNRESPONSE']._serialized_start=624
|
41
|
+
_globals['_GETRUNRESPONSE']._serialized_end=670
|
42
|
+
_globals['_UPDATERUNSTATUSREQUEST']._serialized_start=672
|
43
|
+
_globals['_UPDATERUNSTATUSREQUEST']._serialized_end=755
|
44
|
+
_globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=757
|
45
|
+
_globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=782
|
46
|
+
_globals['_GETRUNSTATUSREQUEST']._serialized_start=784
|
47
|
+
_globals['_GETRUNSTATUSREQUEST']._serialized_end=854
|
48
|
+
_globals['_GETRUNSTATUSRESPONSE']._serialized_start=857
|
49
|
+
_globals['_GETRUNSTATUSRESPONSE']._serialized_end=1034
|
50
|
+
_globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=959
|
51
|
+
_globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1034
|
52
|
+
_globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1036
|
53
|
+
_globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1081
|
54
|
+
_globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1083
|
55
|
+
_globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1167
|
56
56
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/run_pb2.pyi
CHANGED
@@ -41,6 +41,7 @@ class Run(google.protobuf.message.Message):
|
|
41
41
|
RUNNING_AT_FIELD_NUMBER: builtins.int
|
42
42
|
FINISHED_AT_FIELD_NUMBER: builtins.int
|
43
43
|
STATUS_FIELD_NUMBER: builtins.int
|
44
|
+
FLWR_AID_FIELD_NUMBER: builtins.int
|
44
45
|
run_id: builtins.int
|
45
46
|
fab_id: typing.Text
|
46
47
|
fab_version: typing.Text
|
@@ -53,6 +54,7 @@ class Run(google.protobuf.message.Message):
|
|
53
54
|
finished_at: typing.Text
|
54
55
|
@property
|
55
56
|
def status(self) -> global___RunStatus: ...
|
57
|
+
flwr_aid: typing.Text
|
56
58
|
def __init__(self,
|
57
59
|
*,
|
58
60
|
run_id: builtins.int = ...,
|
@@ -65,9 +67,10 @@ class Run(google.protobuf.message.Message):
|
|
65
67
|
running_at: typing.Text = ...,
|
66
68
|
finished_at: typing.Text = ...,
|
67
69
|
status: typing.Optional[global___RunStatus] = ...,
|
70
|
+
flwr_aid: typing.Text = ...,
|
68
71
|
) -> None: ...
|
69
72
|
def HasField(self, field_name: typing_extensions.Literal["status",b"status"]) -> builtins.bool: ...
|
70
|
-
def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","finished_at",b"finished_at","override_config",b"override_config","pending_at",b"pending_at","run_id",b"run_id","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ...
|
73
|
+
def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","finished_at",b"finished_at","flwr_aid",b"flwr_aid","override_config",b"override_config","pending_at",b"pending_at","run_id",b"run_id","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ...
|
71
74
|
global___Run = Run
|
72
75
|
|
73
76
|
class RunStatus(google.protobuf.message.Message):
|
flwr/proto/serverappio_pb2.py
CHANGED
@@ -20,15 +20,13 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
20
20
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
21
21
|
|
22
22
|
|
23
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\
|
23
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x8a\x01\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xd7\t\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
|
24
24
|
|
25
25
|
_globals = globals()
|
26
26
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
27
27
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.serverappio_pb2', _globals)
|
28
28
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
29
29
|
DESCRIPTOR._options = None
|
30
|
-
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._options = None
|
31
|
-
_globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_options = b'8\001'
|
32
30
|
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._options = None
|
33
31
|
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_options = b'8\001'
|
34
32
|
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._options = None
|
@@ -38,27 +36,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
38
36
|
_globals['_GETNODESRESPONSE']._serialized_start=222
|
39
37
|
_globals['_GETNODESRESPONSE']._serialized_end=273
|
40
38
|
_globals['_PUSHINSMESSAGESREQUEST']._serialized_start=276
|
41
|
-
_globals['_PUSHINSMESSAGESREQUEST']._serialized_end=
|
42
|
-
_globals['
|
43
|
-
_globals['
|
44
|
-
_globals['
|
45
|
-
_globals['
|
46
|
-
_globals['
|
47
|
-
_globals['
|
48
|
-
_globals['
|
49
|
-
_globals['
|
50
|
-
_globals['
|
51
|
-
_globals['
|
52
|
-
_globals['
|
53
|
-
_globals['
|
54
|
-
_globals['
|
55
|
-
_globals['
|
56
|
-
_globals['
|
57
|
-
_globals['
|
58
|
-
_globals['
|
59
|
-
_globals['
|
60
|
-
_globals['
|
61
|
-
_globals['
|
62
|
-
_globals['_SERVERAPPIO']._serialized_start=1324
|
63
|
-
_globals['_SERVERAPPIO']._serialized_end=2448
|
39
|
+
_globals['_PUSHINSMESSAGESREQUEST']._serialized_end=414
|
40
|
+
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=417
|
41
|
+
_globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=621
|
42
|
+
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=546
|
43
|
+
_globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=621
|
44
|
+
_globals['_PULLRESMESSAGESREQUEST']._serialized_start=623
|
45
|
+
_globals['_PULLRESMESSAGESREQUEST']._serialized_end=684
|
46
|
+
_globals['_PULLRESMESSAGESRESPONSE']._serialized_start=687
|
47
|
+
_globals['_PULLRESMESSAGESRESPONSE']._serialized_end=914
|
48
|
+
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=839
|
49
|
+
_globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=914
|
50
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=916
|
51
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=944
|
52
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=946
|
53
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1073
|
54
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1075
|
55
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1158
|
56
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1160
|
57
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1190
|
58
|
+
_globals['_SERVERAPPIO']._serialized_start=1193
|
59
|
+
_globals['_SERVERAPPIO']._serialized_end=2432
|
64
60
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/serverappio_pb2.pyi
CHANGED
@@ -42,36 +42,21 @@ global___GetNodesResponse = GetNodesResponse
|
|
42
42
|
class PushInsMessagesRequest(google.protobuf.message.Message):
|
43
43
|
"""PushMessages messages"""
|
44
44
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
45
|
-
class MsgToDescendantMappingEntry(google.protobuf.message.Message):
|
46
|
-
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
47
|
-
KEY_FIELD_NUMBER: builtins.int
|
48
|
-
VALUE_FIELD_NUMBER: builtins.int
|
49
|
-
key: typing.Text
|
50
|
-
@property
|
51
|
-
def value(self) -> flwr.proto.message_pb2.ObjectIDs: ...
|
52
|
-
def __init__(self,
|
53
|
-
*,
|
54
|
-
key: typing.Text = ...,
|
55
|
-
value: typing.Optional[flwr.proto.message_pb2.ObjectIDs] = ...,
|
56
|
-
) -> None: ...
|
57
|
-
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
58
|
-
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
59
|
-
|
60
45
|
MESSAGES_LIST_FIELD_NUMBER: builtins.int
|
61
46
|
RUN_ID_FIELD_NUMBER: builtins.int
|
62
|
-
|
47
|
+
MESSAGE_OBJECT_TREES_FIELD_NUMBER: builtins.int
|
63
48
|
@property
|
64
49
|
def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
|
65
50
|
run_id: builtins.int
|
66
51
|
@property
|
67
|
-
def
|
52
|
+
def message_object_trees(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.ObjectTree]: ...
|
68
53
|
def __init__(self,
|
69
54
|
*,
|
70
55
|
messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
|
71
56
|
run_id: builtins.int = ...,
|
72
|
-
|
57
|
+
message_object_trees: typing.Optional[typing.Iterable[flwr.proto.message_pb2.ObjectTree]] = ...,
|
73
58
|
) -> None: ...
|
74
|
-
def ClearField(self, field_name: typing_extensions.Literal["
|
59
|
+
def ClearField(self, field_name: typing_extensions.Literal["message_object_trees",b"message_object_trees","messages_list",b"messages_list","run_id",b"run_id"]) -> None: ...
|
75
60
|
global___PushInsMessagesRequest = PushInsMessagesRequest
|
76
61
|
|
77
62
|
class PushInsMessagesResponse(google.protobuf.message.Message):
|
@@ -84,6 +84,11 @@ class ServerAppIoStub(object):
|
|
84
84
|
request_serializer=flwr_dot_proto_dot_message__pb2.PullObjectRequest.SerializeToString,
|
85
85
|
response_deserializer=flwr_dot_proto_dot_message__pb2.PullObjectResponse.FromString,
|
86
86
|
)
|
87
|
+
self.ConfirmMessageReceived = channel.unary_unary(
|
88
|
+
'/flwr.proto.ServerAppIo/ConfirmMessageReceived',
|
89
|
+
request_serializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedRequest.SerializeToString,
|
90
|
+
response_deserializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedResponse.FromString,
|
91
|
+
)
|
87
92
|
|
88
93
|
|
89
94
|
class ServerAppIoServicer(object):
|
@@ -180,6 +185,13 @@ class ServerAppIoServicer(object):
|
|
180
185
|
context.set_details('Method not implemented!')
|
181
186
|
raise NotImplementedError('Method not implemented!')
|
182
187
|
|
188
|
+
def ConfirmMessageReceived(self, request, context):
|
189
|
+
"""Confirm Message Received
|
190
|
+
"""
|
191
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
192
|
+
context.set_details('Method not implemented!')
|
193
|
+
raise NotImplementedError('Method not implemented!')
|
194
|
+
|
183
195
|
|
184
196
|
def add_ServerAppIoServicer_to_server(servicer, server):
|
185
197
|
rpc_method_handlers = {
|
@@ -248,6 +260,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
|
|
248
260
|
request_deserializer=flwr_dot_proto_dot_message__pb2.PullObjectRequest.FromString,
|
249
261
|
response_serializer=flwr_dot_proto_dot_message__pb2.PullObjectResponse.SerializeToString,
|
250
262
|
),
|
263
|
+
'ConfirmMessageReceived': grpc.unary_unary_rpc_method_handler(
|
264
|
+
servicer.ConfirmMessageReceived,
|
265
|
+
request_deserializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedRequest.FromString,
|
266
|
+
response_serializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedResponse.SerializeToString,
|
267
|
+
),
|
251
268
|
}
|
252
269
|
generic_handler = grpc.method_handlers_generic_handler(
|
253
270
|
'flwr.proto.ServerAppIo', rpc_method_handlers)
|
@@ -478,3 +495,20 @@ class ServerAppIo(object):
|
|
478
495
|
flwr_dot_proto_dot_message__pb2.PullObjectResponse.FromString,
|
479
496
|
options, channel_credentials,
|
480
497
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
498
|
+
|
499
|
+
@staticmethod
|
500
|
+
def ConfirmMessageReceived(request,
|
501
|
+
target,
|
502
|
+
options=(),
|
503
|
+
channel_credentials=None,
|
504
|
+
call_credentials=None,
|
505
|
+
insecure=False,
|
506
|
+
compression=None,
|
507
|
+
wait_for_ready=None,
|
508
|
+
timeout=None,
|
509
|
+
metadata=None):
|
510
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/ConfirmMessageReceived',
|
511
|
+
flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedRequest.SerializeToString,
|
512
|
+
flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedResponse.FromString,
|
513
|
+
options, channel_credentials,
|
514
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
@@ -78,6 +78,11 @@ class ServerAppIoStub:
|
|
78
78
|
flwr.proto.message_pb2.PullObjectResponse]
|
79
79
|
"""Pull Object"""
|
80
80
|
|
81
|
+
ConfirmMessageReceived: grpc.UnaryUnaryMultiCallable[
|
82
|
+
flwr.proto.message_pb2.ConfirmMessageReceivedRequest,
|
83
|
+
flwr.proto.message_pb2.ConfirmMessageReceivedResponse]
|
84
|
+
"""Confirm Message Received"""
|
85
|
+
|
81
86
|
|
82
87
|
class ServerAppIoServicer(metaclass=abc.ABCMeta):
|
83
88
|
@abc.abstractmethod
|
@@ -184,5 +189,13 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
|
|
184
189
|
"""Pull Object"""
|
185
190
|
pass
|
186
191
|
|
192
|
+
@abc.abstractmethod
|
193
|
+
def ConfirmMessageReceived(self,
|
194
|
+
request: flwr.proto.message_pb2.ConfirmMessageReceivedRequest,
|
195
|
+
context: grpc.ServicerContext,
|
196
|
+
) -> flwr.proto.message_pb2.ConfirmMessageReceivedResponse:
|
197
|
+
"""Confirm Message Received"""
|
198
|
+
pass
|
199
|
+
|
187
200
|
|
188
201
|
def add_ServerAppIoServicer_to_server(servicer: ServerAppIoServicer, server: grpc.Server) -> None: ...
|
flwr/server/grid/grpc_grid.py
CHANGED
@@ -28,7 +28,11 @@ from flwr.common.constant import (
|
|
28
28
|
SUPERLINK_NODE_ID,
|
29
29
|
)
|
30
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
31
|
-
from flwr.common.inflatable import
|
31
|
+
from flwr.common.inflatable import (
|
32
|
+
get_all_nested_objects,
|
33
|
+
get_object_tree,
|
34
|
+
no_object_id_recompute,
|
35
|
+
)
|
32
36
|
from flwr.common.inflatable_grpc_utils import (
|
33
37
|
make_pull_object_fn_grpc,
|
34
38
|
make_push_object_fn_grpc,
|
@@ -43,7 +47,9 @@ from flwr.common.message import remove_content_from_message
|
|
43
47
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
44
48
|
from flwr.common.serde import message_to_proto, run_from_proto
|
45
49
|
from flwr.common.typing import Run
|
46
|
-
from flwr.proto.message_pb2 import
|
50
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
51
|
+
ConfirmMessageReceivedRequest,
|
52
|
+
)
|
47
53
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
48
54
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
49
55
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
@@ -213,18 +219,15 @@ class GrpcGrid(Grid):
|
|
213
219
|
"""Push one message and its associated objects."""
|
214
220
|
# Compute mapping of message descendants
|
215
221
|
all_objects = get_all_nested_objects(message)
|
216
|
-
|
217
|
-
|
218
|
-
descendant_ids = all_object_ids[:-1] # All but the last object are descendants
|
222
|
+
msg_id = message.object_id
|
223
|
+
object_tree = get_object_tree(message)
|
219
224
|
|
220
225
|
# Call GrpcServerAppIoStub method
|
221
226
|
res: PushInsMessagesResponse = self._stub.PushMessages(
|
222
227
|
PushInsMessagesRequest(
|
223
228
|
messages_list=[message_to_proto(remove_content_from_message(message))],
|
224
229
|
run_id=run_id,
|
225
|
-
|
226
|
-
msg_id: ObjectIDs(object_ids=descendant_ids)
|
227
|
-
},
|
230
|
+
message_object_trees=[object_tree],
|
228
231
|
)
|
229
232
|
)
|
230
233
|
|
@@ -262,7 +265,8 @@ class GrpcGrid(Grid):
|
|
262
265
|
# Check message
|
263
266
|
self._check_message(msg)
|
264
267
|
# Try pushing message and its objects
|
265
|
-
|
268
|
+
with no_object_id_recompute():
|
269
|
+
message_ids.append(self._try_push_message(run_id, msg))
|
266
270
|
|
267
271
|
except grpc.RpcError as e:
|
268
272
|
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
@@ -308,6 +312,13 @@ class GrpcGrid(Grid):
|
|
308
312
|
run_id=run_id,
|
309
313
|
),
|
310
314
|
)
|
315
|
+
|
316
|
+
# Confirm that the message has been received
|
317
|
+
self._stub.ConfirmMessageReceived(
|
318
|
+
ConfirmMessageReceivedRequest(
|
319
|
+
node=self.node, run_id=run_id, message_object_id=msg_id
|
320
|
+
)
|
321
|
+
)
|
311
322
|
message = cast(
|
312
323
|
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
313
324
|
)
|
@@ -40,6 +40,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
40
40
|
SendNodeHeartbeatResponse,
|
41
41
|
)
|
42
42
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
43
|
+
ConfirmMessageReceivedRequest,
|
44
|
+
ConfirmMessageReceivedResponse,
|
43
45
|
PullObjectRequest,
|
44
46
|
PullObjectResponse,
|
45
47
|
PushObjectRequest,
|
@@ -151,6 +153,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
151
153
|
res = message_handler.get_run(
|
152
154
|
request=request,
|
153
155
|
state=self.state_factory.state(),
|
156
|
+
store=self.objectstore_factory.store(),
|
154
157
|
)
|
155
158
|
except InvalidRunStatusException as e:
|
156
159
|
abort_grpc_context(e.message, context)
|
@@ -167,6 +170,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
167
170
|
request=request,
|
168
171
|
ffs=self.ffs_factory.ffs(),
|
169
172
|
state=self.state_factory.state(),
|
173
|
+
store=self.objectstore_factory.store(),
|
170
174
|
)
|
171
175
|
except InvalidRunStatusException as e:
|
172
176
|
abort_grpc_context(e.message, context)
|
@@ -219,3 +223,24 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
219
223
|
abort_grpc_context(e.message, context)
|
220
224
|
|
221
225
|
return res
|
226
|
+
|
227
|
+
def ConfirmMessageReceived(
|
228
|
+
self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
|
229
|
+
) -> ConfirmMessageReceivedResponse:
|
230
|
+
"""Confirm message received."""
|
231
|
+
log(
|
232
|
+
DEBUG,
|
233
|
+
"[Fleet.ConfirmMessageReceived] Message with ID '%s' has been received",
|
234
|
+
request.message_object_id,
|
235
|
+
)
|
236
|
+
|
237
|
+
try:
|
238
|
+
res = message_handler.confirm_message_received(
|
239
|
+
request=request,
|
240
|
+
state=self.state_factory.state(),
|
241
|
+
store=self.objectstore_factory.store(),
|
242
|
+
)
|
243
|
+
except InvalidRunStatusException as e:
|
244
|
+
abort_grpc_context(e.message, context)
|
245
|
+
|
246
|
+
return res
|
@@ -44,6 +44,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
44
44
|
SendNodeHeartbeatResponse,
|
45
45
|
)
|
46
46
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
47
|
+
ConfirmMessageReceivedRequest,
|
48
|
+
ConfirmMessageReceivedResponse,
|
47
49
|
ObjectIDs,
|
48
50
|
PullObjectRequest,
|
49
51
|
PullObjectResponse,
|
@@ -146,6 +148,7 @@ def push_messages(
|
|
146
148
|
msg.metadata.run_id,
|
147
149
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
148
150
|
state,
|
151
|
+
store,
|
149
152
|
)
|
150
153
|
if abort_msg:
|
151
154
|
raise InvalidRunStatusException(abort_msg)
|
@@ -165,7 +168,9 @@ def push_messages(
|
|
165
168
|
return response
|
166
169
|
|
167
170
|
|
168
|
-
def get_run(
|
171
|
+
def get_run(
|
172
|
+
request: GetRunRequest, state: LinkState, store: ObjectStore
|
173
|
+
) -> GetRunResponse:
|
169
174
|
"""Get run information."""
|
170
175
|
run = state.get_run(request.run_id)
|
171
176
|
|
@@ -177,6 +182,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
177
182
|
request.run_id,
|
178
183
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
179
184
|
state,
|
185
|
+
store,
|
180
186
|
)
|
181
187
|
if abort_msg:
|
182
188
|
raise InvalidRunStatusException(abort_msg)
|
@@ -193,7 +199,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
193
199
|
|
194
200
|
|
195
201
|
def get_fab(
|
196
|
-
request: GetFabRequest, ffs: Ffs, state: LinkState
|
202
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
|
197
203
|
) -> GetFabResponse:
|
198
204
|
"""Get FAB."""
|
199
205
|
# Abort if the run is not running
|
@@ -201,6 +207,7 @@ def get_fab(
|
|
201
207
|
request.run_id,
|
202
208
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
203
209
|
state,
|
210
|
+
store,
|
204
211
|
)
|
205
212
|
if abort_msg:
|
206
213
|
raise InvalidRunStatusException(abort_msg)
|
@@ -220,6 +227,7 @@ def push_object(
|
|
220
227
|
request.run_id,
|
221
228
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
222
229
|
state,
|
230
|
+
store,
|
223
231
|
)
|
224
232
|
if abort_msg:
|
225
233
|
raise InvalidRunStatusException(abort_msg)
|
@@ -245,6 +253,7 @@ def pull_object(
|
|
245
253
|
request.run_id,
|
246
254
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
247
255
|
state,
|
256
|
+
store,
|
248
257
|
)
|
249
258
|
if abort_msg:
|
250
259
|
raise InvalidRunStatusException(abort_msg)
|
@@ -259,3 +268,25 @@ def pull_object(
|
|
259
268
|
object_content=content,
|
260
269
|
)
|
261
270
|
return PullObjectResponse(object_found=False, object_available=False)
|
271
|
+
|
272
|
+
|
273
|
+
def confirm_message_received(
|
274
|
+
request: ConfirmMessageReceivedRequest,
|
275
|
+
state: LinkState,
|
276
|
+
store: ObjectStore,
|
277
|
+
) -> ConfirmMessageReceivedResponse:
|
278
|
+
"""Confirm message received handler."""
|
279
|
+
abort_msg = check_abort(
|
280
|
+
request.run_id,
|
281
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
282
|
+
state,
|
283
|
+
store,
|
284
|
+
)
|
285
|
+
if abort_msg:
|
286
|
+
raise InvalidRunStatusException(abort_msg)
|
287
|
+
|
288
|
+
# Delete the message object
|
289
|
+
store.delete(request.message_object_id)
|
290
|
+
store.delete_message_descendant_ids(request.message_object_id)
|
291
|
+
|
292
|
+
return ConfirmMessageReceivedResponse()
|
@@ -39,6 +39,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
39
39
|
SendNodeHeartbeatResponse,
|
40
40
|
)
|
41
41
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
42
|
+
ConfirmMessageReceivedRequest,
|
43
|
+
ConfirmMessageReceivedResponse,
|
42
44
|
PullObjectRequest,
|
43
45
|
PullObjectResponse,
|
44
46
|
PushObjectRequest,
|
@@ -176,9 +178,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
176
178
|
"""GetRun."""
|
177
179
|
# Get state from app
|
178
180
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
181
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
179
182
|
|
180
183
|
# Handle message
|
181
|
-
return message_handler.get_run(request=request, state=state)
|
184
|
+
return message_handler.get_run(request=request, state=state, store=store)
|
182
185
|
|
183
186
|
|
184
187
|
@rest_request_response(GetFabRequest)
|
@@ -189,9 +192,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
189
192
|
|
190
193
|
# Get state from app
|
191
194
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
195
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
196
|
+
|
197
|
+
# Handle message
|
198
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
|
199
|
+
|
200
|
+
|
201
|
+
@rest_request_response(ConfirmMessageReceivedRequest)
|
202
|
+
async def confirm_message_received(
|
203
|
+
request: ConfirmMessageReceivedRequest,
|
204
|
+
) -> ConfirmMessageReceivedResponse:
|
205
|
+
"""Confirm message received."""
|
206
|
+
# Get state from app
|
207
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
208
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
192
209
|
|
193
210
|
# Handle message
|
194
|
-
return message_handler.
|
211
|
+
return message_handler.confirm_message_received(
|
212
|
+
request=request, state=state, store=store
|
213
|
+
)
|
195
214
|
|
196
215
|
|
197
216
|
routes = [
|
@@ -204,6 +223,11 @@ routes = [
|
|
204
223
|
Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
|
205
224
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
206
225
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
226
|
+
Route(
|
227
|
+
"/api/v0/fleet/confirm-message-received",
|
228
|
+
confirm_message_received,
|
229
|
+
methods=["POST"],
|
230
|
+
),
|
207
231
|
]
|
208
232
|
|
209
233
|
app: Starlette = Starlette(
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from bisect import bisect_right
|
21
|
+
from collections import defaultdict
|
21
22
|
from dataclasses import dataclass, field
|
22
23
|
from logging import ERROR, WARNING
|
23
24
|
from typing import Optional
|
@@ -79,6 +80,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
79
80
|
self.message_res_store: dict[str, Message] = {}
|
80
81
|
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
81
82
|
|
83
|
+
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
84
|
+
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
85
|
+
|
82
86
|
self.node_public_keys: set[bytes] = set()
|
83
87
|
|
84
88
|
self.lock = threading.RLock()
|
@@ -398,6 +402,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
398
402
|
fab_hash: Optional[str],
|
399
403
|
override_config: UserConfig,
|
400
404
|
federation_options: ConfigRecord,
|
405
|
+
flwr_aid: Optional[str],
|
401
406
|
) -> int:
|
402
407
|
"""Create a new run for the specified `fab_hash`."""
|
403
408
|
# Sample a random int64 as run_id
|
@@ -421,9 +426,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
421
426
|
sub_status="",
|
422
427
|
details="",
|
423
428
|
),
|
429
|
+
flwr_aid=flwr_aid if flwr_aid else "",
|
424
430
|
),
|
425
431
|
)
|
426
432
|
self.run_ids[run_id] = run_record
|
433
|
+
# Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
|
434
|
+
if flwr_aid:
|
435
|
+
self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
|
427
436
|
|
428
437
|
# Record federation options. Leave empty if not passed
|
429
438
|
self.federation_options[run_id] = federation_options
|
@@ -451,9 +460,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
451
460
|
with self.lock:
|
452
461
|
return self.node_public_keys.copy()
|
453
462
|
|
454
|
-
def get_run_ids(self) -> set[int]:
|
455
|
-
"""Retrieve all run IDs.
|
463
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
464
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
465
|
+
|
466
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
467
|
+
"""
|
456
468
|
with self.lock:
|
469
|
+
if flwr_aid is not None:
|
470
|
+
# Return run IDs for the specified flwr_aid
|
471
|
+
return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
|
457
472
|
return set(self.run_ids.keys())
|
458
473
|
|
459
474
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
@@ -164,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
164
164
|
fab_hash: Optional[str],
|
165
165
|
override_config: UserConfig,
|
166
166
|
federation_options: ConfigRecord,
|
167
|
+
flwr_aid: Optional[str],
|
167
168
|
) -> int:
|
168
169
|
"""Create a new run for the specified `fab_hash`."""
|
169
170
|
|
170
171
|
@abc.abstractmethod
|
171
|
-
def get_run_ids(self) -> set[int]:
|
172
|
-
"""Retrieve all run IDs.
|
172
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
173
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
174
|
+
|
175
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
176
|
+
"""
|
173
177
|
|
174
178
|
@abc.abstractmethod
|
175
179
|
def get_run(self, run_id: int) -> Optional[Run]:
|