flwr-nightly 1.10.0.dev20240721__py3-none-any.whl → 1.10.0.dev20240723__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +20 -18
- flwr/cli/new/new.py +1 -1
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +7 -5
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +28 -10
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +7 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +2 -2
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +17 -7
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +2 -1
- flwr/cli/new/templates/app/code/server.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +1 -1
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +13 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +13 -2
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +2 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +6 -6
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +4 -4
- flwr/cli/run/run.py +35 -28
- flwr/client/app.py +3 -3
- flwr/client/grpc_rere_client/connection.py +6 -2
- flwr/client/node_state.py +3 -3
- flwr/client/rest_client/connection.py +6 -2
- flwr/client/supernode/app.py +12 -43
- flwr/common/config.py +23 -17
- flwr/common/context.py +7 -7
- flwr/common/object_ref.py +84 -21
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -1
- flwr/proto/common_pb2.py +13 -1
- flwr/proto/common_pb2.pyi +114 -0
- flwr/proto/driver_pb2.py +22 -21
- flwr/proto/driver_pb2.pyi +7 -4
- flwr/proto/exec_pb2.py +18 -13
- flwr/proto/exec_pb2.pyi +27 -5
- flwr/proto/run_pb2.py +10 -9
- flwr/proto/run_pb2.pyi +7 -4
- flwr/proto/task_pb2.py +7 -8
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +6 -2
- flwr/server/run_serverapp.py +3 -5
- flwr/server/superlink/driver/driver_servicer.py +14 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +4 -4
- flwr/server/superlink/state/in_memory_state.py +2 -2
- flwr/server/superlink/state/sqlite_state.py +2 -2
- flwr/server/superlink/state/state.py +3 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +39 -11
- flwr/superexec/app.py +4 -5
- flwr/superexec/deployment.py +19 -8
- flwr/superexec/exec_grpc.py +3 -2
- flwr/superexec/exec_servicer.py +3 -1
- flwr/superexec/executor.py +10 -5
- flwr/superexec/simulation.py +41 -15
- {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/METADATA +1 -1
- {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/RECORD +74 -74
- {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/entry_points.txt +0 -0
flwr/proto/run_pb2.py
CHANGED
|
@@ -12,9 +12,10 @@ from google.protobuf.internal import builder as _builder
|
|
|
12
12
|
_sym_db = _symbol_database.Default()
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"\
|
|
18
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xc3\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3')
|
|
18
19
|
|
|
19
20
|
_globals = globals()
|
|
20
21
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -23,12 +24,12 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
23
24
|
DESCRIPTOR._options = None
|
|
24
25
|
_globals['_RUN_OVERRIDECONFIGENTRY']._options = None
|
|
25
26
|
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
|
|
26
|
-
_globals['_RUN']._serialized_start=
|
|
27
|
-
_globals['_RUN']._serialized_end=
|
|
28
|
-
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=
|
|
29
|
-
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=
|
|
30
|
-
_globals['_GETRUNREQUEST']._serialized_start=
|
|
31
|
-
_globals['_GETRUNREQUEST']._serialized_end=
|
|
32
|
-
_globals['_GETRUNRESPONSE']._serialized_start=
|
|
33
|
-
_globals['_GETRUNRESPONSE']._serialized_end=
|
|
27
|
+
_globals['_RUN']._serialized_start=65
|
|
28
|
+
_globals['_RUN']._serialized_end=260
|
|
29
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=187
|
|
30
|
+
_globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=260
|
|
31
|
+
_globals['_GETRUNREQUEST']._serialized_start=262
|
|
32
|
+
_globals['_GETRUNREQUEST']._serialized_end=293
|
|
33
|
+
_globals['_GETRUNRESPONSE']._serialized_start=295
|
|
34
|
+
_globals['_GETRUNRESPONSE']._serialized_end=341
|
|
34
35
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/run_pb2.pyi
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
isort:skip_file
|
|
4
4
|
"""
|
|
5
5
|
import builtins
|
|
6
|
+
import flwr.proto.transport_pb2
|
|
6
7
|
import google.protobuf.descriptor
|
|
7
8
|
import google.protobuf.internal.containers
|
|
8
9
|
import google.protobuf.message
|
|
@@ -18,12 +19,14 @@ class Run(google.protobuf.message.Message):
|
|
|
18
19
|
KEY_FIELD_NUMBER: builtins.int
|
|
19
20
|
VALUE_FIELD_NUMBER: builtins.int
|
|
20
21
|
key: typing.Text
|
|
21
|
-
|
|
22
|
+
@property
|
|
23
|
+
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
|
|
22
24
|
def __init__(self,
|
|
23
25
|
*,
|
|
24
26
|
key: typing.Text = ...,
|
|
25
|
-
value: typing.
|
|
27
|
+
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
|
|
26
28
|
) -> None: ...
|
|
29
|
+
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
|
27
30
|
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
|
28
31
|
|
|
29
32
|
RUN_ID_FIELD_NUMBER: builtins.int
|
|
@@ -34,13 +37,13 @@ class Run(google.protobuf.message.Message):
|
|
|
34
37
|
fab_id: typing.Text
|
|
35
38
|
fab_version: typing.Text
|
|
36
39
|
@property
|
|
37
|
-
def override_config(self) -> google.protobuf.internal.containers.
|
|
40
|
+
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
38
41
|
def __init__(self,
|
|
39
42
|
*,
|
|
40
43
|
run_id: builtins.int = ...,
|
|
41
44
|
fab_id: typing.Text = ...,
|
|
42
45
|
fab_version: typing.Text = ...,
|
|
43
|
-
override_config: typing.Optional[typing.Mapping[typing.Text,
|
|
46
|
+
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
44
47
|
) -> None: ...
|
|
45
48
|
def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ...
|
|
46
49
|
global___Run = Run
|
flwr/proto/task_pb2.py
CHANGED
|
@@ -14,21 +14,20 @@ _sym_db = _symbol_database.Default()
|
|
|
14
14
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
16
|
from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
|
17
|
-
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
18
17
|
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
|
22
21
|
|
|
23
22
|
_globals = globals()
|
|
24
23
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
25
24
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
|
|
26
25
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
27
26
|
DESCRIPTOR._options = None
|
|
28
|
-
_globals['_TASK']._serialized_start=
|
|
29
|
-
_globals['_TASK']._serialized_end=
|
|
30
|
-
_globals['_TASKINS']._serialized_start=
|
|
31
|
-
_globals['_TASKINS']._serialized_end=
|
|
32
|
-
_globals['_TASKRES']._serialized_start=
|
|
33
|
-
_globals['_TASKRES']._serialized_end=
|
|
27
|
+
_globals['_TASK']._serialized_start=113
|
|
28
|
+
_globals['_TASK']._serialized_end=378
|
|
29
|
+
_globals['_TASKINS']._serialized_start=380
|
|
30
|
+
_globals['_TASKINS']._serialized_end=472
|
|
31
|
+
_globals['_TASKRES']._serialized_start=474
|
|
32
|
+
_globals['_TASKRES']._serialized_end=566
|
|
34
33
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import Context
|
|
21
|
+
from flwr.common import Context
|
|
22
22
|
|
|
23
23
|
from ..client_manager import ClientManager, SimpleClientManager
|
|
24
24
|
from ..history import History
|
|
@@ -35,9 +35,9 @@ class LegacyContext(Context):
|
|
|
35
35
|
client_manager: ClientManager
|
|
36
36
|
history: History
|
|
37
37
|
|
|
38
|
-
def __init__(
|
|
38
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
39
39
|
self,
|
|
40
|
-
|
|
40
|
+
context: Context,
|
|
41
41
|
config: Optional[ServerConfig] = None,
|
|
42
42
|
strategy: Optional[Strategy] = None,
|
|
43
43
|
client_manager: Optional[ClientManager] = None,
|
|
@@ -52,4 +52,5 @@ class LegacyContext(Context):
|
|
|
52
52
|
self.strategy = strategy
|
|
53
53
|
self.client_manager = client_manager
|
|
54
54
|
self.history = History()
|
|
55
|
-
|
|
55
|
+
|
|
56
|
+
super().__init__(**vars(context))
|
|
@@ -24,7 +24,11 @@ import grpc
|
|
|
24
24
|
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
|
-
from flwr.common.serde import
|
|
27
|
+
from flwr.common.serde import (
|
|
28
|
+
message_from_taskres,
|
|
29
|
+
message_to_taskins,
|
|
30
|
+
user_config_from_proto,
|
|
31
|
+
)
|
|
28
32
|
from flwr.common.typing import Run
|
|
29
33
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
30
34
|
GetNodesRequest,
|
|
@@ -127,7 +131,7 @@ class GrpcDriver(Driver):
|
|
|
127
131
|
run_id=res.run.run_id,
|
|
128
132
|
fab_id=res.run.fab_id,
|
|
129
133
|
fab_version=res.run.fab_version,
|
|
130
|
-
override_config=
|
|
134
|
+
override_config=user_config_from_proto(res.run.override_config),
|
|
131
135
|
)
|
|
132
136
|
|
|
133
137
|
@property
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -19,7 +19,7 @@ import argparse
|
|
|
19
19
|
import sys
|
|
20
20
|
from logging import DEBUG, INFO, WARN
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
25
|
from flwr.common.config import (
|
|
@@ -30,6 +30,7 @@ from flwr.common.config import (
|
|
|
30
30
|
)
|
|
31
31
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
32
32
|
from flwr.common.object_ref import load_app
|
|
33
|
+
from flwr.common.typing import UserConfig
|
|
33
34
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
34
35
|
CreateRunRequest,
|
|
35
36
|
CreateRunResponse,
|
|
@@ -45,7 +46,7 @@ ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
|
45
46
|
def run(
|
|
46
47
|
driver: Driver,
|
|
47
48
|
server_app_dir: str,
|
|
48
|
-
server_app_run_config:
|
|
49
|
+
server_app_run_config: UserConfig,
|
|
49
50
|
server_app_attr: Optional[str] = None,
|
|
50
51
|
loaded_server_app: Optional[ServerApp] = None,
|
|
51
52
|
) -> None:
|
|
@@ -56,9 +57,6 @@ def run(
|
|
|
56
57
|
"but not both."
|
|
57
58
|
)
|
|
58
59
|
|
|
59
|
-
if server_app_dir is not None:
|
|
60
|
-
sys.path.insert(0, str(Path(server_app_dir).absolute()))
|
|
61
|
-
|
|
62
60
|
# Load ServerApp if needed
|
|
63
61
|
def _load() -> ServerApp:
|
|
64
62
|
if server_app_attr:
|
|
@@ -23,6 +23,7 @@ from uuid import UUID
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.serde import user_config_from_proto, user_config_to_proto
|
|
26
27
|
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
27
28
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
28
29
|
CreateRunRequest,
|
|
@@ -72,7 +73,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
72
73
|
run_id = state.create_run(
|
|
73
74
|
request.fab_id,
|
|
74
75
|
request.fab_version,
|
|
75
|
-
|
|
76
|
+
user_config_from_proto(request.override_config),
|
|
76
77
|
)
|
|
77
78
|
return CreateRunResponse(run_id=run_id)
|
|
78
79
|
|
|
@@ -149,8 +150,18 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
149
150
|
|
|
150
151
|
# Retrieve run information
|
|
151
152
|
run = state.get_run(request.run_id)
|
|
152
|
-
|
|
153
|
-
|
|
153
|
+
|
|
154
|
+
if run is None:
|
|
155
|
+
return GetRunResponse()
|
|
156
|
+
|
|
157
|
+
return GetRunResponse(
|
|
158
|
+
run=Run(
|
|
159
|
+
run_id=run.run_id,
|
|
160
|
+
fab_id=run.fab_id,
|
|
161
|
+
fab_version=run.fab_version,
|
|
162
|
+
override_config=user_config_to_proto(run.override_config),
|
|
163
|
+
)
|
|
164
|
+
)
|
|
154
165
|
|
|
155
166
|
|
|
156
167
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
@@ -19,6 +19,7 @@ import time
|
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.serde import user_config_to_proto
|
|
22
23
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
23
24
|
CreateNodeRequest,
|
|
24
25
|
CreateNodeResponse,
|
|
@@ -113,5 +114,15 @@ def get_run(
|
|
|
113
114
|
) -> GetRunResponse:
|
|
114
115
|
"""Get run information."""
|
|
115
116
|
run = state.get_run(request.run_id)
|
|
116
|
-
|
|
117
|
-
|
|
117
|
+
|
|
118
|
+
if run is None:
|
|
119
|
+
return GetRunResponse()
|
|
120
|
+
|
|
121
|
+
return GetRunResponse(
|
|
122
|
+
run=Run(
|
|
123
|
+
run_id=run.run_id,
|
|
124
|
+
fab_id=run.fab_id,
|
|
125
|
+
fab_version=run.fab_version,
|
|
126
|
+
override_config=user_config_to_proto(run.override_config),
|
|
127
|
+
)
|
|
128
|
+
)
|
|
@@ -72,8 +72,8 @@ def _register_node_states(
|
|
|
72
72
|
node_states[node_id] = NodeState(
|
|
73
73
|
node_id=node_id,
|
|
74
74
|
node_config={
|
|
75
|
-
PARTITION_ID_KEY:
|
|
76
|
-
NUM_PARTITIONS_KEY:
|
|
75
|
+
PARTITION_ID_KEY: partition_id,
|
|
76
|
+
NUM_PARTITIONS_KEY: num_partitions,
|
|
77
77
|
},
|
|
78
78
|
)
|
|
79
79
|
|
|
@@ -347,8 +347,8 @@ def start_vce(
|
|
|
347
347
|
if client_app_attr:
|
|
348
348
|
app = _get_load_client_app_fn(
|
|
349
349
|
default_app_ref=client_app_attr,
|
|
350
|
-
|
|
351
|
-
|
|
350
|
+
project_dir=app_dir,
|
|
351
|
+
flwr_dir=flwr_dir,
|
|
352
352
|
multi_app=True,
|
|
353
353
|
)(run.fab_id, run.fab_version)
|
|
354
354
|
|
|
@@ -23,7 +23,7 @@ from uuid import UUID, uuid4
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import log, now
|
|
25
25
|
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
26
|
-
from flwr.common.typing import Run
|
|
26
|
+
from flwr.common.typing import Run, UserConfig
|
|
27
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
28
28
|
from flwr.server.superlink.state.state import State
|
|
29
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
@@ -279,7 +279,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
279
279
|
self,
|
|
280
280
|
fab_id: str,
|
|
281
281
|
fab_version: str,
|
|
282
|
-
override_config:
|
|
282
|
+
override_config: UserConfig,
|
|
283
283
|
) -> int:
|
|
284
284
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
285
285
|
# Sample a random int64 as run_id
|
|
@@ -25,7 +25,7 @@ from uuid import UUID, uuid4
|
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
27
27
|
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
28
|
-
from flwr.common.typing import Run
|
|
28
|
+
from flwr.common.typing import Run, UserConfig
|
|
29
29
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
30
30
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
31
31
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -619,7 +619,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
619
619
|
self,
|
|
620
620
|
fab_id: str,
|
|
621
621
|
fab_version: str,
|
|
622
|
-
override_config:
|
|
622
|
+
override_config: UserConfig,
|
|
623
623
|
) -> int:
|
|
624
624
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
625
625
|
# Sample a random int64 as run_id
|
|
@@ -16,10 +16,10 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.typing import Run
|
|
22
|
+
from flwr.common.typing import Run, UserConfig
|
|
23
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
24
|
|
|
25
25
|
|
|
@@ -161,7 +161,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
161
161
|
self,
|
|
162
162
|
fab_id: str,
|
|
163
163
|
fab_version: str,
|
|
164
|
-
override_config:
|
|
164
|
+
override_config: UserConfig,
|
|
165
165
|
) -> int:
|
|
166
166
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
167
167
|
|
|
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
81
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
|
82
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
|
83
83
|
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
|
84
|
+
failures: List[Exception] = field(default_factory=list)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class SecAggPlusWorkflow:
|
|
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
|
|
|
394
395
|
|
|
395
396
|
for msg in msgs:
|
|
396
397
|
if msg.has_error():
|
|
398
|
+
state.failures.append(Exception(msg.error))
|
|
397
399
|
continue
|
|
398
400
|
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
399
401
|
node_id = msg.metadata.src_node_id
|
|
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
|
|
|
451
453
|
nid: [] for nid in state.active_node_ids
|
|
452
454
|
} # dest node ID -> list of src node IDs
|
|
453
455
|
for msg in msgs:
|
|
456
|
+
if msg.has_error():
|
|
457
|
+
state.failures.append(Exception(msg.error))
|
|
458
|
+
continue
|
|
454
459
|
node_id = msg.metadata.src_node_id
|
|
455
460
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
456
461
|
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
|
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
|
|
|
515
520
|
# Sum collected masked vectors and compute active/dead node IDs
|
|
516
521
|
masked_vector = None
|
|
517
522
|
for msg in msgs:
|
|
523
|
+
if msg.has_error():
|
|
524
|
+
state.failures.append(Exception(msg.error))
|
|
525
|
+
continue
|
|
518
526
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
519
527
|
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
520
528
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
|
|
|
528
536
|
|
|
529
537
|
# Backward compatibility with Strategy
|
|
530
538
|
for msg in msgs:
|
|
539
|
+
if msg.has_error():
|
|
540
|
+
state.failures.append(Exception(msg.error))
|
|
541
|
+
continue
|
|
531
542
|
fitres = compat.recordset_to_fitres(msg.content, True)
|
|
532
543
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
533
544
|
state.legacy_results.append((proxy, fitres))
|
|
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
|
|
|
584
595
|
for nid in state.sampled_node_ids:
|
|
585
596
|
collected_shares_dict[nid] = []
|
|
586
597
|
for msg in msgs:
|
|
598
|
+
if msg.has_error():
|
|
599
|
+
state.failures.append(Exception(msg.error))
|
|
600
|
+
continue
|
|
587
601
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
588
602
|
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
|
589
603
|
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
|
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
|
|
|
652
666
|
INFO,
|
|
653
667
|
"aggregate_fit: received %s results and %s failures",
|
|
654
668
|
len(results),
|
|
655
|
-
|
|
669
|
+
len(state.failures),
|
|
670
|
+
)
|
|
671
|
+
aggregated_result = context.strategy.aggregate_fit(
|
|
672
|
+
current_round, results, state.failures # type: ignore
|
|
656
673
|
)
|
|
657
|
-
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
|
658
674
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
659
675
|
|
|
660
676
|
# Update the parameters and write history
|
flwr/simulation/__init__.py
CHANGED
|
@@ -82,7 +82,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
82
82
|
|
|
83
83
|
# Retrieve context
|
|
84
84
|
context = self.proxy_state.retrieve_context(run_id=run_id)
|
|
85
|
-
partition_id_str = context.node_config[PARTITION_ID_KEY]
|
|
85
|
+
partition_id_str = str(context.node_config[PARTITION_ID_KEY])
|
|
86
86
|
|
|
87
87
|
try:
|
|
88
88
|
self.actor_pool.submit_client_job(
|
|
@@ -25,15 +25,19 @@ from argparse import Namespace
|
|
|
25
25
|
from logging import DEBUG, ERROR, INFO, WARNING
|
|
26
26
|
from pathlib import Path
|
|
27
27
|
from time import sleep
|
|
28
|
-
from typing import
|
|
28
|
+
from typing import List, Optional
|
|
29
29
|
|
|
30
30
|
from flwr.cli.config_utils import load_and_validate
|
|
31
31
|
from flwr.client import ClientApp
|
|
32
32
|
from flwr.common import EventType, event, log
|
|
33
33
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
|
34
34
|
from flwr.common.constant import RUN_ID_NUM_BYTES
|
|
35
|
-
from flwr.common.logger import
|
|
36
|
-
|
|
35
|
+
from flwr.common.logger import (
|
|
36
|
+
set_logger_propagation,
|
|
37
|
+
update_console_handler,
|
|
38
|
+
warn_deprecated_feature_with_example,
|
|
39
|
+
)
|
|
40
|
+
from flwr.common.typing import Run, UserConfig
|
|
37
41
|
from flwr.server.driver import Driver, InMemoryDriver
|
|
38
42
|
from flwr.server.run_serverapp import run as run_server_app
|
|
39
43
|
from flwr.server.server_app import ServerApp
|
|
@@ -93,6 +97,14 @@ def run_simulation_from_cli() -> None:
|
|
|
93
97
|
"""Run Simulation Engine from the CLI."""
|
|
94
98
|
args = _parse_args_run_simulation().parse_args()
|
|
95
99
|
|
|
100
|
+
if args.enable_tf_gpu_growth:
|
|
101
|
+
warn_deprecated_feature_with_example(
|
|
102
|
+
"Passing `--enable-tf-gpu-growth` is deprecated.",
|
|
103
|
+
example_message="Instead, set the `TF_FORCE_GPU_ALLOW_GROWTH` environmnet "
|
|
104
|
+
"variable to true.",
|
|
105
|
+
code_example='TF_FORCE_GPU_ALLOW_GROWTH="true" flower-simulation <...>',
|
|
106
|
+
)
|
|
107
|
+
|
|
96
108
|
# We are supporting two modes for the CLI entrypoint:
|
|
97
109
|
# 1) Running an app dir containing a `pyproject.toml`
|
|
98
110
|
# 2) Running any ClientApp and SeverApp w/o pyproject.toml being present
|
|
@@ -223,6 +235,15 @@ def run_simulation(
|
|
|
223
235
|
When disabled, only INFO, WARNING and ERROR log messages will be shown. If
|
|
224
236
|
enabled, DEBUG-level logs will be displayed.
|
|
225
237
|
"""
|
|
238
|
+
if enable_tf_gpu_growth:
|
|
239
|
+
warn_deprecated_feature_with_example(
|
|
240
|
+
"Passing `enable_tf_gpu_growth=True` is deprecated.",
|
|
241
|
+
example_message="Instead, set the `TF_FORCE_GPU_ALLOW_GROWTH` environmnet "
|
|
242
|
+
"variable to true.",
|
|
243
|
+
code_example='import os;os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"'
|
|
244
|
+
"\n\tflwr.simulation.run_simulationt(...)",
|
|
245
|
+
)
|
|
246
|
+
|
|
226
247
|
_run_simulation(
|
|
227
248
|
num_supernodes=num_supernodes,
|
|
228
249
|
client_app=client_app,
|
|
@@ -238,7 +259,7 @@ def run_simulation(
|
|
|
238
259
|
def run_serverapp_th(
|
|
239
260
|
server_app_attr: Optional[str],
|
|
240
261
|
server_app: Optional[ServerApp],
|
|
241
|
-
server_app_run_config:
|
|
262
|
+
server_app_run_config: UserConfig,
|
|
242
263
|
driver: Driver,
|
|
243
264
|
app_dir: str,
|
|
244
265
|
f_stop: threading.Event,
|
|
@@ -254,7 +275,7 @@ def run_serverapp_th(
|
|
|
254
275
|
exception_event: threading.Event,
|
|
255
276
|
_driver: Driver,
|
|
256
277
|
_server_app_dir: str,
|
|
257
|
-
_server_app_run_config:
|
|
278
|
+
_server_app_run_config: UserConfig,
|
|
258
279
|
_server_app_attr: Optional[str],
|
|
259
280
|
_server_app: Optional[ServerApp],
|
|
260
281
|
) -> None:
|
|
@@ -264,7 +285,7 @@ def run_serverapp_th(
|
|
|
264
285
|
"""
|
|
265
286
|
try:
|
|
266
287
|
if tf_gpu_growth:
|
|
267
|
-
log(INFO, "Enabling GPU growth for Tensorflow on the
|
|
288
|
+
log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
|
|
268
289
|
enable_gpu_growth()
|
|
269
290
|
|
|
270
291
|
# Run ServerApp
|
|
@@ -319,7 +340,7 @@ def _main_loop(
|
|
|
319
340
|
client_app_attr: Optional[str] = None,
|
|
320
341
|
server_app: Optional[ServerApp] = None,
|
|
321
342
|
server_app_attr: Optional[str] = None,
|
|
322
|
-
server_app_run_config: Optional[
|
|
343
|
+
server_app_run_config: Optional[UserConfig] = None,
|
|
323
344
|
) -> None:
|
|
324
345
|
"""Launch SuperLink with Simulation Engine, then ServerApp on a separate thread."""
|
|
325
346
|
# Initialize StateFactory
|
|
@@ -395,7 +416,7 @@ def _run_simulation(
|
|
|
395
416
|
backend_config: Optional[BackendConfig] = None,
|
|
396
417
|
client_app_attr: Optional[str] = None,
|
|
397
418
|
server_app_attr: Optional[str] = None,
|
|
398
|
-
server_app_run_config: Optional[
|
|
419
|
+
server_app_run_config: Optional[UserConfig] = None,
|
|
399
420
|
app_dir: str = "",
|
|
400
421
|
flwr_dir: Optional[str] = None,
|
|
401
422
|
run: Optional[Run] = None,
|
|
@@ -438,7 +459,7 @@ def _run_simulation(
|
|
|
438
459
|
A path to a `ServerApp` module to be loaded: For example: `server:app` or
|
|
439
460
|
`project.package.module:wrapper.app`."
|
|
440
461
|
|
|
441
|
-
server_app_run_config : Optional[
|
|
462
|
+
server_app_run_config : Optional[UserConfig]
|
|
442
463
|
Config dictionary that parameterizes the run config. It will be made accesible
|
|
443
464
|
to the ServerApp.
|
|
444
465
|
|
|
@@ -475,6 +496,14 @@ def _run_simulation(
|
|
|
475
496
|
if "init_args" not in backend_config:
|
|
476
497
|
backend_config["init_args"] = {}
|
|
477
498
|
|
|
499
|
+
# Set default client_resources if not passed
|
|
500
|
+
if "client_resources" not in backend_config:
|
|
501
|
+
backend_config["client_resources"] = {"num_cpus": 2, "num_gpus": 0}
|
|
502
|
+
|
|
503
|
+
# Initialization of backend config to enable GPU growth globally when set
|
|
504
|
+
if "actor" not in backend_config:
|
|
505
|
+
backend_config["actor"] = {"tensorflow": 0}
|
|
506
|
+
|
|
478
507
|
# Set logging level
|
|
479
508
|
logger = logging.getLogger("flwr")
|
|
480
509
|
if verbose_logging:
|
|
@@ -580,8 +609,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
580
609
|
parser.add_argument(
|
|
581
610
|
"--backend-config",
|
|
582
611
|
type=str,
|
|
583
|
-
default=
|
|
584
|
-
'"actor": {"tensorflow": 0}}',
|
|
612
|
+
default="{}",
|
|
585
613
|
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
|
586
614
|
"configure a backend. Values supported in <value> are those included by "
|
|
587
615
|
"`flwr.common.typing.ConfigsRecordValues`. ",
|
flwr/superexec/app.py
CHANGED
|
@@ -93,7 +93,9 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser:
|
|
|
93
93
|
)
|
|
94
94
|
parser.add_argument(
|
|
95
95
|
"--executor-config",
|
|
96
|
-
help="Key-value pairs for the executor config, separated by commas."
|
|
96
|
+
help="Key-value pairs for the executor config, separated by commas. "
|
|
97
|
+
'For example:\n\n`--executor-config superlink="superlink:9091",'
|
|
98
|
+
'root-certificates="certificates/superlink-ca.crt"`',
|
|
97
99
|
)
|
|
98
100
|
parser.add_argument(
|
|
99
101
|
"--insecure",
|
|
@@ -163,11 +165,8 @@ def _load_executor(
|
|
|
163
165
|
args: argparse.Namespace,
|
|
164
166
|
) -> Executor:
|
|
165
167
|
"""Get the executor plugin."""
|
|
166
|
-
if args.executor_dir is not None:
|
|
167
|
-
sys.path.insert(0, args.executor_dir)
|
|
168
|
-
|
|
169
168
|
executor_ref: str = args.executor
|
|
170
|
-
valid, error_msg = validate(executor_ref)
|
|
169
|
+
valid, error_msg = validate(executor_ref, project_dir=args.executor_dir)
|
|
171
170
|
if not valid and error_msg:
|
|
172
171
|
raise LoadExecutorError(error_msg) from None
|
|
173
172
|
|