flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240620__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +3 -7
- flwr/cli/new/new.py +1 -1
- flwr/cli/run/run.py +8 -1
- flwr/client/client_app.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +1 -1
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +12 -7
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +1 -1
- flwr/server/superlink/state/sqlite_state.py +1 -1
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +15 -8
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/RECORD +86 -86
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Flower command line interface."""
|
|
16
16
|
|
|
17
17
|
import typer
|
|
18
|
+
from typer.main import get_command
|
|
18
19
|
|
|
19
20
|
from .build import build
|
|
20
21
|
from .example import example
|
|
@@ -37,5 +38,7 @@ app.command()(run)
|
|
|
37
38
|
app.command()(build)
|
|
38
39
|
app.command()(install)
|
|
39
40
|
|
|
41
|
+
typer_click_object = get_command(app)
|
|
42
|
+
|
|
40
43
|
if __name__ == "__main__":
|
|
41
44
|
app()
|
flwr/cli/build.py
CHANGED
|
@@ -36,13 +36,9 @@ def build(
|
|
|
36
36
|
) -> str:
|
|
37
37
|
"""Build a Flower project into a Flower App Bundle (FAB).
|
|
38
38
|
|
|
39
|
-
You can run
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
You can also build a specific directory:
|
|
44
|
-
|
|
45
|
-
`flwr build --directory ./projects/flower-hello-world`
|
|
39
|
+
You can run ``flwr build`` without any arguments to bundle the current directory,
|
|
40
|
+
or you can use ``--directory`` to build a specific directory:
|
|
41
|
+
``flwr build --directory ./projects/flower-hello-world``.
|
|
46
42
|
"""
|
|
47
43
|
if directory is None:
|
|
48
44
|
directory = Path.cwd()
|
flwr/cli/new/new.py
CHANGED
flwr/cli/run/run.py
CHANGED
|
@@ -41,7 +41,10 @@ class Engine(str, Enum):
|
|
|
41
41
|
def run(
|
|
42
42
|
engine: Annotated[
|
|
43
43
|
Optional[Engine],
|
|
44
|
-
typer.Option(
|
|
44
|
+
typer.Option(
|
|
45
|
+
case_sensitive=False,
|
|
46
|
+
help="The engine to run FL with (currently only simulation is supported).",
|
|
47
|
+
),
|
|
45
48
|
] = None,
|
|
46
49
|
use_superexec: Annotated[
|
|
47
50
|
bool,
|
|
@@ -87,12 +90,16 @@ def run(
|
|
|
87
90
|
|
|
88
91
|
if engine == Engine.SIMULATION:
|
|
89
92
|
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
|
|
93
|
+
backend_config = config["flower"]["engine"]["simulation"].get(
|
|
94
|
+
"backend_config", None
|
|
95
|
+
)
|
|
90
96
|
|
|
91
97
|
typer.secho("Starting run... ", fg=typer.colors.BLUE)
|
|
92
98
|
_run_simulation(
|
|
93
99
|
server_app_attr=server_app_ref,
|
|
94
100
|
client_app_attr=client_app_ref,
|
|
95
101
|
num_supernodes=num_supernodes,
|
|
102
|
+
backend_config=backend_config,
|
|
96
103
|
)
|
|
97
104
|
else:
|
|
98
105
|
typer.secho(
|
flwr/client/client_app.py
CHANGED
flwr/client/mod/__init__.py
CHANGED
flwr/client/mod/utils.py
CHANGED
flwr/client/supernode/app.py
CHANGED
|
@@ -267,7 +267,7 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
267
267
|
"--flwr-dir",
|
|
268
268
|
default=None,
|
|
269
269
|
help="""The path containing installed Flower Apps.
|
|
270
|
-
By default, this value
|
|
270
|
+
By default, this value is equal to:
|
|
271
271
|
|
|
272
272
|
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
273
273
|
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
flwr/common/address.py
CHANGED
flwr/common/config.py
CHANGED
|
@@ -24,14 +24,16 @@ from flwr.cli.config_utils import validate_fields
|
|
|
24
24
|
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def get_flwr_dir() -> Path:
|
|
27
|
+
def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
28
28
|
"""Return the Flower home directory based on env variables."""
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
if provided_path is None or not Path(provided_path).is_dir():
|
|
30
|
+
return Path(
|
|
31
|
+
os.getenv(
|
|
32
|
+
FLWR_HOME,
|
|
33
|
+
f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
|
|
34
|
+
)
|
|
33
35
|
)
|
|
34
|
-
)
|
|
36
|
+
return Path(provided_path).absolute()
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
def get_project_dir(
|
flwr/common/constant.py
CHANGED
flwr/common/date.py
CHANGED
flwr/common/dp.py
CHANGED
flwr/common/grpc.py
CHANGED
flwr/common/version.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
1
15
|
"""Flower package version helper."""
|
|
2
16
|
|
|
3
17
|
import importlib.metadata as importlib_metadata
|
flwr/server/compat/app.py
CHANGED
flwr/server/compat/app_utils.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from typing import Iterable, List, Optional
|
|
20
20
|
|
|
21
21
|
from flwr.common import Message, RecordSet
|
|
22
|
+
from flwr.common.typing import Run
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class Driver(ABC):
|
|
25
26
|
"""Abstract base Driver class for the Driver API."""
|
|
26
27
|
|
|
28
|
+
@property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def run(self) -> Run:
|
|
31
|
+
"""Run information."""
|
|
32
|
+
|
|
27
33
|
@abstractmethod
|
|
28
34
|
def create_message( # pylint: disable=too-many-arguments
|
|
29
35
|
self,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
19
|
from logging import DEBUG, ERROR, WARNING
|
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
|
20
|
+
from typing import Iterable, List, Optional, Tuple, cast
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
|
|
|
25
25
|
from flwr.common.grpc import create_channel
|
|
26
26
|
from flwr.common.logger import log
|
|
27
27
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
+
from flwr.common.typing import Run
|
|
28
29
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
30
|
CreateRunRequest,
|
|
30
31
|
CreateRunResponse,
|
|
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
40
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
41
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
42
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
43
|
|
|
42
44
|
from .driver import Driver
|
|
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
|
46
48
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
47
49
|
[Driver] Error: Not connected.
|
|
48
50
|
|
|
49
|
-
Call `connect()` on the `
|
|
50
|
-
`
|
|
51
|
+
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
|
52
|
+
`GrpcDriverStub` methods.
|
|
51
53
|
"""
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
class
|
|
55
|
-
"""`
|
|
56
|
+
class GrpcDriverStub:
|
|
57
|
+
"""`GrpcDriverStub` provides access to the gRPC Driver API/service.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
driver_service_address : Optional[str]
|
|
62
|
+
The IPv4 or IPv6 address of the Driver API server.
|
|
63
|
+
Defaults to `"[::]:9091"`.
|
|
64
|
+
root_certificates : Optional[bytes] (default: None)
|
|
65
|
+
The PEM-encoded root certificates as a byte string.
|
|
66
|
+
If provided, a secure connection using the certificates will be
|
|
67
|
+
established to an SSL-enabled Flower server.
|
|
68
|
+
"""
|
|
56
69
|
|
|
57
70
|
def __init__(
|
|
58
71
|
self,
|
|
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
|
|
|
64
77
|
self.channel: Optional[grpc.Channel] = None
|
|
65
78
|
self.stub: Optional[DriverStub] = None
|
|
66
79
|
|
|
80
|
+
def is_connected(self) -> bool:
|
|
81
|
+
"""Return True if connected to the Driver API server, otherwise False."""
|
|
82
|
+
return self.channel is not None
|
|
83
|
+
|
|
67
84
|
def connect(self) -> None:
|
|
68
85
|
"""Connect to the Driver API."""
|
|
69
86
|
event(EventType.DRIVER_CONNECT)
|
|
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
|
|
|
95
112
|
# Check if channel is open
|
|
96
113
|
if self.stub is None:
|
|
97
114
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
98
|
-
raise ConnectionError("`
|
|
115
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
99
116
|
|
|
100
117
|
# Call Driver API
|
|
101
118
|
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
102
119
|
return res
|
|
103
120
|
|
|
121
|
+
def get_run(self, req: GetRunRequest) -> GetRunResponse:
|
|
122
|
+
"""Get run information."""
|
|
123
|
+
# Check if channel is open
|
|
124
|
+
if self.stub is None:
|
|
125
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
126
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
127
|
+
|
|
128
|
+
# Call gRPC Driver API
|
|
129
|
+
res: GetRunResponse = self.stub.GetRun(request=req)
|
|
130
|
+
return res
|
|
131
|
+
|
|
104
132
|
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
105
133
|
"""Get client IDs."""
|
|
106
134
|
# Check if channel is open
|
|
107
135
|
if self.stub is None:
|
|
108
136
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
109
|
-
raise ConnectionError("`
|
|
137
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
110
138
|
|
|
111
139
|
# Call gRPC Driver API
|
|
112
140
|
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
|
|
|
117
145
|
# Check if channel is open
|
|
118
146
|
if self.stub is None:
|
|
119
147
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
120
|
-
raise ConnectionError("`
|
|
148
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
121
149
|
|
|
122
150
|
# Call gRPC Driver API
|
|
123
151
|
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
|
|
|
128
156
|
# Check if channel is open
|
|
129
157
|
if self.stub is None:
|
|
130
158
|
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
131
|
-
raise ConnectionError("`
|
|
159
|
+
raise ConnectionError("`GrpcDriverStub` instance not connected")
|
|
132
160
|
|
|
133
161
|
# Call Driver API
|
|
134
162
|
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
|
|
|
140
168
|
|
|
141
169
|
Parameters
|
|
142
170
|
----------
|
|
143
|
-
|
|
144
|
-
The
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
149
|
-
three bytes elements in the following order:
|
|
150
|
-
|
|
151
|
-
* CA certificate.
|
|
152
|
-
* server certificate.
|
|
153
|
-
* server private key.
|
|
154
|
-
fab_id : str (default: None)
|
|
155
|
-
The identifier of the FAB used in the run.
|
|
156
|
-
fab_version : str (default: None)
|
|
157
|
-
The version of the FAB used in the run.
|
|
171
|
+
run_id : int
|
|
172
|
+
The identifier of the run.
|
|
173
|
+
stub : Optional[GrpcDriverStub] (default: None)
|
|
174
|
+
The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
|
|
175
|
+
If None, an instance connected to "[::]:9091" will be created.
|
|
158
176
|
"""
|
|
159
177
|
|
|
160
|
-
def __init__(
|
|
178
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
161
179
|
self,
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
fab_id: Optional[str] = None,
|
|
165
|
-
fab_version: Optional[str] = None,
|
|
180
|
+
run_id: int,
|
|
181
|
+
stub: Optional[GrpcDriverStub] = None,
|
|
166
182
|
) -> None:
|
|
167
|
-
self.
|
|
168
|
-
self.
|
|
169
|
-
self.
|
|
170
|
-
self.run_id: Optional[int] = None
|
|
171
|
-
self.fab_id = fab_id if fab_id is not None else ""
|
|
172
|
-
self.fab_version = fab_version if fab_version is not None else ""
|
|
183
|
+
self._run_id = run_id
|
|
184
|
+
self._run: Optional[Run] = None
|
|
185
|
+
self.stub = stub if stub is not None else GrpcDriverStub()
|
|
173
186
|
self.node = Node(node_id=0, anonymous=True)
|
|
174
187
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
188
|
+
@property
|
|
189
|
+
def run(self) -> Run:
|
|
190
|
+
"""Run information."""
|
|
191
|
+
self._get_stub_and_run_id()
|
|
192
|
+
return Run(**vars(cast(Run, self._run)))
|
|
193
|
+
|
|
194
|
+
def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
|
|
195
|
+
# Check if is initialized
|
|
196
|
+
if self._run is None:
|
|
197
|
+
# Connect
|
|
198
|
+
if not self.stub.is_connected():
|
|
199
|
+
self.stub.connect()
|
|
200
|
+
# Get the run info
|
|
201
|
+
req = GetRunRequest(run_id=self._run_id)
|
|
202
|
+
res = self.stub.get_run(req)
|
|
203
|
+
if not res.HasField("run"):
|
|
204
|
+
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
|
|
205
|
+
self._run = Run(
|
|
206
|
+
run_id=res.run.run_id,
|
|
207
|
+
fab_id=res.run.fab_id,
|
|
208
|
+
fab_version=res.run.fab_version,
|
|
182
209
|
)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
res = self.driver_helper.create_run(req)
|
|
186
|
-
self.run_id = res.run_id
|
|
187
|
-
return self.driver_helper, self.run_id
|
|
210
|
+
|
|
211
|
+
return self.stub, self._run.run_id
|
|
188
212
|
|
|
189
213
|
def _check_message(self, message: Message) -> None:
|
|
190
214
|
# Check if the message is valid
|
|
191
215
|
if not (
|
|
192
|
-
message.metadata.run_id == self.run_id
|
|
216
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
193
217
|
and message.metadata.src_node_id == self.node.node_id
|
|
194
218
|
and message.metadata.message_id == ""
|
|
195
219
|
and message.metadata.reply_to_message == ""
|
|
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
|
|
|
210
234
|
This method constructs a new `Message` with given content and metadata.
|
|
211
235
|
The `run_id` and `src_node_id` will be set automatically.
|
|
212
236
|
"""
|
|
213
|
-
_, run_id = self.
|
|
237
|
+
_, run_id = self._get_stub_and_run_id()
|
|
214
238
|
if ttl:
|
|
215
239
|
warnings.warn(
|
|
216
240
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
|
|
|
234
258
|
|
|
235
259
|
def get_node_ids(self) -> List[int]:
|
|
236
260
|
"""Get node IDs."""
|
|
237
|
-
|
|
238
|
-
# Call
|
|
239
|
-
res =
|
|
261
|
+
stub, run_id = self._get_stub_and_run_id()
|
|
262
|
+
# Call GrpcDriverStub method
|
|
263
|
+
res = stub.get_nodes(GetNodesRequest(run_id=run_id))
|
|
240
264
|
return [node.node_id for node in res.nodes]
|
|
241
265
|
|
|
242
266
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
|
|
|
245
269
|
This method takes an iterable of messages and sends each message
|
|
246
270
|
to the node specified in `dst_node_id`.
|
|
247
271
|
"""
|
|
248
|
-
|
|
272
|
+
stub, _ = self._get_stub_and_run_id()
|
|
249
273
|
# Construct TaskIns
|
|
250
274
|
task_ins_list: List[TaskIns] = []
|
|
251
275
|
for msg in messages:
|
|
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
|
|
|
255
279
|
taskins = message_to_taskins(msg)
|
|
256
280
|
# Add to list
|
|
257
281
|
task_ins_list.append(taskins)
|
|
258
|
-
# Call
|
|
259
|
-
res =
|
|
260
|
-
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
261
|
-
)
|
|
282
|
+
# Call GrpcDriverStub method
|
|
283
|
+
res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
|
|
262
284
|
return list(res.task_ids)
|
|
263
285
|
|
|
264
286
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
|
|
|
267
289
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
268
290
|
set of given message IDs.
|
|
269
291
|
"""
|
|
270
|
-
|
|
292
|
+
stub, _ = self._get_stub_and_run_id()
|
|
271
293
|
# Pull TaskRes
|
|
272
|
-
res =
|
|
294
|
+
res = stub.pull_task_res(
|
|
273
295
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
274
296
|
)
|
|
275
297
|
# Convert TaskRes to Message
|
|
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
|
|
|
308
330
|
|
|
309
331
|
def close(self) -> None:
|
|
310
332
|
"""Disconnect from the SuperLink if connected."""
|
|
311
|
-
# Check if
|
|
312
|
-
if self.
|
|
333
|
+
# Check if `connect` was called before
|
|
334
|
+
if not self.stub.is_connected():
|
|
313
335
|
return
|
|
314
336
|
# Disconnect
|
|
315
|
-
self.
|
|
337
|
+
self.stub.disconnect()
|