flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +42 -18
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/app.py +111 -1
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -16,16 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import
|
|
22
|
+
from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
|
-
from flwr.common import serde
|
|
25
|
-
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
24
|
from flwr.server.client_proxy import ClientProxy
|
|
27
25
|
|
|
28
|
-
from ..driver.driver import
|
|
26
|
+
from ..driver.driver import Driver
|
|
29
27
|
|
|
30
28
|
SLEEP_TIME = 1
|
|
31
29
|
|
|
@@ -33,9 +31,7 @@ SLEEP_TIME = 1
|
|
|
33
31
|
class DriverClientProxy(ClientProxy):
|
|
34
32
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
35
33
|
|
|
36
|
-
def __init__(
|
|
37
|
-
self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
|
|
38
|
-
):
|
|
34
|
+
def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
|
|
39
35
|
super().__init__(str(node_id))
|
|
40
36
|
self.node_id = node_id
|
|
41
37
|
self.driver = driver
|
|
@@ -116,80 +112,39 @@ class DriverClientProxy(ClientProxy):
|
|
|
116
112
|
timeout: Optional[float],
|
|
117
113
|
group_id: Optional[int],
|
|
118
114
|
) -> RecordSet:
|
|
119
|
-
task_ins = task_pb2.TaskIns( # pylint: disable=E1101
|
|
120
|
-
task_id="",
|
|
121
|
-
group_id=str(group_id) if group_id is not None else "",
|
|
122
|
-
run_id=self.run_id,
|
|
123
|
-
task=task_pb2.Task( # pylint: disable=E1101
|
|
124
|
-
producer=node_pb2.Node( # pylint: disable=E1101
|
|
125
|
-
node_id=0,
|
|
126
|
-
anonymous=True,
|
|
127
|
-
),
|
|
128
|
-
consumer=node_pb2.Node( # pylint: disable=E1101
|
|
129
|
-
node_id=self.node_id,
|
|
130
|
-
anonymous=self.anonymous,
|
|
131
|
-
),
|
|
132
|
-
task_type=task_type,
|
|
133
|
-
recordset=serde.recordset_to_proto(recordset),
|
|
134
|
-
ttl=DEFAULT_TTL,
|
|
135
|
-
),
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
# This would normally be recorded upon common.Message creation
|
|
139
|
-
# but this compatibility stack doesn't create Messages,
|
|
140
|
-
# so we need to inject `created_at` manually (needed for
|
|
141
|
-
# taskins validation by server.utils.validator)
|
|
142
|
-
task_ins.task.created_at = time.time()
|
|
143
115
|
|
|
144
|
-
|
|
145
|
-
|
|
116
|
+
# Create message
|
|
117
|
+
message = self.driver.create_message(
|
|
118
|
+
content=recordset,
|
|
119
|
+
message_type=task_type,
|
|
120
|
+
dst_node_id=self.node_id,
|
|
121
|
+
group_id=str(group_id) if group_id else "",
|
|
122
|
+
ttl=timeout,
|
|
146
123
|
)
|
|
147
124
|
|
|
148
|
-
#
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
raise ValueError("Unexpected number of task_ids")
|
|
125
|
+
# Push message
|
|
126
|
+
message_ids = list(self.driver.push_messages(messages=[message]))
|
|
127
|
+
if len(message_ids) != 1:
|
|
128
|
+
raise ValueError("Unexpected number of message_ids")
|
|
153
129
|
|
|
154
|
-
|
|
155
|
-
if
|
|
156
|
-
raise ValueError(f"Failed to
|
|
130
|
+
message_id = message_ids[0]
|
|
131
|
+
if message_id == "":
|
|
132
|
+
raise ValueError(f"Failed to send message to node {self.node_id}")
|
|
157
133
|
|
|
158
134
|
if timeout:
|
|
159
135
|
start_time = time.time()
|
|
160
136
|
|
|
161
137
|
while True:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
pull_task_res_res.task_res_list
|
|
172
|
-
)
|
|
173
|
-
if len(task_res_list) == 1:
|
|
174
|
-
task_res = task_res_list[0]
|
|
175
|
-
|
|
176
|
-
# This will raise an Exception if task_res carries an `error`
|
|
177
|
-
validate_task_res(task_res=task_res)
|
|
178
|
-
|
|
179
|
-
return serde.recordset_from_proto(task_res.task.recordset)
|
|
138
|
+
messages = list(self.driver.pull_messages(message_ids))
|
|
139
|
+
if len(messages) == 1:
|
|
140
|
+
msg: Message = messages[0]
|
|
141
|
+
if msg.has_error():
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
144
|
+
"It originated during client-side execution of a message."
|
|
145
|
+
)
|
|
146
|
+
return msg.content
|
|
180
147
|
|
|
181
148
|
if timeout is not None and time.time() > start_time + timeout:
|
|
182
149
|
raise RuntimeError("Timeout reached")
|
|
183
150
|
time.sleep(SLEEP_TIME)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def validate_task_res(
|
|
187
|
-
task_res: task_pb2.TaskRes, # pylint: disable=E1101
|
|
188
|
-
) -> None:
|
|
189
|
-
"""Validate if a TaskRes is empty or not."""
|
|
190
|
-
if not task_res.HasField("task"):
|
|
191
|
-
raise ValueError("Invalid TaskRes, field `task` missing")
|
|
192
|
-
if task_res.task.HasField("error"):
|
|
193
|
-
raise ValueError("Exception during client-side task execution")
|
|
194
|
-
if not task_res.task.HasField("recordset"):
|
|
195
|
-
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
|
flwr/server/driver/__init__.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,180 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Driver (abstract base class)."""
|
|
16
16
|
|
|
17
|
-
import time
|
|
18
|
-
import warnings
|
|
19
|
-
from logging import DEBUG, ERROR, WARNING
|
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
|
21
17
|
|
|
22
|
-
import
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Iterable, List, Optional
|
|
23
20
|
|
|
24
|
-
from flwr.common import
|
|
25
|
-
from flwr.common.grpc import create_channel
|
|
26
|
-
from flwr.common.logger import log
|
|
27
|
-
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
-
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
|
-
CreateRunRequest,
|
|
30
|
-
CreateRunResponse,
|
|
31
|
-
GetNodesRequest,
|
|
32
|
-
GetNodesResponse,
|
|
33
|
-
PullTaskResRequest,
|
|
34
|
-
PullTaskResResponse,
|
|
35
|
-
PushTaskInsRequest,
|
|
36
|
-
PushTaskInsResponse,
|
|
37
|
-
)
|
|
38
|
-
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
40
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
21
|
+
from flwr.common import Message, RecordSet
|
|
41
22
|
|
|
42
|
-
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
43
23
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
|
|
48
|
-
`GrpcDriverHelper` methods.
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class GrpcDriverHelper:
|
|
53
|
-
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
|
|
54
|
-
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
58
|
-
root_certificates: Optional[bytes] = None,
|
|
59
|
-
) -> None:
|
|
60
|
-
self.driver_service_address = driver_service_address
|
|
61
|
-
self.root_certificates = root_certificates
|
|
62
|
-
self.channel: Optional[grpc.Channel] = None
|
|
63
|
-
self.stub: Optional[DriverStub] = None
|
|
64
|
-
|
|
65
|
-
def connect(self) -> None:
|
|
66
|
-
"""Connect to the Driver API."""
|
|
67
|
-
event(EventType.DRIVER_CONNECT)
|
|
68
|
-
if self.channel is not None or self.stub is not None:
|
|
69
|
-
log(WARNING, "Already connected")
|
|
70
|
-
return
|
|
71
|
-
self.channel = create_channel(
|
|
72
|
-
server_address=self.driver_service_address,
|
|
73
|
-
insecure=(self.root_certificates is None),
|
|
74
|
-
root_certificates=self.root_certificates,
|
|
75
|
-
)
|
|
76
|
-
self.stub = DriverStub(self.channel)
|
|
77
|
-
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
|
|
78
|
-
|
|
79
|
-
def disconnect(self) -> None:
|
|
80
|
-
"""Disconnect from the Driver API."""
|
|
81
|
-
event(EventType.DRIVER_DISCONNECT)
|
|
82
|
-
if self.channel is None or self.stub is None:
|
|
83
|
-
log(DEBUG, "Already disconnected")
|
|
84
|
-
return
|
|
85
|
-
channel = self.channel
|
|
86
|
-
self.channel = None
|
|
87
|
-
self.stub = None
|
|
88
|
-
channel.close()
|
|
89
|
-
log(DEBUG, "[Driver] Disconnected")
|
|
90
|
-
|
|
91
|
-
def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
|
|
92
|
-
"""Request for run ID."""
|
|
93
|
-
# Check if channel is open
|
|
94
|
-
if self.stub is None:
|
|
95
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
96
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
97
|
-
|
|
98
|
-
# Call Driver API
|
|
99
|
-
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
100
|
-
return res
|
|
101
|
-
|
|
102
|
-
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
103
|
-
"""Get client IDs."""
|
|
104
|
-
# Check if channel is open
|
|
105
|
-
if self.stub is None:
|
|
106
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
107
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
108
|
-
|
|
109
|
-
# Call gRPC Driver API
|
|
110
|
-
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
111
|
-
return res
|
|
112
|
-
|
|
113
|
-
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
114
|
-
"""Schedule tasks."""
|
|
115
|
-
# Check if channel is open
|
|
116
|
-
if self.stub is None:
|
|
117
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
118
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
119
|
-
|
|
120
|
-
# Call gRPC Driver API
|
|
121
|
-
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
122
|
-
return res
|
|
123
|
-
|
|
124
|
-
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
125
|
-
"""Get task results."""
|
|
126
|
-
# Check if channel is open
|
|
127
|
-
if self.stub is None:
|
|
128
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
129
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
130
|
-
|
|
131
|
-
# Call Driver API
|
|
132
|
-
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
133
|
-
return res
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class Driver:
|
|
137
|
-
"""`Driver` class provides an interface to the Driver API.
|
|
138
|
-
|
|
139
|
-
Parameters
|
|
140
|
-
----------
|
|
141
|
-
driver_service_address : Optional[str]
|
|
142
|
-
The IPv4 or IPv6 address of the Driver API server.
|
|
143
|
-
Defaults to `"[::]:9091"`.
|
|
144
|
-
certificates : bytes (default: None)
|
|
145
|
-
Tuple containing root certificate, server certificate, and private key
|
|
146
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
147
|
-
three bytes elements in the following order:
|
|
148
|
-
|
|
149
|
-
* CA certificate.
|
|
150
|
-
* server certificate.
|
|
151
|
-
* server private key.
|
|
152
|
-
"""
|
|
153
|
-
|
|
154
|
-
def __init__(
|
|
155
|
-
self,
|
|
156
|
-
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
157
|
-
root_certificates: Optional[bytes] = None,
|
|
158
|
-
) -> None:
|
|
159
|
-
self.addr = driver_service_address
|
|
160
|
-
self.root_certificates = root_certificates
|
|
161
|
-
self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
|
|
162
|
-
self.run_id: Optional[int] = None
|
|
163
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
164
|
-
|
|
165
|
-
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
|
|
166
|
-
# Check if the GrpcDriverHelper is initialized
|
|
167
|
-
if self.grpc_driver_helper is None or self.run_id is None:
|
|
168
|
-
# Connect and create run
|
|
169
|
-
self.grpc_driver_helper = GrpcDriverHelper(
|
|
170
|
-
driver_service_address=self.addr,
|
|
171
|
-
root_certificates=self.root_certificates,
|
|
172
|
-
)
|
|
173
|
-
self.grpc_driver_helper.connect()
|
|
174
|
-
res = self.grpc_driver_helper.create_run(CreateRunRequest())
|
|
175
|
-
self.run_id = res.run_id
|
|
176
|
-
return self.grpc_driver_helper, self.run_id
|
|
177
|
-
|
|
178
|
-
def _check_message(self, message: Message) -> None:
|
|
179
|
-
# Check if the message is valid
|
|
180
|
-
if not (
|
|
181
|
-
message.metadata.run_id == self.run_id
|
|
182
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
183
|
-
and message.metadata.message_id == ""
|
|
184
|
-
and message.metadata.reply_to_message == ""
|
|
185
|
-
and message.metadata.ttl > 0
|
|
186
|
-
):
|
|
187
|
-
raise ValueError(f"Invalid message: {message}")
|
|
24
|
+
class Driver(ABC):
|
|
25
|
+
"""Abstract base Driver class for the Driver API."""
|
|
188
26
|
|
|
27
|
+
@abstractmethod
|
|
189
28
|
def create_message( # pylint: disable=too-many-arguments
|
|
190
29
|
self,
|
|
191
30
|
content: RecordSet,
|
|
@@ -223,35 +62,12 @@ class Driver:
|
|
|
223
62
|
message : Message
|
|
224
63
|
A new `Message` instance with the specified content and metadata.
|
|
225
64
|
"""
|
|
226
|
-
_, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
227
|
-
if ttl:
|
|
228
|
-
warnings.warn(
|
|
229
|
-
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
230
|
-
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
231
|
-
"version of Flower.",
|
|
232
|
-
stacklevel=2,
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
236
|
-
metadata = Metadata(
|
|
237
|
-
run_id=run_id,
|
|
238
|
-
message_id="", # Will be set by the server
|
|
239
|
-
src_node_id=self.node.node_id,
|
|
240
|
-
dst_node_id=dst_node_id,
|
|
241
|
-
reply_to_message="",
|
|
242
|
-
group_id=group_id,
|
|
243
|
-
ttl=ttl_,
|
|
244
|
-
message_type=message_type,
|
|
245
|
-
)
|
|
246
|
-
return Message(metadata=metadata, content=content)
|
|
247
65
|
|
|
66
|
+
@abstractmethod
|
|
248
67
|
def get_node_ids(self) -> List[int]:
|
|
249
68
|
"""Get node IDs."""
|
|
250
|
-
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
251
|
-
# Call GrpcDriverHelper method
|
|
252
|
-
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
|
|
253
|
-
return [node.node_id for node in res.nodes]
|
|
254
69
|
|
|
70
|
+
@abstractmethod
|
|
255
71
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
256
72
|
"""Push messages to specified node IDs.
|
|
257
73
|
|
|
@@ -269,22 +85,8 @@ class Driver:
|
|
|
269
85
|
An iterable of IDs for the messages that were sent, which can be used
|
|
270
86
|
to pull replies.
|
|
271
87
|
"""
|
|
272
|
-
grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
|
|
273
|
-
# Construct TaskIns
|
|
274
|
-
task_ins_list: List[TaskIns] = []
|
|
275
|
-
for msg in messages:
|
|
276
|
-
# Check message
|
|
277
|
-
self._check_message(msg)
|
|
278
|
-
# Convert Message to TaskIns
|
|
279
|
-
taskins = message_to_taskins(msg)
|
|
280
|
-
# Add to list
|
|
281
|
-
task_ins_list.append(taskins)
|
|
282
|
-
# Call GrpcDriverHelper method
|
|
283
|
-
res = grpc_driver_helper.push_task_ins(
|
|
284
|
-
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
285
|
-
)
|
|
286
|
-
return list(res.task_ids)
|
|
287
88
|
|
|
89
|
+
@abstractmethod
|
|
288
90
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
289
91
|
"""Pull messages based on message IDs.
|
|
290
92
|
|
|
@@ -301,15 +103,8 @@ class Driver:
|
|
|
301
103
|
messages : Iterable[Message]
|
|
302
104
|
An iterable of messages received.
|
|
303
105
|
"""
|
|
304
|
-
grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
|
|
305
|
-
# Pull TaskRes
|
|
306
|
-
res = grpc_driver.pull_task_res(
|
|
307
|
-
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
308
|
-
)
|
|
309
|
-
# Convert TaskRes to Message
|
|
310
|
-
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
311
|
-
return msgs
|
|
312
106
|
|
|
107
|
+
@abstractmethod
|
|
313
108
|
def send_and_receive(
|
|
314
109
|
self,
|
|
315
110
|
messages: Iterable[Message],
|
|
@@ -343,28 +138,3 @@ class Driver:
|
|
|
343
138
|
replies for all sent messages. A message remains valid until its TTL,
|
|
344
139
|
which is not affected by `timeout`.
|
|
345
140
|
"""
|
|
346
|
-
# Push messages
|
|
347
|
-
msg_ids = set(self.push_messages(messages))
|
|
348
|
-
|
|
349
|
-
# Pull messages
|
|
350
|
-
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
351
|
-
ret: List[Message] = []
|
|
352
|
-
while timeout is None or time.time() < end_time:
|
|
353
|
-
res_msgs = self.pull_messages(msg_ids)
|
|
354
|
-
ret.extend(res_msgs)
|
|
355
|
-
msg_ids.difference_update(
|
|
356
|
-
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
357
|
-
)
|
|
358
|
-
if len(msg_ids) == 0:
|
|
359
|
-
break
|
|
360
|
-
# Sleep
|
|
361
|
-
time.sleep(3)
|
|
362
|
-
return ret
|
|
363
|
-
|
|
364
|
-
def close(self) -> None:
|
|
365
|
-
"""Disconnect from the SuperLink if connected."""
|
|
366
|
-
# Check if GrpcDriverHelper is initialized
|
|
367
|
-
if self.grpc_driver_helper is None:
|
|
368
|
-
return
|
|
369
|
-
# Disconnect
|
|
370
|
-
self.grpc_driver_helper.disconnect()
|