flwr-nightly 1.19.0.dev20250611__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. flwr/client/grpc_rere_client/connection.py +47 -29
  2. flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
  3. flwr/client/rest_client/connection.py +70 -51
  4. flwr/common/inflatable.py +24 -0
  5. flwr/common/serde.py +2 -0
  6. flwr/common/typing.py +2 -0
  7. flwr/proto/fleet_pb2.py +12 -16
  8. flwr/proto/fleet_pb2.pyi +4 -19
  9. flwr/proto/fleet_pb2_grpc.py +34 -0
  10. flwr/proto/fleet_pb2_grpc.pyi +13 -0
  11. flwr/proto/message_pb2.py +15 -9
  12. flwr/proto/message_pb2.pyi +41 -0
  13. flwr/proto/run_pb2.py +24 -24
  14. flwr/proto/run_pb2.pyi +4 -1
  15. flwr/proto/serverappio_pb2.py +22 -26
  16. flwr/proto/serverappio_pb2.pyi +4 -19
  17. flwr/proto/serverappio_pb2_grpc.py +34 -0
  18. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  19. flwr/server/grid/grpc_grid.py +20 -9
  20. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
  21. flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
  22. flwr/server/superlink/fleet/rest_rere/rest_api.py +26 -2
  23. flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
  24. flwr/server/superlink/linkstate/linkstate.py +6 -2
  25. flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
  26. flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
  27. flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
  28. flwr/server/superlink/utils.py +23 -10
  29. flwr/supercore/object_store/in_memory_object_store.py +160 -33
  30. flwr/supercore/object_store/object_store.py +54 -7
  31. flwr/superexec/deployment.py +6 -2
  32. flwr/superexec/exec_servicer.py +4 -1
  33. flwr/superexec/executor.py +4 -0
  34. flwr/superexec/simulation.py +7 -1
  35. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +38 -38
  37. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
  38. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
flwr/proto/run_pb2.py CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
18
18
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
19
19
 
20
20
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xce\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"T\n\x1cGetFederationOptionsResponse\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x18.flwr.proto.ConfigRecordb\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xe0\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x10\n\x08\x66lwr_aid\x18\x0b \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"T\n\x1cGetFederationOptionsResponse\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x18.flwr.proto.ConfigRecordb\x06proto3')
22
22
 
23
23
  _globals = globals()
24
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -30,27 +30,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
30
30
  _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None
31
31
  _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001'
32
32
  _globals['_RUN']._serialized_start=139
33
- _globals['_RUN']._serialized_end=473
34
- _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=400
35
- _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=473
36
- _globals['_RUNSTATUS']._serialized_start=475
37
- _globals['_RUNSTATUS']._serialized_end=539
38
- _globals['_GETRUNREQUEST']._serialized_start=541
39
- _globals['_GETRUNREQUEST']._serialized_end=604
40
- _globals['_GETRUNRESPONSE']._serialized_start=606
41
- _globals['_GETRUNRESPONSE']._serialized_end=652
42
- _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=654
43
- _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=737
44
- _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=739
45
- _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=764
46
- _globals['_GETRUNSTATUSREQUEST']._serialized_start=766
47
- _globals['_GETRUNSTATUSREQUEST']._serialized_end=836
48
- _globals['_GETRUNSTATUSRESPONSE']._serialized_start=839
49
- _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1016
50
- _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=941
51
- _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1016
52
- _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1018
53
- _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1063
54
- _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1065
55
- _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1149
33
+ _globals['_RUN']._serialized_end=491
34
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=418
35
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=491
36
+ _globals['_RUNSTATUS']._serialized_start=493
37
+ _globals['_RUNSTATUS']._serialized_end=557
38
+ _globals['_GETRUNREQUEST']._serialized_start=559
39
+ _globals['_GETRUNREQUEST']._serialized_end=622
40
+ _globals['_GETRUNRESPONSE']._serialized_start=624
41
+ _globals['_GETRUNRESPONSE']._serialized_end=670
42
+ _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=672
43
+ _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=755
44
+ _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=757
45
+ _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=782
46
+ _globals['_GETRUNSTATUSREQUEST']._serialized_start=784
47
+ _globals['_GETRUNSTATUSREQUEST']._serialized_end=854
48
+ _globals['_GETRUNSTATUSRESPONSE']._serialized_start=857
49
+ _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1034
50
+ _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=959
51
+ _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1034
52
+ _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1036
53
+ _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1081
54
+ _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1083
55
+ _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1167
56
56
  # @@protoc_insertion_point(module_scope)
flwr/proto/run_pb2.pyi CHANGED
@@ -41,6 +41,7 @@ class Run(google.protobuf.message.Message):
41
41
  RUNNING_AT_FIELD_NUMBER: builtins.int
42
42
  FINISHED_AT_FIELD_NUMBER: builtins.int
43
43
  STATUS_FIELD_NUMBER: builtins.int
44
+ FLWR_AID_FIELD_NUMBER: builtins.int
44
45
  run_id: builtins.int
45
46
  fab_id: typing.Text
46
47
  fab_version: typing.Text
@@ -53,6 +54,7 @@ class Run(google.protobuf.message.Message):
53
54
  finished_at: typing.Text
54
55
  @property
55
56
  def status(self) -> global___RunStatus: ...
57
+ flwr_aid: typing.Text
56
58
  def __init__(self,
57
59
  *,
58
60
  run_id: builtins.int = ...,
@@ -65,9 +67,10 @@ class Run(google.protobuf.message.Message):
65
67
  running_at: typing.Text = ...,
66
68
  finished_at: typing.Text = ...,
67
69
  status: typing.Optional[global___RunStatus] = ...,
70
+ flwr_aid: typing.Text = ...,
68
71
  ) -> None: ...
69
72
  def HasField(self, field_name: typing_extensions.Literal["status",b"status"]) -> builtins.bool: ...
70
- def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","finished_at",b"finished_at","override_config",b"override_config","pending_at",b"pending_at","run_id",b"run_id","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ...
73
+ def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","finished_at",b"finished_at","flwr_aid",b"flwr_aid","override_config",b"override_config","pending_at",b"pending_at","run_id",b"run_id","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ...
71
74
  global___Run = Run
72
75
 
73
76
  class RunStatus(google.protobuf.message.Message):
@@ -20,15 +20,13 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
20
20
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
21
21
 
22
22
 
23
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x8d\x02\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x61\n\x19msg_to_descendant_mapping\x18\x03 \x03(\x0b\x32>.flwr.proto.PushInsMessagesRequest.MsgToDescendantMappingEntry\x1aT\n\x1bMsgToDescendantMappingEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xe4\x08\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x62\x06proto3')
23
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"\x8a\x01\n\x16PushInsMessagesRequest\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xcc\x01\n\x17PushInsMessagesResponse\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12O\n\x0fobjects_to_push\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PushInsMessagesResponse.ObjectsToPushEntry\x1aK\n\x12ObjectsToPushEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"=\n\x16PullResMessagesRequest\x12\x13\n\x0bmessage_ids\x18\x01 \x03(\t\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\xe3\x01\n\x17PullResMessagesResponse\x12*\n\rmessages_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.Message\x12O\n\x0fobjects_to_pull\x18\x02 \x03(\x0b\x32\x36.flwr.proto.PullResMessagesResponse.ObjectsToPullEntry\x1aK\n\x12ObjectsToPullEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.ObjectIDs:\x02\x38\x01\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xd7\t\n\x0bServerAppIo\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12Y\n\x0cPushMessages\x12\".flwr.proto.PushInsMessagesRequest\x1a#.flwr.proto.PushInsMessagesResponse\"\x00\x12Y\n\x0cPullMessages\x12\".flwr.proto.PullResMessagesRequest\x1a#.flwr.proto.PullResMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12_\n\x10SendAppHeartbeat\x12#.flwr.proto.SendAppHeartbeatRequest\x1a$.flwr.proto.SendAppHeartbeatResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
24
24
 
25
25
  _globals = globals()
26
26
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
27
27
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.serverappio_pb2', _globals)
28
28
  if _descriptor._USE_C_DESCRIPTORS == False:
29
29
  DESCRIPTOR._options = None
30
- _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._options = None
31
- _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_options = b'8\001'
32
30
  _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._options = None
33
31
  _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_options = b'8\001'
34
32
  _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._options = None
@@ -38,27 +36,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
38
36
  _globals['_GETNODESRESPONSE']._serialized_start=222
39
37
  _globals['_GETNODESRESPONSE']._serialized_end=273
40
38
  _globals['_PUSHINSMESSAGESREQUEST']._serialized_start=276
41
- _globals['_PUSHINSMESSAGESREQUEST']._serialized_end=545
42
- _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_start=461
43
- _globals['_PUSHINSMESSAGESREQUEST_MSGTODESCENDANTMAPPINGENTRY']._serialized_end=545
44
- _globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=548
45
- _globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=752
46
- _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=677
47
- _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=752
48
- _globals['_PULLRESMESSAGESREQUEST']._serialized_start=754
49
- _globals['_PULLRESMESSAGESREQUEST']._serialized_end=815
50
- _globals['_PULLRESMESSAGESRESPONSE']._serialized_start=818
51
- _globals['_PULLRESMESSAGESRESPONSE']._serialized_end=1045
52
- _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=970
53
- _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=1045
54
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=1047
55
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=1075
56
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=1077
57
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1204
58
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1206
59
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1289
60
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1291
61
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1321
62
- _globals['_SERVERAPPIO']._serialized_start=1324
63
- _globals['_SERVERAPPIO']._serialized_end=2448
39
+ _globals['_PUSHINSMESSAGESREQUEST']._serialized_end=414
40
+ _globals['_PUSHINSMESSAGESRESPONSE']._serialized_start=417
41
+ _globals['_PUSHINSMESSAGESRESPONSE']._serialized_end=621
42
+ _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_start=546
43
+ _globals['_PUSHINSMESSAGESRESPONSE_OBJECTSTOPUSHENTRY']._serialized_end=621
44
+ _globals['_PULLRESMESSAGESREQUEST']._serialized_start=623
45
+ _globals['_PULLRESMESSAGESREQUEST']._serialized_end=684
46
+ _globals['_PULLRESMESSAGESRESPONSE']._serialized_start=687
47
+ _globals['_PULLRESMESSAGESRESPONSE']._serialized_end=914
48
+ _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_start=839
49
+ _globals['_PULLRESMESSAGESRESPONSE_OBJECTSTOPULLENTRY']._serialized_end=914
50
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=916
51
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=944
52
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=946
53
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=1073
54
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=1075
55
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=1158
56
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=1160
57
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=1190
58
+ _globals['_SERVERAPPIO']._serialized_start=1193
59
+ _globals['_SERVERAPPIO']._serialized_end=2432
64
60
  # @@protoc_insertion_point(module_scope)
@@ -42,36 +42,21 @@ global___GetNodesResponse = GetNodesResponse
42
42
  class PushInsMessagesRequest(google.protobuf.message.Message):
43
43
  """PushMessages messages"""
44
44
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
45
- class MsgToDescendantMappingEntry(google.protobuf.message.Message):
46
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
47
- KEY_FIELD_NUMBER: builtins.int
48
- VALUE_FIELD_NUMBER: builtins.int
49
- key: typing.Text
50
- @property
51
- def value(self) -> flwr.proto.message_pb2.ObjectIDs: ...
52
- def __init__(self,
53
- *,
54
- key: typing.Text = ...,
55
- value: typing.Optional[flwr.proto.message_pb2.ObjectIDs] = ...,
56
- ) -> None: ...
57
- def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
58
- def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
59
-
60
45
  MESSAGES_LIST_FIELD_NUMBER: builtins.int
61
46
  RUN_ID_FIELD_NUMBER: builtins.int
62
- MSG_TO_DESCENDANT_MAPPING_FIELD_NUMBER: builtins.int
47
+ MESSAGE_OBJECT_TREES_FIELD_NUMBER: builtins.int
63
48
  @property
64
49
  def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
65
50
  run_id: builtins.int
66
51
  @property
67
- def msg_to_descendant_mapping(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.message_pb2.ObjectIDs]: ...
52
+ def message_object_trees(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.ObjectTree]: ...
68
53
  def __init__(self,
69
54
  *,
70
55
  messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
71
56
  run_id: builtins.int = ...,
72
- msg_to_descendant_mapping: typing.Optional[typing.Mapping[typing.Text, flwr.proto.message_pb2.ObjectIDs]] = ...,
57
+ message_object_trees: typing.Optional[typing.Iterable[flwr.proto.message_pb2.ObjectTree]] = ...,
73
58
  ) -> None: ...
74
- def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","msg_to_descendant_mapping",b"msg_to_descendant_mapping","run_id",b"run_id"]) -> None: ...
59
+ def ClearField(self, field_name: typing_extensions.Literal["message_object_trees",b"message_object_trees","messages_list",b"messages_list","run_id",b"run_id"]) -> None: ...
75
60
  global___PushInsMessagesRequest = PushInsMessagesRequest
76
61
 
77
62
  class PushInsMessagesResponse(google.protobuf.message.Message):
@@ -84,6 +84,11 @@ class ServerAppIoStub(object):
84
84
  request_serializer=flwr_dot_proto_dot_message__pb2.PullObjectRequest.SerializeToString,
85
85
  response_deserializer=flwr_dot_proto_dot_message__pb2.PullObjectResponse.FromString,
86
86
  )
87
+ self.ConfirmMessageReceived = channel.unary_unary(
88
+ '/flwr.proto.ServerAppIo/ConfirmMessageReceived',
89
+ request_serializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedRequest.SerializeToString,
90
+ response_deserializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedResponse.FromString,
91
+ )
87
92
 
88
93
 
89
94
  class ServerAppIoServicer(object):
@@ -180,6 +185,13 @@ class ServerAppIoServicer(object):
180
185
  context.set_details('Method not implemented!')
181
186
  raise NotImplementedError('Method not implemented!')
182
187
 
188
+ def ConfirmMessageReceived(self, request, context):
189
+ """Confirm Message Received
190
+ """
191
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
192
+ context.set_details('Method not implemented!')
193
+ raise NotImplementedError('Method not implemented!')
194
+
183
195
 
184
196
  def add_ServerAppIoServicer_to_server(servicer, server):
185
197
  rpc_method_handlers = {
@@ -248,6 +260,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
248
260
  request_deserializer=flwr_dot_proto_dot_message__pb2.PullObjectRequest.FromString,
249
261
  response_serializer=flwr_dot_proto_dot_message__pb2.PullObjectResponse.SerializeToString,
250
262
  ),
263
+ 'ConfirmMessageReceived': grpc.unary_unary_rpc_method_handler(
264
+ servicer.ConfirmMessageReceived,
265
+ request_deserializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedRequest.FromString,
266
+ response_serializer=flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedResponse.SerializeToString,
267
+ ),
251
268
  }
252
269
  generic_handler = grpc.method_handlers_generic_handler(
253
270
  'flwr.proto.ServerAppIo', rpc_method_handlers)
@@ -478,3 +495,20 @@ class ServerAppIo(object):
478
495
  flwr_dot_proto_dot_message__pb2.PullObjectResponse.FromString,
479
496
  options, channel_credentials,
480
497
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
498
+
499
+ @staticmethod
500
+ def ConfirmMessageReceived(request,
501
+ target,
502
+ options=(),
503
+ channel_credentials=None,
504
+ call_credentials=None,
505
+ insecure=False,
506
+ compression=None,
507
+ wait_for_ready=None,
508
+ timeout=None,
509
+ metadata=None):
510
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/ConfirmMessageReceived',
511
+ flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedRequest.SerializeToString,
512
+ flwr_dot_proto_dot_message__pb2.ConfirmMessageReceivedResponse.FromString,
513
+ options, channel_credentials,
514
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -78,6 +78,11 @@ class ServerAppIoStub:
78
78
  flwr.proto.message_pb2.PullObjectResponse]
79
79
  """Pull Object"""
80
80
 
81
+ ConfirmMessageReceived: grpc.UnaryUnaryMultiCallable[
82
+ flwr.proto.message_pb2.ConfirmMessageReceivedRequest,
83
+ flwr.proto.message_pb2.ConfirmMessageReceivedResponse]
84
+ """Confirm Message Received"""
85
+
81
86
 
82
87
  class ServerAppIoServicer(metaclass=abc.ABCMeta):
83
88
  @abc.abstractmethod
@@ -184,5 +189,13 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
184
189
  """Pull Object"""
185
190
  pass
186
191
 
192
+ @abc.abstractmethod
193
+ def ConfirmMessageReceived(self,
194
+ request: flwr.proto.message_pb2.ConfirmMessageReceivedRequest,
195
+ context: grpc.ServicerContext,
196
+ ) -> flwr.proto.message_pb2.ConfirmMessageReceivedResponse:
197
+ """Confirm Message Received"""
198
+ pass
199
+
187
200
 
188
201
  def add_ServerAppIoServicer_to_server(servicer: ServerAppIoServicer, server: grpc.Server) -> None: ...
@@ -28,7 +28,11 @@ from flwr.common.constant import (
28
28
  SUPERLINK_NODE_ID,
29
29
  )
30
30
  from flwr.common.grpc import create_channel, on_channel_state_change
31
- from flwr.common.inflatable import get_all_nested_objects
31
+ from flwr.common.inflatable import (
32
+ get_all_nested_objects,
33
+ get_object_tree,
34
+ no_object_id_recompute,
35
+ )
32
36
  from flwr.common.inflatable_grpc_utils import (
33
37
  make_pull_object_fn_grpc,
34
38
  make_push_object_fn_grpc,
@@ -43,7 +47,9 @@ from flwr.common.message import remove_content_from_message
43
47
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
44
48
  from flwr.common.serde import message_to_proto, run_from_proto
45
49
  from flwr.common.typing import Run
46
- from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
50
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
51
+ ConfirmMessageReceivedRequest,
52
+ )
47
53
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
48
54
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
55
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
@@ -213,18 +219,15 @@ class GrpcGrid(Grid):
213
219
  """Push one message and its associated objects."""
214
220
  # Compute mapping of message descendants
215
221
  all_objects = get_all_nested_objects(message)
216
- all_object_ids = list(all_objects.keys())
217
- msg_id = all_object_ids[-1] # Last object is the message itself
218
- descendant_ids = all_object_ids[:-1] # All but the last object are descendants
222
+ msg_id = message.object_id
223
+ object_tree = get_object_tree(message)
219
224
 
220
225
  # Call GrpcServerAppIoStub method
221
226
  res: PushInsMessagesResponse = self._stub.PushMessages(
222
227
  PushInsMessagesRequest(
223
228
  messages_list=[message_to_proto(remove_content_from_message(message))],
224
229
  run_id=run_id,
225
- msg_to_descendant_mapping={
226
- msg_id: ObjectIDs(object_ids=descendant_ids)
227
- },
230
+ message_object_trees=[object_tree],
228
231
  )
229
232
  )
230
233
 
@@ -262,7 +265,8 @@ class GrpcGrid(Grid):
262
265
  # Check message
263
266
  self._check_message(msg)
264
267
  # Try pushing message and its objects
265
- message_ids.append(self._try_push_message(run_id, msg))
268
+ with no_object_id_recompute():
269
+ message_ids.append(self._try_push_message(run_id, msg))
266
270
 
267
271
  except grpc.RpcError as e:
268
272
  if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
@@ -308,6 +312,13 @@ class GrpcGrid(Grid):
308
312
  run_id=run_id,
309
313
  ),
310
314
  )
315
+
316
+ # Confirm that the message has been received
317
+ self._stub.ConfirmMessageReceived(
318
+ ConfirmMessageReceivedRequest(
319
+ node=self.node, run_id=run_id, message_object_id=msg_id
320
+ )
321
+ )
311
322
  message = cast(
312
323
  Message, inflate_object_from_contents(msg_id, all_object_contents)
313
324
  )
@@ -40,6 +40,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
40
40
  SendNodeHeartbeatResponse,
41
41
  )
42
42
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
43
+ ConfirmMessageReceivedRequest,
44
+ ConfirmMessageReceivedResponse,
43
45
  PullObjectRequest,
44
46
  PullObjectResponse,
45
47
  PushObjectRequest,
@@ -151,6 +153,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
151
153
  res = message_handler.get_run(
152
154
  request=request,
153
155
  state=self.state_factory.state(),
156
+ store=self.objectstore_factory.store(),
154
157
  )
155
158
  except InvalidRunStatusException as e:
156
159
  abort_grpc_context(e.message, context)
@@ -167,6 +170,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
167
170
  request=request,
168
171
  ffs=self.ffs_factory.ffs(),
169
172
  state=self.state_factory.state(),
173
+ store=self.objectstore_factory.store(),
170
174
  )
171
175
  except InvalidRunStatusException as e:
172
176
  abort_grpc_context(e.message, context)
@@ -219,3 +223,24 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
219
223
  abort_grpc_context(e.message, context)
220
224
 
221
225
  return res
226
+
227
+ def ConfirmMessageReceived(
228
+ self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
229
+ ) -> ConfirmMessageReceivedResponse:
230
+ """Confirm message received."""
231
+ log(
232
+ DEBUG,
233
+ "[Fleet.ConfirmMessageReceived] Message with ID '%s' has been received",
234
+ request.message_object_id,
235
+ )
236
+
237
+ try:
238
+ res = message_handler.confirm_message_received(
239
+ request=request,
240
+ state=self.state_factory.state(),
241
+ store=self.objectstore_factory.store(),
242
+ )
243
+ except InvalidRunStatusException as e:
244
+ abort_grpc_context(e.message, context)
245
+
246
+ return res
@@ -44,6 +44,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
44
44
  SendNodeHeartbeatResponse,
45
45
  )
46
46
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
+ ConfirmMessageReceivedRequest,
48
+ ConfirmMessageReceivedResponse,
47
49
  ObjectIDs,
48
50
  PullObjectRequest,
49
51
  PullObjectResponse,
@@ -146,6 +148,7 @@ def push_messages(
146
148
  msg.metadata.run_id,
147
149
  [Status.PENDING, Status.STARTING, Status.FINISHED],
148
150
  state,
151
+ store,
149
152
  )
150
153
  if abort_msg:
151
154
  raise InvalidRunStatusException(abort_msg)
@@ -165,7 +168,9 @@ def push_messages(
165
168
  return response
166
169
 
167
170
 
168
- def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
171
+ def get_run(
172
+ request: GetRunRequest, state: LinkState, store: ObjectStore
173
+ ) -> GetRunResponse:
169
174
  """Get run information."""
170
175
  run = state.get_run(request.run_id)
171
176
 
@@ -177,6 +182,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
177
182
  request.run_id,
178
183
  [Status.PENDING, Status.STARTING, Status.FINISHED],
179
184
  state,
185
+ store,
180
186
  )
181
187
  if abort_msg:
182
188
  raise InvalidRunStatusException(abort_msg)
@@ -193,7 +199,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
193
199
 
194
200
 
195
201
  def get_fab(
196
- request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
202
+ request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
197
203
  ) -> GetFabResponse:
198
204
  """Get FAB."""
199
205
  # Abort if the run is not running
@@ -201,6 +207,7 @@ def get_fab(
201
207
  request.run_id,
202
208
  [Status.PENDING, Status.STARTING, Status.FINISHED],
203
209
  state,
210
+ store,
204
211
  )
205
212
  if abort_msg:
206
213
  raise InvalidRunStatusException(abort_msg)
@@ -220,6 +227,7 @@ def push_object(
220
227
  request.run_id,
221
228
  [Status.PENDING, Status.STARTING, Status.FINISHED],
222
229
  state,
230
+ store,
223
231
  )
224
232
  if abort_msg:
225
233
  raise InvalidRunStatusException(abort_msg)
@@ -245,6 +253,7 @@ def pull_object(
245
253
  request.run_id,
246
254
  [Status.PENDING, Status.STARTING, Status.FINISHED],
247
255
  state,
256
+ store,
248
257
  )
249
258
  if abort_msg:
250
259
  raise InvalidRunStatusException(abort_msg)
@@ -259,3 +268,25 @@ def pull_object(
259
268
  object_content=content,
260
269
  )
261
270
  return PullObjectResponse(object_found=False, object_available=False)
271
+
272
+
273
+ def confirm_message_received(
274
+ request: ConfirmMessageReceivedRequest,
275
+ state: LinkState,
276
+ store: ObjectStore,
277
+ ) -> ConfirmMessageReceivedResponse:
278
+ """Confirm message received handler."""
279
+ abort_msg = check_abort(
280
+ request.run_id,
281
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
282
+ state,
283
+ store,
284
+ )
285
+ if abort_msg:
286
+ raise InvalidRunStatusException(abort_msg)
287
+
288
+ # Delete the message object
289
+ store.delete(request.message_object_id)
290
+ store.delete_message_descendant_ids(request.message_object_id)
291
+
292
+ return ConfirmMessageReceivedResponse()
@@ -39,6 +39,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
39
39
  SendNodeHeartbeatResponse,
40
40
  )
41
41
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
42
+ ConfirmMessageReceivedRequest,
43
+ ConfirmMessageReceivedResponse,
42
44
  PullObjectRequest,
43
45
  PullObjectResponse,
44
46
  PushObjectRequest,
@@ -176,9 +178,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
176
178
  """GetRun."""
177
179
  # Get state from app
178
180
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
181
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
179
182
 
180
183
  # Handle message
181
- return message_handler.get_run(request=request, state=state)
184
+ return message_handler.get_run(request=request, state=state, store=store)
182
185
 
183
186
 
184
187
  @rest_request_response(GetFabRequest)
@@ -189,9 +192,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
189
192
 
190
193
  # Get state from app
191
194
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
195
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
196
+
197
+ # Handle message
198
+ return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
199
+
200
+
201
+ @rest_request_response(ConfirmMessageReceivedRequest)
202
+ async def confirm_message_received(
203
+ request: ConfirmMessageReceivedRequest,
204
+ ) -> ConfirmMessageReceivedResponse:
205
+ """Confirm message received."""
206
+ # Get state from app
207
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
208
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
192
209
 
193
210
  # Handle message
194
- return message_handler.get_fab(request=request, ffs=ffs, state=state)
211
+ return message_handler.confirm_message_received(
212
+ request=request, state=state, store=store
213
+ )
195
214
 
196
215
 
197
216
  routes = [
@@ -204,6 +223,11 @@ routes = [
204
223
  Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
205
224
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
206
225
  Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
226
+ Route(
227
+ "/api/v0/fleet/confirm-message-received",
228
+ confirm_message_received,
229
+ methods=["POST"],
230
+ ),
207
231
  ]
208
232
 
209
233
  app: Starlette = Starlette(
@@ -18,6 +18,7 @@
18
18
  import threading
19
19
  import time
20
20
  from bisect import bisect_right
21
+ from collections import defaultdict
21
22
  from dataclasses import dataclass, field
22
23
  from logging import ERROR, WARNING
23
24
  from typing import Optional
@@ -79,6 +80,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
79
80
  self.message_res_store: dict[str, Message] = {}
80
81
  self.message_ins_id_to_message_res_id: dict[str, str] = {}
81
82
 
83
+ # Map flwr_aid to run_ids for O(1) reverse index lookup
84
+ self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
85
+
82
86
  self.node_public_keys: set[bytes] = set()
83
87
 
84
88
  self.lock = threading.RLock()
@@ -398,6 +402,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
398
402
  fab_hash: Optional[str],
399
403
  override_config: UserConfig,
400
404
  federation_options: ConfigRecord,
405
+ flwr_aid: Optional[str],
401
406
  ) -> int:
402
407
  """Create a new run for the specified `fab_hash`."""
403
408
  # Sample a random int64 as run_id
@@ -421,9 +426,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
421
426
  sub_status="",
422
427
  details="",
423
428
  ),
429
+ flwr_aid=flwr_aid if flwr_aid else "",
424
430
  ),
425
431
  )
426
432
  self.run_ids[run_id] = run_record
433
+ # Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
434
+ if flwr_aid:
435
+ self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
427
436
 
428
437
  # Record federation options. Leave empty if not passed
429
438
  self.federation_options[run_id] = federation_options
@@ -451,9 +460,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
451
460
  with self.lock:
452
461
  return self.node_public_keys.copy()
453
462
 
454
- def get_run_ids(self) -> set[int]:
455
- """Retrieve all run IDs."""
463
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
464
+ """Retrieve all run IDs if `flwr_aid` is not specified.
465
+
466
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
467
+ """
456
468
  with self.lock:
469
+ if flwr_aid is not None:
470
+ # Return run IDs for the specified flwr_aid
471
+ return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
457
472
  return set(self.run_ids.keys())
458
473
 
459
474
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
@@ -164,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
164
164
  fab_hash: Optional[str],
165
165
  override_config: UserConfig,
166
166
  federation_options: ConfigRecord,
167
+ flwr_aid: Optional[str],
167
168
  ) -> int:
168
169
  """Create a new run for the specified `fab_hash`."""
169
170
 
170
171
  @abc.abstractmethod
171
- def get_run_ids(self) -> set[int]:
172
- """Retrieve all run IDs."""
172
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
173
+ """Retrieve all run IDs if `flwr_aid` is not specified.
174
+
175
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
176
+ """
173
177
 
174
178
  @abc.abstractmethod
175
179
  def get_run(self, run_id: int) -> Optional[Run]: