flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (301) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +34 -1
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/__init__.py +15 -6
  9. flwr/cli/auth_plugin/auth_plugin.py +94 -0
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
  12. flwr/cli/build.py +166 -53
  13. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
  14. flwr/cli/config_utils.py +101 -13
  15. flwr/cli/federation/__init__.py +24 -0
  16. flwr/cli/federation/ls.py +140 -0
  17. flwr/cli/federation/show.py +317 -0
  18. flwr/cli/install.py +91 -13
  19. flwr/cli/log.py +54 -11
  20. flwr/cli/login/login.py +41 -27
  21. flwr/cli/ls.py +177 -133
  22. flwr/cli/new/new.py +175 -40
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  24. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  31. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  34. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  35. flwr/cli/pull.py +12 -7
  36. flwr/cli/run/run.py +82 -31
  37. flwr/cli/run_utils.py +130 -0
  38. flwr/cli/stop.py +27 -9
  39. flwr/cli/supernode/__init__.py +25 -0
  40. flwr/cli/supernode/ls.py +268 -0
  41. flwr/cli/supernode/register.py +190 -0
  42. flwr/cli/supernode/unregister.py +140 -0
  43. flwr/cli/utils.py +464 -81
  44. flwr/client/__init__.py +2 -1
  45. flwr/client/dpfedavg_numpy_client.py +4 -1
  46. flwr/client/grpc_adapter_client/connection.py +12 -15
  47. flwr/client/grpc_rere_client/connection.py +68 -41
  48. flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
  49. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
  50. flwr/client/message_handler/message_handler.py +2 -2
  51. flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
  52. flwr/client/numpy_client.py +1 -1
  53. flwr/client/rest_client/connection.py +94 -51
  54. flwr/client/run_info_store.py +4 -5
  55. flwr/client/typing.py +1 -1
  56. flwr/clientapp/__init__.py +1 -2
  57. flwr/{client → clientapp}/client_app.py +9 -10
  58. flwr/clientapp/mod/centraldp_mods.py +16 -17
  59. flwr/clientapp/mod/localdp_mod.py +8 -9
  60. flwr/clientapp/typing.py +1 -1
  61. flwr/{client/clientapp → clientapp}/utils.py +4 -4
  62. flwr/common/address.py +1 -2
  63. flwr/common/args.py +3 -4
  64. flwr/common/config.py +13 -16
  65. flwr/common/constant.py +56 -13
  66. flwr/common/differential_privacy.py +3 -4
  67. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  68. flwr/common/exit/exit.py +15 -2
  69. flwr/common/exit/exit_code.py +39 -10
  70. flwr/common/exit/exit_handler.py +6 -2
  71. flwr/common/exit/signal_handler.py +5 -5
  72. flwr/common/grpc.py +6 -6
  73. flwr/common/inflatable_protobuf_utils.py +1 -1
  74. flwr/common/inflatable_utils.py +48 -31
  75. flwr/common/logger.py +19 -19
  76. flwr/common/message.py +4 -4
  77. flwr/common/object_ref.py +7 -7
  78. flwr/common/record/array.py +6 -6
  79. flwr/common/record/arrayrecord.py +18 -21
  80. flwr/common/record/configrecord.py +3 -3
  81. flwr/common/record/recorddict.py +5 -5
  82. flwr/common/record/typeddict.py +9 -2
  83. flwr/common/recorddict_compat.py +7 -10
  84. flwr/common/retry_invoker.py +20 -20
  85. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  86. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  87. flwr/common/serde.py +9 -6
  88. flwr/common/serde_utils.py +2 -2
  89. flwr/common/telemetry.py +9 -5
  90. flwr/common/typing.py +59 -43
  91. flwr/compat/client/app.py +39 -38
  92. flwr/compat/client/grpc_client/connection.py +13 -13
  93. flwr/compat/server/app.py +5 -6
  94. flwr/proto/appio_pb2.py +13 -3
  95. flwr/proto/appio_pb2.pyi +134 -65
  96. flwr/proto/appio_pb2_grpc.py +20 -0
  97. flwr/proto/appio_pb2_grpc.pyi +27 -0
  98. flwr/proto/clientappio_pb2.py +17 -7
  99. flwr/proto/clientappio_pb2.pyi +15 -0
  100. flwr/proto/clientappio_pb2_grpc.py +206 -40
  101. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  102. flwr/proto/control_pb2.py +72 -40
  103. flwr/proto/control_pb2.pyi +319 -87
  104. flwr/proto/control_pb2_grpc.py +339 -28
  105. flwr/proto/control_pb2_grpc.pyi +209 -37
  106. flwr/proto/error_pb2.py +13 -3
  107. flwr/proto/error_pb2.pyi +24 -6
  108. flwr/proto/error_pb2_grpc.py +20 -0
  109. flwr/proto/error_pb2_grpc.pyi +27 -0
  110. flwr/proto/fab_pb2.py +24 -10
  111. flwr/proto/fab_pb2.pyi +68 -20
  112. flwr/proto/fab_pb2_grpc.py +20 -0
  113. flwr/proto/fab_pb2_grpc.pyi +27 -0
  114. flwr/proto/federation_pb2.py +38 -0
  115. flwr/proto/federation_pb2.pyi +56 -0
  116. flwr/proto/federation_pb2_grpc.py +24 -0
  117. flwr/proto/federation_pb2_grpc.pyi +31 -0
  118. flwr/proto/fleet_pb2.py +45 -27
  119. flwr/proto/fleet_pb2.pyi +186 -70
  120. flwr/proto/fleet_pb2_grpc.py +277 -66
  121. flwr/proto/fleet_pb2_grpc.pyi +201 -55
  122. flwr/proto/grpcadapter_pb2.py +14 -4
  123. flwr/proto/grpcadapter_pb2.pyi +38 -16
  124. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  125. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  126. flwr/proto/heartbeat_pb2.py +17 -7
  127. flwr/proto/heartbeat_pb2.pyi +51 -22
  128. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  129. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  130. flwr/proto/log_pb2.py +13 -3
  131. flwr/proto/log_pb2.pyi +34 -11
  132. flwr/proto/log_pb2_grpc.py +20 -0
  133. flwr/proto/log_pb2_grpc.pyi +27 -0
  134. flwr/proto/message_pb2.py +15 -5
  135. flwr/proto/message_pb2.pyi +154 -86
  136. flwr/proto/message_pb2_grpc.py +20 -0
  137. flwr/proto/message_pb2_grpc.pyi +27 -0
  138. flwr/proto/node_pb2.py +16 -4
  139. flwr/proto/node_pb2.pyi +77 -4
  140. flwr/proto/node_pb2_grpc.py +20 -0
  141. flwr/proto/node_pb2_grpc.pyi +27 -0
  142. flwr/proto/recorddict_pb2.py +13 -3
  143. flwr/proto/recorddict_pb2.pyi +184 -107
  144. flwr/proto/recorddict_pb2_grpc.py +20 -0
  145. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  146. flwr/proto/run_pb2.py +40 -31
  147. flwr/proto/run_pb2.pyi +149 -84
  148. flwr/proto/run_pb2_grpc.py +20 -0
  149. flwr/proto/run_pb2_grpc.pyi +27 -0
  150. flwr/proto/serverappio_pb2.py +13 -3
  151. flwr/proto/serverappio_pb2.pyi +32 -8
  152. flwr/proto/serverappio_pb2_grpc.py +246 -65
  153. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  154. flwr/proto/simulationio_pb2.py +16 -8
  155. flwr/proto/simulationio_pb2.pyi +15 -0
  156. flwr/proto/simulationio_pb2_grpc.py +162 -41
  157. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  158. flwr/proto/transport_pb2.py +20 -10
  159. flwr/proto/transport_pb2.pyi +249 -160
  160. flwr/proto/transport_pb2_grpc.py +35 -4
  161. flwr/proto/transport_pb2_grpc.pyi +38 -8
  162. flwr/server/app.py +173 -127
  163. flwr/server/client_manager.py +4 -5
  164. flwr/server/client_proxy.py +10 -11
  165. flwr/server/compat/app.py +4 -5
  166. flwr/server/compat/app_utils.py +2 -1
  167. flwr/server/compat/grid_client_proxy.py +10 -12
  168. flwr/server/compat/legacy_context.py +3 -4
  169. flwr/server/fleet_event_log_interceptor.py +2 -1
  170. flwr/server/grid/grid.py +2 -3
  171. flwr/server/grid/grpc_grid.py +10 -8
  172. flwr/server/grid/inmemory_grid.py +4 -4
  173. flwr/server/run_serverapp.py +2 -3
  174. flwr/server/server.py +34 -39
  175. flwr/server/server_app.py +7 -8
  176. flwr/server/server_config.py +1 -2
  177. flwr/server/serverapp/app.py +34 -28
  178. flwr/server/serverapp_components.py +4 -5
  179. flwr/server/strategy/aggregate.py +9 -8
  180. flwr/server/strategy/bulyan.py +13 -11
  181. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  182. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  183. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  184. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  185. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  186. flwr/server/strategy/fedadagrad.py +18 -14
  187. flwr/server/strategy/fedadam.py +16 -14
  188. flwr/server/strategy/fedavg.py +16 -17
  189. flwr/server/strategy/fedavg_android.py +15 -15
  190. flwr/server/strategy/fedavgm.py +21 -18
  191. flwr/server/strategy/fedmedian.py +2 -3
  192. flwr/server/strategy/fedopt.py +11 -10
  193. flwr/server/strategy/fedprox.py +10 -9
  194. flwr/server/strategy/fedtrimmedavg.py +12 -11
  195. flwr/server/strategy/fedxgb_bagging.py +13 -11
  196. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  197. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  198. flwr/server/strategy/fedyogi.py +16 -14
  199. flwr/server/strategy/krum.py +12 -11
  200. flwr/server/strategy/qfedavg.py +16 -15
  201. flwr/server/strategy/strategy.py +6 -9
  202. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
  203. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  206. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
  208. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
  209. flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
  210. flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
  211. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  212. flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
  213. flwr/server/superlink/fleet/vce/vce_api.py +32 -13
  214. flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
  215. flwr/server/superlink/linkstate/linkstate.py +161 -62
  216. flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
  217. flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
  218. flwr/server/superlink/linkstate/utils.py +9 -60
  219. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  220. flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
  221. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  222. flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
  223. flwr/server/superlink/utils.py +4 -6
  224. flwr/server/typing.py +1 -1
  225. flwr/server/utils/tensorboard.py +15 -8
  226. flwr/server/utils/validator.py +2 -3
  227. flwr/server/workflow/default_workflows.py +5 -5
  228. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  229. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
  230. flwr/serverapp/strategy/bulyan.py +16 -15
  231. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  232. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  233. flwr/serverapp/strategy/fedadagrad.py +10 -11
  234. flwr/serverapp/strategy/fedadam.py +10 -11
  235. flwr/serverapp/strategy/fedavg.py +9 -10
  236. flwr/serverapp/strategy/fedavgm.py +17 -16
  237. flwr/serverapp/strategy/fedmedian.py +2 -2
  238. flwr/serverapp/strategy/fedopt.py +10 -11
  239. flwr/serverapp/strategy/fedprox.py +7 -8
  240. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  241. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  242. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  243. flwr/serverapp/strategy/fedyogi.py +9 -11
  244. flwr/serverapp/strategy/krum.py +7 -7
  245. flwr/serverapp/strategy/multikrum.py +9 -9
  246. flwr/serverapp/strategy/qfedavg.py +17 -16
  247. flwr/serverapp/strategy/strategy.py +6 -9
  248. flwr/serverapp/strategy/strategy_utils.py +7 -8
  249. flwr/simulation/app.py +46 -42
  250. flwr/simulation/legacy_app.py +12 -12
  251. flwr/simulation/ray_transport/ray_actor.py +11 -12
  252. flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
  253. flwr/simulation/run_simulation.py +44 -43
  254. flwr/simulation/simulationio_connection.py +4 -4
  255. flwr/supercore/cli/flower_superexec.py +3 -4
  256. flwr/supercore/constant.py +52 -0
  257. flwr/supercore/corestate/corestate.py +24 -3
  258. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  259. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  260. flwr/supercore/ffs/disk_ffs.py +1 -2
  261. flwr/supercore/ffs/ffs.py +1 -2
  262. flwr/supercore/ffs/ffs_factory.py +1 -2
  263. flwr/{common → supercore}/heartbeat.py +20 -25
  264. flwr/supercore/object_store/in_memory_object_store.py +1 -6
  265. flwr/supercore/object_store/object_store.py +1 -2
  266. flwr/supercore/object_store/object_store_factory.py +27 -8
  267. flwr/supercore/object_store/sqlite_object_store.py +253 -0
  268. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  269. flwr/supercore/primitives/asymmetric.py +117 -0
  270. flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
  271. flwr/supercore/sqlite_mixin.py +159 -0
  272. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  273. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  274. flwr/supercore/superexec/run_superexec.py +9 -13
  275. flwr/supercore/utils.py +20 -0
  276. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  277. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  278. flwr/superlink/auth_plugin/auth_plugin.py +88 -0
  279. flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
  280. flwr/superlink/federation/__init__.py +24 -0
  281. flwr/superlink/federation/federation_manager.py +64 -0
  282. flwr/superlink/federation/noop_federation_manager.py +71 -0
  283. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
  284. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  285. flwr/superlink/servicer/control/control_grpc.py +18 -17
  286. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  287. flwr/superlink/servicer/control/control_servicer.py +239 -63
  288. flwr/supernode/cli/flower_supernode.py +74 -26
  289. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  290. flwr/supernode/nodestate/nodestate.py +7 -8
  291. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  292. flwr/supernode/runtime/run_clientapp.py +43 -24
  293. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  294. flwr/supernode/start_client_internal.py +175 -51
  295. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  296. flwr-1.24.0.dist-info/RECORD +454 -0
  297. flwr/common/auth_plugin/auth_plugin.py +0 -149
  298. flwr/supercore/object_store/utils.py +0 -43
  299. flwr-1.22.0.dist-info/RECORD +0 -428
  300. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  301. {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -15,25 +15,31 @@
15
15
  """Fleet API gRPC request-response servicer."""
16
16
 
17
17
 
18
- from logging import DEBUG, INFO
18
+ import threading
19
+ from logging import DEBUG, ERROR, INFO
19
20
 
20
21
  import grpc
21
22
  from google.protobuf.json_format import MessageToDict
22
23
 
24
+ from flwr.common.constant import PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
23
25
  from flwr.common.inflatable import UnexpectedObjectContentError
24
26
  from flwr.common.logger import log
25
27
  from flwr.common.typing import InvalidRunStatusException
26
28
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
27
29
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
28
30
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
29
- CreateNodeRequest,
30
- CreateNodeResponse,
31
- DeleteNodeRequest,
32
- DeleteNodeResponse,
31
+ ActivateNodeRequest,
32
+ ActivateNodeResponse,
33
+ DeactivateNodeRequest,
34
+ DeactivateNodeResponse,
33
35
  PullMessagesRequest,
34
36
  PullMessagesResponse,
35
37
  PushMessagesRequest,
36
38
  PushMessagesResponse,
39
+ RegisterNodeFleetRequest,
40
+ RegisterNodeFleetResponse,
41
+ UnregisterNodeFleetRequest,
42
+ UnregisterNodeFleetResponse,
37
43
  )
38
44
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
39
45
  SendNodeHeartbeatRequest,
@@ -63,49 +69,137 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
63
69
  state_factory: LinkStateFactory,
64
70
  ffs_factory: FfsFactory,
65
71
  objectstore_factory: ObjectStoreFactory,
72
+ enable_supernode_auth: bool,
66
73
  ) -> None:
67
74
  self.state_factory = state_factory
68
75
  self.ffs_factory = ffs_factory
69
76
  self.objectstore_factory = objectstore_factory
77
+ self.enable_supernode_auth = enable_supernode_auth
78
+ self.lock = threading.Lock()
70
79
 
71
- def CreateNode(
72
- self, request: CreateNodeRequest, context: grpc.ServicerContext
73
- ) -> CreateNodeResponse:
74
- """."""
75
- log(
76
- INFO,
77
- "[Fleet.CreateNode] Request heartbeat_interval=%s",
78
- request.heartbeat_interval,
79
- )
80
- log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
81
- response = message_handler.create_node(
82
- request=request,
83
- state=self.state_factory.state(),
84
- )
85
- log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
86
- log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
87
- return response
80
+ def RegisterNode(
81
+ self, request: RegisterNodeFleetRequest, context: grpc.ServicerContext
82
+ ) -> RegisterNodeFleetResponse:
83
+ """Register a node."""
84
+ # Prevent registration when SuperNode authentication is enabled
85
+ if self.enable_supernode_auth:
86
+ log(ERROR, "SuperNode registration is disabled through Fleet API.")
87
+ context.abort(
88
+ grpc.StatusCode.FAILED_PRECONDITION,
89
+ "SuperNode authentication is enabled. "
90
+ "All SuperNodes must be registered via the CLI.",
91
+ )
88
92
 
89
- def DeleteNode(
90
- self, request: DeleteNodeRequest, context: grpc.ServicerContext
91
- ) -> DeleteNodeResponse:
92
- """."""
93
- log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
94
- log(DEBUG, "[Fleet.DeleteNode] Request: %s", MessageToDict(request))
95
- return message_handler.delete_node(
96
- request=request,
97
- state=self.state_factory.state(),
98
- )
93
+ try:
94
+ response = message_handler.register_node(
95
+ request=request,
96
+ state=self.state_factory.state(),
97
+ )
98
+ log(DEBUG, "[Fleet.RegisterNode] Registered node_id=%s", response.node_id)
99
+ return response
100
+ except ValueError:
101
+ # Public key already in use
102
+ # This should NEVER happen due to the public keys should be automatically
103
+ # generated and unique for each SuperNode instance.
104
+ log(
105
+ ERROR,
106
+ "[Fleet.RegisterNode] Registration failed: %s",
107
+ PUBLIC_KEY_ALREADY_IN_USE_MESSAGE,
108
+ )
109
+ context.abort(
110
+ grpc.StatusCode.FAILED_PRECONDITION, PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
111
+ )
112
+
113
+ raise RuntimeError # Make mypy happy
114
+
115
+ def ActivateNode(
116
+ self, request: ActivateNodeRequest, context: grpc.ServicerContext
117
+ ) -> ActivateNodeResponse:
118
+ """Activate a node."""
119
+ try:
120
+ response = message_handler.activate_node(
121
+ request=request,
122
+ state=self.state_factory.state(),
123
+ )
124
+ log(INFO, "[Fleet.ActivateNode] Activated node_id=%s", response.node_id)
125
+ return response
126
+ except message_handler.InvalidHeartbeatIntervalError:
127
+ # Heartbeat interval is invalid
128
+ log(ERROR, "[Fleet.ActivateNode] Invalid heartbeat interval")
129
+ context.abort(
130
+ grpc.StatusCode.INVALID_ARGUMENT, "Invalid heartbeat interval"
131
+ )
132
+ except ValueError as e:
133
+ log(ERROR, "[Fleet.ActivateNode] Activation failed: %s", str(e))
134
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
135
+
136
+ raise RuntimeError # Make mypy happy
137
+
138
+ def DeactivateNode(
139
+ self, request: DeactivateNodeRequest, context: grpc.ServicerContext
140
+ ) -> DeactivateNodeResponse:
141
+ """Deactivate a node."""
142
+ try:
143
+ response = message_handler.deactivate_node(
144
+ request=request,
145
+ state=self.state_factory.state(),
146
+ )
147
+ log(INFO, "[Fleet.DeactivateNode] Deactivated node_id=%s", request.node_id)
148
+ return response
149
+ except ValueError as e:
150
+ log(ERROR, "[Fleet.DeactivateNode] Deactivation failed: %s", str(e))
151
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
152
+
153
+ raise RuntimeError # Make mypy happy
154
+
155
+ def UnregisterNode(
156
+ self, request: UnregisterNodeFleetRequest, context: grpc.ServicerContext
157
+ ) -> UnregisterNodeFleetResponse:
158
+ """Unregister a node."""
159
+ # Prevent unregistration when SuperNode authentication is enabled
160
+ if self.enable_supernode_auth:
161
+ log(ERROR, "SuperNode unregistration is disabled through Fleet API.")
162
+ context.abort(
163
+ grpc.StatusCode.FAILED_PRECONDITION,
164
+ "SuperNode authentication is enabled. "
165
+ "All SuperNodes must be unregistered via the CLI.",
166
+ )
167
+
168
+ try:
169
+ response = message_handler.unregister_node(
170
+ request=request,
171
+ state=self.state_factory.state(),
172
+ )
173
+ log(
174
+ DEBUG, "[Fleet.UnregisterNode] Unregistered node_id=%s", request.node_id
175
+ )
176
+ return response
177
+ except ValueError as e:
178
+ log(
179
+ ERROR,
180
+ "[Fleet.UnregisterNode] Unregistration failed: %s",
181
+ str(e),
182
+ )
183
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
184
+ raise RuntimeError from None # Make mypy happy
99
185
 
100
186
  def SendNodeHeartbeat(
101
187
  self, request: SendNodeHeartbeatRequest, context: grpc.ServicerContext
102
188
  ) -> SendNodeHeartbeatResponse:
103
189
  """."""
104
190
  log(DEBUG, "[Fleet.SendNodeHeartbeat] Request: %s", MessageToDict(request))
105
- return message_handler.send_node_heartbeat(
106
- request=request,
107
- state=self.state_factory.state(),
108
- )
191
+ try:
192
+ return message_handler.send_node_heartbeat(
193
+ request=request,
194
+ state=self.state_factory.state(),
195
+ )
196
+ except message_handler.InvalidHeartbeatIntervalError:
197
+ # Heartbeat interval is invalid
198
+ log(ERROR, "[Fleet.SendNodeHeartbeat] Invalid heartbeat interval")
199
+ context.abort(
200
+ grpc.StatusCode.INVALID_ARGUMENT, "Invalid heartbeat interval"
201
+ )
202
+ raise RuntimeError # Make mypy happy
109
203
 
110
204
  def PullMessages(
111
205
  self, request: PullMessagesRequest, context: grpc.ServicerContext
@@ -155,8 +249,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
155
249
  state=self.state_factory.state(),
156
250
  store=self.objectstore_factory.store(),
157
251
  )
158
- except InvalidRunStatusException as e:
159
- abort_grpc_context(e.message, context)
252
+ except (InvalidRunStatusException, ValueError) as e:
253
+ abort_grpc_context(str(e), context)
160
254
 
161
255
  return res
162
256
 
@@ -172,8 +266,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
172
266
  state=self.state_factory.state(),
173
267
  store=self.objectstore_factory.store(),
174
268
  )
175
- except InvalidRunStatusException as e:
176
- abort_grpc_context(e.message, context)
269
+ except (InvalidRunStatusException, ValueError) as e:
270
+ abort_grpc_context(str(e), context)
177
271
 
178
272
  return res
179
273
 
@@ -183,7 +277,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
183
277
  """Push an object to the ObjectStore."""
184
278
  log(
185
279
  DEBUG,
186
- "[ServerAppIoServicer.PushObject] Push Object with object_id=%s",
280
+ "[Fleet.PushObject] Push Object with object_id=%s",
187
281
  request.object_id,
188
282
  )
189
283
 
@@ -208,7 +302,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
208
302
  """Pull an object from the ObjectStore."""
209
303
  log(
210
304
  DEBUG,
211
- "[ServerAppIoServicer.PullObject] Pull Object with object_id=%s",
305
+ "[Fleet.PullObject] Pull Object with object_id=%s",
212
306
  request.object_id,
213
307
  )
214
308
 
@@ -16,7 +16,8 @@
16
16
 
17
17
 
18
18
  import datetime
19
- from typing import Any, Callable, Optional, cast
19
+ from collections.abc import Callable
20
+ from typing import Any, cast
20
21
 
21
22
  import grpc
22
23
  from google.protobuf.message import Message as GrpcMessage
@@ -29,15 +30,12 @@ from flwr.common.constant import (
29
30
  TIMESTAMP_HEADER,
30
31
  TIMESTAMP_TOLERANCE,
31
32
  )
32
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
33
- bytes_to_public_key,
34
- verify_signature,
35
- )
36
33
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
- CreateNodeRequest,
38
- CreateNodeResponse,
34
+ ActivateNodeRequest,
35
+ RegisterNodeFleetRequest,
39
36
  )
40
37
  from flwr.server.superlink.linkstate import LinkStateFactory
38
+ from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
41
39
 
42
40
  MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
43
41
  MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
@@ -53,22 +51,17 @@ def _unary_unary_rpc_terminator(
53
51
  return grpc.unary_unary_rpc_method_handler(terminate)
54
52
 
55
53
 
56
- class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
54
+ class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
57
55
  """Server interceptor for node authentication.
58
56
 
59
57
  Parameters
60
58
  ----------
61
59
  state_factory : LinkStateFactory
62
60
  A factory for creating new instances of LinkState.
63
- auto_auth : bool (default: False)
64
- If True, nodes are authenticated without requiring their public keys to be
65
- pre-stored in the LinkState. If False, only nodes with pre-stored public keys
66
- can be authenticated.
67
61
  """
68
62
 
69
- def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
63
+ def __init__(self, state_factory: LinkStateFactory):
70
64
  self.state_factory = state_factory
71
- self.auto_auth = auto_auth
72
65
 
73
66
  def intercept_service( # pylint: disable=too-many-return-statements
74
67
  self,
@@ -85,7 +78,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
78
  if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
86
79
  return continuation(handler_call_details)
87
80
 
88
- state = self.state_factory.state()
89
81
  metadata_dict = dict(handler_call_details.invocation_metadata)
90
82
 
91
83
  # Retrieve info from the metadata
@@ -96,11 +88,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
96
88
  except KeyError:
97
89
  return _unary_unary_rpc_terminator("Missing authentication metadata")
98
90
 
99
- if not self.auto_auth:
100
- # Abort the RPC call if the node public key is not found
101
- if node_pk_bytes not in state.get_node_public_keys():
102
- return _unary_unary_rpc_terminator("Public key not recognized")
103
-
104
91
  # Verify the signature
105
92
  node_pk = bytes_to_public_key(node_pk_bytes)
106
93
  if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
@@ -113,50 +100,40 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
113
100
  if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
114
101
  return _unary_unary_rpc_terminator("Invalid timestamp")
115
102
 
116
- # Continue the RPC call
117
- expected_node_id = state.get_node_id(node_pk_bytes)
118
- if not handler_call_details.method.endswith("CreateNode"):
119
- # All calls, except for `CreateNode`, must provide a public key that is
120
- # already mapped to a `node_id` (in `LinkState`)
121
- if expected_node_id is None:
122
- return _unary_unary_rpc_terminator("Invalid node ID")
123
- # One of the method handlers in
103
+ # Continue the RPC call: One of the method handlers in
124
104
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
125
105
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
126
- return self._wrap_method_handler(
127
- method_handler, expected_node_id, node_pk_bytes
128
- )
106
+ return self._wrap_method_handler(method_handler, node_pk_bytes)
129
107
 
130
108
  def _wrap_method_handler(
131
109
  self,
132
110
  method_handler: grpc.RpcMethodHandler,
133
- expected_node_id: Optional[int],
134
- node_public_key: bytes,
111
+ expected_public_key: bytes,
135
112
  ) -> grpc.RpcMethodHandler:
136
113
  def _generic_method_handler(
137
114
  request: GrpcMessage,
138
115
  context: grpc.ServicerContext,
139
116
  ) -> GrpcMessage:
140
- # Verify the node ID
141
- if not isinstance(request, CreateNodeRequest):
142
- try:
143
- if request.node.node_id != expected_node_id: # type: ignore
144
- raise ValueError
145
- except (AttributeError, ValueError):
146
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
117
+ # Note: This function runs in a different thread
118
+ # than the `intercept_service` function.
119
+
120
+ # Retrieve the public key
121
+ if isinstance(request, (RegisterNodeFleetRequest | ActivateNodeRequest)):
122
+ actual_public_key = request.public_key
123
+ else:
124
+ if hasattr(request, "node"):
125
+ node_id = request.node.node_id
126
+ else:
127
+ node_id = request.node_id # type: ignore[attr-defined]
128
+ actual_public_key = self.state_factory.state().get_node_public_key(
129
+ node_id
130
+ )
131
+
132
+ # Verify the public key
133
+ if actual_public_key != expected_public_key:
134
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
147
135
 
148
136
  response: GrpcMessage = method_handler.unary_unary(request, context)
149
-
150
- # Set the public key after a successful CreateNode request
151
- if isinstance(response, CreateNodeResponse):
152
- state = self.state_factory.state()
153
- try:
154
- state.set_node_public_key(response.node.node_id, node_public_key)
155
- except ValueError as e:
156
- # Remove newly created node if setting the public key fails
157
- state.delete_node(response.node.node_id)
158
- context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
159
-
160
137
  return response
161
138
 
162
139
  return grpc.unary_unary_rpc_method_handler(
@@ -15,29 +15,38 @@
15
15
  """Fleet API message handlers."""
16
16
 
17
17
  from logging import ERROR
18
- from typing import Optional
19
18
 
20
19
  from flwr.common import Message, log
21
- from flwr.common.constant import Status
20
+ from flwr.common.constant import (
21
+ HEARTBEAT_MAX_INTERVAL,
22
+ HEARTBEAT_MIN_INTERVAL,
23
+ NOOP_ACCOUNT_NAME,
24
+ NOOP_FLWR_AID,
25
+ Status,
26
+ )
22
27
  from flwr.common.inflatable import UnexpectedObjectContentError
23
28
  from flwr.common.serde import (
24
29
  fab_to_proto,
25
30
  message_from_proto,
26
31
  message_to_proto,
27
- user_config_to_proto,
32
+ run_to_proto,
28
33
  )
29
- from flwr.common.typing import Fab, InvalidRunStatusException
34
+ from flwr.common.typing import Fab, InvalidRunStatusException, Run
30
35
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
31
36
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
- CreateNodeRequest,
33
- CreateNodeResponse,
34
- DeleteNodeRequest,
35
- DeleteNodeResponse,
37
+ ActivateNodeRequest,
38
+ ActivateNodeResponse,
39
+ DeactivateNodeRequest,
40
+ DeactivateNodeResponse,
36
41
  PullMessagesRequest,
37
42
  PullMessagesResponse,
38
43
  PushMessagesRequest,
39
44
  PushMessagesResponse,
40
45
  Reconnect,
46
+ RegisterNodeFleetRequest,
47
+ RegisterNodeFleetResponse,
48
+ UnregisterNodeFleetRequest,
49
+ UnregisterNodeFleetResponse,
41
50
  )
42
51
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
43
52
  SendNodeHeartbeatRequest,
@@ -51,38 +60,59 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
51
60
  PushObjectRequest,
52
61
  PushObjectResponse,
53
62
  )
54
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
55
- from flwr.proto.run_pb2 import ( # pylint: disable=E0611
56
- GetRunRequest,
57
- GetRunResponse,
58
- Run,
59
- )
63
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
60
64
  from flwr.server.superlink.linkstate import LinkState
61
65
  from flwr.server.superlink.utils import check_abort
62
66
  from flwr.supercore.ffs import Ffs
63
67
  from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
64
- from flwr.supercore.object_store.utils import store_mapping_and_register_objects
65
68
 
66
69
 
67
- def create_node(
68
- request: CreateNodeRequest, # pylint: disable=unused-argument
70
+ class InvalidHeartbeatIntervalError(Exception):
71
+ """Invalid heartbeat interval exception."""
72
+
73
+
74
+ def register_node(
75
+ request: RegisterNodeFleetRequest,
69
76
  state: LinkState,
70
- ) -> CreateNodeResponse:
71
- """."""
72
- # Create node
73
- node_id = state.create_node(heartbeat_interval=request.heartbeat_interval)
74
- return CreateNodeResponse(node=Node(node_id=node_id))
77
+ ) -> RegisterNodeFleetResponse:
78
+ """Register a node (Fleet API only)."""
79
+ node_id = state.create_node(NOOP_FLWR_AID, NOOP_ACCOUNT_NAME, request.public_key, 0)
80
+ return RegisterNodeFleetResponse(node_id=node_id)
75
81
 
76
82
 
77
- def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
78
- """."""
79
- # Validate node_id
80
- if request.node.node_id == 0: # i.e. unset `node_id`
81
- return DeleteNodeResponse()
83
+ def activate_node(
84
+ request: ActivateNodeRequest,
85
+ state: LinkState,
86
+ ) -> ActivateNodeResponse:
87
+ """Activate a node."""
88
+ node_id = state.get_node_id_by_public_key(request.public_key)
89
+ if node_id is None:
90
+ raise ValueError("No SuperNode found with the given public key.")
91
+ _validate_heartbeat_interval(request.heartbeat_interval)
92
+ if not state.activate_node(node_id, request.heartbeat_interval):
93
+ raise ValueError(f"SuperNode with node ID {node_id} could not be activated.")
94
+ return ActivateNodeResponse(node_id=node_id)
95
+
96
+
97
+ def deactivate_node(
98
+ request: DeactivateNodeRequest,
99
+ state: LinkState,
100
+ ) -> DeactivateNodeResponse:
101
+ """Deactivate a node."""
102
+ if not state.deactivate_node(request.node_id):
103
+ raise ValueError(
104
+ f"SuperNode with node ID {request.node_id} could not be deactivated."
105
+ )
106
+ return DeactivateNodeResponse()
82
107
 
83
- # Update state
84
- state.delete_node(node_id=request.node.node_id)
85
- return DeleteNodeResponse()
108
+
109
+ def unregister_node(
110
+ request: UnregisterNodeFleetRequest,
111
+ state: LinkState,
112
+ ) -> UnregisterNodeFleetResponse:
113
+ """Unregister a node (Fleet API only)."""
114
+ state.delete_node(NOOP_FLWR_AID, request.node_id)
115
+ return UnregisterNodeFleetResponse()
86
116
 
87
117
 
88
118
  def send_node_heartbeat(
@@ -90,6 +120,7 @@ def send_node_heartbeat(
90
120
  state: LinkState, # pylint: disable=unused-argument
91
121
  ) -> SendNodeHeartbeatResponse:
92
122
  """."""
123
+ _validate_heartbeat_interval(request.heartbeat_interval)
93
124
  res = state.acknowledge_node_heartbeat(
94
125
  request.node.node_id, request.heartbeat_interval
95
126
  )
@@ -137,10 +168,11 @@ def push_messages(
137
168
  """Push Messages handler."""
138
169
  # Convert Message from proto
139
170
  msg = message_from_proto(message_proto=request.messages_list[0])
171
+ run_id = msg.metadata.run_id
140
172
 
141
173
  # Abort if the run is not running
142
174
  abort_msg = check_abort(
143
- msg.metadata.run_id,
175
+ run_id,
144
176
  [Status.PENDING, Status.STARTING, Status.FINISHED],
145
177
  state,
146
178
  store,
@@ -148,11 +180,12 @@ def push_messages(
148
180
  if abort_msg:
149
181
  raise InvalidRunStatusException(abort_msg)
150
182
 
151
- # Store Message in State
152
- message_id: Optional[str] = state.store_message_res(message=msg)
153
-
154
183
  # Store Message object to descendants mapping and preregister objects
155
- objects_to_push = store_mapping_and_register_objects(store, request=request)
184
+ objects_to_push: set[str] = set()
185
+ for object_tree in request.message_object_trees:
186
+ objects_to_push |= set(store.preregister(run_id, object_tree))
187
+ # Store Message in State
188
+ message_id: str | None = state.store_message_res(message=msg)
156
189
 
157
190
  # Build response
158
191
  response = PushMessagesResponse(
@@ -167,10 +200,8 @@ def get_run(
167
200
  request: GetRunRequest, state: LinkState, store: ObjectStore
168
201
  ) -> GetRunResponse:
169
202
  """Get run information."""
170
- run = state.get_run(request.run_id)
171
-
172
- if run is None:
173
- return GetRunResponse()
203
+ # Validate that the requesting SuperNode is part of the federation
204
+ run = _validate_node_in_federation(state, request.node.node_id, request.run_id)
174
205
 
175
206
  # Abort if the run is not running
176
207
  abort_msg = check_abort(
@@ -182,21 +213,16 @@ def get_run(
182
213
  if abort_msg:
183
214
  raise InvalidRunStatusException(abort_msg)
184
215
 
185
- return GetRunResponse(
186
- run=Run(
187
- run_id=run.run_id,
188
- fab_id=run.fab_id,
189
- fab_version=run.fab_version,
190
- override_config=user_config_to_proto(run.override_config),
191
- fab_hash=run.fab_hash,
192
- )
193
- )
216
+ return GetRunResponse(run=run_to_proto(run))
194
217
 
195
218
 
196
219
  def get_fab(
197
220
  request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
198
221
  ) -> GetFabResponse:
199
222
  """Get FAB."""
223
+ # Validate that the requesting SuperNode is part of the federation
224
+ _validate_node_in_federation(state, request.node.node_id, request.run_id)
225
+
200
226
  # Abort if the run is not running
201
227
  abort_msg = check_abort(
202
228
  request.run_id,
@@ -208,7 +234,7 @@ def get_fab(
208
234
  raise InvalidRunStatusException(abort_msg)
209
235
 
210
236
  if result := ffs.get(request.hash_str):
211
- fab = Fab(request.hash_str, result[0])
237
+ fab = Fab(request.hash_str, result[0], result[1])
212
238
  return GetFabResponse(fab=fab_to_proto(fab))
213
239
 
214
240
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
@@ -284,3 +310,28 @@ def confirm_message_received(
284
310
  store.delete(request.message_object_id)
285
311
 
286
312
  return ConfirmMessageReceivedResponse()
313
+
314
+
315
+ def _validate_heartbeat_interval(interval: float) -> None:
316
+ """Raise if heartbeat interval is out of bounds."""
317
+ if not HEARTBEAT_MIN_INTERVAL <= interval <= HEARTBEAT_MAX_INTERVAL:
318
+ raise InvalidHeartbeatIntervalError(
319
+ f"Heartbeat interval {interval} is out of bounds "
320
+ f"[{HEARTBEAT_MIN_INTERVAL}, {HEARTBEAT_MAX_INTERVAL}]."
321
+ )
322
+
323
+
324
+ def _validate_node_in_federation(
325
+ state: LinkState,
326
+ node_id: int,
327
+ run_id: int,
328
+ ) -> Run:
329
+ """Raise if the requesting SuperNode is not part of the federation the run belongs
330
+ to."""
331
+ run = state.get_run(run_id)
332
+ if not run:
333
+ raise ValueError(f"Run ID not found: {run_id}")
334
+
335
+ if not state.federation_manager.has_node(node_id, run.federation):
336
+ raise ValueError(f"SuperNode is not part of the federation '{run.federation}'.")
337
+ return run