flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240707__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +5 -9
- flwr/cli/new/new.py +104 -28
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
- flwr/cli/run/run.py +21 -5
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +15 -10
- flwr/client/client_app.py +30 -5
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +4 -5
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +6 -3
- flwr/client/node_state_tests.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +12 -4
- flwr/client/typing.py +2 -1
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +4 -1
- flwr/common/context.py +11 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +13 -0
- flwr/common/message.py +0 -17
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +61 -18
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +45 -26
- flwr/server/superlink/fleet/vce/vce_api.py +3 -8
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +5 -5
- flwr/server/superlink/state/sqlite_state.py +5 -5
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +52 -37
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +0 -6
- flwr/simulation/ray_transport/ray_client_proxy.py +17 -10
- flwr/simulation/run_simulation.py +47 -28
- flwr/superexec/deployment.py +109 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/RECORD +109 -98
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/entry_points.txt +0 -0
flwr/common/date.py
CHANGED
flwr/common/dp.py
CHANGED
flwr/common/grpc.py
CHANGED
flwr/common/logger.py
CHANGED
|
@@ -197,6 +197,19 @@ def warn_deprecated_feature(name: str) -> None:
|
|
|
197
197
|
)
|
|
198
198
|
|
|
199
199
|
|
|
200
|
+
def warn_unsupported_feature(name: str) -> None:
|
|
201
|
+
"""Warn the user when they use an unsupported feature."""
|
|
202
|
+
log(
|
|
203
|
+
WARN,
|
|
204
|
+
"""UNSUPPORTED FEATURE: %s
|
|
205
|
+
|
|
206
|
+
This is an unsupported feature. It will be removed
|
|
207
|
+
entirely in future versions of Flower.
|
|
208
|
+
""",
|
|
209
|
+
name,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
200
213
|
def set_logger_propagation(
|
|
201
214
|
child_logger: logging.Logger, value: bool = True
|
|
202
215
|
) -> logging.Logger:
|
flwr/common/message.py
CHANGED
|
@@ -48,10 +48,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
48
48
|
message_type : str
|
|
49
49
|
A string that encodes the action to be executed on
|
|
50
50
|
the receiving end.
|
|
51
|
-
partition_id : Optional[int]
|
|
52
|
-
An identifier that can be used when loading a particular
|
|
53
|
-
data partition for a ClientApp. Making use of this identifier
|
|
54
|
-
is more relevant when conducting simulations.
|
|
55
51
|
"""
|
|
56
52
|
|
|
57
53
|
def __init__( # pylint: disable=too-many-arguments
|
|
@@ -64,7 +60,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
64
60
|
group_id: str,
|
|
65
61
|
ttl: float,
|
|
66
62
|
message_type: str,
|
|
67
|
-
partition_id: int | None = None,
|
|
68
63
|
) -> None:
|
|
69
64
|
var_dict = {
|
|
70
65
|
"_run_id": run_id,
|
|
@@ -75,7 +70,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
75
70
|
"_group_id": group_id,
|
|
76
71
|
"_ttl": ttl,
|
|
77
72
|
"_message_type": message_type,
|
|
78
|
-
"_partition_id": partition_id,
|
|
79
73
|
}
|
|
80
74
|
self.__dict__.update(var_dict)
|
|
81
75
|
|
|
@@ -149,16 +143,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
149
143
|
"""Set message_type."""
|
|
150
144
|
self.__dict__["_message_type"] = value
|
|
151
145
|
|
|
152
|
-
@property
|
|
153
|
-
def partition_id(self) -> int | None:
|
|
154
|
-
"""An identifier telling which data partition a ClientApp should use."""
|
|
155
|
-
return cast(int, self.__dict__["_partition_id"])
|
|
156
|
-
|
|
157
|
-
@partition_id.setter
|
|
158
|
-
def partition_id(self, value: int) -> None:
|
|
159
|
-
"""Set partition_id."""
|
|
160
|
-
self.__dict__["_partition_id"] = value
|
|
161
|
-
|
|
162
146
|
def __repr__(self) -> str:
|
|
163
147
|
"""Return a string representation of this instance."""
|
|
164
148
|
view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
|
|
@@ -398,5 +382,4 @@ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
|
|
|
398
382
|
group_id=msg.metadata.group_id,
|
|
399
383
|
ttl=ttl,
|
|
400
384
|
message_type=msg.metadata.message_type,
|
|
401
|
-
partition_id=msg.metadata.partition_id,
|
|
402
385
|
)
|
flwr/common/version.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
1
15
|
"""Flower package version helper."""
|
|
2
16
|
|
|
3
17
|
import importlib.metadata as importlib_metadata
|
flwr/server/compat/app.py
CHANGED
flwr/server/compat/app_utils.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from typing import Iterable, List, Optional
|
|
20
20
|
|
|
21
21
|
from flwr.common import Message, RecordSet
|
|
22
|
+
from flwr.common.typing import Run
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class Driver(ABC):
|
|
25
26
|
"""Abstract base Driver class for the Driver API."""
|
|
26
27
|
|
|
28
|
+
@property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def run(self) -> Run:
|
|
31
|
+
"""Run information."""
|
|
32
|
+
|
|
27
33
|
@abstractmethod
|
|
28
34
|
def create_message( # pylint: disable=too-many-arguments
|
|
29
35
|
self,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
19
|
from logging import DEBUG, ERROR, WARNING
|
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
|
20
|
+
from typing import Iterable, List, Optional, Tuple, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
|
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
27
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
+
from flwr.common.typing import Run
|
|
28
29
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
30
|
CreateRunRequest,
|
|
30
31
|
CreateRunResponse,
|
|
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
40
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
41
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
42
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
43
|
|
|
42
44
|
from .driver import Driver
|
|
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
|
46
48
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
47
49
|
[Driver] Error: Not connected.
|
|
48
50
|
|
|
49
|
-
Call `connect()` on the `
|
|
50
|
-
`
|
|
51
|
+
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
|
52
|
+
`GrpcDriverStub` methods.
|
|
51
53
|
"""
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
class
|
|
55
|
-
"""`
|
|
56
|
+
class GrpcDriverStub:
|
|
57
|
+
"""`GrpcDriverStub` provides access to the gRPC Driver API/service.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
driver_service_address : Optional[str]
|
|
62
|
+
The IPv4 or IPv6 address of the Driver API server.
|
|
63
|
+
Defaults to `"[::]:9091"`.
|
|
64
|
+
root_certificates : Optional[bytes] (default: None)
|
|
65
|
+
The PEM-encoded root certificates as a byte string.
|
|
66
|
+
If provided, a secure connection using the certificates will be
|
|
67
|
+
established to an SSL-enabled Flower server.
|
|
68
|
+
"""
|
|
56
69
|
|
|
57
70
|
def __init__(
|
|
58
71
|
self,
|
|
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
|
|
|
64
77
|
self.channel: Optional[grpc.Channel] = None
|
|
65
78
|
self.stub: Optional[DriverStub] = None
|
|
66
79
|
|
|
80
|
+
def is_connected(self) -> bool:
|
|
81
|
+
"""Return True if connected to the Driver API server, otherwise False."""
|
|
82
|
+
return self.channel is not None
|
|
83
|
+
|
|
67
84
|
def connect(self) -> None:
|
|
68
85
|
"""Connect to the Driver API."""
|
|
69
86
|
event(EventType.DRIVER_CONNECT)
|
|
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
|
|
|
95
112
|
# Check if channel is open
|
|
96
113
|
if self.stub is None:
|
|
97
114
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
98
|
-
raise ConnectionError("`
|
|
115
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
99
116
|
|
|
100
117
|
# Call Driver API
|
|
101
118
|
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
102
119
|
return res
|
|
103
120
|
|
|
121
|
+
def get_run(self, req: GetRunRequest) -> GetRunResponse:
|
|
122
|
+
"""Get run information."""
|
|
123
|
+
# Check if channel is open
|
|
124
|
+
if self.stub is None:
|
|
125
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
+
|
|
128
|
+
# Call gRPC Driver API
|
|
129
|
+
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
+
return res
|
|
131
|
+
|
|
104
132
|
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
105
133
|
"""Get client IDs."""
|
|
106
134
|
# Check if channel is open
|
|
107
135
|
if self.stub is None:
|
|
108
136
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
109
|
-
raise ConnectionError("`
|
|
137
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
110
138
|
|
|
111
139
|
# Call gRPC Driver API
|
|
112
140
|
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
|
|
|
117
145
|
# Check if channel is open
|
|
118
146
|
if self.stub is None:
|
|
119
147
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
120
|
-
raise ConnectionError("`
|
|
148
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
121
149
|
|
|
122
150
|
# Call gRPC Driver API
|
|
123
151
|
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
|
|
|
128
156
|
# Check if channel is open
|
|
129
157
|
if self.stub is None:
|
|
130
158
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
131
|
-
raise ConnectionError("`
|
|
159
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
132
160
|
|
|
133
161
|
# Call Driver API
|
|
134
162
|
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
|
|
|
140
168
|
|
|
141
169
|
Parameters
|
|
142
170
|
----------
|
|
143
|
-
|
|
144
|
-
The
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
149
|
-
three bytes elements in the following order:
|
|
150
|
-
|
|
151
|
-
* CA certificate.
|
|
152
|
-
* server certificate.
|
|
153
|
-
* server private key.
|
|
154
|
-
fab_id : str (default: None)
|
|
155
|
-
The identifier of the FAB used in the run.
|
|
156
|
-
fab_version : str (default: None)
|
|
157
|
-
The version of the FAB used in the run.
|
|
171
|
+
run_id : int
|
|
172
|
+
The identifier of the run.
|
|
173
|
+
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
+
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
+
If None, an instance connected to "[::]:9091" will be created.
|
|
158
176
|
"""
|
|
159
177
|
|
|
160
|
-
def __init__(
|
|
178
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
161
179
|
self,
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
fab_id: Optional[str] = None,
|
|
165
|
-
fab_version: Optional[str] = None,
|
|
180
|
+
run_id: int,
|
|
181
|
+
stub: Optional[GrpcDriverStub] = None,
|
|
166
182
|
) -> None:
|
|
167
|
-
self.
|
|
168
|
-
self.
|
|
169
|
-
self.
|
|
170
|
-
self.run_id: Optional[int] = None
|
|
171
|
-
self.fab_id = fab_id if fab_id is not None else ""
|
|
172
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
183
|
+
self._run_id = run_id
|
|
184
|
+
self._run: Optional[Run] = None
|
|
185
|
+
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
173
186
|
self.node = Node(node_id=0, anonymous=True)
|
|
174
187
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
188
|
+
@property
|
|
189
|
+
def run(self) -> Run:
|
|
190
|
+
"""Run information."""
|
|
191
|
+
self._get_stub_and_run_id()
|
|
192
|
+
return Run(**vars(cast(Run, self._run)))
|
|
193
|
+
|
|
194
|
+
def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
|
|
195
|
+
# Check if is initialized
|
|
196
|
+
if self._run is None:
|
|
197
|
+
# Connect
|
|
198
|
+
if not self.stub.is_connected():
|
|
199
|
+
self.stub.connect()
|
|
200
|
+
# Get the run info
|
|
201
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
+
res = self.stub.get_run(req)
|
|
203
|
+
if not res.HasField("run"):
|
|
204
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
+
self._run = Run(
|
|
206
|
+
run_id=res.run.run_id,
|
|
207
|
+
fab_id=res.run.fab_id,
|
|
208
|
+
fab_version=res.run.fab_version,
|
|
182
209
|
)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
res = self.driver_helper.create_run(req)
|
|
186
|
-
self.run_id = res.run_id
|
|
187
|
-
return self.driver_helper, self.run_id
|
|
210
|
+
|
|
211
|
+
return self.stub, self._run.run_id
|
|
188
212
|
|
|
189
213
|
def _check_message(self, message: Message) -> None:
|
|
190
214
|
# Check if the message is valid
|
|
191
215
|
if not (
|
|
192
|
-
message.metadata.run_id == self.run_id
|
|
216
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
193
217
|
and message.metadata.src_node_id == self.node.node_id
|
|
194
218
|
and message.metadata.message_id == ""
|
|
195
219
|
and message.metadata.reply_to_message == ""
|
|
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
|
|
|
210
234
|
This method constructs a new `Message` with given content and metadata.
|
|
211
235
|
The `run_id` and `src_node_id` will be set automatically.
|
|
212
236
|
"""
|
|
213
|
-
_, run_id = self.
|
|
237
|
+
_, run_id = self._get_stub_and_run_id()
|
|
214
238
|
if ttl:
|
|
215
239
|
warnings.warn(
|
|
216
240
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
|
|
|
234
258
|
|
|
235
259
|
def get_node_ids(self) -> List[int]:
|
|
236
260
|
"""Get node IDs."""
|
|
237
|
-
|
|
238
|
-
# Call
|
|
239
|
-
res =
|
|
261
|
+
stub, run_id = self._get_stub_and_run_id()
|
|
262
|
+
# Call GrpcDriverStub method
|
|
263
|
+
res = stub.get_nodes(GetNodesRequest(run_id=run_id))
|
|
240
264
|
return [node.node_id for node in res.nodes]
|
|
241
265
|
|
|
242
266
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
|
|
|
245
269
|
This method takes an iterable of messages and sends each message
|
|
246
270
|
to the node specified in `dst_node_id`.
|
|
247
271
|
"""
|
|
248
|
-
|
|
272
|
+
stub, _ = self._get_stub_and_run_id()
|
|
249
273
|
# Construct TaskIns
|
|
250
274
|
task_ins_list: List[TaskIns] = []
|
|
251
275
|
for msg in messages:
|
|
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
|
|
|
255
279
|
taskins = message_to_taskins(msg)
|
|
256
280
|
# Add to list
|
|
257
281
|
task_ins_list.append(taskins)
|
|
258
|
-
# Call
|
|
259
|
-
res =
|
|
260
|
-
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
261
|
-
)
|
|
282
|
+
# Call GrpcDriverStub method
|
|
283
|
+
res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
|
|
262
284
|
return list(res.task_ids)
|
|
263
285
|
|
|
264
286
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
|
|
|
267
289
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
268
290
|
set of given message IDs.
|
|
269
291
|
"""
|
|
270
|
-
|
|
292
|
+
stub, _ = self._get_stub_and_run_id()
|
|
271
293
|
# Pull TaskRes
|
|
272
|
-
res =
|
|
294
|
+
res = stub.pull_task_res(
|
|
273
295
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
274
296
|
)
|
|
275
297
|
# Convert TaskRes to Message
|
|
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
|
|
|
308
330
|
|
|
309
331
|
def close(self) -> None:
|
|
310
332
|
"""Disconnect from the SuperLink if connected."""
|
|
311
|
-
# Check if
|
|
312
|
-
if self.
|
|
333
|
+
# Check if `connect` was called before
|
|
334
|
+
if not self.stub.is_connected():
|
|
313
335
|
return
|
|
314
336
|
# Disconnect
|
|
315
|
-
self.
|
|
337
|
+
self.stub.disconnect()
|
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
import warnings
|
|
20
|
-
from typing import Iterable, List, Optional
|
|
20
|
+
from typing import Iterable, List, Optional, cast
|
|
21
21
|
from uuid import UUID
|
|
22
22
|
|
|
23
23
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
24
24
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
25
|
+
from flwr.common.typing import Run
|
|
25
26
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
26
27
|
from flwr.server.superlink.state import StateFactory
|
|
27
28
|
|
|
@@ -33,30 +34,27 @@ class InMemoryDriver(Driver):
|
|
|
33
34
|
|
|
34
35
|
Parameters
|
|
35
36
|
----------
|
|
37
|
+
run_id : int
|
|
38
|
+
The identifier of the run.
|
|
36
39
|
state_factory : StateFactory
|
|
37
40
|
A StateFactory embedding a state that this driver can interface with.
|
|
38
|
-
fab_id : str (default: None)
|
|
39
|
-
The identifier of the FAB used in the run.
|
|
40
|
-
fab_version : str (default: None)
|
|
41
|
-
The version of the FAB used in the run.
|
|
42
41
|
"""
|
|
43
42
|
|
|
44
43
|
def __init__(
|
|
45
44
|
self,
|
|
45
|
+
run_id: int,
|
|
46
46
|
state_factory: StateFactory,
|
|
47
|
-
fab_id: Optional[str] = None,
|
|
48
|
-
fab_version: Optional[str] = None,
|
|
49
47
|
) -> None:
|
|
50
|
-
self.
|
|
51
|
-
self.
|
|
52
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
53
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
48
|
+
self._run_id = run_id
|
|
49
|
+
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
51
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
55
52
|
|
|
56
53
|
def _check_message(self, message: Message) -> None:
|
|
54
|
+
self._init_run()
|
|
57
55
|
# Check if the message is valid
|
|
58
56
|
if not (
|
|
59
|
-
message.metadata.run_id == self.run_id
|
|
57
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
60
58
|
and message.metadata.src_node_id == self.node.node_id
|
|
61
59
|
and message.metadata.message_id == ""
|
|
62
60
|
and message.metadata.reply_to_message == ""
|
|
@@ -64,16 +62,20 @@ class InMemoryDriver(Driver):
|
|
|
64
62
|
):
|
|
65
63
|
raise ValueError(f"Invalid message: {message}")
|
|
66
64
|
|
|
67
|
-
def
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
if
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
def _init_run(self) -> None:
|
|
66
|
+
"""Initialize the run."""
|
|
67
|
+
if self._run is not None:
|
|
68
|
+
return
|
|
69
|
+
run = self.state.get_run(self._run_id)
|
|
70
|
+
if run is None:
|
|
71
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
72
|
+
self._run = run
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def run(self) -> Run:
|
|
76
|
+
"""Run ID."""
|
|
77
|
+
self._init_run()
|
|
78
|
+
return Run(**vars(cast(Run, self._run)))
|
|
77
79
|
|
|
78
80
|
def create_message( # pylint: disable=too-many-arguments
|
|
79
81
|
self,
|
|
@@ -88,7 +90,7 @@ class InMemoryDriver(Driver):
|
|
|
88
90
|
This method constructs a new `Message` with given content and metadata.
|
|
89
91
|
The `run_id` and `src_node_id` will be set automatically.
|
|
90
92
|
"""
|
|
91
|
-
|
|
93
|
+
self._init_run()
|
|
92
94
|
if ttl:
|
|
93
95
|
warnings.warn(
|
|
94
96
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -99,7 +101,7 @@ class InMemoryDriver(Driver):
|
|
|
99
101
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
100
102
|
|
|
101
103
|
metadata = Metadata(
|
|
102
|
-
run_id=run_id,
|
|
104
|
+
run_id=cast(Run, self._run).run_id,
|
|
103
105
|
message_id="", # Will be set by the server
|
|
104
106
|
src_node_id=self.node.node_id,
|
|
105
107
|
dst_node_id=dst_node_id,
|
|
@@ -112,8 +114,8 @@ class InMemoryDriver(Driver):
|
|
|
112
114
|
|
|
113
115
|
def get_node_ids(self) -> List[int]:
|
|
114
116
|
"""Get node IDs."""
|
|
115
|
-
|
|
116
|
-
return list(self.state.get_nodes(run_id))
|
|
117
|
+
self._init_run()
|
|
118
|
+
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
117
119
|
|
|
118
120
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
119
121
|
"""Push messages to specified node IDs.
|