flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (100) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +11 -31
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +83 -0
  10. flwr/cli/ls.py +10 -40
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  14. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  15. flwr/cli/run/__init__.py +1 -0
  16. flwr/cli/run/run.py +15 -25
  17. flwr/cli/stop.py +91 -0
  18. flwr/cli/utils.py +109 -1
  19. flwr/client/app.py +3 -2
  20. flwr/client/client.py +1 -0
  21. flwr/client/clientapp/app.py +1 -0
  22. flwr/client/clientapp/utils.py +1 -0
  23. flwr/client/grpc_adapter_client/connection.py +1 -1
  24. flwr/client/grpc_client/connection.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +3 -3
  26. flwr/client/message_handler/message_handler.py +1 -0
  27. flwr/client/mod/comms_mods.py +1 -0
  28. flwr/client/mod/localdp_mod.py +1 -1
  29. flwr/client/nodestate/__init__.py +1 -0
  30. flwr/client/nodestate/nodestate.py +1 -0
  31. flwr/client/nodestate/nodestate_factory.py +1 -0
  32. flwr/client/rest_client/connection.py +3 -3
  33. flwr/client/supernode/app.py +1 -0
  34. flwr/common/address.py +1 -0
  35. flwr/common/args.py +1 -0
  36. flwr/common/auth_plugin/__init__.py +24 -0
  37. flwr/common/auth_plugin/auth_plugin.py +111 -0
  38. flwr/common/config.py +3 -1
  39. flwr/common/constant.py +6 -1
  40. flwr/common/logger.py +1 -0
  41. flwr/common/message.py +1 -0
  42. flwr/common/object_ref.py +57 -54
  43. flwr/common/pyproject.py +1 -0
  44. flwr/common/record/__init__.py +1 -0
  45. flwr/common/record/parametersrecord.py +1 -0
  46. flwr/common/retry_invoker.py +75 -0
  47. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  48. flwr/common/telemetry.py +2 -1
  49. flwr/common/typing.py +12 -0
  50. flwr/common/version.py +1 -0
  51. flwr/proto/exec_pb2.py +27 -3
  52. flwr/proto/exec_pb2.pyi +103 -0
  53. flwr/proto/exec_pb2_grpc.py +102 -0
  54. flwr/proto/exec_pb2_grpc.pyi +39 -0
  55. flwr/proto/fab_pb2.py +4 -4
  56. flwr/proto/fab_pb2.pyi +4 -1
  57. flwr/proto/serverappio_pb2.py +18 -18
  58. flwr/proto/serverappio_pb2.pyi +8 -2
  59. flwr/proto/serverappio_pb2_grpc.py +34 -0
  60. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  61. flwr/proto/simulationio_pb2.py +2 -2
  62. flwr/proto/simulationio_pb2_grpc.py +34 -0
  63. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  64. flwr/server/app.py +53 -1
  65. flwr/server/compat/app_utils.py +7 -1
  66. flwr/server/driver/grpc_driver.py +11 -63
  67. flwr/server/driver/inmemory_driver.py +5 -1
  68. flwr/server/serverapp/app.py +9 -2
  69. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  70. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  71. flwr/server/superlink/driver/serverappio_servicer.py +72 -22
  72. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  73. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  74. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +31 -2
  77. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  78. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  79. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  80. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  81. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
  82. flwr/server/superlink/linkstate/linkstate.py +13 -2
  83. flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
  84. flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
  85. flwr/server/superlink/utils.py +65 -0
  86. flwr/simulation/app.py +1 -0
  87. flwr/simulation/ray_transport/ray_actor.py +1 -0
  88. flwr/simulation/ray_transport/utils.py +1 -0
  89. flwr/simulation/run_simulation.py +1 -0
  90. flwr/superexec/app.py +1 -0
  91. flwr/superexec/deployment.py +1 -0
  92. flwr/superexec/exec_grpc.py +19 -1
  93. flwr/superexec/exec_servicer.py +76 -2
  94. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  95. flwr/superexec/executor.py +1 -0
  96. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/METADATA +8 -7
  97. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/RECORD +100 -92
  98. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/LICENSE +0 -0
  99. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/WHEEL +0 -0
  100. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,11 @@ class ExecStub:
14
14
  flwr.proto.exec_pb2.StartRunResponse]
15
15
  """Start run upon request"""
16
16
 
17
+ StopRun: grpc.UnaryUnaryMultiCallable[
18
+ flwr.proto.exec_pb2.StopRunRequest,
19
+ flwr.proto.exec_pb2.StopRunResponse]
20
+ """Stop run upon request"""
21
+
17
22
  StreamLogs: grpc.UnaryStreamMultiCallable[
18
23
  flwr.proto.exec_pb2.StreamLogsRequest,
19
24
  flwr.proto.exec_pb2.StreamLogsResponse]
@@ -24,6 +29,16 @@ class ExecStub:
24
29
  flwr.proto.exec_pb2.ListRunsResponse]
25
30
  """flwr ls command"""
26
31
 
32
+ GetLoginDetails: grpc.UnaryUnaryMultiCallable[
33
+ flwr.proto.exec_pb2.GetLoginDetailsRequest,
34
+ flwr.proto.exec_pb2.GetLoginDetailsResponse]
35
+ """Get login details upon request"""
36
+
37
+ GetAuthTokens: grpc.UnaryUnaryMultiCallable[
38
+ flwr.proto.exec_pb2.GetAuthTokensRequest,
39
+ flwr.proto.exec_pb2.GetAuthTokensResponse]
40
+ """Get auth tokens upon request"""
41
+
27
42
 
28
43
  class ExecServicer(metaclass=abc.ABCMeta):
29
44
  @abc.abstractmethod
@@ -34,6 +49,14 @@ class ExecServicer(metaclass=abc.ABCMeta):
34
49
  """Start run upon request"""
35
50
  pass
36
51
 
52
+ @abc.abstractmethod
53
+ def StopRun(self,
54
+ request: flwr.proto.exec_pb2.StopRunRequest,
55
+ context: grpc.ServicerContext,
56
+ ) -> flwr.proto.exec_pb2.StopRunResponse:
57
+ """Stop run upon request"""
58
+ pass
59
+
37
60
  @abc.abstractmethod
38
61
  def StreamLogs(self,
39
62
  request: flwr.proto.exec_pb2.StreamLogsRequest,
@@ -50,5 +73,21 @@ class ExecServicer(metaclass=abc.ABCMeta):
50
73
  """flwr ls command"""
51
74
  pass
52
75
 
76
+ @abc.abstractmethod
77
+ def GetLoginDetails(self,
78
+ request: flwr.proto.exec_pb2.GetLoginDetailsRequest,
79
+ context: grpc.ServicerContext,
80
+ ) -> flwr.proto.exec_pb2.GetLoginDetailsResponse:
81
+ """Get login details upon request"""
82
+ pass
83
+
84
+ @abc.abstractmethod
85
+ def GetAuthTokens(self,
86
+ request: flwr.proto.exec_pb2.GetAuthTokensRequest,
87
+ context: grpc.ServicerContext,
88
+ ) -> flwr.proto.exec_pb2.GetAuthTokensResponse:
89
+ """Get auth tokens upon request"""
90
+ pass
91
+
53
92
 
54
93
  def add_ExecServicer_to_server(servicer: ExecServicer, server: grpc.Server) -> None: ...
flwr/proto/fab_pb2.py CHANGED
@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
 
17
17
 
18
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"A\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"Q\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
19
19
 
20
20
  _globals = globals()
21
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -25,7 +25,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
25
25
  _globals['_FAB']._serialized_start=59
26
26
  _globals['_FAB']._serialized_end=99
27
27
  _globals['_GETFABREQUEST']._serialized_start=101
28
- _globals['_GETFABREQUEST']._serialized_end=166
29
- _globals['_GETFABRESPONSE']._serialized_start=168
30
- _globals['_GETFABRESPONSE']._serialized_end=214
28
+ _globals['_GETFABREQUEST']._serialized_end=182
29
+ _globals['_GETFABRESPONSE']._serialized_start=184
30
+ _globals['_GETFABRESPONSE']._serialized_end=230
31
31
  # @@protoc_insertion_point(module_scope)
flwr/proto/fab_pb2.pyi CHANGED
@@ -36,16 +36,19 @@ class GetFabRequest(google.protobuf.message.Message):
36
36
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
37
37
  NODE_FIELD_NUMBER: builtins.int
38
38
  HASH_STR_FIELD_NUMBER: builtins.int
39
+ RUN_ID_FIELD_NUMBER: builtins.int
39
40
  @property
40
41
  def node(self) -> flwr.proto.node_pb2.Node: ...
41
42
  hash_str: typing.Text
43
+ run_id: builtins.int
42
44
  def __init__(self,
43
45
  *,
44
46
  node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
45
47
  hash_str: typing.Text = ...,
48
+ run_id: builtins.int = ...,
46
49
  ) -> None: ...
47
50
  def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
48
- def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
51
+ def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node","run_id",b"run_id"]) -> None: ...
49
52
  global___GetFabRequest = GetFabRequest
50
53
 
51
54
  class GetFabResponse(google.protobuf.message.Message):
@@ -20,7 +20,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
20
20
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
21
21
 
22
22
 
23
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xca\x06\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
23
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
24
24
 
25
25
  _globals = globals()
26
26
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -32,21 +32,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
32
32
  _globals['_GETNODESRESPONSE']._serialized_start=217
33
33
  _globals['_GETNODESRESPONSE']._serialized_end=268
34
34
  _globals['_PUSHTASKINSREQUEST']._serialized_start=270
35
- _globals['_PUSHTASKINSREQUEST']._serialized_end=334
36
- _globals['_PUSHTASKINSRESPONSE']._serialized_start=336
37
- _globals['_PUSHTASKINSRESPONSE']._serialized_end=375
38
- _globals['_PULLTASKRESREQUEST']._serialized_start=377
39
- _globals['_PULLTASKRESREQUEST']._serialized_end=447
40
- _globals['_PULLTASKRESRESPONSE']._serialized_start=449
41
- _globals['_PULLTASKRESRESPONSE']._serialized_end=514
42
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=516
43
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=544
44
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=546
45
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=673
46
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=675
47
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=758
48
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=760
49
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=790
50
- _globals['_SERVERAPPIO']._serialized_start=793
51
- _globals['_SERVERAPPIO']._serialized_end=1635
35
+ _globals['_PUSHTASKINSREQUEST']._serialized_end=350
36
+ _globals['_PUSHTASKINSRESPONSE']._serialized_start=352
37
+ _globals['_PUSHTASKINSRESPONSE']._serialized_end=391
38
+ _globals['_PULLTASKRESREQUEST']._serialized_start=393
39
+ _globals['_PULLTASKRESREQUEST']._serialized_end=479
40
+ _globals['_PULLTASKRESRESPONSE']._serialized_start=481
41
+ _globals['_PULLTASKRESRESPONSE']._serialized_end=546
42
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548
43
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576
44
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578
45
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705
46
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707
47
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790
48
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792
49
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822
50
+ _globals['_SERVERAPPIO']._serialized_start=825
51
+ _globals['_SERVERAPPIO']._serialized_end=1752
52
52
  # @@protoc_insertion_point(module_scope)
@@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message):
44
44
  """PushTaskIns messages"""
45
45
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
46
46
  TASK_INS_LIST_FIELD_NUMBER: builtins.int
47
+ RUN_ID_FIELD_NUMBER: builtins.int
47
48
  @property
48
49
  def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ...
50
+ run_id: builtins.int
49
51
  def __init__(self,
50
52
  *,
51
53
  task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ...,
54
+ run_id: builtins.int = ...,
52
55
  ) -> None: ...
53
- def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ...
56
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ...
54
57
  global___PushTaskInsRequest = PushTaskInsRequest
55
58
 
56
59
  class PushTaskInsResponse(google.protobuf.message.Message):
@@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message):
70
73
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
71
74
  NODE_FIELD_NUMBER: builtins.int
72
75
  TASK_IDS_FIELD_NUMBER: builtins.int
76
+ RUN_ID_FIELD_NUMBER: builtins.int
73
77
  @property
74
78
  def node(self) -> flwr.proto.node_pb2.Node: ...
75
79
  @property
76
80
  def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
81
+ run_id: builtins.int
77
82
  def __init__(self,
78
83
  *,
79
84
  node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
80
85
  task_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
86
+ run_id: builtins.int = ...,
81
87
  ) -> None: ...
82
88
  def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
83
- def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ...
89
+ def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ...
84
90
  global___PullTaskResRequest = PullTaskResRequest
85
91
 
86
92
  class PullTaskResResponse(google.protobuf.message.Message):
@@ -62,6 +62,11 @@ class ServerAppIoStub(object):
62
62
  request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
63
63
  response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
64
64
  )
65
+ self.GetRunStatus = channel.unary_unary(
66
+ '/flwr.proto.ServerAppIo/GetRunStatus',
67
+ request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
68
+ response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
69
+ )
65
70
  self.PushLogs = channel.unary_unary(
66
71
  '/flwr.proto.ServerAppIo/PushLogs',
67
72
  request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
@@ -135,6 +140,13 @@ class ServerAppIoServicer(object):
135
140
  context.set_details('Method not implemented!')
136
141
  raise NotImplementedError('Method not implemented!')
137
142
 
143
+ def GetRunStatus(self, request, context):
144
+ """Get the status of a given run
145
+ """
146
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
147
+ context.set_details('Method not implemented!')
148
+ raise NotImplementedError('Method not implemented!')
149
+
138
150
  def PushLogs(self, request, context):
139
151
  """Push ServerApp logs
140
152
  """
@@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
190
202
  request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
191
203
  response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
192
204
  ),
205
+ 'GetRunStatus': grpc.unary_unary_rpc_method_handler(
206
+ servicer.GetRunStatus,
207
+ request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
208
+ response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
209
+ ),
193
210
  'PushLogs': grpc.unary_unary_rpc_method_handler(
194
211
  servicer.PushLogs,
195
212
  request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
@@ -358,6 +375,23 @@ class ServerAppIo(object):
358
375
  options, channel_credentials,
359
376
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
360
377
 
378
+ @staticmethod
379
+ def GetRunStatus(request,
380
+ target,
381
+ options=(),
382
+ channel_credentials=None,
383
+ call_credentials=None,
384
+ insecure=False,
385
+ compression=None,
386
+ wait_for_ready=None,
387
+ timeout=None,
388
+ metadata=None):
389
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus',
390
+ flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
391
+ flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
392
+ options, channel_credentials,
393
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
394
+
361
395
  @staticmethod
362
396
  def PushLogs(request,
363
397
  target,
@@ -56,6 +56,11 @@ class ServerAppIoStub:
56
56
  flwr.proto.run_pb2.UpdateRunStatusResponse]
57
57
  """Update the status of a given run"""
58
58
 
59
+ GetRunStatus: grpc.UnaryUnaryMultiCallable[
60
+ flwr.proto.run_pb2.GetRunStatusRequest,
61
+ flwr.proto.run_pb2.GetRunStatusResponse]
62
+ """Get the status of a given run"""
63
+
59
64
  PushLogs: grpc.UnaryUnaryMultiCallable[
60
65
  flwr.proto.log_pb2.PushLogsRequest,
61
66
  flwr.proto.log_pb2.PushLogsResponse]
@@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
135
140
  """Update the status of a given run"""
136
141
  pass
137
142
 
143
+ @abc.abstractmethod
144
+ def GetRunStatus(self,
145
+ request: flwr.proto.run_pb2.GetRunStatusRequest,
146
+ context: grpc.ServicerContext,
147
+ ) -> flwr.proto.run_pb2.GetRunStatusResponse:
148
+ """Get the status of a given run"""
149
+ pass
150
+
138
151
  @abc.abstractmethod
139
152
  def PushLogs(self,
140
153
  request: flwr.proto.log_pb2.PushLogsRequest,
@@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
18
18
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
19
19
 
20
20
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lwr/proto/simulationio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"\x1d\n\x1bPullSimulationInputsRequest\"\x80\x01\n\x1cPullSimulationInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"T\n\x1cPushSimulationOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1f\n\x1dPushSimulationOutputsResponse2\xff\x03\n\x0cSimulationIo\x12k\n\x14PullSimulationInputs\x12\'.flwr.proto.PullSimulationInputsRequest\x1a(.flwr.proto.PullSimulationInputsResponse\"\x00\x12n\n\x15PushSimulationOutputs\x12(.flwr.proto.PushSimulationOutputsRequest\x1a).flwr.proto.PushSimulationOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12k\n\x14GetFederationOptions\x12\'.flwr.proto.GetFederationOptionsRequest\x1a(.flwr.proto.GetFederationOptionsResponse\"\x00\x62\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lwr/proto/simulationio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"\x1d\n\x1bPullSimulationInputsRequest\"\x80\x01\n\x1cPullSimulationInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"T\n\x1cPushSimulationOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1f\n\x1dPushSimulationOutputsResponse2\xd4\x04\n\x0cSimulationIo\x12k\n\x14PullSimulationInputs\x12\'.flwr.proto.PullSimulationInputsRequest\x1a(.flwr.proto.PullSimulationInputsResponse\"\x00\x12n\n\x15PushSimulationOutputs\x12(.flwr.proto.PushSimulationOutputsRequest\x1a).flwr.proto.PushSimulationOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12k\n\x14GetFederationOptions\x12\'.flwr.proto.GetFederationOptionsRequest\x1a(.flwr.proto.GetFederationOptionsResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x62\x06proto3')
22
22
 
23
23
  _globals = globals()
24
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -34,5 +34,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
34
34
  _globals['_PUSHSIMULATIONOUTPUTSRESPONSE']._serialized_start=385
35
35
  _globals['_PUSHSIMULATIONOUTPUTSRESPONSE']._serialized_end=416
36
36
  _globals['_SIMULATIONIO']._serialized_start=419
37
- _globals['_SIMULATIONIO']._serialized_end=930
37
+ _globals['_SIMULATIONIO']._serialized_end=1015
38
38
  # @@protoc_insertion_point(module_scope)
@@ -41,6 +41,11 @@ class SimulationIoStub(object):
41
41
  request_serializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsRequest.SerializeToString,
42
42
  response_deserializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsResponse.FromString,
43
43
  )
44
+ self.GetRunStatus = channel.unary_unary(
45
+ '/flwr.proto.SimulationIo/GetRunStatus',
46
+ request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
47
+ response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
48
+ )
44
49
 
45
50
 
46
51
  class SimulationIoServicer(object):
@@ -81,6 +86,13 @@ class SimulationIoServicer(object):
81
86
  context.set_details('Method not implemented!')
82
87
  raise NotImplementedError('Method not implemented!')
83
88
 
89
+ def GetRunStatus(self, request, context):
90
+ """Get Run Status
91
+ """
92
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
93
+ context.set_details('Method not implemented!')
94
+ raise NotImplementedError('Method not implemented!')
95
+
84
96
 
85
97
  def add_SimulationIoServicer_to_server(servicer, server):
86
98
  rpc_method_handlers = {
@@ -109,6 +121,11 @@ def add_SimulationIoServicer_to_server(servicer, server):
109
121
  request_deserializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsRequest.FromString,
110
122
  response_serializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsResponse.SerializeToString,
111
123
  ),
124
+ 'GetRunStatus': grpc.unary_unary_rpc_method_handler(
125
+ servicer.GetRunStatus,
126
+ request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
127
+ response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
128
+ ),
112
129
  }
113
130
  generic_handler = grpc.method_handlers_generic_handler(
114
131
  'flwr.proto.SimulationIo', rpc_method_handlers)
@@ -203,3 +220,20 @@ class SimulationIo(object):
203
220
  flwr_dot_proto_dot_run__pb2.GetFederationOptionsResponse.FromString,
204
221
  options, channel_credentials,
205
222
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
223
+
224
+ @staticmethod
225
+ def GetRunStatus(request,
226
+ target,
227
+ options=(),
228
+ channel_credentials=None,
229
+ call_credentials=None,
230
+ insecure=False,
231
+ compression=None,
232
+ wait_for_ready=None,
233
+ timeout=None,
234
+ metadata=None):
235
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.SimulationIo/GetRunStatus',
236
+ flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
237
+ flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
238
+ options, channel_credentials,
239
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -35,6 +35,11 @@ class SimulationIoStub:
35
35
  flwr.proto.run_pb2.GetFederationOptionsResponse]
36
36
  """Get Federation Options"""
37
37
 
38
+ GetRunStatus: grpc.UnaryUnaryMultiCallable[
39
+ flwr.proto.run_pb2.GetRunStatusRequest,
40
+ flwr.proto.run_pb2.GetRunStatusResponse]
41
+ """Get Run Status"""
42
+
38
43
 
39
44
  class SimulationIoServicer(metaclass=abc.ABCMeta):
40
45
  @abc.abstractmethod
@@ -77,5 +82,13 @@ class SimulationIoServicer(metaclass=abc.ABCMeta):
77
82
  """Get Federation Options"""
78
83
  pass
79
84
 
85
+ @abc.abstractmethod
86
+ def GetRunStatus(self,
87
+ request: flwr.proto.run_pb2.GetRunStatusRequest,
88
+ context: grpc.ServicerContext,
89
+ ) -> flwr.proto.run_pb2.GetRunStatusResponse:
90
+ """Get Run Status"""
91
+ pass
92
+
80
93
 
81
94
  def add_SimulationIoServicer_to_server(servicer: SimulationIoServicer, server: grpc.Server) -> None: ...
flwr/server/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower server app."""
16
16
 
17
+
17
18
  import argparse
18
19
  import csv
19
20
  import importlib.util
@@ -24,9 +25,10 @@ from collections.abc import Sequence
24
25
  from logging import DEBUG, INFO, WARN
25
26
  from pathlib import Path
26
27
  from time import sleep
27
- from typing import Optional
28
+ from typing import Any, Optional
28
29
 
29
30
  import grpc
31
+ import yaml
30
32
  from cryptography.exceptions import UnsupportedAlgorithm
31
33
  from cryptography.hazmat.primitives.asymmetric import ec
32
34
  from cryptography.hazmat.primitives.serialization import (
@@ -37,8 +39,10 @@ from cryptography.hazmat.primitives.serialization import (
37
39
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
38
40
  from flwr.common.address import parse_address
39
41
  from flwr.common.args import try_obtain_server_certificates
42
+ from flwr.common.auth_plugin import ExecAuthPlugin
40
43
  from flwr.common.config import get_flwr_dir, parse_config_args
41
44
  from flwr.common.constant import (
45
+ AUTH_TYPE,
42
46
  CLIENT_OCTET,
43
47
  EXEC_API_DEFAULT_SERVER_ADDRESS,
44
48
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -88,6 +92,15 @@ DATABASE = ":flwr-in-memory-state:"
88
92
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
89
93
 
90
94
 
95
+ try:
96
+ from flwr.ee import get_exec_auth_plugins
97
+ except ImportError:
98
+
99
+ def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
100
+ """Return all Exec API authentication plugins."""
101
+ raise NotImplementedError("No authentication plugins are currently supported.")
102
+
103
+
91
104
  def start_server( # pylint: disable=too-many-arguments,too-many-locals
92
105
  *,
93
106
  server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -246,6 +259,12 @@ def run_superlink() -> None:
246
259
  # Obtain certificates
247
260
  certificates = try_obtain_server_certificates(args, args.fleet_api_type)
248
261
 
262
+ user_auth_config = _try_obtain_user_auth_config(args)
263
+ auth_plugin: Optional[ExecAuthPlugin] = None
264
+ # user_auth_config is None only if the args.user_auth_config is not provided
265
+ if user_auth_config is not None:
266
+ auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
267
+
249
268
  # Initialize StateFactory
250
269
  state_factory = LinkStateFactory(args.database)
251
270
 
@@ -263,6 +282,7 @@ def run_superlink() -> None:
263
282
  config=parse_config_args(
264
283
  [args.executor_config] if args.executor_config else args.executor_config
265
284
  ),
285
+ auth_plugin=auth_plugin,
266
286
  )
267
287
  grpc_servers = [exec_server]
268
288
 
@@ -559,6 +579,32 @@ def _try_setup_node_authentication(
559
579
  )
560
580
 
561
581
 
582
+ def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
583
+ if args.user_auth_config is not None:
584
+ with open(args.user_auth_config, encoding="utf-8") as file:
585
+ config: dict[str, Any] = yaml.safe_load(file)
586
+ return config
587
+ return None
588
+
589
+
590
+ def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
591
+ auth_config: dict[str, Any] = config.get("authentication", {})
592
+ auth_type: str = auth_config.get(AUTH_TYPE, "")
593
+ try:
594
+ all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
595
+ auth_plugin_class = all_plugins[auth_type]
596
+ return auth_plugin_class(config=auth_config)
597
+ except KeyError:
598
+ if auth_type != "":
599
+ sys.exit(
600
+ f'Authentication type "{auth_type}" is not supported. '
601
+ "Please provide a valid authentication type in the configuration."
602
+ )
603
+ sys.exit("No authentication type is provided in the configuration.")
604
+ except NotImplementedError:
605
+ sys.exit("No authentication plugins are currently supported.")
606
+
607
+
562
608
  def _run_fleet_api_grpc_rere(
563
609
  address: str,
564
610
  state_factory: LinkStateFactory,
@@ -746,6 +792,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
746
792
  type=str,
747
793
  help="The SuperLink's public key (as a path str) to enable authentication.",
748
794
  )
795
+ parser.add_argument(
796
+ "--user-auth-config",
797
+ help="The path to the user authentication configuration YAML file.",
798
+ type=str,
799
+ default=None,
800
+ )
749
801
 
750
802
 
751
803
  def _add_args_serverappio_api(parser: argparse.ArgumentParser) -> None:
@@ -17,6 +17,8 @@
17
17
 
18
18
  import threading
19
19
 
20
+ from flwr.common.typing import RunNotRunningException
21
+
20
22
  from ..client_manager import ClientManager
21
23
  from ..compat.driver_client_proxy import DriverClientProxy
22
24
  from ..driver import Driver
@@ -74,7 +76,11 @@ def _update_client_manager(
74
76
  # Loop until the driver is disconnected
75
77
  registered_nodes: dict[int, DriverClientProxy] = {}
76
78
  while not f_stop.is_set():
77
- all_node_ids = set(driver.get_node_ids())
79
+ try:
80
+ all_node_ids = set(driver.get_node_ids())
81
+ except RunNotRunningException:
82
+ f_stop.set()
83
+ break
78
84
  dead_nodes = set(registered_nodes).difference(all_node_ids)
79
85
  new_nodes = all_node_ids.difference(registered_nodes)
80
86
 
@@ -14,19 +14,20 @@
14
14
  # ==============================================================================
15
15
  """Flower gRPC Driver."""
16
16
 
17
+
17
18
  import time
18
19
  import warnings
19
20
  from collections.abc import Iterable
20
- from logging import DEBUG, INFO, WARN, WARNING
21
- from typing import Any, Optional, cast
21
+ from logging import DEBUG, WARNING
22
+ from typing import Optional, cast
22
23
 
23
24
  import grpc
24
25
 
25
26
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
26
- from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
+ from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
28
  from flwr.common.grpc import create_channel
28
29
  from flwr.common.logger import log
29
- from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
30
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
30
31
  from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
31
32
  from flwr.common.typing import Run
32
33
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -203,7 +204,9 @@ class GrpcDriver(Driver):
203
204
  task_ins_list.append(taskins)
204
205
  # Call GrpcDriverStub method
205
206
  res: PushTaskInsResponse = self._stub.PushTaskIns(
206
- PushTaskInsRequest(task_ins_list=task_ins_list)
207
+ PushTaskInsRequest(
208
+ task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
209
+ )
207
210
  )
208
211
  return list(res.task_ids)
209
212
 
@@ -215,7 +218,9 @@ class GrpcDriver(Driver):
215
218
  """
216
219
  # Pull TaskRes
217
220
  res: PullTaskResResponse = self._stub.PullTaskRes(
218
- PullTaskResRequest(node=self.node, task_ids=message_ids)
221
+ PullTaskResRequest(
222
+ node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
223
+ )
219
224
  )
220
225
  # Convert TaskRes to Message
221
226
  msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
@@ -258,60 +263,3 @@ class GrpcDriver(Driver):
258
263
  return
259
264
  # Disconnect
260
265
  self._disconnect()
261
-
262
-
263
- def _make_simple_grpc_retry_invoker() -> RetryInvoker:
264
- """Create a simple gRPC retry invoker."""
265
-
266
- def _on_sucess(retry_state: RetryState) -> None:
267
- if retry_state.tries > 1:
268
- log(
269
- INFO,
270
- "Connection successful after %.2f seconds and %s tries.",
271
- retry_state.elapsed_time,
272
- retry_state.tries,
273
- )
274
-
275
- def _on_backoff(retry_state: RetryState) -> None:
276
- if retry_state.tries == 1:
277
- log(WARN, "Connection attempt failed, retrying...")
278
- else:
279
- log(
280
- WARN,
281
- "Connection attempt failed, retrying in %.2f seconds",
282
- retry_state.actual_wait,
283
- )
284
-
285
- def _on_giveup(retry_state: RetryState) -> None:
286
- if retry_state.tries > 1:
287
- log(
288
- WARN,
289
- "Giving up reconnection after %.2f seconds and %s tries.",
290
- retry_state.elapsed_time,
291
- retry_state.tries,
292
- )
293
-
294
- return RetryInvoker(
295
- wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
296
- recoverable_exceptions=grpc.RpcError,
297
- max_tries=None,
298
- max_time=None,
299
- on_success=_on_sucess,
300
- on_backoff=_on_backoff,
301
- on_giveup=_on_giveup,
302
- should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
303
- )
304
-
305
-
306
- def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
307
- """Wrap the gRPC stub with a retry invoker."""
308
-
309
- def make_lambda(original_method: Any) -> Any:
310
- return lambda *args, **kwargs: retry_invoker.invoke(
311
- original_method, *args, **kwargs
312
- )
313
-
314
- for method_name in vars(stub):
315
- method = getattr(stub, method_name)
316
- if callable(method):
317
- setattr(stub, method_name, make_lambda(method))
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
142
142
  # Pull TaskRes
143
143
  task_res_list = self.state.get_task_res(task_ids=msg_ids)
144
144
  # Delete tasks in state
145
- self.state.delete_tasks(msg_ids)
145
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
146
+ task_ins_ids_to_delete = {
147
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
148
+ }
149
+ self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
146
150
  # Convert TaskRes to Message
147
151
  msgs = [message_from_taskres(taskres) for taskres in task_res_list]
148
152
  return msgs