flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -0,0 +1,214 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower server interceptor."""
16
+
17
+
18
+ import base64
19
+ from logging import WARNING
20
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
+
22
+ import grpc
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+
25
+ from flwr.common.logger import log
26
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
27
+ bytes_to_private_key,
28
+ bytes_to_public_key,
29
+ generate_shared_key,
30
+ verify_hmac,
31
+ )
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
+ CreateNodeRequest,
34
+ CreateNodeResponse,
35
+ DeleteNodeRequest,
36
+ DeleteNodeResponse,
37
+ PingRequest,
38
+ PingResponse,
39
+ PullTaskInsRequest,
40
+ PullTaskInsResponse,
41
+ PushTaskResRequest,
42
+ PushTaskResResponse,
43
+ )
44
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
45
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
46
+ from flwr.server.superlink.state import State
47
+
48
+ _PUBLIC_KEY_HEADER = "public-key"
49
+ _AUTH_TOKEN_HEADER = "auth-token"
50
+
51
+ Request = Union[
52
+ CreateNodeRequest,
53
+ DeleteNodeRequest,
54
+ PullTaskInsRequest,
55
+ PushTaskResRequest,
56
+ GetRunRequest,
57
+ PingRequest,
58
+ ]
59
+
60
+ Response = Union[
61
+ CreateNodeResponse,
62
+ DeleteNodeResponse,
63
+ PullTaskInsResponse,
64
+ PushTaskResResponse,
65
+ GetRunResponse,
66
+ PingResponse,
67
+ ]
68
+
69
+
70
+ def _get_value_from_tuples(
71
+ key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
72
+ ) -> bytes:
73
+ value = next((value for key, value in tuples if key == key_string), "")
74
+ if isinstance(value, str):
75
+ return value.encode()
76
+
77
+ return value
78
+
79
+
80
+ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
81
+ """Server interceptor for client authentication."""
82
+
83
+ def __init__(self, state: State):
84
+ self.state = state
85
+
86
+ self.client_public_keys = state.get_client_public_keys()
87
+ if len(self.client_public_keys) == 0:
88
+ log(WARNING, "Authentication enabled, but no known public keys configured")
89
+
90
+ private_key = self.state.get_server_private_key()
91
+ public_key = self.state.get_server_public_key()
92
+
93
+ if private_key is None or public_key is None:
94
+ raise ValueError("Error loading authentication keys")
95
+
96
+ self.server_private_key = bytes_to_private_key(private_key)
97
+ self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
98
+
99
+ def intercept_service(
100
+ self,
101
+ continuation: Callable[[Any], Any],
102
+ handler_call_details: grpc.HandlerCallDetails,
103
+ ) -> grpc.RpcMethodHandler:
104
+ """Flower server interceptor authentication logic.
105
+
106
+ Intercept all unary calls from clients and authenticate clients by validating
107
+ auth metadata sent by the client. Continue RPC call if client is authenticated,
108
+ else, terminate RPC call by setting context to abort.
109
+ """
110
+ # One of the method handlers in
111
+ # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
112
+ method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
113
+ return self._generic_auth_unary_method_handler(method_handler)
114
+
115
+ def _generic_auth_unary_method_handler(
116
+ self, method_handler: grpc.RpcMethodHandler
117
+ ) -> grpc.RpcMethodHandler:
118
+ def _generic_method_handler(
119
+ request: Request,
120
+ context: grpc.ServicerContext,
121
+ ) -> Response:
122
+ client_public_key_bytes = base64.urlsafe_b64decode(
123
+ _get_value_from_tuples(
124
+ _PUBLIC_KEY_HEADER, context.invocation_metadata()
125
+ )
126
+ )
127
+ if client_public_key_bytes not in self.client_public_keys:
128
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
129
+
130
+ if isinstance(request, CreateNodeRequest):
131
+ return self._create_authenticated_node(
132
+ client_public_key_bytes, request, context
133
+ )
134
+
135
+ # Verify hmac value
136
+ hmac_value = base64.urlsafe_b64decode(
137
+ _get_value_from_tuples(
138
+ _AUTH_TOKEN_HEADER, context.invocation_metadata()
139
+ )
140
+ )
141
+ public_key = bytes_to_public_key(client_public_key_bytes)
142
+
143
+ if not self._verify_hmac(public_key, request, hmac_value):
144
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
145
+
146
+ # Verify node_id
147
+ node_id = self.state.get_node_id(client_public_key_bytes)
148
+
149
+ if not self._verify_node_id(node_id, request):
150
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
151
+
152
+ return method_handler.unary_unary(request, context) # type: ignore
153
+
154
+ return grpc.unary_unary_rpc_method_handler(
155
+ _generic_method_handler,
156
+ request_deserializer=method_handler.request_deserializer,
157
+ response_serializer=method_handler.response_serializer,
158
+ )
159
+
160
+ def _verify_node_id(
161
+ self,
162
+ node_id: Optional[int],
163
+ request: Union[
164
+ DeleteNodeRequest,
165
+ PullTaskInsRequest,
166
+ PushTaskResRequest,
167
+ GetRunRequest,
168
+ PingRequest,
169
+ ],
170
+ ) -> bool:
171
+ if node_id is None:
172
+ return False
173
+ if isinstance(request, PushTaskResRequest):
174
+ if len(request.task_res_list) == 0:
175
+ return False
176
+ return request.task_res_list[0].task.producer.node_id == node_id
177
+ if isinstance(request, GetRunRequest):
178
+ return node_id in self.state.get_nodes(request.run_id)
179
+ return request.node.node_id == node_id
180
+
181
+ def _verify_hmac(
182
+ self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
183
+ ) -> bool:
184
+ shared_secret = generate_shared_key(self.server_private_key, public_key)
185
+ return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
186
+
187
+ def _create_authenticated_node(
188
+ self,
189
+ public_key_bytes: bytes,
190
+ request: CreateNodeRequest,
191
+ context: grpc.ServicerContext,
192
+ ) -> CreateNodeResponse:
193
+ context.send_initial_metadata(
194
+ (
195
+ (
196
+ _PUBLIC_KEY_HEADER,
197
+ self.encoded_server_public_key,
198
+ ),
199
+ )
200
+ )
201
+
202
+ node_id = self.state.get_node_id(public_key_bytes)
203
+
204
+ # Handle `CreateNode` here instead of calling the default method handler
205
+ # Return previously assigned `node_id` for the provided `public_key`
206
+ if node_id is not None:
207
+ self.state.acknowledge_ping(node_id, request.ping_interval)
208
+ return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
209
+
210
+ # No `node_id` exists for the provided `public_key`
211
+ # Handle `CreateNode` here instead of calling the default method handler
212
+ # Note: the innermost `CreateNode` method will never be called
213
+ node_id = self.state.create_node(request.ping_interval, public_key_bytes)
214
+ return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,14 +15,18 @@
15
15
  """Fleet API message handlers."""
16
16
 
17
17
 
18
+ import time
18
19
  from typing import List, Optional
19
20
  from uuid import UUID
20
21
 
22
+ from flwr.common.serde import user_config_to_proto
21
23
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
24
  CreateNodeRequest,
23
25
  CreateNodeResponse,
24
26
  DeleteNodeRequest,
25
27
  DeleteNodeResponse,
28
+ PingRequest,
29
+ PingResponse,
26
30
  PullTaskInsRequest,
27
31
  PullTaskInsResponse,
28
32
  PushTaskResRequest,
@@ -30,6 +34,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
30
34
  Reconnect,
31
35
  )
32
36
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
37
+ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
38
+ GetRunRequest,
39
+ GetRunResponse,
40
+ Run,
41
+ )
33
42
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
34
43
  from flwr.server.superlink.state import State
35
44
 
@@ -40,7 +49,7 @@ def create_node(
40
49
  ) -> CreateNodeResponse:
41
50
  """."""
42
51
  # Create node
43
- node_id = state.create_node()
52
+ node_id = state.create_node(ping_interval=request.ping_interval)
44
53
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
45
54
 
46
55
 
@@ -55,6 +64,15 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
55
64
  return DeleteNodeResponse()
56
65
 
57
66
 
67
+ def ping(
68
+ request: PingRequest, # pylint: disable=unused-argument
69
+ state: State, # pylint: disable=unused-argument
70
+ ) -> PingResponse:
71
+ """."""
72
+ res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
73
+ return PingResponse(success=res)
74
+
75
+
58
76
  def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
59
77
  """Pull TaskIns handler."""
60
78
  # Get node_id if client node is not anonymous
@@ -77,6 +95,9 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
77
95
  task_res: TaskRes = request.task_res_list[0]
78
96
  # pylint: enable=no-member
79
97
 
98
+ # Set pushed_at (timestamp in seconds)
99
+ task_res.task.pushed_at = time.time()
100
+
80
101
  # Store TaskRes in State
81
102
  task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
82
103
 
@@ -86,3 +107,22 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
86
107
  results={str(task_id): 0},
87
108
  )
88
109
  return response
110
+
111
+
112
+ def get_run(
113
+ request: GetRunRequest, state: State # pylint: disable=W0613
114
+ ) -> GetRunResponse:
115
+ """Get run information."""
116
+ run = state.get_run(request.run_id)
117
+
118
+ if run is None:
119
+ return GetRunResponse()
120
+
121
+ return GetRunResponse(
122
+ run=Run(
123
+ run_id=run.run_id,
124
+ fab_id=run.fab_id,
125
+ fab_version=run.fab_version,
126
+ override_config=user_config_to_proto(run.override_config),
127
+ )
128
+ )
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,9 +21,11 @@ from flwr.common.constant import MISSING_EXTRA_REST
21
21
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
22
  CreateNodeRequest,
23
23
  DeleteNodeRequest,
24
+ PingRequest,
24
25
  PullTaskInsRequest,
25
26
  PushTaskResRequest,
26
27
  )
28
+ from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
27
29
  from flwr.server.superlink.fleet.message_handler import message_handler
28
30
  from flwr.server.superlink.state import State
29
31
 
@@ -152,11 +154,67 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
152
154
  )
153
155
 
154
156
 
157
+ async def ping(request: Request) -> Response:
158
+ """Ping."""
159
+ _check_headers(request.headers)
160
+
161
+ # Get the request body as raw bytes
162
+ ping_request_bytes: bytes = await request.body()
163
+
164
+ # Deserialize ProtoBuf
165
+ ping_request_proto = PingRequest()
166
+ ping_request_proto.ParseFromString(ping_request_bytes)
167
+
168
+ # Get state from app
169
+ state: State = app.state.STATE_FACTORY.state()
170
+
171
+ # Handle message
172
+ ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
173
+
174
+ # Return serialized ProtoBuf
175
+ ping_response_bytes = ping_response_proto.SerializeToString()
176
+ return Response(
177
+ status_code=200,
178
+ content=ping_response_bytes,
179
+ headers={"Content-Type": "application/protobuf"},
180
+ )
181
+
182
+
183
+ async def get_run(request: Request) -> Response:
184
+ """GetRun."""
185
+ _check_headers(request.headers)
186
+
187
+ # Get the request body as raw bytes
188
+ get_run_request_bytes: bytes = await request.body()
189
+
190
+ # Deserialize ProtoBuf
191
+ get_run_request_proto = GetRunRequest()
192
+ get_run_request_proto.ParseFromString(get_run_request_bytes)
193
+
194
+ # Get state from app
195
+ state: State = app.state.STATE_FACTORY.state()
196
+
197
+ # Handle message
198
+ get_run_response_proto = message_handler.get_run(
199
+ request=get_run_request_proto, state=state
200
+ )
201
+
202
+ # Return serialized ProtoBuf
203
+ get_run_response_bytes = get_run_response_proto.SerializeToString()
204
+ return Response(
205
+ status_code=200,
206
+ content=get_run_response_bytes,
207
+ headers={"Content-Type": "application/protobuf"},
208
+ )
209
+
210
+
155
211
  routes = [
156
212
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
157
213
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
158
214
  Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
159
215
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
216
+ Route("/api/v0/fleet/ping", ping, methods=["POST"]),
217
+ Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
160
218
  ]
161
219
 
162
220
  app: Starlette = Starlette(
@@ -38,7 +38,7 @@ else:
38
38
 
39
39
  To install the necessary dependencies, install `flwr` with the `simulation` extra:
40
40
 
41
- pip install -U flwr["simulation"]
41
+ pip install -U "flwr[simulation]"
42
42
  """
43
43
 
44
44
 
@@ -29,12 +29,12 @@ BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]]
29
29
  class Backend(ABC):
30
30
  """Abstract base class for a Simulation Engine Backend."""
31
31
 
32
- def __init__(self, backend_config: BackendConfig, work_dir: str) -> None:
32
+ def __init__(self, backend_config: BackendConfig) -> None:
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- async def build(self) -> None:
37
- """Build backend asynchronously.
36
+ def build(self) -> None:
37
+ """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
40
40
  to accept jobs. When this method finishes executing, the backend should be fully
@@ -54,11 +54,11 @@ class Backend(ABC):
54
54
  """Report whether a backend worker is idle and can therefore run a ClientApp."""
55
55
 
56
56
  @abstractmethod
57
- async def terminate(self) -> None:
57
+ def terminate(self) -> None:
58
58
  """Terminate backend."""
59
59
 
60
60
  @abstractmethod
61
- async def process_message(
61
+ def process_message(
62
62
  self,
63
63
  app: Callable[[], ClientApp],
64
64
  message: Message,
@@ -14,26 +14,24 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
- import pathlib
18
- from logging import ERROR, INFO
19
- from typing import Callable, Dict, List, Tuple, Union
17
+ from logging import DEBUG, ERROR
18
+ from typing import Callable, Dict, Tuple, Union
20
19
 
21
20
  import ray
22
21
 
23
- from flwr.client.client_app import ClientApp, LoadClientAppError
22
+ from flwr.client.client_app import ClientApp
23
+ from flwr.common.constant import PARTITION_ID_KEY
24
24
  from flwr.common.context import Context
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.message import Message
27
- from flwr.simulation.ray_transport.ray_actor import (
28
- BasicActorPool,
29
- ClientAppActor,
30
- init_ray,
31
- )
27
+ from flwr.common.typing import ConfigsRecordValues
28
+ from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
32
29
  from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
33
30
 
34
31
  from .backend import Backend, BackendConfig
35
32
 
36
33
  ClientResourcesDict = Dict[str, Union[int, float]]
34
+ ActorArgsDict = Dict[str, Union[int, float, Callable[[], None]]]
37
35
 
38
36
 
39
37
  class RayBackend(Backend):
@@ -42,52 +40,28 @@ class RayBackend(Backend):
42
40
  def __init__(
43
41
  self,
44
42
  backend_config: BackendConfig,
45
- work_dir: str,
46
43
  ) -> None:
47
44
  """Prepare RayBackend by initialising Ray and creating the ActorPool."""
48
- log(INFO, "Initialising: %s", self.__class__.__name__)
49
- log(INFO, "Backend config: %s", backend_config)
45
+ log(DEBUG, "Initialising: %s", self.__class__.__name__)
46
+ log(DEBUG, "Backend config: %s", backend_config)
50
47
 
51
- if not pathlib.Path(work_dir).exists():
52
- raise ValueError(f"Specified work_dir {work_dir} does not exist.")
53
-
54
- # Init ray and append working dir if needed
55
- runtime_env = (
56
- self._configure_runtime_env(work_dir=work_dir) if work_dir else None
57
- )
58
- init_ray(runtime_env=runtime_env)
48
+ # Initialise ray
49
+ self.init_args_key = "init_args"
50
+ self.init_ray(backend_config)
59
51
 
60
52
  # Validate client resources
61
53
  self.client_resources_key = "client_resources"
54
+ client_resources = self._validate_client_resources(config=backend_config)
62
55
 
63
56
  # Create actor pool
64
- use_tf = backend_config.get("tensorflow", False)
65
- actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {}
57
+ actor_kwargs = self._validate_actor_arguments(config=backend_config)
66
58
 
67
- client_resources = self._validate_client_resources(config=backend_config)
68
59
  self.pool = BasicActorPool(
69
60
  actor_type=ClientAppActor,
70
61
  client_resources=client_resources,
71
62
  actor_kwargs=actor_kwargs,
72
63
  )
73
64
 
74
- def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]:
75
- """Return list of files/subdirectories to exclude relative to work_dir.
76
-
77
- Without this, Ray will push everything to the Ray Cluster.
78
- """
79
- runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir}
80
-
81
- excludes = []
82
- path = pathlib.Path(work_dir)
83
- for p in path.rglob("*"):
84
- # Exclude files need to be relative to the working_dir
85
- if p.is_file() and not str(p).endswith(".py"):
86
- excludes.append(str(p.relative_to(path)))
87
- runtime_env["excludes"] = excludes
88
-
89
- return runtime_env
90
-
91
65
  def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
92
66
  client_resources_config = config.get(self.client_resources_key)
93
67
  client_resources: ClientResourcesDict = {}
@@ -109,7 +83,7 @@ class RayBackend(Backend):
109
83
  else:
110
84
  client_resources = {"num_cpus": 2, "num_gpus": 0.0}
111
85
  log(
112
- INFO,
86
+ DEBUG,
113
87
  "`%s` not specified in backend config. Applying default setting: %s",
114
88
  self.client_resources_key,
115
89
  client_resources,
@@ -117,6 +91,29 @@ class RayBackend(Backend):
117
91
 
118
92
  return client_resources
119
93
 
94
+ def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict:
95
+ actor_args_config = config.get("actor", False)
96
+ actor_args: ActorArgsDict = {}
97
+ if actor_args_config:
98
+ use_tf = actor_args.get("tensorflow", False)
99
+ if use_tf:
100
+ actor_args["on_actor_init_fn"] = enable_tf_gpu_growth
101
+ return actor_args
102
+
103
+ def init_ray(self, backend_config: BackendConfig) -> None:
104
+ """Intialises Ray if not already initialised."""
105
+ if not ray.is_initialized():
106
+ ray_init_args: Dict[
107
+ str,
108
+ ConfigsRecordValues,
109
+ ] = {}
110
+
111
+ if backend_config.get(self.init_args_key):
112
+ for k, v in backend_config[self.init_args_key].items():
113
+ ray_init_args[k] = v
114
+
115
+ ray.init(**ray_init_args)
116
+
120
117
  @property
121
118
  def num_workers(self) -> int:
122
119
  """Return number of actors in pool."""
@@ -126,12 +123,12 @@ class RayBackend(Backend):
126
123
  """Report whether the pool has idle actors."""
127
124
  return self.pool.is_actor_available()
128
125
 
129
- async def build(self) -> None:
126
+ def build(self) -> None:
130
127
  """Build pool of Ray actors that this backend will submit jobs to."""
131
- await self.pool.add_actors_to_pool(self.pool.actors_capacity)
132
- log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors)
128
+ self.pool.add_actors_to_pool(self.pool.actors_capacity)
129
+ log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
130
 
134
- async def process_message(
131
+ def process_message(
135
132
  self,
136
133
  app: Callable[[], ClientApp],
137
134
  message: Message,
@@ -141,35 +138,35 @@ class RayBackend(Backend):
141
138
 
142
139
  Return output message and updated context.
143
140
  """
144
- partition_id = message.metadata.partition_id
141
+ partition_id = context.node_config[PARTITION_ID_KEY]
145
142
 
146
143
  try:
147
- # Submite a task to the pool
148
- future = await self.pool.submit(
144
+ # Submit a task to the pool
145
+ future = self.pool.submit(
149
146
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
150
147
  (app, message, str(partition_id), context),
151
148
  )
152
149
 
153
- await future
154
-
155
150
  # Fetch result
156
151
  (
157
152
  out_mssg,
158
153
  updated_context,
159
- ) = await self.pool.fetch_result_and_return_actor_to_pool(future)
154
+ ) = self.pool.fetch_result_and_return_actor_to_pool(future)
160
155
 
161
156
  return out_mssg, updated_context
162
157
 
163
- except LoadClientAppError as load_ex:
158
+ except Exception as ex:
164
159
  log(
165
160
  ERROR,
166
161
  "An exception was raised when processing a message by %s",
167
162
  self.__class__.__name__,
168
163
  )
169
- raise load_ex
164
+ # add actor back into pool
165
+ self.pool.add_actor_back_to_pool(future)
166
+ raise ex
170
167
 
171
- async def terminate(self) -> None:
168
+ def terminate(self) -> None:
172
169
  """Terminate all actors in actor pool."""
173
- await self.pool.terminate_all_actors()
170
+ self.pool.terminate_all_actors()
174
171
  ray.shutdown()
175
- log(INFO, "Terminated %s", self.__class__.__name__)
172
+ log(DEBUG, "Terminated %s", self.__class__.__name__)