flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +47 -27
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +32 -21
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +133 -54
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +39 -39
- flwr/client/typing.py +2 -2
- flwr/common/config.py +92 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/object_ref.py +84 -21
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +20 -11
- flwr/proto/exec_pb2.pyi +41 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -18
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +269 -70
- flwr/superexec/app.py +17 -11
- flwr/superexec/deployment.py +111 -35
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +6 -1
- flwr/superexec/executor.py +21 -0
- flwr/superexec/simulation.py +181 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
- flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
flwr/proto/exec_pb2.pyi
CHANGED
|
@@ -3,7 +3,9 @@
|
|
|
3
3
|
isort:skip_file
|
|
4
4
|
"""
|
|
5
5
|
import builtins
|
|
6
|
+
import flwr.proto.transport_pb2
|
|
6
7
|
import google.protobuf.descriptor
|
|
8
|
+
import google.protobuf.internal.containers
|
|
7
9
|
import google.protobuf.message
|
|
8
10
|
import typing
|
|
9
11
|
import typing_extensions
|
|
@@ -12,13 +14,51 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
|
12
14
|
|
|
13
15
|
class StartRunRequest(google.protobuf.message.Message):
|
|
14
16
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
17
|
+
class OverrideConfigEntry(google.protobuf.message.Message):
|
|
18
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
19
|
+
KEY_FIELD_NUMBER: builtins.int
|
|
20
|
+
VALUE_FIELD_NUMBER: builtins.int
|
|
21
|
+
key: typing.Text
|
|
22
|
+
@property
|
|
23
|
+
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
|
|
24
|
+
def __init__(self,
|
|
25
|
+
*,
|
|
26
|
+
key: typing.Text = ...,
|
|
27
|
+
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
|
|
28
|
+
) -> None: ...
|
|
29
|
+
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
|
30
|
+
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
|
31
|
+
|
|
32
|
+
class FederationConfigEntry(google.protobuf.message.Message):
|
|
33
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
34
|
+
KEY_FIELD_NUMBER: builtins.int
|
|
35
|
+
VALUE_FIELD_NUMBER: builtins.int
|
|
36
|
+
key: typing.Text
|
|
37
|
+
@property
|
|
38
|
+
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
|
|
39
|
+
def __init__(self,
|
|
40
|
+
*,
|
|
41
|
+
key: typing.Text = ...,
|
|
42
|
+
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
|
|
43
|
+
) -> None: ...
|
|
44
|
+
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
|
45
|
+
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
|
46
|
+
|
|
15
47
|
FAB_FILE_FIELD_NUMBER: builtins.int
|
|
48
|
+
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
|
|
49
|
+
FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
|
|
16
50
|
fab_file: builtins.bytes
|
|
51
|
+
@property
|
|
52
|
+
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
53
|
+
@property
|
|
54
|
+
def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
17
55
|
def __init__(self,
|
|
18
56
|
*,
|
|
19
57
|
fab_file: builtins.bytes = ...,
|
|
58
|
+
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
59
|
+
federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
20
60
|
) -> None: ...
|
|
21
|
-
def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file"]) -> None: ...
|
|
61
|
+
def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
|
|
22
62
|
global___StartRunRequest = StartRunRequest
|
|
23
63
|
|
|
24
64
|
class StartRunResponse(google.protobuf.message.Message):
|
flwr/proto/run_pb2.py
CHANGED
|
@@ -12,19 +12,24 @@ from google.protobuf.internal import builder as _builder
|
|
|
12
12
|
_sym_db = _symbol_database.Default()
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"
|
|
18
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xc3\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3')
|
|
18
19
|
|
|
19
20
|
_globals = globals()
|
|
20
21
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
21
22
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals)
|
|
22
23
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
23
24
|
DESCRIPTOR._options = None
|
|
24
|
-
_globals['
|
|
25
|
-
_globals['
|
|
26
|
-
_globals['
|
|
27
|
-
_globals['
|
|
28
|
-
_globals['
|
|
29
|
-
_globals['
|
|
25
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._options = None
|
|
26
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
|
|
27
|
+
_globals['_RUN']._serialized_start=65
|
|
28
|
+
_globals['_RUN']._serialized_end=260
|
|
29
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=187
|
|
30
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=260
|
|
31
|
+
_globals['_GETRUNREQUEST']._serialized_start=262
|
|
32
|
+
_globals['_GETRUNREQUEST']._serialized_end=293
|
|
33
|
+
_globals['_GETRUNRESPONSE']._serialized_start=295
|
|
34
|
+
_globals['_GETRUNRESPONSE']._serialized_end=341
|
|
30
35
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/run_pb2.pyi
CHANGED
|
@@ -3,7 +3,9 @@
|
|
|
3
3
|
isort:skip_file
|
|
4
4
|
"""
|
|
5
5
|
import builtins
|
|
6
|
+
import flwr.proto.transport_pb2
|
|
6
7
|
import google.protobuf.descriptor
|
|
8
|
+
import google.protobuf.internal.containers
|
|
7
9
|
import google.protobuf.message
|
|
8
10
|
import typing
|
|
9
11
|
import typing_extensions
|
|
@@ -12,19 +14,38 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
|
12
14
|
|
|
13
15
|
class Run(google.protobuf.message.Message):
|
|
14
16
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
17
|
+
class OverrideConfigEntry(google.protobuf.message.Message):
|
|
18
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
19
|
+
KEY_FIELD_NUMBER: builtins.int
|
|
20
|
+
VALUE_FIELD_NUMBER: builtins.int
|
|
21
|
+
key: typing.Text
|
|
22
|
+
@property
|
|
23
|
+
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
|
|
24
|
+
def __init__(self,
|
|
25
|
+
*,
|
|
26
|
+
key: typing.Text = ...,
|
|
27
|
+
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
|
|
28
|
+
) -> None: ...
|
|
29
|
+
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
|
30
|
+
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
|
31
|
+
|
|
15
32
|
RUN_ID_FIELD_NUMBER: builtins.int
|
|
16
33
|
FAB_ID_FIELD_NUMBER: builtins.int
|
|
17
34
|
FAB_VERSION_FIELD_NUMBER: builtins.int
|
|
35
|
+
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
|
|
18
36
|
run_id: builtins.int
|
|
19
37
|
fab_id: typing.Text
|
|
20
38
|
fab_version: typing.Text
|
|
39
|
+
@property
|
|
40
|
+
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
21
41
|
def __init__(self,
|
|
22
42
|
*,
|
|
23
43
|
run_id: builtins.int = ...,
|
|
24
44
|
fab_id: typing.Text = ...,
|
|
25
45
|
fab_version: typing.Text = ...,
|
|
46
|
+
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
26
47
|
) -> None: ...
|
|
27
|
-
def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ...
|
|
48
|
+
def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ...
|
|
28
49
|
global___Run = Run
|
|
29
50
|
|
|
30
51
|
class GetRunRequest(google.protobuf.message.Message):
|
flwr/proto/task_pb2.py
CHANGED
|
@@ -14,21 +14,20 @@ _sym_db = _symbol_database.Default()
|
|
|
14
14
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
16
|
from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
|
17
|
-
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
18
17
|
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
|
22
21
|
|
|
23
22
|
_globals = globals()
|
|
24
23
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
25
24
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
|
|
26
25
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
27
26
|
DESCRIPTOR._options = None
|
|
28
|
-
_globals['_TASK']._serialized_start=
|
|
29
|
-
_globals['_TASK']._serialized_end=
|
|
30
|
-
_globals['_TASKINS']._serialized_start=
|
|
31
|
-
_globals['_TASKINS']._serialized_end=
|
|
32
|
-
_globals['_TASKRES']._serialized_start=
|
|
33
|
-
_globals['_TASKRES']._serialized_end=
|
|
27
|
+
_globals['_TASK']._serialized_start=113
|
|
28
|
+
_globals['_TASK']._serialized_end=378
|
|
29
|
+
_globals['_TASKINS']._serialized_start=380
|
|
30
|
+
_globals['_TASKINS']._serialized_end=472
|
|
31
|
+
_globals['_TASKRES']._serialized_start=474
|
|
32
|
+
_globals['_TASKRES']._serialized_end=566
|
|
34
33
|
# @@protoc_insertion_point(module_scope)
|
flwr/server/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ from .run_serverapp import run_server_app as run_server_app
|
|
|
28
28
|
from .server import Server as Server
|
|
29
29
|
from .server_app import ServerApp as ServerApp
|
|
30
30
|
from .server_config import ServerConfig as ServerConfig
|
|
31
|
+
from .serverapp_components import ServerAppComponents as ServerAppComponents
|
|
31
32
|
|
|
32
33
|
__all__ = [
|
|
33
34
|
"ClientManager",
|
|
@@ -36,6 +37,7 @@ __all__ = [
|
|
|
36
37
|
"LegacyContext",
|
|
37
38
|
"Server",
|
|
38
39
|
"ServerApp",
|
|
40
|
+
"ServerAppComponents",
|
|
39
41
|
"ServerConfig",
|
|
40
42
|
"SimpleClientManager",
|
|
41
43
|
"run_server_app",
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import Context
|
|
21
|
+
from flwr.common import Context
|
|
22
22
|
|
|
23
23
|
from ..client_manager import ClientManager, SimpleClientManager
|
|
24
24
|
from ..history import History
|
|
@@ -35,9 +35,9 @@ class LegacyContext(Context):
|
|
|
35
35
|
client_manager: ClientManager
|
|
36
36
|
history: History
|
|
37
37
|
|
|
38
|
-
def __init__(
|
|
38
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
39
39
|
self,
|
|
40
|
-
|
|
40
|
+
context: Context,
|
|
41
41
|
config: Optional[ServerConfig] = None,
|
|
42
42
|
strategy: Optional[Strategy] = None,
|
|
43
43
|
client_manager: Optional[ClientManager] = None,
|
|
@@ -52,4 +52,5 @@ class LegacyContext(Context):
|
|
|
52
52
|
self.strategy = strategy
|
|
53
53
|
self.client_manager = client_manager
|
|
54
54
|
self.history = History()
|
|
55
|
-
|
|
55
|
+
|
|
56
|
+
super().__init__(**vars(context))
|
|
@@ -16,19 +16,21 @@
|
|
|
16
16
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
|
-
from logging import DEBUG,
|
|
20
|
-
from typing import Iterable, List, Optional,
|
|
19
|
+
from logging import DEBUG, WARNING
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
24
24
|
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
|
-
from flwr.common.serde import
|
|
27
|
+
from flwr.common.serde import (
|
|
28
|
+
message_from_taskres,
|
|
29
|
+
message_to_taskins,
|
|
30
|
+
user_config_from_proto,
|
|
31
|
+
)
|
|
28
32
|
from flwr.common.typing import Run
|
|
29
33
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
30
|
-
CreateRunRequest,
|
|
31
|
-
CreateRunResponse,
|
|
32
34
|
GetNodesRequest,
|
|
33
35
|
GetNodesResponse,
|
|
34
36
|
PullTaskResRequest,
|
|
@@ -53,167 +55,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
|
|
|
53
55
|
"""
|
|
54
56
|
|
|
55
57
|
|
|
56
|
-
class
|
|
57
|
-
"""`
|
|
58
|
+
class GrpcDriver(Driver):
|
|
59
|
+
"""`GrpcDriver` provides an interface to the Driver API.
|
|
58
60
|
|
|
59
61
|
Parameters
|
|
60
62
|
----------
|
|
61
|
-
|
|
62
|
-
The
|
|
63
|
-
|
|
63
|
+
run_id : int
|
|
64
|
+
The identifier of the run.
|
|
65
|
+
driver_service_address : str (default: "[::]:9091")
|
|
66
|
+
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
64
67
|
root_certificates : Optional[bytes] (default: None)
|
|
65
68
|
The PEM-encoded root certificates as a byte string.
|
|
66
69
|
If provided, a secure connection using the certificates will be
|
|
67
70
|
established to an SSL-enabled Flower server.
|
|
68
71
|
"""
|
|
69
72
|
|
|
70
|
-
def __init__(
|
|
73
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
71
74
|
self,
|
|
75
|
+
run_id: int,
|
|
72
76
|
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
73
77
|
root_certificates: Optional[bytes] = None,
|
|
74
78
|
) -> None:
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
self.
|
|
79
|
+
self._run_id = run_id
|
|
80
|
+
self._addr = driver_service_address
|
|
81
|
+
self._cert = root_certificates
|
|
82
|
+
self._run: Optional[Run] = None
|
|
83
|
+
self._grpc_stub: Optional[DriverStub] = None
|
|
84
|
+
self._channel: Optional[grpc.Channel] = None
|
|
85
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
79
86
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
87
|
+
@property
|
|
88
|
+
def _is_connected(self) -> bool:
|
|
89
|
+
"""Check if connected to the Driver API server."""
|
|
90
|
+
return self._channel is not None
|
|
91
|
+
|
|
92
|
+
def _connect(self) -> None:
|
|
93
|
+
"""Connect to the Driver API.
|
|
83
94
|
|
|
84
|
-
|
|
85
|
-
"""
|
|
95
|
+
This will not call GetRun.
|
|
96
|
+
"""
|
|
86
97
|
event(EventType.DRIVER_CONNECT)
|
|
87
|
-
if self.
|
|
98
|
+
if self._is_connected:
|
|
88
99
|
log(WARNING, "Already connected")
|
|
89
100
|
return
|
|
90
|
-
self.
|
|
91
|
-
server_address=self.
|
|
92
|
-
insecure=(self.
|
|
93
|
-
root_certificates=self.
|
|
101
|
+
self._channel = create_channel(
|
|
102
|
+
server_address=self._addr,
|
|
103
|
+
insecure=(self._cert is None),
|
|
104
|
+
root_certificates=self._cert,
|
|
94
105
|
)
|
|
95
|
-
self.
|
|
96
|
-
log(DEBUG, "[Driver] Connected to %s", self.
|
|
106
|
+
self._grpc_stub = DriverStub(self._channel)
|
|
107
|
+
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
97
108
|
|
|
98
|
-
def
|
|
109
|
+
def _disconnect(self) -> None:
|
|
99
110
|
"""Disconnect from the Driver API."""
|
|
100
111
|
event(EventType.DRIVER_DISCONNECT)
|
|
101
|
-
if
|
|
112
|
+
if not self._is_connected:
|
|
102
113
|
log(DEBUG, "Already disconnected")
|
|
103
114
|
return
|
|
104
|
-
channel = self.
|
|
105
|
-
self.
|
|
106
|
-
self.
|
|
115
|
+
channel: grpc.Channel = self._channel
|
|
116
|
+
self._channel = None
|
|
117
|
+
self._grpc_stub = None
|
|
107
118
|
channel.close()
|
|
108
119
|
log(DEBUG, "[Driver] Disconnected")
|
|
109
120
|
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
-
|
|
128
|
-
# Call gRPC Driver API
|
|
129
|
-
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
-
return res
|
|
131
|
-
|
|
132
|
-
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
133
|
-
"""Get client IDs."""
|
|
134
|
-
# Check if channel is open
|
|
135
|
-
if self.stub is None:
|
|
136
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
137
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
138
|
-
|
|
139
|
-
# Call gRPC Driver API
|
|
140
|
-
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
141
|
-
return res
|
|
142
|
-
|
|
143
|
-
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
144
|
-
"""Schedule tasks."""
|
|
145
|
-
# Check if channel is open
|
|
146
|
-
if self.stub is None:
|
|
147
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
148
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
149
|
-
|
|
150
|
-
# Call gRPC Driver API
|
|
151
|
-
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
152
|
-
return res
|
|
153
|
-
|
|
154
|
-
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
155
|
-
"""Get task results."""
|
|
156
|
-
# Check if channel is open
|
|
157
|
-
if self.stub is None:
|
|
158
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
159
|
-
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
160
|
-
|
|
161
|
-
# Call Driver API
|
|
162
|
-
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
163
|
-
return res
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
class GrpcDriver(Driver):
|
|
167
|
-
"""`Driver` class provides an interface to the Driver API.
|
|
168
|
-
|
|
169
|
-
Parameters
|
|
170
|
-
----------
|
|
171
|
-
run_id : int
|
|
172
|
-
The identifier of the run.
|
|
173
|
-
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
-
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
-
If None, an instance connected to "[::]:9091" will be created.
|
|
176
|
-
"""
|
|
177
|
-
|
|
178
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
179
|
-
self,
|
|
180
|
-
run_id: int,
|
|
181
|
-
stub: Optional[GrpcDriverStub] = None,
|
|
182
|
-
) -> None:
|
|
183
|
-
self._run_id = run_id
|
|
184
|
-
self._run: Optional[Run] = None
|
|
185
|
-
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
186
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
121
|
+
def _init_run(self) -> None:
|
|
122
|
+
# Check if is initialized
|
|
123
|
+
if self._run is not None:
|
|
124
|
+
return
|
|
125
|
+
# Get the run info
|
|
126
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
127
|
+
res: GetRunResponse = self._stub.GetRun(req)
|
|
128
|
+
if not res.HasField("run"):
|
|
129
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
130
|
+
self._run = Run(
|
|
131
|
+
run_id=res.run.run_id,
|
|
132
|
+
fab_id=res.run.fab_id,
|
|
133
|
+
fab_version=res.run.fab_version,
|
|
134
|
+
override_config=user_config_from_proto(res.run.override_config),
|
|
135
|
+
)
|
|
187
136
|
|
|
188
137
|
@property
|
|
189
138
|
def run(self) -> Run:
|
|
190
139
|
"""Run information."""
|
|
191
|
-
self.
|
|
192
|
-
return Run(**vars(
|
|
140
|
+
self._init_run()
|
|
141
|
+
return Run(**vars(self._run))
|
|
193
142
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
# Get the run info
|
|
201
|
-
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
-
res = self.stub.get_run(req)
|
|
203
|
-
if not res.HasField("run"):
|
|
204
|
-
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
-
self._run = Run(
|
|
206
|
-
run_id=res.run.run_id,
|
|
207
|
-
fab_id=res.run.fab_id,
|
|
208
|
-
fab_version=res.run.fab_version,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
return self.stub, self._run.run_id
|
|
143
|
+
@property
|
|
144
|
+
def _stub(self) -> DriverStub:
|
|
145
|
+
"""Driver stub."""
|
|
146
|
+
if not self._is_connected:
|
|
147
|
+
self._connect()
|
|
148
|
+
return cast(DriverStub, self._grpc_stub)
|
|
212
149
|
|
|
213
150
|
def _check_message(self, message: Message) -> None:
|
|
214
151
|
# Check if the message is valid
|
|
215
152
|
if not (
|
|
216
|
-
|
|
153
|
+
# Assume self._run being initialized
|
|
154
|
+
message.metadata.run_id == self._run_id
|
|
217
155
|
and message.metadata.src_node_id == self.node.node_id
|
|
218
156
|
and message.metadata.message_id == ""
|
|
219
157
|
and message.metadata.reply_to_message == ""
|
|
@@ -234,7 +172,7 @@ class GrpcDriver(Driver):
|
|
|
234
172
|
This method constructs a new `Message` with given content and metadata.
|
|
235
173
|
The `run_id` and `src_node_id` will be set automatically.
|
|
236
174
|
"""
|
|
237
|
-
|
|
175
|
+
self._init_run()
|
|
238
176
|
if ttl:
|
|
239
177
|
warnings.warn(
|
|
240
178
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -245,7 +183,7 @@ class GrpcDriver(Driver):
|
|
|
245
183
|
|
|
246
184
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
247
185
|
metadata = Metadata(
|
|
248
|
-
run_id=
|
|
186
|
+
run_id=self._run_id,
|
|
249
187
|
message_id="", # Will be set by the server
|
|
250
188
|
src_node_id=self.node.node_id,
|
|
251
189
|
dst_node_id=dst_node_id,
|
|
@@ -258,9 +196,11 @@ class GrpcDriver(Driver):
|
|
|
258
196
|
|
|
259
197
|
def get_node_ids(self) -> List[int]:
|
|
260
198
|
"""Get node IDs."""
|
|
261
|
-
|
|
199
|
+
self._init_run()
|
|
262
200
|
# Call GrpcDriverStub method
|
|
263
|
-
res =
|
|
201
|
+
res: GetNodesResponse = self._stub.GetNodes(
|
|
202
|
+
GetNodesRequest(run_id=self._run_id)
|
|
203
|
+
)
|
|
264
204
|
return [node.node_id for node in res.nodes]
|
|
265
205
|
|
|
266
206
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -269,7 +209,7 @@ class GrpcDriver(Driver):
|
|
|
269
209
|
This method takes an iterable of messages and sends each message
|
|
270
210
|
to the node specified in `dst_node_id`.
|
|
271
211
|
"""
|
|
272
|
-
|
|
212
|
+
self._init_run()
|
|
273
213
|
# Construct TaskIns
|
|
274
214
|
task_ins_list: List[TaskIns] = []
|
|
275
215
|
for msg in messages:
|
|
@@ -280,7 +220,9 @@ class GrpcDriver(Driver):
|
|
|
280
220
|
# Add to list
|
|
281
221
|
task_ins_list.append(taskins)
|
|
282
222
|
# Call GrpcDriverStub method
|
|
283
|
-
res =
|
|
223
|
+
res: PushTaskInsResponse = self._stub.PushTaskIns(
|
|
224
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
225
|
+
)
|
|
284
226
|
return list(res.task_ids)
|
|
285
227
|
|
|
286
228
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -289,9 +231,9 @@ class GrpcDriver(Driver):
|
|
|
289
231
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
290
232
|
set of given message IDs.
|
|
291
233
|
"""
|
|
292
|
-
|
|
234
|
+
self._init_run()
|
|
293
235
|
# Pull TaskRes
|
|
294
|
-
res =
|
|
236
|
+
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
295
237
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
296
238
|
)
|
|
297
239
|
# Convert TaskRes to Message
|
|
@@ -331,7 +273,7 @@ class GrpcDriver(Driver):
|
|
|
331
273
|
def close(self) -> None:
|
|
332
274
|
"""Disconnect from the SuperLink if connected."""
|
|
333
275
|
# Check if `connect` was called before
|
|
334
|
-
if not self.
|
|
276
|
+
if not self._is_connected:
|
|
335
277
|
return
|
|
336
278
|
# Disconnect
|
|
337
|
-
self.
|
|
279
|
+
self._disconnect()
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,13 +22,22 @@ from pathlib import Path
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
|
-
from flwr.common.config import
|
|
25
|
+
from flwr.common.config import (
|
|
26
|
+
get_flwr_dir,
|
|
27
|
+
get_fused_config,
|
|
28
|
+
get_project_config,
|
|
29
|
+
get_project_dir,
|
|
30
|
+
)
|
|
26
31
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
27
32
|
from flwr.common.object_ref import load_app
|
|
28
|
-
from flwr.
|
|
33
|
+
from flwr.common.typing import UserConfig
|
|
34
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
35
|
+
CreateRunRequest,
|
|
36
|
+
CreateRunResponse,
|
|
37
|
+
)
|
|
29
38
|
|
|
30
39
|
from .driver import Driver
|
|
31
|
-
from .driver.grpc_driver import GrpcDriver
|
|
40
|
+
from .driver.grpc_driver import GrpcDriver
|
|
32
41
|
from .server_app import LoadServerAppError, ServerApp
|
|
33
42
|
|
|
34
43
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -37,6 +46,7 @@ ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
|
37
46
|
def run(
|
|
38
47
|
driver: Driver,
|
|
39
48
|
server_app_dir: str,
|
|
49
|
+
server_app_run_config: UserConfig,
|
|
40
50
|
server_app_attr: Optional[str] = None,
|
|
41
51
|
loaded_server_app: Optional[ServerApp] = None,
|
|
42
52
|
) -> None:
|
|
@@ -47,9 +57,6 @@ def run(
|
|
|
47
57
|
"but not both."
|
|
48
58
|
)
|
|
49
59
|
|
|
50
|
-
if server_app_dir is not None:
|
|
51
|
-
sys.path.insert(0, str(Path(server_app_dir).absolute()))
|
|
52
|
-
|
|
53
60
|
# Load ServerApp if needed
|
|
54
61
|
def _load() -> ServerApp:
|
|
55
62
|
if server_app_attr:
|
|
@@ -69,7 +76,9 @@ def run(
|
|
|
69
76
|
server_app = _load()
|
|
70
77
|
|
|
71
78
|
# Initialize Context
|
|
72
|
-
context = Context(
|
|
79
|
+
context = Context(
|
|
80
|
+
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
|
|
81
|
+
)
|
|
73
82
|
|
|
74
83
|
# Call ServerApp
|
|
75
84
|
server_app(driver=driver, context=context)
|
|
@@ -144,22 +153,29 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
144
153
|
"For more details, use: ``flower-server-app -h``"
|
|
145
154
|
)
|
|
146
155
|
|
|
147
|
-
|
|
148
|
-
driver_service_address=args.superlink, root_certificates=root_certificates
|
|
149
|
-
)
|
|
156
|
+
# Initialize GrpcDriver
|
|
150
157
|
if args.run_id is not None:
|
|
151
158
|
# User provided `--run-id`, but not `server-app`
|
|
152
|
-
|
|
159
|
+
driver = GrpcDriver(
|
|
160
|
+
run_id=args.run_id,
|
|
161
|
+
driver_service_address=args.superlink,
|
|
162
|
+
root_certificates=root_certificates,
|
|
163
|
+
)
|
|
153
164
|
else:
|
|
154
165
|
# User provided `server-app`, but not `--run-id`
|
|
155
166
|
# Create run if run_id is not provided
|
|
156
|
-
|
|
167
|
+
driver = GrpcDriver(
|
|
168
|
+
run_id=0, # Will be overwritten
|
|
169
|
+
driver_service_address=args.superlink,
|
|
170
|
+
root_certificates=root_certificates,
|
|
171
|
+
)
|
|
172
|
+
# Create run
|
|
157
173
|
req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
|
|
158
|
-
res =
|
|
159
|
-
|
|
174
|
+
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
175
|
+
# Overwrite driver._run_id
|
|
176
|
+
driver._run_id = res.run_id # pylint: disable=W0212
|
|
160
177
|
|
|
161
|
-
|
|
162
|
-
driver = GrpcDriver(run_id=run_id, stub=stub)
|
|
178
|
+
server_app_run_config = {}
|
|
163
179
|
|
|
164
180
|
# Dynamically obtain ServerApp path based on run_id
|
|
165
181
|
if args.run_id is not None:
|
|
@@ -168,7 +184,8 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
168
184
|
run_ = driver.run
|
|
169
185
|
server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
|
|
170
186
|
config = get_project_config(server_app_dir)
|
|
171
|
-
server_app_attr = config["
|
|
187
|
+
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
188
|
+
server_app_run_config = get_fused_config(run_, flwr_dir)
|
|
172
189
|
else:
|
|
173
190
|
# User provided `server-app`, but not `--run-id`
|
|
174
191
|
server_app_dir = str(Path(args.dir).absolute())
|
|
@@ -182,7 +199,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
182
199
|
)
|
|
183
200
|
|
|
184
201
|
# Run the ServerApp with the Driver
|
|
185
|
-
run(
|
|
202
|
+
run(
|
|
203
|
+
driver=driver,
|
|
204
|
+
server_app_dir=server_app_dir,
|
|
205
|
+
server_app_run_config=server_app_run_config,
|
|
206
|
+
server_app_attr=server_app_attr,
|
|
207
|
+
)
|
|
186
208
|
|
|
187
209
|
# Clean up
|
|
188
210
|
driver.close()
|