flwr-nightly 1.9.0.dev20240416__py3-none-any.whl → 1.9.0.dev20240420__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/{flower_toml.py → config_utils.py} +40 -7
- flwr/cli/new/new.py +9 -5
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +56 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +18 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +4 -0
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +4 -0
- flwr/cli/run/run.py +2 -2
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +7 -53
- flwr/client/grpc_client/connection.py +2 -1
- flwr/client/grpc_rere_client/connection.py +16 -2
- flwr/client/rest_client/connection.py +87 -168
- flwr/client/supernode/__init__.py +22 -0
- flwr/client/supernode/app.py +107 -0
- flwr/common/record/recordset.py +67 -28
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +5 -5
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +4 -2
- flwr/server/driver/__init__.py +0 -2
- flwr/server/driver/abc_driver.py +140 -0
- flwr/server/driver/driver.py +124 -21
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/state/in_memory_state.py +13 -4
- flwr/server/superlink/state/sqlite_state.py +17 -5
- flwr/server/superlink/state/state.py +21 -3
- {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/RECORD +34 -32
- {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/entry_points.txt +1 -0
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/server/driver/grpc_driver.py +0 -129
- {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/WHEEL +0 -0
|
@@ -25,7 +25,7 @@ from flwr.common import serde
|
|
|
25
25
|
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
26
|
from flwr.server.client_proxy import ClientProxy
|
|
27
27
|
|
|
28
|
-
from ..driver.
|
|
28
|
+
from ..driver.driver import GrpcDriverHelper
|
|
29
29
|
|
|
30
30
|
SLEEP_TIME = 1
|
|
31
31
|
|
|
@@ -33,7 +33,9 @@ SLEEP_TIME = 1
|
|
|
33
33
|
class DriverClientProxy(ClientProxy):
|
|
34
34
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
35
35
|
|
|
36
|
-
def __init__(
|
|
36
|
+
def __init__(
|
|
37
|
+
self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
|
|
38
|
+
):
|
|
37
39
|
super().__init__(str(node_id))
|
|
38
40
|
self.node_id = node_id
|
|
39
41
|
self.driver = driver
|
flwr/server/driver/__init__.py
CHANGED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Driver (abstract base class)."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Iterable, List, Optional
|
|
20
|
+
|
|
21
|
+
from flwr.common import Message, RecordSet
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Driver(ABC):
|
|
25
|
+
"""Abstract base Driver class for the Driver API."""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
29
|
+
self,
|
|
30
|
+
content: RecordSet,
|
|
31
|
+
message_type: str,
|
|
32
|
+
dst_node_id: int,
|
|
33
|
+
group_id: str,
|
|
34
|
+
ttl: Optional[float] = None,
|
|
35
|
+
) -> Message:
|
|
36
|
+
"""Create a new message with specified parameters.
|
|
37
|
+
|
|
38
|
+
This method constructs a new `Message` with given content and metadata.
|
|
39
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
content : RecordSet
|
|
44
|
+
The content for the new message. This holds records that are to be sent
|
|
45
|
+
to the destination node.
|
|
46
|
+
message_type : str
|
|
47
|
+
The type of the message, defining the action to be executed on
|
|
48
|
+
the receiving end.
|
|
49
|
+
dst_node_id : int
|
|
50
|
+
The ID of the destination node to which the message is being sent.
|
|
51
|
+
group_id : str
|
|
52
|
+
The ID of the group to which this message is associated. In some settings,
|
|
53
|
+
this is used as the FL round.
|
|
54
|
+
ttl : Optional[float] (default: None)
|
|
55
|
+
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
56
|
+
this message to receiving a reply. It specifies in seconds the duration for
|
|
57
|
+
which the message and its potential reply are considered valid. If unset,
|
|
58
|
+
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
message : Message
|
|
63
|
+
A new `Message` instance with the specified content and metadata.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def get_node_ids(self) -> List[int]:
|
|
68
|
+
"""Get node IDs."""
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
72
|
+
"""Push messages to specified node IDs.
|
|
73
|
+
|
|
74
|
+
This method takes an iterable of messages and sends each message
|
|
75
|
+
to the node specified in `dst_node_id`.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
messages : Iterable[Message]
|
|
80
|
+
An iterable of messages to be sent.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
message_ids : Iterable[str]
|
|
85
|
+
An iterable of IDs for the messages that were sent, which can be used
|
|
86
|
+
to pull replies.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
91
|
+
"""Pull messages based on message IDs.
|
|
92
|
+
|
|
93
|
+
This method is used to collect messages from the SuperLink
|
|
94
|
+
that correspond to a set of given message IDs.
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
message_ids : Iterable[str]
|
|
99
|
+
An iterable of message IDs for which reply messages are to be retrieved.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
messages : Iterable[Message]
|
|
104
|
+
An iterable of messages received.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def send_and_receive(
|
|
109
|
+
self,
|
|
110
|
+
messages: Iterable[Message],
|
|
111
|
+
*,
|
|
112
|
+
timeout: Optional[float] = None,
|
|
113
|
+
) -> Iterable[Message]:
|
|
114
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
115
|
+
|
|
116
|
+
This method sends a list of messages to their destination node IDs and then
|
|
117
|
+
waits for the replies. It continues to pull replies until either all
|
|
118
|
+
replies are received or the specified timeout duration is exceeded.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
messages : Iterable[Message]
|
|
123
|
+
An iterable of messages to be sent.
|
|
124
|
+
timeout : Optional[float] (default: None)
|
|
125
|
+
The timeout duration in seconds. If specified, the method will wait for
|
|
126
|
+
replies for this duration. If `None`, there is no time limit and the method
|
|
127
|
+
will wait until replies for all messages are received.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
replies : Iterable[Message]
|
|
132
|
+
An iterable of reply messages received from the SuperLink.
|
|
133
|
+
|
|
134
|
+
Notes
|
|
135
|
+
-----
|
|
136
|
+
This method uses `push_messages` to send the messages and `pull_messages`
|
|
137
|
+
to collect the replies. If `timeout` is set, the method may not return
|
|
138
|
+
replies for all sent messages. A message remains valid until its TTL,
|
|
139
|
+
which is not affected by `timeout`.
|
|
140
|
+
"""
|
flwr/server/driver/driver.py
CHANGED
|
@@ -16,20 +16,121 @@
|
|
|
16
16
|
|
|
17
17
|
import time
|
|
18
18
|
import warnings
|
|
19
|
+
from logging import DEBUG, ERROR, WARNING
|
|
19
20
|
from typing import Iterable, List, Optional, Tuple
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
import grpc
|
|
23
|
+
|
|
24
|
+
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
25
|
+
from flwr.common.grpc import create_channel
|
|
26
|
+
from flwr.common.logger import log
|
|
22
27
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
23
28
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
24
29
|
CreateRunRequest,
|
|
30
|
+
CreateRunResponse,
|
|
25
31
|
GetNodesRequest,
|
|
32
|
+
GetNodesResponse,
|
|
26
33
|
PullTaskResRequest,
|
|
34
|
+
PullTaskResResponse,
|
|
27
35
|
PushTaskInsRequest,
|
|
36
|
+
PushTaskInsResponse,
|
|
28
37
|
)
|
|
38
|
+
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
29
39
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
30
40
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
31
41
|
|
|
32
|
-
|
|
42
|
+
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
43
|
+
|
|
44
|
+
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
45
|
+
[Driver] Error: Not connected.
|
|
46
|
+
|
|
47
|
+
Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
|
|
48
|
+
`GrpcDriverHelper` methods.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class GrpcDriverHelper:
|
|
53
|
+
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
58
|
+
root_certificates: Optional[bytes] = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
self.driver_service_address = driver_service_address
|
|
61
|
+
self.root_certificates = root_certificates
|
|
62
|
+
self.channel: Optional[grpc.Channel] = None
|
|
63
|
+
self.stub: Optional[DriverStub] = None
|
|
64
|
+
|
|
65
|
+
def connect(self) -> None:
|
|
66
|
+
"""Connect to the Driver API."""
|
|
67
|
+
event(EventType.DRIVER_CONNECT)
|
|
68
|
+
if self.channel is not None or self.stub is not None:
|
|
69
|
+
log(WARNING, "Already connected")
|
|
70
|
+
return
|
|
71
|
+
self.channel = create_channel(
|
|
72
|
+
server_address=self.driver_service_address,
|
|
73
|
+
insecure=(self.root_certificates is None),
|
|
74
|
+
root_certificates=self.root_certificates,
|
|
75
|
+
)
|
|
76
|
+
self.stub = DriverStub(self.channel)
|
|
77
|
+
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
|
|
78
|
+
|
|
79
|
+
def disconnect(self) -> None:
|
|
80
|
+
"""Disconnect from the Driver API."""
|
|
81
|
+
event(EventType.DRIVER_DISCONNECT)
|
|
82
|
+
if self.channel is None or self.stub is None:
|
|
83
|
+
log(DEBUG, "Already disconnected")
|
|
84
|
+
return
|
|
85
|
+
channel = self.channel
|
|
86
|
+
self.channel = None
|
|
87
|
+
self.stub = None
|
|
88
|
+
channel.close()
|
|
89
|
+
log(DEBUG, "[Driver] Disconnected")
|
|
90
|
+
|
|
91
|
+
def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
|
|
92
|
+
"""Request for run ID."""
|
|
93
|
+
# Check if channel is open
|
|
94
|
+
if self.stub is None:
|
|
95
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
96
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
97
|
+
|
|
98
|
+
# Call Driver API
|
|
99
|
+
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
100
|
+
return res
|
|
101
|
+
|
|
102
|
+
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
103
|
+
"""Get client IDs."""
|
|
104
|
+
# Check if channel is open
|
|
105
|
+
if self.stub is None:
|
|
106
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
107
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
108
|
+
|
|
109
|
+
# Call gRPC Driver API
|
|
110
|
+
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
111
|
+
return res
|
|
112
|
+
|
|
113
|
+
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
114
|
+
"""Schedule tasks."""
|
|
115
|
+
# Check if channel is open
|
|
116
|
+
if self.stub is None:
|
|
117
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
118
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
119
|
+
|
|
120
|
+
# Call gRPC Driver API
|
|
121
|
+
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
122
|
+
return res
|
|
123
|
+
|
|
124
|
+
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
125
|
+
"""Get task results."""
|
|
126
|
+
# Check if channel is open
|
|
127
|
+
if self.stub is None:
|
|
128
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
129
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
130
|
+
|
|
131
|
+
# Call Driver API
|
|
132
|
+
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
133
|
+
return res
|
|
33
134
|
|
|
34
135
|
|
|
35
136
|
class Driver:
|
|
@@ -57,22 +158,22 @@ class Driver:
|
|
|
57
158
|
) -> None:
|
|
58
159
|
self.addr = driver_service_address
|
|
59
160
|
self.root_certificates = root_certificates
|
|
60
|
-
self.
|
|
161
|
+
self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
|
|
61
162
|
self.run_id: Optional[int] = None
|
|
62
163
|
self.node = Node(node_id=0, anonymous=True)
|
|
63
164
|
|
|
64
|
-
def
|
|
65
|
-
# Check if the
|
|
66
|
-
if self.
|
|
165
|
+
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
|
|
166
|
+
# Check if the GrpcDriverHelper is initialized
|
|
167
|
+
if self.grpc_driver_helper is None or self.run_id is None:
|
|
67
168
|
# Connect and create run
|
|
68
|
-
self.
|
|
169
|
+
self.grpc_driver_helper = GrpcDriverHelper(
|
|
69
170
|
driver_service_address=self.addr,
|
|
70
171
|
root_certificates=self.root_certificates,
|
|
71
172
|
)
|
|
72
|
-
self.
|
|
73
|
-
res = self.
|
|
173
|
+
self.grpc_driver_helper.connect()
|
|
174
|
+
res = self.grpc_driver_helper.create_run(CreateRunRequest())
|
|
74
175
|
self.run_id = res.run_id
|
|
75
|
-
return self.
|
|
176
|
+
return self.grpc_driver_helper, self.run_id
|
|
76
177
|
|
|
77
178
|
def _check_message(self, message: Message) -> None:
|
|
78
179
|
# Check if the message is valid
|
|
@@ -122,7 +223,7 @@ class Driver:
|
|
|
122
223
|
message : Message
|
|
123
224
|
A new `Message` instance with the specified content and metadata.
|
|
124
225
|
"""
|
|
125
|
-
_, run_id = self.
|
|
226
|
+
_, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
126
227
|
if ttl:
|
|
127
228
|
warnings.warn(
|
|
128
229
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -146,9 +247,9 @@ class Driver:
|
|
|
146
247
|
|
|
147
248
|
def get_node_ids(self) -> List[int]:
|
|
148
249
|
"""Get node IDs."""
|
|
149
|
-
|
|
150
|
-
# Call
|
|
151
|
-
res =
|
|
250
|
+
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
251
|
+
# Call GrpcDriverHelper method
|
|
252
|
+
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
|
|
152
253
|
return [node.node_id for node in res.nodes]
|
|
153
254
|
|
|
154
255
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
@@ -168,7 +269,7 @@ class Driver:
|
|
|
168
269
|
An iterable of IDs for the messages that were sent, which can be used
|
|
169
270
|
to pull replies.
|
|
170
271
|
"""
|
|
171
|
-
|
|
272
|
+
grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
|
|
172
273
|
# Construct TaskIns
|
|
173
274
|
task_ins_list: List[TaskIns] = []
|
|
174
275
|
for msg in messages:
|
|
@@ -178,8 +279,10 @@ class Driver:
|
|
|
178
279
|
taskins = message_to_taskins(msg)
|
|
179
280
|
# Add to list
|
|
180
281
|
task_ins_list.append(taskins)
|
|
181
|
-
# Call
|
|
182
|
-
res =
|
|
282
|
+
# Call GrpcDriverHelper method
|
|
283
|
+
res = grpc_driver_helper.push_task_ins(
|
|
284
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
285
|
+
)
|
|
183
286
|
return list(res.task_ids)
|
|
184
287
|
|
|
185
288
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
@@ -198,7 +301,7 @@ class Driver:
|
|
|
198
301
|
messages : Iterable[Message]
|
|
199
302
|
An iterable of messages received.
|
|
200
303
|
"""
|
|
201
|
-
grpc_driver, _ = self.
|
|
304
|
+
grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
|
|
202
305
|
# Pull TaskRes
|
|
203
306
|
res = grpc_driver.pull_task_res(
|
|
204
307
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
@@ -260,8 +363,8 @@ class Driver:
|
|
|
260
363
|
|
|
261
364
|
def close(self) -> None:
|
|
262
365
|
"""Disconnect from the SuperLink if connected."""
|
|
263
|
-
# Check if
|
|
264
|
-
if self.
|
|
366
|
+
# Check if GrpcDriverHelper is initialized
|
|
367
|
+
if self.grpc_driver_helper is None:
|
|
265
368
|
return
|
|
266
369
|
# Disconnect
|
|
267
|
-
self.
|
|
370
|
+
self.grpc_driver_helper.disconnect()
|
|
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
64
64
|
"""Create run ID."""
|
|
65
65
|
log(INFO, "DriverServicer.CreateRun")
|
|
66
66
|
state: State = self.state_factory.state()
|
|
67
|
-
run_id = state.create_run()
|
|
67
|
+
run_id = state.create_run("None/None", "None")
|
|
68
68
|
return CreateRunResponse(run_id=run_id)
|
|
69
69
|
|
|
70
70
|
def PushTaskIns(
|
|
@@ -33,6 +33,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
33
33
|
PushTaskResRequest,
|
|
34
34
|
PushTaskResResponse,
|
|
35
35
|
Reconnect,
|
|
36
|
+
Run,
|
|
36
37
|
)
|
|
37
38
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
38
39
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -109,4 +110,6 @@ def get_run(
|
|
|
109
110
|
request: GetRunRequest, state: State # pylint: disable=W0613
|
|
110
111
|
) -> GetRunResponse:
|
|
111
112
|
"""Get run information."""
|
|
112
|
-
|
|
113
|
+
run_id, fab_id, fab_version = state.get_run(request.run_id)
|
|
114
|
+
run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version)
|
|
115
|
+
return GetRunResponse(run=run)
|
|
@@ -36,7 +36,8 @@ class InMemoryState(State):
|
|
|
36
36
|
def __init__(self) -> None:
|
|
37
37
|
# Map node_id to (online_until, ping_interval)
|
|
38
38
|
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
39
|
-
|
|
39
|
+
# Map run_id to (fab_id, fab_version)
|
|
40
|
+
self.run_ids: Dict[int, Tuple[str, str]] = {}
|
|
40
41
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
41
42
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
42
43
|
self.lock = threading.Lock()
|
|
@@ -238,18 +239,26 @@ class InMemoryState(State):
|
|
|
238
239
|
if online_until > current_time
|
|
239
240
|
}
|
|
240
241
|
|
|
241
|
-
def create_run(self) -> int:
|
|
242
|
-
"""Create
|
|
242
|
+
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
243
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
243
244
|
# Sample a random int64 as run_id
|
|
244
245
|
with self.lock:
|
|
245
246
|
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
246
247
|
|
|
247
248
|
if run_id not in self.run_ids:
|
|
248
|
-
self.run_ids
|
|
249
|
+
self.run_ids[run_id] = (fab_id, fab_version)
|
|
249
250
|
return run_id
|
|
250
251
|
log(ERROR, "Unexpected run creation failure.")
|
|
251
252
|
return 0
|
|
252
253
|
|
|
254
|
+
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
255
|
+
"""Retrieve information about the run with the specified `run_id`."""
|
|
256
|
+
with self.lock:
|
|
257
|
+
if run_id not in self.run_ids:
|
|
258
|
+
log(ERROR, "`run_id` is invalid")
|
|
259
|
+
return 0, "", ""
|
|
260
|
+
return run_id, *self.run_ids[run_id]
|
|
261
|
+
|
|
253
262
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
254
263
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
255
264
|
with self.lock:
|
|
@@ -46,7 +46,9 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
46
46
|
|
|
47
47
|
SQL_CREATE_TABLE_RUN = """
|
|
48
48
|
CREATE TABLE IF NOT EXISTS run(
|
|
49
|
-
run_id
|
|
49
|
+
run_id INTEGER UNIQUE,
|
|
50
|
+
fab_id TEXT,
|
|
51
|
+
fab_version TEXT
|
|
50
52
|
);
|
|
51
53
|
"""
|
|
52
54
|
|
|
@@ -558,8 +560,8 @@ class SqliteState(State):
|
|
|
558
560
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
559
561
|
return result
|
|
560
562
|
|
|
561
|
-
def create_run(self) -> int:
|
|
562
|
-
"""Create
|
|
563
|
+
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
564
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
563
565
|
# Sample a random int64 as run_id
|
|
564
566
|
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
565
567
|
|
|
@@ -567,12 +569,22 @@ class SqliteState(State):
|
|
|
567
569
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
568
570
|
# If run_id does not exist
|
|
569
571
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
570
|
-
query = "INSERT INTO run
|
|
571
|
-
self.query(query,
|
|
572
|
+
query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
|
|
573
|
+
self.query(query, (run_id, fab_id, fab_version))
|
|
572
574
|
return run_id
|
|
573
575
|
log(ERROR, "Unexpected run creation failure.")
|
|
574
576
|
return 0
|
|
575
577
|
|
|
578
|
+
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
579
|
+
"""Retrieve information about the run with the specified `run_id`."""
|
|
580
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
581
|
+
try:
|
|
582
|
+
row = self.query(query, (run_id,))[0]
|
|
583
|
+
return run_id, row["fab_id"], row["fab_version"]
|
|
584
|
+
except sqlite3.IntegrityError:
|
|
585
|
+
log(ERROR, "`run_id` does not exist.")
|
|
586
|
+
return 0, "", ""
|
|
587
|
+
|
|
576
588
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
577
589
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
578
590
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import List, Optional, Set
|
|
19
|
+
from typing import List, Optional, Set, Tuple
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -150,8 +150,26 @@ class State(abc.ABC):
|
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
152
|
@abc.abstractmethod
|
|
153
|
-
def create_run(self) -> int:
|
|
154
|
-
"""Create
|
|
153
|
+
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
154
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
155
|
+
|
|
156
|
+
@abc.abstractmethod
|
|
157
|
+
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
158
|
+
"""Retrieve information about the run with the specified `run_id`.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
run_id : int
|
|
163
|
+
The identifier of the run.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
Tuple[int, str, str]
|
|
168
|
+
A tuple containing three elements:
|
|
169
|
+
- `run_id`: The identifier of the run, same as the specified `run_id`.
|
|
170
|
+
- `fab_id`: The identifier of the FAB used in the specified run.
|
|
171
|
+
- `fab_version`: The version of the FAB used in the specified run.
|
|
172
|
+
"""
|
|
155
173
|
|
|
156
174
|
@abc.abstractmethod
|
|
157
175
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|