flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (292) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/federation/__init__.py +24 -0
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +317 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +170 -130
  21. flwr/cli/new/new.py +33 -50
  22. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  23. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  30. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  33. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  34. flwr/cli/pull.py +10 -5
  35. flwr/cli/run/run.py +77 -30
  36. flwr/cli/run_utils.py +130 -0
  37. flwr/cli/stop.py +25 -7
  38. flwr/cli/supernode/ls.py +16 -8
  39. flwr/cli/supernode/register.py +9 -4
  40. flwr/cli/supernode/unregister.py +5 -3
  41. flwr/cli/utils.py +376 -16
  42. flwr/client/__init__.py +1 -1
  43. flwr/client/dpfedavg_numpy_client.py +4 -1
  44. flwr/client/grpc_adapter_client/connection.py +6 -7
  45. flwr/client/grpc_rere_client/connection.py +10 -11
  46. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  47. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  48. flwr/client/message_handler/message_handler.py +2 -2
  49. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  50. flwr/client/numpy_client.py +1 -1
  51. flwr/client/rest_client/connection.py +12 -14
  52. flwr/client/run_info_store.py +4 -5
  53. flwr/client/typing.py +1 -1
  54. flwr/clientapp/client_app.py +9 -10
  55. flwr/clientapp/mod/centraldp_mods.py +16 -17
  56. flwr/clientapp/mod/localdp_mod.py +8 -9
  57. flwr/clientapp/typing.py +1 -1
  58. flwr/clientapp/utils.py +3 -3
  59. flwr/common/address.py +1 -2
  60. flwr/common/args.py +3 -4
  61. flwr/common/config.py +13 -16
  62. flwr/common/constant.py +5 -2
  63. flwr/common/differential_privacy.py +3 -4
  64. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  65. flwr/common/exit/exit.py +15 -2
  66. flwr/common/exit/exit_code.py +19 -0
  67. flwr/common/exit/exit_handler.py +6 -2
  68. flwr/common/exit/signal_handler.py +5 -5
  69. flwr/common/grpc.py +6 -6
  70. flwr/common/inflatable_protobuf_utils.py +1 -1
  71. flwr/common/inflatable_utils.py +38 -21
  72. flwr/common/logger.py +19 -19
  73. flwr/common/message.py +4 -4
  74. flwr/common/object_ref.py +7 -7
  75. flwr/common/record/array.py +3 -3
  76. flwr/common/record/arrayrecord.py +18 -30
  77. flwr/common/record/configrecord.py +3 -3
  78. flwr/common/record/recorddict.py +5 -5
  79. flwr/common/record/typeddict.py +9 -2
  80. flwr/common/recorddict_compat.py +7 -10
  81. flwr/common/retry_invoker.py +20 -20
  82. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  83. flwr/common/serde.py +5 -4
  84. flwr/common/serde_utils.py +2 -2
  85. flwr/common/telemetry.py +9 -5
  86. flwr/common/typing.py +52 -37
  87. flwr/compat/client/app.py +38 -37
  88. flwr/compat/client/grpc_client/connection.py +11 -11
  89. flwr/compat/server/app.py +5 -6
  90. flwr/proto/appio_pb2.py +13 -3
  91. flwr/proto/appio_pb2.pyi +134 -65
  92. flwr/proto/appio_pb2_grpc.py +20 -0
  93. flwr/proto/appio_pb2_grpc.pyi +27 -0
  94. flwr/proto/clientappio_pb2.py +17 -7
  95. flwr/proto/clientappio_pb2.pyi +15 -0
  96. flwr/proto/clientappio_pb2_grpc.py +206 -40
  97. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  98. flwr/proto/control_pb2.py +71 -52
  99. flwr/proto/control_pb2.pyi +277 -111
  100. flwr/proto/control_pb2_grpc.py +249 -40
  101. flwr/proto/control_pb2_grpc.pyi +185 -52
  102. flwr/proto/error_pb2.py +13 -3
  103. flwr/proto/error_pb2.pyi +24 -6
  104. flwr/proto/error_pb2_grpc.py +20 -0
  105. flwr/proto/error_pb2_grpc.pyi +27 -0
  106. flwr/proto/fab_pb2.py +14 -4
  107. flwr/proto/fab_pb2.pyi +59 -31
  108. flwr/proto/fab_pb2_grpc.py +20 -0
  109. flwr/proto/fab_pb2_grpc.pyi +27 -0
  110. flwr/proto/federation_pb2.py +38 -0
  111. flwr/proto/federation_pb2.pyi +56 -0
  112. flwr/proto/federation_pb2_grpc.py +24 -0
  113. flwr/proto/federation_pb2_grpc.pyi +31 -0
  114. flwr/proto/fleet_pb2.py +14 -4
  115. flwr/proto/fleet_pb2.pyi +137 -61
  116. flwr/proto/fleet_pb2_grpc.py +189 -48
  117. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  118. flwr/proto/grpcadapter_pb2.py +14 -4
  119. flwr/proto/grpcadapter_pb2.pyi +38 -16
  120. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  121. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  122. flwr/proto/heartbeat_pb2.py +17 -7
  123. flwr/proto/heartbeat_pb2.pyi +51 -22
  124. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  125. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  126. flwr/proto/log_pb2.py +13 -3
  127. flwr/proto/log_pb2.pyi +34 -11
  128. flwr/proto/log_pb2_grpc.py +20 -0
  129. flwr/proto/log_pb2_grpc.pyi +27 -0
  130. flwr/proto/message_pb2.py +15 -5
  131. flwr/proto/message_pb2.pyi +154 -86
  132. flwr/proto/message_pb2_grpc.py +20 -0
  133. flwr/proto/message_pb2_grpc.pyi +27 -0
  134. flwr/proto/node_pb2.py +15 -5
  135. flwr/proto/node_pb2.pyi +50 -25
  136. flwr/proto/node_pb2_grpc.py +20 -0
  137. flwr/proto/node_pb2_grpc.pyi +27 -0
  138. flwr/proto/recorddict_pb2.py +13 -3
  139. flwr/proto/recorddict_pb2.pyi +184 -107
  140. flwr/proto/recorddict_pb2_grpc.py +20 -0
  141. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  142. flwr/proto/run_pb2.py +40 -31
  143. flwr/proto/run_pb2.pyi +149 -84
  144. flwr/proto/run_pb2_grpc.py +20 -0
  145. flwr/proto/run_pb2_grpc.pyi +27 -0
  146. flwr/proto/serverappio_pb2.py +13 -3
  147. flwr/proto/serverappio_pb2.pyi +32 -8
  148. flwr/proto/serverappio_pb2_grpc.py +246 -65
  149. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  150. flwr/proto/simulationio_pb2.py +16 -8
  151. flwr/proto/simulationio_pb2.pyi +15 -0
  152. flwr/proto/simulationio_pb2_grpc.py +162 -41
  153. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  154. flwr/proto/transport_pb2.py +20 -10
  155. flwr/proto/transport_pb2.pyi +249 -160
  156. flwr/proto/transport_pb2_grpc.py +35 -4
  157. flwr/proto/transport_pb2_grpc.pyi +38 -8
  158. flwr/server/app.py +38 -17
  159. flwr/server/client_manager.py +4 -5
  160. flwr/server/client_proxy.py +10 -11
  161. flwr/server/compat/app.py +4 -5
  162. flwr/server/compat/app_utils.py +2 -1
  163. flwr/server/compat/grid_client_proxy.py +10 -12
  164. flwr/server/compat/legacy_context.py +3 -4
  165. flwr/server/fleet_event_log_interceptor.py +2 -1
  166. flwr/server/grid/grid.py +2 -3
  167. flwr/server/grid/grpc_grid.py +10 -8
  168. flwr/server/grid/inmemory_grid.py +4 -4
  169. flwr/server/run_serverapp.py +2 -3
  170. flwr/server/server.py +34 -39
  171. flwr/server/server_app.py +7 -8
  172. flwr/server/server_config.py +1 -2
  173. flwr/server/serverapp/app.py +34 -28
  174. flwr/server/serverapp_components.py +4 -5
  175. flwr/server/strategy/aggregate.py +9 -8
  176. flwr/server/strategy/bulyan.py +13 -11
  177. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  178. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  179. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  180. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  181. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  182. flwr/server/strategy/fedadagrad.py +18 -14
  183. flwr/server/strategy/fedadam.py +16 -14
  184. flwr/server/strategy/fedavg.py +16 -17
  185. flwr/server/strategy/fedavg_android.py +15 -15
  186. flwr/server/strategy/fedavgm.py +21 -18
  187. flwr/server/strategy/fedmedian.py +2 -3
  188. flwr/server/strategy/fedopt.py +11 -10
  189. flwr/server/strategy/fedprox.py +10 -9
  190. flwr/server/strategy/fedtrimmedavg.py +12 -11
  191. flwr/server/strategy/fedxgb_bagging.py +13 -11
  192. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  193. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  194. flwr/server/strategy/fedyogi.py +16 -14
  195. flwr/server/strategy/krum.py +12 -11
  196. flwr/server/strategy/qfedavg.py +16 -15
  197. flwr/server/strategy/strategy.py +6 -9
  198. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  199. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  200. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  201. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  202. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  203. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  204. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  205. flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
  206. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  207. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  208. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  209. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  210. flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
  211. flwr/server/superlink/linkstate/linkstate.py +59 -43
  212. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  213. flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
  214. flwr/server/superlink/linkstate/utils.py +6 -6
  215. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  216. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  217. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  218. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  219. flwr/server/superlink/utils.py +4 -6
  220. flwr/server/typing.py +1 -1
  221. flwr/server/utils/tensorboard.py +15 -8
  222. flwr/server/workflow/default_workflows.py +5 -5
  223. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  224. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  225. flwr/serverapp/strategy/bulyan.py +16 -15
  226. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  227. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  228. flwr/serverapp/strategy/fedadagrad.py +10 -11
  229. flwr/serverapp/strategy/fedadam.py +10 -11
  230. flwr/serverapp/strategy/fedavg.py +9 -10
  231. flwr/serverapp/strategy/fedavgm.py +17 -16
  232. flwr/serverapp/strategy/fedmedian.py +2 -2
  233. flwr/serverapp/strategy/fedopt.py +10 -11
  234. flwr/serverapp/strategy/fedprox.py +7 -8
  235. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  236. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  237. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  238. flwr/serverapp/strategy/fedyogi.py +9 -11
  239. flwr/serverapp/strategy/krum.py +7 -7
  240. flwr/serverapp/strategy/multikrum.py +9 -9
  241. flwr/serverapp/strategy/qfedavg.py +17 -16
  242. flwr/serverapp/strategy/strategy.py +6 -9
  243. flwr/serverapp/strategy/strategy_utils.py +7 -8
  244. flwr/simulation/app.py +46 -42
  245. flwr/simulation/legacy_app.py +12 -12
  246. flwr/simulation/ray_transport/ray_actor.py +10 -11
  247. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  248. flwr/simulation/run_simulation.py +43 -43
  249. flwr/simulation/simulationio_connection.py +4 -4
  250. flwr/supercore/cli/flower_superexec.py +3 -4
  251. flwr/supercore/constant.py +31 -1
  252. flwr/supercore/corestate/corestate.py +24 -3
  253. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  254. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  255. flwr/supercore/ffs/disk_ffs.py +1 -2
  256. flwr/supercore/ffs/ffs.py +1 -2
  257. flwr/supercore/ffs/ffs_factory.py +1 -2
  258. flwr/{common → supercore}/heartbeat.py +20 -25
  259. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  260. flwr/supercore/object_store/object_store.py +1 -2
  261. flwr/supercore/object_store/object_store_factory.py +1 -2
  262. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  263. flwr/supercore/primitives/asymmetric.py +1 -1
  264. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  265. flwr/supercore/sqlite_mixin.py +37 -34
  266. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  267. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  268. flwr/supercore/superexec/run_superexec.py +9 -13
  269. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  270. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  271. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  272. flwr/superlink/federation/__init__.py +24 -0
  273. flwr/superlink/federation/federation_manager.py +64 -0
  274. flwr/superlink/federation/noop_federation_manager.py +71 -0
  275. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  276. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  277. flwr/superlink/servicer/control/control_grpc.py +5 -6
  278. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  279. flwr/superlink/servicer/control/control_servicer.py +102 -18
  280. flwr/supernode/cli/flower_supernode.py +58 -3
  281. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  282. flwr/supernode/nodestate/nodestate.py +7 -8
  283. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  284. flwr/supernode/runtime/run_clientapp.py +41 -22
  285. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  286. flwr/supernode/start_client_internal.py +158 -42
  287. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  288. flwr-1.24.0.dist-info/RECORD +454 -0
  289. flwr/supercore/object_store/utils.py +0 -43
  290. flwr-1.23.0.dist-info/RECORD +0 -439
  291. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  292. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -15,7 +15,6 @@
15
15
  """In-memory LinkState implementation."""
16
16
 
17
17
 
18
- import secrets
19
18
  import threading
20
19
  from bisect import bisect_right
21
20
  from collections import defaultdict
@@ -23,12 +22,9 @@ from collections.abc import Sequence
23
22
  from dataclasses import dataclass, field
24
23
  from datetime import datetime, timezone
25
24
  from logging import ERROR, WARNING
26
- from typing import Optional
27
25
 
28
26
  from flwr.common import Context, Message, log, now
29
27
  from flwr.common.constant import (
30
- FLWR_APP_TOKEN_LENGTH,
31
- HEARTBEAT_INTERVAL_INF,
32
28
  HEARTBEAT_PATIENCE,
33
29
  MESSAGE_TTL_TOLERANCE,
34
30
  NODE_ID_NUM_BYTES,
@@ -44,6 +40,9 @@ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
44
40
  from flwr.server.superlink.linkstate.linkstate import LinkState
45
41
  from flwr.server.utils import validate_message
46
42
  from flwr.supercore.constant import NodeStatus
43
+ from flwr.supercore.corestate.in_memory_corestate import InMemoryCoreState
44
+ from flwr.supercore.object_store.object_store import ObjectStore
45
+ from flwr.superlink.federation import FederationManager
47
46
 
48
47
  from .utils import (
49
48
  check_node_availability_for_in_message,
@@ -60,17 +59,18 @@ class RunRecord: # pylint: disable=R0902
60
59
  """The record of a specific run, including its status and timestamps."""
61
60
 
62
61
  run: Run
63
- active_until: float = 0.0
64
- heartbeat_interval: float = 0.0
65
62
  logs: list[tuple[float, str]] = field(default_factory=list)
66
63
  log_lock: threading.Lock = field(default_factory=threading.Lock)
67
64
  lock: threading.RLock = field(default_factory=threading.RLock)
68
65
 
69
66
 
70
- class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
67
+ class InMemoryLinkState(LinkState, InMemoryCoreState): # pylint: disable=R0902,R0904
71
68
  """In-memory LinkState implementation."""
72
69
 
73
- def __init__(self) -> None:
70
+ def __init__(
71
+ self, federation_manager: FederationManager, object_store: ObjectStore
72
+ ) -> None:
73
+ super().__init__(object_store)
74
74
 
75
75
  # Map node_id to NodeInfo
76
76
  self.nodes: dict[int, NodeInfo] = {}
@@ -85,19 +85,21 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
85
85
  self.message_res_store: dict[str, Message] = {}
86
86
  self.message_ins_id_to_message_res_id: dict[str, str] = {}
87
87
 
88
- # Store run ID to token mapping and token to run ID mapping
89
- self.token_store: dict[int, str] = {}
90
- self.token_to_run_id: dict[str, int] = {}
91
- self.lock_token_store = threading.Lock()
92
-
93
88
  # Map flwr_aid to run_ids for O(1) reverse index lookup
94
89
  self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
95
90
 
96
91
  self.node_public_keys: set[bytes] = set()
97
92
 
98
93
  self.lock = threading.RLock()
94
+ federation_manager.linkstate = self
95
+ self._federation_manager = federation_manager
96
+
97
+ @property
98
+ def federation_manager(self) -> FederationManager:
99
+ """Get the FederationManager instance."""
100
+ return self._federation_manager
99
101
 
100
- def store_message_ins(self, message: Message) -> Optional[str]:
102
+ def store_message_ins(self, message: Message) -> str | None:
101
103
  """Store one Message."""
102
104
  # Validate message
103
105
  errors = validate_message(message, is_reply_message=False)
@@ -108,6 +110,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
108
110
  if message.metadata.run_id not in self.run_ids:
109
111
  log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
110
112
  return None
113
+ federation = self.run_ids[message.metadata.run_id].run.federation
111
114
  # Validate source node ID
112
115
  if message.metadata.src_node_id != SUPERLINK_NODE_ID:
113
116
  log(
@@ -118,10 +121,14 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
118
121
  return None
119
122
  # Validate destination node ID
120
123
  dst_node = self.nodes.get(message.metadata.dst_node_id)
121
- if dst_node is None or dst_node.status not in [
122
- NodeStatus.ONLINE,
123
- NodeStatus.OFFLINE,
124
- ]:
124
+ if (
125
+ # Node must exist
126
+ dst_node is None
127
+ # Node must be online or offline
128
+ or dst_node.status not in (NodeStatus.ONLINE, NodeStatus.OFFLINE)
129
+ # Node must belong to the same federation
130
+ or not self.federation_manager.has_node(dst_node.node_id, federation)
131
+ ):
125
132
  log(
126
133
  ERROR,
127
134
  "Invalid destination node ID for Message: %s",
@@ -136,21 +143,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
136
143
  # Return the new message_id
137
144
  return message_id
138
145
 
139
- def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
146
+ def _check_stored_messages(self, message_ids: set[str]) -> None:
147
+ """Check and delete the message if it's invalid."""
148
+ with self.lock:
149
+ invalid_msg_ids: set[str] = set()
150
+ current = now().timestamp()
151
+ for msg_id in message_ids:
152
+ if not (message := self.message_ins_store.get(msg_id)):
153
+ continue
154
+
155
+ # Check if the message has expired
156
+ available_until = message.metadata.created_at + message.metadata.ttl
157
+ if available_until <= current:
158
+ invalid_msg_ids.add(msg_id)
159
+ continue
160
+
161
+ # Check if the destination node and the source node are still in the
162
+ # same federation
163
+ src_node_id = message.metadata.src_node_id
164
+ dst_node_id = message.metadata.dst_node_id
165
+ filtered = self.federation_manager.filter_nodes(
166
+ {src_node_id, dst_node_id},
167
+ self.run_ids[message.metadata.run_id].run.federation,
168
+ )
169
+ if len(filtered) != 2: # Not both nodes are in the federation
170
+ invalid_msg_ids.add(msg_id)
171
+
172
+ # Delete all invalid messages
173
+ self.delete_messages(invalid_msg_ids)
174
+
175
+ def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
140
176
  """Get all Messages that have not been delivered yet."""
141
177
  if limit is not None and limit < 1:
142
178
  raise AssertionError("`limit` must be >= 1")
143
179
 
144
180
  # Find Message for node_id that were not delivered yet
145
181
  message_ins_list: list[Message] = []
146
- current_time = now().timestamp()
147
182
  with self.lock:
148
- for _, msg_ins in self.message_ins_store.items():
183
+ for msg_id in list(self.message_ins_store.keys()):
184
+ self._check_stored_messages({msg_id})
185
+
149
186
  if (
150
- msg_ins.metadata.dst_node_id == node_id
187
+ (msg_ins := self.message_ins_store.get(msg_id))
188
+ and msg_ins.metadata.dst_node_id == node_id
151
189
  and msg_ins.metadata.delivered_at == ""
152
- and msg_ins.metadata.created_at + msg_ins.metadata.ttl
153
- > current_time
154
190
  ):
155
191
  message_ins_list.append(msg_ins)
156
192
  if limit and len(message_ins_list) == limit:
@@ -165,7 +201,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
165
201
  return message_ins_list
166
202
 
167
203
  # pylint: disable=R0911
168
- def store_message_res(self, message: Message) -> Optional[str]:
204
+ def store_message_res(self, message: Message) -> str | None:
169
205
  """Store one Message."""
170
206
  # Validate message
171
207
  errors = validate_message(message, is_reply_message=True)
@@ -177,6 +213,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
177
213
  with self.lock:
178
214
  # Check if the Message it is replying to exists and is valid
179
215
  msg_ins_id = res_metadata.reply_to_message_id
216
+ self._check_stored_messages({msg_ins_id})
180
217
  msg_ins = self.message_ins_store.get(msg_ins_id)
181
218
 
182
219
  # Ensure that dst_node_id of original Message matches the src_node_id of
@@ -196,22 +233,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
196
233
  )
197
234
  return None
198
235
 
199
- ins_metadata = msg_ins.metadata
200
- if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
201
- log(
202
- ERROR,
203
- "Failed to store Message: the message it is replying to "
204
- "(with ID %s) has expired",
205
- msg_ins_id,
206
- )
207
- return None
208
-
209
236
  # Fail if the Message TTL exceeds the
210
237
  # expiration time of the Message it replies to.
211
238
  # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
212
239
  # res_metadata.created_at + res_metadata.ttl
213
240
  # A small tolerance is introduced to account
214
241
  # for floating-point precision issues.
242
+ ins_metadata = msg_ins.metadata
215
243
  max_allowed_ttl = (
216
244
  ins_metadata.created_at + ins_metadata.ttl - res_metadata.created_at
217
245
  )
@@ -245,6 +273,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
245
273
  ret: dict[str, Message] = {}
246
274
 
247
275
  with self.lock:
276
+ self._check_stored_messages(message_ids)
248
277
  current = now().timestamp()
249
278
 
250
279
  # Verify Message IDs
@@ -339,7 +368,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
339
368
  return len(self.message_res_store)
340
369
 
341
370
  def create_node(
342
- self, owner_aid: str, public_key: bytes, heartbeat_interval: float
371
+ self,
372
+ owner_aid: str,
373
+ owner_name: str,
374
+ public_key: bytes,
375
+ heartbeat_interval: float,
343
376
  ) -> int:
344
377
  """Create, store in the link state, and return `node_id`."""
345
378
  # Sample a random int64 as node_id
@@ -358,6 +391,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
358
391
  self.nodes[node_id] = NodeInfo(
359
392
  node_id=node_id,
360
393
  owner_aid=owner_aid,
394
+ owner_name=owner_name,
361
395
  status=NodeStatus.REGISTERED,
362
396
  registered_at=now().isoformat(),
363
397
  last_activated_at=None,
@@ -442,17 +476,19 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
442
476
  with self.lock:
443
477
  if run_id not in self.run_ids:
444
478
  return set()
445
- return {
479
+ federation = self.run_ids[run_id].run.federation
480
+ node_ids = {
446
481
  node.node_id
447
482
  for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
448
483
  }
484
+ return self.federation_manager.filter_nodes(node_ids, federation)
449
485
 
450
486
  def get_node_info(
451
487
  self,
452
488
  *,
453
- node_ids: Optional[Sequence[int]] = None,
454
- owner_aids: Optional[Sequence[str]] = None,
455
- statuses: Optional[Sequence[str]] = None,
489
+ node_ids: Sequence[int] | None = None,
490
+ owner_aids: Sequence[str] | None = None,
491
+ statuses: Sequence[str] | None = None,
456
492
  ) -> Sequence[NodeInfo]:
457
493
  """Retrieve information about nodes based on the specified filters."""
458
494
  with self.lock:
@@ -468,9 +504,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
468
504
  result.append(node)
469
505
  return result
470
506
 
471
- def _check_and_tag_offline_nodes(
472
- self, node_ids: Optional[list[int]] = None
473
- ) -> None:
507
+ def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
474
508
  with self.lock:
475
509
  # Set all nodes of "online" status to "offline" if they've offline
476
510
  current_ts = now().timestamp()
@@ -493,7 +527,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
493
527
  raise ValueError(f"Node ID {node_id} not found")
494
528
  return node.public_key
495
529
 
496
- def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
530
+ def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
497
531
  """Get `node_id` for the specified `public_key` if it exists and is not
498
532
  deleted."""
499
533
  with self.lock:
@@ -510,14 +544,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
510
544
  # pylint: disable=too-many-arguments,too-many-positional-arguments
511
545
  def create_run(
512
546
  self,
513
- fab_id: Optional[str],
514
- fab_version: Optional[str],
515
- fab_hash: Optional[str],
547
+ fab_id: str | None,
548
+ fab_version: str | None,
549
+ fab_hash: str | None,
516
550
  override_config: UserConfig,
551
+ federation: str,
517
552
  federation_options: ConfigRecord,
518
- flwr_aid: Optional[str],
553
+ flwr_aid: str | None,
519
554
  ) -> int:
520
- """Create a new run for the specified `fab_hash`."""
555
+ """Create a new run."""
521
556
  # Sample a random int64 as run_id
522
557
  with self.lock:
523
558
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
@@ -540,6 +575,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
540
575
  details="",
541
576
  ),
542
577
  flwr_aid=flwr_aid if flwr_aid else "",
578
+ federation=federation,
543
579
  ),
544
580
  )
545
581
  self.run_ids[run_id] = run_record
@@ -553,7 +589,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
553
589
  log(ERROR, "Unexpected run creation failure.")
554
590
  return 0
555
591
 
556
- def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
592
+ def get_run_ids(self, flwr_aid: str | None) -> set[int]:
557
593
  """Retrieve all run IDs if `flwr_aid` is not specified.
558
594
 
559
595
  Otherwise, retrieve all run IDs for the specified `flwr_aid`.
@@ -564,30 +600,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
564
600
  return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
565
601
  return set(self.run_ids.keys())
566
602
 
567
- def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
568
- """Check if any runs are no longer active.
569
-
570
- Marks runs with status 'starting' or 'running' as failed
571
- if they have not sent a heartbeat before `active_until`.
572
- """
573
- current = now()
574
- for record in (self.run_ids.get(run_id) for run_id in run_ids):
575
- if record is None:
576
- continue
577
- with record.lock:
578
- if record.run.status.status in (Status.STARTING, Status.RUNNING):
579
- if record.active_until < current.timestamp():
580
- record.run.status = RunStatus(
581
- status=Status.FINISHED,
582
- sub_status=SubStatus.FAILED,
583
- details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
584
- )
585
- record.run.finished_at = now().isoformat()
586
-
587
- def get_run(self, run_id: int) -> Optional[Run]:
603
+ def get_run(self, run_id: int) -> Run | None:
588
604
  """Retrieve information about the run with the specified `run_id`."""
589
- # Check if runs are still active
590
- self._check_and_tag_inactive_run(run_ids={run_id})
605
+ # Clean up expired tokens; this will flag inactive runs as needed
606
+ self._cleanup_expired_tokens()
591
607
 
592
608
  with self.lock:
593
609
  if run_id not in self.run_ids:
@@ -597,8 +613,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
597
613
 
598
614
  def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
599
615
  """Retrieve the statuses for the specified runs."""
600
- # Check if runs are still active
601
- self._check_and_tag_inactive_run(run_ids=run_ids)
616
+ # Clean up expired tokens; this will flag inactive runs as needed
617
+ self._cleanup_expired_tokens()
602
618
 
603
619
  with self.lock:
604
620
  return {
@@ -609,8 +625,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
609
625
 
610
626
  def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
611
627
  """Update the status of the run with the specified `run_id`."""
612
- # Check if runs are still active
613
- self._check_and_tag_inactive_run(run_ids={run_id})
628
+ # Clean up expired tokens; this will flag inactive runs as needed
629
+ self._cleanup_expired_tokens()
614
630
 
615
631
  with self.lock:
616
632
  # Check if the run_id exists
@@ -640,17 +656,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
640
656
  )
641
657
  return False
642
658
 
643
- # Initialize heartbeat_interval and active_until
644
- # when switching to starting or running
659
+ # Update the run status
645
660
  current = now()
646
661
  run_record = self.run_ids[run_id]
647
- if new_status.status in (Status.STARTING, Status.RUNNING):
648
- run_record.heartbeat_interval = HEARTBEAT_INTERVAL_INF
649
- run_record.active_until = (
650
- current.timestamp() + run_record.heartbeat_interval
651
- )
652
-
653
- # Update the run status
654
662
  if new_status.status == Status.STARTING:
655
663
  run_record.run.starting_at = current.isoformat()
656
664
  elif new_status.status == Status.RUNNING:
@@ -660,7 +668,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
660
668
  run_record.run.status = new_status
661
669
  return True
662
670
 
663
- def get_pending_run_id(self) -> Optional[int]:
671
+ def get_pending_run_id(self) -> int | None:
664
672
  """Get the `run_id` of a run with `Status.PENDING` status, if any."""
665
673
  pending_run_id = None
666
674
 
@@ -673,7 +681,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
673
681
 
674
682
  return pending_run_id
675
683
 
676
- def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
684
+ def get_federation_options(self, run_id: int) -> ConfigRecord | None:
677
685
  """Retrieve the federation options for the specified `run_id`."""
678
686
  with self.lock:
679
687
  if run_id not in self.run_ids:
@@ -710,44 +718,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
710
718
  return True
711
719
  return False
712
720
 
713
- def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
714
- """Acknowledge a heartbeat received from a ServerApp for a given run.
721
+ def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
722
+ """Transition runs with expired tokens to failed status.
715
723
 
716
- A run with status `"running"` is considered alive as long as it sends heartbeats
717
- within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
718
- HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
719
- marked as `"completed:failed"`.
724
+ Parameters
725
+ ----------
726
+ expired_records : list[tuple[int, float]]
727
+ List of tuples containing (run_id, active_until timestamp)
728
+ for expired tokens.
720
729
  """
721
- with self.lock:
722
- # Search for the run
723
- record = self.run_ids.get(run_id)
724
-
725
- # Check if the run_id exists
726
- if record is None:
727
- log(ERROR, "`run_id` is invalid")
728
- return False
729
-
730
- with record.lock:
731
- # Check if runs are still active
732
- self._check_and_tag_inactive_run(run_ids={run_id})
733
-
734
- # Check if the run is of status "running"/"starting"
735
- current_status = record.run.status
736
- if current_status.status not in (Status.RUNNING, Status.STARTING):
737
- log(
738
- ERROR,
739
- 'Cannot acknowledge heartbeat for run with status "%s"',
740
- current_status.status,
730
+ for run_id, active_until in expired_records:
731
+ if not (run_record := self.run_ids.get(run_id)):
732
+ continue
733
+ with run_record.lock:
734
+ run_record.run.status = RunStatus(
735
+ status=Status.FINISHED,
736
+ sub_status=SubStatus.FAILED,
737
+ details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
741
738
  )
742
- return False
739
+ active_until_dt = datetime.fromtimestamp(active_until, tz=timezone.utc)
740
+ run_record.run.finished_at = active_until_dt.isoformat()
743
741
 
744
- # Update the `active_until` and `heartbeat_interval` for the given run
745
- current = now().timestamp()
746
- record.active_until = current + HEARTBEAT_PATIENCE * heartbeat_interval
747
- record.heartbeat_interval = heartbeat_interval
748
- return True
749
-
750
- def get_serverapp_context(self, run_id: int) -> Optional[Context]:
742
+ def get_serverapp_context(self, run_id: int) -> Context | None:
751
743
  """Get the context for the specified `run_id`."""
752
744
  return self.contexts.get(run_id)
753
745
 
@@ -766,7 +758,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
766
758
  run.logs.append((now().timestamp(), log_message))
767
759
 
768
760
  def get_serverapp_log(
769
- self, run_id: int, after_timestamp: Optional[float]
761
+ self, run_id: int, after_timestamp: float | None
770
762
  ) -> tuple[str, float]:
771
763
  """Get the serverapp logs for the specified `run_id`."""
772
764
  if run_id not in self.run_ids:
@@ -779,30 +771,3 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
779
771
  index = bisect_right(run.logs, (after_timestamp, ""))
780
772
  latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
781
773
  return "".join(log for _, log in run.logs[index:]), latest_timestamp
782
-
783
- def create_token(self, run_id: int) -> Optional[str]:
784
- """Create a token for the given run ID."""
785
- token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
786
- with self.lock_token_store:
787
- if run_id in self.token_store:
788
- return None # Token already created for this run ID
789
- self.token_store[run_id] = token
790
- self.token_to_run_id[token] = run_id
791
- return token
792
-
793
- def verify_token(self, run_id: int, token: str) -> bool:
794
- """Verify a token for the given run ID."""
795
- with self.lock_token_store:
796
- return self.token_store.get(run_id) == token
797
-
798
- def delete_token(self, run_id: int) -> None:
799
- """Delete the token for the given run ID."""
800
- with self.lock_token_store:
801
- token = self.token_store.pop(run_id, None)
802
- if token is not None:
803
- self.token_to_run_id.pop(token, None)
804
-
805
- def get_run_id_by_token(self, token: str) -> Optional[int]:
806
- """Get the run ID associated with a given token."""
807
- with self.lock_token_store:
808
- return self.token_to_run_id.get(token)