flwr-nightly 1.8.0.dev20240327__py3-none-any.whl → 1.8.0.dev20240402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +53 -29
- flwr/client/client_app.py +16 -0
- flwr/client/grpc_rere_client/connection.py +71 -29
- flwr/client/heartbeat.py +72 -0
- flwr/client/rest_client/connection.py +102 -28
- flwr/common/constant.py +20 -0
- flwr/common/logger.py +4 -4
- flwr/common/message.py +53 -14
- flwr/common/retry_invoker.py +24 -13
- flwr/proto/fleet_pb2.py +26 -26
- flwr/proto/fleet_pb2.pyi +5 -0
- flwr/server/compat/driver_client_proxy.py +16 -0
- flwr/server/driver/driver.py +15 -5
- flwr/server/server_app.py +3 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +3 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -4
- flwr/server/superlink/fleet/vce/vce_api.py +61 -27
- flwr/server/superlink/state/in_memory_state.py +25 -8
- flwr/server/superlink/state/sqlite_state.py +53 -5
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/utils.py +56 -0
- flwr/server/workflow/default_workflows.py +1 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +0 -5
- flwr/simulation/ray_transport/ray_actor.py +8 -24
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/RECORD +30 -28
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py
CHANGED
|
@@ -36,6 +36,13 @@ TRANSPORT_TYPES = [
|
|
|
36
36
|
TRANSPORT_TYPE_VCE,
|
|
37
37
|
]
|
|
38
38
|
|
|
39
|
+
# Constants for ping
|
|
40
|
+
PING_DEFAULT_INTERVAL = 30
|
|
41
|
+
PING_CALL_TIMEOUT = 5
|
|
42
|
+
PING_BASE_MULTIPLIER = 0.8
|
|
43
|
+
PING_RANDOM_RANGE = (-0.1, 0.1)
|
|
44
|
+
PING_MAX_INTERVAL = 1e300
|
|
45
|
+
|
|
39
46
|
|
|
40
47
|
class MessageType:
|
|
41
48
|
"""Message type."""
|
|
@@ -68,3 +75,16 @@ class SType:
|
|
|
68
75
|
def __new__(cls) -> SType:
|
|
69
76
|
"""Prevent instantiation."""
|
|
70
77
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ErrorCode:
|
|
81
|
+
"""Error codes for Message's Error."""
|
|
82
|
+
|
|
83
|
+
UNKNOWN = 0
|
|
84
|
+
LOAD_CLIENT_APP_EXCEPTION = 1
|
|
85
|
+
CLIENT_APP_RAISED_EXCEPTION = 2
|
|
86
|
+
NODE_UNAVAILABLE = 3
|
|
87
|
+
|
|
88
|
+
def __new__(cls) -> ErrorCode:
|
|
89
|
+
"""Prevent instantiation."""
|
|
90
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
flwr/common/logger.py
CHANGED
|
@@ -164,13 +164,13 @@ logger = logging.getLogger(LOGGER_NAME) # pylint: disable=invalid-name
|
|
|
164
164
|
log = logger.log # pylint: disable=invalid-name
|
|
165
165
|
|
|
166
166
|
|
|
167
|
-
def
|
|
168
|
-
"""Warn the user when they use
|
|
167
|
+
def warn_preview_feature(name: str) -> None:
|
|
168
|
+
"""Warn the user when they use a preview feature."""
|
|
169
169
|
log(
|
|
170
170
|
WARN,
|
|
171
|
-
"""
|
|
171
|
+
"""PREVIEW FEATURE: %s
|
|
172
172
|
|
|
173
|
-
This is
|
|
173
|
+
This is a preview feature. It could change significantly or be removed
|
|
174
174
|
entirely in future versions of Flower.
|
|
175
175
|
""",
|
|
176
176
|
name,
|
flwr/common/message.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
19
|
import time
|
|
20
|
+
import warnings
|
|
20
21
|
from dataclasses import dataclass
|
|
21
22
|
|
|
22
23
|
from .record import RecordSet
|
|
@@ -297,22 +298,40 @@ class Message:
|
|
|
297
298
|
partition_id=self.metadata.partition_id,
|
|
298
299
|
)
|
|
299
300
|
|
|
300
|
-
def create_error_reply(
|
|
301
|
-
self,
|
|
302
|
-
error: Error,
|
|
303
|
-
ttl: float,
|
|
304
|
-
) -> Message:
|
|
301
|
+
def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
|
|
305
302
|
"""Construct a reply message indicating an error happened.
|
|
306
303
|
|
|
307
304
|
Parameters
|
|
308
305
|
----------
|
|
309
306
|
error : Error
|
|
310
307
|
The error that was encountered.
|
|
311
|
-
ttl : float
|
|
312
|
-
Time-to-live for this message in seconds.
|
|
308
|
+
ttl : Optional[float] (default: None)
|
|
309
|
+
Time-to-live for this message in seconds. If unset, it will be set based
|
|
310
|
+
on the remaining time for the received message before it expires. This
|
|
311
|
+
follows the equation:
|
|
312
|
+
|
|
313
|
+
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
|
|
313
314
|
"""
|
|
315
|
+
if ttl:
|
|
316
|
+
warnings.warn(
|
|
317
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
318
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
319
|
+
"version of Flower.",
|
|
320
|
+
stacklevel=2,
|
|
321
|
+
)
|
|
322
|
+
# If no TTL passed, use default for message creation (will update after
|
|
323
|
+
# message creation)
|
|
324
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
314
325
|
# Create reply with error
|
|
315
|
-
message = Message(metadata=self._create_reply_metadata(
|
|
326
|
+
message = Message(metadata=self._create_reply_metadata(ttl_), error=error)
|
|
327
|
+
|
|
328
|
+
if ttl is None:
|
|
329
|
+
# Set TTL equal to the remaining time for the received message to expire
|
|
330
|
+
ttl = self.metadata.ttl - (
|
|
331
|
+
message.metadata.created_at - self.metadata.created_at
|
|
332
|
+
)
|
|
333
|
+
message.metadata.ttl = ttl
|
|
334
|
+
|
|
316
335
|
return message
|
|
317
336
|
|
|
318
337
|
def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
|
|
@@ -327,18 +346,38 @@ class Message:
|
|
|
327
346
|
content : RecordSet
|
|
328
347
|
The content for the reply message.
|
|
329
348
|
ttl : Optional[float] (default: None)
|
|
330
|
-
Time-to-live for this message in seconds. If unset, it will
|
|
331
|
-
the
|
|
349
|
+
Time-to-live for this message in seconds. If unset, it will be set based
|
|
350
|
+
on the remaining time for the received message before it expires. This
|
|
351
|
+
follows the equation:
|
|
352
|
+
|
|
353
|
+
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
|
|
332
354
|
|
|
333
355
|
Returns
|
|
334
356
|
-------
|
|
335
357
|
Message
|
|
336
358
|
A new `Message` instance representing the reply.
|
|
337
359
|
"""
|
|
338
|
-
if ttl
|
|
339
|
-
|
|
360
|
+
if ttl:
|
|
361
|
+
warnings.warn(
|
|
362
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
363
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
364
|
+
"version of Flower.",
|
|
365
|
+
stacklevel=2,
|
|
366
|
+
)
|
|
367
|
+
# If no TTL passed, use default for message creation (will update after
|
|
368
|
+
# message creation)
|
|
369
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
340
370
|
|
|
341
|
-
|
|
342
|
-
metadata=self._create_reply_metadata(
|
|
371
|
+
message = Message(
|
|
372
|
+
metadata=self._create_reply_metadata(ttl_),
|
|
343
373
|
content=content,
|
|
344
374
|
)
|
|
375
|
+
|
|
376
|
+
if ttl is None:
|
|
377
|
+
# Set TTL equal to the remaining time for the received message to expire
|
|
378
|
+
ttl = self.metadata.ttl - (
|
|
379
|
+
message.metadata.created_at - self.metadata.created_at
|
|
380
|
+
)
|
|
381
|
+
message.metadata.ttl = ttl
|
|
382
|
+
|
|
383
|
+
return message
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -107,7 +107,7 @@ class RetryInvoker:
|
|
|
107
107
|
|
|
108
108
|
Parameters
|
|
109
109
|
----------
|
|
110
|
-
|
|
110
|
+
wait_gen_factory: Callable[[], Generator[float, None, None]]
|
|
111
111
|
A generator yielding successive wait times in seconds. If the generator
|
|
112
112
|
is finite, the giveup event will be triggered when the generator raises
|
|
113
113
|
`StopIteration`.
|
|
@@ -129,12 +129,12 @@ class RetryInvoker:
|
|
|
129
129
|
data class object detailing the invocation.
|
|
130
130
|
on_giveup: Optional[Callable[[RetryState], None]] (default: None)
|
|
131
131
|
A callable to be executed in the event that `max_tries` or `max_time` is
|
|
132
|
-
exceeded, `should_giveup` returns True, or `
|
|
132
|
+
exceeded, `should_giveup` returns True, or `wait_gen_factory()` generator raises
|
|
133
133
|
`StopInteration`. The parameter is a data class object detailing the
|
|
134
134
|
invocation.
|
|
135
135
|
jitter: Optional[Callable[[float], float]] (default: full_jitter)
|
|
136
|
-
A function of the value yielded by `
|
|
137
|
-
to wait. This function helps distribute wait times stochastically to avoid
|
|
136
|
+
A function of the value yielded by `wait_gen_factory()` returning the actual
|
|
137
|
+
time to wait. This function helps distribute wait times stochastically to avoid
|
|
138
138
|
timing collisions across concurrent clients. Wait times are jittered by
|
|
139
139
|
default using the `full_jitter` function. To disable jittering, pass
|
|
140
140
|
`jitter=None`.
|
|
@@ -142,6 +142,13 @@ class RetryInvoker:
|
|
|
142
142
|
A function accepting an exception instance, returning whether or not
|
|
143
143
|
to give up prematurely before other give-up conditions are evaluated.
|
|
144
144
|
If set to None, the strategy is to never give up prematurely.
|
|
145
|
+
wait_function: Optional[Callable[[float], None]] (default: None)
|
|
146
|
+
A function that defines how to wait between retry attempts. It accepts
|
|
147
|
+
one argument, the wait time in seconds, allowing the use of various waiting
|
|
148
|
+
mechanisms (e.g., asynchronous waits or event-based synchronization) suitable
|
|
149
|
+
for different execution environments. If set to `None`, the `wait_function`
|
|
150
|
+
defaults to `time.sleep`, which is ideal for synchronous operations. Custom
|
|
151
|
+
functions should manage execution flow to prevent blocking or interference.
|
|
145
152
|
|
|
146
153
|
Examples
|
|
147
154
|
--------
|
|
@@ -159,7 +166,7 @@ class RetryInvoker:
|
|
|
159
166
|
# pylint: disable-next=too-many-arguments
|
|
160
167
|
def __init__(
|
|
161
168
|
self,
|
|
162
|
-
|
|
169
|
+
wait_gen_factory: Callable[[], Generator[float, None, None]],
|
|
163
170
|
recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
|
|
164
171
|
max_tries: Optional[int],
|
|
165
172
|
max_time: Optional[float],
|
|
@@ -169,8 +176,9 @@ class RetryInvoker:
|
|
|
169
176
|
on_giveup: Optional[Callable[[RetryState], None]] = None,
|
|
170
177
|
jitter: Optional[Callable[[float], float]] = full_jitter,
|
|
171
178
|
should_giveup: Optional[Callable[[Exception], bool]] = None,
|
|
179
|
+
wait_function: Optional[Callable[[float], None]] = None,
|
|
172
180
|
) -> None:
|
|
173
|
-
self.
|
|
181
|
+
self.wait_gen_factory = wait_gen_factory
|
|
174
182
|
self.recoverable_exceptions = recoverable_exceptions
|
|
175
183
|
self.max_tries = max_tries
|
|
176
184
|
self.max_time = max_time
|
|
@@ -179,6 +187,9 @@ class RetryInvoker:
|
|
|
179
187
|
self.on_giveup = on_giveup
|
|
180
188
|
self.jitter = jitter
|
|
181
189
|
self.should_giveup = should_giveup
|
|
190
|
+
if wait_function is None:
|
|
191
|
+
wait_function = time.sleep
|
|
192
|
+
self.wait_function = wait_function
|
|
182
193
|
|
|
183
194
|
# pylint: disable-next=too-many-locals
|
|
184
195
|
def invoke(
|
|
@@ -212,13 +223,13 @@ class RetryInvoker:
|
|
|
212
223
|
Raises
|
|
213
224
|
------
|
|
214
225
|
Exception
|
|
215
|
-
If the number of tries exceeds `max_tries`, if the total time
|
|
216
|
-
|
|
226
|
+
If the number of tries exceeds `max_tries`, if the total time exceeds
|
|
227
|
+
`max_time`, if `wait_gen_factory()` generator raises `StopInteration`,
|
|
217
228
|
or if the `should_giveup` returns True for a raised exception.
|
|
218
229
|
|
|
219
230
|
Notes
|
|
220
231
|
-----
|
|
221
|
-
The time between retries is determined by the provided `
|
|
232
|
+
The time between retries is determined by the provided `wait_gen_factory()`
|
|
222
233
|
generator and can optionally be jittered using the `jitter` function.
|
|
223
234
|
The recoverable exceptions that trigger a retry, as well as conditions to
|
|
224
235
|
stop retries, are also determined by the class's initialization parameters.
|
|
@@ -231,13 +242,13 @@ class RetryInvoker:
|
|
|
231
242
|
handler(cast(RetryState, ref_state[0]))
|
|
232
243
|
|
|
233
244
|
try_cnt = 0
|
|
234
|
-
wait_generator = self.
|
|
235
|
-
start = time.
|
|
245
|
+
wait_generator = self.wait_gen_factory()
|
|
246
|
+
start = time.monotonic()
|
|
236
247
|
ref_state: List[Optional[RetryState]] = [None]
|
|
237
248
|
|
|
238
249
|
while True:
|
|
239
250
|
try_cnt += 1
|
|
240
|
-
elapsed_time = time.
|
|
251
|
+
elapsed_time = time.monotonic() - start
|
|
241
252
|
state = RetryState(
|
|
242
253
|
target=target,
|
|
243
254
|
args=args,
|
|
@@ -282,7 +293,7 @@ class RetryInvoker:
|
|
|
282
293
|
try_call_event_handler(self.on_backoff)
|
|
283
294
|
|
|
284
295
|
# Sleep
|
|
285
|
-
|
|
296
|
+
self.wait_function(state.actual_wait)
|
|
286
297
|
else:
|
|
287
298
|
# Trigger success event
|
|
288
299
|
try_call_event_handler(self.on_success)
|
flwr/proto/fleet_pb2.py
CHANGED
|
@@ -16,7 +16,7 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
|
16
16
|
from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -26,29 +26,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
26
26
|
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None
|
|
27
27
|
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
|
|
28
28
|
_globals['_CREATENODEREQUEST']._serialized_start=84
|
|
29
|
-
_globals['_CREATENODEREQUEST']._serialized_end=
|
|
30
|
-
_globals['_CREATENODERESPONSE']._serialized_start=
|
|
31
|
-
_globals['_CREATENODERESPONSE']._serialized_end=
|
|
32
|
-
_globals['_DELETENODEREQUEST']._serialized_start=
|
|
33
|
-
_globals['_DELETENODEREQUEST']._serialized_end=
|
|
34
|
-
_globals['_DELETENODERESPONSE']._serialized_start=
|
|
35
|
-
_globals['_DELETENODERESPONSE']._serialized_end=
|
|
36
|
-
_globals['_PINGREQUEST']._serialized_start=
|
|
37
|
-
_globals['_PINGREQUEST']._serialized_end=
|
|
38
|
-
_globals['_PINGRESPONSE']._serialized_start=
|
|
39
|
-
_globals['_PINGRESPONSE']._serialized_end=
|
|
40
|
-
_globals['_PULLTASKINSREQUEST']._serialized_start=
|
|
41
|
-
_globals['_PULLTASKINSREQUEST']._serialized_end=
|
|
42
|
-
_globals['_PULLTASKINSRESPONSE']._serialized_start=
|
|
43
|
-
_globals['_PULLTASKINSRESPONSE']._serialized_end=
|
|
44
|
-
_globals['_PUSHTASKRESREQUEST']._serialized_start=
|
|
45
|
-
_globals['_PUSHTASKRESREQUEST']._serialized_end=
|
|
46
|
-
_globals['_PUSHTASKRESRESPONSE']._serialized_start=
|
|
47
|
-
_globals['_PUSHTASKRESRESPONSE']._serialized_end=
|
|
48
|
-
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=
|
|
49
|
-
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=
|
|
50
|
-
_globals['_RECONNECT']._serialized_start=
|
|
51
|
-
_globals['_RECONNECT']._serialized_end=
|
|
52
|
-
_globals['_FLEET']._serialized_start=
|
|
53
|
-
_globals['_FLEET']._serialized_end=
|
|
29
|
+
_globals['_CREATENODEREQUEST']._serialized_end=126
|
|
30
|
+
_globals['_CREATENODERESPONSE']._serialized_start=128
|
|
31
|
+
_globals['_CREATENODERESPONSE']._serialized_end=180
|
|
32
|
+
_globals['_DELETENODEREQUEST']._serialized_start=182
|
|
33
|
+
_globals['_DELETENODEREQUEST']._serialized_end=233
|
|
34
|
+
_globals['_DELETENODERESPONSE']._serialized_start=235
|
|
35
|
+
_globals['_DELETENODERESPONSE']._serialized_end=255
|
|
36
|
+
_globals['_PINGREQUEST']._serialized_start=257
|
|
37
|
+
_globals['_PINGREQUEST']._serialized_end=325
|
|
38
|
+
_globals['_PINGRESPONSE']._serialized_start=327
|
|
39
|
+
_globals['_PINGRESPONSE']._serialized_end=358
|
|
40
|
+
_globals['_PULLTASKINSREQUEST']._serialized_start=360
|
|
41
|
+
_globals['_PULLTASKINSREQUEST']._serialized_end=430
|
|
42
|
+
_globals['_PULLTASKINSRESPONSE']._serialized_start=432
|
|
43
|
+
_globals['_PULLTASKINSRESPONSE']._serialized_end=539
|
|
44
|
+
_globals['_PUSHTASKRESREQUEST']._serialized_start=541
|
|
45
|
+
_globals['_PUSHTASKRESREQUEST']._serialized_end=605
|
|
46
|
+
_globals['_PUSHTASKRESRESPONSE']._serialized_start=608
|
|
47
|
+
_globals['_PUSHTASKRESRESPONSE']._serialized_end=782
|
|
48
|
+
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=736
|
|
49
|
+
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=782
|
|
50
|
+
_globals['_RECONNECT']._serialized_start=784
|
|
51
|
+
_globals['_RECONNECT']._serialized_end=814
|
|
52
|
+
_globals['_FLEET']._serialized_start=817
|
|
53
|
+
_globals['_FLEET']._serialized_end=1207
|
|
54
54
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fleet_pb2.pyi
CHANGED
|
@@ -16,8 +16,13 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
|
16
16
|
class CreateNodeRequest(google.protobuf.message.Message):
|
|
17
17
|
"""CreateNode messages"""
|
|
18
18
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
19
|
+
PING_INTERVAL_FIELD_NUMBER: builtins.int
|
|
20
|
+
ping_interval: builtins.float
|
|
19
21
|
def __init__(self,
|
|
22
|
+
*,
|
|
23
|
+
ping_interval: builtins.float = ...,
|
|
20
24
|
) -> None: ...
|
|
25
|
+
def ClearField(self, field_name: typing_extensions.Literal["ping_interval",b"ping_interval"]) -> None: ...
|
|
21
26
|
global___CreateNodeRequest = CreateNodeRequest
|
|
22
27
|
|
|
23
28
|
class CreateNodeResponse(google.protobuf.message.Message):
|
|
@@ -170,8 +170,24 @@ class DriverClientProxy(ClientProxy):
|
|
|
170
170
|
)
|
|
171
171
|
if len(task_res_list) == 1:
|
|
172
172
|
task_res = task_res_list[0]
|
|
173
|
+
|
|
174
|
+
# This will raise an Exception if task_res carries an `error`
|
|
175
|
+
validate_task_res(task_res=task_res)
|
|
176
|
+
|
|
173
177
|
return serde.recordset_from_proto(task_res.task.recordset)
|
|
174
178
|
|
|
175
179
|
if timeout is not None and time.time() > start_time + timeout:
|
|
176
180
|
raise RuntimeError("Timeout reached")
|
|
177
181
|
time.sleep(SLEEP_TIME)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def validate_task_res(
|
|
185
|
+
task_res: task_pb2.TaskRes, # pylint: disable=E1101
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Validate if a TaskRes is empty or not."""
|
|
188
|
+
if not task_res.HasField("task"):
|
|
189
|
+
raise ValueError("Invalid TaskRes, field `task` missing")
|
|
190
|
+
if task_res.task.HasField("error"):
|
|
191
|
+
raise ValueError("Exception during client-side task execution")
|
|
192
|
+
if not task_res.task.HasField("recordset"):
|
|
193
|
+
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
|
flwr/server/driver/driver.py
CHANGED
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower driver service client."""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
import time
|
|
18
|
+
import warnings
|
|
19
19
|
from typing import Iterable, List, Optional, Tuple
|
|
20
20
|
|
|
21
21
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
@@ -91,7 +91,7 @@ class Driver:
|
|
|
91
91
|
message_type: str,
|
|
92
92
|
dst_node_id: int,
|
|
93
93
|
group_id: str,
|
|
94
|
-
ttl: float =
|
|
94
|
+
ttl: Optional[float] = None,
|
|
95
95
|
) -> Message:
|
|
96
96
|
"""Create a new message with specified parameters.
|
|
97
97
|
|
|
@@ -111,10 +111,11 @@ class Driver:
|
|
|
111
111
|
group_id : str
|
|
112
112
|
The ID of the group to which this message is associated. In some settings,
|
|
113
113
|
this is used as the FL round.
|
|
114
|
-
ttl : float (default:
|
|
114
|
+
ttl : Optional[float] (default: None)
|
|
115
115
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
116
116
|
this message to receiving a reply. It specifies in seconds the duration for
|
|
117
|
-
which the message and its potential reply are considered valid.
|
|
117
|
+
which the message and its potential reply are considered valid. If unset,
|
|
118
|
+
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
|
|
118
119
|
|
|
119
120
|
Returns
|
|
120
121
|
-------
|
|
@@ -122,6 +123,15 @@ class Driver:
|
|
|
122
123
|
A new `Message` instance with the specified content and metadata.
|
|
123
124
|
"""
|
|
124
125
|
_, run_id = self._get_grpc_driver_and_run_id()
|
|
126
|
+
if ttl:
|
|
127
|
+
warnings.warn(
|
|
128
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
129
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
130
|
+
"version of Flower.",
|
|
131
|
+
stacklevel=2,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
125
135
|
metadata = Metadata(
|
|
126
136
|
run_id=run_id,
|
|
127
137
|
message_id="", # Will be set by the server
|
|
@@ -129,7 +139,7 @@ class Driver:
|
|
|
129
139
|
dst_node_id=dst_node_id,
|
|
130
140
|
reply_to_message="",
|
|
131
141
|
group_id=group_id,
|
|
132
|
-
ttl=
|
|
142
|
+
ttl=ttl_,
|
|
133
143
|
message_type=message_type,
|
|
134
144
|
)
|
|
135
145
|
return Message(metadata=metadata, content=content)
|
flwr/server/server_app.py
CHANGED
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
from typing import Callable, Optional
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context, RecordSet
|
|
21
|
+
from flwr.common.logger import warn_preview_feature
|
|
21
22
|
from flwr.server.strategy import Strategy
|
|
22
23
|
|
|
23
24
|
from .client_manager import ClientManager
|
|
@@ -120,6 +121,8 @@ class ServerApp:
|
|
|
120
121
|
""",
|
|
121
122
|
)
|
|
122
123
|
|
|
124
|
+
warn_preview_feature("ServerApp-register-main-function")
|
|
125
|
+
|
|
123
126
|
# Register provided function with the ServerApp object
|
|
124
127
|
self._main = main_fn
|
|
125
128
|
|
|
@@ -43,7 +43,7 @@ def create_node(
|
|
|
43
43
|
) -> CreateNodeResponse:
|
|
44
44
|
"""."""
|
|
45
45
|
# Create node
|
|
46
|
-
node_id = state.create_node()
|
|
46
|
+
node_id = state.create_node(ping_interval=request.ping_interval)
|
|
47
47
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
48
48
|
|
|
49
49
|
|
|
@@ -63,7 +63,8 @@ def ping(
|
|
|
63
63
|
state: State, # pylint: disable=unused-argument
|
|
64
64
|
) -> PingResponse:
|
|
65
65
|
"""."""
|
|
66
|
-
|
|
66
|
+
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
67
|
+
return PingResponse(success=res)
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
|
|
@@ -21,6 +21,7 @@ from flwr.common.constant import MISSING_EXTRA_REST
|
|
|
21
21
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
22
|
CreateNodeRequest,
|
|
23
23
|
DeleteNodeRequest,
|
|
24
|
+
PingRequest,
|
|
24
25
|
PullTaskInsRequest,
|
|
25
26
|
PushTaskResRequest,
|
|
26
27
|
)
|
|
@@ -152,11 +153,38 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
|
|
|
152
153
|
)
|
|
153
154
|
|
|
154
155
|
|
|
156
|
+
async def ping(request: Request) -> Response:
|
|
157
|
+
"""Ping."""
|
|
158
|
+
_check_headers(request.headers)
|
|
159
|
+
|
|
160
|
+
# Get the request body as raw bytes
|
|
161
|
+
ping_request_bytes: bytes = await request.body()
|
|
162
|
+
|
|
163
|
+
# Deserialize ProtoBuf
|
|
164
|
+
ping_request_proto = PingRequest()
|
|
165
|
+
ping_request_proto.ParseFromString(ping_request_bytes)
|
|
166
|
+
|
|
167
|
+
# Get state from app
|
|
168
|
+
state: State = app.state.STATE_FACTORY.state()
|
|
169
|
+
|
|
170
|
+
# Handle message
|
|
171
|
+
ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
|
|
172
|
+
|
|
173
|
+
# Return serialized ProtoBuf
|
|
174
|
+
ping_response_bytes = ping_response_proto.SerializeToString()
|
|
175
|
+
return Response(
|
|
176
|
+
status_code=200,
|
|
177
|
+
content=ping_response_bytes,
|
|
178
|
+
headers={"Content-Type": "application/protobuf"},
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
155
182
|
routes = [
|
|
156
183
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
|
157
184
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
|
158
185
|
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
|
|
159
186
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
|
187
|
+
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
160
188
|
]
|
|
161
189
|
|
|
162
190
|
app: Starlette = Starlette(
|
|
@@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Tuple, Union
|
|
|
20
20
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
23
|
-
from flwr.client.client_app import ClientApp
|
|
23
|
+
from flwr.client.client_app import ClientApp
|
|
24
24
|
from flwr.common.context import Context
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.message import Message
|
|
@@ -151,7 +151,6 @@ class RayBackend(Backend):
|
|
|
151
151
|
)
|
|
152
152
|
|
|
153
153
|
await future
|
|
154
|
-
|
|
155
154
|
# Fetch result
|
|
156
155
|
(
|
|
157
156
|
out_mssg,
|
|
@@ -160,13 +159,15 @@ class RayBackend(Backend):
|
|
|
160
159
|
|
|
161
160
|
return out_mssg, updated_context
|
|
162
161
|
|
|
163
|
-
except
|
|
162
|
+
except Exception as ex:
|
|
164
163
|
log(
|
|
165
164
|
ERROR,
|
|
166
165
|
"An exception was raised when processing a message by %s",
|
|
167
166
|
self.__class__.__name__,
|
|
168
167
|
)
|
|
169
|
-
|
|
168
|
+
# add actor back into pool
|
|
169
|
+
await self.pool.add_actor_back_to_pool(future)
|
|
170
|
+
raise ex
|
|
170
171
|
|
|
171
172
|
async def terminate(self) -> None:
|
|
172
173
|
"""Terminate all actors in actor pool."""
|