flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
|
@@ -14,33 +14,40 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
import random
|
|
19
|
-
import threading
|
|
20
17
|
from collections.abc import Iterator
|
|
21
18
|
from contextlib import contextmanager
|
|
22
19
|
from copy import copy
|
|
23
|
-
from logging import ERROR, INFO, WARN
|
|
24
|
-
from typing import Callable, Optional, TypeVar, Union
|
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
+
from typing import Callable, Optional, TypeVar, Union, cast
|
|
25
22
|
|
|
26
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
27
24
|
from google.protobuf.message import Message as GrpcMessage
|
|
28
25
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
|
29
26
|
|
|
30
|
-
from flwr.
|
|
27
|
+
from flwr.app.metadata import Metadata
|
|
31
28
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
32
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
33
|
-
from flwr.common.constant import
|
|
34
|
-
PING_BASE_MULTIPLIER,
|
|
35
|
-
PING_CALL_TIMEOUT,
|
|
36
|
-
PING_DEFAULT_INTERVAL,
|
|
37
|
-
PING_RANDOM_RANGE,
|
|
38
|
-
)
|
|
30
|
+
from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
|
|
39
31
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
32
|
+
from flwr.common.heartbeat import HeartbeatSender
|
|
33
|
+
from flwr.common.inflatable import (
|
|
34
|
+
get_all_nested_objects,
|
|
35
|
+
get_object_tree,
|
|
36
|
+
no_object_id_recompute,
|
|
37
|
+
)
|
|
38
|
+
from flwr.common.inflatable_rest_utils import (
|
|
39
|
+
make_pull_object_fn_rest,
|
|
40
|
+
make_push_object_fn_rest,
|
|
41
|
+
)
|
|
42
|
+
from flwr.common.inflatable_utils import (
|
|
43
|
+
inflate_object_from_contents,
|
|
44
|
+
pull_objects,
|
|
45
|
+
push_objects,
|
|
46
|
+
)
|
|
40
47
|
from flwr.common.logger import log
|
|
41
|
-
from flwr.common.message import Message,
|
|
48
|
+
from flwr.common.message import Message, remove_content_from_message
|
|
42
49
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
|
-
from flwr.common.serde import
|
|
50
|
+
from flwr.common.serde import message_to_proto, run_from_proto
|
|
44
51
|
from flwr.common.typing import Fab, Run
|
|
45
52
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
46
53
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -48,13 +55,23 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
48
55
|
CreateNodeResponse,
|
|
49
56
|
DeleteNodeRequest,
|
|
50
57
|
DeleteNodeResponse,
|
|
51
|
-
PingRequest,
|
|
52
|
-
PingResponse,
|
|
53
58
|
PullMessagesRequest,
|
|
54
59
|
PullMessagesResponse,
|
|
55
60
|
PushMessagesRequest,
|
|
56
61
|
PushMessagesResponse,
|
|
57
62
|
)
|
|
63
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
64
|
+
SendNodeHeartbeatRequest,
|
|
65
|
+
SendNodeHeartbeatResponse,
|
|
66
|
+
)
|
|
67
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
68
|
+
ConfirmMessageReceivedRequest,
|
|
69
|
+
ConfirmMessageReceivedResponse,
|
|
70
|
+
PullObjectRequest,
|
|
71
|
+
PullObjectResponse,
|
|
72
|
+
PushObjectRequest,
|
|
73
|
+
PushObjectResponse,
|
|
74
|
+
)
|
|
58
75
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
59
76
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
60
77
|
|
|
@@ -68,9 +85,12 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
|
|
68
85
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
|
69
86
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
|
70
87
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
|
71
|
-
|
|
88
|
+
PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
|
|
89
|
+
PATH_PUSH_OBJECT: str = "/api/v0/fleet/push-object"
|
|
90
|
+
PATH_SEND_NODE_HEARTBEAT: str = "api/v0/fleet/send-node-heartbeat"
|
|
72
91
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
|
73
92
|
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
|
|
93
|
+
PATH_CONFIRM_MESSAGE_RECEIVED: str = "/api/v0/fleet/confirm-message-received"
|
|
74
94
|
|
|
75
95
|
T = TypeVar("T", bound=GrpcMessage)
|
|
76
96
|
|
|
@@ -91,10 +111,10 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
91
111
|
tuple[
|
|
92
112
|
Callable[[], Optional[Message]],
|
|
93
113
|
Callable[[Message], None],
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
114
|
+
Callable[[], Optional[int]],
|
|
115
|
+
Callable[[], None],
|
|
116
|
+
Callable[[int], Run],
|
|
117
|
+
Callable[[str, int], Fab],
|
|
98
118
|
]
|
|
99
119
|
]:
|
|
100
120
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -160,11 +180,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
160
180
|
# Shared variables for inner functions
|
|
161
181
|
metadata: Optional[Metadata] = None
|
|
162
182
|
node: Optional[Node] = None
|
|
163
|
-
ping_thread: Optional[threading.Thread] = None
|
|
164
|
-
ping_stop_event = threading.Event()
|
|
165
183
|
|
|
166
184
|
###########################################################################
|
|
167
|
-
#
|
|
185
|
+
# heartbeat/create_node/delete_node/receive/send/get_run functions
|
|
168
186
|
###########################################################################
|
|
169
187
|
|
|
170
188
|
def _request(
|
|
@@ -214,44 +232,47 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
214
232
|
grpc_res.ParseFromString(res.content)
|
|
215
233
|
return grpc_res
|
|
216
234
|
|
|
217
|
-
def
|
|
235
|
+
def send_node_heartbeat() -> bool:
|
|
218
236
|
# Get Node
|
|
219
237
|
if node is None:
|
|
220
238
|
log(ERROR, "Node instance missing")
|
|
221
|
-
return
|
|
239
|
+
return False
|
|
222
240
|
|
|
223
|
-
# Construct the
|
|
224
|
-
req =
|
|
241
|
+
# Construct the heartbeat request
|
|
242
|
+
req = SendNodeHeartbeatRequest(
|
|
243
|
+
node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
244
|
+
)
|
|
225
245
|
|
|
226
246
|
# Send the request
|
|
227
|
-
res = _request(
|
|
247
|
+
res = _request(
|
|
248
|
+
req, SendNodeHeartbeatResponse, PATH_SEND_NODE_HEARTBEAT, retry=False
|
|
249
|
+
)
|
|
228
250
|
if res is None:
|
|
229
|
-
return
|
|
251
|
+
return False
|
|
230
252
|
|
|
231
253
|
# Check if success
|
|
232
254
|
if not res.success:
|
|
233
|
-
raise RuntimeError(
|
|
255
|
+
raise RuntimeError(
|
|
256
|
+
"Heartbeat failed unexpectedly. The SuperLink does not "
|
|
257
|
+
"recognize this SuperNode."
|
|
258
|
+
)
|
|
259
|
+
return True
|
|
234
260
|
|
|
235
|
-
|
|
236
|
-
rd = random.uniform(*PING_RANDOM_RANGE)
|
|
237
|
-
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
|
|
238
|
-
next_interval *= PING_BASE_MULTIPLIER + rd
|
|
239
|
-
if not ping_stop_event.is_set():
|
|
240
|
-
ping_stop_event.wait(next_interval)
|
|
261
|
+
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
241
262
|
|
|
242
263
|
def create_node() -> Optional[int]:
|
|
243
264
|
"""Set create_node."""
|
|
244
|
-
req = CreateNodeRequest(
|
|
265
|
+
req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
|
|
245
266
|
|
|
246
267
|
# Send the request
|
|
247
268
|
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
|
248
269
|
if res is None:
|
|
249
270
|
return None
|
|
250
271
|
|
|
251
|
-
# Remember the node and the
|
|
252
|
-
nonlocal node
|
|
272
|
+
# Remember the node and start the heartbeat sender
|
|
273
|
+
nonlocal node
|
|
253
274
|
node = res.node
|
|
254
|
-
|
|
275
|
+
heartbeat_sender.start()
|
|
255
276
|
return node.node_id
|
|
256
277
|
|
|
257
278
|
def delete_node() -> None:
|
|
@@ -261,10 +282,8 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
261
282
|
log(ERROR, "Node instance missing")
|
|
262
283
|
return
|
|
263
284
|
|
|
264
|
-
# Stop the
|
|
265
|
-
|
|
266
|
-
if ping_thread is not None:
|
|
267
|
-
ping_thread.join()
|
|
285
|
+
# Stop the heartbeat sender
|
|
286
|
+
heartbeat_sender.stop()
|
|
268
287
|
|
|
269
288
|
# Send DeleteNode request
|
|
270
289
|
req = DeleteNodeRequest(node=node)
|
|
@@ -301,14 +320,58 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
301
320
|
):
|
|
302
321
|
message_proto = None
|
|
303
322
|
|
|
304
|
-
#
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
if message_proto
|
|
308
|
-
message = message_from_proto(message_proto)
|
|
309
|
-
metadata = copy(message.metadata)
|
|
323
|
+
# Construct the Message
|
|
324
|
+
in_message: Optional[Message] = None
|
|
325
|
+
|
|
326
|
+
if message_proto:
|
|
310
327
|
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
|
311
|
-
|
|
328
|
+
msg_id = message_proto.metadata.message_id
|
|
329
|
+
run_id = message_proto.metadata.run_id
|
|
330
|
+
|
|
331
|
+
def fn(request: PullObjectRequest) -> PullObjectResponse:
|
|
332
|
+
res = _request(
|
|
333
|
+
req=request, res_type=PullObjectResponse, api_path=PATH_PULL_OBJECT
|
|
334
|
+
)
|
|
335
|
+
if res is None:
|
|
336
|
+
raise ValueError("PushObjectResponse is None.")
|
|
337
|
+
return res
|
|
338
|
+
|
|
339
|
+
try:
|
|
340
|
+
all_object_contents = pull_objects(
|
|
341
|
+
list(res.objects_to_pull[msg_id].object_ids) + [msg_id],
|
|
342
|
+
pull_object_fn=make_pull_object_fn_rest(
|
|
343
|
+
pull_object_rest=fn,
|
|
344
|
+
node=node,
|
|
345
|
+
run_id=run_id,
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Confirm that the message has been received
|
|
350
|
+
_request(
|
|
351
|
+
req=ConfirmMessageReceivedRequest(
|
|
352
|
+
node=node, run_id=run_id, message_object_id=msg_id
|
|
353
|
+
),
|
|
354
|
+
res_type=ConfirmMessageReceivedResponse,
|
|
355
|
+
api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
|
|
356
|
+
)
|
|
357
|
+
except ValueError as e:
|
|
358
|
+
log(
|
|
359
|
+
ERROR,
|
|
360
|
+
"Pulling objects failed. Potential irrecoverable error: %s",
|
|
361
|
+
str(e),
|
|
362
|
+
)
|
|
363
|
+
in_message = cast(
|
|
364
|
+
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
|
365
|
+
)
|
|
366
|
+
# The deflated message doesn't contain the message_id (its own object_id)
|
|
367
|
+
# Inject
|
|
368
|
+
in_message.metadata.__dict__["_message_id"] = msg_id
|
|
369
|
+
|
|
370
|
+
# Remember `metadata` of the in message
|
|
371
|
+
nonlocal metadata
|
|
372
|
+
metadata = copy(in_message.metadata) if in_message else None
|
|
373
|
+
|
|
374
|
+
return in_message
|
|
312
375
|
|
|
313
376
|
def send(message: Message) -> None:
|
|
314
377
|
"""Send Message result back to server."""
|
|
@@ -323,29 +386,72 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
323
386
|
log(ERROR, "No current message")
|
|
324
387
|
return
|
|
325
388
|
|
|
389
|
+
# Set message_id
|
|
390
|
+
message.metadata.__dict__["_message_id"] = message.object_id
|
|
326
391
|
# Validate out message
|
|
327
392
|
if not validate_out_message(message, metadata):
|
|
328
393
|
log(ERROR, "Invalid out message")
|
|
329
394
|
return
|
|
330
|
-
metadata = None
|
|
331
395
|
|
|
332
|
-
|
|
333
|
-
|
|
396
|
+
with no_object_id_recompute():
|
|
397
|
+
# Get all nested objects
|
|
398
|
+
all_objects = get_all_nested_objects(message)
|
|
399
|
+
object_tree = get_object_tree(message)
|
|
334
400
|
|
|
335
|
-
|
|
336
|
-
|
|
401
|
+
# Serialize Message
|
|
402
|
+
message_proto = message_to_proto(
|
|
403
|
+
message=remove_content_from_message(message)
|
|
404
|
+
)
|
|
405
|
+
req = PushMessagesRequest(
|
|
406
|
+
node=node,
|
|
407
|
+
messages_list=[message_proto],
|
|
408
|
+
message_object_trees=[object_tree],
|
|
409
|
+
)
|
|
337
410
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
411
|
+
# Send the request
|
|
412
|
+
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
|
413
|
+
if res:
|
|
414
|
+
log(
|
|
415
|
+
INFO,
|
|
416
|
+
"[Node] POST /%s: success, created result %s",
|
|
417
|
+
PATH_PUSH_MESSAGES,
|
|
418
|
+
res.results, # pylint: disable=no-member
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
if res and res.objects_to_push:
|
|
422
|
+
objs_to_push = set(res.objects_to_push[message.object_id].object_ids)
|
|
423
|
+
|
|
424
|
+
def fn(request: PushObjectRequest) -> PushObjectResponse:
|
|
425
|
+
res = _request(
|
|
426
|
+
req=request,
|
|
427
|
+
res_type=PushObjectResponse,
|
|
428
|
+
api_path=PATH_PUSH_OBJECT,
|
|
429
|
+
)
|
|
430
|
+
if res is None:
|
|
431
|
+
raise ValueError("PushObjectResponse is None.")
|
|
432
|
+
return res
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
push_objects(
|
|
436
|
+
all_objects,
|
|
437
|
+
push_object_fn=make_push_object_fn_rest(
|
|
438
|
+
push_object_rest=fn,
|
|
439
|
+
node=node,
|
|
440
|
+
run_id=message_proto.metadata.run_id,
|
|
441
|
+
),
|
|
442
|
+
object_ids_to_push=objs_to_push,
|
|
443
|
+
)
|
|
444
|
+
log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
|
|
445
|
+
except ValueError as e:
|
|
446
|
+
log(
|
|
447
|
+
ERROR,
|
|
448
|
+
"Pushing objects failed. Potential irrecoverable error: %s",
|
|
449
|
+
str(e),
|
|
450
|
+
)
|
|
451
|
+
log(ERROR, str(e))
|
|
342
452
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
"[Node] POST /%s: success, created result %s",
|
|
346
|
-
PATH_PUSH_MESSAGES,
|
|
347
|
-
res.results, # pylint: disable=no-member
|
|
348
|
-
)
|
|
453
|
+
# Cleanup
|
|
454
|
+
metadata = None
|
|
349
455
|
|
|
350
456
|
def get_run(run_id: int) -> Run:
|
|
351
457
|
# Construct the request
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Public Flower ClientApp APIs."""
|
flwr/common/__init__.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Common components shared between server and client."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from ..app.error import Error as Error
|
|
19
|
+
from ..app.metadata import Metadata as Metadata
|
|
18
20
|
from .constant import MessageType as MessageType
|
|
19
21
|
from .constant import MessageTypeLegacy as MessageTypeLegacy
|
|
20
22
|
from .context import Context as Context
|
|
@@ -23,9 +25,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
|
|
23
25
|
from .logger import configure as configure
|
|
24
26
|
from .logger import log as log
|
|
25
27
|
from .message import DEFAULT_TTL
|
|
26
|
-
from .message import Error as Error
|
|
27
28
|
from .message import Message as Message
|
|
28
|
-
from .message import Metadata as Metadata
|
|
29
29
|
from .parameter import bytes_to_ndarray as bytes_to_ndarray
|
|
30
30
|
from .parameter import ndarray_to_bytes as ndarray_to_bytes
|
|
31
31
|
from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
|
|
@@ -17,8 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
from .auth_plugin import CliAuthPlugin as CliAuthPlugin
|
|
19
19
|
from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
|
|
20
|
+
from .auth_plugin import ExecAuthzPlugin as ExecAuthzPlugin
|
|
20
21
|
|
|
21
22
|
__all__ = [
|
|
22
23
|
"CliAuthPlugin",
|
|
23
24
|
"ExecAuthPlugin",
|
|
25
|
+
"ExecAuthzPlugin",
|
|
24
26
|
]
|
|
@@ -20,7 +20,7 @@ from collections.abc import Sequence
|
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Optional, Union
|
|
22
22
|
|
|
23
|
-
from flwr.common.typing import
|
|
23
|
+
from flwr.common.typing import AccountInfo
|
|
24
24
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
25
25
|
|
|
26
26
|
from ..typing import UserAuthCredentials, UserAuthLoginDetails
|
|
@@ -33,6 +33,9 @@ class ExecAuthPlugin(ABC):
|
|
|
33
33
|
----------
|
|
34
34
|
user_auth_config_path : Path
|
|
35
35
|
Path to the YAML file containing the authentication configuration.
|
|
36
|
+
verify_tls_cert : bool
|
|
37
|
+
Boolean indicating whether to verify the TLS certificate
|
|
38
|
+
when making requests to the server.
|
|
36
39
|
"""
|
|
37
40
|
|
|
38
41
|
@abstractmethod
|
|
@@ -50,7 +53,7 @@ class ExecAuthPlugin(ABC):
|
|
|
50
53
|
@abstractmethod
|
|
51
54
|
def validate_tokens_in_metadata(
|
|
52
55
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
53
|
-
) -> tuple[bool, Optional[
|
|
56
|
+
) -> tuple[bool, Optional[AccountInfo]]:
|
|
54
57
|
"""Validate authentication tokens in the provided metadata."""
|
|
55
58
|
|
|
56
59
|
@abstractmethod
|
|
@@ -60,10 +63,33 @@ class ExecAuthPlugin(ABC):
|
|
|
60
63
|
@abstractmethod
|
|
61
64
|
def refresh_tokens(
|
|
62
65
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
63
|
-
) ->
|
|
66
|
+
) -> tuple[
|
|
67
|
+
Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
|
|
68
|
+
]:
|
|
64
69
|
"""Refresh authentication tokens in the provided metadata."""
|
|
65
70
|
|
|
66
71
|
|
|
72
|
+
class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
|
|
73
|
+
"""Abstract Flower Authorization Plugin class for ExecServicer.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
user_auth_config_path : Path
|
|
78
|
+
Path to the YAML file containing the authorization configuration.
|
|
79
|
+
verify_tls_cert : bool
|
|
80
|
+
Boolean indicating whether to verify the TLS certificate
|
|
81
|
+
when making requests to the server.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def __init__(self, user_auth_config_path: Path, verify_tls_cert: bool):
|
|
86
|
+
"""Abstract constructor."""
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def verify_user_authorization(self, account_info: AccountInfo) -> bool:
|
|
90
|
+
"""Verify user authorization request."""
|
|
91
|
+
|
|
92
|
+
|
|
67
93
|
class CliAuthPlugin(ABC):
|
|
68
94
|
"""Abstract Flower Auth Plugin class for CLI.
|
|
69
95
|
|
flwr/common/constant.py
CHANGED
|
@@ -55,13 +55,14 @@ EXEC_API_DEFAULT_SERVER_ADDRESS = f"{SERVER_OCTET}:{EXEC_API_PORT}"
|
|
|
55
55
|
SIMULATIONIO_API_DEFAULT_SERVER_ADDRESS = f"{SERVER_OCTET}:{SIMULATIONIO_PORT}"
|
|
56
56
|
SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS = f"{CLIENT_OCTET}:{SIMULATIONIO_PORT}"
|
|
57
57
|
|
|
58
|
-
# Constants for
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
58
|
+
# Constants for heartbeat
|
|
59
|
+
HEARTBEAT_DEFAULT_INTERVAL = 30
|
|
60
|
+
HEARTBEAT_CALL_TIMEOUT = 5
|
|
61
|
+
HEARTBEAT_BASE_MULTIPLIER = 0.8
|
|
62
|
+
HEARTBEAT_RANDOM_RANGE = (-0.1, 0.1)
|
|
63
|
+
HEARTBEAT_MAX_INTERVAL = 1e300
|
|
64
|
+
HEARTBEAT_PATIENCE = 2
|
|
65
|
+
RUN_FAILURE_DETAILS_NO_HEARTBEAT = "No heartbeat received from the run."
|
|
65
66
|
|
|
66
67
|
# IDs
|
|
67
68
|
RUN_ID_NUM_BYTES = 8
|
|
@@ -114,6 +115,9 @@ AUTH_TYPE_YAML_KEY = "auth_type" # For key name in YAML file
|
|
|
114
115
|
ACCESS_TOKEN_KEY = "flwr-oidc-access-token"
|
|
115
116
|
REFRESH_TOKEN_KEY = "flwr-oidc-refresh-token"
|
|
116
117
|
|
|
118
|
+
# Constants for user authorization
|
|
119
|
+
AUTHZ_TYPE_YAML_KEY = "authz_type" # For key name in YAML file
|
|
120
|
+
|
|
117
121
|
# Constants for node authentication
|
|
118
122
|
PUBLIC_KEY_HEADER = "flwr-public-key-bin" # Must end with "-bin" for binary data
|
|
119
123
|
SIGNATURE_HEADER = "flwr-signature-bin" # Must end with "-bin" for binary data
|
|
@@ -121,9 +125,34 @@ TIMESTAMP_HEADER = "flwr-timestamp"
|
|
|
121
125
|
TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
|
|
122
126
|
SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift
|
|
123
127
|
|
|
128
|
+
# Constants for grpc retry
|
|
129
|
+
GRPC_RETRY_MAX_DELAY = 20 # Maximum delay duration between two consecutive retries.
|
|
130
|
+
|
|
124
131
|
# Constants for ArrayRecord
|
|
125
132
|
GC_THRESHOLD = 200_000_000 # 200 MB
|
|
126
133
|
|
|
134
|
+
# Constants for Inflatable
|
|
135
|
+
HEAD_BODY_DIVIDER = b"\x00"
|
|
136
|
+
HEAD_VALUE_DIVIDER = " "
|
|
137
|
+
|
|
138
|
+
# Constants for serialization
|
|
139
|
+
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
|
140
|
+
|
|
141
|
+
# Constants for `flwr-serverapp` and `flwr-clientapp` CLI commands
|
|
142
|
+
FLWR_APP_TOKEN_LENGTH = 128 # Length of the token used
|
|
143
|
+
|
|
144
|
+
# Constants for object pushing and pulling
|
|
145
|
+
MAX_CONCURRENT_PUSHES = 8 # Default maximum number of concurrent pushes
|
|
146
|
+
MAX_CONCURRENT_PULLS = 8 # Default maximum number of concurrent pulls
|
|
147
|
+
PULL_MAX_TIME = 7200 # Default maximum time to wait for pulling objects
|
|
148
|
+
PULL_MAX_TRIES_PER_OBJECT = 500 # Default maximum number of tries to pull an object
|
|
149
|
+
PULL_INITIAL_BACKOFF = 1 # Initial backoff time for pulling objects
|
|
150
|
+
PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# ExecServicer constants
|
|
154
|
+
RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
|
|
155
|
+
|
|
127
156
|
|
|
128
157
|
class MessageType:
|
|
129
158
|
"""Message type."""
|
|
@@ -21,7 +21,7 @@ from typing import Optional, Union
|
|
|
21
21
|
import grpc
|
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
|
23
23
|
|
|
24
|
-
from flwr.common.typing import
|
|
24
|
+
from flwr.common.typing import AccountInfo, LogEntry
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class EventLogWriterPlugin(ABC):
|
|
@@ -36,7 +36,7 @@ class EventLogWriterPlugin(ABC):
|
|
|
36
36
|
self,
|
|
37
37
|
request: GrpcMessage,
|
|
38
38
|
context: grpc.ServicerContext,
|
|
39
|
-
|
|
39
|
+
account_info: Optional[AccountInfo],
|
|
40
40
|
method_name: str,
|
|
41
41
|
) -> LogEntry:
|
|
42
42
|
"""Compose pre-event log entry from the provided request and context."""
|
|
@@ -46,7 +46,7 @@ class EventLogWriterPlugin(ABC):
|
|
|
46
46
|
self,
|
|
47
47
|
request: GrpcMessage,
|
|
48
48
|
context: grpc.ServicerContext,
|
|
49
|
-
|
|
49
|
+
account_info: Optional[AccountInfo],
|
|
50
50
|
method_name: str,
|
|
51
51
|
response: Optional[Union[GrpcMessage, BaseException]],
|
|
52
52
|
) -> LogEntry:
|
flwr/common/exit_handlers.py
CHANGED
|
@@ -30,6 +30,7 @@ SIGNAL_TO_EXIT_CODE: dict[int, int] = {
|
|
|
30
30
|
signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
|
|
31
31
|
signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
|
|
32
32
|
}
|
|
33
|
+
registered_exit_handlers: list[Callable[[], None]] = []
|
|
33
34
|
|
|
34
35
|
# SIGQUIT is not available on Windows
|
|
35
36
|
if hasattr(signal, "SIGQUIT"):
|
|
@@ -41,6 +42,7 @@ def register_exit_handlers(
|
|
|
41
42
|
exit_message: Optional[str] = None,
|
|
42
43
|
grpc_servers: Optional[list[Server]] = None,
|
|
43
44
|
bckg_threads: Optional[list[Thread]] = None,
|
|
45
|
+
exit_handlers: Optional[list[Callable[[], None]]] = None,
|
|
44
46
|
) -> None:
|
|
45
47
|
"""Register exit handlers for `SIGINT`, `SIGTERM` and `SIGQUIT` signals.
|
|
46
48
|
|
|
@@ -56,8 +58,12 @@ def register_exit_handlers(
|
|
|
56
58
|
bckg_threads: Optional[List[Thread]] (default: None)
|
|
57
59
|
An optional list of threads that need to be gracefully
|
|
58
60
|
terminated before exiting.
|
|
61
|
+
exit_handlers: Optional[List[Callable[[], None]]] (default: None)
|
|
62
|
+
An optional list of exit handlers to be called before exiting.
|
|
63
|
+
Additional exit handlers can be added using `add_exit_handler`.
|
|
59
64
|
"""
|
|
60
65
|
default_handlers: dict[int, Callable[[int, FrameType], None]] = {}
|
|
66
|
+
registered_exit_handlers.extend(exit_handlers or [])
|
|
61
67
|
|
|
62
68
|
def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
|
|
63
69
|
"""Exit handler to be registered with `signal.signal`.
|
|
@@ -68,6 +74,9 @@ def register_exit_handlers(
|
|
|
68
74
|
# Reset to default handler
|
|
69
75
|
signal.signal(signalnum, default_handlers[signalnum]) # type: ignore
|
|
70
76
|
|
|
77
|
+
for handler in registered_exit_handlers:
|
|
78
|
+
handler()
|
|
79
|
+
|
|
71
80
|
if grpc_servers is not None:
|
|
72
81
|
for grpc_server in grpc_servers:
|
|
73
82
|
grpc_server.stop(grace=1)
|
|
@@ -87,3 +96,24 @@ def register_exit_handlers(
|
|
|
87
96
|
for sig in SIGNAL_TO_EXIT_CODE:
|
|
88
97
|
default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
|
|
89
98
|
default_handlers[sig] = default_handler # type: ignore
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def add_exit_handler(exit_handler: Callable[[], None]) -> None:
|
|
102
|
+
"""Add an exit handler to be called on graceful exit.
|
|
103
|
+
|
|
104
|
+
This function allows you to register additional exit handlers
|
|
105
|
+
that will be executed when the application exits gracefully,
|
|
106
|
+
if `register_exit_handlers` was called.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
exit_handler : Callable[[], None]
|
|
111
|
+
A callable that takes no arguments and performs cleanup or
|
|
112
|
+
other actions before the application exits.
|
|
113
|
+
|
|
114
|
+
Notes
|
|
115
|
+
-----
|
|
116
|
+
This method is not thread-safe, and it allows you to add the
|
|
117
|
+
same exit handler multiple times.
|
|
118
|
+
"""
|
|
119
|
+
registered_exit_handlers.append(exit_handler)
|