flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,32 +15,52 @@
15
15
  """Contextmanager for a gRPC request-response channel to the Flower server."""
16
16
 
17
17
 
18
+ import random
19
+ import threading
18
20
  from contextlib import contextmanager
19
21
  from copy import copy
20
22
  from logging import DEBUG, ERROR
21
23
  from pathlib import Path
22
- from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast
23
25
 
26
+ import grpc
27
+ from cryptography.hazmat.primitives.asymmetric import ec
28
+
29
+ from flwr.client.heartbeat import start_ping_loop
24
30
  from flwr.client.message_handler.message_handler import validate_out_message
25
31
  from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
26
32
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
33
+ from flwr.common.constant import (
34
+ PING_BASE_MULTIPLIER,
35
+ PING_CALL_TIMEOUT,
36
+ PING_DEFAULT_INTERVAL,
37
+ PING_RANDOM_RANGE,
38
+ )
27
39
  from flwr.common.grpc import create_channel
28
- from flwr.common.logger import log, warn_experimental_feature
40
+ from flwr.common.logger import log
29
41
  from flwr.common.message import Message, Metadata
30
42
  from flwr.common.retry_invoker import RetryInvoker
31
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Fab, Run
32
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
50
  CreateNodeRequest,
34
51
  DeleteNodeRequest,
52
+ PingRequest,
53
+ PingResponse,
35
54
  PullTaskInsRequest,
36
55
  PushTaskResRequest,
37
56
  )
38
57
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
39
58
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
59
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
60
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
61
 
42
- KEY_NODE = "node"
43
- KEY_METADATA = "in_message_metadata"
62
+ from .client_interceptor import AuthenticateClientInterceptor
63
+ from .grpc_adapter import GrpcAdapter
44
64
 
45
65
 
46
66
  def on_channel_state_change(channel_connectivity: str) -> None:
@@ -49,18 +69,24 @@ def on_channel_state_change(channel_connectivity: str) -> None:
49
69
 
50
70
 
51
71
  @contextmanager
52
- def grpc_request_response(
72
+ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
53
73
  server_address: str,
54
74
  insecure: bool,
55
75
  retry_invoker: RetryInvoker,
56
76
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
57
77
  root_certificates: Optional[Union[bytes, str]] = None,
78
+ authentication_keys: Optional[
79
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
80
+ ] = None,
81
+ adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None,
58
82
  ) -> Iterator[
59
83
  Tuple[
60
84
  Callable[[], Optional[Message]],
61
85
  Callable[[Message], None],
86
+ Optional[Callable[[], Optional[int]]],
62
87
  Optional[Callable[[], None]],
63
- Optional[Callable[[], None]],
88
+ Optional[Callable[[int], Run]],
89
+ Optional[Callable[[str], Fab]],
64
90
  ]
65
91
  ]:
66
92
  """Primitives for request/response-based interaction with a server.
@@ -87,6 +113,11 @@ def grpc_request_response(
87
113
  Path of the root certificate. If provided, a secure
88
114
  connection using the certificates will be established to an SSL-enabled
89
115
  Flower server. Bytes won't work for the REST API.
116
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
117
+ Tuple containing the elliptic curve private key and public key for
118
+ authentication from the cryptography library.
119
+ Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
120
+ Used to establish an authenticated connection with the server.
90
121
 
91
122
  Returns
92
123
  -------
@@ -94,60 +125,101 @@ def grpc_request_response(
94
125
  send : Callable
95
126
  create_node : Optional[Callable]
96
127
  delete_node : Optional[Callable]
128
+ get_run : Optional[Callable]
97
129
  """
98
- warn_experimental_feature("`grpc-rere`")
99
-
100
130
  if isinstance(root_certificates, str):
101
131
  root_certificates = Path(root_certificates).read_bytes()
102
132
 
133
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
134
+ if authentication_keys is not None:
135
+ interceptors = AuthenticateClientInterceptor(
136
+ authentication_keys[0], authentication_keys[1]
137
+ )
138
+
103
139
  channel = create_channel(
104
140
  server_address=server_address,
105
141
  insecure=insecure,
106
142
  root_certificates=root_certificates,
107
143
  max_message_length=max_message_length,
144
+ interceptors=interceptors,
108
145
  )
109
146
  channel.subscribe(on_channel_state_change)
110
- stub = FleetStub(channel)
111
-
112
- # Necessary state to validate messages to be sent
113
- state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}
114
147
 
115
- # Enable create_node and delete_node to store node
116
- node_store: Dict[str, Optional[Node]] = {KEY_NODE: None}
148
+ # Shared variables for inner functions
149
+ if adapter_cls is None:
150
+ adapter_cls = FleetStub
151
+ stub = adapter_cls(channel)
152
+ metadata: Optional[Metadata] = None
153
+ node: Optional[Node] = None
154
+ ping_thread: Optional[threading.Thread] = None
155
+ ping_stop_event = threading.Event()
117
156
 
118
157
  ###########################################################################
119
- # receive/send functions
158
+ # ping/create_node/delete_node/receive/send/get_run functions
120
159
  ###########################################################################
121
160
 
122
- def create_node() -> None:
161
+ def ping() -> None:
162
+ # Get Node
163
+ if node is None:
164
+ log(ERROR, "Node instance missing")
165
+ return
166
+
167
+ # Construct the ping request
168
+ req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
169
+
170
+ # Call FleetAPI
171
+ res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT)
172
+
173
+ # Check if success
174
+ if not res.success:
175
+ raise RuntimeError("Ping failed unexpectedly.")
176
+
177
+ # Wait
178
+ rd = random.uniform(*PING_RANDOM_RANGE)
179
+ next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
180
+ next_interval *= PING_BASE_MULTIPLIER + rd
181
+ if not ping_stop_event.is_set():
182
+ ping_stop_event.wait(next_interval)
183
+
184
+ def create_node() -> Optional[int]:
123
185
  """Set create_node."""
124
- create_node_request = CreateNodeRequest()
186
+ # Call FleetAPI
187
+ create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
125
188
  create_node_response = retry_invoker.invoke(
126
189
  stub.CreateNode,
127
190
  request=create_node_request,
128
191
  )
129
- node_store[KEY_NODE] = create_node_response.node
192
+
193
+ # Remember the node and the ping-loop thread
194
+ nonlocal node, ping_thread
195
+ node = cast(Node, create_node_response.node)
196
+ ping_thread = start_ping_loop(ping, ping_stop_event)
197
+ return node.node_id
130
198
 
131
199
  def delete_node() -> None:
132
200
  """Set delete_node."""
133
201
  # Get Node
134
- if node_store[KEY_NODE] is None:
202
+ nonlocal node
203
+ if node is None:
135
204
  log(ERROR, "Node instance missing")
136
205
  return
137
- node: Node = cast(Node, node_store[KEY_NODE])
138
206
 
207
+ # Stop the ping-loop thread
208
+ ping_stop_event.set()
209
+
210
+ # Call FleetAPI
139
211
  delete_node_request = DeleteNodeRequest(node=node)
140
212
  retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
141
213
 
142
- del node_store[KEY_NODE]
214
+ # Cleanup
215
+ node = None
143
216
 
144
217
  def receive() -> Optional[Message]:
145
218
  """Receive next task from server."""
146
219
  # Get Node
147
- if node_store[KEY_NODE] is None:
220
+ if node is None:
148
221
  log(ERROR, "Node instance missing")
149
222
  return None
150
- node: Node = cast(Node, node_store[KEY_NODE])
151
223
 
152
224
  # Request instructions (task) from server
153
225
  request = PullTaskInsRequest(node=node)
@@ -167,7 +239,8 @@ def grpc_request_response(
167
239
  in_message = message_from_taskins(task_ins) if task_ins else None
168
240
 
169
241
  # Remember `metadata` of the in message
170
- state[KEY_METADATA] = copy(in_message.metadata) if in_message else None
242
+ nonlocal metadata
243
+ metadata = copy(in_message.metadata) if in_message else None
171
244
 
172
245
  # Return the message if available
173
246
  return in_message
@@ -175,18 +248,18 @@ def grpc_request_response(
175
248
  def send(message: Message) -> None:
176
249
  """Send task result back to server."""
177
250
  # Get Node
178
- if node_store[KEY_NODE] is None:
251
+ if node is None:
179
252
  log(ERROR, "Node instance missing")
180
253
  return
181
254
 
182
- # Get incoming message
183
- in_metadata = state[KEY_METADATA]
184
- if in_metadata is None:
255
+ # Get the metadata of the incoming message
256
+ nonlocal metadata
257
+ if metadata is None:
185
258
  log(ERROR, "No current message")
186
259
  return
187
260
 
188
261
  # Validate out message
189
- if not validate_out_message(message, in_metadata):
262
+ if not validate_out_message(message, metadata):
190
263
  log(ERROR, "Invalid out message")
191
264
  return
192
265
 
@@ -197,10 +270,31 @@ def grpc_request_response(
197
270
  request = PushTaskResRequest(task_res_list=[task_res])
198
271
  _ = retry_invoker.invoke(stub.PushTaskRes, request)
199
272
 
200
- state[KEY_METADATA] = None
273
+ # Cleanup
274
+ metadata = None
275
+
276
+ def get_run(run_id: int) -> Run:
277
+ # Call FleetAPI
278
+ get_run_request = GetRunRequest(run_id=run_id)
279
+ get_run_response: GetRunResponse = retry_invoker.invoke(
280
+ stub.GetRun,
281
+ request=get_run_request,
282
+ )
283
+
284
+ # Return fab_id and fab_version
285
+ return Run(
286
+ run_id,
287
+ get_run_response.run.fab_id,
288
+ get_run_response.run.fab_version,
289
+ user_config_from_proto(get_run_response.run.override_config),
290
+ )
291
+
292
+ def get_fab(fab_hash: str) -> Fab:
293
+ # Call FleetAPI
294
+ raise NotImplementedError
201
295
 
202
296
  try:
203
297
  # Yield methods
204
- yield (receive, send, create_node, delete_node)
298
+ yield (receive, send, create_node, delete_node, get_run, get_fab)
205
299
  except Exception as exc: # pylint: disable=broad-except
206
300
  log(ERROR, exc)
@@ -0,0 +1,140 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """GrpcAdapter implementation."""
16
+
17
+
18
+ import sys
19
+ from logging import DEBUG
20
+ from typing import Any, Type, TypeVar, cast
21
+
22
+ import grpc
23
+ from google.protobuf.message import Message as GrpcMessage
24
+
25
+ from flwr.common import log
26
+ from flwr.common.constant import (
27
+ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY,
28
+ GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
29
+ )
30
+ from flwr.common.version import package_version
31
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
+ CreateNodeRequest,
34
+ CreateNodeResponse,
35
+ DeleteNodeRequest,
36
+ DeleteNodeResponse,
37
+ PingRequest,
38
+ PingResponse,
39
+ PullTaskInsRequest,
40
+ PullTaskInsResponse,
41
+ PushTaskResRequest,
42
+ PushTaskResResponse,
43
+ )
44
+ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
45
+ from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
46
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
47
+
48
+ T = TypeVar("T", bound=GrpcMessage)
49
+
50
+
51
+ class GrpcAdapter:
52
+ """Adapter class to send and receive gRPC messages via the ``GrpcAdapterStub``.
53
+
54
+ This class utilizes the ``GrpcAdapterStub`` to send and receive gRPC messages
55
+ which are defined and used by the Fleet API, as defined in ``fleet.proto``.
56
+ """
57
+
58
+ def __init__(self, channel: grpc.Channel) -> None:
59
+ self.stub = GrpcAdapterStub(channel)
60
+
61
+ def _send_and_receive(
62
+ self, request: GrpcMessage, response_type: Type[T], **kwargs: Any
63
+ ) -> T:
64
+ # Serialize request
65
+ container_req = MessageContainer(
66
+ metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version},
67
+ grpc_message_name=request.__class__.__qualname__,
68
+ grpc_message_content=request.SerializeToString(),
69
+ )
70
+
71
+ # Send via the stub
72
+ container_res = cast(
73
+ MessageContainer, self.stub.SendReceive(container_req, **kwargs)
74
+ )
75
+
76
+ # Handle control message
77
+ should_exit = (
78
+ container_res.metadata.get(GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY, "false")
79
+ == "true"
80
+ )
81
+ if should_exit:
82
+ log(
83
+ DEBUG,
84
+ 'Received shutdown signal: exit flag is set to ``"true"``. Exiting...',
85
+ )
86
+ sys.exit(0)
87
+
88
+ # Check the grpc_message_name of the response
89
+ if container_res.grpc_message_name != response_type.__qualname__:
90
+ raise ValueError(
91
+ f"Invalid grpc_message_name. Expected {response_type.__qualname__}"
92
+ f", but got {container_res.grpc_message_name}."
93
+ )
94
+
95
+ # Deserialize response
96
+ response = response_type()
97
+ response.ParseFromString(container_res.grpc_message_content)
98
+ return response
99
+
100
+ def CreateNode( # pylint: disable=C0103
101
+ self, request: CreateNodeRequest, **kwargs: Any
102
+ ) -> CreateNodeResponse:
103
+ """."""
104
+ return self._send_and_receive(request, CreateNodeResponse, **kwargs)
105
+
106
+ def DeleteNode( # pylint: disable=C0103
107
+ self, request: DeleteNodeRequest, **kwargs: Any
108
+ ) -> DeleteNodeResponse:
109
+ """."""
110
+ return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
111
+
112
+ def Ping( # pylint: disable=C0103
113
+ self, request: PingRequest, **kwargs: Any
114
+ ) -> PingResponse:
115
+ """."""
116
+ return self._send_and_receive(request, PingResponse, **kwargs)
117
+
118
+ def PullTaskIns( # pylint: disable=C0103
119
+ self, request: PullTaskInsRequest, **kwargs: Any
120
+ ) -> PullTaskInsResponse:
121
+ """."""
122
+ return self._send_and_receive(request, PullTaskInsResponse, **kwargs)
123
+
124
+ def PushTaskRes( # pylint: disable=C0103
125
+ self, request: PushTaskResRequest, **kwargs: Any
126
+ ) -> PushTaskResResponse:
127
+ """."""
128
+ return self._send_and_receive(request, PushTaskResResponse, **kwargs)
129
+
130
+ def GetRun( # pylint: disable=C0103
131
+ self, request: GetRunRequest, **kwargs: Any
132
+ ) -> GetRunResponse:
133
+ """."""
134
+ return self._send_and_receive(request, GetRunResponse, **kwargs)
135
+
136
+ def GetFab( # pylint: disable=C0103
137
+ self, request: GetFabRequest, **kwargs: Any
138
+ ) -> GetFabResponse:
139
+ """."""
140
+ return self._send_and_receive(request, GetFabResponse, **kwargs)
@@ -0,0 +1,74 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Heartbeat utility functions."""
16
+
17
+
18
+ import threading
19
+ from typing import Callable
20
+
21
+ import grpc
22
+
23
+ from flwr.common.constant import PING_CALL_TIMEOUT
24
+ from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
25
+
26
+
27
+ def _ping_loop(ping_fn: Callable[[], None], stop_event: threading.Event) -> None:
28
+ def wait_fn(wait_time: float) -> None:
29
+ if not stop_event.is_set():
30
+ stop_event.wait(wait_time)
31
+
32
+ def on_backoff(state: RetryState) -> None:
33
+ err = state.exception
34
+ if not isinstance(err, grpc.RpcError):
35
+ return
36
+ status_code = err.code()
37
+ # If ping call timeout is triggered
38
+ if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
39
+ # Avoid long wait time.
40
+ if state.actual_wait is None:
41
+ return
42
+ state.actual_wait = max(state.actual_wait - PING_CALL_TIMEOUT, 0.0)
43
+
44
+ def wrapped_ping() -> None:
45
+ if not stop_event.is_set():
46
+ ping_fn()
47
+
48
+ retrier = RetryInvoker(
49
+ exponential,
50
+ grpc.RpcError,
51
+ max_tries=None,
52
+ max_time=None,
53
+ on_backoff=on_backoff,
54
+ wait_function=wait_fn,
55
+ )
56
+ while not stop_event.is_set():
57
+ retrier.invoke(wrapped_ping)
58
+
59
+
60
+ def start_ping_loop(
61
+ ping_fn: Callable[[], None], stop_event: threading.Event
62
+ ) -> threading.Thread:
63
+ """Start a ping loop in a separate thread.
64
+
65
+ This function initializes a new thread that runs a ping loop, allowing for
66
+ asynchronous ping operations. The loop can be terminated through the provided stop
67
+ event.
68
+ """
69
+ thread = threading.Thread(
70
+ target=_ping_loop, args=(ping_fn, stop_event), daemon=True
71
+ )
72
+ thread.start()
73
+
74
+ return thread
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Client-side message handler."""
16
16
 
17
-
18
17
  from logging import WARN
19
18
  from typing import Optional, Tuple, cast
20
19
 
@@ -25,7 +24,7 @@ from flwr.client.client import (
25
24
  maybe_call_get_properties,
26
25
  )
27
26
  from flwr.client.numpy_client import NumPyClient
28
- from flwr.client.typing import ClientFn
27
+ from flwr.client.typing import ClientFnExt
29
28
  from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
30
29
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
30
  from flwr.common.recordset_compat import (
@@ -81,7 +80,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
81
80
  reason = cast(int, disconnect_msg.disconnect_res.reason)
82
81
  recordset = RecordSet()
83
82
  recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
84
- out_message = message.create_reply(recordset, ttl="")
83
+ out_message = message.create_reply(recordset)
85
84
  # Return TaskRes and sleep duration
86
85
  return out_message, sleep_duration
87
86
 
@@ -90,10 +89,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
90
89
 
91
90
 
92
91
  def handle_legacy_message_from_msgtype(
93
- client_fn: ClientFn, message: Message, context: Context
92
+ client_fn: ClientFnExt, message: Message, context: Context
94
93
  ) -> Message:
95
94
  """Handle legacy message in the inner most mod."""
96
- client = client_fn(str(message.metadata.partition_id))
95
+ client = client_fn(context)
97
96
 
98
97
  # Check if NumPyClient is returend
99
98
  if isinstance(client, NumPyClient):
@@ -143,7 +142,7 @@ def handle_legacy_message_from_msgtype(
143
142
  raise ValueError(f"Invalid message type: {message_type}")
144
143
 
145
144
  # Return Message
146
- return message.create_reply(out_recordset, ttl="")
145
+ return message.create_reply(out_recordset)
147
146
 
148
147
 
149
148
  def _reconnect(
@@ -172,6 +171,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
172
171
  and out_meta.reply_to_message == in_meta.message_id
173
172
  and out_meta.group_id == in_meta.group_id
174
173
  and out_meta.message_type == in_meta.message_type
174
+ and out_meta.created_at > in_meta.created_at
175
175
  ):
176
176
  return True
177
177
  return False
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Mods."""
15
+ """Flower Built-in Mods."""
16
16
 
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
@@ -22,12 +22,12 @@ from .secure_aggregation import secagg_mod, secaggplus_mod
22
22
  from .utils import make_ffn
23
23
 
24
24
  __all__ = [
25
+ "LocalDpMod",
25
26
  "adaptiveclipping_mod",
26
27
  "fixedclipping_mod",
27
- "LocalDpMod",
28
28
  "make_ffn",
29
- "secagg_mod",
30
- "secaggplus_mod",
31
29
  "message_size_mod",
32
30
  "parameters_size_mod",
31
+ "secagg_mod",
32
+ "secaggplus_mod",
33
33
  ]
@@ -82,7 +82,9 @@ def fixedclipping_mod(
82
82
  clipping_norm,
83
83
  )
84
84
 
85
- log(INFO, "fixedclipping_mod: parameters are clipped by value: %s.", clipping_norm)
85
+ log(
86
+ INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
87
+ )
86
88
 
87
89
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
88
90
  out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
@@ -146,7 +148,7 @@ def adaptiveclipping_mod(
146
148
  )
147
149
  log(
148
150
  INFO,
149
- "adaptiveclipping_mod: parameters are clipped by value: %s.",
151
+ "adaptiveclipping_mod: parameters are clipped by value: %.4f.",
150
152
  clipping_norm,
151
153
  )
152
154
 
@@ -29,7 +29,7 @@ def message_size_mod(
29
29
  ) -> Message:
30
30
  """Message size mod.
31
31
 
32
- This mod logs the size in Bytes of the message being transmited.
32
+ This mod logs the size in bytes of the message being transmited.
33
33
  """
34
34
  message_size_in_bytes = 0
35
35
 
@@ -42,7 +42,7 @@ def message_size_mod(
42
42
  for m_record in msg.content.metrics_records.values():
43
43
  message_size_in_bytes += m_record.count_bytes()
44
44
 
45
- log(INFO, "Message size: %i Bytes", message_size_in_bytes)
45
+ log(INFO, "Message size: %i bytes", message_size_in_bytes)
46
46
 
47
47
  return call_next(msg, ctxt)
48
48
 
@@ -53,7 +53,7 @@ def parameters_size_mod(
53
53
  """Parameters size mod.
54
54
 
55
55
  This mod logs the number of parameters transmitted in the message as well as their
56
- size in Bytes.
56
+ size in bytes.
57
57
  """
58
58
  model_size_stats = {}
59
59
  parameters_size_in_bytes = 0
@@ -74,6 +74,6 @@ def parameters_size_mod(
74
74
  if model_size_stats:
75
75
  log(INFO, model_size_stats)
76
76
 
77
- log(INFO, "Total parameters transmited: %i Bytes", parameters_size_in_bytes)
77
+ log(INFO, "Total parameters transmitted: %i bytes", parameters_size_in_bytes)
78
78
 
79
79
  return call_next(msg, ctxt)
@@ -128,7 +128,9 @@ class LocalDpMod:
128
128
  self.clipping_norm,
129
129
  )
130
130
  log(
131
- INFO, "LocalDpMod: parameters are clipped by value: %s.", self.clipping_norm
131
+ INFO,
132
+ "LocalDpMod: parameters are clipped by value: %.4f.",
133
+ self.clipping_norm,
132
134
  )
133
135
 
134
136
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
@@ -137,11 +139,14 @@ class LocalDpMod:
137
139
  add_localdp_gaussian_noise_to_params(
138
140
  fit_res.parameters, self.sensitivity, self.epsilon, self.delta
139
141
  )
142
+
143
+ noise_value_sd = (
144
+ self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
145
+ )
140
146
  log(
141
147
  INFO,
142
- "LocalDpMod: local DP noise with "
143
- "standard deviation: %s added to parameters.",
144
- self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon,
148
+ "LocalDpMod: local DP noise with %.4f stedv added to parameters",
149
+ noise_value_sd,
145
150
  )
146
151
 
147
152
  out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)