flwr-nightly 1.11.0.dev20240813__py3-none-any.whl → 1.11.0.dev20240822__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +2 -2
- flwr/cli/install.py +3 -1
- flwr/cli/run/run.py +15 -11
- flwr/client/app.py +132 -14
- flwr/client/clientapp/__init__.py +22 -0
- flwr/client/clientapp/app.py +233 -0
- flwr/client/clientapp/clientappio_servicer.py +244 -0
- flwr/client/clientapp/utils.py +108 -0
- flwr/client/grpc_rere_client/connection.py +9 -1
- flwr/client/node_state.py +17 -4
- flwr/client/rest_client/connection.py +16 -3
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +36 -164
- flwr/common/__init__.py +4 -0
- flwr/common/config.py +31 -10
- flwr/common/record/configsrecord.py +49 -15
- flwr/common/record/metricsrecord.py +54 -14
- flwr/common/record/parametersrecord.py +84 -17
- flwr/common/record/recordset.py +80 -8
- flwr/common/record/typeddict.py +20 -58
- flwr/common/recordset_compat.py +6 -6
- flwr/common/serde.py +24 -2
- flwr/common/typing.py +1 -0
- flwr/proto/clientappio_pb2.py +17 -13
- flwr/proto/clientappio_pb2.pyi +24 -2
- flwr/proto/clientappio_pb2_grpc.py +34 -0
- flwr/proto/clientappio_pb2_grpc.pyi +13 -0
- flwr/proto/exec_pb2.py +16 -15
- flwr/proto/exec_pb2.pyi +7 -4
- flwr/proto/message_pb2.py +2 -2
- flwr/proto/message_pb2.pyi +4 -1
- flwr/server/app.py +15 -0
- flwr/server/driver/grpc_driver.py +1 -0
- flwr/server/run_serverapp.py +18 -2
- flwr/server/server.py +3 -1
- flwr/server/superlink/driver/driver_grpc.py +3 -0
- flwr/server/superlink/driver/driver_servicer.py +32 -4
- flwr/server/superlink/ffs/disk_ffs.py +6 -3
- flwr/server/superlink/ffs/ffs.py +3 -3
- flwr/server/superlink/ffs/ffs_factory.py +47 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +12 -4
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +8 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +16 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -2
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/state/in_memory_state.py +7 -5
- flwr/server/superlink/state/sqlite_state.py +17 -7
- flwr/server/superlink/state/state.py +4 -3
- flwr/server/workflow/default_workflows.py +3 -1
- flwr/simulation/run_simulation.py +5 -67
- flwr/superexec/app.py +3 -3
- flwr/superexec/deployment.py +8 -9
- flwr/superexec/exec_servicer.py +1 -1
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240822.dist-info}/METADATA +2 -2
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240822.dist-info}/RECORD +58 -53
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240822.dist-info}/entry_points.txt +1 -1
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240822.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240822.dist-info}/WHEEL +0 -0
flwr/proto/exec_pb2.pyi
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
isort:skip_file
|
|
4
4
|
"""
|
|
5
5
|
import builtins
|
|
6
|
+
import flwr.proto.fab_pb2
|
|
6
7
|
import flwr.proto.transport_pb2
|
|
7
8
|
import google.protobuf.descriptor
|
|
8
9
|
import google.protobuf.internal.containers
|
|
@@ -44,21 +45,23 @@ class StartRunRequest(google.protobuf.message.Message):
|
|
|
44
45
|
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
|
|
45
46
|
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
|
|
46
47
|
|
|
47
|
-
|
|
48
|
+
FAB_FIELD_NUMBER: builtins.int
|
|
48
49
|
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
|
|
49
50
|
FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
|
|
50
|
-
|
|
51
|
+
@property
|
|
52
|
+
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
|
|
51
53
|
@property
|
|
52
54
|
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
53
55
|
@property
|
|
54
56
|
def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
|
|
55
57
|
def __init__(self,
|
|
56
58
|
*,
|
|
57
|
-
|
|
59
|
+
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
|
|
58
60
|
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
59
61
|
federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
|
|
60
62
|
) -> None: ...
|
|
61
|
-
def
|
|
63
|
+
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
|
|
64
|
+
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
|
|
62
65
|
global___StartRunRequest = StartRunRequest
|
|
63
66
|
|
|
64
67
|
class StartRunResponse(google.protobuf.message.Message):
|
flwr/proto/message_pb2.py
CHANGED
|
@@ -17,7 +17,7 @@ from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
|
|
17
17
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3')
|
|
21
21
|
|
|
22
22
|
_globals = globals()
|
|
23
23
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -37,5 +37,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
37
37
|
_globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497
|
|
38
38
|
_globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565
|
|
39
39
|
_globals['_METADATA']._serialized_start=568
|
|
40
|
-
_globals['_METADATA']._serialized_end=
|
|
40
|
+
_globals['_METADATA']._serialized_end=755
|
|
41
41
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/message_pb2.pyi
CHANGED
|
@@ -99,6 +99,7 @@ class Metadata(google.protobuf.message.Message):
|
|
|
99
99
|
GROUP_ID_FIELD_NUMBER: builtins.int
|
|
100
100
|
TTL_FIELD_NUMBER: builtins.int
|
|
101
101
|
MESSAGE_TYPE_FIELD_NUMBER: builtins.int
|
|
102
|
+
CREATED_AT_FIELD_NUMBER: builtins.int
|
|
102
103
|
run_id: builtins.int
|
|
103
104
|
message_id: typing.Text
|
|
104
105
|
src_node_id: builtins.int
|
|
@@ -107,6 +108,7 @@ class Metadata(google.protobuf.message.Message):
|
|
|
107
108
|
group_id: typing.Text
|
|
108
109
|
ttl: builtins.float
|
|
109
110
|
message_type: typing.Text
|
|
111
|
+
created_at: builtins.float
|
|
110
112
|
def __init__(self,
|
|
111
113
|
*,
|
|
112
114
|
run_id: builtins.int = ...,
|
|
@@ -117,6 +119,7 @@ class Metadata(google.protobuf.message.Message):
|
|
|
117
119
|
group_id: typing.Text = ...,
|
|
118
120
|
ttl: builtins.float = ...,
|
|
119
121
|
message_type: typing.Text = ...,
|
|
122
|
+
created_at: builtins.float = ...,
|
|
120
123
|
) -> None: ...
|
|
121
|
-
def ClearField(self, field_name: typing_extensions.Literal["dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
|
|
124
|
+
def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
|
|
122
125
|
global___Metadata = Metadata
|
flwr/server/app.py
CHANGED
|
@@ -34,6 +34,7 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
34
34
|
|
|
35
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
|
+
from flwr.common.config import get_flwr_dir
|
|
37
38
|
from flwr.common.constant import (
|
|
38
39
|
MISSING_EXTRA_REST,
|
|
39
40
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
@@ -57,6 +58,7 @@ from .server import Server, init_defaults, run_fl
|
|
|
57
58
|
from .server_config import ServerConfig
|
|
58
59
|
from .strategy import Strategy
|
|
59
60
|
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
|
61
|
+
from .superlink.ffs.ffs_factory import FfsFactory
|
|
60
62
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
61
63
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
62
64
|
generic_create_grpc_server,
|
|
@@ -72,6 +74,7 @@ ADDRESS_FLEET_API_GRPC_BIDI = "[::]:8080" # IPv6 to keep start_server compatibl
|
|
|
72
74
|
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"
|
|
73
75
|
|
|
74
76
|
DATABASE = ":flwr-in-memory-state:"
|
|
77
|
+
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
@@ -211,10 +214,14 @@ def run_superlink() -> None:
|
|
|
211
214
|
# Initialize StateFactory
|
|
212
215
|
state_factory = StateFactory(args.database)
|
|
213
216
|
|
|
217
|
+
# Initialize FfsFactory
|
|
218
|
+
ffs_factory = FfsFactory(args.storage_dir)
|
|
219
|
+
|
|
214
220
|
# Start Driver API
|
|
215
221
|
driver_server: grpc.Server = run_driver_api_grpc(
|
|
216
222
|
address=driver_address,
|
|
217
223
|
state_factory=state_factory,
|
|
224
|
+
ffs_factory=ffs_factory,
|
|
218
225
|
certificates=certificates,
|
|
219
226
|
)
|
|
220
227
|
|
|
@@ -294,6 +301,7 @@ def run_superlink() -> None:
|
|
|
294
301
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
295
302
|
address=fleet_address,
|
|
296
303
|
state_factory=state_factory,
|
|
304
|
+
ffs_factory=ffs_factory,
|
|
297
305
|
certificates=certificates,
|
|
298
306
|
interceptors=interceptors,
|
|
299
307
|
)
|
|
@@ -480,6 +488,7 @@ def _try_obtain_certificates(
|
|
|
480
488
|
def _run_fleet_api_grpc_rere(
|
|
481
489
|
address: str,
|
|
482
490
|
state_factory: StateFactory,
|
|
491
|
+
ffs_factory: FfsFactory,
|
|
483
492
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
484
493
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
485
494
|
) -> grpc.Server:
|
|
@@ -487,6 +496,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
487
496
|
# Create Fleet API gRPC server
|
|
488
497
|
fleet_servicer = FleetServicer(
|
|
489
498
|
state_factory=state_factory,
|
|
499
|
+
ffs_factory=ffs_factory,
|
|
490
500
|
)
|
|
491
501
|
fleet_add_servicer_to_server_fn = add_FleetServicer_to_server
|
|
492
502
|
fleet_grpc_server = generic_create_grpc_server(
|
|
@@ -610,6 +620,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
610
620
|
"Flower will just create a state in memory.",
|
|
611
621
|
default=DATABASE,
|
|
612
622
|
)
|
|
623
|
+
parser.add_argument(
|
|
624
|
+
"--storage-dir",
|
|
625
|
+
help="The base directory to store the objects for the Flower File System.",
|
|
626
|
+
default=BASE_DIR,
|
|
627
|
+
)
|
|
613
628
|
parser.add_argument(
|
|
614
629
|
"--auth-list-public-keys",
|
|
615
630
|
type=str,
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -21,6 +21,8 @@ from logging import DEBUG, INFO, WARN
|
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
|
+
from flwr.cli.config_utils import get_fab_metadata
|
|
25
|
+
from flwr.cli.install import install_from_fab
|
|
24
26
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
27
|
from flwr.common.config import (
|
|
26
28
|
get_flwr_dir,
|
|
@@ -36,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
36
38
|
CreateRunRequest,
|
|
37
39
|
CreateRunResponse,
|
|
38
40
|
)
|
|
41
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
39
42
|
|
|
40
43
|
from .driver import Driver
|
|
41
44
|
from .driver.grpc_driver import GrpcDriver
|
|
@@ -87,7 +90,8 @@ def run(
|
|
|
87
90
|
log(DEBUG, "ServerApp finished running.")
|
|
88
91
|
|
|
89
92
|
|
|
90
|
-
|
|
93
|
+
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
94
|
+
def run_server_app() -> None:
|
|
91
95
|
"""Run Flower server app."""
|
|
92
96
|
event(EventType.RUN_SERVER_APP_ENTER)
|
|
93
97
|
|
|
@@ -164,7 +168,19 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
164
168
|
)
|
|
165
169
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
166
170
|
run_ = driver.run
|
|
167
|
-
|
|
171
|
+
if run_.fab_hash:
|
|
172
|
+
fab_req = GetFabRequest(hash_str=run_.fab_hash)
|
|
173
|
+
# pylint: disable-next=W0212
|
|
174
|
+
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
|
|
175
|
+
if fab_res.fab.hash_str != run_.fab_hash:
|
|
176
|
+
raise ValueError("FAB hashes don't match.")
|
|
177
|
+
|
|
178
|
+
install_from_fab(fab_res.fab.content, flwr_dir, True)
|
|
179
|
+
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
|
|
180
|
+
else:
|
|
181
|
+
fab_id, fab_version = run_.fab_id, run_.fab_version
|
|
182
|
+
|
|
183
|
+
app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
|
|
168
184
|
config = get_project_config(app_path)
|
|
169
185
|
else:
|
|
170
186
|
# User provided `app_dir`, but not `--run-id`
|
flwr/server/server.py
CHANGED
|
@@ -91,7 +91,7 @@ class Server:
|
|
|
91
91
|
# Initialize parameters
|
|
92
92
|
log(INFO, "[INIT]")
|
|
93
93
|
self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
|
|
94
|
-
log(INFO, "
|
|
94
|
+
log(INFO, "Starting evaluation of initial global parameters")
|
|
95
95
|
res = self.strategy.evaluate(0, parameters=self.parameters)
|
|
96
96
|
if res is not None:
|
|
97
97
|
log(
|
|
@@ -102,6 +102,8 @@ class Server:
|
|
|
102
102
|
)
|
|
103
103
|
history.add_loss_centralized(server_round=0, loss=res[0])
|
|
104
104
|
history.add_metrics_centralized(server_round=0, metrics=res[1])
|
|
105
|
+
else:
|
|
106
|
+
log(INFO, "Evaluation returned no results (`None`)")
|
|
105
107
|
|
|
106
108
|
# Run federated learning for num_rounds
|
|
107
109
|
start_time = timeit.default_timer()
|
|
@@ -24,6 +24,7 @@ from flwr.common.logger import log
|
|
|
24
24
|
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
|
|
25
25
|
add_DriverServicer_to_server,
|
|
26
26
|
)
|
|
27
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
27
28
|
from flwr.server.superlink.state import StateFactory
|
|
28
29
|
|
|
29
30
|
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
@@ -33,12 +34,14 @@ from .driver_servicer import DriverServicer
|
|
|
33
34
|
def run_driver_api_grpc(
|
|
34
35
|
address: str,
|
|
35
36
|
state_factory: StateFactory,
|
|
37
|
+
ffs_factory: FfsFactory,
|
|
36
38
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
37
39
|
) -> grpc.Server:
|
|
38
40
|
"""Run Driver API (gRPC, request-response)."""
|
|
39
41
|
# Create Driver API gRPC server
|
|
40
42
|
driver_servicer: grpc.Server = DriverServicer(
|
|
41
43
|
state_factory=state_factory,
|
|
44
|
+
ffs_factory=ffs_factory,
|
|
42
45
|
)
|
|
43
46
|
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
|
|
44
47
|
driver_grpc_server = generic_create_grpc_server(
|
|
@@ -23,7 +23,13 @@ from uuid import UUID
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common.logger import log
|
|
26
|
-
from flwr.common.serde import
|
|
26
|
+
from flwr.common.serde import (
|
|
27
|
+
fab_from_proto,
|
|
28
|
+
fab_to_proto,
|
|
29
|
+
user_config_from_proto,
|
|
30
|
+
user_config_to_proto,
|
|
31
|
+
)
|
|
32
|
+
from flwr.common.typing import Fab
|
|
27
33
|
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
28
34
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
35
|
CreateRunRequest,
|
|
@@ -43,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
49
|
Run,
|
|
44
50
|
)
|
|
45
51
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
52
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
|
53
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
46
54
|
from flwr.server.superlink.state import State, StateFactory
|
|
47
55
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
48
56
|
|
|
@@ -50,8 +58,9 @@ from flwr.server.utils.validator import validate_task_ins_or_res
|
|
|
50
58
|
class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
51
59
|
"""Driver API servicer."""
|
|
52
60
|
|
|
53
|
-
def __init__(self, state_factory: StateFactory) -> None:
|
|
61
|
+
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
|
|
54
62
|
self.state_factory = state_factory
|
|
63
|
+
self.ffs_factory = ffs_factory
|
|
55
64
|
|
|
56
65
|
def GetNodes(
|
|
57
66
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
@@ -71,9 +80,20 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
71
80
|
"""Create run ID."""
|
|
72
81
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
73
82
|
state: State = self.state_factory.state()
|
|
83
|
+
if request.HasField("fab"):
|
|
84
|
+
fab = fab_from_proto(request.fab)
|
|
85
|
+
ffs: Ffs = self.ffs_factory.ffs()
|
|
86
|
+
fab_hash = ffs.put(fab.content, {})
|
|
87
|
+
_raise_if(
|
|
88
|
+
fab_hash != fab.hash_str,
|
|
89
|
+
f"FAB ({fab.hash_str}) hash from request doesn't match contents",
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
fab_hash = ""
|
|
74
93
|
run_id = state.create_run(
|
|
75
94
|
request.fab_id,
|
|
76
95
|
request.fab_version,
|
|
96
|
+
fab_hash,
|
|
77
97
|
user_config_from_proto(request.override_config),
|
|
78
98
|
)
|
|
79
99
|
return CreateRunResponse(run_id=run_id)
|
|
@@ -161,14 +181,22 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
161
181
|
fab_id=run.fab_id,
|
|
162
182
|
fab_version=run.fab_version,
|
|
163
183
|
override_config=user_config_to_proto(run.override_config),
|
|
184
|
+
fab_hash=run.fab_hash,
|
|
164
185
|
)
|
|
165
186
|
)
|
|
166
187
|
|
|
167
188
|
def GetFab(
|
|
168
189
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
169
190
|
) -> GetFabResponse:
|
|
170
|
-
"""
|
|
171
|
-
|
|
191
|
+
"""Get FAB from Ffs."""
|
|
192
|
+
log(DEBUG, "DriverServicer.GetFab")
|
|
193
|
+
|
|
194
|
+
ffs: Ffs = self.ffs_factory.ffs()
|
|
195
|
+
if result := ffs.get(request.hash_str):
|
|
196
|
+
fab = Fab(request.hash_str, result[0])
|
|
197
|
+
return GetFabResponse(fab=fab_to_proto(fab))
|
|
198
|
+
|
|
199
|
+
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
172
200
|
|
|
173
201
|
|
|
174
202
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import hashlib
|
|
18
18
|
import json
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import Dict, List, Tuple
|
|
20
|
+
from typing import Dict, List, Optional, Tuple
|
|
21
21
|
|
|
22
22
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
23
23
|
|
|
@@ -58,7 +58,7 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
58
58
|
|
|
59
59
|
return content_hash
|
|
60
60
|
|
|
61
|
-
def get(self, key: str) -> Tuple[bytes, Dict[str, str]]:
|
|
61
|
+
def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]:
|
|
62
62
|
"""Return tuple containing the object content and metadata.
|
|
63
63
|
|
|
64
64
|
Parameters
|
|
@@ -68,9 +68,12 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
68
68
|
|
|
69
69
|
Returns
|
|
70
70
|
-------
|
|
71
|
-
Tuple[bytes, Dict[str, str]]
|
|
71
|
+
Optional[Tuple[bytes, Dict[str, str]]]
|
|
72
72
|
A tuple containing the object content and metadata.
|
|
73
73
|
"""
|
|
74
|
+
if not (self.base_dir / key).exists():
|
|
75
|
+
return None
|
|
76
|
+
|
|
74
77
|
content = (self.base_dir / key).read_bytes()
|
|
75
78
|
meta = json.loads((self.base_dir / f"{key}.META").read_text())
|
|
76
79
|
|
flwr/server/superlink/ffs/ffs.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import Dict, List, Tuple
|
|
19
|
+
from typing import Dict, List, Optional, Tuple
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class Ffs(abc.ABC): # pylint: disable=R0904
|
|
@@ -40,7 +40,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
|
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
@abc.abstractmethod
|
|
43
|
-
def get(self, key: str) -> Tuple[bytes, Dict[str, str]]:
|
|
43
|
+
def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]:
|
|
44
44
|
"""Return tuple containing the object content and metadata.
|
|
45
45
|
|
|
46
46
|
Parameters
|
|
@@ -50,7 +50,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
|
|
|
50
50
|
|
|
51
51
|
Returns
|
|
52
52
|
-------
|
|
53
|
-
Tuple[bytes, Dict[str, str]]
|
|
53
|
+
Optional[Tuple[bytes, Dict[str, str]]]
|
|
54
54
|
A tuple containing the object content and metadata.
|
|
55
55
|
"""
|
|
56
56
|
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Factory class that creates Ffs instances."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from logging import DEBUG
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from flwr.common.logger import log
|
|
22
|
+
|
|
23
|
+
from .disk_ffs import DiskFfs
|
|
24
|
+
from .ffs import Ffs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FfsFactory:
|
|
28
|
+
"""Factory class that creates Ffs instances.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
base_dir : str
|
|
33
|
+
The base directory used by DiskFfs to store objects.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, base_dir: str) -> None:
|
|
37
|
+
self.base_dir = base_dir
|
|
38
|
+
self.ffs_instance: Optional[Ffs] = None
|
|
39
|
+
|
|
40
|
+
def ffs(self) -> Ffs:
|
|
41
|
+
"""Return a Ffs instance and create it, if necessary."""
|
|
42
|
+
if not self.ffs_instance:
|
|
43
|
+
log(DEBUG, "Initializing DiskFfs")
|
|
44
|
+
self.ffs_instance = DiskFfs(self.base_dir)
|
|
45
|
+
|
|
46
|
+
log(DEBUG, "Using DiskFfs")
|
|
47
|
+
return self.ffs_instance
|
|
@@ -35,6 +35,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
35
35
|
PushTaskResResponse,
|
|
36
36
|
)
|
|
37
37
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
38
39
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
39
40
|
from flwr.server.superlink.state import StateFactory
|
|
40
41
|
|
|
@@ -42,18 +43,21 @@ from flwr.server.superlink.state import StateFactory
|
|
|
42
43
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
43
44
|
"""Fleet API servicer."""
|
|
44
45
|
|
|
45
|
-
def __init__(self, state_factory: StateFactory) -> None:
|
|
46
|
+
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
|
|
46
47
|
self.state_factory = state_factory
|
|
48
|
+
self.ffs_factory = ffs_factory
|
|
47
49
|
|
|
48
50
|
def CreateNode(
|
|
49
51
|
self, request: CreateNodeRequest, context: grpc.ServicerContext
|
|
50
52
|
) -> CreateNodeResponse:
|
|
51
53
|
"""."""
|
|
52
54
|
log(INFO, "FleetServicer.CreateNode")
|
|
53
|
-
|
|
55
|
+
response = message_handler.create_node(
|
|
54
56
|
request=request,
|
|
55
57
|
state=self.state_factory.state(),
|
|
56
58
|
)
|
|
59
|
+
log(INFO, "FleetServicer: Created node_id=%s", response.node.node_id)
|
|
60
|
+
return response
|
|
57
61
|
|
|
58
62
|
def DeleteNode(
|
|
59
63
|
self, request: DeleteNodeRequest, context: grpc.ServicerContext
|
|
@@ -106,5 +110,9 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
106
110
|
def GetFab(
|
|
107
111
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
108
112
|
) -> GetFabResponse:
|
|
109
|
-
"""
|
|
110
|
-
|
|
113
|
+
"""Get FAB."""
|
|
114
|
+
log(DEBUG, "DriverServicer.GetFab")
|
|
115
|
+
return message_handler.get_fab(
|
|
116
|
+
request=request,
|
|
117
|
+
ffs=self.ffs_factory.ffs(),
|
|
118
|
+
)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
|
-
from logging import WARNING
|
|
19
|
+
from logging import INFO, WARNING
|
|
20
20
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
@@ -128,9 +128,15 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
128
128
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
129
129
|
|
|
130
130
|
if isinstance(request, CreateNodeRequest):
|
|
131
|
-
|
|
131
|
+
response = self._create_authenticated_node(
|
|
132
132
|
client_public_key_bytes, request, context
|
|
133
133
|
)
|
|
134
|
+
log(
|
|
135
|
+
INFO,
|
|
136
|
+
"AuthenticateServerInterceptor: Created node_id=%s",
|
|
137
|
+
response.node.node_id,
|
|
138
|
+
)
|
|
139
|
+
return response
|
|
134
140
|
|
|
135
141
|
# Verify hmac value
|
|
136
142
|
hmac_value = base64.urlsafe_b64decode(
|
|
@@ -19,7 +19,9 @@ import time
|
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.serde import user_config_to_proto
|
|
22
|
+
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
23
|
+
from flwr.common.typing import Fab
|
|
24
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
23
25
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
24
26
|
CreateNodeRequest,
|
|
25
27
|
CreateNodeResponse,
|
|
@@ -40,6 +42,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
40
42
|
Run,
|
|
41
43
|
)
|
|
42
44
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
|
43
46
|
from flwr.server.superlink.state import State
|
|
44
47
|
|
|
45
48
|
|
|
@@ -124,5 +127,17 @@ def get_run(
|
|
|
124
127
|
fab_id=run.fab_id,
|
|
125
128
|
fab_version=run.fab_version,
|
|
126
129
|
override_config=user_config_to_proto(run.override_config),
|
|
130
|
+
fab_hash=run.fab_hash,
|
|
127
131
|
)
|
|
128
132
|
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_fab(
|
|
136
|
+
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
|
|
137
|
+
) -> GetFabResponse:
|
|
138
|
+
"""Get FAB."""
|
|
139
|
+
if result := ffs.get(request.hash_str):
|
|
140
|
+
fab = Fab(request.hash_str, result[0])
|
|
141
|
+
return GetFabResponse(fab=fab_to_proto(fab))
|
|
142
|
+
|
|
143
|
+
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
|
+
import sys
|
|
17
18
|
from logging import DEBUG, ERROR
|
|
18
19
|
from typing import Callable, Dict, Tuple, Union
|
|
19
20
|
|
|
@@ -111,8 +112,10 @@ class RayBackend(Backend):
|
|
|
111
112
|
if backend_config.get(self.init_args_key):
|
|
112
113
|
for k, v in backend_config[self.init_args_key].items():
|
|
113
114
|
ray_init_args[k] = v
|
|
114
|
-
|
|
115
|
-
|
|
115
|
+
ray.init(
|
|
116
|
+
runtime_env={"env_vars": {"PYTHONPATH": ":".join(sys.path)}},
|
|
117
|
+
**ray_init_args,
|
|
118
|
+
)
|
|
116
119
|
|
|
117
120
|
@property
|
|
118
121
|
def num_workers(self) -> int:
|
|
@@ -27,8 +27,8 @@ from time import sleep
|
|
|
27
27
|
from typing import Callable, Dict, Optional
|
|
28
28
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
|
+
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
30
31
|
from flwr.client.node_state import NodeState
|
|
31
|
-
from flwr.client.supernode.app import _get_load_client_app_fn
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
NUM_PARTITIONS_KEY,
|
|
34
34
|
PARTITION_ID_KEY,
|
|
@@ -345,7 +345,7 @@ def start_vce(
|
|
|
345
345
|
def _load() -> ClientApp:
|
|
346
346
|
|
|
347
347
|
if client_app_attr:
|
|
348
|
-
app =
|
|
348
|
+
app = get_load_client_app_fn(
|
|
349
349
|
default_app_ref=client_app_attr,
|
|
350
350
|
app_path=app_dir,
|
|
351
351
|
flwr_dir=flwr_dir,
|
|
@@ -277,11 +277,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
277
277
|
|
|
278
278
|
def create_run(
|
|
279
279
|
self,
|
|
280
|
-
fab_id: str,
|
|
281
|
-
fab_version: str,
|
|
280
|
+
fab_id: Optional[str],
|
|
281
|
+
fab_version: Optional[str],
|
|
282
|
+
fab_hash: Optional[str],
|
|
282
283
|
override_config: UserConfig,
|
|
283
284
|
) -> int:
|
|
284
|
-
"""Create a new run for the specified `
|
|
285
|
+
"""Create a new run for the specified `fab_hash`."""
|
|
285
286
|
# Sample a random int64 as run_id
|
|
286
287
|
with self.lock:
|
|
287
288
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -289,8 +290,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
289
290
|
if run_id not in self.run_ids:
|
|
290
291
|
self.run_ids[run_id] = Run(
|
|
291
292
|
run_id=run_id,
|
|
292
|
-
fab_id=fab_id,
|
|
293
|
-
fab_version=fab_version,
|
|
293
|
+
fab_id=fab_id if fab_id else "",
|
|
294
|
+
fab_version=fab_version if fab_version else "",
|
|
295
|
+
fab_hash=fab_hash if fab_hash else "",
|
|
294
296
|
override_config=override_config,
|
|
295
297
|
)
|
|
296
298
|
return run_id
|