flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import json
19
+ import secrets
19
20
  import threading
20
21
  import time
21
22
  import traceback
@@ -27,12 +28,13 @@ from typing import Callable, Optional
27
28
  from uuid import uuid4
28
29
 
29
30
  from flwr.app.error import Error
30
- from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
31
- from flwr.client.clientapp.utils import get_load_client_app_fn
32
31
  from flwr.client.run_info_store import DeprecatedRunInfoStore
32
+ from flwr.clientapp.client_app import ClientApp, ClientAppException, LoadClientAppError
33
+ from flwr.clientapp.utils import get_load_client_app_fn
33
34
  from flwr.common import Message
34
35
  from flwr.common.constant import (
35
- HEARTBEAT_MAX_INTERVAL,
36
+ HEARTBEAT_INTERVAL_INF,
37
+ NOOP_FLWR_AID,
36
38
  NUM_PARTITIONS_KEY,
37
39
  PARTITION_ID_KEY,
38
40
  ErrorCode,
@@ -40,6 +42,7 @@ from flwr.common.constant import (
40
42
  from flwr.common.logger import log
41
43
  from flwr.common.typing import Run
42
44
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
45
+ from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
43
46
 
44
47
  from .backend import Backend, error_messages_backends, supported_backends
45
48
 
@@ -53,7 +56,17 @@ def _register_nodes(
53
56
  nodes_mapping: NodeToPartitionMapping = {}
54
57
  state = state_factory.state()
55
58
  for i in range(num_nodes):
56
- node_id = state.create_node(heartbeat_interval=HEARTBEAT_MAX_INTERVAL)
59
+ node_id = state.create_node(
60
+ # No node authentication in simulation;
61
+ # use NOOP_FLWR_AID as owner_aid and
62
+ # use random bytes as public key
63
+ NOOP_FLWR_AID,
64
+ secrets.token_bytes(32),
65
+ heartbeat_interval=HEARTBEAT_INTERVAL_INF,
66
+ )
67
+ state.acknowledge_node_heartbeat(
68
+ node_id=node_id, heartbeat_interval=HEARTBEAT_INTERVAL_INF
69
+ )
57
70
  nodes_mapping[node_id] = i
58
71
  log(DEBUG, "Registered %i nodes", len(nodes_mapping))
59
72
  return nodes_mapping
@@ -300,7 +313,7 @@ def start_vce(
300
313
  if not state_factory:
301
314
  log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
302
315
  # Create an empty in-memory state factory
303
- state_factory = LinkStateFactory(":flwr-in-memory-state:")
316
+ state_factory = LinkStateFactory(FLWR_IN_MEMORY_DB_NAME)
304
317
  log(INFO, "Created new %s.", state_factory.__class__.__name__)
305
318
 
306
319
  if num_supernodes:
@@ -17,17 +17,18 @@
17
17
 
18
18
  import secrets
19
19
  import threading
20
- import time
21
20
  from bisect import bisect_right
22
21
  from collections import defaultdict
22
+ from collections.abc import Sequence
23
23
  from dataclasses import dataclass, field
24
+ from datetime import datetime, timezone
24
25
  from logging import ERROR, WARNING
25
26
  from typing import Optional
26
27
 
27
28
  from flwr.common import Context, Message, log, now
28
29
  from flwr.common.constant import (
29
30
  FLWR_APP_TOKEN_LENGTH,
30
- HEARTBEAT_MAX_INTERVAL,
31
+ HEARTBEAT_INTERVAL_INF,
31
32
  HEARTBEAT_PATIENCE,
32
33
  MESSAGE_TTL_TOLERANCE,
33
34
  NODE_ID_NUM_BYTES,
@@ -39,8 +40,10 @@ from flwr.common.constant import (
39
40
  )
40
41
  from flwr.common.record import ConfigRecord
41
42
  from flwr.common.typing import Run, RunStatus, UserConfig
43
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
42
44
  from flwr.server.superlink.linkstate.linkstate import LinkState
43
45
  from flwr.server.utils import validate_message
46
+ from flwr.supercore.constant import NodeStatus
44
47
 
45
48
  from .utils import (
46
49
  check_node_availability_for_in_message,
@@ -69,10 +72,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
69
72
 
70
73
  def __init__(self) -> None:
71
74
 
72
- # Map node_id to (online_until, heartbeat_interval)
73
- self.node_ids: dict[int, tuple[float, float]] = {}
74
- self.public_key_to_node_id: dict[bytes, int] = {}
75
- self.node_id_to_public_key: dict[int, bytes] = {}
75
+ # Map node_id to NodeInfo
76
+ self.nodes: dict[int, NodeInfo] = {}
77
+ self.node_public_key_to_node_id: dict[bytes, int] = {}
78
+ self.owner_to_node_ids: dict[str, set[int]] = {} # Quick lookup
76
79
 
77
80
  # Map run_id to RunRecord
78
81
  self.run_ids: dict[int, RunRecord] = {}
@@ -114,7 +117,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
114
117
  )
115
118
  return None
116
119
  # Validate destination node ID
117
- if message.metadata.dst_node_id not in self.node_ids:
120
+ dst_node = self.nodes.get(message.metadata.dst_node_id)
121
+ if dst_node is None or dst_node.status not in [
122
+ NodeStatus.ONLINE,
123
+ NodeStatus.OFFLINE,
124
+ ]:
118
125
  log(
119
126
  ERROR,
120
127
  "Invalid destination node ID for Message: %s",
@@ -136,7 +143,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
136
143
 
137
144
  # Find Message for node_id that were not delivered yet
138
145
  message_ins_list: list[Message] = []
139
- current_time = time.time()
146
+ current_time = now().timestamp()
140
147
  with self.lock:
141
148
  for _, msg_ins in self.message_ins_store.items():
142
149
  if (
@@ -190,7 +197,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
190
197
  return None
191
198
 
192
199
  ins_metadata = msg_ins.metadata
193
- if ins_metadata.created_at + ins_metadata.ttl <= time.time():
200
+ if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
194
201
  log(
195
202
  ERROR,
196
203
  "Failed to store Message: the message it is replying to "
@@ -238,7 +245,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
238
245
  ret: dict[str, Message] = {}
239
246
 
240
247
  with self.lock:
241
- current = time.time()
248
+ current = now().timestamp()
242
249
 
243
250
  # Verify Message IDs
244
251
  ret = verify_message_ids(
@@ -256,9 +263,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
256
263
  inquired_in_message_ids=message_ids,
257
264
  found_in_message_dict=self.message_ins_store,
258
265
  node_id_to_online_until={
259
- node_id: self.node_ids[node_id][0]
266
+ node_id: self.nodes[node_id].online_until
260
267
  for node_id in dst_node_ids
261
- if node_id in self.node_ids
268
+ if node_id in self.nodes
269
+ and self.nodes[node_id].status != NodeStatus.UNREGISTERED
262
270
  },
263
271
  current_time=current,
264
272
  )
@@ -330,7 +338,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
330
338
  """
331
339
  return len(self.message_res_store)
332
340
 
333
- def create_node(self, heartbeat_interval: float) -> int:
341
+ def create_node(
342
+ self, owner_aid: str, public_key: bytes, heartbeat_interval: float
343
+ ) -> int:
334
344
  """Create, store in the link state, and return `node_id`."""
335
345
  # Sample a random int64 as node_id
336
346
  node_id = generate_rand_int_from_bytes(
@@ -338,28 +348,88 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
338
348
  )
339
349
 
340
350
  with self.lock:
341
- if node_id in self.node_ids:
351
+ if node_id in self.nodes:
342
352
  log(ERROR, "Unexpected node registration failure.")
343
353
  return 0
354
+ if public_key in self.node_public_key_to_node_id:
355
+ raise ValueError("Public key already in use")
344
356
 
345
- # Mark the node online until time.time() + heartbeat_interval
346
- self.node_ids[node_id] = (
347
- time.time() + heartbeat_interval,
348
- heartbeat_interval,
357
+ # The node is not activated upon creation
358
+ self.nodes[node_id] = NodeInfo(
359
+ node_id=node_id,
360
+ owner_aid=owner_aid,
361
+ status=NodeStatus.REGISTERED,
362
+ registered_at=now().isoformat(),
363
+ last_activated_at=None,
364
+ last_deactivated_at=None,
365
+ unregistered_at=None,
366
+ online_until=None,
367
+ heartbeat_interval=heartbeat_interval,
368
+ public_key=public_key,
349
369
  )
370
+ self.node_public_key_to_node_id[public_key] = node_id
371
+ self.owner_to_node_ids.setdefault(owner_aid, set()).add(node_id)
350
372
  return node_id
351
373
 
352
- def delete_node(self, node_id: int) -> None:
374
+ def delete_node(self, owner_aid: str, node_id: int) -> None:
353
375
  """Delete a node."""
354
376
  with self.lock:
355
- if node_id not in self.node_ids:
356
- raise ValueError(f"Node {node_id} not found")
377
+ if (
378
+ not (node := self.nodes.get(node_id))
379
+ or node.status == NodeStatus.UNREGISTERED
380
+ or owner_aid != self.nodes[node_id].owner_aid
381
+ ):
382
+ raise ValueError(
383
+ f"Node ID {node_id} already unregistered, not found or "
384
+ "the request was unauthorized."
385
+ )
386
+
387
+ node.status = NodeStatus.UNREGISTERED
388
+ current = now()
389
+ node.unregistered_at = current.isoformat()
390
+ # Set online_until to current timestamp on deletion, if it is in the future
391
+ node.online_until = min(node.online_until, current.timestamp())
392
+
393
+ def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
394
+ """Activate the node with the specified `node_id`."""
395
+ with self.lock:
396
+ self._check_and_tag_offline_nodes(node_ids=[node_id])
357
397
 
358
- # Remove node ID <> public key mappings
359
- if pk := self.node_id_to_public_key.pop(node_id, None):
360
- del self.public_key_to_node_id[pk]
398
+ # Check if the node exists
399
+ if not (node := self.nodes.get(node_id)):
400
+ return False
401
+
402
+ # Only activate if the node is currently registered or offline
403
+ current_dt = now()
404
+ if node.status in (NodeStatus.REGISTERED, NodeStatus.OFFLINE):
405
+ node.status = NodeStatus.ONLINE
406
+ node.last_activated_at = current_dt.isoformat()
407
+ node.online_until = (
408
+ current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
409
+ )
410
+ node.heartbeat_interval = heartbeat_interval
411
+ return True
412
+ return False
413
+
414
+ def deactivate_node(self, node_id: int) -> bool:
415
+ """Deactivate the node with the specified `node_id`."""
416
+ with self.lock:
417
+ self._check_and_tag_offline_nodes(node_ids=[node_id])
418
+
419
+ # Check if the node exists
420
+ if not (node := self.nodes.get(node_id)):
421
+ return False
361
422
 
362
- del self.node_ids[node_id]
423
+ # Only deactivate if the node is currently online
424
+ current_dt = now()
425
+ if node.status == NodeStatus.ONLINE:
426
+ node.status = NodeStatus.OFFLINE
427
+ node.last_deactivated_at = current_dt.isoformat()
428
+
429
+ # Set online_until to current timestamp
430
+ node.online_until = current_dt.timestamp()
431
+ return True
432
+ return False
363
433
 
364
434
  def get_nodes(self, run_id: int) -> set[int]:
365
435
  """Return all available nodes.
@@ -372,36 +442,70 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
372
442
  with self.lock:
373
443
  if run_id not in self.run_ids:
374
444
  return set()
375
- current_time = time.time()
376
445
  return {
377
- node_id
378
- for node_id, (online_until, _) in self.node_ids.items()
379
- if online_until > current_time
446
+ node.node_id
447
+ for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
380
448
  }
381
449
 
382
- def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
383
- """Set `public_key` for the specified `node_id`."""
450
+ def get_node_info(
451
+ self,
452
+ *,
453
+ node_ids: Optional[Sequence[int]] = None,
454
+ owner_aids: Optional[Sequence[str]] = None,
455
+ statuses: Optional[Sequence[str]] = None,
456
+ ) -> Sequence[NodeInfo]:
457
+ """Retrieve information about nodes based on the specified filters."""
384
458
  with self.lock:
385
- if node_id not in self.node_ids:
386
- raise ValueError(f"Node {node_id} not found")
387
-
388
- if public_key in self.public_key_to_node_id:
389
- raise ValueError("Public key already in use")
390
-
391
- self.public_key_to_node_id[public_key] = node_id
392
- self.node_id_to_public_key[node_id] = public_key
393
-
394
- def get_node_public_key(self, node_id: int) -> Optional[bytes]:
459
+ self._check_and_tag_offline_nodes()
460
+ result = []
461
+ for node_id in self.nodes.keys() if node_ids is None else node_ids:
462
+ if (node := self.nodes.get(node_id)) is None:
463
+ continue
464
+ if owner_aids is not None and node.owner_aid not in owner_aids:
465
+ continue
466
+ if statuses is not None and node.status not in statuses:
467
+ continue
468
+ result.append(node)
469
+ return result
470
+
471
+ def _check_and_tag_offline_nodes(
472
+ self, node_ids: Optional[list[int]] = None
473
+ ) -> None:
474
+ with self.lock:
475
+ # Set all nodes of "online" status to "offline" if they've offline
476
+ current_ts = now().timestamp()
477
+ for node_id in node_ids or self.nodes.keys():
478
+ if (node := self.nodes.get(node_id)) is None:
479
+ continue
480
+ if node.status == NodeStatus.ONLINE:
481
+ if node.online_until <= current_ts:
482
+ node.status = NodeStatus.OFFLINE
483
+ node.last_deactivated_at = datetime.fromtimestamp(
484
+ node.online_until, tz=timezone.utc
485
+ ).isoformat()
486
+
487
+ def get_node_public_key(self, node_id: int) -> bytes:
395
488
  """Get `public_key` for the specified `node_id`."""
396
489
  with self.lock:
397
- if node_id not in self.node_ids:
398
- raise ValueError(f"Node {node_id} not found")
490
+ if (
491
+ node := self.nodes.get(node_id)
492
+ ) is None or node.status == NodeStatus.UNREGISTERED:
493
+ raise ValueError(f"Node ID {node_id} not found")
494
+ return node.public_key
495
+
496
+ def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
497
+ """Get `node_id` for the specified `public_key` if it exists and is not
498
+ deleted."""
499
+ with self.lock:
500
+ node_id = self.node_public_key_to_node_id.get(public_key)
399
501
 
400
- return self.node_id_to_public_key.get(node_id)
502
+ if node_id is None:
503
+ return None
401
504
 
402
- def get_node_id(self, node_public_key: bytes) -> Optional[int]:
403
- """Retrieve stored `node_id` filtered by `node_public_keys`."""
404
- return self.public_key_to_node_id.get(node_public_key)
505
+ node_info = self.nodes[node_id]
506
+ if node_info.status == NodeStatus.UNREGISTERED:
507
+ return None
508
+ return node_id
405
509
 
406
510
  # pylint: disable=too-many-arguments,too-many-positional-arguments
407
511
  def create_run(
@@ -449,26 +553,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
449
553
  log(ERROR, "Unexpected run creation failure.")
450
554
  return 0
451
555
 
452
- def clear_supernode_auth_keys(self) -> None:
453
- """Clear stored `node_public_keys` in the link state if any."""
454
- with self.lock:
455
- self.node_public_keys.clear()
456
-
457
- def store_node_public_keys(self, public_keys: set[bytes]) -> None:
458
- """Store a set of `node_public_keys` in the link state."""
459
- with self.lock:
460
- self.node_public_keys.update(public_keys)
461
-
462
- def store_node_public_key(self, public_key: bytes) -> None:
463
- """Store a `node_public_key` in the link state."""
464
- with self.lock:
465
- self.node_public_keys.add(public_key)
466
-
467
- def get_node_public_keys(self) -> set[bytes]:
468
- """Retrieve all currently stored `node_public_keys` as a set."""
469
- with self.lock:
470
- return self.node_public_keys.copy()
471
-
472
556
  def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
473
557
  """Retrieve all run IDs if `flwr_aid` is not specified.
474
558
 
@@ -561,7 +645,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
561
645
  current = now()
562
646
  run_record = self.run_ids[run_id]
563
647
  if new_status.status in (Status.STARTING, Status.RUNNING):
564
- run_record.heartbeat_interval = HEARTBEAT_MAX_INTERVAL
648
+ run_record.heartbeat_interval = HEARTBEAT_INTERVAL_INF
565
649
  run_record.active_until = (
566
650
  current.timestamp() + run_record.heartbeat_interval
567
651
  )
@@ -608,13 +692,23 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
608
692
  the node is marked as offline.
609
693
  """
610
694
  with self.lock:
611
- if node_id in self.node_ids:
612
- self.node_ids[node_id] = (
613
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
614
- heartbeat_interval,
695
+ if (
696
+ node := self.nodes.get(node_id)
697
+ ) and node.status != NodeStatus.UNREGISTERED:
698
+ current_dt = now()
699
+
700
+ # Set timestamp if the status changes
701
+ if node.status != NodeStatus.ONLINE: # offline or registered
702
+ node.status = NodeStatus.ONLINE
703
+ node.last_activated_at = current_dt.isoformat()
704
+
705
+ # Refresh `online_until` and `heartbeat_interval`
706
+ node.online_until = (
707
+ current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
615
708
  )
709
+ node.heartbeat_interval = heartbeat_interval
616
710
  return True
617
- return False
711
+ return False
618
712
 
619
713
  def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
620
714
  """Acknowledge a heartbeat received from a ServerApp for a given run.
@@ -16,11 +16,13 @@
16
16
 
17
17
 
18
18
  import abc
19
+ from collections.abc import Sequence
19
20
  from typing import Optional
20
21
 
21
22
  from flwr.common import Context, Message
22
23
  from flwr.common.record import ConfigRecord
23
24
  from flwr.common.typing import Run, RunStatus, UserConfig
25
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
24
26
  from flwr.supercore.corestate import CoreState
25
27
 
26
28
 
@@ -128,13 +130,54 @@ class LinkState(CoreState): # pylint: disable=R0904
128
130
  """Get all instruction Message IDs for the given run_id."""
129
131
 
130
132
  @abc.abstractmethod
131
- def create_node(self, heartbeat_interval: float) -> int:
133
+ def create_node(
134
+ self, owner_aid: str, public_key: bytes, heartbeat_interval: float
135
+ ) -> int:
132
136
  """Create, store in the link state, and return `node_id`."""
133
137
 
134
138
  @abc.abstractmethod
135
- def delete_node(self, node_id: int) -> None:
139
+ def delete_node(self, owner_aid: str, node_id: int) -> None:
136
140
  """Remove `node_id` from the link state."""
137
141
 
142
+ @abc.abstractmethod
143
+ def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
144
+ """Activate the node with the specified `node_id`.
145
+
146
+ Transitions the node status to "online". The transition will fail
147
+ if the current status is not "registered" or "offline".
148
+
149
+ Parameters
150
+ ----------
151
+ node_id : int
152
+ The identifier of the node to activate.
153
+ heartbeat_interval : float
154
+ The interval (in seconds) from the current timestamp within which
155
+ the next heartbeat from this node is expected to be received.
156
+
157
+ Returns
158
+ -------
159
+ bool
160
+ True if the status transition was successful, False otherwise.
161
+ """
162
+
163
+ @abc.abstractmethod
164
+ def deactivate_node(self, node_id: int) -> bool:
165
+ """Deactivate the node with the specified `node_id`.
166
+
167
+ Transitions the node status to "offline". The transition will fail
168
+ if the current status is not "online".
169
+
170
+ Parameters
171
+ ----------
172
+ node_id : int
173
+ The identifier of the node to deactivate.
174
+
175
+ Returns
176
+ -------
177
+ bool
178
+ True if the status transition was successful, False otherwise.
179
+ """
180
+
138
181
  @abc.abstractmethod
139
182
  def get_nodes(self, run_id: int) -> set[int]:
140
183
  """Retrieve all currently stored node IDs as a set.
@@ -146,16 +189,72 @@ class LinkState(CoreState): # pylint: disable=R0904
146
189
  """
147
190
 
148
191
  @abc.abstractmethod
149
- def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
150
- """Set `public_key` for the specified `node_id`."""
192
+ def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
193
+ """Get `node_id` for the specified `public_key` if it exists and is not deleted.
194
+
195
+ Parameters
196
+ ----------
197
+ public_key : bytes
198
+ The public key of the node whose information is to be retrieved.
199
+
200
+ Returns
201
+ -------
202
+ Optional[int]
203
+ The `node_id` associated with the specified `public_key` if it exists
204
+ and is not deleted; otherwise, `None`.
205
+ """
151
206
 
152
207
  @abc.abstractmethod
153
- def get_node_public_key(self, node_id: int) -> Optional[bytes]:
154
- """Get `public_key` for the specified `node_id`."""
208
+ def get_node_info(
209
+ self,
210
+ *,
211
+ node_ids: Optional[Sequence[int]] = None,
212
+ owner_aids: Optional[Sequence[str]] = None,
213
+ statuses: Optional[Sequence[str]] = None,
214
+ ) -> Sequence[NodeInfo]:
215
+ """Retrieve information about nodes based on the specified filters.
216
+
217
+ If a filter is set to None, it is ignored.
218
+ If multiple filters are provided, they are combined using AND logic.
219
+
220
+ Parameters
221
+ ----------
222
+ node_ids : Optional[Sequence[int]] (default: None)
223
+ Sequence of node IDs to filter by. If a sequence is provided,
224
+ it is treated as an OR condition.
225
+ owner_aids : Optional[Sequence[str]] (default: None)
226
+ Sequence of owner account IDs to filter by. If a sequence is provided,
227
+ it is treated as an OR condition.
228
+ statuses : Optional[Sequence[str]] (default: None)
229
+ Sequence of node status values (e.g., "created", "activated")
230
+ to filter by. If a sequence is provided, it is treated as an OR condition.
231
+
232
+ Returns
233
+ -------
234
+ Sequence[NodeInfo]
235
+ A sequence of NodeInfo objects representing the nodes matching
236
+ the specified filters.
237
+ """
155
238
 
156
239
  @abc.abstractmethod
157
- def get_node_id(self, node_public_key: bytes) -> Optional[int]:
158
- """Retrieve stored `node_id` filtered by `node_public_keys`."""
240
+ def get_node_public_key(self, node_id: int) -> bytes:
241
+ """Get `public_key` for the specified `node_id`.
242
+
243
+ Parameters
244
+ ----------
245
+ node_id : int
246
+ The identifier of the node whose public key is to be retrieved.
247
+
248
+ Returns
249
+ -------
250
+ bytes
251
+ The public key associated with the specified `node_id`.
252
+
253
+ Raises
254
+ ------
255
+ ValueError
256
+ If the specified `node_id` does not exist in the link state.
257
+ """
159
258
 
160
259
  @abc.abstractmethod
161
260
  def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -254,22 +353,6 @@ class LinkState(CoreState): # pylint: disable=R0904
254
353
  The federation options for the run if it exists; None otherwise.
255
354
  """
256
355
 
257
- @abc.abstractmethod
258
- def clear_supernode_auth_keys(self) -> None:
259
- """Clear stored `node_public_keys` in the link state if any."""
260
-
261
- @abc.abstractmethod
262
- def store_node_public_keys(self, public_keys: set[bytes]) -> None:
263
- """Store a set of `node_public_keys` in the link state."""
264
-
265
- @abc.abstractmethod
266
- def store_node_public_key(self, public_key: bytes) -> None:
267
- """Store a `node_public_key` in the link state."""
268
-
269
- @abc.abstractmethod
270
- def get_node_public_keys(self) -> set[bytes]:
271
- """Retrieve all currently stored `node_public_keys` as a set."""
272
-
273
356
  @abc.abstractmethod
274
357
  def acknowledge_node_heartbeat(
275
358
  self, node_id: int, heartbeat_interval: float
@@ -19,6 +19,7 @@ from logging import DEBUG
19
19
  from typing import Optional
20
20
 
21
21
  from flwr.common.logger import log
22
+ from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
22
23
 
23
24
  from .in_memory_linkstate import InMemoryLinkState
24
25
  from .linkstate import LinkState
@@ -44,7 +45,7 @@ class LinkStateFactory:
44
45
  def state(self) -> LinkState:
45
46
  """Return a State instance and create it, if necessary."""
46
47
  # InMemoryState
47
- if self.database == ":flwr-in-memory-state:":
48
+ if self.database == FLWR_IN_MEMORY_DB_NAME:
48
49
  if self.state_instance is None:
49
50
  self.state_instance = InMemoryLinkState()
50
51
  log(DEBUG, "Using InMemoryState")