flwr 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +15 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +187 -35
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +92 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +53 -13
- flwr/common/exit/exit_code.py +20 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -35
- flwr/proto/control_pb2.pyi +71 -5
- flwr/proto/control_pb2_grpc.py +102 -0
- flwr/proto/control_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +139 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +3 -2
- flwr/supercore/constant.py +22 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/utils.py +20 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +13 -11
- flwr/superlink/servicer/control/control_servicer.py +152 -60
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/METADATA +1 -1
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/RECORD +107 -96
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
|
@@ -36,18 +36,27 @@ from flwr.common.inflatable_protobuf_utils import (
|
|
|
36
36
|
from flwr.common.logger import log
|
|
37
37
|
from flwr.common.message import Message, remove_content_from_message
|
|
38
38
|
from flwr.common.retry_invoker import RetryInvoker
|
|
39
|
-
from flwr.common.serde import
|
|
39
|
+
from flwr.common.serde import (
|
|
40
|
+
fab_from_proto,
|
|
41
|
+
message_from_proto,
|
|
42
|
+
message_to_proto,
|
|
43
|
+
run_from_proto,
|
|
44
|
+
)
|
|
40
45
|
from flwr.common.typing import Fab, Run
|
|
41
46
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
42
47
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
48
|
+
ActivateNodeRequest,
|
|
49
|
+
ActivateNodeResponse,
|
|
50
|
+
DeactivateNodeRequest,
|
|
51
|
+
DeactivateNodeResponse,
|
|
47
52
|
PullMessagesRequest,
|
|
48
53
|
PullMessagesResponse,
|
|
49
54
|
PushMessagesRequest,
|
|
50
55
|
PushMessagesResponse,
|
|
56
|
+
RegisterNodeFleetRequest,
|
|
57
|
+
RegisterNodeFleetResponse,
|
|
58
|
+
UnregisterNodeFleetRequest,
|
|
59
|
+
UnregisterNodeFleetResponse,
|
|
51
60
|
)
|
|
52
61
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
53
62
|
SendNodeHeartbeatRequest,
|
|
@@ -64,6 +73,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
|
64
73
|
)
|
|
65
74
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
66
75
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
76
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
|
67
77
|
|
|
68
78
|
try:
|
|
69
79
|
import requests
|
|
@@ -71,8 +81,10 @@ except ModuleNotFoundError:
|
|
|
71
81
|
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
72
82
|
|
|
73
83
|
|
|
74
|
-
|
|
75
|
-
|
|
84
|
+
PATH_REGISTER_NODE: str = "/api/v0/fleet/register-node"
|
|
85
|
+
PATH_ACTIVATE_NODE: str = "/api/v0/fleet/activate-node"
|
|
86
|
+
PATH_DEACTIVATE_NODE: str = "/api/v0/fleet/deactivate-node"
|
|
87
|
+
PATH_UNREGISTER_NODE: str = "/api/v0/fleet/unregister-node"
|
|
76
88
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
|
77
89
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
|
78
90
|
PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
|
|
@@ -99,10 +111,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
99
111
|
] = None,
|
|
100
112
|
) -> Iterator[
|
|
101
113
|
tuple[
|
|
114
|
+
int,
|
|
102
115
|
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
103
116
|
Callable[[Message, ObjectTree], set[str]],
|
|
104
|
-
Callable[[], Optional[int]],
|
|
105
|
-
Callable[[], None],
|
|
106
117
|
Callable[[int], Run],
|
|
107
118
|
Callable[[str, int], Fab],
|
|
108
119
|
Callable[[int, str], bytes],
|
|
@@ -134,15 +145,15 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
134
145
|
connection using the certificates will be established to an SSL-enabled
|
|
135
146
|
Flower server. Bytes won't work for the REST API.
|
|
136
147
|
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
137
|
-
|
|
148
|
+
SuperNode authentication is not supported for this transport type.
|
|
138
149
|
|
|
139
150
|
Returns
|
|
140
151
|
-------
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
152
|
+
node_id : int
|
|
153
|
+
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
154
|
+
send : Callable[[Message, ObjectTree], set[str]]
|
|
155
|
+
get_run : Callable[[int], Run]
|
|
156
|
+
get_fab : Callable[[str, int], Fab]
|
|
146
157
|
pull_object : Callable[[str], bytes]
|
|
147
158
|
push_object : Callable[[str, bytes], None]
|
|
148
159
|
confirm_message_received : Callable[[str], None]
|
|
@@ -171,7 +182,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
171
182
|
"must be provided as a string path to the client.",
|
|
172
183
|
)
|
|
173
184
|
if authentication_keys is not None:
|
|
174
|
-
log(ERROR, "
|
|
185
|
+
log(ERROR, "SuperNode authentication is not supported for this transport type.")
|
|
186
|
+
|
|
187
|
+
# REST does NOT support node authentication
|
|
188
|
+
self_registered = False
|
|
189
|
+
if authentication_keys is None:
|
|
190
|
+
self_registered = True
|
|
191
|
+
authentication_keys = generate_key_pairs()
|
|
192
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
|
175
193
|
|
|
176
194
|
# Shared variables for inner functions
|
|
177
195
|
node: Optional[Node] = None
|
|
@@ -180,7 +198,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
180
198
|
retry_invoker.should_giveup = None
|
|
181
199
|
|
|
182
200
|
###########################################################################
|
|
183
|
-
#
|
|
201
|
+
# SuperNode functions
|
|
184
202
|
###########################################################################
|
|
185
203
|
|
|
186
204
|
def _request(
|
|
@@ -290,23 +308,35 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
290
308
|
|
|
291
309
|
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
292
310
|
|
|
293
|
-
def
|
|
294
|
-
"""
|
|
295
|
-
req =
|
|
311
|
+
def register_node() -> None:
|
|
312
|
+
"""Register node with SuperLink."""
|
|
313
|
+
req = RegisterNodeFleetRequest(public_key=node_pk)
|
|
296
314
|
|
|
297
315
|
# Send the request
|
|
298
|
-
res = _request(req,
|
|
316
|
+
res = _request(req, RegisterNodeFleetResponse, PATH_REGISTER_NODE)
|
|
299
317
|
if res is None:
|
|
300
|
-
|
|
318
|
+
raise RuntimeError("Failed to register node")
|
|
319
|
+
|
|
320
|
+
def activate_node() -> int:
|
|
321
|
+
"""Activate node and start heartbeat."""
|
|
322
|
+
req = ActivateNodeRequest(
|
|
323
|
+
public_key=node_pk,
|
|
324
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Send the request
|
|
328
|
+
res = _request(req, ActivateNodeResponse, PATH_ACTIVATE_NODE)
|
|
329
|
+
if res is None:
|
|
330
|
+
raise RuntimeError("Failed to activate node")
|
|
301
331
|
|
|
302
332
|
# Remember the node and start the heartbeat sender
|
|
303
333
|
nonlocal node
|
|
304
|
-
node = res.
|
|
334
|
+
node = Node(node_id=res.node_id)
|
|
305
335
|
heartbeat_sender.start()
|
|
306
336
|
return node.node_id
|
|
307
337
|
|
|
308
|
-
def
|
|
309
|
-
"""
|
|
338
|
+
def deactivate_node() -> None:
|
|
339
|
+
"""Deactivate node and stop heartbeat."""
|
|
310
340
|
nonlocal node
|
|
311
341
|
if node is None:
|
|
312
342
|
raise RuntimeError("Node instance missing")
|
|
@@ -314,13 +344,27 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
314
344
|
# Stop the heartbeat sender
|
|
315
345
|
heartbeat_sender.stop()
|
|
316
346
|
|
|
317
|
-
# Send
|
|
318
|
-
req =
|
|
347
|
+
# Send DeactivateNode request
|
|
348
|
+
req = DeactivateNodeRequest(node_id=node.node_id)
|
|
319
349
|
|
|
320
350
|
# Send the request
|
|
321
|
-
res = _request(req,
|
|
351
|
+
res = _request(req, DeactivateNodeResponse, PATH_DEACTIVATE_NODE)
|
|
322
352
|
if res is None:
|
|
323
|
-
|
|
353
|
+
raise RuntimeError("Failed to deactivate node")
|
|
354
|
+
|
|
355
|
+
def unregister_node() -> None:
|
|
356
|
+
"""Unregister node from SuperLink."""
|
|
357
|
+
nonlocal node
|
|
358
|
+
if node is None:
|
|
359
|
+
raise RuntimeError("Node instance missing")
|
|
360
|
+
|
|
361
|
+
# Send UnregisterNode request
|
|
362
|
+
req = UnregisterNodeFleetRequest(node_id=node.node_id)
|
|
363
|
+
|
|
364
|
+
# Send the request
|
|
365
|
+
res = _request(req, UnregisterNodeFleetResponse, PATH_UNREGISTER_NODE)
|
|
366
|
+
if res is None:
|
|
367
|
+
raise RuntimeError("Failed to unregister node")
|
|
324
368
|
|
|
325
369
|
# Cleanup
|
|
326
370
|
node = None
|
|
@@ -392,12 +436,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
392
436
|
# Send the request
|
|
393
437
|
res = _request(req, GetFabResponse, PATH_GET_FAB)
|
|
394
438
|
if res is None:
|
|
395
|
-
return Fab("", b"")
|
|
439
|
+
return Fab("", b"", {})
|
|
396
440
|
|
|
397
|
-
return
|
|
398
|
-
res.fab.hash_str,
|
|
399
|
-
res.fab.content,
|
|
400
|
-
)
|
|
441
|
+
return fab_from_proto(res.fab)
|
|
401
442
|
|
|
402
443
|
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
403
444
|
"""Pull the object from the SuperLink."""
|
|
@@ -439,12 +480,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
439
480
|
fn(object_id)
|
|
440
481
|
|
|
441
482
|
try:
|
|
483
|
+
if self_registered:
|
|
484
|
+
register_node()
|
|
485
|
+
node_id = activate_node()
|
|
442
486
|
# Yield methods
|
|
443
487
|
yield (
|
|
488
|
+
node_id,
|
|
444
489
|
receive,
|
|
445
490
|
send,
|
|
446
|
-
create_node,
|
|
447
|
-
delete_node,
|
|
448
491
|
get_run,
|
|
449
492
|
get_fab,
|
|
450
493
|
pull_object,
|
|
@@ -459,6 +502,8 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
459
502
|
if node is not None:
|
|
460
503
|
# Disable retrying
|
|
461
504
|
retry_invoker.max_tries = 1
|
|
462
|
-
|
|
505
|
+
deactivate_node()
|
|
506
|
+
if self_registered:
|
|
507
|
+
unregister_node()
|
|
463
508
|
except RequestsConnectionError:
|
|
464
509
|
pass
|
flwr/clientapp/__init__.py
CHANGED
|
@@ -19,7 +19,7 @@ from logging import DEBUG
|
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
from typing import Callable, Optional
|
|
21
21
|
|
|
22
|
-
from flwr.
|
|
22
|
+
from flwr.clientapp.client_app import ClientApp, LoadClientAppError
|
|
23
23
|
from flwr.common.config import (
|
|
24
24
|
get_flwr_dir,
|
|
25
25
|
get_metadata_from_config,
|
flwr/common/constant.py
CHANGED
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
+
import os
|
|
21
|
+
|
|
20
22
|
TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi"
|
|
21
23
|
TRANSPORT_TYPE_GRPC_RERE = "grpc-rere"
|
|
22
24
|
TRANSPORT_TYPE_GRPC_ADAPTER = "grpc-adapter"
|
|
@@ -60,7 +62,9 @@ HEARTBEAT_DEFAULT_INTERVAL = 30
|
|
|
60
62
|
HEARTBEAT_CALL_TIMEOUT = 5
|
|
61
63
|
HEARTBEAT_BASE_MULTIPLIER = 0.8
|
|
62
64
|
HEARTBEAT_RANDOM_RANGE = (-0.1, 0.1)
|
|
63
|
-
|
|
65
|
+
HEARTBEAT_MIN_INTERVAL = 10
|
|
66
|
+
HEARTBEAT_MAX_INTERVAL = 1800 # 30 minutes
|
|
67
|
+
HEARTBEAT_INTERVAL_INF = 1e300 # Large value, disabling heartbeats
|
|
64
68
|
HEARTBEAT_PATIENCE = 2
|
|
65
69
|
RUN_FAILURE_DETAILS_NO_HEARTBEAT = "No heartbeat received from the run."
|
|
66
70
|
|
|
@@ -70,13 +74,23 @@ NODE_ID_NUM_BYTES = 8
|
|
|
70
74
|
|
|
71
75
|
# Constants for FAB
|
|
72
76
|
APP_DIR = "apps"
|
|
73
|
-
FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
|
|
74
77
|
FAB_CONFIG_FILE = "pyproject.toml"
|
|
75
78
|
FAB_DATE = (2024, 10, 1, 0, 0, 0)
|
|
76
79
|
FAB_HASH_TRUNCATION = 8
|
|
77
80
|
FAB_MAX_SIZE = 10 * 1024 * 1024 # 10 MB
|
|
78
81
|
FLWR_DIR = ".flwr" # The default Flower directory: ~/.flwr/
|
|
79
82
|
FLWR_HOME = "FLWR_HOME" # If set, override the default Flower directory
|
|
83
|
+
# FAB file include patterns (gitignore-style patterns)
|
|
84
|
+
FAB_INCLUDE_PATTERNS = (
|
|
85
|
+
"**/*.py",
|
|
86
|
+
"**/*.toml",
|
|
87
|
+
"**/*.md",
|
|
88
|
+
)
|
|
89
|
+
# FAB file exclude patterns (gitignore-style patterns)
|
|
90
|
+
FAB_EXCLUDE_PATTERNS = (
|
|
91
|
+
"**/__pycache__/**",
|
|
92
|
+
FAB_CONFIG_FILE, # Exclude the original pyproject.toml
|
|
93
|
+
)
|
|
80
94
|
|
|
81
95
|
# Constant for SuperLink
|
|
82
96
|
SUPERLINK_NODE_ID = 1
|
|
@@ -109,14 +123,14 @@ LOG_UPLOAD_INTERVAL = 0.2 # Minimum interval between two log uploads
|
|
|
109
123
|
# Retry configurations
|
|
110
124
|
MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.
|
|
111
125
|
|
|
112
|
-
# Constants for
|
|
126
|
+
# Constants for account authentication
|
|
113
127
|
CREDENTIALS_DIR = ".credentials"
|
|
114
|
-
|
|
115
|
-
|
|
128
|
+
AUTHN_TYPE_JSON_KEY = "authn-type" # For key name in JSON file
|
|
129
|
+
AUTHN_TYPE_YAML_KEY = "authn_type" # For key name in YAML file
|
|
116
130
|
ACCESS_TOKEN_KEY = "flwr-oidc-access-token"
|
|
117
131
|
REFRESH_TOKEN_KEY = "flwr-oidc-refresh-token"
|
|
118
132
|
|
|
119
|
-
# Constants for
|
|
133
|
+
# Constants for account authorization
|
|
120
134
|
AUTHZ_TYPE_YAML_KEY = "authz_type" # For key name in YAML file
|
|
121
135
|
|
|
122
136
|
# Constants for node authentication
|
|
@@ -135,7 +149,9 @@ GC_THRESHOLD = 200_000_000 # 200 MB
|
|
|
135
149
|
# Constants for Inflatable
|
|
136
150
|
HEAD_BODY_DIVIDER = b"\x00"
|
|
137
151
|
HEAD_VALUE_DIVIDER = " "
|
|
138
|
-
|
|
152
|
+
FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE = int(
|
|
153
|
+
os.getenv("FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE", "5242880")
|
|
154
|
+
) # 5 MB
|
|
139
155
|
|
|
140
156
|
# Constants for serialization
|
|
141
157
|
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
|
@@ -144,8 +160,12 @@ INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
|
|
144
160
|
FLWR_APP_TOKEN_LENGTH = 128 # Length of the token used
|
|
145
161
|
|
|
146
162
|
# Constants for object pushing and pulling
|
|
147
|
-
|
|
148
|
-
|
|
163
|
+
FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES = int(
|
|
164
|
+
os.getenv("FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES", "2")
|
|
165
|
+
) # Default maximum number of concurrent pushes
|
|
166
|
+
FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS = int(
|
|
167
|
+
os.getenv("FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS", "2")
|
|
168
|
+
) # Default maximum number of concurrent pulls
|
|
149
169
|
PULL_MAX_TIME = 7200 # Default maximum time to wait for pulling objects
|
|
150
170
|
PULL_MAX_TRIES_PER_OBJECT = 500 # Default maximum number of tries to pull an object
|
|
151
171
|
PULL_INITIAL_BACKOFF = 1 # Initial backoff time for pulling objects
|
|
@@ -154,9 +174,13 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
|
|
|
154
174
|
|
|
155
175
|
# ControlServicer constants
|
|
156
176
|
RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
|
|
157
|
-
|
|
177
|
+
NO_ACCOUNT_AUTH_MESSAGE = "ControlServicer initialized without account authentication"
|
|
158
178
|
NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
|
|
159
179
|
PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
|
|
180
|
+
SUPERNODE_NOT_CREATED_FROM_CLI_MESSAGE = "Invalid SuperNode credentials"
|
|
181
|
+
PUBLIC_KEY_ALREADY_IN_USE_MESSAGE = "Public key already in use"
|
|
182
|
+
PUBLIC_KEY_NOT_VALID = "The provided public key is not valid"
|
|
183
|
+
NODE_NOT_FOUND_MESSAGE = "Node ID not found for account"
|
|
160
184
|
|
|
161
185
|
|
|
162
186
|
class MessageType:
|
|
@@ -245,12 +269,23 @@ class CliOutputFormat:
|
|
|
245
269
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
246
270
|
|
|
247
271
|
|
|
248
|
-
class
|
|
249
|
-
"""
|
|
272
|
+
class AuthnType:
|
|
273
|
+
"""Account authentication types."""
|
|
250
274
|
|
|
275
|
+
NOOP = "noop"
|
|
251
276
|
OIDC = "oidc"
|
|
252
277
|
|
|
253
|
-
def __new__(cls) ->
|
|
278
|
+
def __new__(cls) -> AuthnType:
|
|
279
|
+
"""Prevent instantiation."""
|
|
280
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class AuthzType:
|
|
284
|
+
"""Account authorization types."""
|
|
285
|
+
|
|
286
|
+
NOOP = "noop"
|
|
287
|
+
|
|
288
|
+
def __new__(cls) -> AuthzType:
|
|
254
289
|
"""Prevent instantiation."""
|
|
255
290
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
256
291
|
|
|
@@ -281,3 +316,8 @@ class ExecPluginType:
|
|
|
281
316
|
"""Return all SuperExec plugin types."""
|
|
282
317
|
# Filter all constants (uppercase) of the class
|
|
283
318
|
return [v for k, v in vars(ExecPluginType).items() if k.isupper()]
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# Constants for No-op auth plugins
|
|
322
|
+
NOOP_FLWR_AID = "<none>"
|
|
323
|
+
NOOP_ACCOUNT_NAME = "sys_noauth"
|
flwr/common/exit/exit_code.py
CHANGED
|
@@ -41,12 +41,16 @@ class ExitCode:
|
|
|
41
41
|
|
|
42
42
|
# SuperNode-specific exit codes (300-399)
|
|
43
43
|
SUPERNODE_REST_ADDRESS_INVALID = 300
|
|
44
|
-
SUPERNODE_NODE_AUTH_KEYS_REQUIRED = 301
|
|
45
|
-
|
|
44
|
+
# SUPERNODE_NODE_AUTH_KEYS_REQUIRED = 301 --- DELETED ---
|
|
45
|
+
SUPERNODE_NODE_AUTH_KEY_INVALID = 302
|
|
46
|
+
SUPERNODE_STARTED_WITHOUT_TLS_BUT_NODE_AUTH_ENABLED = 303
|
|
46
47
|
|
|
47
48
|
# SuperExec-specific exit codes (400-499)
|
|
48
49
|
SUPEREXEC_INVALID_PLUGIN_CONFIG = 400
|
|
49
50
|
|
|
51
|
+
# FlowerCLI-specific exit codes (500-599)
|
|
52
|
+
FLWRCLI_NODE_AUTH_PUBLIC_KEY_INVALID = 500
|
|
53
|
+
|
|
50
54
|
# Common exit codes (600-699)
|
|
51
55
|
COMMON_ADDRESS_INVALID = 600
|
|
52
56
|
COMMON_MISSING_EXTRA_REST = 601
|
|
@@ -102,20 +106,26 @@ EXIT_CODE_HELP = {
|
|
|
102
106
|
"When using the REST API, please provide `https://` or "
|
|
103
107
|
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
|
|
104
108
|
),
|
|
105
|
-
ExitCode.
|
|
106
|
-
"Node authentication requires
|
|
107
|
-
"
|
|
108
|
-
"to be provided (providing only one of them is not sufficient)."
|
|
109
|
-
),
|
|
110
|
-
ExitCode.SUPERNODE_NODE_AUTH_KEYS_INVALID: (
|
|
111
|
-
"Node authentication requires elliptic curve private and public key pair. "
|
|
112
|
-
"Please ensure that the file path points to a valid private/public key "
|
|
109
|
+
ExitCode.SUPERNODE_NODE_AUTH_KEY_INVALID: (
|
|
110
|
+
"Node authentication requires elliptic curve private key. "
|
|
111
|
+
"Please ensure that the file path points to a valid private key "
|
|
113
112
|
"file and try again."
|
|
114
113
|
),
|
|
114
|
+
ExitCode.SUPERNODE_STARTED_WITHOUT_TLS_BUT_NODE_AUTH_ENABLED: (
|
|
115
|
+
"The private key for SuperNode authentication was provided, but TLS is not "
|
|
116
|
+
"enabled. Node authentication can only be used when TLS is enabled."
|
|
117
|
+
),
|
|
115
118
|
# SuperExec-specific exit codes (400-499)
|
|
116
119
|
ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG: (
|
|
117
120
|
"The YAML configuration for the SuperExec plugin is invalid."
|
|
118
121
|
),
|
|
122
|
+
# FlowerCLI-specific exit codes (500-599)
|
|
123
|
+
ExitCode.FLWRCLI_NODE_AUTH_PUBLIC_KEY_INVALID: (
|
|
124
|
+
"Node authentication requires a valid elliptic curve public key in the "
|
|
125
|
+
"SSH format and following a NIST standard elliptic curve (e.g. SECP384R1). "
|
|
126
|
+
"Please ensure that the file path points to a valid public key "
|
|
127
|
+
"file and try again."
|
|
128
|
+
),
|
|
119
129
|
# Common exit codes (600-699)
|
|
120
130
|
ExitCode.COMMON_ADDRESS_INVALID: (
|
|
121
131
|
"Please provide a valid URL, IPv4 or IPv6 address."
|
flwr/common/inflatable_utils.py
CHANGED
|
@@ -25,10 +25,10 @@ from typing import Callable, Optional, TypeVar
|
|
|
25
25
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
26
26
|
|
|
27
27
|
from .constant import (
|
|
28
|
+
FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
|
|
29
|
+
FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
|
|
28
30
|
HEAD_BODY_DIVIDER,
|
|
29
31
|
HEAD_VALUE_DIVIDER,
|
|
30
|
-
MAX_CONCURRENT_PULLS,
|
|
31
|
-
MAX_CONCURRENT_PUSHES,
|
|
32
32
|
PULL_BACKOFF_CAP,
|
|
33
33
|
PULL_INITIAL_BACKOFF,
|
|
34
34
|
PULL_MAX_TIME,
|
|
@@ -118,7 +118,7 @@ def push_objects(
|
|
|
118
118
|
*,
|
|
119
119
|
object_ids_to_push: Optional[set[str]] = None,
|
|
120
120
|
keep_objects: bool = False,
|
|
121
|
-
max_concurrent_pushes: int =
|
|
121
|
+
max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
|
|
122
122
|
) -> None:
|
|
123
123
|
"""Push multiple objects to the servicer.
|
|
124
124
|
|
|
@@ -137,7 +137,7 @@ def push_objects(
|
|
|
137
137
|
If `True`, the original objects will be kept in the `objects` dictionary
|
|
138
138
|
after pushing. If `False`, they will be removed from the dictionary to avoid
|
|
139
139
|
high memory usage.
|
|
140
|
-
max_concurrent_pushes : int (default:
|
|
140
|
+
max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
|
|
141
141
|
The maximum number of concurrent pushes to perform.
|
|
142
142
|
"""
|
|
143
143
|
lock = threading.Lock()
|
|
@@ -168,7 +168,7 @@ def push_object_contents_from_iterable(
|
|
|
168
168
|
object_contents: Iterable[tuple[str, bytes]],
|
|
169
169
|
push_object_fn: Callable[[str, bytes], None],
|
|
170
170
|
*,
|
|
171
|
-
max_concurrent_pushes: int =
|
|
171
|
+
max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
|
|
172
172
|
) -> None:
|
|
173
173
|
"""Push multiple object contents to the servicer.
|
|
174
174
|
|
|
@@ -181,7 +181,7 @@ def push_object_contents_from_iterable(
|
|
|
181
181
|
A function that takes an object ID and its content as bytes, and pushes
|
|
182
182
|
it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
|
|
183
183
|
if the object ID is not pre-registered.
|
|
184
|
-
max_concurrent_pushes : int (default:
|
|
184
|
+
max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
|
|
185
185
|
The maximum number of concurrent pushes to perform.
|
|
186
186
|
"""
|
|
187
187
|
|
|
@@ -210,7 +210,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
210
210
|
object_ids: list[str],
|
|
211
211
|
pull_object_fn: Callable[[str], bytes],
|
|
212
212
|
*,
|
|
213
|
-
max_concurrent_pulls: int =
|
|
213
|
+
max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
|
|
214
214
|
max_time: Optional[float] = PULL_MAX_TIME,
|
|
215
215
|
max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
|
|
216
216
|
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
|
@@ -227,7 +227,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
227
227
|
The function should raise `ObjectUnavailableError` if the object is not yet
|
|
228
228
|
available, or `ObjectIdNotPreregisteredError` if the object ID is not
|
|
229
229
|
pre-registered.
|
|
230
|
-
max_concurrent_pulls : int (default:
|
|
230
|
+
max_concurrent_pulls : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS)
|
|
231
231
|
The maximum number of concurrent pulls to perform.
|
|
232
232
|
max_time : Optional[float] (default: PULL_MAX_TIME)
|
|
233
233
|
The maximum time to wait for all pulls to complete. If `None`, waits
|
|
@@ -442,7 +442,7 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
|
|
|
442
442
|
confirm_object_received_fn: Callable[[str], None],
|
|
443
443
|
*,
|
|
444
444
|
return_type: type[T] = InflatableObject, # type: ignore
|
|
445
|
-
max_concurrent_pulls: int =
|
|
445
|
+
max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
|
|
446
446
|
max_time: Optional[float] = PULL_MAX_TIME,
|
|
447
447
|
max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
|
|
448
448
|
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
|
@@ -460,7 +460,7 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
|
|
|
460
460
|
A function to confirm that the object has been received.
|
|
461
461
|
return_type : type[T] (default: InflatableObject)
|
|
462
462
|
The type of the object to return. Must be a subclass of `InflatableObject`.
|
|
463
|
-
max_concurrent_pulls : int (default:
|
|
463
|
+
max_concurrent_pulls : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS)
|
|
464
464
|
The maximum number of concurrent pulls to perform.
|
|
465
465
|
max_time : Optional[float] (default: PULL_MAX_TIME)
|
|
466
466
|
The maximum time to wait for all pulls to complete. If `None`, waits
|
flwr/common/record/array.py
CHANGED
|
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, cast, overload
|
|
|
25
25
|
|
|
26
26
|
import numpy as np
|
|
27
27
|
|
|
28
|
-
from ..constant import
|
|
28
|
+
from ..constant import FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE, SType
|
|
29
29
|
from ..inflatable import (
|
|
30
30
|
InflatableObject,
|
|
31
31
|
add_header_to_object_body,
|
|
@@ -272,8 +272,8 @@ class Array(InflatableObject):
|
|
|
272
272
|
chunks: list[tuple[str, InflatableObject]] = []
|
|
273
273
|
# memoryview allows for zero-copy slicing
|
|
274
274
|
data_view = memoryview(self.data)
|
|
275
|
-
for start in range(0, len(data_view),
|
|
276
|
-
end = min(start +
|
|
275
|
+
for start in range(0, len(data_view), FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE):
|
|
276
|
+
end = min(start + FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE, len(data_view))
|
|
277
277
|
ac = ArrayChunk(data_view[start:end])
|
|
278
278
|
chunks.append((ac.object_id, ac))
|
|
279
279
|
|
|
@@ -147,11 +147,20 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
147
147
|
keep_input: bool = True,
|
|
148
148
|
) -> None: ...
|
|
149
149
|
|
|
150
|
+
# This is also required for PyTorch state dict because they are not strongly typed
|
|
151
|
+
@overload
|
|
152
|
+
def __init__( # noqa: E704
|
|
153
|
+
self,
|
|
154
|
+
torch_state_dict: dict[str, Any],
|
|
155
|
+
*,
|
|
156
|
+
keep_input: bool = True,
|
|
157
|
+
) -> None: ...
|
|
158
|
+
|
|
150
159
|
def __init__( # pylint: disable=too-many-arguments
|
|
151
160
|
self,
|
|
152
161
|
*args: Any,
|
|
153
162
|
numpy_ndarrays: list[NDArray] | None = None,
|
|
154
|
-
torch_state_dict: OrderedDict[str, torch.Tensor] | None = None,
|
|
163
|
+
torch_state_dict: OrderedDict[str, torch.Tensor] | dict[str, Any] | None = None,
|
|
155
164
|
array_dict: OrderedDict[str, Array] | None = None,
|
|
156
165
|
keep_input: bool = True,
|
|
157
166
|
) -> None:
|
|
@@ -16,57 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
|
-
from typing import cast
|
|
20
19
|
|
|
21
20
|
from cryptography.exceptions import InvalidSignature
|
|
22
21
|
from cryptography.fernet import Fernet
|
|
23
|
-
from cryptography.hazmat.primitives import hashes, hmac
|
|
22
|
+
from cryptography.hazmat.primitives import hashes, hmac
|
|
24
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
25
24
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
def generate_key_pairs() -> (
|
|
29
|
-
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
30
|
-
):
|
|
31
|
-
"""Generate private and public key pairs with Cryptography."""
|
|
32
|
-
private_key = ec.generate_private_key(ec.SECP384R1())
|
|
33
|
-
public_key = private_key.public_key()
|
|
34
|
-
return private_key, public_key
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def private_key_to_bytes(private_key: ec.EllipticCurvePrivateKey) -> bytes:
|
|
38
|
-
"""Serialize private key to bytes."""
|
|
39
|
-
return private_key.private_bytes(
|
|
40
|
-
encoding=serialization.Encoding.PEM,
|
|
41
|
-
format=serialization.PrivateFormat.PKCS8,
|
|
42
|
-
encryption_algorithm=serialization.NoEncryption(),
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def bytes_to_private_key(private_key_bytes: bytes) -> ec.EllipticCurvePrivateKey:
|
|
47
|
-
"""Deserialize private key from bytes."""
|
|
48
|
-
return cast(
|
|
49
|
-
ec.EllipticCurvePrivateKey,
|
|
50
|
-
serialization.load_pem_private_key(data=private_key_bytes, password=None),
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def public_key_to_bytes(public_key: ec.EllipticCurvePublicKey) -> bytes:
|
|
55
|
-
"""Serialize public key to bytes."""
|
|
56
|
-
return public_key.public_bytes(
|
|
57
|
-
encoding=serialization.Encoding.PEM,
|
|
58
|
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def bytes_to_public_key(public_key_bytes: bytes) -> ec.EllipticCurvePublicKey:
|
|
63
|
-
"""Deserialize public key from bytes."""
|
|
64
|
-
return cast(
|
|
65
|
-
ec.EllipticCurvePublicKey,
|
|
66
|
-
serialization.load_pem_public_key(data=public_key_bytes),
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
|
|
70
27
|
def generate_shared_key(
|
|
71
28
|
private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey
|
|
72
29
|
) -> bytes:
|
|
@@ -117,48 +74,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
|
117
74
|
return True
|
|
118
75
|
except InvalidSignature:
|
|
119
76
|
return False
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def sign_message(private_key: ec.EllipticCurvePrivateKey, message: bytes) -> bytes:
|
|
123
|
-
"""Sign a message using the provided EC private key.
|
|
124
|
-
|
|
125
|
-
Parameters
|
|
126
|
-
----------
|
|
127
|
-
private_key : ec.EllipticCurvePrivateKey
|
|
128
|
-
The EC private key to sign the message with.
|
|
129
|
-
message : bytes
|
|
130
|
-
The message to be signed.
|
|
131
|
-
|
|
132
|
-
Returns
|
|
133
|
-
-------
|
|
134
|
-
bytes
|
|
135
|
-
The signature of the message.
|
|
136
|
-
"""
|
|
137
|
-
signature = private_key.sign(message, ec.ECDSA(hashes.SHA256()))
|
|
138
|
-
return signature
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def verify_signature(
|
|
142
|
-
public_key: ec.EllipticCurvePublicKey, message: bytes, signature: bytes
|
|
143
|
-
) -> bool:
|
|
144
|
-
"""Verify a signature against a message using the provided EC public key.
|
|
145
|
-
|
|
146
|
-
Parameters
|
|
147
|
-
----------
|
|
148
|
-
public_key : ec.EllipticCurvePublicKey
|
|
149
|
-
The EC public key to verify the signature.
|
|
150
|
-
message : bytes
|
|
151
|
-
The original message.
|
|
152
|
-
signature : bytes
|
|
153
|
-
The signature to verify.
|
|
154
|
-
|
|
155
|
-
Returns
|
|
156
|
-
-------
|
|
157
|
-
bool
|
|
158
|
-
True if the signature is valid, False otherwise.
|
|
159
|
-
"""
|
|
160
|
-
try:
|
|
161
|
-
public_key.verify(signature, message, ec.ECDSA(hashes.SHA256()))
|
|
162
|
-
return True
|
|
163
|
-
except InvalidSignature:
|
|
164
|
-
return False
|