flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -15,25 +15,31 @@
15
15
  """Fleet API gRPC request-response servicer."""
16
16
 
17
17
 
18
- from logging import DEBUG, INFO
18
+ import threading
19
+ from logging import DEBUG, ERROR, INFO
19
20
 
20
21
  import grpc
21
22
  from google.protobuf.json_format import MessageToDict
22
23
 
24
+ from flwr.common.constant import PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
23
25
  from flwr.common.inflatable import UnexpectedObjectContentError
24
26
  from flwr.common.logger import log
25
27
  from flwr.common.typing import InvalidRunStatusException
26
28
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
27
29
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
28
30
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
29
- CreateNodeRequest,
30
- CreateNodeResponse,
31
- DeleteNodeRequest,
32
- DeleteNodeResponse,
31
+ ActivateNodeRequest,
32
+ ActivateNodeResponse,
33
+ DeactivateNodeRequest,
34
+ DeactivateNodeResponse,
33
35
  PullMessagesRequest,
34
36
  PullMessagesResponse,
35
37
  PushMessagesRequest,
36
38
  PushMessagesResponse,
39
+ RegisterNodeFleetRequest,
40
+ RegisterNodeFleetResponse,
41
+ UnregisterNodeFleetRequest,
42
+ UnregisterNodeFleetResponse,
37
43
  )
38
44
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
39
45
  SendNodeHeartbeatRequest,
@@ -63,49 +69,137 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
63
69
  state_factory: LinkStateFactory,
64
70
  ffs_factory: FfsFactory,
65
71
  objectstore_factory: ObjectStoreFactory,
72
+ enable_supernode_auth: bool,
66
73
  ) -> None:
67
74
  self.state_factory = state_factory
68
75
  self.ffs_factory = ffs_factory
69
76
  self.objectstore_factory = objectstore_factory
77
+ self.enable_supernode_auth = enable_supernode_auth
78
+ self.lock = threading.Lock()
70
79
 
71
- def CreateNode(
72
- self, request: CreateNodeRequest, context: grpc.ServicerContext
73
- ) -> CreateNodeResponse:
74
- """."""
75
- log(
76
- INFO,
77
- "[Fleet.CreateNode] Request heartbeat_interval=%s",
78
- request.heartbeat_interval,
79
- )
80
- log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
81
- response = message_handler.create_node(
82
- request=request,
83
- state=self.state_factory.state(),
84
- )
85
- log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
86
- log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
87
- return response
80
+ def RegisterNode(
81
+ self, request: RegisterNodeFleetRequest, context: grpc.ServicerContext
82
+ ) -> RegisterNodeFleetResponse:
83
+ """Register a node."""
84
+ # Prevent registration when SuperNode authentication is enabled
85
+ if self.enable_supernode_auth:
86
+ log(ERROR, "SuperNode registration is disabled through Fleet API.")
87
+ context.abort(
88
+ grpc.StatusCode.FAILED_PRECONDITION,
89
+ "SuperNode authentication is enabled. "
90
+ "All SuperNodes must be registered via the CLI.",
91
+ )
88
92
 
89
- def DeleteNode(
90
- self, request: DeleteNodeRequest, context: grpc.ServicerContext
91
- ) -> DeleteNodeResponse:
92
- """."""
93
- log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
94
- log(DEBUG, "[Fleet.DeleteNode] Request: %s", MessageToDict(request))
95
- return message_handler.delete_node(
96
- request=request,
97
- state=self.state_factory.state(),
98
- )
93
+ try:
94
+ response = message_handler.register_node(
95
+ request=request,
96
+ state=self.state_factory.state(),
97
+ )
98
+ log(DEBUG, "[Fleet.RegisterNode] Registered node_id=%s", response.node_id)
99
+ return response
100
+ except ValueError:
101
+ # Public key already in use
102
+ # This should NEVER happen due to the public keys should be automatically
103
+ # generated and unique for each SuperNode instance.
104
+ log(
105
+ ERROR,
106
+ "[Fleet.RegisterNode] Registration failed: %s",
107
+ PUBLIC_KEY_ALREADY_IN_USE_MESSAGE,
108
+ )
109
+ context.abort(
110
+ grpc.StatusCode.FAILED_PRECONDITION, PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
111
+ )
112
+
113
+ raise RuntimeError # Make mypy happy
114
+
115
+ def ActivateNode(
116
+ self, request: ActivateNodeRequest, context: grpc.ServicerContext
117
+ ) -> ActivateNodeResponse:
118
+ """Activate a node."""
119
+ try:
120
+ response = message_handler.activate_node(
121
+ request=request,
122
+ state=self.state_factory.state(),
123
+ )
124
+ log(INFO, "[Fleet.ActivateNode] Activated node_id=%s", response.node_id)
125
+ return response
126
+ except message_handler.InvalidHeartbeatIntervalError:
127
+ # Heartbeat interval is invalid
128
+ log(ERROR, "[Fleet.ActivateNode] Invalid heartbeat interval")
129
+ context.abort(
130
+ grpc.StatusCode.INVALID_ARGUMENT, "Invalid heartbeat interval"
131
+ )
132
+ except ValueError as e:
133
+ log(ERROR, "[Fleet.ActivateNode] Activation failed: %s", str(e))
134
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
135
+
136
+ raise RuntimeError # Make mypy happy
137
+
138
+ def DeactivateNode(
139
+ self, request: DeactivateNodeRequest, context: grpc.ServicerContext
140
+ ) -> DeactivateNodeResponse:
141
+ """Deactivate a node."""
142
+ try:
143
+ response = message_handler.deactivate_node(
144
+ request=request,
145
+ state=self.state_factory.state(),
146
+ )
147
+ log(INFO, "[Fleet.DeactivateNode] Deactivated node_id=%s", request.node_id)
148
+ return response
149
+ except ValueError as e:
150
+ log(ERROR, "[Fleet.DeactivateNode] Deactivation failed: %s", str(e))
151
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
152
+
153
+ raise RuntimeError # Make mypy happy
154
+
155
+ def UnregisterNode(
156
+ self, request: UnregisterNodeFleetRequest, context: grpc.ServicerContext
157
+ ) -> UnregisterNodeFleetResponse:
158
+ """Unregister a node."""
159
+ # Prevent unregistration when SuperNode authentication is enabled
160
+ if self.enable_supernode_auth:
161
+ log(ERROR, "SuperNode unregistration is disabled through Fleet API.")
162
+ context.abort(
163
+ grpc.StatusCode.FAILED_PRECONDITION,
164
+ "SuperNode authentication is enabled. "
165
+ "All SuperNodes must be unregistered via the CLI.",
166
+ )
167
+
168
+ try:
169
+ response = message_handler.unregister_node(
170
+ request=request,
171
+ state=self.state_factory.state(),
172
+ )
173
+ log(
174
+ DEBUG, "[Fleet.UnregisterNode] Unregistered node_id=%s", request.node_id
175
+ )
176
+ return response
177
+ except ValueError as e:
178
+ log(
179
+ ERROR,
180
+ "[Fleet.UnregisterNode] Unregistration failed: %s",
181
+ str(e),
182
+ )
183
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
184
+ raise RuntimeError from None # Make mypy happy
99
185
 
100
186
  def SendNodeHeartbeat(
101
187
  self, request: SendNodeHeartbeatRequest, context: grpc.ServicerContext
102
188
  ) -> SendNodeHeartbeatResponse:
103
189
  """."""
104
190
  log(DEBUG, "[Fleet.SendNodeHeartbeat] Request: %s", MessageToDict(request))
105
- return message_handler.send_node_heartbeat(
106
- request=request,
107
- state=self.state_factory.state(),
108
- )
191
+ try:
192
+ return message_handler.send_node_heartbeat(
193
+ request=request,
194
+ state=self.state_factory.state(),
195
+ )
196
+ except message_handler.InvalidHeartbeatIntervalError:
197
+ # Heartbeat interval is invalid
198
+ log(ERROR, "[Fleet.SendNodeHeartbeat] Invalid heartbeat interval")
199
+ context.abort(
200
+ grpc.StatusCode.INVALID_ARGUMENT, "Invalid heartbeat interval"
201
+ )
202
+ raise RuntimeError # Make mypy happy
109
203
 
110
204
  def PullMessages(
111
205
  self, request: PullMessagesRequest, context: grpc.ServicerContext
@@ -183,7 +277,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
183
277
  """Push an object to the ObjectStore."""
184
278
  log(
185
279
  DEBUG,
186
- "[ServerAppIoServicer.PushObject] Push Object with object_id=%s",
280
+ "[Fleet.PushObject] Push Object with object_id=%s",
187
281
  request.object_id,
188
282
  )
189
283
 
@@ -208,7 +302,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
208
302
  """Pull an object from the ObjectStore."""
209
303
  log(
210
304
  DEBUG,
211
- "[ServerAppIoServicer.PullObject] Pull Object with object_id=%s",
305
+ "[Fleet.PullObject] Pull Object with object_id=%s",
212
306
  request.object_id,
213
307
  )
214
308
 
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import datetime
19
- from typing import Any, Callable, Optional, cast
19
+ from typing import Any, Callable, cast
20
20
 
21
21
  import grpc
22
22
  from google.protobuf.message import Message as GrpcMessage
@@ -29,15 +29,12 @@ from flwr.common.constant import (
29
29
  TIMESTAMP_HEADER,
30
30
  TIMESTAMP_TOLERANCE,
31
31
  )
32
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
33
- bytes_to_public_key,
34
- verify_signature,
35
- )
36
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
- CreateNodeRequest,
38
- CreateNodeResponse,
33
+ ActivateNodeRequest,
34
+ RegisterNodeFleetRequest,
39
35
  )
40
36
  from flwr.server.superlink.linkstate import LinkStateFactory
37
+ from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
41
38
 
42
39
  MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
43
40
  MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
@@ -53,22 +50,17 @@ def _unary_unary_rpc_terminator(
53
50
  return grpc.unary_unary_rpc_method_handler(terminate)
54
51
 
55
52
 
56
- class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
53
+ class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
57
54
  """Server interceptor for node authentication.
58
55
 
59
56
  Parameters
60
57
  ----------
61
58
  state_factory : LinkStateFactory
62
59
  A factory for creating new instances of LinkState.
63
- auto_auth : bool (default: False)
64
- If True, nodes are authenticated without requiring their public keys to be
65
- pre-stored in the LinkState. If False, only nodes with pre-stored public keys
66
- can be authenticated.
67
60
  """
68
61
 
69
- def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
62
+ def __init__(self, state_factory: LinkStateFactory):
70
63
  self.state_factory = state_factory
71
- self.auto_auth = auto_auth
72
64
 
73
65
  def intercept_service( # pylint: disable=too-many-return-statements
74
66
  self,
@@ -85,7 +77,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
77
  if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
86
78
  return continuation(handler_call_details)
87
79
 
88
- state = self.state_factory.state()
89
80
  metadata_dict = dict(handler_call_details.invocation_metadata)
90
81
 
91
82
  # Retrieve info from the metadata
@@ -96,11 +87,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
96
87
  except KeyError:
97
88
  return _unary_unary_rpc_terminator("Missing authentication metadata")
98
89
 
99
- if not self.auto_auth:
100
- # Abort the RPC call if the node public key is not found
101
- if node_pk_bytes not in state.get_node_public_keys():
102
- return _unary_unary_rpc_terminator("Public key not recognized")
103
-
104
90
  # Verify the signature
105
91
  node_pk = bytes_to_public_key(node_pk_bytes)
106
92
  if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
@@ -113,50 +99,40 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
113
99
  if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
114
100
  return _unary_unary_rpc_terminator("Invalid timestamp")
115
101
 
116
- # Continue the RPC call
117
- expected_node_id = state.get_node_id(node_pk_bytes)
118
- if not handler_call_details.method.endswith("CreateNode"):
119
- # All calls, except for `CreateNode`, must provide a public key that is
120
- # already mapped to a `node_id` (in `LinkState`)
121
- if expected_node_id is None:
122
- return _unary_unary_rpc_terminator("Invalid node ID")
123
- # One of the method handlers in
102
+ # Continue the RPC call: One of the method handlers in
124
103
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
125
104
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
126
- return self._wrap_method_handler(
127
- method_handler, expected_node_id, node_pk_bytes
128
- )
105
+ return self._wrap_method_handler(method_handler, node_pk_bytes)
129
106
 
130
107
  def _wrap_method_handler(
131
108
  self,
132
109
  method_handler: grpc.RpcMethodHandler,
133
- expected_node_id: Optional[int],
134
- node_public_key: bytes,
110
+ expected_public_key: bytes,
135
111
  ) -> grpc.RpcMethodHandler:
136
112
  def _generic_method_handler(
137
113
  request: GrpcMessage,
138
114
  context: grpc.ServicerContext,
139
115
  ) -> GrpcMessage:
140
- # Verify the node ID
141
- if not isinstance(request, CreateNodeRequest):
142
- try:
143
- if request.node.node_id != expected_node_id: # type: ignore
144
- raise ValueError
145
- except (AttributeError, ValueError):
146
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
116
+ # Note: This function runs in a different thread
117
+ # than the `intercept_service` function.
118
+
119
+ # Retrieve the public key
120
+ if isinstance(request, (RegisterNodeFleetRequest, ActivateNodeRequest)):
121
+ actual_public_key = request.public_key
122
+ else:
123
+ if hasattr(request, "node"):
124
+ node_id = request.node.node_id
125
+ else:
126
+ node_id = request.node_id # type: ignore[attr-defined]
127
+ actual_public_key = self.state_factory.state().get_node_public_key(
128
+ node_id
129
+ )
130
+
131
+ # Verify the public key
132
+ if actual_public_key != expected_public_key:
133
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
147
134
 
148
135
  response: GrpcMessage = method_handler.unary_unary(request, context)
149
-
150
- # Set the public key after a successful CreateNode request
151
- if isinstance(response, CreateNodeResponse):
152
- state = self.state_factory.state()
153
- try:
154
- state.set_node_public_key(response.node.node_id, node_public_key)
155
- except ValueError as e:
156
- # Remove newly created node if setting the public key fails
157
- state.delete_node(response.node.node_id)
158
- context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
159
-
160
136
  return response
161
137
 
162
138
  return grpc.unary_unary_rpc_method_handler(
@@ -18,7 +18,12 @@ from logging import ERROR
18
18
  from typing import Optional
19
19
 
20
20
  from flwr.common import Message, log
21
- from flwr.common.constant import Status
21
+ from flwr.common.constant import (
22
+ HEARTBEAT_MAX_INTERVAL,
23
+ HEARTBEAT_MIN_INTERVAL,
24
+ NOOP_FLWR_AID,
25
+ Status,
26
+ )
22
27
  from flwr.common.inflatable import UnexpectedObjectContentError
23
28
  from flwr.common.serde import (
24
29
  fab_to_proto,
@@ -29,15 +34,19 @@ from flwr.common.serde import (
29
34
  from flwr.common.typing import Fab, InvalidRunStatusException
30
35
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
31
36
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
- CreateNodeRequest,
33
- CreateNodeResponse,
34
- DeleteNodeRequest,
35
- DeleteNodeResponse,
37
+ ActivateNodeRequest,
38
+ ActivateNodeResponse,
39
+ DeactivateNodeRequest,
40
+ DeactivateNodeResponse,
36
41
  PullMessagesRequest,
37
42
  PullMessagesResponse,
38
43
  PushMessagesRequest,
39
44
  PushMessagesResponse,
40
45
  Reconnect,
46
+ RegisterNodeFleetRequest,
47
+ RegisterNodeFleetResponse,
48
+ UnregisterNodeFleetRequest,
49
+ UnregisterNodeFleetResponse,
41
50
  )
42
51
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
43
52
  SendNodeHeartbeatRequest,
@@ -51,7 +60,6 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
51
60
  PushObjectRequest,
52
61
  PushObjectResponse,
53
62
  )
54
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
55
63
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
56
64
  GetRunRequest,
57
65
  GetRunResponse,
@@ -64,25 +72,52 @@ from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
64
72
  from flwr.supercore.object_store.utils import store_mapping_and_register_objects
65
73
 
66
74
 
67
- def create_node(
68
- request: CreateNodeRequest, # pylint: disable=unused-argument
75
+ class InvalidHeartbeatIntervalError(Exception):
76
+ """Invalid heartbeat interval exception."""
77
+
78
+
79
+ def register_node(
80
+ request: RegisterNodeFleetRequest,
69
81
  state: LinkState,
70
- ) -> CreateNodeResponse:
71
- """."""
72
- # Create node
73
- node_id = state.create_node(heartbeat_interval=request.heartbeat_interval)
74
- return CreateNodeResponse(node=Node(node_id=node_id))
82
+ ) -> RegisterNodeFleetResponse:
83
+ """Register a node (Fleet API only)."""
84
+ node_id = state.create_node(NOOP_FLWR_AID, request.public_key, 0)
85
+ return RegisterNodeFleetResponse(node_id=node_id)
75
86
 
76
87
 
77
- def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
78
- """."""
79
- # Validate node_id
80
- if request.node.node_id == 0: # i.e. unset `node_id`
81
- return DeleteNodeResponse()
88
+ def activate_node(
89
+ request: ActivateNodeRequest,
90
+ state: LinkState,
91
+ ) -> ActivateNodeResponse:
92
+ """Activate a node."""
93
+ node_id = state.get_node_id_by_public_key(request.public_key)
94
+ if node_id is None:
95
+ raise ValueError("No SuperNode found with the given public key.")
96
+ _validate_heartbeat_interval(request.heartbeat_interval)
97
+ if not state.activate_node(node_id, request.heartbeat_interval):
98
+ raise ValueError(f"SuperNode with node ID {node_id} could not be activated.")
99
+ return ActivateNodeResponse(node_id=node_id)
100
+
101
+
102
+ def deactivate_node(
103
+ request: DeactivateNodeRequest,
104
+ state: LinkState,
105
+ ) -> DeactivateNodeResponse:
106
+ """Deactivate a node."""
107
+ if not state.deactivate_node(request.node_id):
108
+ raise ValueError(
109
+ f"SuperNode with node ID {request.node_id} could not be deactivated."
110
+ )
111
+ return DeactivateNodeResponse()
112
+
82
113
 
83
- # Update state
84
- state.delete_node(node_id=request.node.node_id)
85
- return DeleteNodeResponse()
114
+ def unregister_node(
115
+ request: UnregisterNodeFleetRequest,
116
+ state: LinkState,
117
+ ) -> UnregisterNodeFleetResponse:
118
+ """Unregister a node (Fleet API only)."""
119
+ state.delete_node(NOOP_FLWR_AID, request.node_id)
120
+ return UnregisterNodeFleetResponse()
86
121
 
87
122
 
88
123
  def send_node_heartbeat(
@@ -90,6 +125,7 @@ def send_node_heartbeat(
90
125
  state: LinkState, # pylint: disable=unused-argument
91
126
  ) -> SendNodeHeartbeatResponse:
92
127
  """."""
128
+ _validate_heartbeat_interval(request.heartbeat_interval)
93
129
  res = state.acknowledge_node_heartbeat(
94
130
  request.node.node_id, request.heartbeat_interval
95
131
  )
@@ -208,7 +244,7 @@ def get_fab(
208
244
  raise InvalidRunStatusException(abort_msg)
209
245
 
210
246
  if result := ffs.get(request.hash_str):
211
- fab = Fab(request.hash_str, result[0])
247
+ fab = Fab(request.hash_str, result[0], result[1])
212
248
  return GetFabResponse(fab=fab_to_proto(fab))
213
249
 
214
250
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
@@ -284,3 +320,12 @@ def confirm_message_received(
284
320
  store.delete(request.message_object_id)
285
321
 
286
322
  return ConfirmMessageReceivedResponse()
323
+
324
+
325
+ def _validate_heartbeat_interval(interval: float) -> None:
326
+ """Raise if heartbeat interval is out of bounds."""
327
+ if not HEARTBEAT_MIN_INTERVAL <= interval <= HEARTBEAT_MAX_INTERVAL:
328
+ raise InvalidHeartbeatIntervalError(
329
+ f"Heartbeat interval {interval} is out of bounds "
330
+ f"[{HEARTBEAT_MIN_INTERVAL}, {HEARTBEAT_MAX_INTERVAL}]."
331
+ )
@@ -25,14 +25,18 @@ from google.protobuf.message import Message as GrpcMessage
25
25
  from flwr.common.exit import ExitCode, flwr_exit
26
26
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
27
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
28
- CreateNodeRequest,
29
- CreateNodeResponse,
30
- DeleteNodeRequest,
31
- DeleteNodeResponse,
28
+ ActivateNodeRequest,
29
+ ActivateNodeResponse,
30
+ DeactivateNodeRequest,
31
+ DeactivateNodeResponse,
32
32
  PullMessagesRequest,
33
33
  PullMessagesResponse,
34
34
  PushMessagesRequest,
35
35
  PushMessagesResponse,
36
+ RegisterNodeFleetRequest,
37
+ RegisterNodeFleetResponse,
38
+ UnregisterNodeFleetRequest,
39
+ UnregisterNodeFleetResponse,
36
40
  )
37
41
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
38
42
  SendNodeHeartbeatRequest,
@@ -69,6 +73,8 @@ GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
69
73
  GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
70
74
  RestEndPoint = Callable[[Request], Awaitable[Response]]
71
75
 
76
+ routes = []
77
+
72
78
 
73
79
  def rest_request_response(
74
80
  grpc_request_type: type[GrpcRequest],
@@ -76,6 +82,7 @@ def rest_request_response(
76
82
  """Convert an async gRPC-based function into a RESTful HTTP endpoint."""
77
83
 
78
84
  def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
85
+
79
86
  async def wrapper(request: Request) -> Response:
80
87
  _check_headers(request.headers)
81
88
 
@@ -91,33 +98,64 @@ def rest_request_response(
91
98
  headers={"Content-Type": "application/protobuf"},
92
99
  )
93
100
 
101
+ # Register route
102
+ path = f"/api/v0/fleet/{func.__name__.replace('_', '-')}"
103
+ routes.append(Route(path, wrapper, methods=["POST"]))
94
104
  return wrapper
95
105
 
96
106
  return decorator
97
107
 
98
108
 
99
- @rest_request_response(CreateNodeRequest)
100
- async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
101
- """Create Node."""
109
+ @rest_request_response(RegisterNodeFleetRequest)
110
+ async def register_node(
111
+ request: RegisterNodeFleetRequest,
112
+ ) -> RegisterNodeFleetResponse:
113
+ """Register a node (Fleet API only)."""
114
+ # Get state from app
115
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
116
+
117
+ # Handle message
118
+ return message_handler.register_node(request=request, state=state)
119
+
120
+
121
+ @rest_request_response(ActivateNodeRequest)
122
+ async def activate_node(
123
+ request: ActivateNodeRequest,
124
+ ) -> ActivateNodeResponse:
125
+ """Activate a node."""
102
126
  # Get state from app
103
127
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
104
128
 
105
129
  # Handle message
106
- return message_handler.create_node(request=request, state=state)
130
+ return message_handler.activate_node(request=request, state=state)
107
131
 
108
132
 
109
- @rest_request_response(DeleteNodeRequest)
110
- async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
111
- """Delete Node Id."""
133
+ @rest_request_response(DeactivateNodeRequest)
134
+ async def deactivate_node(
135
+ request: DeactivateNodeRequest,
136
+ ) -> DeactivateNodeResponse:
137
+ """Deactivate a node."""
112
138
  # Get state from app
113
139
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
114
140
 
115
141
  # Handle message
116
- return message_handler.delete_node(request=request, state=state)
142
+ return message_handler.deactivate_node(request=request, state=state)
143
+
144
+
145
+ @rest_request_response(UnregisterNodeFleetRequest)
146
+ async def unregister_node(
147
+ request: UnregisterNodeFleetRequest,
148
+ ) -> UnregisterNodeFleetResponse:
149
+ """Unregister a node (Fleet API only)."""
150
+ # Get state from app
151
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
152
+
153
+ # Handle message
154
+ return message_handler.unregister_node(request=request, state=state)
117
155
 
118
156
 
119
157
  @rest_request_response(PullMessagesRequest)
120
- async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
158
+ async def pull_messages(request: PullMessagesRequest) -> PullMessagesResponse:
121
159
  """Pull PullMessages."""
122
160
  # Get state from app
123
161
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
@@ -128,7 +166,7 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
128
166
 
129
167
 
130
168
  @rest_request_response(PushMessagesRequest)
131
- async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
169
+ async def push_messages(request: PushMessagesRequest) -> PushMessagesResponse:
132
170
  """Pull PushMessages."""
133
171
  # Get state from app
134
172
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
@@ -212,23 +250,6 @@ async def confirm_message_received(
212
250
  )
213
251
 
214
252
 
215
- routes = [
216
- Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
217
- Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
218
- Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
219
- Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
220
- Route("/api/v0/fleet/pull-object", pull_object, methods=["POST"]),
221
- Route("/api/v0/fleet/push-object", push_object, methods=["POST"]),
222
- Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
223
- Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
224
- Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
225
- Route(
226
- "/api/v0/fleet/confirm-message-received",
227
- confirm_message_received,
228
- methods=["POST"],
229
- ),
230
- ]
231
-
232
253
  app: Starlette = Starlette(
233
254
  debug=False,
234
255
  routes=routes,
@@ -18,7 +18,7 @@
18
18
  from abc import ABC, abstractmethod
19
19
  from typing import Callable
20
20
 
21
- from flwr.client.client_app import ClientApp
21
+ from flwr.clientapp.client_app import ClientApp
22
22
  from flwr.common.context import Context
23
23
  from flwr.common.message import Message
24
24
  from flwr.common.typing import ConfigRecordValues
@@ -21,7 +21,7 @@ from typing import Callable, Optional, Union
21
21
 
22
22
  import ray
23
23
 
24
- from flwr.client.client_app import ClientApp
24
+ from flwr.clientapp.client_app import ClientApp
25
25
  from flwr.common.constant import PARTITION_ID_KEY
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.logger import log