flwr 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +15 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +187 -35
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +92 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +53 -13
- flwr/common/exit/exit_code.py +20 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -35
- flwr/proto/control_pb2.pyi +71 -5
- flwr/proto/control_pb2_grpc.py +102 -0
- flwr/proto/control_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +139 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +3 -2
- flwr/supercore/constant.py +22 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/utils.py +20 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +13 -11
- flwr/superlink/servicer/control/control_servicer.py +152 -60
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/METADATA +1 -1
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/RECORD +107 -96
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import json
|
|
19
|
+
import secrets
|
|
19
20
|
import threading
|
|
20
21
|
import time
|
|
21
22
|
import traceback
|
|
@@ -27,12 +28,13 @@ from typing import Callable, Optional
|
|
|
27
28
|
from uuid import uuid4
|
|
28
29
|
|
|
29
30
|
from flwr.app.error import Error
|
|
30
|
-
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
31
|
-
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
32
31
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
|
+
from flwr.clientapp.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
33
|
+
from flwr.clientapp.utils import get_load_client_app_fn
|
|
33
34
|
from flwr.common import Message
|
|
34
35
|
from flwr.common.constant import (
|
|
35
|
-
|
|
36
|
+
HEARTBEAT_INTERVAL_INF,
|
|
37
|
+
NOOP_FLWR_AID,
|
|
36
38
|
NUM_PARTITIONS_KEY,
|
|
37
39
|
PARTITION_ID_KEY,
|
|
38
40
|
ErrorCode,
|
|
@@ -40,6 +42,7 @@ from flwr.common.constant import (
|
|
|
40
42
|
from flwr.common.logger import log
|
|
41
43
|
from flwr.common.typing import Run
|
|
42
44
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
45
|
+
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
|
|
43
46
|
|
|
44
47
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
45
48
|
|
|
@@ -53,7 +56,17 @@ def _register_nodes(
|
|
|
53
56
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
54
57
|
state = state_factory.state()
|
|
55
58
|
for i in range(num_nodes):
|
|
56
|
-
node_id = state.create_node(
|
|
59
|
+
node_id = state.create_node(
|
|
60
|
+
# No node authentication in simulation;
|
|
61
|
+
# use NOOP_FLWR_AID as owner_aid and
|
|
62
|
+
# use random bytes as public key
|
|
63
|
+
NOOP_FLWR_AID,
|
|
64
|
+
secrets.token_bytes(32),
|
|
65
|
+
heartbeat_interval=HEARTBEAT_INTERVAL_INF,
|
|
66
|
+
)
|
|
67
|
+
state.acknowledge_node_heartbeat(
|
|
68
|
+
node_id=node_id, heartbeat_interval=HEARTBEAT_INTERVAL_INF
|
|
69
|
+
)
|
|
57
70
|
nodes_mapping[node_id] = i
|
|
58
71
|
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
|
59
72
|
return nodes_mapping
|
|
@@ -300,7 +313,7 @@ def start_vce(
|
|
|
300
313
|
if not state_factory:
|
|
301
314
|
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
|
302
315
|
# Create an empty in-memory state factory
|
|
303
|
-
state_factory = LinkStateFactory(
|
|
316
|
+
state_factory = LinkStateFactory(FLWR_IN_MEMORY_DB_NAME)
|
|
304
317
|
log(INFO, "Created new %s.", state_factory.__class__.__name__)
|
|
305
318
|
|
|
306
319
|
if num_supernodes:
|
|
@@ -17,17 +17,18 @@
|
|
|
17
17
|
|
|
18
18
|
import secrets
|
|
19
19
|
import threading
|
|
20
|
-
import time
|
|
21
20
|
from bisect import bisect_right
|
|
22
21
|
from collections import defaultdict
|
|
22
|
+
from collections.abc import Sequence
|
|
23
23
|
from dataclasses import dataclass, field
|
|
24
|
+
from datetime import datetime, timezone
|
|
24
25
|
from logging import ERROR, WARNING
|
|
25
26
|
from typing import Optional
|
|
26
27
|
|
|
27
28
|
from flwr.common import Context, Message, log, now
|
|
28
29
|
from flwr.common.constant import (
|
|
29
30
|
FLWR_APP_TOKEN_LENGTH,
|
|
30
|
-
|
|
31
|
+
HEARTBEAT_INTERVAL_INF,
|
|
31
32
|
HEARTBEAT_PATIENCE,
|
|
32
33
|
MESSAGE_TTL_TOLERANCE,
|
|
33
34
|
NODE_ID_NUM_BYTES,
|
|
@@ -39,8 +40,10 @@ from flwr.common.constant import (
|
|
|
39
40
|
)
|
|
40
41
|
from flwr.common.record import ConfigRecord
|
|
41
42
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
43
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
42
44
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
43
45
|
from flwr.server.utils import validate_message
|
|
46
|
+
from flwr.supercore.constant import NodeStatus
|
|
44
47
|
|
|
45
48
|
from .utils import (
|
|
46
49
|
check_node_availability_for_in_message,
|
|
@@ -69,10 +72,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
69
72
|
|
|
70
73
|
def __init__(self) -> None:
|
|
71
74
|
|
|
72
|
-
# Map node_id to
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
75
|
+
# Map node_id to NodeInfo
|
|
76
|
+
self.nodes: dict[int, NodeInfo] = {}
|
|
77
|
+
self.node_public_key_to_node_id: dict[bytes, int] = {}
|
|
78
|
+
self.owner_to_node_ids: dict[str, set[int]] = {} # Quick lookup
|
|
76
79
|
|
|
77
80
|
# Map run_id to RunRecord
|
|
78
81
|
self.run_ids: dict[int, RunRecord] = {}
|
|
@@ -114,7 +117,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
114
117
|
)
|
|
115
118
|
return None
|
|
116
119
|
# Validate destination node ID
|
|
117
|
-
|
|
120
|
+
dst_node = self.nodes.get(message.metadata.dst_node_id)
|
|
121
|
+
if dst_node is None or dst_node.status not in [
|
|
122
|
+
NodeStatus.ONLINE,
|
|
123
|
+
NodeStatus.OFFLINE,
|
|
124
|
+
]:
|
|
118
125
|
log(
|
|
119
126
|
ERROR,
|
|
120
127
|
"Invalid destination node ID for Message: %s",
|
|
@@ -136,7 +143,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
136
143
|
|
|
137
144
|
# Find Message for node_id that were not delivered yet
|
|
138
145
|
message_ins_list: list[Message] = []
|
|
139
|
-
current_time =
|
|
146
|
+
current_time = now().timestamp()
|
|
140
147
|
with self.lock:
|
|
141
148
|
for _, msg_ins in self.message_ins_store.items():
|
|
142
149
|
if (
|
|
@@ -190,7 +197,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
190
197
|
return None
|
|
191
198
|
|
|
192
199
|
ins_metadata = msg_ins.metadata
|
|
193
|
-
if ins_metadata.created_at + ins_metadata.ttl <=
|
|
200
|
+
if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
|
|
194
201
|
log(
|
|
195
202
|
ERROR,
|
|
196
203
|
"Failed to store Message: the message it is replying to "
|
|
@@ -238,7 +245,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
238
245
|
ret: dict[str, Message] = {}
|
|
239
246
|
|
|
240
247
|
with self.lock:
|
|
241
|
-
current =
|
|
248
|
+
current = now().timestamp()
|
|
242
249
|
|
|
243
250
|
# Verify Message IDs
|
|
244
251
|
ret = verify_message_ids(
|
|
@@ -256,9 +263,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
256
263
|
inquired_in_message_ids=message_ids,
|
|
257
264
|
found_in_message_dict=self.message_ins_store,
|
|
258
265
|
node_id_to_online_until={
|
|
259
|
-
node_id: self.
|
|
266
|
+
node_id: self.nodes[node_id].online_until
|
|
260
267
|
for node_id in dst_node_ids
|
|
261
|
-
if node_id in self.
|
|
268
|
+
if node_id in self.nodes
|
|
269
|
+
and self.nodes[node_id].status != NodeStatus.UNREGISTERED
|
|
262
270
|
},
|
|
263
271
|
current_time=current,
|
|
264
272
|
)
|
|
@@ -330,7 +338,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
330
338
|
"""
|
|
331
339
|
return len(self.message_res_store)
|
|
332
340
|
|
|
333
|
-
def create_node(
|
|
341
|
+
def create_node(
|
|
342
|
+
self, owner_aid: str, public_key: bytes, heartbeat_interval: float
|
|
343
|
+
) -> int:
|
|
334
344
|
"""Create, store in the link state, and return `node_id`."""
|
|
335
345
|
# Sample a random int64 as node_id
|
|
336
346
|
node_id = generate_rand_int_from_bytes(
|
|
@@ -338,28 +348,88 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
338
348
|
)
|
|
339
349
|
|
|
340
350
|
with self.lock:
|
|
341
|
-
if node_id in self.
|
|
351
|
+
if node_id in self.nodes:
|
|
342
352
|
log(ERROR, "Unexpected node registration failure.")
|
|
343
353
|
return 0
|
|
354
|
+
if public_key in self.node_public_key_to_node_id:
|
|
355
|
+
raise ValueError("Public key already in use")
|
|
344
356
|
|
|
345
|
-
#
|
|
346
|
-
self.
|
|
347
|
-
|
|
348
|
-
|
|
357
|
+
# The node is not activated upon creation
|
|
358
|
+
self.nodes[node_id] = NodeInfo(
|
|
359
|
+
node_id=node_id,
|
|
360
|
+
owner_aid=owner_aid,
|
|
361
|
+
status=NodeStatus.REGISTERED,
|
|
362
|
+
registered_at=now().isoformat(),
|
|
363
|
+
last_activated_at=None,
|
|
364
|
+
last_deactivated_at=None,
|
|
365
|
+
unregistered_at=None,
|
|
366
|
+
online_until=None,
|
|
367
|
+
heartbeat_interval=heartbeat_interval,
|
|
368
|
+
public_key=public_key,
|
|
349
369
|
)
|
|
370
|
+
self.node_public_key_to_node_id[public_key] = node_id
|
|
371
|
+
self.owner_to_node_ids.setdefault(owner_aid, set()).add(node_id)
|
|
350
372
|
return node_id
|
|
351
373
|
|
|
352
|
-
def delete_node(self, node_id: int) -> None:
|
|
374
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
|
353
375
|
"""Delete a node."""
|
|
354
376
|
with self.lock:
|
|
355
|
-
if
|
|
356
|
-
|
|
377
|
+
if (
|
|
378
|
+
not (node := self.nodes.get(node_id))
|
|
379
|
+
or node.status == NodeStatus.UNREGISTERED
|
|
380
|
+
or owner_aid != self.nodes[node_id].owner_aid
|
|
381
|
+
):
|
|
382
|
+
raise ValueError(
|
|
383
|
+
f"Node ID {node_id} already unregistered, not found or "
|
|
384
|
+
"the request was unauthorized."
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
node.status = NodeStatus.UNREGISTERED
|
|
388
|
+
current = now()
|
|
389
|
+
node.unregistered_at = current.isoformat()
|
|
390
|
+
# Set online_until to current timestamp on deletion, if it is in the future
|
|
391
|
+
node.online_until = min(node.online_until, current.timestamp())
|
|
392
|
+
|
|
393
|
+
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
394
|
+
"""Activate the node with the specified `node_id`."""
|
|
395
|
+
with self.lock:
|
|
396
|
+
self._check_and_tag_offline_nodes(node_ids=[node_id])
|
|
357
397
|
|
|
358
|
-
#
|
|
359
|
-
if
|
|
360
|
-
|
|
398
|
+
# Check if the node exists
|
|
399
|
+
if not (node := self.nodes.get(node_id)):
|
|
400
|
+
return False
|
|
401
|
+
|
|
402
|
+
# Only activate if the node is currently registered or offline
|
|
403
|
+
current_dt = now()
|
|
404
|
+
if node.status in (NodeStatus.REGISTERED, NodeStatus.OFFLINE):
|
|
405
|
+
node.status = NodeStatus.ONLINE
|
|
406
|
+
node.last_activated_at = current_dt.isoformat()
|
|
407
|
+
node.online_until = (
|
|
408
|
+
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
409
|
+
)
|
|
410
|
+
node.heartbeat_interval = heartbeat_interval
|
|
411
|
+
return True
|
|
412
|
+
return False
|
|
413
|
+
|
|
414
|
+
def deactivate_node(self, node_id: int) -> bool:
|
|
415
|
+
"""Deactivate the node with the specified `node_id`."""
|
|
416
|
+
with self.lock:
|
|
417
|
+
self._check_and_tag_offline_nodes(node_ids=[node_id])
|
|
418
|
+
|
|
419
|
+
# Check if the node exists
|
|
420
|
+
if not (node := self.nodes.get(node_id)):
|
|
421
|
+
return False
|
|
361
422
|
|
|
362
|
-
|
|
423
|
+
# Only deactivate if the node is currently online
|
|
424
|
+
current_dt = now()
|
|
425
|
+
if node.status == NodeStatus.ONLINE:
|
|
426
|
+
node.status = NodeStatus.OFFLINE
|
|
427
|
+
node.last_deactivated_at = current_dt.isoformat()
|
|
428
|
+
|
|
429
|
+
# Set online_until to current timestamp
|
|
430
|
+
node.online_until = current_dt.timestamp()
|
|
431
|
+
return True
|
|
432
|
+
return False
|
|
363
433
|
|
|
364
434
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
365
435
|
"""Return all available nodes.
|
|
@@ -372,36 +442,70 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
372
442
|
with self.lock:
|
|
373
443
|
if run_id not in self.run_ids:
|
|
374
444
|
return set()
|
|
375
|
-
current_time = time.time()
|
|
376
445
|
return {
|
|
377
|
-
node_id
|
|
378
|
-
for
|
|
379
|
-
if online_until > current_time
|
|
446
|
+
node.node_id
|
|
447
|
+
for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
380
448
|
}
|
|
381
449
|
|
|
382
|
-
def
|
|
383
|
-
|
|
450
|
+
def get_node_info(
|
|
451
|
+
self,
|
|
452
|
+
*,
|
|
453
|
+
node_ids: Optional[Sequence[int]] = None,
|
|
454
|
+
owner_aids: Optional[Sequence[str]] = None,
|
|
455
|
+
statuses: Optional[Sequence[str]] = None,
|
|
456
|
+
) -> Sequence[NodeInfo]:
|
|
457
|
+
"""Retrieve information about nodes based on the specified filters."""
|
|
384
458
|
with self.lock:
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
459
|
+
self._check_and_tag_offline_nodes()
|
|
460
|
+
result = []
|
|
461
|
+
for node_id in self.nodes.keys() if node_ids is None else node_ids:
|
|
462
|
+
if (node := self.nodes.get(node_id)) is None:
|
|
463
|
+
continue
|
|
464
|
+
if owner_aids is not None and node.owner_aid not in owner_aids:
|
|
465
|
+
continue
|
|
466
|
+
if statuses is not None and node.status not in statuses:
|
|
467
|
+
continue
|
|
468
|
+
result.append(node)
|
|
469
|
+
return result
|
|
470
|
+
|
|
471
|
+
def _check_and_tag_offline_nodes(
|
|
472
|
+
self, node_ids: Optional[list[int]] = None
|
|
473
|
+
) -> None:
|
|
474
|
+
with self.lock:
|
|
475
|
+
# Set all nodes of "online" status to "offline" if they've offline
|
|
476
|
+
current_ts = now().timestamp()
|
|
477
|
+
for node_id in node_ids or self.nodes.keys():
|
|
478
|
+
if (node := self.nodes.get(node_id)) is None:
|
|
479
|
+
continue
|
|
480
|
+
if node.status == NodeStatus.ONLINE:
|
|
481
|
+
if node.online_until <= current_ts:
|
|
482
|
+
node.status = NodeStatus.OFFLINE
|
|
483
|
+
node.last_deactivated_at = datetime.fromtimestamp(
|
|
484
|
+
node.online_until, tz=timezone.utc
|
|
485
|
+
).isoformat()
|
|
486
|
+
|
|
487
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
|
395
488
|
"""Get `public_key` for the specified `node_id`."""
|
|
396
489
|
with self.lock:
|
|
397
|
-
if
|
|
398
|
-
|
|
490
|
+
if (
|
|
491
|
+
node := self.nodes.get(node_id)
|
|
492
|
+
) is None or node.status == NodeStatus.UNREGISTERED:
|
|
493
|
+
raise ValueError(f"Node ID {node_id} not found")
|
|
494
|
+
return node.public_key
|
|
495
|
+
|
|
496
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
|
|
497
|
+
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
498
|
+
deleted."""
|
|
499
|
+
with self.lock:
|
|
500
|
+
node_id = self.node_public_key_to_node_id.get(public_key)
|
|
399
501
|
|
|
400
|
-
|
|
502
|
+
if node_id is None:
|
|
503
|
+
return None
|
|
401
504
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
505
|
+
node_info = self.nodes[node_id]
|
|
506
|
+
if node_info.status == NodeStatus.UNREGISTERED:
|
|
507
|
+
return None
|
|
508
|
+
return node_id
|
|
405
509
|
|
|
406
510
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
407
511
|
def create_run(
|
|
@@ -449,26 +553,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
449
553
|
log(ERROR, "Unexpected run creation failure.")
|
|
450
554
|
return 0
|
|
451
555
|
|
|
452
|
-
def clear_supernode_auth_keys(self) -> None:
|
|
453
|
-
"""Clear stored `node_public_keys` in the link state if any."""
|
|
454
|
-
with self.lock:
|
|
455
|
-
self.node_public_keys.clear()
|
|
456
|
-
|
|
457
|
-
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
458
|
-
"""Store a set of `node_public_keys` in the link state."""
|
|
459
|
-
with self.lock:
|
|
460
|
-
self.node_public_keys.update(public_keys)
|
|
461
|
-
|
|
462
|
-
def store_node_public_key(self, public_key: bytes) -> None:
|
|
463
|
-
"""Store a `node_public_key` in the link state."""
|
|
464
|
-
with self.lock:
|
|
465
|
-
self.node_public_keys.add(public_key)
|
|
466
|
-
|
|
467
|
-
def get_node_public_keys(self) -> set[bytes]:
|
|
468
|
-
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
469
|
-
with self.lock:
|
|
470
|
-
return self.node_public_keys.copy()
|
|
471
|
-
|
|
472
556
|
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
|
473
557
|
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
474
558
|
|
|
@@ -561,7 +645,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
561
645
|
current = now()
|
|
562
646
|
run_record = self.run_ids[run_id]
|
|
563
647
|
if new_status.status in (Status.STARTING, Status.RUNNING):
|
|
564
|
-
run_record.heartbeat_interval =
|
|
648
|
+
run_record.heartbeat_interval = HEARTBEAT_INTERVAL_INF
|
|
565
649
|
run_record.active_until = (
|
|
566
650
|
current.timestamp() + run_record.heartbeat_interval
|
|
567
651
|
)
|
|
@@ -608,13 +692,23 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
608
692
|
the node is marked as offline.
|
|
609
693
|
"""
|
|
610
694
|
with self.lock:
|
|
611
|
-
if
|
|
612
|
-
self.
|
|
613
|
-
|
|
614
|
-
|
|
695
|
+
if (
|
|
696
|
+
node := self.nodes.get(node_id)
|
|
697
|
+
) and node.status != NodeStatus.UNREGISTERED:
|
|
698
|
+
current_dt = now()
|
|
699
|
+
|
|
700
|
+
# Set timestamp if the status changes
|
|
701
|
+
if node.status != NodeStatus.ONLINE: # offline or registered
|
|
702
|
+
node.status = NodeStatus.ONLINE
|
|
703
|
+
node.last_activated_at = current_dt.isoformat()
|
|
704
|
+
|
|
705
|
+
# Refresh `online_until` and `heartbeat_interval`
|
|
706
|
+
node.online_until = (
|
|
707
|
+
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
615
708
|
)
|
|
709
|
+
node.heartbeat_interval = heartbeat_interval
|
|
616
710
|
return True
|
|
617
|
-
|
|
711
|
+
return False
|
|
618
712
|
|
|
619
713
|
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
|
620
714
|
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
|
@@ -16,11 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
+
from collections.abc import Sequence
|
|
19
20
|
from typing import Optional
|
|
20
21
|
|
|
21
22
|
from flwr.common import Context, Message
|
|
22
23
|
from flwr.common.record import ConfigRecord
|
|
23
24
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
25
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
24
26
|
from flwr.supercore.corestate import CoreState
|
|
25
27
|
|
|
26
28
|
|
|
@@ -128,13 +130,54 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
|
128
130
|
"""Get all instruction Message IDs for the given run_id."""
|
|
129
131
|
|
|
130
132
|
@abc.abstractmethod
|
|
131
|
-
def create_node(
|
|
133
|
+
def create_node(
|
|
134
|
+
self, owner_aid: str, public_key: bytes, heartbeat_interval: float
|
|
135
|
+
) -> int:
|
|
132
136
|
"""Create, store in the link state, and return `node_id`."""
|
|
133
137
|
|
|
134
138
|
@abc.abstractmethod
|
|
135
|
-
def delete_node(self, node_id: int) -> None:
|
|
139
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
|
136
140
|
"""Remove `node_id` from the link state."""
|
|
137
141
|
|
|
142
|
+
@abc.abstractmethod
|
|
143
|
+
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
144
|
+
"""Activate the node with the specified `node_id`.
|
|
145
|
+
|
|
146
|
+
Transitions the node status to "online". The transition will fail
|
|
147
|
+
if the current status is not "registered" or "offline".
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
node_id : int
|
|
152
|
+
The identifier of the node to activate.
|
|
153
|
+
heartbeat_interval : float
|
|
154
|
+
The interval (in seconds) from the current timestamp within which
|
|
155
|
+
the next heartbeat from this node is expected to be received.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
bool
|
|
160
|
+
True if the status transition was successful, False otherwise.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
@abc.abstractmethod
|
|
164
|
+
def deactivate_node(self, node_id: int) -> bool:
|
|
165
|
+
"""Deactivate the node with the specified `node_id`.
|
|
166
|
+
|
|
167
|
+
Transitions the node status to "offline". The transition will fail
|
|
168
|
+
if the current status is not "online".
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
node_id : int
|
|
173
|
+
The identifier of the node to deactivate.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
bool
|
|
178
|
+
True if the status transition was successful, False otherwise.
|
|
179
|
+
"""
|
|
180
|
+
|
|
138
181
|
@abc.abstractmethod
|
|
139
182
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
140
183
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -146,16 +189,72 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
|
146
189
|
"""
|
|
147
190
|
|
|
148
191
|
@abc.abstractmethod
|
|
149
|
-
def
|
|
150
|
-
"""
|
|
192
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
|
|
193
|
+
"""Get `node_id` for the specified `public_key` if it exists and is not deleted.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
public_key : bytes
|
|
198
|
+
The public key of the node whose information is to be retrieved.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
Optional[int]
|
|
203
|
+
The `node_id` associated with the specified `public_key` if it exists
|
|
204
|
+
and is not deleted; otherwise, `None`.
|
|
205
|
+
"""
|
|
151
206
|
|
|
152
207
|
@abc.abstractmethod
|
|
153
|
-
def
|
|
154
|
-
|
|
208
|
+
def get_node_info(
|
|
209
|
+
self,
|
|
210
|
+
*,
|
|
211
|
+
node_ids: Optional[Sequence[int]] = None,
|
|
212
|
+
owner_aids: Optional[Sequence[str]] = None,
|
|
213
|
+
statuses: Optional[Sequence[str]] = None,
|
|
214
|
+
) -> Sequence[NodeInfo]:
|
|
215
|
+
"""Retrieve information about nodes based on the specified filters.
|
|
216
|
+
|
|
217
|
+
If a filter is set to None, it is ignored.
|
|
218
|
+
If multiple filters are provided, they are combined using AND logic.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
node_ids : Optional[Sequence[int]] (default: None)
|
|
223
|
+
Sequence of node IDs to filter by. If a sequence is provided,
|
|
224
|
+
it is treated as an OR condition.
|
|
225
|
+
owner_aids : Optional[Sequence[str]] (default: None)
|
|
226
|
+
Sequence of owner account IDs to filter by. If a sequence is provided,
|
|
227
|
+
it is treated as an OR condition.
|
|
228
|
+
statuses : Optional[Sequence[str]] (default: None)
|
|
229
|
+
Sequence of node status values (e.g., "created", "activated")
|
|
230
|
+
to filter by. If a sequence is provided, it is treated as an OR condition.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
Sequence[NodeInfo]
|
|
235
|
+
A sequence of NodeInfo objects representing the nodes matching
|
|
236
|
+
the specified filters.
|
|
237
|
+
"""
|
|
155
238
|
|
|
156
239
|
@abc.abstractmethod
|
|
157
|
-
def
|
|
158
|
-
"""
|
|
240
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
|
241
|
+
"""Get `public_key` for the specified `node_id`.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
node_id : int
|
|
246
|
+
The identifier of the node whose public key is to be retrieved.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
bytes
|
|
251
|
+
The public key associated with the specified `node_id`.
|
|
252
|
+
|
|
253
|
+
Raises
|
|
254
|
+
------
|
|
255
|
+
ValueError
|
|
256
|
+
If the specified `node_id` does not exist in the link state.
|
|
257
|
+
"""
|
|
159
258
|
|
|
160
259
|
@abc.abstractmethod
|
|
161
260
|
def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
@@ -254,22 +353,6 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
|
254
353
|
The federation options for the run if it exists; None otherwise.
|
|
255
354
|
"""
|
|
256
355
|
|
|
257
|
-
@abc.abstractmethod
|
|
258
|
-
def clear_supernode_auth_keys(self) -> None:
|
|
259
|
-
"""Clear stored `node_public_keys` in the link state if any."""
|
|
260
|
-
|
|
261
|
-
@abc.abstractmethod
|
|
262
|
-
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
263
|
-
"""Store a set of `node_public_keys` in the link state."""
|
|
264
|
-
|
|
265
|
-
@abc.abstractmethod
|
|
266
|
-
def store_node_public_key(self, public_key: bytes) -> None:
|
|
267
|
-
"""Store a `node_public_key` in the link state."""
|
|
268
|
-
|
|
269
|
-
@abc.abstractmethod
|
|
270
|
-
def get_node_public_keys(self) -> set[bytes]:
|
|
271
|
-
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
272
|
-
|
|
273
356
|
@abc.abstractmethod
|
|
274
357
|
def acknowledge_node_heartbeat(
|
|
275
358
|
self, node_id: int, heartbeat_interval: float
|
|
@@ -19,6 +19,7 @@ from logging import DEBUG
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr.common.logger import log
|
|
22
|
+
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
|
|
22
23
|
|
|
23
24
|
from .in_memory_linkstate import InMemoryLinkState
|
|
24
25
|
from .linkstate import LinkState
|
|
@@ -44,7 +45,7 @@ class LinkStateFactory:
|
|
|
44
45
|
def state(self) -> LinkState:
|
|
45
46
|
"""Return a State instance and create it, if necessary."""
|
|
46
47
|
# InMemoryState
|
|
47
|
-
if self.database ==
|
|
48
|
+
if self.database == FLWR_IN_MEMORY_DB_NAME:
|
|
48
49
|
if self.state_instance is None:
|
|
49
50
|
self.state_instance = InMemoryLinkState()
|
|
50
51
|
log(DEBUG, "Using InMemoryState")
|