flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/server/compat/app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,14 +15,11 @@
15
15
  """Flower driver app."""
16
16
 
17
17
 
18
- import sys
19
18
  from logging import INFO
20
- from pathlib import Path
21
- from typing import Optional, Union
19
+ from typing import Optional
22
20
 
23
21
  from flwr.common import EventType, event
24
- from flwr.common.address import parse_address
25
- from flwr.common.logger import log, warn_deprecated_feature
22
+ from flwr.common.logger import log
26
23
  from flwr.server.client_manager import ClientManager
27
24
  from flwr.server.history import History
28
25
  from flwr.server.server import Server, init_defaults, run_fl
@@ -32,33 +29,21 @@ from flwr.server.strategy import Strategy
32
29
  from ..driver import Driver
33
30
  from .app_utils import start_update_client_manager_thread
34
31
 
35
- DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
36
-
37
- ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
38
- [Driver] Error: Not connected.
39
-
40
- Call `connect()` on the `Driver` instance before calling any of the other `Driver`
41
- methods.
42
- """
43
-
44
32
 
45
33
  def start_driver( # pylint: disable=too-many-arguments, too-many-locals
46
34
  *,
47
- server_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
35
+ driver: Driver,
48
36
  server: Optional[Server] = None,
49
37
  config: Optional[ServerConfig] = None,
50
38
  strategy: Optional[Strategy] = None,
51
39
  client_manager: Optional[ClientManager] = None,
52
- root_certificates: Optional[Union[bytes, str]] = None,
53
- driver: Optional[Driver] = None,
54
40
  ) -> History:
55
41
  """Start a Flower Driver API server.
56
42
 
57
43
  Parameters
58
44
  ----------
59
- server_address : Optional[str]
60
- The IPv4 or IPv6 address of the Driver API server.
61
- Defaults to `"[::]:8080"`.
45
+ driver : Driver
46
+ The Driver object to use.
62
47
  server : Optional[flwr.server.Server] (default: None)
63
48
  A server implementation, either `flwr.server.Server` or a subclass
64
49
  thereof. If no instance is provided, then `start_driver` will create
@@ -74,50 +59,14 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
74
59
  An implementation of the class `flwr.server.ClientManager`. If no
75
60
  implementation is provided, then `start_driver` will use
76
61
  `flwr.server.SimpleClientManager`.
77
- root_certificates : Optional[Union[bytes, str]] (default: None)
78
- The PEM-encoded root certificates as a byte string or a path string.
79
- If provided, a secure connection using the certificates will be
80
- established to an SSL-enabled Flower server.
81
- driver : Optional[Driver] (default: None)
82
- The Driver object to use.
83
62
 
84
63
  Returns
85
64
  -------
86
65
  hist : flwr.server.history.History
87
66
  Object containing training and evaluation metrics.
88
-
89
- Examples
90
- --------
91
- Starting a driver that connects to an insecure server:
92
-
93
- >>> start_driver()
94
-
95
- Starting a driver that connects to an SSL-enabled server:
96
-
97
- >>> start_driver(
98
- >>> root_certificates=Path("/crts/root.pem").read_bytes()
99
- >>> )
100
67
  """
101
68
  event(EventType.START_DRIVER_ENTER)
102
69
 
103
- if driver is None:
104
- # Not passing a `Driver` object is deprecated
105
- warn_deprecated_feature("start_driver")
106
-
107
- # Parse IP address
108
- parsed_address = parse_address(server_address)
109
- if not parsed_address:
110
- sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
111
- host, port, is_v6 = parsed_address
112
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
113
-
114
- # Create the Driver
115
- if isinstance(root_certificates, str):
116
- root_certificates = Path(root_certificates).read_bytes()
117
- driver = Driver(
118
- driver_service_address=address, root_certificates=root_certificates
119
- )
120
-
121
70
  # Initialize the Driver API server and config
122
71
  initialized_server, initialized_config = init_defaults(
123
72
  server=server,
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import threading
19
- import time
20
19
  from typing import Dict, Tuple
21
20
 
22
21
  from ..client_manager import ClientManager
@@ -60,6 +59,7 @@ def start_update_client_manager_thread(
60
59
  client_manager,
61
60
  f_stop,
62
61
  ),
62
+ daemon=True,
63
63
  )
64
64
  thread.start()
65
65
 
@@ -89,9 +89,9 @@ def _update_client_manager(
89
89
  for node_id in new_nodes:
90
90
  client_proxy = DriverClientProxy(
91
91
  node_id=node_id,
92
- driver=driver.grpc_driver, # type: ignore
92
+ driver=driver,
93
93
  anonymous=False,
94
- run_id=driver.run_id, # type: ignore
94
+ run_id=driver.run.run_id,
95
95
  )
96
96
  if client_manager.register(client_proxy):
97
97
  registered_nodes[node_id] = client_proxy
@@ -99,4 +99,5 @@ def _update_client_manager(
99
99
  raise RuntimeError("Could not register node.")
100
100
 
101
101
  # Sleep for 3 seconds
102
- time.sleep(3)
102
+ if not f_stop.is_set():
103
+ f_stop.wait(3)
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,16 +16,14 @@
16
16
 
17
17
 
18
18
  import time
19
- from typing import List, Optional
19
+ from typing import Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
- from flwr.common import serde
25
- from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
24
  from flwr.server.client_proxy import ClientProxy
27
25
 
28
- from ..driver.grpc_driver import GrpcDriver
26
+ from ..driver.driver import Driver
29
27
 
30
28
  SLEEP_TIME = 1
31
29
 
@@ -33,7 +31,7 @@ SLEEP_TIME = 1
33
31
  class DriverClientProxy(ClientProxy):
34
32
  """Flower client proxy which delegates work using the Driver API."""
35
33
 
36
- def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int):
34
+ def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
37
35
  super().__init__(str(node_id))
38
36
  self.node_id = node_id
39
37
  self.driver = driver
@@ -114,55 +112,38 @@ class DriverClientProxy(ClientProxy):
114
112
  timeout: Optional[float],
115
113
  group_id: Optional[int],
116
114
  ) -> RecordSet:
117
- task_ins = task_pb2.TaskIns( # pylint: disable=E1101
118
- task_id="",
119
- group_id=str(group_id) if group_id is not None else "",
120
- run_id=self.run_id,
121
- task=task_pb2.Task( # pylint: disable=E1101
122
- producer=node_pb2.Node( # pylint: disable=E1101
123
- node_id=0,
124
- anonymous=True,
125
- ),
126
- consumer=node_pb2.Node( # pylint: disable=E1101
127
- node_id=self.node_id,
128
- anonymous=self.anonymous,
129
- ),
130
- task_type=task_type,
131
- recordset=serde.recordset_to_proto(recordset),
132
- ),
133
- )
134
- push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
135
- task_ins_list=[task_ins]
136
- )
137
115
 
138
- # Send TaskIns to Driver API
139
- push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
116
+ # Create message
117
+ message = self.driver.create_message(
118
+ content=recordset,
119
+ message_type=task_type,
120
+ dst_node_id=self.node_id,
121
+ group_id=str(group_id) if group_id else "",
122
+ ttl=timeout,
123
+ )
140
124
 
141
- if len(push_task_ins_res.task_ids) != 1:
142
- raise ValueError("Unexpected number of task_ids")
125
+ # Push message
126
+ message_ids = list(self.driver.push_messages(messages=[message]))
127
+ if len(message_ids) != 1:
128
+ raise ValueError("Unexpected number of message_ids")
143
129
 
144
- task_id = push_task_ins_res.task_ids[0]
145
- if task_id == "":
146
- raise ValueError(f"Failed to schedule task for node {self.node_id}")
130
+ message_id = message_ids[0]
131
+ if message_id == "":
132
+ raise ValueError(f"Failed to send message to node {self.node_id}")
147
133
 
148
134
  if timeout:
149
135
  start_time = time.time()
150
136
 
151
137
  while True:
152
- pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
153
- node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
154
- task_ids=[task_id],
155
- )
156
-
157
- # Ask Driver API for TaskRes
158
- pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
159
-
160
- task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
161
- pull_task_res_res.task_res_list
162
- )
163
- if len(task_res_list) == 1:
164
- task_res = task_res_list[0]
165
- return serde.recordset_from_proto(task_res.task.recordset)
138
+ messages = list(self.driver.pull_messages(message_ids))
139
+ if len(messages) == 1:
140
+ msg: Message = messages[0]
141
+ if msg.has_error():
142
+ raise ValueError(
143
+ f"Message contains an Error (reason: {msg.error.reason}). "
144
+ "It originated during client-side execution of a message."
145
+ )
146
+ return msg.content
166
147
 
167
148
  if timeout is not None and time.time() > start_time + timeout:
168
149
  raise RuntimeError("Timeout reached")
@@ -18,7 +18,7 @@
18
18
  from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
- from flwr.common import Context, RecordSet
21
+ from flwr.common import Context
22
22
 
23
23
  from ..client_manager import ClientManager, SimpleClientManager
24
24
  from ..history import History
@@ -35,9 +35,9 @@ class LegacyContext(Context):
35
35
  client_manager: ClientManager
36
36
  history: History
37
37
 
38
- def __init__(
38
+ def __init__( # pylint: disable=too-many-arguments
39
39
  self,
40
- state: RecordSet,
40
+ context: Context,
41
41
  config: Optional[ServerConfig] = None,
42
42
  strategy: Optional[Strategy] = None,
43
43
  client_manager: Optional[ClientManager] = None,
@@ -52,4 +52,5 @@ class LegacyContext(Context):
52
52
  self.strategy = strategy
53
53
  self.client_manager = client_manager
54
54
  self.history = History()
55
- super().__init__(state)
55
+
56
+ super().__init__(**vars(context))
@@ -17,8 +17,10 @@
17
17
 
18
18
  from .driver import Driver
19
19
  from .grpc_driver import GrpcDriver
20
+ from .inmemory_driver import InMemoryDriver
20
21
 
21
22
  __all__ = [
22
23
  "Driver",
23
24
  "GrpcDriver",
25
+ "InMemoryDriver",
24
26
  ]
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,85 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
15
+ """Driver (abstract base class)."""
16
16
 
17
17
 
18
- import time
19
- from typing import Iterable, List, Optional, Tuple
18
+ from abc import ABC, abstractmethod
19
+ from typing import Iterable, List, Optional
20
20
 
21
- from flwr.common import Message, Metadata, RecordSet
22
- from flwr.common.serde import message_from_taskres, message_to_taskins
23
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
24
- CreateRunRequest,
25
- GetNodesRequest,
26
- PullTaskResRequest,
27
- PushTaskInsRequest,
28
- )
29
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
30
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
+ from flwr.common import Message, RecordSet
22
+ from flwr.common.typing import Run
31
23
 
32
- from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
33
24
 
25
+ class Driver(ABC):
26
+ """Abstract base Driver class for the Driver API."""
34
27
 
35
- class Driver:
36
- """`Driver` class provides an interface to the Driver API.
37
-
38
- Parameters
39
- ----------
40
- driver_service_address : Optional[str]
41
- The IPv4 or IPv6 address of the Driver API server.
42
- Defaults to `"[::]:9091"`.
43
- certificates : bytes (default: None)
44
- Tuple containing root certificate, server certificate, and private key
45
- to start a secure SSL-enabled server. The tuple is expected to have
46
- three bytes elements in the following order:
47
-
48
- * CA certificate.
49
- * server certificate.
50
- * server private key.
51
- """
52
-
53
- def __init__(
54
- self,
55
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
56
- root_certificates: Optional[bytes] = None,
57
- ) -> None:
58
- self.addr = driver_service_address
59
- self.root_certificates = root_certificates
60
- self.grpc_driver: Optional[GrpcDriver] = None
61
- self.run_id: Optional[int] = None
62
- self.node = Node(node_id=0, anonymous=True)
63
-
64
- def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]:
65
- # Check if the GrpcDriver is initialized
66
- if self.grpc_driver is None or self.run_id is None:
67
- # Connect and create run
68
- self.grpc_driver = GrpcDriver(
69
- driver_service_address=self.addr,
70
- root_certificates=self.root_certificates,
71
- )
72
- self.grpc_driver.connect()
73
- res = self.grpc_driver.create_run(CreateRunRequest())
74
- self.run_id = res.run_id
75
- return self.grpc_driver, self.run_id
76
-
77
- def _check_message(self, message: Message) -> None:
78
- # Check if the message is valid
79
- if not (
80
- message.metadata.run_id == self.run_id
81
- and message.metadata.src_node_id == self.node.node_id
82
- and message.metadata.message_id == ""
83
- and message.metadata.reply_to_message == ""
84
- ):
85
- raise ValueError(f"Invalid message: {message}")
28
+ @property
29
+ @abstractmethod
30
+ def run(self) -> Run:
31
+ """Run information."""
86
32
 
33
+ @abstractmethod
87
34
  def create_message( # pylint: disable=too-many-arguments
88
35
  self,
89
36
  content: RecordSet,
90
37
  message_type: str,
91
38
  dst_node_id: int,
92
39
  group_id: str,
93
- ttl: str,
40
+ ttl: Optional[float] = None,
94
41
  ) -> Message:
95
42
  """Create a new message with specified parameters.
96
43
 
@@ -110,36 +57,23 @@ class Driver:
110
57
  group_id : str
111
58
  The ID of the group to which this message is associated. In some settings,
112
59
  this is used as the FL round.
113
- ttl : str
60
+ ttl : Optional[float] (default: None)
114
61
  Time-to-live for the round trip of this message, i.e., the time from sending
115
- this message to receiving a reply. It specifies the duration for which the
116
- message and its potential reply are considered valid.
62
+ this message to receiving a reply. It specifies in seconds the duration for
63
+ which the message and its potential reply are considered valid. If unset,
64
+ the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
117
65
 
118
66
  Returns
119
67
  -------
120
68
  message : Message
121
69
  A new `Message` instance with the specified content and metadata.
122
70
  """
123
- _, run_id = self._get_grpc_driver_and_run_id()
124
- metadata = Metadata(
125
- run_id=run_id,
126
- message_id="", # Will be set by the server
127
- src_node_id=self.node.node_id,
128
- dst_node_id=dst_node_id,
129
- reply_to_message="",
130
- group_id=group_id,
131
- ttl=ttl,
132
- message_type=message_type,
133
- )
134
- return Message(metadata=metadata, content=content)
135
71
 
72
+ @abstractmethod
136
73
  def get_node_ids(self) -> List[int]:
137
74
  """Get node IDs."""
138
- grpc_driver, run_id = self._get_grpc_driver_and_run_id()
139
- # Call GrpcDriver method
140
- res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id))
141
- return [node.node_id for node in res.nodes]
142
75
 
76
+ @abstractmethod
143
77
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
144
78
  """Push messages to specified node IDs.
145
79
 
@@ -157,20 +91,8 @@ class Driver:
157
91
  An iterable of IDs for the messages that were sent, which can be used
158
92
  to pull replies.
159
93
  """
160
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
161
- # Construct TaskIns
162
- task_ins_list: List[TaskIns] = []
163
- for msg in messages:
164
- # Check message
165
- self._check_message(msg)
166
- # Convert Message to TaskIns
167
- taskins = message_to_taskins(msg)
168
- # Add to list
169
- task_ins_list.append(taskins)
170
- # Call GrpcDriver method
171
- res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
172
- return list(res.task_ids)
173
94
 
95
+ @abstractmethod
174
96
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
175
97
  """Pull messages based on message IDs.
176
98
 
@@ -187,15 +109,8 @@ class Driver:
187
109
  messages : Iterable[Message]
188
110
  An iterable of messages received.
189
111
  """
190
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
191
- # Pull TaskRes
192
- res = grpc_driver.pull_task_res(
193
- PullTaskResRequest(node=self.node, task_ids=message_ids)
194
- )
195
- # Convert TaskRes to Message
196
- msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
197
- return msgs
198
112
 
113
+ @abstractmethod
199
114
  def send_and_receive(
200
115
  self,
201
116
  messages: Iterable[Message],
@@ -229,28 +144,3 @@ class Driver:
229
144
  replies for all sent messages. A message remains valid until its TTL,
230
145
  which is not affected by `timeout`.
231
146
  """
232
- # Push messages
233
- msg_ids = set(self.push_messages(messages))
234
-
235
- # Pull messages
236
- end_time = time.time() + (timeout if timeout is not None else 0.0)
237
- ret: List[Message] = []
238
- while timeout is None or time.time() < end_time:
239
- res_msgs = self.pull_messages(msg_ids)
240
- ret.extend(res_msgs)
241
- msg_ids.difference_update(
242
- {msg.metadata.reply_to_message for msg in res_msgs}
243
- )
244
- if len(msg_ids) == 0:
245
- break
246
- # Sleep
247
- time.sleep(3)
248
- return ret
249
-
250
- def close(self) -> None:
251
- """Disconnect from the SuperLink if connected."""
252
- # Check if GrpcDriver is initialized
253
- if self.grpc_driver is None:
254
- return
255
- # Disconnect
256
- self.grpc_driver.disconnect()