flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +11 -31
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +83 -0
- flwr/cli/ls.py +10 -40
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +15 -25
- flwr/cli/stop.py +91 -0
- flwr/cli/utils.py +109 -1
- flwr/client/app.py +3 -2
- flwr/client/client.py +1 -0
- flwr/client/clientapp/app.py +1 -0
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +3 -3
- flwr/client/message_handler/message_handler.py +1 -0
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +6 -1
- flwr/common/logger.py +1 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +75 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +2 -1
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +27 -3
- flwr/proto/exec_pb2.pyi +103 -0
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +53 -1
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/serverapp/app.py +9 -2
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +72 -22
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +31 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
- flwr/server/superlink/linkstate/linkstate.py +13 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
- flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +1 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/METADATA +8 -7
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/RECORD +100 -92
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/entry_points.txt +0 -0
flwr/proto/exec_pb2_grpc.pyi
CHANGED
|
@@ -14,6 +14,11 @@ class ExecStub:
|
|
|
14
14
|
flwr.proto.exec_pb2.StartRunResponse]
|
|
15
15
|
"""Start run upon request"""
|
|
16
16
|
|
|
17
|
+
StopRun: grpc.UnaryUnaryMultiCallable[
|
|
18
|
+
flwr.proto.exec_pb2.StopRunRequest,
|
|
19
|
+
flwr.proto.exec_pb2.StopRunResponse]
|
|
20
|
+
"""Stop run upon request"""
|
|
21
|
+
|
|
17
22
|
StreamLogs: grpc.UnaryStreamMultiCallable[
|
|
18
23
|
flwr.proto.exec_pb2.StreamLogsRequest,
|
|
19
24
|
flwr.proto.exec_pb2.StreamLogsResponse]
|
|
@@ -24,6 +29,16 @@ class ExecStub:
|
|
|
24
29
|
flwr.proto.exec_pb2.ListRunsResponse]
|
|
25
30
|
"""flwr ls command"""
|
|
26
31
|
|
|
32
|
+
GetLoginDetails: grpc.UnaryUnaryMultiCallable[
|
|
33
|
+
flwr.proto.exec_pb2.GetLoginDetailsRequest,
|
|
34
|
+
flwr.proto.exec_pb2.GetLoginDetailsResponse]
|
|
35
|
+
"""Get login details upon request"""
|
|
36
|
+
|
|
37
|
+
GetAuthTokens: grpc.UnaryUnaryMultiCallable[
|
|
38
|
+
flwr.proto.exec_pb2.GetAuthTokensRequest,
|
|
39
|
+
flwr.proto.exec_pb2.GetAuthTokensResponse]
|
|
40
|
+
"""Get auth tokens upon request"""
|
|
41
|
+
|
|
27
42
|
|
|
28
43
|
class ExecServicer(metaclass=abc.ABCMeta):
|
|
29
44
|
@abc.abstractmethod
|
|
@@ -34,6 +49,14 @@ class ExecServicer(metaclass=abc.ABCMeta):
|
|
|
34
49
|
"""Start run upon request"""
|
|
35
50
|
pass
|
|
36
51
|
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
def StopRun(self,
|
|
54
|
+
request: flwr.proto.exec_pb2.StopRunRequest,
|
|
55
|
+
context: grpc.ServicerContext,
|
|
56
|
+
) -> flwr.proto.exec_pb2.StopRunResponse:
|
|
57
|
+
"""Stop run upon request"""
|
|
58
|
+
pass
|
|
59
|
+
|
|
37
60
|
@abc.abstractmethod
|
|
38
61
|
def StreamLogs(self,
|
|
39
62
|
request: flwr.proto.exec_pb2.StreamLogsRequest,
|
|
@@ -50,5 +73,21 @@ class ExecServicer(metaclass=abc.ABCMeta):
|
|
|
50
73
|
"""flwr ls command"""
|
|
51
74
|
pass
|
|
52
75
|
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def GetLoginDetails(self,
|
|
78
|
+
request: flwr.proto.exec_pb2.GetLoginDetailsRequest,
|
|
79
|
+
context: grpc.ServicerContext,
|
|
80
|
+
) -> flwr.proto.exec_pb2.GetLoginDetailsResponse:
|
|
81
|
+
"""Get login details upon request"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abc.abstractmethod
|
|
85
|
+
def GetAuthTokens(self,
|
|
86
|
+
request: flwr.proto.exec_pb2.GetAuthTokensRequest,
|
|
87
|
+
context: grpc.ServicerContext,
|
|
88
|
+
) -> flwr.proto.exec_pb2.GetAuthTokensResponse:
|
|
89
|
+
"""Get auth tokens upon request"""
|
|
90
|
+
pass
|
|
91
|
+
|
|
53
92
|
|
|
54
93
|
def add_ExecServicer_to_server(servicer: ExecServicer, server: grpc.Server) -> None: ...
|
flwr/proto/fab_pb2.py
CHANGED
|
@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"
|
|
18
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"Q\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
|
|
19
19
|
|
|
20
20
|
_globals = globals()
|
|
21
21
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -25,7 +25,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
25
25
|
_globals['_FAB']._serialized_start=59
|
|
26
26
|
_globals['_FAB']._serialized_end=99
|
|
27
27
|
_globals['_GETFABREQUEST']._serialized_start=101
|
|
28
|
-
_globals['_GETFABREQUEST']._serialized_end=
|
|
29
|
-
_globals['_GETFABRESPONSE']._serialized_start=
|
|
30
|
-
_globals['_GETFABRESPONSE']._serialized_end=
|
|
28
|
+
_globals['_GETFABREQUEST']._serialized_end=182
|
|
29
|
+
_globals['_GETFABRESPONSE']._serialized_start=184
|
|
30
|
+
_globals['_GETFABRESPONSE']._serialized_end=230
|
|
31
31
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fab_pb2.pyi
CHANGED
|
@@ -36,16 +36,19 @@ class GetFabRequest(google.protobuf.message.Message):
|
|
|
36
36
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
37
37
|
NODE_FIELD_NUMBER: builtins.int
|
|
38
38
|
HASH_STR_FIELD_NUMBER: builtins.int
|
|
39
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
39
40
|
@property
|
|
40
41
|
def node(self) -> flwr.proto.node_pb2.Node: ...
|
|
41
42
|
hash_str: typing.Text
|
|
43
|
+
run_id: builtins.int
|
|
42
44
|
def __init__(self,
|
|
43
45
|
*,
|
|
44
46
|
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
45
47
|
hash_str: typing.Text = ...,
|
|
48
|
+
run_id: builtins.int = ...,
|
|
46
49
|
) -> None: ...
|
|
47
50
|
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
|
|
48
|
-
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
|
|
51
|
+
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node","run_id",b"run_id"]) -> None: ...
|
|
49
52
|
global___GetFabRequest = GetFabRequest
|
|
50
53
|
|
|
51
54
|
class GetFabResponse(google.protobuf.message.Message):
|
flwr/proto/serverappio_pb2.py
CHANGED
|
@@ -20,7 +20,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
|
20
20
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"
|
|
23
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
|
|
24
24
|
|
|
25
25
|
_globals = globals()
|
|
26
26
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -32,21 +32,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
32
32
|
_globals['_GETNODESRESPONSE']._serialized_start=217
|
|
33
33
|
_globals['_GETNODESRESPONSE']._serialized_end=268
|
|
34
34
|
_globals['_PUSHTASKINSREQUEST']._serialized_start=270
|
|
35
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_end=
|
|
36
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_start=
|
|
37
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_end=
|
|
38
|
-
_globals['_PULLTASKRESREQUEST']._serialized_start=
|
|
39
|
-
_globals['_PULLTASKRESREQUEST']._serialized_end=
|
|
40
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_start=
|
|
41
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_end=
|
|
42
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=
|
|
43
|
-
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=
|
|
44
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=
|
|
45
|
-
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=
|
|
46
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=
|
|
47
|
-
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=
|
|
48
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=
|
|
49
|
-
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=
|
|
50
|
-
_globals['_SERVERAPPIO']._serialized_start=
|
|
51
|
-
_globals['_SERVERAPPIO']._serialized_end=
|
|
35
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_end=350
|
|
36
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_start=352
|
|
37
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_end=391
|
|
38
|
+
_globals['_PULLTASKRESREQUEST']._serialized_start=393
|
|
39
|
+
_globals['_PULLTASKRESREQUEST']._serialized_end=479
|
|
40
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_start=481
|
|
41
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_end=546
|
|
42
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548
|
|
43
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576
|
|
44
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578
|
|
45
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705
|
|
46
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707
|
|
47
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790
|
|
48
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792
|
|
49
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822
|
|
50
|
+
_globals['_SERVERAPPIO']._serialized_start=825
|
|
51
|
+
_globals['_SERVERAPPIO']._serialized_end=1752
|
|
52
52
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/serverappio_pb2.pyi
CHANGED
|
@@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message):
|
|
|
44
44
|
"""PushTaskIns messages"""
|
|
45
45
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
46
46
|
TASK_INS_LIST_FIELD_NUMBER: builtins.int
|
|
47
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
47
48
|
@property
|
|
48
49
|
def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ...
|
|
50
|
+
run_id: builtins.int
|
|
49
51
|
def __init__(self,
|
|
50
52
|
*,
|
|
51
53
|
task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ...,
|
|
54
|
+
run_id: builtins.int = ...,
|
|
52
55
|
) -> None: ...
|
|
53
|
-
def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ...
|
|
56
|
+
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ...
|
|
54
57
|
global___PushTaskInsRequest = PushTaskInsRequest
|
|
55
58
|
|
|
56
59
|
class PushTaskInsResponse(google.protobuf.message.Message):
|
|
@@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message):
|
|
|
70
73
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
71
74
|
NODE_FIELD_NUMBER: builtins.int
|
|
72
75
|
TASK_IDS_FIELD_NUMBER: builtins.int
|
|
76
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
73
77
|
@property
|
|
74
78
|
def node(self) -> flwr.proto.node_pb2.Node: ...
|
|
75
79
|
@property
|
|
76
80
|
def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
|
|
81
|
+
run_id: builtins.int
|
|
77
82
|
def __init__(self,
|
|
78
83
|
*,
|
|
79
84
|
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
80
85
|
task_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
|
86
|
+
run_id: builtins.int = ...,
|
|
81
87
|
) -> None: ...
|
|
82
88
|
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
|
|
83
|
-
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ...
|
|
89
|
+
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ...
|
|
84
90
|
global___PullTaskResRequest = PullTaskResRequest
|
|
85
91
|
|
|
86
92
|
class PullTaskResResponse(google.protobuf.message.Message):
|
|
@@ -62,6 +62,11 @@ class ServerAppIoStub(object):
|
|
|
62
62
|
request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
|
|
63
63
|
response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
|
|
64
64
|
)
|
|
65
|
+
self.GetRunStatus = channel.unary_unary(
|
|
66
|
+
'/flwr.proto.ServerAppIo/GetRunStatus',
|
|
67
|
+
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
|
|
68
|
+
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
|
|
69
|
+
)
|
|
65
70
|
self.PushLogs = channel.unary_unary(
|
|
66
71
|
'/flwr.proto.ServerAppIo/PushLogs',
|
|
67
72
|
request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
|
|
@@ -135,6 +140,13 @@ class ServerAppIoServicer(object):
|
|
|
135
140
|
context.set_details('Method not implemented!')
|
|
136
141
|
raise NotImplementedError('Method not implemented!')
|
|
137
142
|
|
|
143
|
+
def GetRunStatus(self, request, context):
|
|
144
|
+
"""Get the status of a given run
|
|
145
|
+
"""
|
|
146
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
147
|
+
context.set_details('Method not implemented!')
|
|
148
|
+
raise NotImplementedError('Method not implemented!')
|
|
149
|
+
|
|
138
150
|
def PushLogs(self, request, context):
|
|
139
151
|
"""Push ServerApp logs
|
|
140
152
|
"""
|
|
@@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
|
|
|
190
202
|
request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
|
|
191
203
|
response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
|
|
192
204
|
),
|
|
205
|
+
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
|
|
206
|
+
servicer.GetRunStatus,
|
|
207
|
+
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
|
|
208
|
+
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
|
|
209
|
+
),
|
|
193
210
|
'PushLogs': grpc.unary_unary_rpc_method_handler(
|
|
194
211
|
servicer.PushLogs,
|
|
195
212
|
request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
|
|
@@ -358,6 +375,23 @@ class ServerAppIo(object):
|
|
|
358
375
|
options, channel_credentials,
|
|
359
376
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
360
377
|
|
|
378
|
+
@staticmethod
|
|
379
|
+
def GetRunStatus(request,
|
|
380
|
+
target,
|
|
381
|
+
options=(),
|
|
382
|
+
channel_credentials=None,
|
|
383
|
+
call_credentials=None,
|
|
384
|
+
insecure=False,
|
|
385
|
+
compression=None,
|
|
386
|
+
wait_for_ready=None,
|
|
387
|
+
timeout=None,
|
|
388
|
+
metadata=None):
|
|
389
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus',
|
|
390
|
+
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
|
|
391
|
+
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
|
|
392
|
+
options, channel_credentials,
|
|
393
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
394
|
+
|
|
361
395
|
@staticmethod
|
|
362
396
|
def PushLogs(request,
|
|
363
397
|
target,
|
|
@@ -56,6 +56,11 @@ class ServerAppIoStub:
|
|
|
56
56
|
flwr.proto.run_pb2.UpdateRunStatusResponse]
|
|
57
57
|
"""Update the status of a given run"""
|
|
58
58
|
|
|
59
|
+
GetRunStatus: grpc.UnaryUnaryMultiCallable[
|
|
60
|
+
flwr.proto.run_pb2.GetRunStatusRequest,
|
|
61
|
+
flwr.proto.run_pb2.GetRunStatusResponse]
|
|
62
|
+
"""Get the status of a given run"""
|
|
63
|
+
|
|
59
64
|
PushLogs: grpc.UnaryUnaryMultiCallable[
|
|
60
65
|
flwr.proto.log_pb2.PushLogsRequest,
|
|
61
66
|
flwr.proto.log_pb2.PushLogsResponse]
|
|
@@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
|
|
|
135
140
|
"""Update the status of a given run"""
|
|
136
141
|
pass
|
|
137
142
|
|
|
143
|
+
@abc.abstractmethod
|
|
144
|
+
def GetRunStatus(self,
|
|
145
|
+
request: flwr.proto.run_pb2.GetRunStatusRequest,
|
|
146
|
+
context: grpc.ServicerContext,
|
|
147
|
+
) -> flwr.proto.run_pb2.GetRunStatusResponse:
|
|
148
|
+
"""Get the status of a given run"""
|
|
149
|
+
pass
|
|
150
|
+
|
|
138
151
|
@abc.abstractmethod
|
|
139
152
|
def PushLogs(self,
|
|
140
153
|
request: flwr.proto.log_pb2.PushLogsRequest,
|
flwr/proto/simulationio_pb2.py
CHANGED
|
@@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
|
18
18
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lwr/proto/simulationio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"\x1d\n\x1bPullSimulationInputsRequest\"\x80\x01\n\x1cPullSimulationInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"T\n\x1cPushSimulationOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1f\n\x1dPushSimulationOutputsResponse2\
|
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lwr/proto/simulationio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"\x1d\n\x1bPullSimulationInputsRequest\"\x80\x01\n\x1cPullSimulationInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"T\n\x1cPushSimulationOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1f\n\x1dPushSimulationOutputsResponse2\xd4\x04\n\x0cSimulationIo\x12k\n\x14PullSimulationInputs\x12\'.flwr.proto.PullSimulationInputsRequest\x1a(.flwr.proto.PullSimulationInputsResponse\"\x00\x12n\n\x15PushSimulationOutputs\x12(.flwr.proto.PushSimulationOutputsRequest\x1a).flwr.proto.PushSimulationOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x12k\n\x14GetFederationOptions\x12\'.flwr.proto.GetFederationOptionsRequest\x1a(.flwr.proto.GetFederationOptionsResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x62\x06proto3')
|
|
22
22
|
|
|
23
23
|
_globals = globals()
|
|
24
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -34,5 +34,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
34
34
|
_globals['_PUSHSIMULATIONOUTPUTSRESPONSE']._serialized_start=385
|
|
35
35
|
_globals['_PUSHSIMULATIONOUTPUTSRESPONSE']._serialized_end=416
|
|
36
36
|
_globals['_SIMULATIONIO']._serialized_start=419
|
|
37
|
-
_globals['_SIMULATIONIO']._serialized_end=
|
|
37
|
+
_globals['_SIMULATIONIO']._serialized_end=1015
|
|
38
38
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -41,6 +41,11 @@ class SimulationIoStub(object):
|
|
|
41
41
|
request_serializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsRequest.SerializeToString,
|
|
42
42
|
response_deserializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsResponse.FromString,
|
|
43
43
|
)
|
|
44
|
+
self.GetRunStatus = channel.unary_unary(
|
|
45
|
+
'/flwr.proto.SimulationIo/GetRunStatus',
|
|
46
|
+
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
|
|
47
|
+
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
|
|
48
|
+
)
|
|
44
49
|
|
|
45
50
|
|
|
46
51
|
class SimulationIoServicer(object):
|
|
@@ -81,6 +86,13 @@ class SimulationIoServicer(object):
|
|
|
81
86
|
context.set_details('Method not implemented!')
|
|
82
87
|
raise NotImplementedError('Method not implemented!')
|
|
83
88
|
|
|
89
|
+
def GetRunStatus(self, request, context):
|
|
90
|
+
"""Get Run Status
|
|
91
|
+
"""
|
|
92
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
93
|
+
context.set_details('Method not implemented!')
|
|
94
|
+
raise NotImplementedError('Method not implemented!')
|
|
95
|
+
|
|
84
96
|
|
|
85
97
|
def add_SimulationIoServicer_to_server(servicer, server):
|
|
86
98
|
rpc_method_handlers = {
|
|
@@ -109,6 +121,11 @@ def add_SimulationIoServicer_to_server(servicer, server):
|
|
|
109
121
|
request_deserializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsRequest.FromString,
|
|
110
122
|
response_serializer=flwr_dot_proto_dot_run__pb2.GetFederationOptionsResponse.SerializeToString,
|
|
111
123
|
),
|
|
124
|
+
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
|
|
125
|
+
servicer.GetRunStatus,
|
|
126
|
+
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
|
|
127
|
+
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
|
|
128
|
+
),
|
|
112
129
|
}
|
|
113
130
|
generic_handler = grpc.method_handlers_generic_handler(
|
|
114
131
|
'flwr.proto.SimulationIo', rpc_method_handlers)
|
|
@@ -203,3 +220,20 @@ class SimulationIo(object):
|
|
|
203
220
|
flwr_dot_proto_dot_run__pb2.GetFederationOptionsResponse.FromString,
|
|
204
221
|
options, channel_credentials,
|
|
205
222
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def GetRunStatus(request,
|
|
226
|
+
target,
|
|
227
|
+
options=(),
|
|
228
|
+
channel_credentials=None,
|
|
229
|
+
call_credentials=None,
|
|
230
|
+
insecure=False,
|
|
231
|
+
compression=None,
|
|
232
|
+
wait_for_ready=None,
|
|
233
|
+
timeout=None,
|
|
234
|
+
metadata=None):
|
|
235
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.SimulationIo/GetRunStatus',
|
|
236
|
+
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
|
|
237
|
+
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
|
|
238
|
+
options, channel_credentials,
|
|
239
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@@ -35,6 +35,11 @@ class SimulationIoStub:
|
|
|
35
35
|
flwr.proto.run_pb2.GetFederationOptionsResponse]
|
|
36
36
|
"""Get Federation Options"""
|
|
37
37
|
|
|
38
|
+
GetRunStatus: grpc.UnaryUnaryMultiCallable[
|
|
39
|
+
flwr.proto.run_pb2.GetRunStatusRequest,
|
|
40
|
+
flwr.proto.run_pb2.GetRunStatusResponse]
|
|
41
|
+
"""Get Run Status"""
|
|
42
|
+
|
|
38
43
|
|
|
39
44
|
class SimulationIoServicer(metaclass=abc.ABCMeta):
|
|
40
45
|
@abc.abstractmethod
|
|
@@ -77,5 +82,13 @@ class SimulationIoServicer(metaclass=abc.ABCMeta):
|
|
|
77
82
|
"""Get Federation Options"""
|
|
78
83
|
pass
|
|
79
84
|
|
|
85
|
+
@abc.abstractmethod
|
|
86
|
+
def GetRunStatus(self,
|
|
87
|
+
request: flwr.proto.run_pb2.GetRunStatusRequest,
|
|
88
|
+
context: grpc.ServicerContext,
|
|
89
|
+
) -> flwr.proto.run_pb2.GetRunStatusResponse:
|
|
90
|
+
"""Get Run Status"""
|
|
91
|
+
pass
|
|
92
|
+
|
|
80
93
|
|
|
81
94
|
def add_SimulationIoServicer_to_server(servicer: SimulationIoServicer, server: grpc.Server) -> None: ...
|
flwr/server/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower server app."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import csv
|
|
19
20
|
import importlib.util
|
|
@@ -24,9 +25,10 @@ from collections.abc import Sequence
|
|
|
24
25
|
from logging import DEBUG, INFO, WARN
|
|
25
26
|
from pathlib import Path
|
|
26
27
|
from time import sleep
|
|
27
|
-
from typing import Optional
|
|
28
|
+
from typing import Any, Optional
|
|
28
29
|
|
|
29
30
|
import grpc
|
|
31
|
+
import yaml
|
|
30
32
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
31
33
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
32
34
|
from cryptography.hazmat.primitives.serialization import (
|
|
@@ -37,8 +39,10 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
37
39
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
38
40
|
from flwr.common.address import parse_address
|
|
39
41
|
from flwr.common.args import try_obtain_server_certificates
|
|
42
|
+
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
40
43
|
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
41
44
|
from flwr.common.constant import (
|
|
45
|
+
AUTH_TYPE,
|
|
42
46
|
CLIENT_OCTET,
|
|
43
47
|
EXEC_API_DEFAULT_SERVER_ADDRESS,
|
|
44
48
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -88,6 +92,15 @@ DATABASE = ":flwr-in-memory-state:"
|
|
|
88
92
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
89
93
|
|
|
90
94
|
|
|
95
|
+
try:
|
|
96
|
+
from flwr.ee import get_exec_auth_plugins
|
|
97
|
+
except ImportError:
|
|
98
|
+
|
|
99
|
+
def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
|
|
100
|
+
"""Return all Exec API authentication plugins."""
|
|
101
|
+
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
102
|
+
|
|
103
|
+
|
|
91
104
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
92
105
|
*,
|
|
93
106
|
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -246,6 +259,12 @@ def run_superlink() -> None:
|
|
|
246
259
|
# Obtain certificates
|
|
247
260
|
certificates = try_obtain_server_certificates(args, args.fleet_api_type)
|
|
248
261
|
|
|
262
|
+
user_auth_config = _try_obtain_user_auth_config(args)
|
|
263
|
+
auth_plugin: Optional[ExecAuthPlugin] = None
|
|
264
|
+
# user_auth_config is None only if the args.user_auth_config is not provided
|
|
265
|
+
if user_auth_config is not None:
|
|
266
|
+
auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
|
|
267
|
+
|
|
249
268
|
# Initialize StateFactory
|
|
250
269
|
state_factory = LinkStateFactory(args.database)
|
|
251
270
|
|
|
@@ -263,6 +282,7 @@ def run_superlink() -> None:
|
|
|
263
282
|
config=parse_config_args(
|
|
264
283
|
[args.executor_config] if args.executor_config else args.executor_config
|
|
265
284
|
),
|
|
285
|
+
auth_plugin=auth_plugin,
|
|
266
286
|
)
|
|
267
287
|
grpc_servers = [exec_server]
|
|
268
288
|
|
|
@@ -559,6 +579,32 @@ def _try_setup_node_authentication(
|
|
|
559
579
|
)
|
|
560
580
|
|
|
561
581
|
|
|
582
|
+
def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
|
|
583
|
+
if args.user_auth_config is not None:
|
|
584
|
+
with open(args.user_auth_config, encoding="utf-8") as file:
|
|
585
|
+
config: dict[str, Any] = yaml.safe_load(file)
|
|
586
|
+
return config
|
|
587
|
+
return None
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
|
|
591
|
+
auth_config: dict[str, Any] = config.get("authentication", {})
|
|
592
|
+
auth_type: str = auth_config.get(AUTH_TYPE, "")
|
|
593
|
+
try:
|
|
594
|
+
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
|
|
595
|
+
auth_plugin_class = all_plugins[auth_type]
|
|
596
|
+
return auth_plugin_class(config=auth_config)
|
|
597
|
+
except KeyError:
|
|
598
|
+
if auth_type != "":
|
|
599
|
+
sys.exit(
|
|
600
|
+
f'Authentication type "{auth_type}" is not supported. '
|
|
601
|
+
"Please provide a valid authentication type in the configuration."
|
|
602
|
+
)
|
|
603
|
+
sys.exit("No authentication type is provided in the configuration.")
|
|
604
|
+
except NotImplementedError:
|
|
605
|
+
sys.exit("No authentication plugins are currently supported.")
|
|
606
|
+
|
|
607
|
+
|
|
562
608
|
def _run_fleet_api_grpc_rere(
|
|
563
609
|
address: str,
|
|
564
610
|
state_factory: LinkStateFactory,
|
|
@@ -746,6 +792,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
746
792
|
type=str,
|
|
747
793
|
help="The SuperLink's public key (as a path str) to enable authentication.",
|
|
748
794
|
)
|
|
795
|
+
parser.add_argument(
|
|
796
|
+
"--user-auth-config",
|
|
797
|
+
help="The path to the user authentication configuration YAML file.",
|
|
798
|
+
type=str,
|
|
799
|
+
default=None,
|
|
800
|
+
)
|
|
749
801
|
|
|
750
802
|
|
|
751
803
|
def _add_args_serverappio_api(parser: argparse.ArgumentParser) -> None:
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
|
|
20
|
+
from flwr.common.typing import RunNotRunningException
|
|
21
|
+
|
|
20
22
|
from ..client_manager import ClientManager
|
|
21
23
|
from ..compat.driver_client_proxy import DriverClientProxy
|
|
22
24
|
from ..driver import Driver
|
|
@@ -74,7 +76,11 @@ def _update_client_manager(
|
|
|
74
76
|
# Loop until the driver is disconnected
|
|
75
77
|
registered_nodes: dict[int, DriverClientProxy] = {}
|
|
76
78
|
while not f_stop.is_set():
|
|
77
|
-
|
|
79
|
+
try:
|
|
80
|
+
all_node_ids = set(driver.get_node_ids())
|
|
81
|
+
except RunNotRunningException:
|
|
82
|
+
f_stop.set()
|
|
83
|
+
break
|
|
78
84
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
|
79
85
|
new_nodes = all_node_ids.difference(registered_nodes)
|
|
80
86
|
|
|
@@ -14,19 +14,20 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower gRPC Driver."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import time
|
|
18
19
|
import warnings
|
|
19
20
|
from collections.abc import Iterable
|
|
20
|
-
from logging import DEBUG,
|
|
21
|
-
from typing import
|
|
21
|
+
from logging import DEBUG, WARNING
|
|
22
|
+
from typing import Optional, cast
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
25
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
27
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
|
27
28
|
from flwr.common.grpc import create_channel
|
|
28
29
|
from flwr.common.logger import log
|
|
29
|
-
from flwr.common.retry_invoker import
|
|
30
|
+
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
30
31
|
from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
|
|
31
32
|
from flwr.common.typing import Run
|
|
32
33
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
@@ -203,7 +204,9 @@ class GrpcDriver(Driver):
|
|
|
203
204
|
task_ins_list.append(taskins)
|
|
204
205
|
# Call GrpcDriverStub method
|
|
205
206
|
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
206
|
-
PushTaskInsRequest(
|
|
207
|
+
PushTaskInsRequest(
|
|
208
|
+
task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
|
|
209
|
+
)
|
|
207
210
|
)
|
|
208
211
|
return list(res.task_ids)
|
|
209
212
|
|
|
@@ -215,7 +218,9 @@ class GrpcDriver(Driver):
|
|
|
215
218
|
"""
|
|
216
219
|
# Pull TaskRes
|
|
217
220
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
218
|
-
PullTaskResRequest(
|
|
221
|
+
PullTaskResRequest(
|
|
222
|
+
node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
|
|
223
|
+
)
|
|
219
224
|
)
|
|
220
225
|
# Convert TaskRes to Message
|
|
221
226
|
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
@@ -258,60 +263,3 @@ class GrpcDriver(Driver):
|
|
|
258
263
|
return
|
|
259
264
|
# Disconnect
|
|
260
265
|
self._disconnect()
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
264
|
-
"""Create a simple gRPC retry invoker."""
|
|
265
|
-
|
|
266
|
-
def _on_sucess(retry_state: RetryState) -> None:
|
|
267
|
-
if retry_state.tries > 1:
|
|
268
|
-
log(
|
|
269
|
-
INFO,
|
|
270
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
271
|
-
retry_state.elapsed_time,
|
|
272
|
-
retry_state.tries,
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
def _on_backoff(retry_state: RetryState) -> None:
|
|
276
|
-
if retry_state.tries == 1:
|
|
277
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
278
|
-
else:
|
|
279
|
-
log(
|
|
280
|
-
WARN,
|
|
281
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
282
|
-
retry_state.actual_wait,
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
def _on_giveup(retry_state: RetryState) -> None:
|
|
286
|
-
if retry_state.tries > 1:
|
|
287
|
-
log(
|
|
288
|
-
WARN,
|
|
289
|
-
"Giving up reconnection after %.2f seconds and %s tries.",
|
|
290
|
-
retry_state.elapsed_time,
|
|
291
|
-
retry_state.tries,
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
return RetryInvoker(
|
|
295
|
-
wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
|
|
296
|
-
recoverable_exceptions=grpc.RpcError,
|
|
297
|
-
max_tries=None,
|
|
298
|
-
max_time=None,
|
|
299
|
-
on_success=_on_sucess,
|
|
300
|
-
on_backoff=_on_backoff,
|
|
301
|
-
on_giveup=_on_giveup,
|
|
302
|
-
should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
|
|
307
|
-
"""Wrap the gRPC stub with a retry invoker."""
|
|
308
|
-
|
|
309
|
-
def make_lambda(original_method: Any) -> Any:
|
|
310
|
-
return lambda *args, **kwargs: retry_invoker.invoke(
|
|
311
|
-
original_method, *args, **kwargs
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
for method_name in vars(stub):
|
|
315
|
-
method = getattr(stub, method_name)
|
|
316
|
-
if callable(method):
|
|
317
|
-
setattr(stub, method_name, make_lambda(method))
|
|
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
|
|
|
142
142
|
# Pull TaskRes
|
|
143
143
|
task_res_list = self.state.get_task_res(task_ids=msg_ids)
|
|
144
144
|
# Delete tasks in state
|
|
145
|
-
|
|
145
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
146
|
+
task_ins_ids_to_delete = {
|
|
147
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
148
|
+
}
|
|
149
|
+
self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
146
150
|
# Convert TaskRes to Message
|
|
147
151
|
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
148
152
|
return msgs
|