flwr-nightly 1.8.0.dev20240328__py3-none-any.whl → 1.8.0.dev20240402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +26 -13
- flwr/client/client_app.py +16 -0
- flwr/client/grpc_rere_client/connection.py +71 -29
- flwr/client/heartbeat.py +72 -0
- flwr/client/rest_client/connection.py +102 -28
- flwr/common/constant.py +20 -0
- flwr/common/logger.py +4 -4
- flwr/common/message.py +15 -0
- flwr/common/retry_invoker.py +24 -13
- flwr/proto/fleet_pb2.py +26 -26
- flwr/proto/fleet_pb2.pyi +5 -0
- flwr/server/driver/driver.py +15 -5
- flwr/server/server_app.py +3 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +3 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -0
- flwr/server/superlink/fleet/vce/vce_api.py +22 -4
- flwr/server/superlink/state/in_memory_state.py +25 -8
- flwr/server/superlink/state/sqlite_state.py +53 -5
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/utils.py +56 -0
- flwr/server/workflow/default_workflows.py +1 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +0 -5
- flwr/simulation/ray_transport/ray_actor.py +2 -22
- {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/RECORD +28 -26
- {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/entry_points.txt +0 -0
flwr/common/logger.py
CHANGED
|
@@ -164,13 +164,13 @@ logger = logging.getLogger(LOGGER_NAME) # pylint: disable=invalid-name
|
|
|
164
164
|
log = logger.log # pylint: disable=invalid-name
|
|
165
165
|
|
|
166
166
|
|
|
167
|
-
def
|
|
168
|
-
"""Warn the user when they use
|
|
167
|
+
def warn_preview_feature(name: str) -> None:
|
|
168
|
+
"""Warn the user when they use a preview feature."""
|
|
169
169
|
log(
|
|
170
170
|
WARN,
|
|
171
|
-
"""
|
|
171
|
+
"""PREVIEW FEATURE: %s
|
|
172
172
|
|
|
173
|
-
This is
|
|
173
|
+
This is a preview feature. It could change significantly or be removed
|
|
174
174
|
entirely in future versions of Flower.
|
|
175
175
|
""",
|
|
176
176
|
name,
|
flwr/common/message.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
19
|
import time
|
|
20
|
+
import warnings
|
|
20
21
|
from dataclasses import dataclass
|
|
21
22
|
|
|
22
23
|
from .record import RecordSet
|
|
@@ -311,6 +312,13 @@ class Message:
|
|
|
311
312
|
|
|
312
313
|
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
|
|
313
314
|
"""
|
|
315
|
+
if ttl:
|
|
316
|
+
warnings.warn(
|
|
317
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
318
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
319
|
+
"version of Flower.",
|
|
320
|
+
stacklevel=2,
|
|
321
|
+
)
|
|
314
322
|
# If no TTL passed, use default for message creation (will update after
|
|
315
323
|
# message creation)
|
|
316
324
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
@@ -349,6 +357,13 @@ class Message:
|
|
|
349
357
|
Message
|
|
350
358
|
A new `Message` instance representing the reply.
|
|
351
359
|
"""
|
|
360
|
+
if ttl:
|
|
361
|
+
warnings.warn(
|
|
362
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
363
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
364
|
+
"version of Flower.",
|
|
365
|
+
stacklevel=2,
|
|
366
|
+
)
|
|
352
367
|
# If no TTL passed, use default for message creation (will update after
|
|
353
368
|
# message creation)
|
|
354
369
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -107,7 +107,7 @@ class RetryInvoker:
|
|
|
107
107
|
|
|
108
108
|
Parameters
|
|
109
109
|
----------
|
|
110
|
-
|
|
110
|
+
wait_gen_factory: Callable[[], Generator[float, None, None]]
|
|
111
111
|
A generator yielding successive wait times in seconds. If the generator
|
|
112
112
|
is finite, the giveup event will be triggered when the generator raises
|
|
113
113
|
`StopIteration`.
|
|
@@ -129,12 +129,12 @@ class RetryInvoker:
|
|
|
129
129
|
data class object detailing the invocation.
|
|
130
130
|
on_giveup: Optional[Callable[[RetryState], None]] (default: None)
|
|
131
131
|
A callable to be executed in the event that `max_tries` or `max_time` is
|
|
132
|
-
exceeded, `should_giveup` returns True, or `
|
|
132
|
+
exceeded, `should_giveup` returns True, or `wait_gen_factory()` generator raises
|
|
133
133
|
`StopInteration`. The parameter is a data class object detailing the
|
|
134
134
|
invocation.
|
|
135
135
|
jitter: Optional[Callable[[float], float]] (default: full_jitter)
|
|
136
|
-
A function of the value yielded by `
|
|
137
|
-
to wait. This function helps distribute wait times stochastically to avoid
|
|
136
|
+
A function of the value yielded by `wait_gen_factory()` returning the actual
|
|
137
|
+
time to wait. This function helps distribute wait times stochastically to avoid
|
|
138
138
|
timing collisions across concurrent clients. Wait times are jittered by
|
|
139
139
|
default using the `full_jitter` function. To disable jittering, pass
|
|
140
140
|
`jitter=None`.
|
|
@@ -142,6 +142,13 @@ class RetryInvoker:
|
|
|
142
142
|
A function accepting an exception instance, returning whether or not
|
|
143
143
|
to give up prematurely before other give-up conditions are evaluated.
|
|
144
144
|
If set to None, the strategy is to never give up prematurely.
|
|
145
|
+
wait_function: Optional[Callable[[float], None]] (default: None)
|
|
146
|
+
A function that defines how to wait between retry attempts. It accepts
|
|
147
|
+
one argument, the wait time in seconds, allowing the use of various waiting
|
|
148
|
+
mechanisms (e.g., asynchronous waits or event-based synchronization) suitable
|
|
149
|
+
for different execution environments. If set to `None`, the `wait_function`
|
|
150
|
+
defaults to `time.sleep`, which is ideal for synchronous operations. Custom
|
|
151
|
+
functions should manage execution flow to prevent blocking or interference.
|
|
145
152
|
|
|
146
153
|
Examples
|
|
147
154
|
--------
|
|
@@ -159,7 +166,7 @@ class RetryInvoker:
|
|
|
159
166
|
# pylint: disable-next=too-many-arguments
|
|
160
167
|
def __init__(
|
|
161
168
|
self,
|
|
162
|
-
|
|
169
|
+
wait_gen_factory: Callable[[], Generator[float, None, None]],
|
|
163
170
|
recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
|
|
164
171
|
max_tries: Optional[int],
|
|
165
172
|
max_time: Optional[float],
|
|
@@ -169,8 +176,9 @@ class RetryInvoker:
|
|
|
169
176
|
on_giveup: Optional[Callable[[RetryState], None]] = None,
|
|
170
177
|
jitter: Optional[Callable[[float], float]] = full_jitter,
|
|
171
178
|
should_giveup: Optional[Callable[[Exception], bool]] = None,
|
|
179
|
+
wait_function: Optional[Callable[[float], None]] = None,
|
|
172
180
|
) -> None:
|
|
173
|
-
self.
|
|
181
|
+
self.wait_gen_factory = wait_gen_factory
|
|
174
182
|
self.recoverable_exceptions = recoverable_exceptions
|
|
175
183
|
self.max_tries = max_tries
|
|
176
184
|
self.max_time = max_time
|
|
@@ -179,6 +187,9 @@ class RetryInvoker:
|
|
|
179
187
|
self.on_giveup = on_giveup
|
|
180
188
|
self.jitter = jitter
|
|
181
189
|
self.should_giveup = should_giveup
|
|
190
|
+
if wait_function is None:
|
|
191
|
+
wait_function = time.sleep
|
|
192
|
+
self.wait_function = wait_function
|
|
182
193
|
|
|
183
194
|
# pylint: disable-next=too-many-locals
|
|
184
195
|
def invoke(
|
|
@@ -212,13 +223,13 @@ class RetryInvoker:
|
|
|
212
223
|
Raises
|
|
213
224
|
------
|
|
214
225
|
Exception
|
|
215
|
-
If the number of tries exceeds `max_tries`, if the total time
|
|
216
|
-
|
|
226
|
+
If the number of tries exceeds `max_tries`, if the total time exceeds
|
|
227
|
+
`max_time`, if `wait_gen_factory()` generator raises `StopInteration`,
|
|
217
228
|
or if the `should_giveup` returns True for a raised exception.
|
|
218
229
|
|
|
219
230
|
Notes
|
|
220
231
|
-----
|
|
221
|
-
The time between retries is determined by the provided `
|
|
232
|
+
The time between retries is determined by the provided `wait_gen_factory()`
|
|
222
233
|
generator and can optionally be jittered using the `jitter` function.
|
|
223
234
|
The recoverable exceptions that trigger a retry, as well as conditions to
|
|
224
235
|
stop retries, are also determined by the class's initialization parameters.
|
|
@@ -231,13 +242,13 @@ class RetryInvoker:
|
|
|
231
242
|
handler(cast(RetryState, ref_state[0]))
|
|
232
243
|
|
|
233
244
|
try_cnt = 0
|
|
234
|
-
wait_generator = self.
|
|
235
|
-
start = time.
|
|
245
|
+
wait_generator = self.wait_gen_factory()
|
|
246
|
+
start = time.monotonic()
|
|
236
247
|
ref_state: List[Optional[RetryState]] = [None]
|
|
237
248
|
|
|
238
249
|
while True:
|
|
239
250
|
try_cnt += 1
|
|
240
|
-
elapsed_time = time.
|
|
251
|
+
elapsed_time = time.monotonic() - start
|
|
241
252
|
state = RetryState(
|
|
242
253
|
target=target,
|
|
243
254
|
args=args,
|
|
@@ -282,7 +293,7 @@ class RetryInvoker:
|
|
|
282
293
|
try_call_event_handler(self.on_backoff)
|
|
283
294
|
|
|
284
295
|
# Sleep
|
|
285
|
-
|
|
296
|
+
self.wait_function(state.actual_wait)
|
|
286
297
|
else:
|
|
287
298
|
# Trigger success event
|
|
288
299
|
try_call_event_handler(self.on_success)
|
flwr/proto/fleet_pb2.py
CHANGED
|
@@ -16,7 +16,7 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
|
16
16
|
from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -26,29 +26,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
26
26
|
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None
|
|
27
27
|
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
|
|
28
28
|
_globals['_CREATENODEREQUEST']._serialized_start=84
|
|
29
|
-
_globals['_CREATENODEREQUEST']._serialized_end=
|
|
30
|
-
_globals['_CREATENODERESPONSE']._serialized_start=
|
|
31
|
-
_globals['_CREATENODERESPONSE']._serialized_end=
|
|
32
|
-
_globals['_DELETENODEREQUEST']._serialized_start=
|
|
33
|
-
_globals['_DELETENODEREQUEST']._serialized_end=
|
|
34
|
-
_globals['_DELETENODERESPONSE']._serialized_start=
|
|
35
|
-
_globals['_DELETENODERESPONSE']._serialized_end=
|
|
36
|
-
_globals['_PINGREQUEST']._serialized_start=
|
|
37
|
-
_globals['_PINGREQUEST']._serialized_end=
|
|
38
|
-
_globals['_PINGRESPONSE']._serialized_start=
|
|
39
|
-
_globals['_PINGRESPONSE']._serialized_end=
|
|
40
|
-
_globals['_PULLTASKINSREQUEST']._serialized_start=
|
|
41
|
-
_globals['_PULLTASKINSREQUEST']._serialized_end=
|
|
42
|
-
_globals['_PULLTASKINSRESPONSE']._serialized_start=
|
|
43
|
-
_globals['_PULLTASKINSRESPONSE']._serialized_end=
|
|
44
|
-
_globals['_PUSHTASKRESREQUEST']._serialized_start=
|
|
45
|
-
_globals['_PUSHTASKRESREQUEST']._serialized_end=
|
|
46
|
-
_globals['_PUSHTASKRESRESPONSE']._serialized_start=
|
|
47
|
-
_globals['_PUSHTASKRESRESPONSE']._serialized_end=
|
|
48
|
-
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=
|
|
49
|
-
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=
|
|
50
|
-
_globals['_RECONNECT']._serialized_start=
|
|
51
|
-
_globals['_RECONNECT']._serialized_end=
|
|
52
|
-
_globals['_FLEET']._serialized_start=
|
|
53
|
-
_globals['_FLEET']._serialized_end=
|
|
29
|
+
_globals['_CREATENODEREQUEST']._serialized_end=126
|
|
30
|
+
_globals['_CREATENODERESPONSE']._serialized_start=128
|
|
31
|
+
_globals['_CREATENODERESPONSE']._serialized_end=180
|
|
32
|
+
_globals['_DELETENODEREQUEST']._serialized_start=182
|
|
33
|
+
_globals['_DELETENODEREQUEST']._serialized_end=233
|
|
34
|
+
_globals['_DELETENODERESPONSE']._serialized_start=235
|
|
35
|
+
_globals['_DELETENODERESPONSE']._serialized_end=255
|
|
36
|
+
_globals['_PINGREQUEST']._serialized_start=257
|
|
37
|
+
_globals['_PINGREQUEST']._serialized_end=325
|
|
38
|
+
_globals['_PINGRESPONSE']._serialized_start=327
|
|
39
|
+
_globals['_PINGRESPONSE']._serialized_end=358
|
|
40
|
+
_globals['_PULLTASKINSREQUEST']._serialized_start=360
|
|
41
|
+
_globals['_PULLTASKINSREQUEST']._serialized_end=430
|
|
42
|
+
_globals['_PULLTASKINSRESPONSE']._serialized_start=432
|
|
43
|
+
_globals['_PULLTASKINSRESPONSE']._serialized_end=539
|
|
44
|
+
_globals['_PUSHTASKRESREQUEST']._serialized_start=541
|
|
45
|
+
_globals['_PUSHTASKRESREQUEST']._serialized_end=605
|
|
46
|
+
_globals['_PUSHTASKRESRESPONSE']._serialized_start=608
|
|
47
|
+
_globals['_PUSHTASKRESRESPONSE']._serialized_end=782
|
|
48
|
+
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=736
|
|
49
|
+
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=782
|
|
50
|
+
_globals['_RECONNECT']._serialized_start=784
|
|
51
|
+
_globals['_RECONNECT']._serialized_end=814
|
|
52
|
+
_globals['_FLEET']._serialized_start=817
|
|
53
|
+
_globals['_FLEET']._serialized_end=1207
|
|
54
54
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fleet_pb2.pyi
CHANGED
|
@@ -16,8 +16,13 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
|
16
16
|
class CreateNodeRequest(google.protobuf.message.Message):
|
|
17
17
|
"""CreateNode messages"""
|
|
18
18
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
19
|
+
PING_INTERVAL_FIELD_NUMBER: builtins.int
|
|
20
|
+
ping_interval: builtins.float
|
|
19
21
|
def __init__(self,
|
|
22
|
+
*,
|
|
23
|
+
ping_interval: builtins.float = ...,
|
|
20
24
|
) -> None: ...
|
|
25
|
+
def ClearField(self, field_name: typing_extensions.Literal["ping_interval",b"ping_interval"]) -> None: ...
|
|
21
26
|
global___CreateNodeRequest = CreateNodeRequest
|
|
22
27
|
|
|
23
28
|
class CreateNodeResponse(google.protobuf.message.Message):
|
flwr/server/driver/driver.py
CHANGED
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower driver service client."""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
import time
|
|
18
|
+
import warnings
|
|
19
19
|
from typing import Iterable, List, Optional, Tuple
|
|
20
20
|
|
|
21
21
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
@@ -91,7 +91,7 @@ class Driver:
|
|
|
91
91
|
message_type: str,
|
|
92
92
|
dst_node_id: int,
|
|
93
93
|
group_id: str,
|
|
94
|
-
ttl: float =
|
|
94
|
+
ttl: Optional[float] = None,
|
|
95
95
|
) -> Message:
|
|
96
96
|
"""Create a new message with specified parameters.
|
|
97
97
|
|
|
@@ -111,10 +111,11 @@ class Driver:
|
|
|
111
111
|
group_id : str
|
|
112
112
|
The ID of the group to which this message is associated. In some settings,
|
|
113
113
|
this is used as the FL round.
|
|
114
|
-
ttl : float (default:
|
|
114
|
+
ttl : Optional[float] (default: None)
|
|
115
115
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
116
116
|
this message to receiving a reply. It specifies in seconds the duration for
|
|
117
|
-
which the message and its potential reply are considered valid.
|
|
117
|
+
which the message and its potential reply are considered valid. If unset,
|
|
118
|
+
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
|
|
118
119
|
|
|
119
120
|
Returns
|
|
120
121
|
-------
|
|
@@ -122,6 +123,15 @@ class Driver:
|
|
|
122
123
|
A new `Message` instance with the specified content and metadata.
|
|
123
124
|
"""
|
|
124
125
|
_, run_id = self._get_grpc_driver_and_run_id()
|
|
126
|
+
if ttl:
|
|
127
|
+
warnings.warn(
|
|
128
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
129
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
130
|
+
"version of Flower.",
|
|
131
|
+
stacklevel=2,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
125
135
|
metadata = Metadata(
|
|
126
136
|
run_id=run_id,
|
|
127
137
|
message_id="", # Will be set by the server
|
|
@@ -129,7 +139,7 @@ class Driver:
|
|
|
129
139
|
dst_node_id=dst_node_id,
|
|
130
140
|
reply_to_message="",
|
|
131
141
|
group_id=group_id,
|
|
132
|
-
ttl=
|
|
142
|
+
ttl=ttl_,
|
|
133
143
|
message_type=message_type,
|
|
134
144
|
)
|
|
135
145
|
return Message(metadata=metadata, content=content)
|
flwr/server/server_app.py
CHANGED
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
from typing import Callable, Optional
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context, RecordSet
|
|
21
|
+
from flwr.common.logger import warn_preview_feature
|
|
21
22
|
from flwr.server.strategy import Strategy
|
|
22
23
|
|
|
23
24
|
from .client_manager import ClientManager
|
|
@@ -120,6 +121,8 @@ class ServerApp:
|
|
|
120
121
|
""",
|
|
121
122
|
)
|
|
122
123
|
|
|
124
|
+
warn_preview_feature("ServerApp-register-main-function")
|
|
125
|
+
|
|
123
126
|
# Register provided function with the ServerApp object
|
|
124
127
|
self._main = main_fn
|
|
125
128
|
|
|
@@ -43,7 +43,7 @@ def create_node(
|
|
|
43
43
|
) -> CreateNodeResponse:
|
|
44
44
|
"""."""
|
|
45
45
|
# Create node
|
|
46
|
-
node_id = state.create_node()
|
|
46
|
+
node_id = state.create_node(ping_interval=request.ping_interval)
|
|
47
47
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
48
48
|
|
|
49
49
|
|
|
@@ -63,7 +63,8 @@ def ping(
|
|
|
63
63
|
state: State, # pylint: disable=unused-argument
|
|
64
64
|
) -> PingResponse:
|
|
65
65
|
"""."""
|
|
66
|
-
|
|
66
|
+
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
67
|
+
return PingResponse(success=res)
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
|
|
@@ -21,6 +21,7 @@ from flwr.common.constant import MISSING_EXTRA_REST
|
|
|
21
21
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
22
|
CreateNodeRequest,
|
|
23
23
|
DeleteNodeRequest,
|
|
24
|
+
PingRequest,
|
|
24
25
|
PullTaskInsRequest,
|
|
25
26
|
PushTaskResRequest,
|
|
26
27
|
)
|
|
@@ -152,11 +153,38 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
|
|
|
152
153
|
)
|
|
153
154
|
|
|
154
155
|
|
|
156
|
+
async def ping(request: Request) -> Response:
|
|
157
|
+
"""Ping."""
|
|
158
|
+
_check_headers(request.headers)
|
|
159
|
+
|
|
160
|
+
# Get the request body as raw bytes
|
|
161
|
+
ping_request_bytes: bytes = await request.body()
|
|
162
|
+
|
|
163
|
+
# Deserialize ProtoBuf
|
|
164
|
+
ping_request_proto = PingRequest()
|
|
165
|
+
ping_request_proto.ParseFromString(ping_request_bytes)
|
|
166
|
+
|
|
167
|
+
# Get state from app
|
|
168
|
+
state: State = app.state.STATE_FACTORY.state()
|
|
169
|
+
|
|
170
|
+
# Handle message
|
|
171
|
+
ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
|
|
172
|
+
|
|
173
|
+
# Return serialized ProtoBuf
|
|
174
|
+
ping_response_bytes = ping_response_proto.SerializeToString()
|
|
175
|
+
return Response(
|
|
176
|
+
status_code=200,
|
|
177
|
+
content=ping_response_bytes,
|
|
178
|
+
headers={"Content-Type": "application/protobuf"},
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
155
182
|
routes = [
|
|
156
183
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
|
157
184
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
|
158
185
|
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
|
|
159
186
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
|
187
|
+
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
160
188
|
]
|
|
161
189
|
|
|
162
190
|
app: Starlette = Starlette(
|
|
@@ -22,8 +22,9 @@ import traceback
|
|
|
22
22
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
23
23
|
from typing import Callable, Dict, List, Optional
|
|
24
24
|
|
|
25
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
25
|
+
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
26
26
|
from flwr.client.node_state import NodeState
|
|
27
|
+
from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
|
|
27
28
|
from flwr.common.logger import log
|
|
28
29
|
from flwr.common.message import Error
|
|
29
30
|
from flwr.common.object_ref import load_app
|
|
@@ -43,7 +44,7 @@ def _register_nodes(
|
|
|
43
44
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
44
45
|
state = state_factory.state()
|
|
45
46
|
for i in range(num_nodes):
|
|
46
|
-
node_id = state.create_node()
|
|
47
|
+
node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
|
|
47
48
|
nodes_mapping[node_id] = i
|
|
48
49
|
log(INFO, "Registered %i nodes", len(nodes_mapping))
|
|
49
50
|
return nodes_mapping
|
|
@@ -93,9 +94,18 @@ async def worker(
|
|
|
93
94
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
94
95
|
log(ERROR, ex)
|
|
95
96
|
log(ERROR, traceback.format_exc())
|
|
97
|
+
|
|
98
|
+
if isinstance(ex, ClientAppException):
|
|
99
|
+
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
100
|
+
elif isinstance(ex, LoadClientAppError):
|
|
101
|
+
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
102
|
+
else:
|
|
103
|
+
e_code = ErrorCode.UNKNOWN
|
|
104
|
+
|
|
96
105
|
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
97
|
-
|
|
98
|
-
|
|
106
|
+
out_mssg = message.create_error_reply(
|
|
107
|
+
error=Error(code=e_code, reason=reason)
|
|
108
|
+
)
|
|
99
109
|
|
|
100
110
|
finally:
|
|
101
111
|
if out_mssg:
|
|
@@ -223,6 +233,7 @@ async def run(
|
|
|
223
233
|
|
|
224
234
|
|
|
225
235
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
236
|
+
# pylint: disable=too-many-statements
|
|
226
237
|
def start_vce(
|
|
227
238
|
backend_name: str,
|
|
228
239
|
backend_config_json_stream: str,
|
|
@@ -341,6 +352,13 @@ def start_vce(
|
|
|
341
352
|
)
|
|
342
353
|
)
|
|
343
354
|
except LoadClientAppError as loadapp_ex:
|
|
355
|
+
f_stop_delay = 10
|
|
356
|
+
log(
|
|
357
|
+
ERROR,
|
|
358
|
+
"LoadClientAppError exception encountered. Terminating simulation in %is",
|
|
359
|
+
f_stop_delay,
|
|
360
|
+
)
|
|
361
|
+
time.sleep(f_stop_delay)
|
|
344
362
|
f_stop.set() # set termination event
|
|
345
363
|
raise loadapp_ex
|
|
346
364
|
except Exception as ex:
|
|
@@ -27,6 +27,8 @@ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
|
27
27
|
from flwr.server.superlink.state.state import State
|
|
28
28
|
from flwr.server.utils import validate_task_ins_or_res
|
|
29
29
|
|
|
30
|
+
from .utils import make_node_unavailable_taskres
|
|
31
|
+
|
|
30
32
|
|
|
31
33
|
class InMemoryState(State):
|
|
32
34
|
"""In-memory State implementation."""
|
|
@@ -129,15 +131,32 @@ class InMemoryState(State):
|
|
|
129
131
|
with self.lock:
|
|
130
132
|
# Find TaskRes that were not delivered yet
|
|
131
133
|
task_res_list: List[TaskRes] = []
|
|
134
|
+
replied_task_ids: Set[UUID] = set()
|
|
132
135
|
for _, task_res in self.task_res_store.items():
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
and task_res.task.delivered_at == ""
|
|
136
|
-
):
|
|
136
|
+
reply_to = UUID(task_res.task.ancestry[0])
|
|
137
|
+
if reply_to in task_ids and task_res.task.delivered_at == "":
|
|
137
138
|
task_res_list.append(task_res)
|
|
139
|
+
replied_task_ids.add(reply_to)
|
|
138
140
|
if limit and len(task_res_list) == limit:
|
|
139
141
|
break
|
|
140
142
|
|
|
143
|
+
# Check if the node is offline
|
|
144
|
+
for task_id in task_ids - replied_task_ids:
|
|
145
|
+
if limit and len(task_res_list) == limit:
|
|
146
|
+
break
|
|
147
|
+
task_ins = self.task_ins_store.get(task_id)
|
|
148
|
+
if task_ins is None:
|
|
149
|
+
continue
|
|
150
|
+
node_id = task_ins.task.consumer.node_id
|
|
151
|
+
online_until, _ = self.node_ids[node_id]
|
|
152
|
+
# Generate a TaskRes containing an error reply if the node is offline.
|
|
153
|
+
if online_until < time.time():
|
|
154
|
+
err_taskres = make_node_unavailable_taskres(
|
|
155
|
+
ref_taskins=task_ins,
|
|
156
|
+
)
|
|
157
|
+
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
|
|
158
|
+
task_res_list.append(err_taskres)
|
|
159
|
+
|
|
141
160
|
# Mark all of them as delivered
|
|
142
161
|
delivered_at = now().isoformat()
|
|
143
162
|
for task_res in task_res_list:
|
|
@@ -182,16 +201,14 @@ class InMemoryState(State):
|
|
|
182
201
|
"""
|
|
183
202
|
return len(self.task_res_store)
|
|
184
203
|
|
|
185
|
-
def create_node(self) -> int:
|
|
204
|
+
def create_node(self, ping_interval: float) -> int:
|
|
186
205
|
"""Create, store in state, and return `node_id`."""
|
|
187
206
|
# Sample a random int64 as node_id
|
|
188
207
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
189
208
|
|
|
190
209
|
with self.lock:
|
|
191
210
|
if node_id not in self.node_ids:
|
|
192
|
-
|
|
193
|
-
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
194
|
-
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
|
|
211
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
195
212
|
return node_id
|
|
196
213
|
log(ERROR, "Unexpected node registration failure.")
|
|
197
214
|
return 0
|
|
@@ -30,6 +30,7 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
|
30
30
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
31
31
|
|
|
32
32
|
from .state import State
|
|
33
|
+
from .utils import make_node_unavailable_taskres
|
|
33
34
|
|
|
34
35
|
SQL_CREATE_TABLE_NODE = """
|
|
35
36
|
CREATE TABLE IF NOT EXISTS node(
|
|
@@ -344,6 +345,7 @@ class SqliteState(State):
|
|
|
344
345
|
|
|
345
346
|
return task_id
|
|
346
347
|
|
|
348
|
+
# pylint: disable-next=R0914
|
|
347
349
|
def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
|
|
348
350
|
"""Get TaskRes for task_ids.
|
|
349
351
|
|
|
@@ -374,7 +376,7 @@ class SqliteState(State):
|
|
|
374
376
|
AND delivered_at = ""
|
|
375
377
|
"""
|
|
376
378
|
|
|
377
|
-
data: Dict[str, Union[str, int]] = {}
|
|
379
|
+
data: Dict[str, Union[str, float, int]] = {}
|
|
378
380
|
|
|
379
381
|
if limit is not None:
|
|
380
382
|
query += " LIMIT :limit"
|
|
@@ -408,6 +410,54 @@ class SqliteState(State):
|
|
|
408
410
|
rows = self.query(query, data)
|
|
409
411
|
|
|
410
412
|
result = [dict_to_task_res(row) for row in rows]
|
|
413
|
+
|
|
414
|
+
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
415
|
+
# Assume the ancestry field only contains one element
|
|
416
|
+
data.clear()
|
|
417
|
+
replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
418
|
+
remaining_task_ids = task_ids - replied_task_ids
|
|
419
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
420
|
+
query = f"""
|
|
421
|
+
SELECT consumer_node_id
|
|
422
|
+
FROM task_ins
|
|
423
|
+
WHERE task_id IN ({placeholders});
|
|
424
|
+
"""
|
|
425
|
+
for index, task_id in enumerate(remaining_task_ids):
|
|
426
|
+
data[f"id_{index}"] = str(task_id)
|
|
427
|
+
node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
|
|
428
|
+
|
|
429
|
+
# 2. Query: Select offline nodes
|
|
430
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
|
|
431
|
+
query = f"""
|
|
432
|
+
SELECT node_id
|
|
433
|
+
FROM node
|
|
434
|
+
WHERE node_id IN ({placeholders})
|
|
435
|
+
AND online_until < :time;
|
|
436
|
+
"""
|
|
437
|
+
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
|
|
438
|
+
data["time"] = time.time()
|
|
439
|
+
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
|
|
440
|
+
|
|
441
|
+
# 3. Query: Select TaskIns for offline nodes
|
|
442
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
|
|
443
|
+
query = f"""
|
|
444
|
+
SELECT *
|
|
445
|
+
FROM task_ins
|
|
446
|
+
WHERE consumer_node_id IN ({placeholders});
|
|
447
|
+
"""
|
|
448
|
+
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
449
|
+
task_ins_rows = self.query(query, data)
|
|
450
|
+
|
|
451
|
+
# Make TaskRes containing node unavailabe error
|
|
452
|
+
for row in task_ins_rows:
|
|
453
|
+
if limit and len(result) == limit:
|
|
454
|
+
break
|
|
455
|
+
task_ins = dict_to_task_ins(row)
|
|
456
|
+
err_taskres = make_node_unavailable_taskres(
|
|
457
|
+
ref_taskins=task_ins,
|
|
458
|
+
)
|
|
459
|
+
result.append(err_taskres)
|
|
460
|
+
|
|
411
461
|
return result
|
|
412
462
|
|
|
413
463
|
def num_task_ins(self) -> int:
|
|
@@ -468,7 +518,7 @@ class SqliteState(State):
|
|
|
468
518
|
|
|
469
519
|
return None
|
|
470
520
|
|
|
471
|
-
def create_node(self) -> int:
|
|
521
|
+
def create_node(self, ping_interval: float) -> int:
|
|
472
522
|
"""Create, store in state, and return `node_id`."""
|
|
473
523
|
# Sample a random int64 as node_id
|
|
474
524
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
@@ -478,9 +528,7 @@ class SqliteState(State):
|
|
|
478
528
|
)
|
|
479
529
|
|
|
480
530
|
try:
|
|
481
|
-
|
|
482
|
-
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
483
|
-
self.query(query, (node_id, time.time() + 1e9, 1e9))
|
|
531
|
+
self.query(query, (node_id, time.time() + ping_interval, ping_interval))
|
|
484
532
|
except sqlite3.IntegrityError:
|
|
485
533
|
log(ERROR, "Unexpected node registration failure.")
|
|
486
534
|
return 0
|
|
@@ -132,7 +132,7 @@ class State(abc.ABC):
|
|
|
132
132
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
133
133
|
|
|
134
134
|
@abc.abstractmethod
|
|
135
|
-
def create_node(self) -> int:
|
|
135
|
+
def create_node(self, ping_interval: float) -> int:
|
|
136
136
|
"""Create, store in state, and return `node_id`."""
|
|
137
137
|
|
|
138
138
|
@abc.abstractmethod
|